Skip to main content

ferray_core/array/
reductions.rs

1// ferray-core: Reduction methods for Array<T, D>
2//
3// Provides NumPy-equivalent reduction methods directly on Array and ArrayView:
4//   sum, prod, min, max, mean, var, std, any, all
5//
6// Each reduction has a whole-array variant and an axis variant:
7//   .sum()             -> T
8//   .sum_axis(Axis(0)) -> Array<T, IxDyn>
9//
10// These methods complement the lower-level fold_axis primitive in methods.rs
11// and the free-function reductions in ferray-stats. The instance-method form
12// matches NumPy's `arr.sum()` ergonomics so users don't need to import a
13// separate crate just to compute a sum.
14//
15// ## REQ status (reductions, NumPy parity)
16//  - sum / prod / sum_axis / prod_axis accumulator promotion — SHIPPED (#780):
17//    the `ReduceAcc` trait (this file) maps narrow signed ints → i64, narrow
18//    unsigned ints → u64, and bool → i64 before reducing, so narrow-int
19//    reductions never overflow and the result dtype matches numpy
20//    (numpy/_core/fromnumeric.py:2321-2327). Consumers: `Array::sum`/`prod`/
21//    `sum_axis`/`prod_axis` and the `ArrayView` mirrors (all in this file).
22//  - min / max / mean / var / std / any / all — SHIPPED (#368), NaN-propagating.
23//  - empty-array min/max raising ValueError — SHIPPED (#782, resolved R-DEV-4):
24//    `Array::min`/`max` return `Option<T>` = `None` on an empty array (the
25//    idiomatic Rust analog of numpy's no-identity error — same boundary contract
26//    as `argmax`/`argmin`, REQ-40/41; test `min_max_empty_returns_none` in this
27//    file). The ferray-python boundary maps `None` -> `ValueError`, matching
28//    numpy's `ValueError` ("zero-size array to reduction operation minimum which
29//    has no identity", `fromnumeric.py:3150` `min`/`:3266` `amin`). Verified:
30//    `fr.min(fr.array([], dtype=fr.float64))` raises `ValueError` vs numpy 2.4.5.
31//  - cumsum / cumprod live as ferray-stats free functions; their narrow-int
32//    promotion is a separate ferray-stats blocker that reuses `ReduceAcc`.
33//  - argmax / argmin (REQ-40, REQ-41) — SHIPPED: `Array::argmax`/`argmin`
34//    (flattened, returning `Option<i64>`) and `argmax_axis`/`argmin_axis`
35//    (returning `Array<i64, IxDyn>`) in this file. First-occurrence on ties and
36//    NaN-first propagation (numpy/_core/fromnumeric.py:1222 `argmax`,
37//    :1261-1262 ties; :1322 `argmin`, :1361-1362 ties). Empty flattened form
38//    returns `None` (mirroring `min`/`max`'s `Option` analog; the ferray-python
39//    boundary maps `None`→`ValueError` as numpy does). Result index dtype is
40//    `i64` (ferray's `intp` analog), independent of element dtype. Consumers:
41//    the boundary methods themselves are the public API surface (like
42//    `sum`/`min`), exercised by the `ArrayView` mirror and the in-file tests.
43//  - integer/bool mean → f64 (REQ-42) — SHIPPED: the `MeanAcc` trait (this
44//    file) maps bool/integer element types to an `f64` accumulator-and-result,
45//    while `f32`/`f64`/complex stay themselves, so `Array::<i32, _>::mean()`
46//    returns `f64` matching numpy, which casts bool/unsigned/signed int to
47//    float64 before averaging (numpy/_core/_methods.py:124-127). `mean`/
48//    `mean_axis` and the `ArrayView::mean` mirror are now bounded by `MeanAcc`
49//    instead of `Float`; existing `f32`/`f64` means are unchanged
50//    (`MeanAcc::Mean == Self`). Consumers: `Array::var`/`std` (`self.mean()?`)
51//    and the `ArrayView` mirror, all in this file.
52//
53// See: https://github.com/dollspace-gay/ferray/issues/368, /issues/780
54
55use num_traits::Float;
56
57use crate::array::owned::Array;
58use crate::array::view::ArrayView;
59use crate::dimension::{Axis, Dimension, IxDyn};
60use crate::dtype::Element;
61use crate::error::FerrayResult;
62
63// ---------------------------------------------------------------------------
64// ReduceAcc — NumPy's sum/prod/cumsum/cumprod accumulator-and-result dtype.
65// ---------------------------------------------------------------------------
66
67/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
68/// return) `sum` / `prod` / `cumsum` / `cumprod` over it.
69///
70/// NumPy promotes any integer dtype of *less precision than the default
71/// platform integer* before reducing, so a narrow-int reduction can never
72/// overflow and the result dtype is the platform integer:
73///
74/// > "The dtype of `a` is used by default unless `a` has an integer dtype of
75/// > less precision than the default platform integer.  In that case, if `a`
76/// > is signed then the platform integer is used while if `a` is unsigned then
77/// > an unsigned integer of the same precision as the platform integer is
78/// > used."
79/// > — `numpy/_core/fromnumeric.py:2321-2327` (sum), `:3306-3312` (prod)
80///
81/// The reduction itself is `umr_sum = um.add.reduce` /
82/// `umr_prod = um.multiply.reduce` (`numpy/_core/_methods.py:20-21`), i.e. the
83/// add/multiply ufunc whose *loop dtype* is the promoted accumulator.
84///
85/// The mapping (platform integer = 64-bit, matching ferray's `intp`/`int64`):
86///   - `i8 / i16 / i32 → i64`,  `i64 → i64`,  `i128 → i128`
87///   - `u8 / u16 / u32 → u64`,  `u64 → u64`,  `u128 → u128`
88///   - `bool → i64` (NumPy reduces bool as the platform integer, counting `true`)
89///   - `f32 → f32`, `f64 → f64`, complex stays itself (no promotion)
90///
91/// Wider-or-equal dtypes map to themselves, so existing `f64`/`i64`/complex
92/// reductions are unchanged — only narrow-int callers observe the promoted
93/// return type.
94pub trait ReduceAcc: Element + Copy {
95    /// The accumulator-and-result element type for reductions over `Self`.
96    type Acc: Element + Copy + std::ops::Add<Output = Self::Acc> + std::ops::Mul<Output = Self::Acc>;
97
98    /// Widen one element into the accumulator type before reducing, matching
99    /// NumPy's promotion of the loop dtype (`true → 1` for `bool`).
100    fn widen(self) -> Self::Acc;
101}
102
103macro_rules! impl_reduce_acc {
104    ($($t:ty => $acc:ty),* $(,)?) => {
105        $(
106            impl ReduceAcc for $t {
107                type Acc = $acc;
108                #[inline]
109                fn widen(self) -> $acc {
110                    self as $acc
111                }
112            }
113        )*
114    };
115}
116
117// Narrow signed ints promote to i64; i64/i128 stay themselves.
118impl_reduce_acc! {
119    i8 => i64, i16 => i64, i32 => i64, i64 => i64, i128 => i128,
120    u8 => u64, u16 => u64, u32 => u64, u64 => u64, u128 => u128,
121    f32 => f32, f64 => f64,
122}
123
124// bool reduces as the platform integer, counting `true` (numpy:
125// `np.sum(np.array([True, True, True])).dtype == int64`). `as i64` maps
126// false→0, true→1.
127impl ReduceAcc for bool {
128    type Acc = i64;
129    #[inline]
130    fn widen(self) -> i64 {
131        i64::from(self)
132    }
133}
134
135// Complex stays itself — numpy never promotes a complex reduction.
136impl ReduceAcc for num_complex::Complex<f32> {
137    type Acc = num_complex::Complex<f32>;
138    #[inline]
139    fn widen(self) -> Self {
140        self
141    }
142}
143
144impl ReduceAcc for num_complex::Complex<f64> {
145    type Acc = num_complex::Complex<f64>;
146    #[inline]
147    fn widen(self) -> Self {
148        self
149    }
150}
151
152// ---------------------------------------------------------------------------
153// MeanAcc — NumPy's mean accumulator-and-result dtype.
154// ---------------------------------------------------------------------------
155
156/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
157/// return) `mean` over it.
158///
159/// NumPy casts a bool / unsigned-int / signed-int input to `float64` before
160/// averaging:
161///
162/// > "Cast bool, unsigned int, and int to float64 by default ...
163/// > `dtype = mu.dtype('f8')`"
164/// > — `numpy/_core/_methods.py:124-127`
165///
166/// so `np.mean(np.array([1, 2, 3], np.int32))` is `float64 2.0` and
167/// `np.mean([True, False, True])` is `float64 0.6666…`. Floating-point inputs
168/// keep their own dtype (`f32`→`f32`, `f64`→`f64`), and complex stays itself.
169///
170/// The mapping:
171///   - `bool / i8.. / u8.. → f64`
172///   - `f32 → f32`, `f64 → f64` (unchanged — `Mean == Self`)
173///   - `Complex<f32> → Complex<f32>`, `Complex<f64> → Complex<f64>`
174pub trait MeanAcc: Element + Copy {
175    /// The accumulator-and-result element type for `mean` over `Self`.
176    type Mean: Element
177        + Copy
178        + std::ops::Add<Output = Self::Mean>
179        + std::ops::Div<Output = Self::Mean>;
180
181    /// Widen one element into the mean accumulator type, matching NumPy's
182    /// pre-average cast (`true → 1.0`, `false → 0.0` for `bool`).
183    fn widen_mean(self) -> Self::Mean;
184
185    /// Construct the divisor (element count `n`) in the accumulator type.
186    fn count(n: usize) -> Self::Mean;
187}
188
189macro_rules! impl_mean_acc_to_f64 {
190    ($($t:ty),* $(,)?) => {
191        $(
192            impl MeanAcc for $t {
193                type Mean = f64;
194                #[inline]
195                fn widen_mean(self) -> f64 {
196                    self as f64
197                }
198                #[inline]
199                fn count(n: usize) -> f64 {
200                    n as f64
201                }
202            }
203        )*
204    };
205}
206
207// bool / all integer dtypes average in f64 (numpy/_core/_methods.py:124-127).
208impl_mean_acc_to_f64!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
209
210impl MeanAcc for bool {
211    type Mean = f64;
212    #[inline]
213    fn widen_mean(self) -> f64 {
214        if self { 1.0 } else { 0.0 }
215    }
216    #[inline]
217    fn count(n: usize) -> f64 {
218        n as f64
219    }
220}
221
222// Floating-point inputs keep their own dtype (Mean == Self).
223impl MeanAcc for f32 {
224    type Mean = f32;
225    #[inline]
226    fn widen_mean(self) -> f32 {
227        self
228    }
229    #[inline]
230    fn count(n: usize) -> f32 {
231        n as f32
232    }
233}
234
235impl MeanAcc for f64 {
236    type Mean = f64;
237    #[inline]
238    fn widen_mean(self) -> f64 {
239        self
240    }
241    #[inline]
242    fn count(n: usize) -> f64 {
243        n as f64
244    }
245}
246
247impl MeanAcc for num_complex::Complex<f32> {
248    type Mean = num_complex::Complex<f32>;
249    #[inline]
250    fn widen_mean(self) -> Self {
251        self
252    }
253    #[inline]
254    fn count(n: usize) -> Self {
255        num_complex::Complex::new(n as f32, 0.0)
256    }
257}
258
259impl MeanAcc for num_complex::Complex<f64> {
260    type Mean = num_complex::Complex<f64>;
261    #[inline]
262    fn widen_mean(self) -> Self {
263        self
264    }
265    #[inline]
266    fn count(n: usize) -> Self {
267        num_complex::Complex::new(n as f64, 0.0)
268    }
269}
270
271/// First-occurrence, NaN-first arg-reduction over a flat element iterator.
272///
273/// Mirrors NumPy's `argmax`/`argmin` (`numpy/_core/fromnumeric.py:1222`,
274/// `:1322`): on ties the *first* occurrence wins (`:1261-1262`, `:1361-1362`),
275/// and when any NaN is present the index of the *first* NaN is returned
276/// (NaN-propagating, NaN-first — live oracle numpy 2.4.5:
277/// `np.argmax([1.0, nan, 3.0, nan]) == 1`).
278///
279/// Returns `None` for an empty iterator, the `Option` analog `min`/`max` use;
280/// the ferray-python boundary maps `None`→`ValueError` as numpy does.
281#[inline]
282fn arg_reduce<T: PartialOrd + Copy>(iter: impl Iterator<Item = T>, take_min: bool) -> Option<i64> {
283    let mut best_idx: Option<i64> = None;
284    let mut best: Option<T> = None;
285    for (i, x) in iter.enumerate() {
286        let i = i as i64;
287        // NaN-first: the first NaN seen wins immediately and is never beaten.
288        if x.partial_cmp(&x).is_none() {
289            return Some(i);
290        }
291        match best {
292            None => {
293                best = Some(x);
294                best_idx = Some(i);
295            }
296            Some(b) => {
297                // Strict comparison => first occurrence wins on ties.
298                let replace = match x.partial_cmp(&b) {
299                    Some(std::cmp::Ordering::Less) => take_min,
300                    Some(std::cmp::Ordering::Greater) => !take_min,
301                    _ => false,
302                };
303                if replace {
304                    best = Some(x);
305                    best_idx = Some(i);
306                }
307            }
308        }
309    }
310    best_idx
311}
312
313/// Generic min/max fold step that propagates NaN per `NumPy` semantics.
314///
315/// Once any NaN enters the fold, all subsequent steps return NaN. Detected
316/// generically via `x.partial_cmp(&x).is_none()`, which is true iff `x` is
317/// NaN (or any other value that violates `PartialOrd` reflexivity, e.g.
318/// `Complex` types — but those don't implement `PartialOrd` so this is moot).
319///
320/// On an equal compare (`Ordering::Equal`) the NEW element `x` is kept, not the
321/// accumulator. This mirrors `numpy`'s `maximum.reduce`/`minimum.reduce`
322/// (`numpy/_core/_methods.py:38-44`, `umr_maximum`/`umr_minimum`), whose
323/// underlying scalar `maximum(a, b)`/`minimum(a, b)` loops return the *later*
324/// operand on ties — observable only for signed zeros (`+0.0 == -0.0`), where
325/// numpy keeps the LAST seen zero's sign bit. For any non-signed-zero equal
326/// pair the values are identical, so this changes nothing. This is the VALUE
327/// min/max reduce; `argmin`/`argmax` use first-occurrence on ties and live on a
328/// separate code path (they do not call `reduce_step`).
329#[inline]
330fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
331    let acc_is_nan = acc.partial_cmp(&acc).is_none();
332    if acc_is_nan {
333        return acc;
334    }
335    let x_is_nan = x.partial_cmp(&x).is_none();
336    if x_is_nan {
337        return x;
338    }
339    match (take_min, x.partial_cmp(&acc)) {
340        (true, Some(std::cmp::Ordering::Less)) => x,
341        (false, Some(std::cmp::Ordering::Greater)) => x,
342        // Tie: keep the LAST operand (numpy maximum/minimum.reduce semantics).
343        (_, Some(std::cmp::Ordering::Equal)) => x,
344        _ => acc,
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Sum / Prod (work for any Element with Add/Mul, using Element::zero/one)
350// ---------------------------------------------------------------------------
351
352impl<T, D> Array<T, D>
353where
354    T: Element + Copy,
355    D: Dimension,
356{
357    /// Sum of all elements (whole-array reduction).
358    ///
359    /// The result type is the NumPy reduction accumulator
360    /// [`ReduceAcc::Acc`]: narrow signed ints widen to `i64`, narrow unsigned
361    /// ints to `u64`, `bool` to `i64`, and `f32`/`f64`/complex stay
362    /// themselves. This means a narrow-int sum can never overflow and its
363    /// dtype matches `np.sum`'s promoted result
364    /// (`numpy/_core/fromnumeric.py:2321-2327`).
365    ///
366    /// Returns `Acc::zero()` for an empty array.
367    ///
368    /// # Examples
369    /// ```
370    /// # use ferray_core::Array;
371    /// # use ferray_core::dimension::Ix1;
372    /// let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
373    /// assert_eq!(a.sum(), 6.0);
374    /// // i8 sums promote to i64 and never overflow (numpy parity):
375    /// let b = Array::<i8, Ix1>::from_vec(Ix1::new([3]), vec![100, 100, 100]).unwrap();
376    /// assert_eq!(b.sum(), 300_i64);
377    /// ```
378    pub fn sum(&self) -> <T as ReduceAcc>::Acc
379    where
380        T: ReduceAcc,
381    {
382        let mut acc = <T as ReduceAcc>::Acc::zero();
383        for &x in self.iter() {
384            acc = acc + x.widen();
385        }
386        acc
387    }
388
389    /// Sum along the given axis. Returns an array with one fewer dimension,
390    /// whose element type is the promoted [`ReduceAcc::Acc`] (same numpy
391    /// narrow-int promotion as the whole-array [`Array::sum`]).
392    ///
393    /// # Errors
394    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
395    pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
396    where
397        T: ReduceAcc,
398        D::NdarrayDim: ndarray::RemoveAxis,
399    {
400        let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
401        widened.fold_axis(axis, <T as ReduceAcc>::Acc::zero(), |acc, &x| *acc + x)
402    }
403
404    /// Product of all elements.
405    ///
406    /// The result type is the promoted [`ReduceAcc::Acc`] (same numpy
407    /// narrow-int promotion as [`Array::sum`]; see
408    /// `numpy/_core/fromnumeric.py:3306-3312`).
409    ///
410    /// Returns `Acc::one()` for an empty array.
411    pub fn prod(&self) -> <T as ReduceAcc>::Acc
412    where
413        T: ReduceAcc,
414    {
415        let mut acc = <T as ReduceAcc>::Acc::one();
416        for &x in self.iter() {
417            acc = acc * x.widen();
418        }
419        acc
420    }
421
422    /// Product along the given axis. Element type is the promoted
423    /// [`ReduceAcc::Acc`].
424    pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
425    where
426        T: ReduceAcc,
427        D::NdarrayDim: ndarray::RemoveAxis,
428    {
429        let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
430        widened.fold_axis(axis, <T as ReduceAcc>::Acc::one(), |acc, &x| *acc * x)
431    }
432}
433
434// ---------------------------------------------------------------------------
435// Min / Max — require PartialOrd
436// ---------------------------------------------------------------------------
437
438impl<T, D> Array<T, D>
439where
440    T: Element + Copy + PartialOrd,
441    D: Dimension,
442{
443    /// Minimum value across the entire array.
444    ///
445    /// Returns `None` if the array is empty. NaN values follow `NumPy` semantics:
446    /// once a NaN is seen the result stays NaN, detected via self-comparison
447    /// (`x.partial_cmp(&x).is_none()`).
448    pub fn min(&self) -> Option<T> {
449        let mut iter = self.iter().copied();
450        let first = iter.next()?;
451        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
452    }
453
454    /// Maximum value across the entire array.
455    ///
456    /// Returns `None` if the array is empty. NaN values propagate per `NumPy`.
457    pub fn max(&self) -> Option<T> {
458        let mut iter = self.iter().copied();
459        let first = iter.next()?;
460        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
461    }
462
463    /// Minimum value along an axis.
464    ///
465    /// # Errors
466    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
467    /// `FerrayError::ShapeMismatch` if the resulting axis would be empty.
468    pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
469    where
470        D::NdarrayDim: ndarray::RemoveAxis,
471    {
472        // Use the first element along the axis as init by sentinel: pull the
473        // first lane and fold the rest. fold_axis applies init to every lane,
474        // but min has no neutral identity for arbitrary T. We sidestep by
475        // folding starting from any element of `self` — the per-lane init is
476        // overwritten by the first comparison, which is correct iff every lane
477        // has at least one element. Empty axes would yield uninitialized data.
478        let ndim = self.ndim();
479        if axis.index() >= ndim {
480            return Err(crate::error::FerrayError::axis_out_of_bounds(
481                axis.index(),
482                ndim,
483            ));
484        }
485        if self.shape()[axis.index()] == 0 {
486            return Err(crate::error::FerrayError::shape_mismatch(
487                "cannot compute min along empty axis",
488            ));
489        }
490        // Manual lane iteration: fold_axis can't be used here because min has
491        // no neutral identity that works for arbitrary `T: PartialOrd` (no
492        // T::infinity for ints).
493        self.fold_axis_min_max(axis, true)
494    }
495
496    /// Maximum value along an axis.
497    ///
498    /// See [`Array::min_axis`] for error semantics.
499    pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
500    where
501        D::NdarrayDim: ndarray::RemoveAxis,
502    {
503        let ndim = self.ndim();
504        if axis.index() >= ndim {
505            return Err(crate::error::FerrayError::axis_out_of_bounds(
506                axis.index(),
507                ndim,
508            ));
509        }
510        if self.shape()[axis.index()] == 0 {
511            return Err(crate::error::FerrayError::shape_mismatch(
512                "cannot compute max along empty axis",
513            ));
514        }
515        self.fold_axis_min_max(axis, false)
516    }
517
518    /// Flat index of the maximum element (whole-array reduction).
519    ///
520    /// Returns `None` for an empty array (the `Option` analog `min`/`max`
521    /// use; the ferray-python boundary maps `None`→`ValueError`, matching
522    /// `np.argmax([])`). On ties the **first** occurrence wins, and when any
523    /// NaN is present the index of the **first** NaN is returned (NaN-first),
524    /// matching `np.argmax` (`numpy/_core/fromnumeric.py:1222`, ties at
525    /// `:1261-1262`; live oracle `np.argmax([1.0, nan, 3.0, nan]) == 1`). The
526    /// index type is `i64` (ferray's `intp` analog), independent of `T`.
527    pub fn argmax(&self) -> Option<i64> {
528        arg_reduce(self.iter().copied(), false)
529    }
530
531    /// Flat index of the minimum element (whole-array reduction).
532    ///
533    /// Mirror of [`Array::argmax`] with min substituted for max: first
534    /// occurrence on ties, NaN-first, `None` on empty, `i64` index
535    /// (`numpy/_core/fromnumeric.py:1322`, ties at `:1361-1362`; live oracle
536    /// `np.argmin([1.0, nan, 3.0]) == 1`).
537    pub fn argmin(&self) -> Option<i64> {
538        arg_reduce(self.iter().copied(), true)
539    }
540
541    /// Indices of the maxima along `axis`, as an `Array<i64, IxDyn>` with the
542    /// reduced axis removed. First-occurrence on ties, NaN-first per lane
543    /// (`numpy/_core/fromnumeric.py:1222`).
544    ///
545    /// # Errors
546    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
547    /// `FerrayError::ShapeMismatch` if the reduced axis is empty (matching
548    /// numpy's `ValueError` on an empty argmax axis).
549    pub fn argmax_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
550    where
551        D::NdarrayDim: ndarray::RemoveAxis,
552    {
553        self.arg_axis(axis, false)
554    }
555
556    /// Indices of the minima along `axis`. See [`Array::argmax_axis`].
557    pub fn argmin_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
558    where
559        D::NdarrayDim: ndarray::RemoveAxis,
560    {
561        self.arg_axis(axis, true)
562    }
563
564    /// Internal: per-lane arg-reduction along `axis`. Each lane is a 1D view
565    /// orthogonal to `axis`; the reduced index is the position within the lane.
566    fn arg_axis(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<i64, IxDyn>>
567    where
568        D::NdarrayDim: ndarray::RemoveAxis,
569    {
570        let ndim = self.ndim();
571        if axis.index() >= ndim {
572            return Err(crate::error::FerrayError::axis_out_of_bounds(
573                axis.index(),
574                ndim,
575            ));
576        }
577        if self.shape()[axis.index()] == 0 {
578            return Err(crate::error::FerrayError::shape_mismatch(
579                "attempt to get argmax/argmin of an empty axis",
580            ));
581        }
582        let nd_axis = ndarray::Axis(axis.index());
583        let lanes = self.inner.lanes(nd_axis);
584        let mut out: Vec<i64> = Vec::with_capacity(lanes.into_iter().len());
585        for lane in self.inner.lanes(nd_axis) {
586            // Lane is non-empty (empty axis already rejected), so arg_reduce
587            // returns Some; default 0 is unreachable but keeps the code panic-free.
588            let idx = arg_reduce(lane.iter().copied(), take_min).unwrap_or(0);
589            out.push(idx);
590        }
591        let mut out_shape: Vec<usize> = self.shape().to_vec();
592        out_shape.remove(axis.index());
593        Array::from_vec(IxDyn::from(&out_shape[..]), out)
594    }
595
596    /// Internal: per-lane min/max via manual lane iteration. Avoids the
597    /// init-bias problem of `fold_axis` (which applies a single init to every
598    /// lane, even though min/max have no identity element).
599    fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
600    where
601        D::NdarrayDim: ndarray::RemoveAxis,
602    {
603        let nd_axis = ndarray::Axis(axis.index());
604        // Use ndarray's lane iteration directly via the inner ndarray::ArrayBase.
605        // Each lane is a 1D view orthogonal to the chosen axis.
606        let lanes = self.inner.lanes(nd_axis);
607        let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
608        for lane in self.inner.lanes(nd_axis) {
609            let mut iter = lane.iter().copied();
610            let first = iter.next().unwrap(); // safe: empty axis already rejected
611            let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
612            out.push(result);
613        }
614
615        // Output shape: drop the reduced axis from the input shape.
616        let mut out_shape: Vec<usize> = self.shape().to_vec();
617        out_shape.remove(axis.index());
618        Array::from_vec(IxDyn::from(&out_shape[..]), out)
619    }
620}
621
622// ---------------------------------------------------------------------------
623// Mean / Var / Std — require Float
624// ---------------------------------------------------------------------------
625
626impl<T, D> Array<T, D>
627where
628    T: MeanAcc,
629    D: Dimension,
630{
631    /// Arithmetic mean of all elements. Returns `None` for an empty array.
632    ///
633    /// The result type is the NumPy mean accumulator [`MeanAcc::Mean`]:
634    /// bool / integer inputs average in (and return) `f64`, while `f32`/`f64`/
635    /// complex keep their own dtype. This matches numpy, which casts bool /
636    /// unsigned / signed int to `float64` before averaging
637    /// (`numpy/_core/_methods.py:124-127`), so `Array::<i32, _>::mean()` is
638    /// `Some(f64)` and `Array::<bool, _>::mean()` is `Some(f64)` (e.g.
639    /// `0.666…`), matching `np.mean`.
640    pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
641        let n = self.size();
642        if n == 0 {
643            return None;
644        }
645        let sum = self
646            .iter()
647            .copied()
648            .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
649        Some(sum / <T as MeanAcc>::count(n))
650    }
651
652    /// Mean along an axis. Element type is the promoted [`MeanAcc::Mean`]
653    /// (bool / integer lanes average in `f64`; `f32`/`f64` stay themselves).
654    pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<<T as MeanAcc>::Mean, IxDyn>>
655    where
656        <T as MeanAcc>::Mean: ReduceAcc<Acc = <T as MeanAcc>::Mean>,
657        D::NdarrayDim: ndarray::RemoveAxis,
658    {
659        let ndim = self.ndim();
660        if axis.index() >= ndim {
661            return Err(crate::error::FerrayError::axis_out_of_bounds(
662                axis.index(),
663                ndim,
664            ));
665        }
666        let n = self.shape()[axis.index()];
667        if n == 0 {
668            return Err(crate::error::FerrayError::shape_mismatch(
669                "cannot compute mean along empty axis",
670            ));
671        }
672        // Widen each element into the mean accumulator, then sum along the axis
673        // and divide by the lane length.
674        let widened = self.map_to::<<T as MeanAcc>::Mean>(MeanAcc::widen_mean);
675        let sums = widened.sum_axis(axis)?;
676        let n_t = <T as MeanAcc>::count(n);
677        Ok(sums.mapv(|x| x / n_t))
678    }
679}
680
681impl<T, D> Array<T, D>
682where
683    T: Element + Float + MeanAcc<Mean = T>,
684    D: Dimension,
685{
686    /// Variance with `ddof` degrees of freedom (Bessel's correction = 1).
687    ///
688    /// Returns `None` for an empty array, or when `ddof >= n`.
689    pub fn var(&self, ddof: usize) -> Option<T> {
690        let n = self.size();
691        if n == 0 || ddof >= n {
692            return None;
693        }
694        let mean = self.mean()?;
695        let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
696            acc + (x - mean) * (x - mean)
697        });
698        Some(sum_sq / T::from(n - ddof).unwrap())
699    }
700
701    /// Standard deviation with `ddof` degrees of freedom.
702    pub fn std(&self, ddof: usize) -> Option<T> {
703        self.var(ddof).map(num_traits::Float::sqrt)
704    }
705}
706
707// ---------------------------------------------------------------------------
708// any / all — for bool arrays
709// ---------------------------------------------------------------------------
710
711impl<D> Array<bool, D>
712where
713    D: Dimension,
714{
715    /// Returns `true` if any element is `true`.
716    pub fn any(&self) -> bool {
717        self.iter().any(|&x| x)
718    }
719
720    /// Returns `true` if all elements are `true`. Vacuously `true` for empty arrays.
721    pub fn all(&self) -> bool {
722        self.iter().all(|&x| x)
723    }
724}
725
726// ---------------------------------------------------------------------------
727// ArrayView mirrors — same methods on borrowed views
728// ---------------------------------------------------------------------------
729
730impl<T, D> ArrayView<'_, T, D>
731where
732    T: Element + Copy,
733    D: Dimension,
734{
735    /// Sum of all elements. See [`Array::sum`] — returns the promoted
736    /// [`ReduceAcc::Acc`].
737    pub fn sum(&self) -> <T as ReduceAcc>::Acc
738    where
739        T: ReduceAcc,
740    {
741        let mut acc = <T as ReduceAcc>::Acc::zero();
742        for &x in self.iter() {
743            acc = acc + x.widen();
744        }
745        acc
746    }
747
748    /// Product of all elements. See [`Array::prod`] — returns the promoted
749    /// [`ReduceAcc::Acc`].
750    pub fn prod(&self) -> <T as ReduceAcc>::Acc
751    where
752        T: ReduceAcc,
753    {
754        let mut acc = <T as ReduceAcc>::Acc::one();
755        for &x in self.iter() {
756            acc = acc * x.widen();
757        }
758        acc
759    }
760}
761
762impl<T, D> ArrayView<'_, T, D>
763where
764    T: Element + Copy + PartialOrd,
765    D: Dimension,
766{
767    /// Minimum value. See [`Array::min`].
768    pub fn min(&self) -> Option<T> {
769        let mut iter = self.iter().copied();
770        let first = iter.next()?;
771        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
772    }
773
774    /// Maximum value. See [`Array::max`].
775    pub fn max(&self) -> Option<T> {
776        let mut iter = self.iter().copied();
777        let first = iter.next()?;
778        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
779    }
780
781    /// Flat index of the maximum element. See [`Array::argmax`].
782    pub fn argmax(&self) -> Option<i64> {
783        arg_reduce(self.iter().copied(), false)
784    }
785
786    /// Flat index of the minimum element. See [`Array::argmin`].
787    pub fn argmin(&self) -> Option<i64> {
788        arg_reduce(self.iter().copied(), true)
789    }
790}
791
792impl<T, D> ArrayView<'_, T, D>
793where
794    T: MeanAcc,
795    D: Dimension,
796{
797    /// Mean. See [`Array::mean`] — returns the promoted [`MeanAcc::Mean`].
798    pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
799        let n = self.size();
800        if n == 0 {
801            return None;
802        }
803        let sum = self
804            .iter()
805            .copied()
806            .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
807        Some(sum / <T as MeanAcc>::count(n))
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use super::*;
814    use crate::dimension::{Ix1, Ix2};
815
816    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
817        let n = data.len();
818        Array::from_vec(Ix1::new([n]), data).unwrap()
819    }
820
821    fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
822        Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
823    }
824
825    // ----- sum / prod -----
826
827    #[test]
828    fn sum_1d() {
829        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
830        assert_eq!(a.sum(), 10.0);
831    }
832
833    #[test]
834    fn sum_empty_returns_zero() {
835        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
836        assert_eq!(a.sum(), 0.0);
837    }
838
839    #[test]
840    fn sum_axis_2d() {
841        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
842        // Sum across rows (axis 0): [1+4, 2+5, 3+6] = [5, 7, 9]
843        let s0 = a.sum_axis(Axis(0)).unwrap();
844        assert_eq!(s0.shape(), &[3]);
845        assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
846
847        // Sum across columns (axis 1): [1+2+3, 4+5+6] = [6, 15]
848        let s1 = a.sum_axis(Axis(1)).unwrap();
849        assert_eq!(s1.shape(), &[2]);
850        assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
851    }
852
853    #[test]
854    fn prod_1d() {
855        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
856        assert_eq!(a.prod(), 24.0);
857    }
858
859    #[test]
860    fn prod_empty_returns_one() {
861        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
862        assert_eq!(a.prod(), 1.0);
863    }
864
865    #[test]
866    fn prod_axis_2d() {
867        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
868        let p0 = a.prod_axis(Axis(0)).unwrap();
869        assert_eq!(
870            p0.iter().copied().collect::<Vec<_>>(),
871            vec![4.0, 10.0, 18.0]
872        );
873
874        let p1 = a.prod_axis(Axis(1)).unwrap();
875        assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
876    }
877
878    // ----- min / max -----
879
880    #[test]
881    fn min_max_1d() {
882        let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
883        assert_eq!(a.min(), Some(1.0));
884        assert_eq!(a.max(), Some(9.0));
885    }
886
887    #[test]
888    fn min_max_empty_returns_none() {
889        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
890        assert_eq!(a.min(), None);
891        assert_eq!(a.max(), None);
892    }
893
894    #[test]
895    fn min_max_int() {
896        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
897        assert_eq!(a.min(), Some(-5));
898        assert_eq!(a.max(), Some(4));
899    }
900
901    #[test]
902    fn min_max_axis_2d() {
903        let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
904        // axis 0: min/max per column
905        let mn0 = a.min_axis(Axis(0)).unwrap();
906        assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
907        let mx0 = a.max_axis(Axis(0)).unwrap();
908        assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
909
910        // axis 1: min/max per row
911        let mn1 = a.min_axis(Axis(1)).unwrap();
912        assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
913        let mx1 = a.max_axis(Axis(1)).unwrap();
914        assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
915    }
916
917    // ----- mean / var / std -----
918
919    #[test]
920    fn mean_1d() {
921        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
922        assert_eq!(a.mean(), Some(2.5));
923    }
924
925    #[test]
926    fn mean_empty_returns_none() {
927        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
928        assert_eq!(a.mean(), None);
929    }
930
931    #[test]
932    fn mean_axis_2d() {
933        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
934        let m0 = a.mean_axis(Axis(0)).unwrap();
935        assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
936        let m1 = a.mean_axis(Axis(1)).unwrap();
937        assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
938    }
939
940    #[test]
941    fn var_population() {
942        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
943        // population variance (ddof=0): ((1-3)^2+(2-3)^2+(3-3)^2+(4-3)^2+(5-3)^2)/5 = 10/5 = 2
944        assert_eq!(a.var(0), Some(2.0));
945    }
946
947    #[test]
948    fn var_sample() {
949        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
950        // sample variance (ddof=1): 10/4 = 2.5
951        assert_eq!(a.var(1), Some(2.5));
952    }
953
954    #[test]
955    fn std_basic() {
956        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
957        let s = a.std(0).unwrap();
958        assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
959    }
960
961    #[test]
962    fn var_ddof_too_large_returns_none() {
963        let a = arr1(vec![1.0, 2.0]);
964        assert_eq!(a.var(2), None);
965        assert_eq!(a.var(5), None);
966    }
967
968    // ----- any / all -----
969
970    #[test]
971    fn any_all_bool() {
972        let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
973        let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
974        let false_arr =
975            Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
976        let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
977
978        assert!(true_arr.all());
979        assert!(true_arr.any());
980
981        assert!(!mixed.all());
982        assert!(mixed.any());
983
984        assert!(!false_arr.all());
985        assert!(!false_arr.any());
986
987        // Vacuous truth for empty
988        assert!(empty.all());
989        assert!(!empty.any());
990    }
991
992    // ----- ArrayView mirrors -----
993
994    #[test]
995    fn view_sum_min_max_mean() {
996        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
997        let v = a.view();
998        assert_eq!(v.sum(), 10.0);
999        assert_eq!(v.min(), Some(1.0));
1000        assert_eq!(v.max(), Some(4.0));
1001        assert_eq!(v.mean(), Some(2.5));
1002    }
1003
1004    #[test]
1005    fn nan_propagates_in_min_max() {
1006        // NaN somewhere in the middle
1007        let a = arr1(vec![1.0, f64::NAN, 3.0]);
1008        assert!(a.min().unwrap().is_nan());
1009        assert!(a.max().unwrap().is_nan());
1010
1011        // NaN at the start
1012        let b = arr1(vec![f64::NAN, 1.0, 3.0]);
1013        assert!(b.min().unwrap().is_nan());
1014        assert!(b.max().unwrap().is_nan());
1015
1016        // NaN at the end
1017        let c = arr1(vec![1.0, 3.0, f64::NAN]);
1018        assert!(c.min().unwrap().is_nan());
1019        assert!(c.max().unwrap().is_nan());
1020    }
1021}