#![allow(unsafe_code)]
use crate::Transform;
use std::marker::PhantomData;
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub struct SimdColorConvert<T> {
use_simd: bool,
_phantom: PhantomData<T>,
}
impl<T> SimdColorConvert<T>
where
T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
pub fn new() -> Self {
#[cfg(target_arch = "x86_64")]
let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
#[cfg(not(target_arch = "x86_64"))]
let use_simd = false;
Self {
use_simd,
_phantom: PhantomData,
}
}
pub fn rgb_to_hsv(&self, rgb_data: &mut [T]) {
if self.use_simd && std::mem::size_of::<T>() == 4 && rgb_data.len() % 3 == 0 {
#[cfg(target_arch = "x86_64")]
unsafe {
self.rgb_to_hsv_f32_simd(std::mem::transmute::<&mut [T], &mut [f32]>(rgb_data));
return;
}
}
self.rgb_to_hsv_scalar(rgb_data);
}
#[cfg(target_arch = "x86_64")]
unsafe fn rgb_to_hsv_f32_simd(&self, rgb_data: &mut [f32]) {
let pixels = rgb_data.len() / 3;
for i in 0..pixels {
let base = i * 3;
let r = rgb_data[base];
let g = rgb_data[base + 1];
let b = rgb_data[base + 2];
let max_val = r.max(g.max(b));
let min_val = r.min(g.min(b));
let delta = max_val - min_val;
let v = max_val;
let s = if max_val == 0.0 { 0.0 } else { delta / max_val };
let h = if delta == 0.0 {
0.0
} else if max_val == r {
60.0 * (((g - b) / delta) % 6.0)
} else if max_val == g {
60.0 * ((b - r) / delta + 2.0)
} else {
60.0 * ((r - g) / delta + 4.0)
};
let h_normalized = if h < 0.0 { h + 360.0 } else { h };
rgb_data[base] = h_normalized / 360.0; rgb_data[base + 1] = s;
rgb_data[base + 2] = v;
}
}
fn rgb_to_hsv_scalar(&self, rgb_data: &mut [T]) {
let pixels = rgb_data.len() / 3;
for i in 0..pixels {
let base = i * 3;
let r = rgb_data[base];
let g = rgb_data[base + 1];
let b = rgb_data[base + 2];
let max_val = r.max(g.max(b));
let min_val = r.min(g.min(b));
let delta = max_val - min_val;
let v = max_val;
let s = if max_val == T::zero() {
T::zero()
} else {
delta / max_val
};
let h = if delta == T::zero() {
T::zero()
} else if max_val == r {
let six = T::from(6.0).unwrap_or_else(|| T::from(6).unwrap_or(T::zero()));
let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
sixty * (((g - b) / delta) % six)
} else if max_val == g {
let two = T::from(2.0).unwrap_or_else(|| T::from(2).unwrap_or(T::zero()));
let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
sixty * ((b - r) / delta + two)
} else {
let four = T::from(4.0).unwrap_or_else(|| T::from(4).unwrap_or(T::zero()));
let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
sixty * ((r - g) / delta + four)
};
let three_sixty = T::from(360.0).unwrap_or_else(|| T::from(360).unwrap_or(T::one()));
let h_normalized = if h < T::zero() { h + three_sixty } else { h };
rgb_data[base] = h_normalized / three_sixty; rgb_data[base + 1] = s;
rgb_data[base + 2] = v;
}
}
}
impl<T> Default for SimdColorConvert<T>
where
T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> Transform<T> for SimdColorConvert<T>
where
T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
let (features, labels) = sample;
let mut data = features
.as_slice()
.ok_or_else(|| {
TensorError::invalid_argument(
"Unable to access tensor data for color conversion".to_string(),
)
})?
.to_vec();
self.rgb_to_hsv(&mut data);
let converted_features = Tensor::from_vec(data, features.shape().dims())?;
Ok((converted_features, labels))
}
}
pub struct SimdHistogram {
bins: usize,
min_val: f32,
max_val: f32,
use_simd: bool,
}
impl SimdHistogram {
pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
#[cfg(target_arch = "x86_64")]
let use_simd = is_x86_feature_detected!("avx2");
#[cfg(not(target_arch = "x86_64"))]
let use_simd = false;
Self {
bins,
min_val,
max_val,
use_simd,
}
}
pub fn is_simd_enabled(&self) -> bool {
self.use_simd
}
pub fn compute(&self, tensor: &Tensor<f32>) -> Result<Vec<u32>> {
let data = tensor
.as_slice()
.ok_or_else(|| TensorError::InvalidOperation {
operation: "histogram_compute".to_string(),
reason: "Cannot get tensor slice".to_string(),
context: None,
})?;
let mut histogram = vec![0u32; self.bins];
let bin_width = (self.max_val - self.min_val) / self.bins as f32;
#[cfg(target_arch = "x86_64")]
if self.use_simd && data.len() >= 8 {
self.compute_simd_f32(data, &mut histogram, bin_width);
} else {
self.compute_scalar(data, &mut histogram, bin_width);
}
#[cfg(not(target_arch = "x86_64"))]
self.compute_scalar(data, &mut histogram, bin_width);
Ok(histogram)
}
#[cfg(target_arch = "x86_64")]
fn compute_simd_f32(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
unsafe {
let min_vec = _mm256_set1_ps(self.min_val);
let max_vec = _mm256_set1_ps(self.max_val);
let bin_width_vec = _mm256_set1_ps(bin_width);
let bins_minus_one = _mm256_set1_epi32((self.bins - 1) as i32);
let zero_vec = _mm256_setzero_si256();
let chunks = data.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let values = _mm256_loadu_ps(chunk.as_ptr());
let clamped = _mm256_max_ps(_mm256_min_ps(values, max_vec), min_vec);
let normalized = _mm256_sub_ps(clamped, min_vec);
let bin_indices_f = _mm256_div_ps(normalized, bin_width_vec);
let bin_indices = _mm256_cvttps_epi32(bin_indices_f);
let clamped_indices =
_mm256_max_epi32(_mm256_min_epi32(bin_indices, bins_minus_one), zero_vec);
let indices: [i32; 8] = std::mem::transmute(clamped_indices);
for &idx in &indices {
histogram[idx as usize] += 1;
}
}
self.compute_scalar(remainder, histogram, bin_width);
}
}
fn compute_scalar(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
for &value in data {
let clamped = value.clamp(self.min_val, self.max_val);
let bin_idx = if clamped == self.max_val {
self.bins - 1
} else {
((clamped - self.min_val) / bin_width) as usize
};
let bin_idx = bin_idx.min(self.bins - 1);
histogram[bin_idx] += 1;
}
}
}
impl Transform<f32> for SimdHistogram {
fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
Ok(sample)
}
}
pub struct SimdHistogramTransform {
histogram_computer: SimdHistogram,
}
impl SimdHistogramTransform {
pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
Self {
histogram_computer: SimdHistogram::new(bins, min_val, max_val),
}
}
pub fn apply_with_histogram(&self, input: &Tensor<f32>) -> Result<Vec<u32>> {
self.histogram_computer.compute(input)
}
}