ferray_core/array/reductions.rs
1// ferray-core: Reduction methods for Array<T, D>
2//
3// Provides NumPy-equivalent reduction methods directly on Array and ArrayView:
4// sum, prod, min, max, mean, var, std, any, all
5//
6// Each reduction has a whole-array variant and an axis variant:
7// .sum() -> T
8// .sum_axis(Axis(0)) -> Array<T, IxDyn>
9//
10// These methods complement the lower-level fold_axis primitive in methods.rs
11// and the free-function reductions in ferray-stats. The instance-method form
12// matches NumPy's `arr.sum()` ergonomics so users don't need to import a
13// separate crate just to compute a sum.
14//
15// ## REQ status (reductions, NumPy parity)
16// - sum / prod / sum_axis / prod_axis accumulator promotion — SHIPPED (#780):
17// the `ReduceAcc` trait (this file) maps narrow signed ints → i64, narrow
18// unsigned ints → u64, and bool → i64 before reducing, so narrow-int
19// reductions never overflow and the result dtype matches numpy
20// (numpy/_core/fromnumeric.py:2321-2327). Consumers: `Array::sum`/`prod`/
21// `sum_axis`/`prod_axis` and the `ArrayView` mirrors (all in this file).
22// - min / max / mean / var / std / any / all — SHIPPED (#368), NaN-propagating.
23// - empty-array min/max raising ValueError — NOT-STARTED (open blocker #782).
24// - cumsum / cumprod live as ferray-stats free functions; their narrow-int
25// promotion is a separate ferray-stats blocker that reuses `ReduceAcc`.
26// - argmax / argmin (REQ-40, REQ-41) — SHIPPED: `Array::argmax`/`argmin`
27// (flattened, returning `Option<i64>`) and `argmax_axis`/`argmin_axis`
28// (returning `Array<i64, IxDyn>`) in this file. First-occurrence on ties and
29// NaN-first propagation (numpy/_core/fromnumeric.py:1222 `argmax`,
30// :1261-1262 ties; :1322 `argmin`, :1361-1362 ties). Empty flattened form
31// returns `None` (mirroring `min`/`max`'s `Option` analog; the ferray-python
32// boundary maps `None`→`ValueError` as numpy does). Result index dtype is
33// `i64` (ferray's `intp` analog), independent of element dtype. Consumers:
34// the boundary methods themselves are the public API surface (like
35// `sum`/`min`), exercised by the `ArrayView` mirror and the in-file tests.
36// - integer/bool mean → f64 (REQ-42) — SHIPPED: the `MeanAcc` trait (this
37// file) maps bool/integer element types to an `f64` accumulator-and-result,
38// while `f32`/`f64`/complex stay themselves, so `Array::<i32, _>::mean()`
39// returns `f64` matching numpy, which casts bool/unsigned/signed int to
40// float64 before averaging (numpy/_core/_methods.py:124-127). `mean`/
41// `mean_axis` and the `ArrayView::mean` mirror are now bounded by `MeanAcc`
42// instead of `Float`; existing `f32`/`f64` means are unchanged
43// (`MeanAcc::Mean == Self`). Consumers: `Array::var`/`std` (`self.mean()?`)
44// and the `ArrayView` mirror, all in this file.
45//
46// See: https://github.com/dollspace-gay/ferray/issues/368, /issues/780
47
48use num_traits::Float;
49
50use crate::array::owned::Array;
51use crate::array::view::ArrayView;
52use crate::dimension::{Axis, Dimension, IxDyn};
53use crate::dtype::Element;
54use crate::error::FerrayResult;
55
56// ---------------------------------------------------------------------------
57// ReduceAcc — NumPy's sum/prod/cumsum/cumprod accumulator-and-result dtype.
58// ---------------------------------------------------------------------------
59
60/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
61/// return) `sum` / `prod` / `cumsum` / `cumprod` over it.
62///
63/// NumPy promotes any integer dtype of *less precision than the default
64/// platform integer* before reducing, so a narrow-int reduction can never
65/// overflow and the result dtype is the platform integer:
66///
67/// > "The dtype of `a` is used by default unless `a` has an integer dtype of
68/// > less precision than the default platform integer. In that case, if `a`
69/// > is signed then the platform integer is used while if `a` is unsigned then
70/// > an unsigned integer of the same precision as the platform integer is
71/// > used."
72/// > — `numpy/_core/fromnumeric.py:2321-2327` (sum), `:3306-3312` (prod)
73///
74/// The reduction itself is `umr_sum = um.add.reduce` /
75/// `umr_prod = um.multiply.reduce` (`numpy/_core/_methods.py:20-21`), i.e. the
76/// add/multiply ufunc whose *loop dtype* is the promoted accumulator.
77///
78/// The mapping (platform integer = 64-bit, matching ferray's `intp`/`int64`):
79/// - `i8 / i16 / i32 → i64`, `i64 → i64`, `i128 → i128`
80/// - `u8 / u16 / u32 → u64`, `u64 → u64`, `u128 → u128`
81/// - `bool → i64` (NumPy reduces bool as the platform integer, counting `true`)
82/// - `f32 → f32`, `f64 → f64`, complex stays itself (no promotion)
83///
84/// Wider-or-equal dtypes map to themselves, so existing `f64`/`i64`/complex
85/// reductions are unchanged — only narrow-int callers observe the promoted
86/// return type.
87pub trait ReduceAcc: Element + Copy {
88 /// The accumulator-and-result element type for reductions over `Self`.
89 type Acc: Element + Copy + std::ops::Add<Output = Self::Acc> + std::ops::Mul<Output = Self::Acc>;
90
91 /// Widen one element into the accumulator type before reducing, matching
92 /// NumPy's promotion of the loop dtype (`true → 1` for `bool`).
93 fn widen(self) -> Self::Acc;
94}
95
96macro_rules! impl_reduce_acc {
97 ($($t:ty => $acc:ty),* $(,)?) => {
98 $(
99 impl ReduceAcc for $t {
100 type Acc = $acc;
101 #[inline]
102 fn widen(self) -> $acc {
103 self as $acc
104 }
105 }
106 )*
107 };
108}
109
110// Narrow signed ints promote to i64; i64/i128 stay themselves.
111impl_reduce_acc! {
112 i8 => i64, i16 => i64, i32 => i64, i64 => i64, i128 => i128,
113 u8 => u64, u16 => u64, u32 => u64, u64 => u64, u128 => u128,
114 f32 => f32, f64 => f64,
115}
116
117// bool reduces as the platform integer, counting `true` (numpy:
118// `np.sum(np.array([True, True, True])).dtype == int64`). `as i64` maps
119// false→0, true→1.
120impl ReduceAcc for bool {
121 type Acc = i64;
122 #[inline]
123 fn widen(self) -> i64 {
124 i64::from(self)
125 }
126}
127
128// Complex stays itself — numpy never promotes a complex reduction.
129impl ReduceAcc for num_complex::Complex<f32> {
130 type Acc = num_complex::Complex<f32>;
131 #[inline]
132 fn widen(self) -> Self {
133 self
134 }
135}
136
137impl ReduceAcc for num_complex::Complex<f64> {
138 type Acc = num_complex::Complex<f64>;
139 #[inline]
140 fn widen(self) -> Self {
141 self
142 }
143}
144
145// ---------------------------------------------------------------------------
146// MeanAcc — NumPy's mean accumulator-and-result dtype.
147// ---------------------------------------------------------------------------
148
149/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
150/// return) `mean` over it.
151///
152/// NumPy casts a bool / unsigned-int / signed-int input to `float64` before
153/// averaging:
154///
155/// > "Cast bool, unsigned int, and int to float64 by default ...
156/// > `dtype = mu.dtype('f8')`"
157/// > — `numpy/_core/_methods.py:124-127`
158///
159/// so `np.mean(np.array([1, 2, 3], np.int32))` is `float64 2.0` and
160/// `np.mean([True, False, True])` is `float64 0.6666…`. Floating-point inputs
161/// keep their own dtype (`f32`→`f32`, `f64`→`f64`), and complex stays itself.
162///
163/// The mapping:
164/// - `bool / i8.. / u8.. → f64`
165/// - `f32 → f32`, `f64 → f64` (unchanged — `Mean == Self`)
166/// - `Complex<f32> → Complex<f32>`, `Complex<f64> → Complex<f64>`
167pub trait MeanAcc: Element + Copy {
168 /// The accumulator-and-result element type for `mean` over `Self`.
169 type Mean: Element
170 + Copy
171 + std::ops::Add<Output = Self::Mean>
172 + std::ops::Div<Output = Self::Mean>;
173
174 /// Widen one element into the mean accumulator type, matching NumPy's
175 /// pre-average cast (`true → 1.0`, `false → 0.0` for `bool`).
176 fn widen_mean(self) -> Self::Mean;
177
178 /// Construct the divisor (element count `n`) in the accumulator type.
179 fn count(n: usize) -> Self::Mean;
180}
181
182macro_rules! impl_mean_acc_to_f64 {
183 ($($t:ty),* $(,)?) => {
184 $(
185 impl MeanAcc for $t {
186 type Mean = f64;
187 #[inline]
188 fn widen_mean(self) -> f64 {
189 self as f64
190 }
191 #[inline]
192 fn count(n: usize) -> f64 {
193 n as f64
194 }
195 }
196 )*
197 };
198}
199
200// bool / all integer dtypes average in f64 (numpy/_core/_methods.py:124-127).
201impl_mean_acc_to_f64!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
202
203impl MeanAcc for bool {
204 type Mean = f64;
205 #[inline]
206 fn widen_mean(self) -> f64 {
207 if self { 1.0 } else { 0.0 }
208 }
209 #[inline]
210 fn count(n: usize) -> f64 {
211 n as f64
212 }
213}
214
215// Floating-point inputs keep their own dtype (Mean == Self).
216impl MeanAcc for f32 {
217 type Mean = f32;
218 #[inline]
219 fn widen_mean(self) -> f32 {
220 self
221 }
222 #[inline]
223 fn count(n: usize) -> f32 {
224 n as f32
225 }
226}
227
228impl MeanAcc for f64 {
229 type Mean = f64;
230 #[inline]
231 fn widen_mean(self) -> f64 {
232 self
233 }
234 #[inline]
235 fn count(n: usize) -> f64 {
236 n as f64
237 }
238}
239
240impl MeanAcc for num_complex::Complex<f32> {
241 type Mean = num_complex::Complex<f32>;
242 #[inline]
243 fn widen_mean(self) -> Self {
244 self
245 }
246 #[inline]
247 fn count(n: usize) -> Self {
248 num_complex::Complex::new(n as f32, 0.0)
249 }
250}
251
252impl MeanAcc for num_complex::Complex<f64> {
253 type Mean = num_complex::Complex<f64>;
254 #[inline]
255 fn widen_mean(self) -> Self {
256 self
257 }
258 #[inline]
259 fn count(n: usize) -> Self {
260 num_complex::Complex::new(n as f64, 0.0)
261 }
262}
263
264/// First-occurrence, NaN-first arg-reduction over a flat element iterator.
265///
266/// Mirrors NumPy's `argmax`/`argmin` (`numpy/_core/fromnumeric.py:1222`,
267/// `:1322`): on ties the *first* occurrence wins (`:1261-1262`, `:1361-1362`),
268/// and when any NaN is present the index of the *first* NaN is returned
269/// (NaN-propagating, NaN-first — live oracle numpy 2.4.5:
270/// `np.argmax([1.0, nan, 3.0, nan]) == 1`).
271///
272/// Returns `None` for an empty iterator, the `Option` analog `min`/`max` use;
273/// the ferray-python boundary maps `None`→`ValueError` as numpy does.
274#[inline]
275fn arg_reduce<T: PartialOrd + Copy>(iter: impl Iterator<Item = T>, take_min: bool) -> Option<i64> {
276 let mut best_idx: Option<i64> = None;
277 let mut best: Option<T> = None;
278 for (i, x) in iter.enumerate() {
279 let i = i as i64;
280 // NaN-first: the first NaN seen wins immediately and is never beaten.
281 if x.partial_cmp(&x).is_none() {
282 return Some(i);
283 }
284 match best {
285 None => {
286 best = Some(x);
287 best_idx = Some(i);
288 }
289 Some(b) => {
290 // Strict comparison => first occurrence wins on ties.
291 let replace = match x.partial_cmp(&b) {
292 Some(std::cmp::Ordering::Less) => take_min,
293 Some(std::cmp::Ordering::Greater) => !take_min,
294 _ => false,
295 };
296 if replace {
297 best = Some(x);
298 best_idx = Some(i);
299 }
300 }
301 }
302 }
303 best_idx
304}
305
306/// Generic min/max fold step that propagates NaN per `NumPy` semantics.
307///
308/// Once any NaN enters the fold, all subsequent steps return NaN. Detected
309/// generically via `x.partial_cmp(&x).is_none()`, which is true iff `x` is
310/// NaN (or any other value that violates `PartialOrd` reflexivity, e.g.
311/// `Complex` types — but those don't implement `PartialOrd` so this is moot).
312///
313/// On an equal compare (`Ordering::Equal`) the NEW element `x` is kept, not the
314/// accumulator. This mirrors `numpy`'s `maximum.reduce`/`minimum.reduce`
315/// (`numpy/_core/_methods.py:38-44`, `umr_maximum`/`umr_minimum`), whose
316/// underlying scalar `maximum(a, b)`/`minimum(a, b)` loops return the *later*
317/// operand on ties — observable only for signed zeros (`+0.0 == -0.0`), where
318/// numpy keeps the LAST seen zero's sign bit. For any non-signed-zero equal
319/// pair the values are identical, so this changes nothing. This is the VALUE
320/// min/max reduce; `argmin`/`argmax` use first-occurrence on ties and live on a
321/// separate code path (they do not call `reduce_step`).
322#[inline]
323fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
324 let acc_is_nan = acc.partial_cmp(&acc).is_none();
325 if acc_is_nan {
326 return acc;
327 }
328 let x_is_nan = x.partial_cmp(&x).is_none();
329 if x_is_nan {
330 return x;
331 }
332 match (take_min, x.partial_cmp(&acc)) {
333 (true, Some(std::cmp::Ordering::Less)) => x,
334 (false, Some(std::cmp::Ordering::Greater)) => x,
335 // Tie: keep the LAST operand (numpy maximum/minimum.reduce semantics).
336 (_, Some(std::cmp::Ordering::Equal)) => x,
337 _ => acc,
338 }
339}
340
341// ---------------------------------------------------------------------------
342// Sum / Prod (work for any Element with Add/Mul, using Element::zero/one)
343// ---------------------------------------------------------------------------
344
345impl<T, D> Array<T, D>
346where
347 T: Element + Copy,
348 D: Dimension,
349{
350 /// Sum of all elements (whole-array reduction).
351 ///
352 /// The result type is the NumPy reduction accumulator
353 /// [`ReduceAcc::Acc`]: narrow signed ints widen to `i64`, narrow unsigned
354 /// ints to `u64`, `bool` to `i64`, and `f32`/`f64`/complex stay
355 /// themselves. This means a narrow-int sum can never overflow and its
356 /// dtype matches `np.sum`'s promoted result
357 /// (`numpy/_core/fromnumeric.py:2321-2327`).
358 ///
359 /// Returns `Acc::zero()` for an empty array.
360 ///
361 /// # Examples
362 /// ```
363 /// # use ferray_core::Array;
364 /// # use ferray_core::dimension::Ix1;
365 /// let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
366 /// assert_eq!(a.sum(), 6.0);
367 /// // i8 sums promote to i64 and never overflow (numpy parity):
368 /// let b = Array::<i8, Ix1>::from_vec(Ix1::new([3]), vec![100, 100, 100]).unwrap();
369 /// assert_eq!(b.sum(), 300_i64);
370 /// ```
371 pub fn sum(&self) -> <T as ReduceAcc>::Acc
372 where
373 T: ReduceAcc,
374 {
375 let mut acc = <T as ReduceAcc>::Acc::zero();
376 for &x in self.iter() {
377 acc = acc + x.widen();
378 }
379 acc
380 }
381
382 /// Sum along the given axis. Returns an array with one fewer dimension,
383 /// whose element type is the promoted [`ReduceAcc::Acc`] (same numpy
384 /// narrow-int promotion as the whole-array [`Array::sum`]).
385 ///
386 /// # Errors
387 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
388 pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
389 where
390 T: ReduceAcc,
391 D::NdarrayDim: ndarray::RemoveAxis,
392 {
393 let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
394 widened.fold_axis(axis, <T as ReduceAcc>::Acc::zero(), |acc, &x| *acc + x)
395 }
396
397 /// Product of all elements.
398 ///
399 /// The result type is the promoted [`ReduceAcc::Acc`] (same numpy
400 /// narrow-int promotion as [`Array::sum`]; see
401 /// `numpy/_core/fromnumeric.py:3306-3312`).
402 ///
403 /// Returns `Acc::one()` for an empty array.
404 pub fn prod(&self) -> <T as ReduceAcc>::Acc
405 where
406 T: ReduceAcc,
407 {
408 let mut acc = <T as ReduceAcc>::Acc::one();
409 for &x in self.iter() {
410 acc = acc * x.widen();
411 }
412 acc
413 }
414
415 /// Product along the given axis. Element type is the promoted
416 /// [`ReduceAcc::Acc`].
417 pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
418 where
419 T: ReduceAcc,
420 D::NdarrayDim: ndarray::RemoveAxis,
421 {
422 let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
423 widened.fold_axis(axis, <T as ReduceAcc>::Acc::one(), |acc, &x| *acc * x)
424 }
425}
426
427// ---------------------------------------------------------------------------
428// Min / Max — require PartialOrd
429// ---------------------------------------------------------------------------
430
431impl<T, D> Array<T, D>
432where
433 T: Element + Copy + PartialOrd,
434 D: Dimension,
435{
436 /// Minimum value across the entire array.
437 ///
438 /// Returns `None` if the array is empty. NaN values follow `NumPy` semantics:
439 /// once a NaN is seen the result stays NaN, detected via self-comparison
440 /// (`x.partial_cmp(&x).is_none()`).
441 pub fn min(&self) -> Option<T> {
442 let mut iter = self.iter().copied();
443 let first = iter.next()?;
444 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
445 }
446
447 /// Maximum value across the entire array.
448 ///
449 /// Returns `None` if the array is empty. NaN values propagate per `NumPy`.
450 pub fn max(&self) -> Option<T> {
451 let mut iter = self.iter().copied();
452 let first = iter.next()?;
453 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
454 }
455
456 /// Minimum value along an axis.
457 ///
458 /// # Errors
459 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
460 /// `FerrayError::ShapeMismatch` if the resulting axis would be empty.
461 pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
462 where
463 D::NdarrayDim: ndarray::RemoveAxis,
464 {
465 // Use the first element along the axis as init by sentinel: pull the
466 // first lane and fold the rest. fold_axis applies init to every lane,
467 // but min has no neutral identity for arbitrary T. We sidestep by
468 // folding starting from any element of `self` — the per-lane init is
469 // overwritten by the first comparison, which is correct iff every lane
470 // has at least one element. Empty axes would yield uninitialized data.
471 let ndim = self.ndim();
472 if axis.index() >= ndim {
473 return Err(crate::error::FerrayError::axis_out_of_bounds(
474 axis.index(),
475 ndim,
476 ));
477 }
478 if self.shape()[axis.index()] == 0 {
479 return Err(crate::error::FerrayError::shape_mismatch(
480 "cannot compute min along empty axis",
481 ));
482 }
483 // Manual lane iteration: fold_axis can't be used here because min has
484 // no neutral identity that works for arbitrary `T: PartialOrd` (no
485 // T::infinity for ints).
486 self.fold_axis_min_max(axis, true)
487 }
488
489 /// Maximum value along an axis.
490 ///
491 /// See [`Array::min_axis`] for error semantics.
492 pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
493 where
494 D::NdarrayDim: ndarray::RemoveAxis,
495 {
496 let ndim = self.ndim();
497 if axis.index() >= ndim {
498 return Err(crate::error::FerrayError::axis_out_of_bounds(
499 axis.index(),
500 ndim,
501 ));
502 }
503 if self.shape()[axis.index()] == 0 {
504 return Err(crate::error::FerrayError::shape_mismatch(
505 "cannot compute max along empty axis",
506 ));
507 }
508 self.fold_axis_min_max(axis, false)
509 }
510
511 /// Flat index of the maximum element (whole-array reduction).
512 ///
513 /// Returns `None` for an empty array (the `Option` analog `min`/`max`
514 /// use; the ferray-python boundary maps `None`→`ValueError`, matching
515 /// `np.argmax([])`). On ties the **first** occurrence wins, and when any
516 /// NaN is present the index of the **first** NaN is returned (NaN-first),
517 /// matching `np.argmax` (`numpy/_core/fromnumeric.py:1222`, ties at
518 /// `:1261-1262`; live oracle `np.argmax([1.0, nan, 3.0, nan]) == 1`). The
519 /// index type is `i64` (ferray's `intp` analog), independent of `T`.
520 pub fn argmax(&self) -> Option<i64> {
521 arg_reduce(self.iter().copied(), false)
522 }
523
524 /// Flat index of the minimum element (whole-array reduction).
525 ///
526 /// Mirror of [`Array::argmax`] with min substituted for max: first
527 /// occurrence on ties, NaN-first, `None` on empty, `i64` index
528 /// (`numpy/_core/fromnumeric.py:1322`, ties at `:1361-1362`; live oracle
529 /// `np.argmin([1.0, nan, 3.0]) == 1`).
530 pub fn argmin(&self) -> Option<i64> {
531 arg_reduce(self.iter().copied(), true)
532 }
533
534 /// Indices of the maxima along `axis`, as an `Array<i64, IxDyn>` with the
535 /// reduced axis removed. First-occurrence on ties, NaN-first per lane
536 /// (`numpy/_core/fromnumeric.py:1222`).
537 ///
538 /// # Errors
539 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
540 /// `FerrayError::ShapeMismatch` if the reduced axis is empty (matching
541 /// numpy's `ValueError` on an empty argmax axis).
542 pub fn argmax_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
543 where
544 D::NdarrayDim: ndarray::RemoveAxis,
545 {
546 self.arg_axis(axis, false)
547 }
548
549 /// Indices of the minima along `axis`. See [`Array::argmax_axis`].
550 pub fn argmin_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
551 where
552 D::NdarrayDim: ndarray::RemoveAxis,
553 {
554 self.arg_axis(axis, true)
555 }
556
557 /// Internal: per-lane arg-reduction along `axis`. Each lane is a 1D view
558 /// orthogonal to `axis`; the reduced index is the position within the lane.
559 fn arg_axis(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<i64, IxDyn>>
560 where
561 D::NdarrayDim: ndarray::RemoveAxis,
562 {
563 let ndim = self.ndim();
564 if axis.index() >= ndim {
565 return Err(crate::error::FerrayError::axis_out_of_bounds(
566 axis.index(),
567 ndim,
568 ));
569 }
570 if self.shape()[axis.index()] == 0 {
571 return Err(crate::error::FerrayError::shape_mismatch(
572 "attempt to get argmax/argmin of an empty axis",
573 ));
574 }
575 let nd_axis = ndarray::Axis(axis.index());
576 let lanes = self.inner.lanes(nd_axis);
577 let mut out: Vec<i64> = Vec::with_capacity(lanes.into_iter().len());
578 for lane in self.inner.lanes(nd_axis) {
579 // Lane is non-empty (empty axis already rejected), so arg_reduce
580 // returns Some; default 0 is unreachable but keeps the code panic-free.
581 let idx = arg_reduce(lane.iter().copied(), take_min).unwrap_or(0);
582 out.push(idx);
583 }
584 let mut out_shape: Vec<usize> = self.shape().to_vec();
585 out_shape.remove(axis.index());
586 Array::from_vec(IxDyn::from(&out_shape[..]), out)
587 }
588
589 /// Internal: per-lane min/max via manual lane iteration. Avoids the
590 /// init-bias problem of `fold_axis` (which applies a single init to every
591 /// lane, even though min/max have no identity element).
592 fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
593 where
594 D::NdarrayDim: ndarray::RemoveAxis,
595 {
596 let nd_axis = ndarray::Axis(axis.index());
597 // Use ndarray's lane iteration directly via the inner ndarray::ArrayBase.
598 // Each lane is a 1D view orthogonal to the chosen axis.
599 let lanes = self.inner.lanes(nd_axis);
600 let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
601 for lane in self.inner.lanes(nd_axis) {
602 let mut iter = lane.iter().copied();
603 let first = iter.next().unwrap(); // safe: empty axis already rejected
604 let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
605 out.push(result);
606 }
607
608 // Output shape: drop the reduced axis from the input shape.
609 let mut out_shape: Vec<usize> = self.shape().to_vec();
610 out_shape.remove(axis.index());
611 Array::from_vec(IxDyn::from(&out_shape[..]), out)
612 }
613}
614
615// ---------------------------------------------------------------------------
616// Mean / Var / Std — require Float
617// ---------------------------------------------------------------------------
618
619impl<T, D> Array<T, D>
620where
621 T: MeanAcc,
622 D: Dimension,
623{
624 /// Arithmetic mean of all elements. Returns `None` for an empty array.
625 ///
626 /// The result type is the NumPy mean accumulator [`MeanAcc::Mean`]:
627 /// bool / integer inputs average in (and return) `f64`, while `f32`/`f64`/
628 /// complex keep their own dtype. This matches numpy, which casts bool /
629 /// unsigned / signed int to `float64` before averaging
630 /// (`numpy/_core/_methods.py:124-127`), so `Array::<i32, _>::mean()` is
631 /// `Some(f64)` and `Array::<bool, _>::mean()` is `Some(f64)` (e.g.
632 /// `0.666…`), matching `np.mean`.
633 pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
634 let n = self.size();
635 if n == 0 {
636 return None;
637 }
638 let sum = self
639 .iter()
640 .copied()
641 .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
642 Some(sum / <T as MeanAcc>::count(n))
643 }
644
645 /// Mean along an axis. Element type is the promoted [`MeanAcc::Mean`]
646 /// (bool / integer lanes average in `f64`; `f32`/`f64` stay themselves).
647 pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<<T as MeanAcc>::Mean, IxDyn>>
648 where
649 <T as MeanAcc>::Mean: ReduceAcc<Acc = <T as MeanAcc>::Mean>,
650 D::NdarrayDim: ndarray::RemoveAxis,
651 {
652 let ndim = self.ndim();
653 if axis.index() >= ndim {
654 return Err(crate::error::FerrayError::axis_out_of_bounds(
655 axis.index(),
656 ndim,
657 ));
658 }
659 let n = self.shape()[axis.index()];
660 if n == 0 {
661 return Err(crate::error::FerrayError::shape_mismatch(
662 "cannot compute mean along empty axis",
663 ));
664 }
665 // Widen each element into the mean accumulator, then sum along the axis
666 // and divide by the lane length.
667 let widened = self.map_to::<<T as MeanAcc>::Mean>(MeanAcc::widen_mean);
668 let sums = widened.sum_axis(axis)?;
669 let n_t = <T as MeanAcc>::count(n);
670 Ok(sums.mapv(|x| x / n_t))
671 }
672}
673
674impl<T, D> Array<T, D>
675where
676 T: Element + Float + MeanAcc<Mean = T>,
677 D: Dimension,
678{
679 /// Variance with `ddof` degrees of freedom (Bessel's correction = 1).
680 ///
681 /// Returns `None` for an empty array, or when `ddof >= n`.
682 pub fn var(&self, ddof: usize) -> Option<T> {
683 let n = self.size();
684 if n == 0 || ddof >= n {
685 return None;
686 }
687 let mean = self.mean()?;
688 let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
689 acc + (x - mean) * (x - mean)
690 });
691 Some(sum_sq / T::from(n - ddof).unwrap())
692 }
693
694 /// Standard deviation with `ddof` degrees of freedom.
695 pub fn std(&self, ddof: usize) -> Option<T> {
696 self.var(ddof).map(num_traits::Float::sqrt)
697 }
698}
699
700// ---------------------------------------------------------------------------
701// any / all — for bool arrays
702// ---------------------------------------------------------------------------
703
704impl<D> Array<bool, D>
705where
706 D: Dimension,
707{
708 /// Returns `true` if any element is `true`.
709 pub fn any(&self) -> bool {
710 self.iter().any(|&x| x)
711 }
712
713 /// Returns `true` if all elements are `true`. Vacuously `true` for empty arrays.
714 pub fn all(&self) -> bool {
715 self.iter().all(|&x| x)
716 }
717}
718
719// ---------------------------------------------------------------------------
720// ArrayView mirrors — same methods on borrowed views
721// ---------------------------------------------------------------------------
722
723impl<T, D> ArrayView<'_, T, D>
724where
725 T: Element + Copy,
726 D: Dimension,
727{
728 /// Sum of all elements. See [`Array::sum`] — returns the promoted
729 /// [`ReduceAcc::Acc`].
730 pub fn sum(&self) -> <T as ReduceAcc>::Acc
731 where
732 T: ReduceAcc,
733 {
734 let mut acc = <T as ReduceAcc>::Acc::zero();
735 for &x in self.iter() {
736 acc = acc + x.widen();
737 }
738 acc
739 }
740
741 /// Product of all elements. See [`Array::prod`] — returns the promoted
742 /// [`ReduceAcc::Acc`].
743 pub fn prod(&self) -> <T as ReduceAcc>::Acc
744 where
745 T: ReduceAcc,
746 {
747 let mut acc = <T as ReduceAcc>::Acc::one();
748 for &x in self.iter() {
749 acc = acc * x.widen();
750 }
751 acc
752 }
753}
754
755impl<T, D> ArrayView<'_, T, D>
756where
757 T: Element + Copy + PartialOrd,
758 D: Dimension,
759{
760 /// Minimum value. See [`Array::min`].
761 pub fn min(&self) -> Option<T> {
762 let mut iter = self.iter().copied();
763 let first = iter.next()?;
764 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
765 }
766
767 /// Maximum value. See [`Array::max`].
768 pub fn max(&self) -> Option<T> {
769 let mut iter = self.iter().copied();
770 let first = iter.next()?;
771 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
772 }
773
774 /// Flat index of the maximum element. See [`Array::argmax`].
775 pub fn argmax(&self) -> Option<i64> {
776 arg_reduce(self.iter().copied(), false)
777 }
778
779 /// Flat index of the minimum element. See [`Array::argmin`].
780 pub fn argmin(&self) -> Option<i64> {
781 arg_reduce(self.iter().copied(), true)
782 }
783}
784
785impl<T, D> ArrayView<'_, T, D>
786where
787 T: MeanAcc,
788 D: Dimension,
789{
790 /// Mean. See [`Array::mean`] — returns the promoted [`MeanAcc::Mean`].
791 pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
792 let n = self.size();
793 if n == 0 {
794 return None;
795 }
796 let sum = self
797 .iter()
798 .copied()
799 .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
800 Some(sum / <T as MeanAcc>::count(n))
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use crate::dimension::{Ix1, Ix2};
808
809 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
810 let n = data.len();
811 Array::from_vec(Ix1::new([n]), data).unwrap()
812 }
813
814 fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
815 Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
816 }
817
818 // ----- sum / prod -----
819
820 #[test]
821 fn sum_1d() {
822 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
823 assert_eq!(a.sum(), 10.0);
824 }
825
826 #[test]
827 fn sum_empty_returns_zero() {
828 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
829 assert_eq!(a.sum(), 0.0);
830 }
831
832 #[test]
833 fn sum_axis_2d() {
834 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
835 // Sum across rows (axis 0): [1+4, 2+5, 3+6] = [5, 7, 9]
836 let s0 = a.sum_axis(Axis(0)).unwrap();
837 assert_eq!(s0.shape(), &[3]);
838 assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
839
840 // Sum across columns (axis 1): [1+2+3, 4+5+6] = [6, 15]
841 let s1 = a.sum_axis(Axis(1)).unwrap();
842 assert_eq!(s1.shape(), &[2]);
843 assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
844 }
845
846 #[test]
847 fn prod_1d() {
848 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
849 assert_eq!(a.prod(), 24.0);
850 }
851
852 #[test]
853 fn prod_empty_returns_one() {
854 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
855 assert_eq!(a.prod(), 1.0);
856 }
857
858 #[test]
859 fn prod_axis_2d() {
860 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
861 let p0 = a.prod_axis(Axis(0)).unwrap();
862 assert_eq!(
863 p0.iter().copied().collect::<Vec<_>>(),
864 vec![4.0, 10.0, 18.0]
865 );
866
867 let p1 = a.prod_axis(Axis(1)).unwrap();
868 assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
869 }
870
871 // ----- min / max -----
872
873 #[test]
874 fn min_max_1d() {
875 let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
876 assert_eq!(a.min(), Some(1.0));
877 assert_eq!(a.max(), Some(9.0));
878 }
879
880 #[test]
881 fn min_max_empty_returns_none() {
882 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
883 assert_eq!(a.min(), None);
884 assert_eq!(a.max(), None);
885 }
886
887 #[test]
888 fn min_max_int() {
889 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
890 assert_eq!(a.min(), Some(-5));
891 assert_eq!(a.max(), Some(4));
892 }
893
894 #[test]
895 fn min_max_axis_2d() {
896 let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
897 // axis 0: min/max per column
898 let mn0 = a.min_axis(Axis(0)).unwrap();
899 assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
900 let mx0 = a.max_axis(Axis(0)).unwrap();
901 assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
902
903 // axis 1: min/max per row
904 let mn1 = a.min_axis(Axis(1)).unwrap();
905 assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
906 let mx1 = a.max_axis(Axis(1)).unwrap();
907 assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
908 }
909
910 // ----- mean / var / std -----
911
912 #[test]
913 fn mean_1d() {
914 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
915 assert_eq!(a.mean(), Some(2.5));
916 }
917
918 #[test]
919 fn mean_empty_returns_none() {
920 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
921 assert_eq!(a.mean(), None);
922 }
923
924 #[test]
925 fn mean_axis_2d() {
926 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
927 let m0 = a.mean_axis(Axis(0)).unwrap();
928 assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
929 let m1 = a.mean_axis(Axis(1)).unwrap();
930 assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
931 }
932
933 #[test]
934 fn var_population() {
935 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
936 // population variance (ddof=0): ((1-3)^2+(2-3)^2+(3-3)^2+(4-3)^2+(5-3)^2)/5 = 10/5 = 2
937 assert_eq!(a.var(0), Some(2.0));
938 }
939
940 #[test]
941 fn var_sample() {
942 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
943 // sample variance (ddof=1): 10/4 = 2.5
944 assert_eq!(a.var(1), Some(2.5));
945 }
946
947 #[test]
948 fn std_basic() {
949 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
950 let s = a.std(0).unwrap();
951 assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
952 }
953
954 #[test]
955 fn var_ddof_too_large_returns_none() {
956 let a = arr1(vec![1.0, 2.0]);
957 assert_eq!(a.var(2), None);
958 assert_eq!(a.var(5), None);
959 }
960
961 // ----- any / all -----
962
963 #[test]
964 fn any_all_bool() {
965 let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
966 let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
967 let false_arr =
968 Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
969 let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
970
971 assert!(true_arr.all());
972 assert!(true_arr.any());
973
974 assert!(!mixed.all());
975 assert!(mixed.any());
976
977 assert!(!false_arr.all());
978 assert!(!false_arr.any());
979
980 // Vacuous truth for empty
981 assert!(empty.all());
982 assert!(!empty.any());
983 }
984
985 // ----- ArrayView mirrors -----
986
987 #[test]
988 fn view_sum_min_max_mean() {
989 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
990 let v = a.view();
991 assert_eq!(v.sum(), 10.0);
992 assert_eq!(v.min(), Some(1.0));
993 assert_eq!(v.max(), Some(4.0));
994 assert_eq!(v.mean(), Some(2.5));
995 }
996
997 #[test]
998 fn nan_propagates_in_min_max() {
999 // NaN somewhere in the middle
1000 let a = arr1(vec![1.0, f64::NAN, 3.0]);
1001 assert!(a.min().unwrap().is_nan());
1002 assert!(a.max().unwrap().is_nan());
1003
1004 // NaN at the start
1005 let b = arr1(vec![f64::NAN, 1.0, 3.0]);
1006 assert!(b.min().unwrap().is_nan());
1007 assert!(b.max().unwrap().is_nan());
1008
1009 // NaN at the end
1010 let c = arr1(vec![1.0, 3.0, f64::NAN]);
1011 assert!(c.min().unwrap().is_nan());
1012 assert!(c.max().unwrap().is_nan());
1013 }
1014}