Skip to main content

ferray_core/array/
reductions.rs

1// ferray-core: Reduction methods for Array<T, D>
2//
3// Provides NumPy-equivalent reduction methods directly on Array and ArrayView:
4//   sum, prod, min, max, mean, var, std, any, all
5//
6// Each reduction has a whole-array variant and an axis variant:
7//   .sum()             -> T
8//   .sum_axis(Axis(0)) -> Array<T, IxDyn>
9//
10// These methods complement the lower-level fold_axis primitive in methods.rs
11// and the free-function reductions in ferray-stats. The instance-method form
12// matches NumPy's `arr.sum()` ergonomics so users don't need to import a
13// separate crate just to compute a sum.
14//
15// ## REQ status (reductions, NumPy parity)
16//  - sum / prod / sum_axis / prod_axis accumulator promotion — SHIPPED (#780):
17//    the `ReduceAcc` trait (this file) maps narrow signed ints → i64, narrow
18//    unsigned ints → u64, and bool → i64 before reducing, so narrow-int
19//    reductions never overflow and the result dtype matches numpy
20//    (numpy/_core/fromnumeric.py:2321-2327). Consumers: `Array::sum`/`prod`/
21//    `sum_axis`/`prod_axis` and the `ArrayView` mirrors (all in this file).
22//  - min / max / mean / var / std / any / all — SHIPPED (#368), NaN-propagating.
23//  - empty-array min/max raising ValueError — NOT-STARTED (open blocker #782).
24//  - cumsum / cumprod live as ferray-stats free functions; their narrow-int
25//    promotion is a separate ferray-stats blocker that reuses `ReduceAcc`.
26//  - argmax / argmin (REQ-40, REQ-41) — SHIPPED: `Array::argmax`/`argmin`
27//    (flattened, returning `Option<i64>`) and `argmax_axis`/`argmin_axis`
28//    (returning `Array<i64, IxDyn>`) in this file. First-occurrence on ties and
29//    NaN-first propagation (numpy/_core/fromnumeric.py:1222 `argmax`,
30//    :1261-1262 ties; :1322 `argmin`, :1361-1362 ties). Empty flattened form
31//    returns `None` (mirroring `min`/`max`'s `Option` analog; the ferray-python
32//    boundary maps `None`→`ValueError` as numpy does). Result index dtype is
33//    `i64` (ferray's `intp` analog), independent of element dtype. Consumers:
34//    the boundary methods themselves are the public API surface (like
35//    `sum`/`min`), exercised by the `ArrayView` mirror and the in-file tests.
36//  - integer/bool mean → f64 (REQ-42) — SHIPPED: the `MeanAcc` trait (this
37//    file) maps bool/integer element types to an `f64` accumulator-and-result,
38//    while `f32`/`f64`/complex stay themselves, so `Array::<i32, _>::mean()`
39//    returns `f64` matching numpy, which casts bool/unsigned/signed int to
40//    float64 before averaging (numpy/_core/_methods.py:124-127). `mean`/
41//    `mean_axis` and the `ArrayView::mean` mirror are now bounded by `MeanAcc`
42//    instead of `Float`; existing `f32`/`f64` means are unchanged
43//    (`MeanAcc::Mean == Self`). Consumers: `Array::var`/`std` (`self.mean()?`)
44//    and the `ArrayView` mirror, all in this file.
45//
46// See: https://github.com/dollspace-gay/ferray/issues/368, /issues/780
47
48use num_traits::Float;
49
50use crate::array::owned::Array;
51use crate::array::view::ArrayView;
52use crate::dimension::{Axis, Dimension, IxDyn};
53use crate::dtype::Element;
54use crate::error::FerrayResult;
55
56// ---------------------------------------------------------------------------
57// ReduceAcc — NumPy's sum/prod/cumsum/cumprod accumulator-and-result dtype.
58// ---------------------------------------------------------------------------
59
60/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
61/// return) `sum` / `prod` / `cumsum` / `cumprod` over it.
62///
63/// NumPy promotes any integer dtype of *less precision than the default
64/// platform integer* before reducing, so a narrow-int reduction can never
65/// overflow and the result dtype is the platform integer:
66///
67/// > "The dtype of `a` is used by default unless `a` has an integer dtype of
68/// > less precision than the default platform integer.  In that case, if `a`
69/// > is signed then the platform integer is used while if `a` is unsigned then
70/// > an unsigned integer of the same precision as the platform integer is
71/// > used."
72/// > — `numpy/_core/fromnumeric.py:2321-2327` (sum), `:3306-3312` (prod)
73///
74/// The reduction itself is `umr_sum = um.add.reduce` /
75/// `umr_prod = um.multiply.reduce` (`numpy/_core/_methods.py:20-21`), i.e. the
76/// add/multiply ufunc whose *loop dtype* is the promoted accumulator.
77///
78/// The mapping (platform integer = 64-bit, matching ferray's `intp`/`int64`):
79///   - `i8 / i16 / i32 → i64`,  `i64 → i64`,  `i128 → i128`
80///   - `u8 / u16 / u32 → u64`,  `u64 → u64`,  `u128 → u128`
81///   - `bool → i64` (NumPy reduces bool as the platform integer, counting `true`)
82///   - `f32 → f32`, `f64 → f64`, complex stays itself (no promotion)
83///
84/// Wider-or-equal dtypes map to themselves, so existing `f64`/`i64`/complex
85/// reductions are unchanged — only narrow-int callers observe the promoted
86/// return type.
87pub trait ReduceAcc: Element + Copy {
88    /// The accumulator-and-result element type for reductions over `Self`.
89    type Acc: Element + Copy + std::ops::Add<Output = Self::Acc> + std::ops::Mul<Output = Self::Acc>;
90
91    /// Widen one element into the accumulator type before reducing, matching
92    /// NumPy's promotion of the loop dtype (`true → 1` for `bool`).
93    fn widen(self) -> Self::Acc;
94}
95
96macro_rules! impl_reduce_acc {
97    ($($t:ty => $acc:ty),* $(,)?) => {
98        $(
99            impl ReduceAcc for $t {
100                type Acc = $acc;
101                #[inline]
102                fn widen(self) -> $acc {
103                    self as $acc
104                }
105            }
106        )*
107    };
108}
109
110// Narrow signed ints promote to i64; i64/i128 stay themselves.
111impl_reduce_acc! {
112    i8 => i64, i16 => i64, i32 => i64, i64 => i64, i128 => i128,
113    u8 => u64, u16 => u64, u32 => u64, u64 => u64, u128 => u128,
114    f32 => f32, f64 => f64,
115}
116
117// bool reduces as the platform integer, counting `true` (numpy:
118// `np.sum(np.array([True, True, True])).dtype == int64`). `as i64` maps
119// false→0, true→1.
120impl ReduceAcc for bool {
121    type Acc = i64;
122    #[inline]
123    fn widen(self) -> i64 {
124        i64::from(self)
125    }
126}
127
128// Complex stays itself — numpy never promotes a complex reduction.
129impl ReduceAcc for num_complex::Complex<f32> {
130    type Acc = num_complex::Complex<f32>;
131    #[inline]
132    fn widen(self) -> Self {
133        self
134    }
135}
136
137impl ReduceAcc for num_complex::Complex<f64> {
138    type Acc = num_complex::Complex<f64>;
139    #[inline]
140    fn widen(self) -> Self {
141        self
142    }
143}
144
145// ---------------------------------------------------------------------------
146// MeanAcc — NumPy's mean accumulator-and-result dtype.
147// ---------------------------------------------------------------------------
148
149/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
150/// return) `mean` over it.
151///
152/// NumPy casts a bool / unsigned-int / signed-int input to `float64` before
153/// averaging:
154///
155/// > "Cast bool, unsigned int, and int to float64 by default ...
156/// > `dtype = mu.dtype('f8')`"
157/// > — `numpy/_core/_methods.py:124-127`
158///
159/// so `np.mean(np.array([1, 2, 3], np.int32))` is `float64 2.0` and
160/// `np.mean([True, False, True])` is `float64 0.6666…`. Floating-point inputs
161/// keep their own dtype (`f32`→`f32`, `f64`→`f64`), and complex stays itself.
162///
163/// The mapping:
164///   - `bool / i8.. / u8.. → f64`
165///   - `f32 → f32`, `f64 → f64` (unchanged — `Mean == Self`)
166///   - `Complex<f32> → Complex<f32>`, `Complex<f64> → Complex<f64>`
167pub trait MeanAcc: Element + Copy {
168    /// The accumulator-and-result element type for `mean` over `Self`.
169    type Mean: Element
170        + Copy
171        + std::ops::Add<Output = Self::Mean>
172        + std::ops::Div<Output = Self::Mean>;
173
174    /// Widen one element into the mean accumulator type, matching NumPy's
175    /// pre-average cast (`true → 1.0`, `false → 0.0` for `bool`).
176    fn widen_mean(self) -> Self::Mean;
177
178    /// Construct the divisor (element count `n`) in the accumulator type.
179    fn count(n: usize) -> Self::Mean;
180}
181
182macro_rules! impl_mean_acc_to_f64 {
183    ($($t:ty),* $(,)?) => {
184        $(
185            impl MeanAcc for $t {
186                type Mean = f64;
187                #[inline]
188                fn widen_mean(self) -> f64 {
189                    self as f64
190                }
191                #[inline]
192                fn count(n: usize) -> f64 {
193                    n as f64
194                }
195            }
196        )*
197    };
198}
199
200// bool / all integer dtypes average in f64 (numpy/_core/_methods.py:124-127).
201impl_mean_acc_to_f64!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
202
203impl MeanAcc for bool {
204    type Mean = f64;
205    #[inline]
206    fn widen_mean(self) -> f64 {
207        if self { 1.0 } else { 0.0 }
208    }
209    #[inline]
210    fn count(n: usize) -> f64 {
211        n as f64
212    }
213}
214
215// Floating-point inputs keep their own dtype (Mean == Self).
216impl MeanAcc for f32 {
217    type Mean = f32;
218    #[inline]
219    fn widen_mean(self) -> f32 {
220        self
221    }
222    #[inline]
223    fn count(n: usize) -> f32 {
224        n as f32
225    }
226}
227
228impl MeanAcc for f64 {
229    type Mean = f64;
230    #[inline]
231    fn widen_mean(self) -> f64 {
232        self
233    }
234    #[inline]
235    fn count(n: usize) -> f64 {
236        n as f64
237    }
238}
239
240impl MeanAcc for num_complex::Complex<f32> {
241    type Mean = num_complex::Complex<f32>;
242    #[inline]
243    fn widen_mean(self) -> Self {
244        self
245    }
246    #[inline]
247    fn count(n: usize) -> Self {
248        num_complex::Complex::new(n as f32, 0.0)
249    }
250}
251
252impl MeanAcc for num_complex::Complex<f64> {
253    type Mean = num_complex::Complex<f64>;
254    #[inline]
255    fn widen_mean(self) -> Self {
256        self
257    }
258    #[inline]
259    fn count(n: usize) -> Self {
260        num_complex::Complex::new(n as f64, 0.0)
261    }
262}
263
264/// First-occurrence, NaN-first arg-reduction over a flat element iterator.
265///
266/// Mirrors NumPy's `argmax`/`argmin` (`numpy/_core/fromnumeric.py:1222`,
267/// `:1322`): on ties the *first* occurrence wins (`:1261-1262`, `:1361-1362`),
268/// and when any NaN is present the index of the *first* NaN is returned
269/// (NaN-propagating, NaN-first — live oracle numpy 2.4.5:
270/// `np.argmax([1.0, nan, 3.0, nan]) == 1`).
271///
272/// Returns `None` for an empty iterator, the `Option` analog `min`/`max` use;
273/// the ferray-python boundary maps `None`→`ValueError` as numpy does.
274#[inline]
275fn arg_reduce<T: PartialOrd + Copy>(iter: impl Iterator<Item = T>, take_min: bool) -> Option<i64> {
276    let mut best_idx: Option<i64> = None;
277    let mut best: Option<T> = None;
278    for (i, x) in iter.enumerate() {
279        let i = i as i64;
280        // NaN-first: the first NaN seen wins immediately and is never beaten.
281        if x.partial_cmp(&x).is_none() {
282            return Some(i);
283        }
284        match best {
285            None => {
286                best = Some(x);
287                best_idx = Some(i);
288            }
289            Some(b) => {
290                // Strict comparison => first occurrence wins on ties.
291                let replace = match x.partial_cmp(&b) {
292                    Some(std::cmp::Ordering::Less) => take_min,
293                    Some(std::cmp::Ordering::Greater) => !take_min,
294                    _ => false,
295                };
296                if replace {
297                    best = Some(x);
298                    best_idx = Some(i);
299                }
300            }
301        }
302    }
303    best_idx
304}
305
306/// Generic min/max fold step that propagates NaN per `NumPy` semantics.
307///
308/// Once any NaN enters the fold, all subsequent steps return NaN. Detected
309/// generically via `x.partial_cmp(&x).is_none()`, which is true iff `x` is
310/// NaN (or any other value that violates `PartialOrd` reflexivity, e.g.
311/// `Complex` types — but those don't implement `PartialOrd` so this is moot).
312///
313/// On an equal compare (`Ordering::Equal`) the NEW element `x` is kept, not the
314/// accumulator. This mirrors `numpy`'s `maximum.reduce`/`minimum.reduce`
315/// (`numpy/_core/_methods.py:38-44`, `umr_maximum`/`umr_minimum`), whose
316/// underlying scalar `maximum(a, b)`/`minimum(a, b)` loops return the *later*
317/// operand on ties — observable only for signed zeros (`+0.0 == -0.0`), where
318/// numpy keeps the LAST seen zero's sign bit. For any non-signed-zero equal
319/// pair the values are identical, so this changes nothing. This is the VALUE
320/// min/max reduce; `argmin`/`argmax` use first-occurrence on ties and live on a
321/// separate code path (they do not call `reduce_step`).
322#[inline]
323fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
324    let acc_is_nan = acc.partial_cmp(&acc).is_none();
325    if acc_is_nan {
326        return acc;
327    }
328    let x_is_nan = x.partial_cmp(&x).is_none();
329    if x_is_nan {
330        return x;
331    }
332    match (take_min, x.partial_cmp(&acc)) {
333        (true, Some(std::cmp::Ordering::Less)) => x,
334        (false, Some(std::cmp::Ordering::Greater)) => x,
335        // Tie: keep the LAST operand (numpy maximum/minimum.reduce semantics).
336        (_, Some(std::cmp::Ordering::Equal)) => x,
337        _ => acc,
338    }
339}
340
341// ---------------------------------------------------------------------------
342// Sum / Prod (work for any Element with Add/Mul, using Element::zero/one)
343// ---------------------------------------------------------------------------
344
345impl<T, D> Array<T, D>
346where
347    T: Element + Copy,
348    D: Dimension,
349{
350    /// Sum of all elements (whole-array reduction).
351    ///
352    /// The result type is the NumPy reduction accumulator
353    /// [`ReduceAcc::Acc`]: narrow signed ints widen to `i64`, narrow unsigned
354    /// ints to `u64`, `bool` to `i64`, and `f32`/`f64`/complex stay
355    /// themselves. This means a narrow-int sum can never overflow and its
356    /// dtype matches `np.sum`'s promoted result
357    /// (`numpy/_core/fromnumeric.py:2321-2327`).
358    ///
359    /// Returns `Acc::zero()` for an empty array.
360    ///
361    /// # Examples
362    /// ```
363    /// # use ferray_core::Array;
364    /// # use ferray_core::dimension::Ix1;
365    /// let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
366    /// assert_eq!(a.sum(), 6.0);
367    /// // i8 sums promote to i64 and never overflow (numpy parity):
368    /// let b = Array::<i8, Ix1>::from_vec(Ix1::new([3]), vec![100, 100, 100]).unwrap();
369    /// assert_eq!(b.sum(), 300_i64);
370    /// ```
371    pub fn sum(&self) -> <T as ReduceAcc>::Acc
372    where
373        T: ReduceAcc,
374    {
375        let mut acc = <T as ReduceAcc>::Acc::zero();
376        for &x in self.iter() {
377            acc = acc + x.widen();
378        }
379        acc
380    }
381
382    /// Sum along the given axis. Returns an array with one fewer dimension,
383    /// whose element type is the promoted [`ReduceAcc::Acc`] (same numpy
384    /// narrow-int promotion as the whole-array [`Array::sum`]).
385    ///
386    /// # Errors
387    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
388    pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
389    where
390        T: ReduceAcc,
391        D::NdarrayDim: ndarray::RemoveAxis,
392    {
393        let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
394        widened.fold_axis(axis, <T as ReduceAcc>::Acc::zero(), |acc, &x| *acc + x)
395    }
396
397    /// Product of all elements.
398    ///
399    /// The result type is the promoted [`ReduceAcc::Acc`] (same numpy
400    /// narrow-int promotion as [`Array::sum`]; see
401    /// `numpy/_core/fromnumeric.py:3306-3312`).
402    ///
403    /// Returns `Acc::one()` for an empty array.
404    pub fn prod(&self) -> <T as ReduceAcc>::Acc
405    where
406        T: ReduceAcc,
407    {
408        let mut acc = <T as ReduceAcc>::Acc::one();
409        for &x in self.iter() {
410            acc = acc * x.widen();
411        }
412        acc
413    }
414
415    /// Product along the given axis. Element type is the promoted
416    /// [`ReduceAcc::Acc`].
417    pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
418    where
419        T: ReduceAcc,
420        D::NdarrayDim: ndarray::RemoveAxis,
421    {
422        let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
423        widened.fold_axis(axis, <T as ReduceAcc>::Acc::one(), |acc, &x| *acc * x)
424    }
425}
426
427// ---------------------------------------------------------------------------
428// Min / Max — require PartialOrd
429// ---------------------------------------------------------------------------
430
431impl<T, D> Array<T, D>
432where
433    T: Element + Copy + PartialOrd,
434    D: Dimension,
435{
436    /// Minimum value across the entire array.
437    ///
438    /// Returns `None` if the array is empty. NaN values follow `NumPy` semantics:
439    /// once a NaN is seen the result stays NaN, detected via self-comparison
440    /// (`x.partial_cmp(&x).is_none()`).
441    pub fn min(&self) -> Option<T> {
442        let mut iter = self.iter().copied();
443        let first = iter.next()?;
444        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
445    }
446
447    /// Maximum value across the entire array.
448    ///
449    /// Returns `None` if the array is empty. NaN values propagate per `NumPy`.
450    pub fn max(&self) -> Option<T> {
451        let mut iter = self.iter().copied();
452        let first = iter.next()?;
453        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
454    }
455
456    /// Minimum value along an axis.
457    ///
458    /// # Errors
459    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
460    /// `FerrayError::ShapeMismatch` if the resulting axis would be empty.
461    pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
462    where
463        D::NdarrayDim: ndarray::RemoveAxis,
464    {
465        // Use the first element along the axis as init by sentinel: pull the
466        // first lane and fold the rest. fold_axis applies init to every lane,
467        // but min has no neutral identity for arbitrary T. We sidestep by
468        // folding starting from any element of `self` — the per-lane init is
469        // overwritten by the first comparison, which is correct iff every lane
470        // has at least one element. Empty axes would yield uninitialized data.
471        let ndim = self.ndim();
472        if axis.index() >= ndim {
473            return Err(crate::error::FerrayError::axis_out_of_bounds(
474                axis.index(),
475                ndim,
476            ));
477        }
478        if self.shape()[axis.index()] == 0 {
479            return Err(crate::error::FerrayError::shape_mismatch(
480                "cannot compute min along empty axis",
481            ));
482        }
483        // Manual lane iteration: fold_axis can't be used here because min has
484        // no neutral identity that works for arbitrary `T: PartialOrd` (no
485        // T::infinity for ints).
486        self.fold_axis_min_max(axis, true)
487    }
488
489    /// Maximum value along an axis.
490    ///
491    /// See [`Array::min_axis`] for error semantics.
492    pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
493    where
494        D::NdarrayDim: ndarray::RemoveAxis,
495    {
496        let ndim = self.ndim();
497        if axis.index() >= ndim {
498            return Err(crate::error::FerrayError::axis_out_of_bounds(
499                axis.index(),
500                ndim,
501            ));
502        }
503        if self.shape()[axis.index()] == 0 {
504            return Err(crate::error::FerrayError::shape_mismatch(
505                "cannot compute max along empty axis",
506            ));
507        }
508        self.fold_axis_min_max(axis, false)
509    }
510
511    /// Flat index of the maximum element (whole-array reduction).
512    ///
513    /// Returns `None` for an empty array (the `Option` analog `min`/`max`
514    /// use; the ferray-python boundary maps `None`→`ValueError`, matching
515    /// `np.argmax([])`). On ties the **first** occurrence wins, and when any
516    /// NaN is present the index of the **first** NaN is returned (NaN-first),
517    /// matching `np.argmax` (`numpy/_core/fromnumeric.py:1222`, ties at
518    /// `:1261-1262`; live oracle `np.argmax([1.0, nan, 3.0, nan]) == 1`). The
519    /// index type is `i64` (ferray's `intp` analog), independent of `T`.
520    pub fn argmax(&self) -> Option<i64> {
521        arg_reduce(self.iter().copied(), false)
522    }
523
524    /// Flat index of the minimum element (whole-array reduction).
525    ///
526    /// Mirror of [`Array::argmax`] with min substituted for max: first
527    /// occurrence on ties, NaN-first, `None` on empty, `i64` index
528    /// (`numpy/_core/fromnumeric.py:1322`, ties at `:1361-1362`; live oracle
529    /// `np.argmin([1.0, nan, 3.0]) == 1`).
530    pub fn argmin(&self) -> Option<i64> {
531        arg_reduce(self.iter().copied(), true)
532    }
533
534    /// Indices of the maxima along `axis`, as an `Array<i64, IxDyn>` with the
535    /// reduced axis removed. First-occurrence on ties, NaN-first per lane
536    /// (`numpy/_core/fromnumeric.py:1222`).
537    ///
538    /// # Errors
539    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
540    /// `FerrayError::ShapeMismatch` if the reduced axis is empty (matching
541    /// numpy's `ValueError` on an empty argmax axis).
542    pub fn argmax_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
543    where
544        D::NdarrayDim: ndarray::RemoveAxis,
545    {
546        self.arg_axis(axis, false)
547    }
548
549    /// Indices of the minima along `axis`. See [`Array::argmax_axis`].
550    pub fn argmin_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
551    where
552        D::NdarrayDim: ndarray::RemoveAxis,
553    {
554        self.arg_axis(axis, true)
555    }
556
557    /// Internal: per-lane arg-reduction along `axis`. Each lane is a 1D view
558    /// orthogonal to `axis`; the reduced index is the position within the lane.
559    fn arg_axis(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<i64, IxDyn>>
560    where
561        D::NdarrayDim: ndarray::RemoveAxis,
562    {
563        let ndim = self.ndim();
564        if axis.index() >= ndim {
565            return Err(crate::error::FerrayError::axis_out_of_bounds(
566                axis.index(),
567                ndim,
568            ));
569        }
570        if self.shape()[axis.index()] == 0 {
571            return Err(crate::error::FerrayError::shape_mismatch(
572                "attempt to get argmax/argmin of an empty axis",
573            ));
574        }
575        let nd_axis = ndarray::Axis(axis.index());
576        let lanes = self.inner.lanes(nd_axis);
577        let mut out: Vec<i64> = Vec::with_capacity(lanes.into_iter().len());
578        for lane in self.inner.lanes(nd_axis) {
579            // Lane is non-empty (empty axis already rejected), so arg_reduce
580            // returns Some; default 0 is unreachable but keeps the code panic-free.
581            let idx = arg_reduce(lane.iter().copied(), take_min).unwrap_or(0);
582            out.push(idx);
583        }
584        let mut out_shape: Vec<usize> = self.shape().to_vec();
585        out_shape.remove(axis.index());
586        Array::from_vec(IxDyn::from(&out_shape[..]), out)
587    }
588
589    /// Internal: per-lane min/max via manual lane iteration. Avoids the
590    /// init-bias problem of `fold_axis` (which applies a single init to every
591    /// lane, even though min/max have no identity element).
592    fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
593    where
594        D::NdarrayDim: ndarray::RemoveAxis,
595    {
596        let nd_axis = ndarray::Axis(axis.index());
597        // Use ndarray's lane iteration directly via the inner ndarray::ArrayBase.
598        // Each lane is a 1D view orthogonal to the chosen axis.
599        let lanes = self.inner.lanes(nd_axis);
600        let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
601        for lane in self.inner.lanes(nd_axis) {
602            let mut iter = lane.iter().copied();
603            let first = iter.next().unwrap(); // safe: empty axis already rejected
604            let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
605            out.push(result);
606        }
607
608        // Output shape: drop the reduced axis from the input shape.
609        let mut out_shape: Vec<usize> = self.shape().to_vec();
610        out_shape.remove(axis.index());
611        Array::from_vec(IxDyn::from(&out_shape[..]), out)
612    }
613}
614
615// ---------------------------------------------------------------------------
616// Mean / Var / Std — require Float
617// ---------------------------------------------------------------------------
618
619impl<T, D> Array<T, D>
620where
621    T: MeanAcc,
622    D: Dimension,
623{
624    /// Arithmetic mean of all elements. Returns `None` for an empty array.
625    ///
626    /// The result type is the NumPy mean accumulator [`MeanAcc::Mean`]:
627    /// bool / integer inputs average in (and return) `f64`, while `f32`/`f64`/
628    /// complex keep their own dtype. This matches numpy, which casts bool /
629    /// unsigned / signed int to `float64` before averaging
630    /// (`numpy/_core/_methods.py:124-127`), so `Array::<i32, _>::mean()` is
631    /// `Some(f64)` and `Array::<bool, _>::mean()` is `Some(f64)` (e.g.
632    /// `0.666…`), matching `np.mean`.
633    pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
634        let n = self.size();
635        if n == 0 {
636            return None;
637        }
638        let sum = self
639            .iter()
640            .copied()
641            .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
642        Some(sum / <T as MeanAcc>::count(n))
643    }
644
645    /// Mean along an axis. Element type is the promoted [`MeanAcc::Mean`]
646    /// (bool / integer lanes average in `f64`; `f32`/`f64` stay themselves).
647    pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<<T as MeanAcc>::Mean, IxDyn>>
648    where
649        <T as MeanAcc>::Mean: ReduceAcc<Acc = <T as MeanAcc>::Mean>,
650        D::NdarrayDim: ndarray::RemoveAxis,
651    {
652        let ndim = self.ndim();
653        if axis.index() >= ndim {
654            return Err(crate::error::FerrayError::axis_out_of_bounds(
655                axis.index(),
656                ndim,
657            ));
658        }
659        let n = self.shape()[axis.index()];
660        if n == 0 {
661            return Err(crate::error::FerrayError::shape_mismatch(
662                "cannot compute mean along empty axis",
663            ));
664        }
665        // Widen each element into the mean accumulator, then sum along the axis
666        // and divide by the lane length.
667        let widened = self.map_to::<<T as MeanAcc>::Mean>(MeanAcc::widen_mean);
668        let sums = widened.sum_axis(axis)?;
669        let n_t = <T as MeanAcc>::count(n);
670        Ok(sums.mapv(|x| x / n_t))
671    }
672}
673
674impl<T, D> Array<T, D>
675where
676    T: Element + Float + MeanAcc<Mean = T>,
677    D: Dimension,
678{
679    /// Variance with `ddof` degrees of freedom (Bessel's correction = 1).
680    ///
681    /// Returns `None` for an empty array, or when `ddof >= n`.
682    pub fn var(&self, ddof: usize) -> Option<T> {
683        let n = self.size();
684        if n == 0 || ddof >= n {
685            return None;
686        }
687        let mean = self.mean()?;
688        let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
689            acc + (x - mean) * (x - mean)
690        });
691        Some(sum_sq / T::from(n - ddof).unwrap())
692    }
693
694    /// Standard deviation with `ddof` degrees of freedom.
695    pub fn std(&self, ddof: usize) -> Option<T> {
696        self.var(ddof).map(num_traits::Float::sqrt)
697    }
698}
699
700// ---------------------------------------------------------------------------
701// any / all — for bool arrays
702// ---------------------------------------------------------------------------
703
704impl<D> Array<bool, D>
705where
706    D: Dimension,
707{
708    /// Returns `true` if any element is `true`.
709    pub fn any(&self) -> bool {
710        self.iter().any(|&x| x)
711    }
712
713    /// Returns `true` if all elements are `true`. Vacuously `true` for empty arrays.
714    pub fn all(&self) -> bool {
715        self.iter().all(|&x| x)
716    }
717}
718
719// ---------------------------------------------------------------------------
720// ArrayView mirrors — same methods on borrowed views
721// ---------------------------------------------------------------------------
722
723impl<T, D> ArrayView<'_, T, D>
724where
725    T: Element + Copy,
726    D: Dimension,
727{
728    /// Sum of all elements. See [`Array::sum`] — returns the promoted
729    /// [`ReduceAcc::Acc`].
730    pub fn sum(&self) -> <T as ReduceAcc>::Acc
731    where
732        T: ReduceAcc,
733    {
734        let mut acc = <T as ReduceAcc>::Acc::zero();
735        for &x in self.iter() {
736            acc = acc + x.widen();
737        }
738        acc
739    }
740
741    /// Product of all elements. See [`Array::prod`] — returns the promoted
742    /// [`ReduceAcc::Acc`].
743    pub fn prod(&self) -> <T as ReduceAcc>::Acc
744    where
745        T: ReduceAcc,
746    {
747        let mut acc = <T as ReduceAcc>::Acc::one();
748        for &x in self.iter() {
749            acc = acc * x.widen();
750        }
751        acc
752    }
753}
754
755impl<T, D> ArrayView<'_, T, D>
756where
757    T: Element + Copy + PartialOrd,
758    D: Dimension,
759{
760    /// Minimum value. See [`Array::min`].
761    pub fn min(&self) -> Option<T> {
762        let mut iter = self.iter().copied();
763        let first = iter.next()?;
764        Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
765    }
766
767    /// Maximum value. See [`Array::max`].
768    pub fn max(&self) -> Option<T> {
769        let mut iter = self.iter().copied();
770        let first = iter.next()?;
771        Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
772    }
773
774    /// Flat index of the maximum element. See [`Array::argmax`].
775    pub fn argmax(&self) -> Option<i64> {
776        arg_reduce(self.iter().copied(), false)
777    }
778
779    /// Flat index of the minimum element. See [`Array::argmin`].
780    pub fn argmin(&self) -> Option<i64> {
781        arg_reduce(self.iter().copied(), true)
782    }
783}
784
785impl<T, D> ArrayView<'_, T, D>
786where
787    T: MeanAcc,
788    D: Dimension,
789{
790    /// Mean. See [`Array::mean`] — returns the promoted [`MeanAcc::Mean`].
791    pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
792        let n = self.size();
793        if n == 0 {
794            return None;
795        }
796        let sum = self
797            .iter()
798            .copied()
799            .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
800        Some(sum / <T as MeanAcc>::count(n))
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use crate::dimension::{Ix1, Ix2};
808
809    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
810        let n = data.len();
811        Array::from_vec(Ix1::new([n]), data).unwrap()
812    }
813
814    fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
815        Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
816    }
817
818    // ----- sum / prod -----
819
820    #[test]
821    fn sum_1d() {
822        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
823        assert_eq!(a.sum(), 10.0);
824    }
825
826    #[test]
827    fn sum_empty_returns_zero() {
828        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
829        assert_eq!(a.sum(), 0.0);
830    }
831
832    #[test]
833    fn sum_axis_2d() {
834        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
835        // Sum across rows (axis 0): [1+4, 2+5, 3+6] = [5, 7, 9]
836        let s0 = a.sum_axis(Axis(0)).unwrap();
837        assert_eq!(s0.shape(), &[3]);
838        assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
839
840        // Sum across columns (axis 1): [1+2+3, 4+5+6] = [6, 15]
841        let s1 = a.sum_axis(Axis(1)).unwrap();
842        assert_eq!(s1.shape(), &[2]);
843        assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
844    }
845
846    #[test]
847    fn prod_1d() {
848        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
849        assert_eq!(a.prod(), 24.0);
850    }
851
852    #[test]
853    fn prod_empty_returns_one() {
854        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
855        assert_eq!(a.prod(), 1.0);
856    }
857
858    #[test]
859    fn prod_axis_2d() {
860        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
861        let p0 = a.prod_axis(Axis(0)).unwrap();
862        assert_eq!(
863            p0.iter().copied().collect::<Vec<_>>(),
864            vec![4.0, 10.0, 18.0]
865        );
866
867        let p1 = a.prod_axis(Axis(1)).unwrap();
868        assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
869    }
870
871    // ----- min / max -----
872
873    #[test]
874    fn min_max_1d() {
875        let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
876        assert_eq!(a.min(), Some(1.0));
877        assert_eq!(a.max(), Some(9.0));
878    }
879
880    #[test]
881    fn min_max_empty_returns_none() {
882        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
883        assert_eq!(a.min(), None);
884        assert_eq!(a.max(), None);
885    }
886
887    #[test]
888    fn min_max_int() {
889        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
890        assert_eq!(a.min(), Some(-5));
891        assert_eq!(a.max(), Some(4));
892    }
893
894    #[test]
895    fn min_max_axis_2d() {
896        let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
897        // axis 0: min/max per column
898        let mn0 = a.min_axis(Axis(0)).unwrap();
899        assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
900        let mx0 = a.max_axis(Axis(0)).unwrap();
901        assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
902
903        // axis 1: min/max per row
904        let mn1 = a.min_axis(Axis(1)).unwrap();
905        assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
906        let mx1 = a.max_axis(Axis(1)).unwrap();
907        assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
908    }
909
910    // ----- mean / var / std -----
911
912    #[test]
913    fn mean_1d() {
914        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
915        assert_eq!(a.mean(), Some(2.5));
916    }
917
918    #[test]
919    fn mean_empty_returns_none() {
920        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
921        assert_eq!(a.mean(), None);
922    }
923
924    #[test]
925    fn mean_axis_2d() {
926        let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
927        let m0 = a.mean_axis(Axis(0)).unwrap();
928        assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
929        let m1 = a.mean_axis(Axis(1)).unwrap();
930        assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
931    }
932
933    #[test]
934    fn var_population() {
935        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
936        // population variance (ddof=0): ((1-3)^2+(2-3)^2+(3-3)^2+(4-3)^2+(5-3)^2)/5 = 10/5 = 2
937        assert_eq!(a.var(0), Some(2.0));
938    }
939
940    #[test]
941    fn var_sample() {
942        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
943        // sample variance (ddof=1): 10/4 = 2.5
944        assert_eq!(a.var(1), Some(2.5));
945    }
946
947    #[test]
948    fn std_basic() {
949        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
950        let s = a.std(0).unwrap();
951        assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
952    }
953
954    #[test]
955    fn var_ddof_too_large_returns_none() {
956        let a = arr1(vec![1.0, 2.0]);
957        assert_eq!(a.var(2), None);
958        assert_eq!(a.var(5), None);
959    }
960
961    // ----- any / all -----
962
963    #[test]
964    fn any_all_bool() {
965        let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
966        let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
967        let false_arr =
968            Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
969        let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
970
971        assert!(true_arr.all());
972        assert!(true_arr.any());
973
974        assert!(!mixed.all());
975        assert!(mixed.any());
976
977        assert!(!false_arr.all());
978        assert!(!false_arr.any());
979
980        // Vacuous truth for empty
981        assert!(empty.all());
982        assert!(!empty.any());
983    }
984
985    // ----- ArrayView mirrors -----
986
987    #[test]
988    fn view_sum_min_max_mean() {
989        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
990        let v = a.view();
991        assert_eq!(v.sum(), 10.0);
992        assert_eq!(v.min(), Some(1.0));
993        assert_eq!(v.max(), Some(4.0));
994        assert_eq!(v.mean(), Some(2.5));
995    }
996
997    #[test]
998    fn nan_propagates_in_min_max() {
999        // NaN somewhere in the middle
1000        let a = arr1(vec![1.0, f64::NAN, 3.0]);
1001        assert!(a.min().unwrap().is_nan());
1002        assert!(a.max().unwrap().is_nan());
1003
1004        // NaN at the start
1005        let b = arr1(vec![f64::NAN, 1.0, 3.0]);
1006        assert!(b.min().unwrap().is_nan());
1007        assert!(b.max().unwrap().is_nan());
1008
1009        // NaN at the end
1010        let c = arr1(vec![1.0, 3.0, f64::NAN]);
1011        assert!(c.min().unwrap().is_nan());
1012        assert!(c.max().unwrap().is_nan());
1013    }
1014}