Skip to main content

scirs2_stats/
descriptive_simd.rs

1//! SIMD-optimized descriptive statistics functions
2//!
3//! This module provides SIMD-accelerated implementations of common
4//! statistical functions using scirs2-core's unified SIMD operations.
5
6use crate::error::StatsResult;
7use crate::error_standardization::ErrorMessages;
8use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
11
12/// Calculate the mean of an array using SIMD operations when available
13///
14/// This function automatically selects the best implementation based on:
15/// - Array size
16/// - Available SIMD capabilities
17/// - Data alignment
18///
19/// # Arguments
20///
21/// * `x` - Input data array
22///
23/// # Returns
24///
25/// * The arithmetic mean of the input data
26///
27/// # Examples
28///
29/// ```
30/// use scirs2_core::ndarray::array;
31/// use scirs2_stats::mean_simd;
32///
33/// let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
34/// let mean = mean_simd(&data.view()).expect("Operation failed");
35/// assert!((mean - 3.0_f64).abs() < 1e-10);
36/// ```
37#[allow(dead_code)]
38pub fn mean_simd<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<F>
39where
40    F: Float + NumCast + SimdUnifiedOps,
41    D: Data<Elem = F>,
42{
43    if x.is_empty() {
44        return Err(ErrorMessages::empty_array("x"));
45    }
46
47    let n = x.len();
48    let optimizer = AutoOptimizer::new();
49
50    // Let the optimizer decide the best approach
51    let sum = if optimizer.should_use_simd(n) {
52        // Use SIMD operations for sum
53        F::simd_sum(&x.view())
54    } else {
55        // Fallback to scalar sum for small arrays
56        x.iter().fold(F::zero(), |acc, &val| acc + val)
57    };
58
59    Ok(sum / F::from(n).expect("Failed to convert to float"))
60}
61
62/// Calculate variance using SIMD operations
63///
64/// Computes the variance using Welford's algorithm with SIMD acceleration
65/// for better numerical stability.
66///
67/// # Arguments
68///
69/// * `x` - Input data array
70/// * `ddof` - Delta degrees of freedom (0 for population, 1 for sample)
71///
72/// # Returns
73///
74/// * The variance of the input data
75#[allow(dead_code)]
76pub fn variance_simd<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
77where
78    F: Float + NumCast + SimdUnifiedOps,
79    D: Data<Elem = F>,
80{
81    let n = x.len();
82    if n <= ddof {
83        return Err(ErrorMessages::insufficientdata(
84            "variance calculation",
85            ddof + 1,
86            n,
87        ));
88    }
89
90    // First compute the mean
91    let mean = mean_simd(x)?;
92
93    // Use SIMD to compute sum of squared deviations
94    let optimizer = AutoOptimizer::new();
95
96    let sum_sq_dev = if optimizer.should_use_simd(n) {
97        // Create a constant array filled with mean for SIMD subtraction
98        let mean_array = scirs2_core::ndarray::Array1::from_elem(x.len(), mean);
99
100        // Compute (x - mean)
101        let deviations = F::simd_sub(&x.view(), &mean_array.view());
102
103        // Compute (x - mean)² using element-wise multiplication
104        let squared_devs = F::simd_mul(&deviations.view(), &deviations.view());
105        F::simd_sum(&squared_devs.view())
106    } else {
107        // Scalar fallback
108        x.iter()
109            .map(|&val| {
110                let dev = val - mean;
111                dev * dev
112            })
113            .fold(F::zero(), |acc, val| acc + val)
114    };
115
116    Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
117}
118
119/// Calculate standard deviation using SIMD operations
120///
121/// # Arguments
122///
123/// * `x` - Input data array
124/// * `ddof` - Delta degrees of freedom (0 for population, 1 for sample)
125///
126/// # Returns
127///
128/// * The standard deviation of the input data
129#[allow(dead_code)]
130pub fn std_simd<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
131where
132    F: Float + NumCast + SimdUnifiedOps,
133    D: Data<Elem = F>,
134{
135    variance_simd(x, ddof).map(|var| var.sqrt())
136}
137
138/// Calculate multiple descriptive statistics in a single pass using SIMD
139///
140/// This function efficiently computes mean, variance, min, and max
141/// in a single pass through the data.
142///
143/// # Arguments
144///
145/// * `x` - Input data array
146///
147/// # Returns
148///
149/// * A tuple containing (mean, variance, min, max)
150#[allow(dead_code)]
151pub fn descriptive_stats_simd<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F)>
152where
153    F: Float + NumCast + SimdUnifiedOps,
154    D: Data<Elem = F>,
155{
156    if x.is_empty() {
157        return Err(crate::error::StatsError::InvalidArgument(
158            "Cannot compute statistics of empty array".to_string(),
159        ));
160    }
161
162    let n = x.len();
163    let capabilities = PlatformCapabilities::detect();
164    let optimizer = AutoOptimizer::new();
165
166    if optimizer.should_use_simd(n) && capabilities.simd_available {
167        // Use SIMD operations for all statistics
168        let sum = F::simd_sum(&x.view());
169        let mean = sum / F::from(n).expect("Failed to convert to float");
170
171        // For min/max, we use element reduction operations
172        let min = F::simd_min_element(&x.view());
173        let max = F::simd_max_element(&x.view());
174
175        // Variance calculation
176        let mean_array = scirs2_core::ndarray::Array1::from_elem(x.len(), mean);
177        let deviations = F::simd_sub(&x.view(), &mean_array.view());
178        let squared_devs = F::simd_mul(&deviations.view(), &deviations.view());
179        let sum_sq_dev = F::simd_sum(&squared_devs.view());
180        let variance = sum_sq_dev / F::from(n - 1).expect("Failed to convert to float");
181
182        Ok((mean, variance, min, max))
183    } else {
184        // Scalar fallback with single-pass algorithm
185        let mut sum = F::zero();
186        let mut sum_sq = F::zero();
187        let mut min = x[0];
188        let mut max = x[0];
189
190        for &val in x.iter() {
191            sum = sum + val;
192            sum_sq = sum_sq + val * val;
193            if val < min {
194                min = val;
195            }
196            if val > max {
197                max = val;
198            }
199        }
200
201        let mean = sum / F::from(n).expect("Failed to convert to float");
202        let variance = (sum_sq - sum * sum / F::from(n).expect("Failed to convert to float"))
203            / F::from(n - 1).expect("Failed to convert to float");
204
205        Ok((mean, variance, min, max))
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use approx::assert_relative_eq;
213    use scirs2_core::ndarray::array;
214
215    #[test]
216    fn test_mean_simd() {
217        let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
218        let result = mean_simd(&data.view()).expect("Operation failed");
219        assert_relative_eq!(result, 4.5, epsilon = 1e-10);
220    }
221
222    #[test]
223    fn test_variance_simd() {
224        let data = array![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
225        let result = variance_simd(&data.view(), 1).expect("Operation failed");
226        // Expected sample variance with ddof=1: sum_sq_dev / (n-1) = 32 / 7 = 4.571428571428571
227        assert_relative_eq!(result, 32.0 / 7.0, epsilon = 1e-10);
228    }
229
230    #[test]
231    fn test_descriptive_stats_simd() {
232        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
233        let (mean, var, min, max) = descriptive_stats_simd(&data.view()).expect("Operation failed");
234
235        assert_relative_eq!(mean, 3.0, epsilon = 1e-10);
236        assert_relative_eq!(var, 2.5, epsilon = 1e-10); // Sample variance
237        assert_relative_eq!(min, 1.0, epsilon = 1e-10);
238        assert_relative_eq!(max, 5.0, epsilon = 1e-10);
239    }
240
241    #[test]
242    fn test_simd_consistency() {
243        // Test that SIMD and scalar paths produce identical results
244        let data = array![1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9];
245
246        // Compare with non-SIMD version
247        let simd_mean = mean_simd(&data.view()).expect("Operation failed");
248        let scalar_mean = crate::descriptive::mean(&data.view()).expect("Operation failed");
249
250        assert_relative_eq!(simd_mean, scalar_mean, epsilon = 1e-10);
251    }
252}