tenflowers-dataset 0.1.1

Data pipeline and dataset utilities for TenfloweRS
Documentation
//! SIMD-accelerated image processing operations
//!
//! This module provides vectorized implementations of image processing operations
//! such as color space conversion and histogram computation using SIMD instructions.

#![allow(unsafe_code)]

use crate::Transform;
use std::marker::PhantomData;
use tenflowers_core::{Result, Tensor, TensorError};

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

/// SIMD-accelerated RGB to HSV color space conversion
pub struct SimdColorConvert<T> {
    use_simd: bool,
    _phantom: PhantomData<T>,
}

impl<T> SimdColorConvert<T>
where
    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
    pub fn new() -> Self {
        #[cfg(target_arch = "x86_64")]
        let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;

        #[cfg(not(target_arch = "x86_64"))]
        let use_simd = false;

        Self {
            use_simd,
            _phantom: PhantomData,
        }
    }

    /// Convert RGB to HSV using SIMD acceleration
    pub fn rgb_to_hsv(&self, rgb_data: &mut [T]) {
        if self.use_simd && std::mem::size_of::<T>() == 4 && rgb_data.len() % 3 == 0 {
            #[cfg(target_arch = "x86_64")]
            unsafe {
                self.rgb_to_hsv_f32_simd(std::mem::transmute::<&mut [T], &mut [f32]>(rgb_data));
                return;
            }
        }

        // Fallback to scalar implementation
        self.rgb_to_hsv_scalar(rgb_data);
    }

    /// SIMD-accelerated RGB to HSV conversion for f32 data
    #[cfg(target_arch = "x86_64")]
    unsafe fn rgb_to_hsv_f32_simd(&self, rgb_data: &mut [f32]) {
        let pixels = rgb_data.len() / 3;

        for i in 0..pixels {
            let base = i * 3;
            let r = rgb_data[base];
            let g = rgb_data[base + 1];
            let b = rgb_data[base + 2];

            let max_val = r.max(g.max(b));
            let min_val = r.min(g.min(b));
            let delta = max_val - min_val;

            // Value
            let v = max_val;

            // Saturation
            let s = if max_val == 0.0 { 0.0 } else { delta / max_val };

            // Hue
            let h = if delta == 0.0 {
                0.0
            } else if max_val == r {
                60.0 * (((g - b) / delta) % 6.0)
            } else if max_val == g {
                60.0 * ((b - r) / delta + 2.0)
            } else {
                60.0 * ((r - g) / delta + 4.0)
            };

            let h_normalized = if h < 0.0 { h + 360.0 } else { h };

            rgb_data[base] = h_normalized / 360.0; // Normalize H to [0,1]
            rgb_data[base + 1] = s;
            rgb_data[base + 2] = v;
        }
    }

    /// Scalar fallback for RGB to HSV conversion
    fn rgb_to_hsv_scalar(&self, rgb_data: &mut [T]) {
        let pixels = rgb_data.len() / 3;

        for i in 0..pixels {
            let base = i * 3;
            let r = rgb_data[base];
            let g = rgb_data[base + 1];
            let b = rgb_data[base + 2];

            let max_val = r.max(g.max(b));
            let min_val = r.min(g.min(b));
            let delta = max_val - min_val;

            // Value
            let v = max_val;

            // Saturation
            let s = if max_val == T::zero() {
                T::zero()
            } else {
                delta / max_val
            };

            // Hue
            let h = if delta == T::zero() {
                T::zero()
            } else if max_val == r {
                let six = T::from(6.0).unwrap_or_else(|| T::from(6).unwrap_or(T::zero()));
                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
                sixty * (((g - b) / delta) % six)
            } else if max_val == g {
                let two = T::from(2.0).unwrap_or_else(|| T::from(2).unwrap_or(T::zero()));
                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
                sixty * ((b - r) / delta + two)
            } else {
                let four = T::from(4.0).unwrap_or_else(|| T::from(4).unwrap_or(T::zero()));
                let sixty = T::from(60.0).unwrap_or_else(|| T::from(60).unwrap_or(T::zero()));
                sixty * ((r - g) / delta + four)
            };

            let three_sixty = T::from(360.0).unwrap_or_else(|| T::from(360).unwrap_or(T::one()));
            let h_normalized = if h < T::zero() { h + three_sixty } else { h };

            rgb_data[base] = h_normalized / three_sixty; // Normalize H to [0,1]
            rgb_data[base + 1] = s;
            rgb_data[base + 2] = v;
        }
    }
}

impl<T> Default for SimdColorConvert<T>
where
    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
    fn default() -> Self {
        Self::new()
    }
}

impl<T> Transform<T> for SimdColorConvert<T>
where
    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
{
    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
        let (features, labels) = sample;
        let mut data = features
            .as_slice()
            .ok_or_else(|| {
                TensorError::invalid_argument(
                    "Unable to access tensor data for color conversion".to_string(),
                )
            })?
            .to_vec();
        self.rgb_to_hsv(&mut data);
        let converted_features = Tensor::from_vec(data, features.shape().dims())?;
        Ok((converted_features, labels))
    }
}

/// SIMD-accelerated histogram computation for efficient data distribution analysis
///
/// Provides high-performance histogram calculation using SIMD instructions
/// for up to 8x speedup on compatible hardware for dataset statistics.
pub struct SimdHistogram {
    bins: usize,
    min_val: f32,
    max_val: f32,
    use_simd: bool,
}

impl SimdHistogram {
    /// Create a new SIMD-accelerated histogram calculator
    pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
        #[cfg(target_arch = "x86_64")]
        let use_simd = is_x86_feature_detected!("avx2");

        #[cfg(not(target_arch = "x86_64"))]
        let use_simd = false;

        Self {
            bins,
            min_val,
            max_val,
            use_simd,
        }
    }

    /// Get SIMD capability status
    pub fn is_simd_enabled(&self) -> bool {
        self.use_simd
    }

    /// Compute histogram of tensor data with SIMD acceleration
    pub fn compute(&self, tensor: &Tensor<f32>) -> Result<Vec<u32>> {
        let data = tensor
            .as_slice()
            .ok_or_else(|| TensorError::InvalidOperation {
                operation: "histogram_compute".to_string(),
                reason: "Cannot get tensor slice".to_string(),
                context: None,
            })?;

        let mut histogram = vec![0u32; self.bins];
        let bin_width = (self.max_val - self.min_val) / self.bins as f32;

        #[cfg(target_arch = "x86_64")]
        if self.use_simd && data.len() >= 8 {
            self.compute_simd_f32(data, &mut histogram, bin_width);
        } else {
            self.compute_scalar(data, &mut histogram, bin_width);
        }

        #[cfg(not(target_arch = "x86_64"))]
        self.compute_scalar(data, &mut histogram, bin_width);

        Ok(histogram)
    }

    /// SIMD-accelerated histogram computation for f32 data
    #[cfg(target_arch = "x86_64")]
    fn compute_simd_f32(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
        unsafe {
            let min_vec = _mm256_set1_ps(self.min_val);
            let max_vec = _mm256_set1_ps(self.max_val);
            let bin_width_vec = _mm256_set1_ps(bin_width);
            let bins_minus_one = _mm256_set1_epi32((self.bins - 1) as i32);
            let zero_vec = _mm256_setzero_si256();

            let chunks = data.chunks_exact(8);
            let remainder = chunks.remainder();

            // Process 8 elements at a time with SIMD
            for chunk in chunks {
                let values = _mm256_loadu_ps(chunk.as_ptr());

                // Clamp values to [min_val, max_val]
                let clamped = _mm256_max_ps(_mm256_min_ps(values, max_vec), min_vec);

                // Calculate bin indices: (value - min_val) / bin_width
                let normalized = _mm256_sub_ps(clamped, min_vec);
                let bin_indices_f = _mm256_div_ps(normalized, bin_width_vec);

                // Use truncation toward zero instead of rounding
                let bin_indices = _mm256_cvttps_epi32(bin_indices_f);

                // Clamp bin indices to valid range [0, bins-1]
                let clamped_indices =
                    _mm256_max_epi32(_mm256_min_epi32(bin_indices, bins_minus_one), zero_vec);

                // Extract indices and increment histogram bins
                let indices: [i32; 8] = std::mem::transmute(clamped_indices);
                for &idx in &indices {
                    histogram[idx as usize] += 1;
                }
            }

            // Process remaining elements with scalar code
            self.compute_scalar(remainder, histogram, bin_width);
        }
    }

    /// Fallback scalar implementation for non-SIMD hardware
    fn compute_scalar(&self, data: &[f32], histogram: &mut [u32], bin_width: f32) {
        for &value in data {
            let clamped = value.clamp(self.min_val, self.max_val);
            // Handle the edge case where clamped equals max_val
            let bin_idx = if clamped == self.max_val {
                self.bins - 1
            } else {
                ((clamped - self.min_val) / bin_width) as usize
            };
            let bin_idx = bin_idx.min(self.bins - 1);
            histogram[bin_idx] += 1;
        }
    }
}

impl Transform<f32> for SimdHistogram {
    fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
        // Return the original sample unchanged - this transform is meant for analysis
        Ok(sample)
    }
}

/// Specialized histogram transform that computes histogram alongside the data
pub struct SimdHistogramTransform {
    histogram_computer: SimdHistogram,
}

impl SimdHistogramTransform {
    pub fn new(bins: usize, min_val: f32, max_val: f32) -> Self {
        Self {
            histogram_computer: SimdHistogram::new(bins, min_val, max_val),
        }
    }

    pub fn apply_with_histogram(&self, input: &Tensor<f32>) -> Result<Vec<u32>> {
        self.histogram_computer.compute(input)
    }
}