Skip to main content

ferray_core/array/
reductions.rs

1// ferray-core: Reduction methods for Array<T, D>
2//
3// Provides NumPy-equivalent reduction methods directly on Array and ArrayView:
4//   sum, prod, min, max, mean, var, std, any, all
5//
6// Each reduction has a whole-array variant and an axis variant:
7//   .sum()             -> T
8//   .sum_axis(Axis(0)) -> Array<T, IxDyn>
9//
10// These methods complement the lower-level fold_axis primitive in methods.rs
11// and the free-function reductions in ferray-stats. The instance-method form
12// matches NumPy's `arr.sum()` ergonomics so users don't need to import a
13// separate crate just to compute a sum.
14//
15// See: https://github.com/dollspace-gay/ferray/issues/368
16
17use num_traits::Float;
18
19use crate::array::owned::Array;
20use crate::array::view::ArrayView;
21use crate::dimension::{Axis, Dimension, IxDyn};
22use crate::dtype::Element;
23use crate::error::FerrayResult;
24
25/// Generic min/max fold step that propagates NaN per NumPy semantics.
26///
27/// Once any NaN enters the fold, all subsequent steps return NaN. Detected
28/// generically via `x.partial_cmp(&x).is_none()`, which is true iff `x` is
29/// NaN (or any other value that violates `PartialOrd` reflexivity, e.g.
30/// `Complex` types — but those don't implement `PartialOrd` so this is moot).
31#[inline]
32fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
33    let acc_is_nan = acc.partial_cmp(&acc).is_none();
34    if acc_is_nan {
35        return acc;
36    }
37    let x_is_nan = x.partial_cmp(&x).is_none();
38    if x_is_nan {
39        return x;
40    }
41    match (take_min, x.partial_cmp(&acc)) {
42        (true, Some(std::cmp::Ordering::Less)) => x,
43        (false, Some(std::cmp::Ordering::Greater)) => x,
44        _ => acc,
45    }
46}
47
48// ---------------------------------------------------------------------------
49// Sum / Prod (work for any Element with Add/Mul, using Element::zero/one)
50// ---------------------------------------------------------------------------
51
52impl<T, D> Array<T, D>
53where
54    T: Element + Copy,
55    D: Dimension,
56{
57    /// Sum of all elements (whole-array reduction).
58    ///
59    /// Returns `Element::zero()` for an empty array.
60    ///
61    /// # Examples
62    /// ```
63    /// # use ferray_core::Array;
64    /// # use ferray_core::dimension::Ix1;
65    /// let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
66    /// assert_eq!(a.sum(), 6.0);
67    /// ```
68    pub fn sum(&self) -> T
69    where
70        T: std::ops::Add<Output = T>,
71    {
72        let mut acc = T::zero();
73        for &x in self.iter() {
74            acc = acc + x;
75        }
76        acc
77    }
78
79    /// Sum along the given axis. Returns an array with one fewer dimension.
80    ///
81    /// # Errors
82    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
83    pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
84    where
85        T: std::ops::Add<Output = T>,
86        D::NdarrayDim: ndarray::RemoveAxis,
87    {
88        self.fold_axis(axis, T::zero(), |acc, &x| *acc + x)
89    }
90
91    /// Product of all elements.
92    ///
93    /// Returns `Element::one()` for an empty array.
94    pub fn prod(&self) -> T
95    where
96        T: std::ops::Mul<Output = T>,
97    {
98        let mut acc = T::one();
99        for &x in self.iter() {
100            acc = acc * x;
101        }
102        acc
103    }
104
105    /// Product along the given axis.
106    pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
107    where
108        T: std::ops::Mul<Output = T>,
109        D::NdarrayDim: ndarray::RemoveAxis,
110    {
111        self.fold_axis(axis, T::one(), |acc, &x| *acc * x)
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Min / Max — require PartialOrd
117// ---------------------------------------------------------------------------
118
119impl<T, D> Array<T, D>
120where
121    T: Element + Copy + PartialOrd,
122    D: Dimension,
123{
124    /// Minimum value across the entire array.
125    ///
126    /// Returns `None` if the array is empty. NaN values follow NumPy semantics:
127    /// once a NaN is seen the result stays NaN, detected via self-comparison
128    /// (`x.partial_cmp(&x).is_none()`).
129    pub fn min(&self) -> Option<T> {
130        let mut iter = self.iter().copied();
131        let first = iter.next()?;
132        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
133    }
134
135    /// Maximum value across the entire array.
136    ///
137    /// Returns `None` if the array is empty. NaN values propagate per NumPy.
138    pub fn max(&self) -> Option<T> {
139        let mut iter = self.iter().copied();
140        let first = iter.next()?;
141        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
142    }
143
144    /// Minimum value along an axis.
145    ///
146    /// # Errors
147    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
148    /// `FerrayError::ShapeMismatch` if the resulting axis would be empty.
149    pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
150    where
151        D::NdarrayDim: ndarray::RemoveAxis,
152    {
153        // Use the first element along the axis as init by sentinel: pull the
154        // first lane and fold the rest. fold_axis applies init to every lane,
155        // but min has no neutral identity for arbitrary T. We sidestep by
156        // folding starting from any element of `self` — the per-lane init is
157        // overwritten by the first comparison, which is correct iff every lane
158        // has at least one element. Empty axes would yield uninitialized data.
159        let ndim = self.ndim();
160        if axis.index() >= ndim {
161            return Err(crate::error::FerrayError::axis_out_of_bounds(
162                axis.index(),
163                ndim,
164            ));
165        }
166        if self.shape()[axis.index()] == 0 {
167            return Err(crate::error::FerrayError::shape_mismatch(
168                "cannot compute min along empty axis",
169            ));
170        }
171        // Manual lane iteration: fold_axis can't be used here because min has
172        // no neutral identity that works for arbitrary `T: PartialOrd` (no
173        // T::infinity for ints).
174        self.fold_axis_min_max(axis, true)
175    }
176
177    /// Maximum value along an axis.
178    ///
179    /// See [`Array::min_axis`] for error semantics.
180    pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
181    where
182        D::NdarrayDim: ndarray::RemoveAxis,
183    {
184        let ndim = self.ndim();
185        if axis.index() >= ndim {
186            return Err(crate::error::FerrayError::axis_out_of_bounds(
187                axis.index(),
188                ndim,
189            ));
190        }
191        if self.shape()[axis.index()] == 0 {
192            return Err(crate::error::FerrayError::shape_mismatch(
193                "cannot compute max along empty axis",
194            ));
195        }
196        self.fold_axis_min_max(axis, false)
197    }
198
199    /// Internal: per-lane min/max via manual lane iteration. Avoids the
200    /// init-bias problem of fold_axis (which applies a single init to every
201    /// lane, even though min/max have no identity element).
202    fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
203    where
204        D::NdarrayDim: ndarray::RemoveAxis,
205    {
206        let nd_axis = ndarray::Axis(axis.index());
207        // Use ndarray's lane iteration directly via the inner ndarray::ArrayBase.
208        // Each lane is a 1D view orthogonal to the chosen axis.
209        let lanes = self.inner.lanes(nd_axis);
210        let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
211        for lane in self.inner.lanes(nd_axis) {
212            let mut iter = lane.iter().copied();
213            let first = iter.next().unwrap(); // safe: empty axis already rejected
214            let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
215            out.push(result);
216        }
217
218        // Output shape: drop the reduced axis from the input shape.
219        let mut out_shape: Vec<usize> = self.shape().to_vec();
220        out_shape.remove(axis.index());
221        Array::from_vec(IxDyn::from(&out_shape[..]), out)
222    }
223}
224
225// ---------------------------------------------------------------------------
226// Mean / Var / Std — require Float
227// ---------------------------------------------------------------------------
228
229impl<T, D> Array<T, D>
230where
231    T: Element + Float,
232    D: Dimension,
233{
234    /// Arithmetic mean of all elements. Returns `None` for an empty array.
235    pub fn mean(&self) -> Option<T> {
236        let n = self.size();
237        if n == 0 {
238            return None;
239        }
240        let sum: T = self
241            .iter()
242            .copied()
243            .fold(<T as Element>::zero(), |acc, x| acc + x);
244        Some(sum / T::from(n).unwrap())
245    }
246
247    /// Mean along an axis.
248    pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
249    where
250        D::NdarrayDim: ndarray::RemoveAxis,
251    {
252        let ndim = self.ndim();
253        if axis.index() >= ndim {
254            return Err(crate::error::FerrayError::axis_out_of_bounds(
255                axis.index(),
256                ndim,
257            ));
258        }
259        let n = self.shape()[axis.index()];
260        if n == 0 {
261            return Err(crate::error::FerrayError::shape_mismatch(
262                "cannot compute mean along empty axis",
263            ));
264        }
265        let sums = self.sum_axis(axis)?;
266        let n_t = T::from(n).unwrap();
267        Ok(sums.mapv(|x| x / n_t))
268    }
269
270    /// Variance with `ddof` degrees of freedom (Bessel's correction = 1).
271    ///
272    /// Returns `None` for an empty array, or when `ddof >= n`.
273    pub fn var(&self, ddof: usize) -> Option<T> {
274        let n = self.size();
275        if n == 0 || ddof >= n {
276            return None;
277        }
278        let mean = self.mean()?;
279        let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
280            acc + (x - mean) * (x - mean)
281        });
282        Some(sum_sq / T::from(n - ddof).unwrap())
283    }
284
285    /// Standard deviation with `ddof` degrees of freedom.
286    pub fn std(&self, ddof: usize) -> Option<T> {
287        self.var(ddof).map(|v| v.sqrt())
288    }
289}
290
291// ---------------------------------------------------------------------------
292// any / all — for bool arrays
293// ---------------------------------------------------------------------------
294
295impl<D> Array<bool, D>
296where
297    D: Dimension,
298{
299    /// Returns `true` if any element is `true`.
300    pub fn any(&self) -> bool {
301        self.iter().any(|&x| x)
302    }
303
304    /// Returns `true` if all elements are `true`. Vacuously `true` for empty arrays.
305    pub fn all(&self) -> bool {
306        self.iter().all(|&x| x)
307    }
308}
309
310// ---------------------------------------------------------------------------
311// ArrayView mirrors — same methods on borrowed views
312// ---------------------------------------------------------------------------
313
314impl<T, D> ArrayView<'_, T, D>
315where
316    T: Element + Copy,
317    D: Dimension,
318{
319    /// Sum of all elements. See [`Array::sum`].
320    pub fn sum(&self) -> T
321    where
322        T: std::ops::Add<Output = T>,
323    {
324        let mut acc = T::zero();
325        for &x in self.iter() {
326            acc = acc + x;
327        }
328        acc
329    }
330
331    /// Product of all elements. See [`Array::prod`].
332    pub fn prod(&self) -> T
333    where
334        T: std::ops::Mul<Output = T>,
335    {
336        let mut acc = T::one();
337        for &x in self.iter() {
338            acc = acc * x;
339        }
340        acc
341    }
342}
343
344impl<T, D> ArrayView<'_, T, D>
345where
346    T: Element + Copy + PartialOrd,
347    D: Dimension,
348{
349    /// Minimum value. See [`Array::min`].
350    pub fn min(&self) -> Option<T> {
351        let mut iter = self.iter().copied();
352        let first = iter.next()?;
353        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
354    }
355
356    /// Maximum value. See [`Array::max`].
357    pub fn max(&self) -> Option<T> {
358        let mut iter = self.iter().copied();
359        let first = iter.next()?;
360        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
361    }
362}
363
364impl<T, D> ArrayView<'_, T, D>
365where
366    T: Element + Float,
367    D: Dimension,
368{
369    /// Mean. See [`Array::mean`].
370    pub fn mean(&self) -> Option<T> {
371        let n = self.size();
372        if n == 0 {
373            return None;
374        }
375        let sum: T = self
376            .iter()
377            .copied()
378            .fold(<T as Element>::zero(), |acc, x| acc + x);
379        Some(sum / T::from(n).unwrap())
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::dimension::{Ix1, Ix2};
387
388    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
389        let n = data.len();
390        Array::from_vec(Ix1::new([n]), data).unwrap()
391    }
392
393    fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
394        Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
395    }
396
397    // ----- sum / prod -----
398
399    #[test]
400    fn sum_1d() {
401        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
402        assert_eq!(a.sum(), 10.0);
403    }
404
405    #[test]
406    fn sum_empty_returns_zero() {
407        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
408        assert_eq!(a.sum(), 0.0);
409    }
410
411    #[test]
412    fn sum_axis_2d() {
413        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
414        // Sum across rows (axis 0): [1+4, 2+5, 3+6] = [5, 7, 9]
415        let s0 = a.sum_axis(Axis(0)).unwrap();
416        assert_eq!(s0.shape(), &[3]);
417        assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
418
419        // Sum across columns (axis 1): [1+2+3, 4+5+6] = [6, 15]
420        let s1 = a.sum_axis(Axis(1)).unwrap();
421        assert_eq!(s1.shape(), &[2]);
422        assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
423    }
424
425    #[test]
426    fn prod_1d() {
427        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
428        assert_eq!(a.prod(), 24.0);
429    }
430
431    #[test]
432    fn prod_empty_returns_one() {
433        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
434        assert_eq!(a.prod(), 1.0);
435    }
436
437    #[test]
438    fn prod_axis_2d() {
439        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
440        let p0 = a.prod_axis(Axis(0)).unwrap();
441        assert_eq!(
442            p0.iter().copied().collect::<Vec<_>>(),
443            vec![4.0, 10.0, 18.0]
444        );
445
446        let p1 = a.prod_axis(Axis(1)).unwrap();
447        assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
448    }
449
450    // ----- min / max -----
451
452    #[test]
453    fn min_max_1d() {
454        let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
455        assert_eq!(a.min(), Some(1.0));
456        assert_eq!(a.max(), Some(9.0));
457    }
458
459    #[test]
460    fn min_max_empty_returns_none() {
461        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
462        assert_eq!(a.min(), None);
463        assert_eq!(a.max(), None);
464    }
465
466    #[test]
467    fn min_max_int() {
468        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
469        assert_eq!(a.min(), Some(-5));
470        assert_eq!(a.max(), Some(4));
471    }
472
473    #[test]
474    fn min_max_axis_2d() {
475        let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
476        // axis 0: min/max per column
477        let mn0 = a.min_axis(Axis(0)).unwrap();
478        assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
479        let mx0 = a.max_axis(Axis(0)).unwrap();
480        assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
481
482        // axis 1: min/max per row
483        let mn1 = a.min_axis(Axis(1)).unwrap();
484        assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
485        let mx1 = a.max_axis(Axis(1)).unwrap();
486        assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
487    }
488
489    // ----- mean / var / std -----
490
491    #[test]
492    fn mean_1d() {
493        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
494        assert_eq!(a.mean(), Some(2.5));
495    }
496
497    #[test]
498    fn mean_empty_returns_none() {
499        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
500        assert_eq!(a.mean(), None);
501    }
502
503    #[test]
504    fn mean_axis_2d() {
505        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
506        let m0 = a.mean_axis(Axis(0)).unwrap();
507        assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
508        let m1 = a.mean_axis(Axis(1)).unwrap();
509        assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
510    }
511
512    #[test]
513    fn var_population() {
514        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
515        // population variance (ddof=0): ((1-3)^2+(2-3)^2+(3-3)^2+(4-3)^2+(5-3)^2)/5 = 10/5 = 2
516        assert_eq!(a.var(0), Some(2.0));
517    }
518
519    #[test]
520    fn var_sample() {
521        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
522        // sample variance (ddof=1): 10/4 = 2.5
523        assert_eq!(a.var(1), Some(2.5));
524    }
525
526    #[test]
527    fn std_basic() {
528        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
529        let s = a.std(0).unwrap();
530        assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
531    }
532
533    #[test]
534    fn var_ddof_too_large_returns_none() {
535        let a = arr1(vec![1.0, 2.0]);
536        assert_eq!(a.var(2), None);
537        assert_eq!(a.var(5), None);
538    }
539
540    // ----- any / all -----
541
542    #[test]
543    fn any_all_bool() {
544        let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
545        let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
546        let false_arr =
547            Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
548        let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
549
550        assert!(true_arr.all());
551        assert!(true_arr.any());
552
553        assert!(!mixed.all());
554        assert!(mixed.any());
555
556        assert!(!false_arr.all());
557        assert!(!false_arr.any());
558
559        // Vacuous truth for empty
560        assert!(empty.all());
561        assert!(!empty.any());
562    }
563
564    // ----- ArrayView mirrors -----
565
566    #[test]
567    fn view_sum_min_max_mean() {
568        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
569        let v = a.view();
570        assert_eq!(v.sum(), 10.0);
571        assert_eq!(v.min(), Some(1.0));
572        assert_eq!(v.max(), Some(4.0));
573        assert_eq!(v.mean(), Some(2.5));
574    }
575
576    #[test]
577    fn nan_propagates_in_min_max() {
578        // NaN somewhere in the middle
579        let a = arr1(vec![1.0, f64::NAN, 3.0]);
580        assert!(a.min().unwrap().is_nan());
581        assert!(a.max().unwrap().is_nan());
582
583        // NaN at the start
584        let b = arr1(vec![f64::NAN, 1.0, 3.0]);
585        assert!(b.min().unwrap().is_nan());
586        assert!(b.max().unwrap().is_nan());
587
588        // NaN at the end
589        let c = arr1(vec![1.0, 3.0, f64::NAN]);
590        assert!(c.min().unwrap().is_nan());
591        assert!(c.max().unwrap().is_nan());
592    }
593}