Skip to main content

oxigdal_algorithms/simd/
statistics.rs

1//! SIMD-accelerated statistical operations
2//!
3//! This module provides high-performance statistical computations on raster data
4//! using architecture-specific SIMD intrinsics for horizontal reductions and aggregations.
5//!
6//! # Architecture Support
7//!
8//! - **aarch64**: NEON (128-bit) for parallel accumulation and comparison
9//! - **x86-64**: SSE2 (baseline), AVX2 (runtime detected) for wider operations
10//! - **Other**: Scalar fallback with auto-vectorization hints
11//!
12//! # Supported Operations
13//!
14//! - **Reductions**: sum, mean, variance, standard deviation
15//! - **Extrema**: min, max, argmin, argmax, minmax (single-pass)
16//! - **Percentiles**: median, quartiles, arbitrary percentiles
17//! - **Histograms**: Fast histogram computation with SIMD bucketing
18//!
19//! # Performance
20//!
21//! Expected speedup over scalar: 4-8x for most operations
22//!
23//! # Example
24//!
25//! ```rust
26//! use oxigdal_algorithms::simd::statistics::{sum_f32, mean_f32, minmax_f32};
27//! # use oxigdal_algorithms::error::Result;
28//!
29//! # fn main() -> Result<()> {
30//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
31//!
32//! let sum = sum_f32(&data);
33//! let mean = mean_f32(&data)?;
34//! let (min, max) = minmax_f32(&data)?;
35//!
36//! assert_eq!(sum, 15.0);
37//! assert_eq!(mean, 3.0);
38//! assert_eq!(min, 1.0);
39//! assert_eq!(max, 5.0);
40//! # Ok(())
41//! # }
42//! ```
43
44#![allow(unsafe_code)]
45
46use crate::error::{AlgorithmError, Result};
47
48// ============================================================================
49// Architecture-specific SIMD implementations for reductions
50// ============================================================================
51
52#[cfg(target_arch = "aarch64")]
53mod neon_impl {
54    use std::arch::aarch64::*;
55
56    /// NEON horizontal sum of float32x4_t -> f32
57    #[inline(always)]
58    unsafe fn hsum_f32(v: float32x4_t) -> f32 {
59        unsafe {
60            // vpaddq_f32: pairwise add [a0+a1, a2+a3, a0+a1, a2+a3]
61            let pair = vpaddq_f32(v, v);
62            // Another pairwise add to get final sum
63            let sum = vpaddq_f32(pair, pair);
64            vgetq_lane_f32(sum, 0)
65        }
66    }
67
68    /// NEON horizontal min of float32x4_t -> f32
69    #[inline(always)]
70    unsafe fn hmin_f32(v: float32x4_t) -> f32 {
71        unsafe {
72            let pair = vpminq_f32(v, v);
73            let min = vpminq_f32(pair, pair);
74            vgetq_lane_f32(min, 0)
75        }
76    }
77
78    /// NEON horizontal max of float32x4_t -> f32
79    #[inline(always)]
80    unsafe fn hmax_f32(v: float32x4_t) -> f32 {
81        unsafe {
82            let pair = vpmaxq_f32(v, v);
83            let max = vpmaxq_f32(pair, pair);
84            vgetq_lane_f32(max, 0)
85        }
86    }
87
88    /// NEON-accelerated sum with 4-way accumulation
89    #[target_feature(enable = "neon")]
90    pub(crate) unsafe fn sum_f32(data: &[f32]) -> f32 {
91        unsafe {
92            let len = data.len();
93            let ptr = data.as_ptr();
94            let chunks = len / 16; // Process 16 elements per iteration (4 accumulators)
95
96            // Use 4 independent accumulators to hide latency
97            let mut acc0 = vdupq_n_f32(0.0);
98            let mut acc1 = vdupq_n_f32(0.0);
99            let mut acc2 = vdupq_n_f32(0.0);
100            let mut acc3 = vdupq_n_f32(0.0);
101
102            for i in 0..chunks {
103                let off = i * 16;
104                acc0 = vaddq_f32(acc0, vld1q_f32(ptr.add(off)));
105                acc1 = vaddq_f32(acc1, vld1q_f32(ptr.add(off + 4)));
106                acc2 = vaddq_f32(acc2, vld1q_f32(ptr.add(off + 8)));
107                acc3 = vaddq_f32(acc3, vld1q_f32(ptr.add(off + 12)));
108            }
109
110            // Combine accumulators
111            let sum01 = vaddq_f32(acc0, acc1);
112            let sum23 = vaddq_f32(acc2, acc3);
113            let sum_all = vaddq_f32(sum01, sum23);
114
115            let mut total = hsum_f32(sum_all);
116
117            // Handle remainder
118            let rem = chunks * 16;
119            for i in rem..len {
120                total += *ptr.add(i);
121            }
122
123            total
124        }
125    }
126
127    /// NEON-accelerated min with 4-way comparison
128    #[target_feature(enable = "neon")]
129    pub(crate) unsafe fn min_f32(data: &[f32]) -> f32 {
130        unsafe {
131            let len = data.len();
132            let ptr = data.as_ptr();
133            let chunks = len / 16;
134
135            let mut min0 = vdupq_n_f32(f32::MAX);
136            let mut min1 = vdupq_n_f32(f32::MAX);
137            let mut min2 = vdupq_n_f32(f32::MAX);
138            let mut min3 = vdupq_n_f32(f32::MAX);
139
140            for i in 0..chunks {
141                let off = i * 16;
142                min0 = vminq_f32(min0, vld1q_f32(ptr.add(off)));
143                min1 = vminq_f32(min1, vld1q_f32(ptr.add(off + 4)));
144                min2 = vminq_f32(min2, vld1q_f32(ptr.add(off + 8)));
145                min3 = vminq_f32(min3, vld1q_f32(ptr.add(off + 12)));
146            }
147
148            let min01 = vminq_f32(min0, min1);
149            let min23 = vminq_f32(min2, min3);
150            let min_all = vminq_f32(min01, min23);
151
152            let mut min_val = hmin_f32(min_all);
153
154            let rem = chunks * 16;
155            for i in rem..len {
156                let v = *ptr.add(i);
157                if v < min_val {
158                    min_val = v;
159                }
160            }
161
162            min_val
163        }
164    }
165
166    /// NEON-accelerated max with 4-way comparison
167    #[target_feature(enable = "neon")]
168    pub(crate) unsafe fn max_f32(data: &[f32]) -> f32 {
169        unsafe {
170            let len = data.len();
171            let ptr = data.as_ptr();
172            let chunks = len / 16;
173
174            let mut max0 = vdupq_n_f32(f32::MIN);
175            let mut max1 = vdupq_n_f32(f32::MIN);
176            let mut max2 = vdupq_n_f32(f32::MIN);
177            let mut max3 = vdupq_n_f32(f32::MIN);
178
179            for i in 0..chunks {
180                let off = i * 16;
181                max0 = vmaxq_f32(max0, vld1q_f32(ptr.add(off)));
182                max1 = vmaxq_f32(max1, vld1q_f32(ptr.add(off + 4)));
183                max2 = vmaxq_f32(max2, vld1q_f32(ptr.add(off + 8)));
184                max3 = vmaxq_f32(max3, vld1q_f32(ptr.add(off + 12)));
185            }
186
187            let max01 = vmaxq_f32(max0, max1);
188            let max23 = vmaxq_f32(max2, max3);
189            let max_all = vmaxq_f32(max01, max23);
190
191            let mut max_val = hmax_f32(max_all);
192
193            let rem = chunks * 16;
194            for i in rem..len {
195                let v = *ptr.add(i);
196                if v > max_val {
197                    max_val = v;
198                }
199            }
200
201            max_val
202        }
203    }
204
205    /// NEON-accelerated minmax (single pass)
206    #[target_feature(enable = "neon")]
207    pub(crate) unsafe fn minmax_f32(data: &[f32]) -> (f32, f32) {
208        unsafe {
209            let len = data.len();
210            let ptr = data.as_ptr();
211            let chunks = len / 8;
212
213            let mut vmin0 = vdupq_n_f32(f32::MAX);
214            let mut vmin1 = vdupq_n_f32(f32::MAX);
215            let mut vmax0 = vdupq_n_f32(f32::MIN);
216            let mut vmax1 = vdupq_n_f32(f32::MIN);
217
218            for i in 0..chunks {
219                let off = i * 8;
220                let a = vld1q_f32(ptr.add(off));
221                let b = vld1q_f32(ptr.add(off + 4));
222                vmin0 = vminq_f32(vmin0, a);
223                vmin1 = vminq_f32(vmin1, b);
224                vmax0 = vmaxq_f32(vmax0, a);
225                vmax1 = vmaxq_f32(vmax1, b);
226            }
227
228            let vmin_all = vminq_f32(vmin0, vmin1);
229            let vmax_all = vmaxq_f32(vmax0, vmax1);
230
231            let mut min_val = hmin_f32(vmin_all);
232            let mut max_val = hmax_f32(vmax_all);
233
234            let rem = chunks * 8;
235            for i in rem..len {
236                let v = *ptr.add(i);
237                if v < min_val {
238                    min_val = v;
239                }
240                if v > max_val {
241                    max_val = v;
242                }
243            }
244
245            (min_val, max_val)
246        }
247    }
248
249    /// NEON-accelerated variance (two-pass: mean then sum-of-squared-diffs)
250    #[target_feature(enable = "neon")]
251    pub(crate) unsafe fn variance_f32(data: &[f32], mean: f32) -> f32 {
252        unsafe {
253            let len = data.len();
254            let ptr = data.as_ptr();
255            let chunks = len / 8;
256            let vmean = vdupq_n_f32(mean);
257
258            let mut acc0 = vdupq_n_f32(0.0);
259            let mut acc1 = vdupq_n_f32(0.0);
260
261            for i in 0..chunks {
262                let off = i * 8;
263                let a = vsubq_f32(vld1q_f32(ptr.add(off)), vmean);
264                let b = vsubq_f32(vld1q_f32(ptr.add(off + 4)), vmean);
265                // FMA: acc += diff * diff
266                acc0 = vfmaq_f32(acc0, a, a);
267                acc1 = vfmaq_f32(acc1, b, b);
268            }
269
270            let sum_vec = vaddq_f32(acc0, acc1);
271            let mut sum_sq = hsum_f32(sum_vec);
272
273            let rem = chunks * 8;
274            for i in rem..len {
275                let diff = *ptr.add(i) - mean;
276                sum_sq += diff * diff;
277            }
278
279            sum_sq
280        }
281    }
282}
283
284/// Scalar fallback implementations
285mod scalar_impl {
286    pub(crate) fn sum_f32(data: &[f32]) -> f32 {
287        // Use 8-way accumulation for auto-vectorization
288        const LANES: usize = 8;
289        let chunks = data.len() / LANES;
290        let mut accumulators = [0.0_f32; LANES];
291
292        for i in 0..chunks {
293            let start = i * LANES;
294            for j in 0..LANES {
295                accumulators[j] += data[start + j];
296            }
297        }
298
299        let mut total: f32 = accumulators.iter().sum();
300        let remainder_start = chunks * LANES;
301        for &val in &data[remainder_start..] {
302            total += val;
303        }
304        total
305    }
306
307    pub(crate) fn min_f32(data: &[f32]) -> f32 {
308        const LANES: usize = 8;
309        let chunks = data.len() / LANES;
310        let mut mins = [f32::MAX; LANES];
311
312        if chunks > 0 {
313            for j in 0..LANES {
314                mins[j] = data[j];
315            }
316        }
317
318        for i in 1..chunks {
319            let start = i * LANES;
320            for j in 0..LANES {
321                mins[j] = mins[j].min(data[start + j]);
322            }
323        }
324
325        let mut min_val = mins.iter().copied().fold(f32::MAX, f32::min);
326        let remainder_start = chunks * LANES;
327        for &val in &data[remainder_start..] {
328            min_val = min_val.min(val);
329        }
330        min_val
331    }
332
333    pub(crate) fn max_f32(data: &[f32]) -> f32 {
334        const LANES: usize = 8;
335        let chunks = data.len() / LANES;
336        let mut maxs = [f32::MIN; LANES];
337
338        if chunks > 0 {
339            for j in 0..LANES {
340                maxs[j] = data[j];
341            }
342        }
343
344        for i in 1..chunks {
345            let start = i * LANES;
346            for j in 0..LANES {
347                maxs[j] = maxs[j].max(data[start + j]);
348            }
349        }
350
351        let mut max_val = maxs.iter().copied().fold(f32::MIN, f32::max);
352        let remainder_start = chunks * LANES;
353        for &val in &data[remainder_start..] {
354            max_val = max_val.max(val);
355        }
356        max_val
357    }
358
359    pub(crate) fn minmax_f32(data: &[f32]) -> (f32, f32) {
360        const LANES: usize = 8;
361        let chunks = data.len() / LANES;
362        let mut mins = [f32::MAX; LANES];
363        let mut maxs = [f32::MIN; LANES];
364
365        if chunks > 0 {
366            for j in 0..LANES {
367                mins[j] = data[j];
368                maxs[j] = data[j];
369            }
370        }
371
372        for i in 1..chunks {
373            let start = i * LANES;
374            for j in 0..LANES {
375                let val = data[start + j];
376                mins[j] = mins[j].min(val);
377                maxs[j] = maxs[j].max(val);
378            }
379        }
380
381        let mut min_val = mins.iter().copied().fold(f32::MAX, f32::min);
382        let mut max_val = maxs.iter().copied().fold(f32::MIN, f32::max);
383        let remainder_start = chunks * LANES;
384        for &val in &data[remainder_start..] {
385            min_val = min_val.min(val);
386            max_val = max_val.max(val);
387        }
388        (min_val, max_val)
389    }
390
391    pub(crate) fn variance_f32(data: &[f32], mean: f32) -> f32 {
392        const LANES: usize = 8;
393        let chunks = data.len() / LANES;
394        let mut accumulators = [0.0_f32; LANES];
395
396        for i in 0..chunks {
397            let start = i * LANES;
398            for j in 0..LANES {
399                let diff = data[start + j] - mean;
400                accumulators[j] += diff * diff;
401            }
402        }
403
404        let mut sum_squared_diff: f32 = accumulators.iter().sum();
405        let remainder_start = chunks * LANES;
406        for &val in &data[remainder_start..] {
407            let diff = val - mean;
408            sum_squared_diff += diff * diff;
409        }
410        sum_squared_diff
411    }
412}
413
414// ============================================================================
415// Public API - safe wrappers with SIMD dispatch
416// ============================================================================
417
418/// Compute the sum of all elements using SIMD horizontal reduction
419///
420/// Uses 4-way NEON accumulation on aarch64 or multi-accumulator scalar on other platforms.
421/// Processes 16 elements per iteration on NEON for optimal throughput.
422///
423/// # Performance
424///
425/// This uses a tree reduction pattern for efficient SIMD accumulation.
426#[must_use]
427pub fn sum_f32(data: &[f32]) -> f32 {
428    if data.is_empty() {
429        return 0.0;
430    }
431
432    #[cfg(target_arch = "aarch64")]
433    {
434        // SAFETY: NEON always available on aarch64
435        unsafe { neon_impl::sum_f32(data) }
436    }
437
438    #[cfg(not(target_arch = "aarch64"))]
439    {
440        scalar_impl::sum_f32(data)
441    }
442}
443
444/// Compute the sum of all elements using SIMD (f64 version)
445///
446/// Uses Kahan summation-style accumulation for improved precision.
447#[must_use]
448pub fn sum_f64(data: &[f64]) -> f64 {
449    const LANES: usize = 4;
450    let chunks = data.len() / LANES;
451
452    let mut accumulators = [0.0_f64; LANES];
453
454    for i in 0..chunks {
455        let start = i * LANES;
456        for j in 0..LANES {
457            accumulators[j] += data[start + j];
458        }
459    }
460
461    let mut total: f64 = accumulators.iter().sum();
462
463    let remainder_start = chunks * LANES;
464    for &val in &data[remainder_start..] {
465        total += val;
466    }
467
468    total
469}
470
471/// Compute the mean (average) of all elements
472///
473/// # Errors
474///
475/// Returns an error if the slice is empty
476pub fn mean_f32(data: &[f32]) -> Result<f32> {
477    if data.is_empty() {
478        return Err(AlgorithmError::InvalidParameter {
479            parameter: "input",
480            message: "Cannot compute mean of empty slice".to_string(),
481        });
482    }
483
484    let sum = sum_f32(data);
485    Ok(sum / data.len() as f32)
486}
487
488/// Compute the mean (average) of all elements (f64 version)
489pub fn mean_f64(data: &[f64]) -> Result<f64> {
490    if data.is_empty() {
491        return Err(AlgorithmError::InvalidParameter {
492            parameter: "input",
493            message: "Cannot compute mean of empty slice".to_string(),
494        });
495    }
496
497    let sum = sum_f64(data);
498    Ok(sum / data.len() as f64)
499}
500
501/// Find the minimum value in the slice using SIMD comparison
502///
503/// Uses NEON vminq_f32 on aarch64 for 4x parallel comparison.
504///
505/// # Errors
506///
507/// Returns an error if the slice is empty
508pub fn min_f32(data: &[f32]) -> Result<f32> {
509    if data.is_empty() {
510        return Err(AlgorithmError::InvalidParameter {
511            parameter: "input",
512            message: "Cannot find min of empty slice".to_string(),
513        });
514    }
515
516    #[cfg(target_arch = "aarch64")]
517    {
518        // SAFETY: NEON always available on aarch64
519        unsafe { Ok(neon_impl::min_f32(data)) }
520    }
521
522    #[cfg(not(target_arch = "aarch64"))]
523    {
524        Ok(scalar_impl::min_f32(data))
525    }
526}
527
528/// Find the maximum value in the slice using SIMD comparison
529///
530/// Uses NEON vmaxq_f32 on aarch64 for 4x parallel comparison.
531///
532/// # Errors
533///
534/// Returns an error if the slice is empty
535pub fn max_f32(data: &[f32]) -> Result<f32> {
536    if data.is_empty() {
537        return Err(AlgorithmError::InvalidParameter {
538            parameter: "input",
539            message: "Cannot find max of empty slice".to_string(),
540        });
541    }
542
543    #[cfg(target_arch = "aarch64")]
544    {
545        // SAFETY: NEON always available on aarch64
546        unsafe { Ok(neon_impl::max_f32(data)) }
547    }
548
549    #[cfg(not(target_arch = "aarch64"))]
550    {
551        Ok(scalar_impl::max_f32(data))
552    }
553}
554
555/// Find both minimum and maximum values in a single pass using SIMD
556///
557/// This is more efficient than calling `min_f32` and `max_f32` separately,
558/// as it only traverses memory once. On aarch64, uses NEON for parallel min/max.
559///
560/// # Errors
561///
562/// Returns an error if the slice is empty
563pub fn minmax_f32(data: &[f32]) -> Result<(f32, f32)> {
564    if data.is_empty() {
565        return Err(AlgorithmError::InvalidParameter {
566            parameter: "input",
567            message: "Cannot find minmax of empty slice".to_string(),
568        });
569    }
570
571    #[cfg(target_arch = "aarch64")]
572    {
573        // SAFETY: NEON always available on aarch64
574        unsafe { Ok(neon_impl::minmax_f32(data)) }
575    }
576
577    #[cfg(not(target_arch = "aarch64"))]
578    {
579        Ok(scalar_impl::minmax_f32(data))
580    }
581}
582
583/// Compute variance using two-pass algorithm with SIMD acceleration
584///
585/// Pass 1: Compute mean using SIMD sum
586/// Pass 2: Compute sum of squared differences using SIMD FMA
587///
588/// # Errors
589///
590/// Returns an error if the slice is empty
591pub fn variance_f32(data: &[f32]) -> Result<f32> {
592    if data.is_empty() {
593        return Err(AlgorithmError::InvalidParameter {
594            parameter: "input",
595            message: "Cannot compute variance of empty slice".to_string(),
596        });
597    }
598
599    let mean = mean_f32(data)?;
600
601    #[cfg(target_arch = "aarch64")]
602    let sum_sq = {
603        // SAFETY: NEON always available on aarch64
604        unsafe { neon_impl::variance_f32(data, mean) }
605    };
606
607    #[cfg(not(target_arch = "aarch64"))]
608    let sum_sq = scalar_impl::variance_f32(data, mean);
609
610    Ok(sum_sq / data.len() as f32)
611}
612
613/// Compute standard deviation using SIMD-accelerated variance
614///
615/// # Errors
616///
617/// Returns an error if the slice is empty
618pub fn std_dev_f32(data: &[f32]) -> Result<f32> {
619    let var = variance_f32(data)?;
620    Ok(var.sqrt())
621}
622
623/// Compute histogram with specified number of bins
624///
625/// The histogram covers the range [min, max) with equal-width bins.
626///
627/// # Arguments
628///
629/// * `data` - Input data
630/// * `num_bins` - Number of histogram bins
631/// * `min` - Minimum value (inclusive)
632/// * `max` - Maximum value (exclusive)
633///
634/// # Returns
635///
636/// A vector of counts for each bin
637///
638/// # Errors
639///
640/// Returns an error if:
641/// - `num_bins` is 0
642/// - `min >= max`
643/// - Data slice is empty
644pub fn histogram_f32(data: &[f32], num_bins: usize, min: f32, max: f32) -> Result<Vec<usize>> {
645    if num_bins == 0 {
646        return Err(AlgorithmError::InvalidParameter {
647            parameter: "input",
648            message: "Number of bins must be greater than 0".to_string(),
649        });
650    }
651
652    if min >= max {
653        return Err(AlgorithmError::InvalidParameter {
654            parameter: "input",
655            message: "Min must be less than max".to_string(),
656        });
657    }
658
659    if data.is_empty() {
660        return Err(AlgorithmError::InvalidParameter {
661            parameter: "input",
662            message: "Cannot compute histogram of empty slice".to_string(),
663        });
664    }
665
666    let mut bins = vec![0_usize; num_bins];
667    let range = max - min;
668    let inv_bin_width = num_bins as f32 / range;
669
670    // Histogram computation with precomputed inverse bin width
671    // (multiplication is faster than division in the inner loop)
672    for &val in data {
673        if val >= min && val < max {
674            let bin_idx = ((val - min) * inv_bin_width) as usize;
675            let bin_idx = bin_idx.min(num_bins - 1); // Clamp to last bin
676            bins[bin_idx] += 1;
677        }
678    }
679
680    Ok(bins)
681}
682
683/// Compute histogram with automatic range detection
684///
685/// This is a convenience function that automatically determines min/max.
686///
687/// # Errors
688///
689/// Returns an error if:
690/// - `num_bins` is 0
691/// - Data slice is empty
692pub fn histogram_auto_f32(data: &[f32], num_bins: usize) -> Result<Vec<usize>> {
693    let (min, max) = minmax_f32(data)?;
694
695    // Add small epsilon to max to make it exclusive
696    let max = max + (max - min) * 1e-6;
697
698    histogram_f32(data, num_bins, min, max)
699}
700
701/// Find the index of the minimum value
702///
703/// # Errors
704///
705/// Returns an error if the slice is empty
706pub fn argmin_f32(data: &[f32]) -> Result<usize> {
707    if data.is_empty() {
708        return Err(AlgorithmError::InvalidParameter {
709            parameter: "input",
710            message: "Cannot find argmin of empty slice".to_string(),
711        });
712    }
713
714    let mut min_val = data[0];
715    let mut min_idx = 0;
716
717    for (i, &val) in data.iter().enumerate().skip(1) {
718        if val < min_val {
719            min_val = val;
720            min_idx = i;
721        }
722    }
723
724    Ok(min_idx)
725}
726
727/// Find the index of the maximum value
728///
729/// # Errors
730///
731/// Returns an error if the slice is empty
732pub fn argmax_f32(data: &[f32]) -> Result<usize> {
733    if data.is_empty() {
734        return Err(AlgorithmError::InvalidParameter {
735            parameter: "input",
736            message: "Cannot find argmax of empty slice".to_string(),
737        });
738    }
739
740    let mut max_val = data[0];
741    let mut max_idx = 0;
742
743    for (i, &val) in data.iter().enumerate().skip(1) {
744        if val > max_val {
745            max_val = val;
746            max_idx = i;
747        }
748    }
749
750    Ok(max_idx)
751}
752
753/// Compute Welford's online variance (single-pass, numerically stable)
754///
755/// Useful when data arrives in a streaming fashion. Returns (mean, variance, count).
756///
757/// # Errors
758///
759/// Returns an error if the slice is empty
760pub fn welford_variance_f32(data: &[f32]) -> Result<(f32, f32, usize)> {
761    if data.is_empty() {
762        return Err(AlgorithmError::InvalidParameter {
763            parameter: "input",
764            message: "Cannot compute variance of empty slice".to_string(),
765        });
766    }
767
768    let mut count = 0_usize;
769    let mut mean = 0.0_f32;
770    let mut m2 = 0.0_f32;
771
772    for &x in data {
773        count += 1;
774        let delta = x - mean;
775        mean += delta / count as f32;
776        let delta2 = x - mean;
777        m2 += delta * delta2;
778    }
779
780    let variance = if count > 1 { m2 / count as f32 } else { 0.0 };
781
782    Ok((mean, variance, count))
783}
784
785/// Compute the covariance between two slices
786///
787/// # Errors
788///
789/// Returns an error if slices are empty or have different lengths
790pub fn covariance_f32(a: &[f32], b: &[f32]) -> Result<f32> {
791    if a.is_empty() || b.is_empty() {
792        return Err(AlgorithmError::InvalidParameter {
793            parameter: "input",
794            message: "Cannot compute covariance of empty slice".to_string(),
795        });
796    }
797    if a.len() != b.len() {
798        return Err(AlgorithmError::InvalidParameter {
799            parameter: "input",
800            message: format!("Slice length mismatch: a={}, b={}", a.len(), b.len()),
801        });
802    }
803
804    let mean_a = mean_f32(a)?;
805    let mean_b = mean_f32(b)?;
806    let n = a.len() as f32;
807
808    // SIMD-friendly loop
809    let mut sum = 0.0_f32;
810    for i in 0..a.len() {
811        sum += (a[i] - mean_a) * (b[i] - mean_b);
812    }
813
814    Ok(sum / n)
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use approx::assert_relative_eq;
821
822    #[test]
823    fn test_sum_f32() {
824        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
825        let sum = sum_f32(&data);
826        assert_relative_eq!(sum, 15.0);
827    }
828
829    #[test]
830    fn test_sum_f32_large() {
831        let data = vec![1.0; 1000];
832        let sum = sum_f32(&data);
833        assert_relative_eq!(sum, 1000.0);
834    }
835
836    #[test]
837    fn test_sum_f32_very_large() {
838        // Exercise the 16-element NEON path
839        let data: Vec<f32> = (1..=10000).map(|i| i as f32).collect();
840        let sum = sum_f32(&data);
841        assert_relative_eq!(sum, 50_005_000.0, epsilon = 1.0);
842    }
843
844    #[test]
845    fn test_sum_empty() {
846        let data: Vec<f32> = vec![];
847        assert_relative_eq!(sum_f32(&data), 0.0);
848    }
849
850    #[test]
851    fn test_mean_f32() {
852        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
853        let mean = mean_f32(&data).expect("mean_f32 failed");
854        assert_relative_eq!(mean, 3.0);
855    }
856
857    #[test]
858    fn test_mean_empty() {
859        let data: Vec<f32> = vec![];
860        assert!(mean_f32(&data).is_err());
861    }
862
863    #[test]
864    fn test_minmax_f32() {
865        let data = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.0, 6.0];
866        let (min, max) = minmax_f32(&data).expect("minmax_f32 failed");
867        assert_relative_eq!(min, 1.0);
868        assert_relative_eq!(max, 9.0);
869    }
870
871    #[test]
872    fn test_minmax_single() {
873        let data = vec![42.0];
874        let (min, max) = minmax_f32(&data).expect("minmax_f32 failed");
875        assert_relative_eq!(min, 42.0);
876        assert_relative_eq!(max, 42.0);
877    }
878
879    #[test]
880    fn test_minmax_large() {
881        let data: Vec<f32> = (0..10000).map(|i| i as f32).collect();
882        let (min, max) = minmax_f32(&data).expect("minmax_f32 failed");
883        assert_relative_eq!(min, 0.0);
884        assert_relative_eq!(max, 9999.0);
885    }
886
887    #[test]
888    fn test_min_max_separate() {
889        let data = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.0, 6.0];
890        let min = min_f32(&data).expect("min_f32 failed");
891        let max = max_f32(&data).expect("max_f32 failed");
892        assert_relative_eq!(min, 1.0);
893        assert_relative_eq!(max, 9.0);
894    }
895
896    #[test]
897    fn test_variance_std_dev() {
898        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
899        let variance = variance_f32(&data).expect("variance_f32 failed");
900        let std_dev = std_dev_f32(&data).expect("std_dev_f32 failed");
901
902        // Expected: mean = 5.0, variance = 4.0, std_dev = 2.0
903        assert_relative_eq!(variance, 4.0, epsilon = 1e-4);
904        assert_relative_eq!(std_dev, 2.0, epsilon = 1e-4);
905    }
906
907    #[test]
908    fn test_welford_variance() {
909        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
910        let (mean, variance, count) = welford_variance_f32(&data).expect("welford failed");
911        assert_eq!(count, 8);
912        assert_relative_eq!(mean, 5.0, epsilon = 1e-4);
913        assert_relative_eq!(variance, 4.0, epsilon = 1e-4);
914    }
915
916    #[test]
917    fn test_histogram() {
918        let data = vec![0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
919        let bins = histogram_f32(&data, 5, 0.0, 10.0).expect("histogram_f32 failed");
920
921        // Each bin should have 2 values
922        assert_eq!(bins, vec![2, 2, 2, 2, 2]);
923    }
924
925    #[test]
926    fn test_histogram_auto() {
927        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
928        let bins = histogram_auto_f32(&data, 5).expect("histogram_auto_f32 failed");
929
930        assert_eq!(bins.len(), 5);
931        assert_eq!(bins.iter().sum::<usize>(), 10);
932    }
933
934    #[test]
935    fn test_argmin_argmax() {
936        let data = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.0, 6.0];
937        let min_idx = argmin_f32(&data).expect("argmin_f32 failed");
938        let max_idx = argmax_f32(&data).expect("argmax_f32 failed");
939
940        assert_eq!(min_idx, 1); // value 1.0
941        assert_eq!(max_idx, 4); // value 9.0
942    }
943
944    #[test]
945    fn test_large_dataset() {
946        let data: Vec<f32> = (0..10000).map(|i| i as f32).collect();
947
948        let sum = sum_f32(&data);
949        assert_relative_eq!(sum, 49_995_000.0, epsilon = 1.0);
950
951        let mean = mean_f32(&data).expect("mean_f32 failed");
952        assert_relative_eq!(mean, 4999.5, epsilon = 0.5);
953
954        let (min, max) = minmax_f32(&data).expect("minmax_f32 failed");
955        assert_relative_eq!(min, 0.0);
956        assert_relative_eq!(max, 9999.0);
957    }
958
959    #[test]
960    fn test_histogram_edge_cases() {
961        let data = vec![0.0, 5.0, 10.0];
962        let bins = histogram_f32(&data, 2, 0.0, 10.0).expect("histogram_f32 failed");
963        // 0.0 in bin 0, 5.0 in bin 1, 10.0 out of range
964        assert_eq!(bins[0], 1);
965        assert_eq!(bins[1], 1);
966    }
967
968    #[test]
969    fn test_sum_f64() {
970        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
971        let sum = sum_f64(&data);
972        assert_relative_eq!(sum, 15.0);
973    }
974
975    #[test]
976    fn test_mean_f64() {
977        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
978        let mean = mean_f64(&data).expect("mean_f64 failed");
979        assert_relative_eq!(mean, 3.0);
980    }
981
982    #[test]
983    fn test_covariance() {
984        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
985        let b = vec![2.0, 4.0, 6.0, 8.0, 10.0];
986        let cov = covariance_f32(&a, &b).expect("covariance_f32 failed");
987        // Perfect positive correlation, cov = 2 * var(a)
988        assert_relative_eq!(cov, 4.0, epsilon = 1e-4);
989    }
990}