ferray_core/array/reductions.rs
1// ferray-core: Reduction methods for Array<T, D>
2//
3// Provides NumPy-equivalent reduction methods directly on Array and ArrayView:
4// sum, prod, min, max, mean, var, std, any, all
5//
6// Each reduction has a whole-array variant and an axis variant:
7// .sum() -> T
8// .sum_axis(Axis(0)) -> Array<T, IxDyn>
9//
10// These methods complement the lower-level fold_axis primitive in methods.rs
11// and the free-function reductions in ferray-stats. The instance-method form
12// matches NumPy's `arr.sum()` ergonomics so users don't need to import a
13// separate crate just to compute a sum.
14//
15// ## REQ status (reductions, NumPy parity)
16// - sum / prod / sum_axis / prod_axis accumulator promotion — SHIPPED (#780):
17// the `ReduceAcc` trait (this file) maps narrow signed ints → i64, narrow
18// unsigned ints → u64, and bool → i64 before reducing, so narrow-int
19// reductions never overflow and the result dtype matches numpy
20// (numpy/_core/fromnumeric.py:2321-2327). Consumers: `Array::sum`/`prod`/
21// `sum_axis`/`prod_axis` and the `ArrayView` mirrors (all in this file).
22// - min / max / mean / var / std / any / all — SHIPPED (#368), NaN-propagating.
23// - empty-array min/max raising ValueError — SHIPPED (#782, resolved R-DEV-4):
24// `Array::min`/`max` return `Option<T>` = `None` on an empty array (the
25// idiomatic Rust analog of numpy's no-identity error — same boundary contract
26// as `argmax`/`argmin`, REQ-40/41; test `min_max_empty_returns_none` in this
27// file). The ferray-python boundary maps `None` -> `ValueError`, matching
28// numpy's `ValueError` ("zero-size array to reduction operation minimum which
29// has no identity", `fromnumeric.py:3150` `min`/`:3266` `amin`). Verified:
30// `fr.min(fr.array([], dtype=fr.float64))` raises `ValueError` vs numpy 2.4.5.
31// - cumsum / cumprod live as ferray-stats free functions; their narrow-int
32// promotion is a separate ferray-stats blocker that reuses `ReduceAcc`.
33// - argmax / argmin (REQ-40, REQ-41) — SHIPPED: `Array::argmax`/`argmin`
34// (flattened, returning `Option<i64>`) and `argmax_axis`/`argmin_axis`
35// (returning `Array<i64, IxDyn>`) in this file. First-occurrence on ties and
36// NaN-first propagation (numpy/_core/fromnumeric.py:1222 `argmax`,
37// :1261-1262 ties; :1322 `argmin`, :1361-1362 ties). Empty flattened form
38// returns `None` (mirroring `min`/`max`'s `Option` analog; the ferray-python
39// boundary maps `None`→`ValueError` as numpy does). Result index dtype is
40// `i64` (ferray's `intp` analog), independent of element dtype. Consumers:
41// the boundary methods themselves are the public API surface (like
42// `sum`/`min`), exercised by the `ArrayView` mirror and the in-file tests.
43// - integer/bool mean → f64 (REQ-42) — SHIPPED: the `MeanAcc` trait (this
44// file) maps bool/integer element types to an `f64` accumulator-and-result,
45// while `f32`/`f64`/complex stay themselves, so `Array::<i32, _>::mean()`
46// returns `f64` matching numpy, which casts bool/unsigned/signed int to
47// float64 before averaging (numpy/_core/_methods.py:124-127). `mean`/
48// `mean_axis` and the `ArrayView::mean` mirror are now bounded by `MeanAcc`
49// instead of `Float`; existing `f32`/`f64` means are unchanged
50// (`MeanAcc::Mean == Self`). Consumers: `Array::var`/`std` (`self.mean()?`)
51// and the `ArrayView` mirror, all in this file.
52//
53// See: https://github.com/dollspace-gay/ferray/issues/368, /issues/780
54
55use num_traits::Float;
56
57use crate::array::owned::Array;
58use crate::array::view::ArrayView;
59use crate::dimension::{Axis, Dimension, IxDyn};
60use crate::dtype::Element;
61use crate::error::FerrayResult;
62
63// ---------------------------------------------------------------------------
64// ReduceAcc — NumPy's sum/prod/cumsum/cumprod accumulator-and-result dtype.
65// ---------------------------------------------------------------------------
66
67/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
68/// return) `sum` / `prod` / `cumsum` / `cumprod` over it.
69///
70/// NumPy promotes any integer dtype of *less precision than the default
71/// platform integer* before reducing, so a narrow-int reduction can never
72/// overflow and the result dtype is the platform integer:
73///
74/// > "The dtype of `a` is used by default unless `a` has an integer dtype of
75/// > less precision than the default platform integer. In that case, if `a`
76/// > is signed then the platform integer is used while if `a` is unsigned then
77/// > an unsigned integer of the same precision as the platform integer is
78/// > used."
79/// > — `numpy/_core/fromnumeric.py:2321-2327` (sum), `:3306-3312` (prod)
80///
81/// The reduction itself is `umr_sum = um.add.reduce` /
82/// `umr_prod = um.multiply.reduce` (`numpy/_core/_methods.py:20-21`), i.e. the
83/// add/multiply ufunc whose *loop dtype* is the promoted accumulator.
84///
85/// The mapping (platform integer = 64-bit, matching ferray's `intp`/`int64`):
86/// - `i8 / i16 / i32 → i64`, `i64 → i64`, `i128 → i128`
87/// - `u8 / u16 / u32 → u64`, `u64 → u64`, `u128 → u128`
88/// - `bool → i64` (NumPy reduces bool as the platform integer, counting `true`)
89/// - `f32 → f32`, `f64 → f64`, complex stays itself (no promotion)
90///
91/// Wider-or-equal dtypes map to themselves, so existing `f64`/`i64`/complex
92/// reductions are unchanged — only narrow-int callers observe the promoted
93/// return type.
94pub trait ReduceAcc: Element + Copy {
95 /// The accumulator-and-result element type for reductions over `Self`.
96 type Acc: Element + Copy + std::ops::Add<Output = Self::Acc> + std::ops::Mul<Output = Self::Acc>;
97
98 /// Widen one element into the accumulator type before reducing, matching
99 /// NumPy's promotion of the loop dtype (`true → 1` for `bool`).
100 fn widen(self) -> Self::Acc;
101}
102
103macro_rules! impl_reduce_acc {
104 ($($t:ty => $acc:ty),* $(,)?) => {
105 $(
106 impl ReduceAcc for $t {
107 type Acc = $acc;
108 #[inline]
109 fn widen(self) -> $acc {
110 self as $acc
111 }
112 }
113 )*
114 };
115}
116
117// Narrow signed ints promote to i64; i64/i128 stay themselves.
118impl_reduce_acc! {
119 i8 => i64, i16 => i64, i32 => i64, i64 => i64, i128 => i128,
120 u8 => u64, u16 => u64, u32 => u64, u64 => u64, u128 => u128,
121 f32 => f32, f64 => f64,
122}
123
124// bool reduces as the platform integer, counting `true` (numpy:
125// `np.sum(np.array([True, True, True])).dtype == int64`). `as i64` maps
126// false→0, true→1.
127impl ReduceAcc for bool {
128 type Acc = i64;
129 #[inline]
130 fn widen(self) -> i64 {
131 i64::from(self)
132 }
133}
134
135// Complex stays itself — numpy never promotes a complex reduction.
136impl ReduceAcc for num_complex::Complex<f32> {
137 type Acc = num_complex::Complex<f32>;
138 #[inline]
139 fn widen(self) -> Self {
140 self
141 }
142}
143
144impl ReduceAcc for num_complex::Complex<f64> {
145 type Acc = num_complex::Complex<f64>;
146 #[inline]
147 fn widen(self) -> Self {
148 self
149 }
150}
151
152// ---------------------------------------------------------------------------
153// MeanAcc — NumPy's mean accumulator-and-result dtype.
154// ---------------------------------------------------------------------------
155
156/// Maps an element type `T` to the type NumPy uses to *accumulate* (and
157/// return) `mean` over it.
158///
159/// NumPy casts a bool / unsigned-int / signed-int input to `float64` before
160/// averaging:
161///
162/// > "Cast bool, unsigned int, and int to float64 by default ...
163/// > `dtype = mu.dtype('f8')`"
164/// > — `numpy/_core/_methods.py:124-127`
165///
166/// so `np.mean(np.array([1, 2, 3], np.int32))` is `float64 2.0` and
167/// `np.mean([True, False, True])` is `float64 0.6666…`. Floating-point inputs
168/// keep their own dtype (`f32`→`f32`, `f64`→`f64`), and complex stays itself.
169///
170/// The mapping:
171/// - `bool / i8.. / u8.. → f64`
172/// - `f32 → f32`, `f64 → f64` (unchanged — `Mean == Self`)
173/// - `Complex<f32> → Complex<f32>`, `Complex<f64> → Complex<f64>`
174pub trait MeanAcc: Element + Copy {
175 /// The accumulator-and-result element type for `mean` over `Self`.
176 type Mean: Element
177 + Copy
178 + std::ops::Add<Output = Self::Mean>
179 + std::ops::Div<Output = Self::Mean>;
180
181 /// Widen one element into the mean accumulator type, matching NumPy's
182 /// pre-average cast (`true → 1.0`, `false → 0.0` for `bool`).
183 fn widen_mean(self) -> Self::Mean;
184
185 /// Construct the divisor (element count `n`) in the accumulator type.
186 fn count(n: usize) -> Self::Mean;
187}
188
189macro_rules! impl_mean_acc_to_f64 {
190 ($($t:ty),* $(,)?) => {
191 $(
192 impl MeanAcc for $t {
193 type Mean = f64;
194 #[inline]
195 fn widen_mean(self) -> f64 {
196 self as f64
197 }
198 #[inline]
199 fn count(n: usize) -> f64 {
200 n as f64
201 }
202 }
203 )*
204 };
205}
206
207// bool / all integer dtypes average in f64 (numpy/_core/_methods.py:124-127).
208impl_mean_acc_to_f64!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
209
210impl MeanAcc for bool {
211 type Mean = f64;
212 #[inline]
213 fn widen_mean(self) -> f64 {
214 if self { 1.0 } else { 0.0 }
215 }
216 #[inline]
217 fn count(n: usize) -> f64 {
218 n as f64
219 }
220}
221
222// Floating-point inputs keep their own dtype (Mean == Self).
223impl MeanAcc for f32 {
224 type Mean = f32;
225 #[inline]
226 fn widen_mean(self) -> f32 {
227 self
228 }
229 #[inline]
230 fn count(n: usize) -> f32 {
231 n as f32
232 }
233}
234
235impl MeanAcc for f64 {
236 type Mean = f64;
237 #[inline]
238 fn widen_mean(self) -> f64 {
239 self
240 }
241 #[inline]
242 fn count(n: usize) -> f64 {
243 n as f64
244 }
245}
246
247impl MeanAcc for num_complex::Complex<f32> {
248 type Mean = num_complex::Complex<f32>;
249 #[inline]
250 fn widen_mean(self) -> Self {
251 self
252 }
253 #[inline]
254 fn count(n: usize) -> Self {
255 num_complex::Complex::new(n as f32, 0.0)
256 }
257}
258
259impl MeanAcc for num_complex::Complex<f64> {
260 type Mean = num_complex::Complex<f64>;
261 #[inline]
262 fn widen_mean(self) -> Self {
263 self
264 }
265 #[inline]
266 fn count(n: usize) -> Self {
267 num_complex::Complex::new(n as f64, 0.0)
268 }
269}
270
271/// First-occurrence, NaN-first arg-reduction over a flat element iterator.
272///
273/// Mirrors NumPy's `argmax`/`argmin` (`numpy/_core/fromnumeric.py:1222`,
274/// `:1322`): on ties the *first* occurrence wins (`:1261-1262`, `:1361-1362`),
275/// and when any NaN is present the index of the *first* NaN is returned
276/// (NaN-propagating, NaN-first — live oracle numpy 2.4.5:
277/// `np.argmax([1.0, nan, 3.0, nan]) == 1`).
278///
279/// Returns `None` for an empty iterator, the `Option` analog `min`/`max` use;
280/// the ferray-python boundary maps `None`→`ValueError` as numpy does.
281#[inline]
282fn arg_reduce<T: PartialOrd + Copy>(iter: impl Iterator<Item = T>, take_min: bool) -> Option<i64> {
283 let mut best_idx: Option<i64> = None;
284 let mut best: Option<T> = None;
285 for (i, x) in iter.enumerate() {
286 let i = i as i64;
287 // NaN-first: the first NaN seen wins immediately and is never beaten.
288 if x.partial_cmp(&x).is_none() {
289 return Some(i);
290 }
291 match best {
292 None => {
293 best = Some(x);
294 best_idx = Some(i);
295 }
296 Some(b) => {
297 // Strict comparison => first occurrence wins on ties.
298 let replace = match x.partial_cmp(&b) {
299 Some(std::cmp::Ordering::Less) => take_min,
300 Some(std::cmp::Ordering::Greater) => !take_min,
301 _ => false,
302 };
303 if replace {
304 best = Some(x);
305 best_idx = Some(i);
306 }
307 }
308 }
309 }
310 best_idx
311}
312
313/// Generic min/max fold step that propagates NaN per `NumPy` semantics.
314///
315/// Once any NaN enters the fold, all subsequent steps return NaN. Detected
316/// generically via `x.partial_cmp(&x).is_none()`, which is true iff `x` is
317/// NaN (or any other value that violates `PartialOrd` reflexivity, e.g.
318/// `Complex` types — but those don't implement `PartialOrd` so this is moot).
319///
320/// On an equal compare (`Ordering::Equal`) the NEW element `x` is kept, not the
321/// accumulator. This mirrors `numpy`'s `maximum.reduce`/`minimum.reduce`
322/// (`numpy/_core/_methods.py:38-44`, `umr_maximum`/`umr_minimum`), whose
323/// underlying scalar `maximum(a, b)`/`minimum(a, b)` loops return the *later*
324/// operand on ties — observable only for signed zeros (`+0.0 == -0.0`), where
325/// numpy keeps the LAST seen zero's sign bit. For any non-signed-zero equal
326/// pair the values are identical, so this changes nothing. This is the VALUE
327/// min/max reduce; `argmin`/`argmax` use first-occurrence on ties and live on a
328/// separate code path (they do not call `reduce_step`).
329#[inline]
330fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
331 let acc_is_nan = acc.partial_cmp(&acc).is_none();
332 if acc_is_nan {
333 return acc;
334 }
335 let x_is_nan = x.partial_cmp(&x).is_none();
336 if x_is_nan {
337 return x;
338 }
339 match (take_min, x.partial_cmp(&acc)) {
340 (true, Some(std::cmp::Ordering::Less)) => x,
341 (false, Some(std::cmp::Ordering::Greater)) => x,
342 // Tie: keep the LAST operand (numpy maximum/minimum.reduce semantics).
343 (_, Some(std::cmp::Ordering::Equal)) => x,
344 _ => acc,
345 }
346}
347
348// ---------------------------------------------------------------------------
349// Sum / Prod (work for any Element with Add/Mul, using Element::zero/one)
350// ---------------------------------------------------------------------------
351
352impl<T, D> Array<T, D>
353where
354 T: Element + Copy,
355 D: Dimension,
356{
357 /// Sum of all elements (whole-array reduction).
358 ///
359 /// The result type is the NumPy reduction accumulator
360 /// [`ReduceAcc::Acc`]: narrow signed ints widen to `i64`, narrow unsigned
361 /// ints to `u64`, `bool` to `i64`, and `f32`/`f64`/complex stay
362 /// themselves. This means a narrow-int sum can never overflow and its
363 /// dtype matches `np.sum`'s promoted result
364 /// (`numpy/_core/fromnumeric.py:2321-2327`).
365 ///
366 /// Returns `Acc::zero()` for an empty array.
367 ///
368 /// # Examples
369 /// ```
370 /// # use ferray_core::Array;
371 /// # use ferray_core::dimension::Ix1;
372 /// let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
373 /// assert_eq!(a.sum(), 6.0);
374 /// // i8 sums promote to i64 and never overflow (numpy parity):
375 /// let b = Array::<i8, Ix1>::from_vec(Ix1::new([3]), vec![100, 100, 100]).unwrap();
376 /// assert_eq!(b.sum(), 300_i64);
377 /// ```
378 pub fn sum(&self) -> <T as ReduceAcc>::Acc
379 where
380 T: ReduceAcc,
381 {
382 let mut acc = <T as ReduceAcc>::Acc::zero();
383 for &x in self.iter() {
384 acc = acc + x.widen();
385 }
386 acc
387 }
388
389 /// Sum along the given axis. Returns an array with one fewer dimension,
390 /// whose element type is the promoted [`ReduceAcc::Acc`] (same numpy
391 /// narrow-int promotion as the whole-array [`Array::sum`]).
392 ///
393 /// # Errors
394 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
395 pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
396 where
397 T: ReduceAcc,
398 D::NdarrayDim: ndarray::RemoveAxis,
399 {
400 let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
401 widened.fold_axis(axis, <T as ReduceAcc>::Acc::zero(), |acc, &x| *acc + x)
402 }
403
404 /// Product of all elements.
405 ///
406 /// The result type is the promoted [`ReduceAcc::Acc`] (same numpy
407 /// narrow-int promotion as [`Array::sum`]; see
408 /// `numpy/_core/fromnumeric.py:3306-3312`).
409 ///
410 /// Returns `Acc::one()` for an empty array.
411 pub fn prod(&self) -> <T as ReduceAcc>::Acc
412 where
413 T: ReduceAcc,
414 {
415 let mut acc = <T as ReduceAcc>::Acc::one();
416 for &x in self.iter() {
417 acc = acc * x.widen();
418 }
419 acc
420 }
421
422 /// Product along the given axis. Element type is the promoted
423 /// [`ReduceAcc::Acc`].
424 pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
425 where
426 T: ReduceAcc,
427 D::NdarrayDim: ndarray::RemoveAxis,
428 {
429 let widened = self.map_to::<<T as ReduceAcc>::Acc>(ReduceAcc::widen);
430 widened.fold_axis(axis, <T as ReduceAcc>::Acc::one(), |acc, &x| *acc * x)
431 }
432}
433
434// ---------------------------------------------------------------------------
435// Min / Max — require PartialOrd
436// ---------------------------------------------------------------------------
437
438impl<T, D> Array<T, D>
439where
440 T: Element + Copy + PartialOrd,
441 D: Dimension,
442{
443 /// Minimum value across the entire array.
444 ///
445 /// Returns `None` if the array is empty. NaN values follow `NumPy` semantics:
446 /// once a NaN is seen the result stays NaN, detected via self-comparison
447 /// (`x.partial_cmp(&x).is_none()`).
448 pub fn min(&self) -> Option<T> {
449 let mut iter = self.iter().copied();
450 let first = iter.next()?;
451 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
452 }
453
454 /// Maximum value across the entire array.
455 ///
456 /// Returns `None` if the array is empty. NaN values propagate per `NumPy`.
457 pub fn max(&self) -> Option<T> {
458 let mut iter = self.iter().copied();
459 let first = iter.next()?;
460 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
461 }
462
463 /// Minimum value along an axis.
464 ///
465 /// # Errors
466 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
467 /// `FerrayError::ShapeMismatch` if the resulting axis would be empty.
468 pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
469 where
470 D::NdarrayDim: ndarray::RemoveAxis,
471 {
472 // Use the first element along the axis as init by sentinel: pull the
473 // first lane and fold the rest. fold_axis applies init to every lane,
474 // but min has no neutral identity for arbitrary T. We sidestep by
475 // folding starting from any element of `self` — the per-lane init is
476 // overwritten by the first comparison, which is correct iff every lane
477 // has at least one element. Empty axes would yield uninitialized data.
478 let ndim = self.ndim();
479 if axis.index() >= ndim {
480 return Err(crate::error::FerrayError::axis_out_of_bounds(
481 axis.index(),
482 ndim,
483 ));
484 }
485 if self.shape()[axis.index()] == 0 {
486 return Err(crate::error::FerrayError::shape_mismatch(
487 "cannot compute min along empty axis",
488 ));
489 }
490 // Manual lane iteration: fold_axis can't be used here because min has
491 // no neutral identity that works for arbitrary `T: PartialOrd` (no
492 // T::infinity for ints).
493 self.fold_axis_min_max(axis, true)
494 }
495
496 /// Maximum value along an axis.
497 ///
498 /// See [`Array::min_axis`] for error semantics.
499 pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
500 where
501 D::NdarrayDim: ndarray::RemoveAxis,
502 {
503 let ndim = self.ndim();
504 if axis.index() >= ndim {
505 return Err(crate::error::FerrayError::axis_out_of_bounds(
506 axis.index(),
507 ndim,
508 ));
509 }
510 if self.shape()[axis.index()] == 0 {
511 return Err(crate::error::FerrayError::shape_mismatch(
512 "cannot compute max along empty axis",
513 ));
514 }
515 self.fold_axis_min_max(axis, false)
516 }
517
518 /// Flat index of the maximum element (whole-array reduction).
519 ///
520 /// Returns `None` for an empty array (the `Option` analog `min`/`max`
521 /// use; the ferray-python boundary maps `None`→`ValueError`, matching
522 /// `np.argmax([])`). On ties the **first** occurrence wins, and when any
523 /// NaN is present the index of the **first** NaN is returned (NaN-first),
524 /// matching `np.argmax` (`numpy/_core/fromnumeric.py:1222`, ties at
525 /// `:1261-1262`; live oracle `np.argmax([1.0, nan, 3.0, nan]) == 1`). The
526 /// index type is `i64` (ferray's `intp` analog), independent of `T`.
527 pub fn argmax(&self) -> Option<i64> {
528 arg_reduce(self.iter().copied(), false)
529 }
530
531 /// Flat index of the minimum element (whole-array reduction).
532 ///
533 /// Mirror of [`Array::argmax`] with min substituted for max: first
534 /// occurrence on ties, NaN-first, `None` on empty, `i64` index
535 /// (`numpy/_core/fromnumeric.py:1322`, ties at `:1361-1362`; live oracle
536 /// `np.argmin([1.0, nan, 3.0]) == 1`).
537 pub fn argmin(&self) -> Option<i64> {
538 arg_reduce(self.iter().copied(), true)
539 }
540
541 /// Indices of the maxima along `axis`, as an `Array<i64, IxDyn>` with the
542 /// reduced axis removed. First-occurrence on ties, NaN-first per lane
543 /// (`numpy/_core/fromnumeric.py:1222`).
544 ///
545 /// # Errors
546 /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`, or
547 /// `FerrayError::ShapeMismatch` if the reduced axis is empty (matching
548 /// numpy's `ValueError` on an empty argmax axis).
549 pub fn argmax_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
550 where
551 D::NdarrayDim: ndarray::RemoveAxis,
552 {
553 self.arg_axis(axis, false)
554 }
555
556 /// Indices of the minima along `axis`. See [`Array::argmax_axis`].
557 pub fn argmin_axis(&self, axis: Axis) -> FerrayResult<Array<i64, IxDyn>>
558 where
559 D::NdarrayDim: ndarray::RemoveAxis,
560 {
561 self.arg_axis(axis, true)
562 }
563
564 /// Internal: per-lane arg-reduction along `axis`. Each lane is a 1D view
565 /// orthogonal to `axis`; the reduced index is the position within the lane.
566 fn arg_axis(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<i64, IxDyn>>
567 where
568 D::NdarrayDim: ndarray::RemoveAxis,
569 {
570 let ndim = self.ndim();
571 if axis.index() >= ndim {
572 return Err(crate::error::FerrayError::axis_out_of_bounds(
573 axis.index(),
574 ndim,
575 ));
576 }
577 if self.shape()[axis.index()] == 0 {
578 return Err(crate::error::FerrayError::shape_mismatch(
579 "attempt to get argmax/argmin of an empty axis",
580 ));
581 }
582 let nd_axis = ndarray::Axis(axis.index());
583 let lanes = self.inner.lanes(nd_axis);
584 let mut out: Vec<i64> = Vec::with_capacity(lanes.into_iter().len());
585 for lane in self.inner.lanes(nd_axis) {
586 // Lane is non-empty (empty axis already rejected), so arg_reduce
587 // returns Some; default 0 is unreachable but keeps the code panic-free.
588 let idx = arg_reduce(lane.iter().copied(), take_min).unwrap_or(0);
589 out.push(idx);
590 }
591 let mut out_shape: Vec<usize> = self.shape().to_vec();
592 out_shape.remove(axis.index());
593 Array::from_vec(IxDyn::from(&out_shape[..]), out)
594 }
595
596 /// Internal: per-lane min/max via manual lane iteration. Avoids the
597 /// init-bias problem of `fold_axis` (which applies a single init to every
598 /// lane, even though min/max have no identity element).
599 fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
600 where
601 D::NdarrayDim: ndarray::RemoveAxis,
602 {
603 let nd_axis = ndarray::Axis(axis.index());
604 // Use ndarray's lane iteration directly via the inner ndarray::ArrayBase.
605 // Each lane is a 1D view orthogonal to the chosen axis.
606 let lanes = self.inner.lanes(nd_axis);
607 let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
608 for lane in self.inner.lanes(nd_axis) {
609 let mut iter = lane.iter().copied();
610 let first = iter.next().unwrap(); // safe: empty axis already rejected
611 let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
612 out.push(result);
613 }
614
615 // Output shape: drop the reduced axis from the input shape.
616 let mut out_shape: Vec<usize> = self.shape().to_vec();
617 out_shape.remove(axis.index());
618 Array::from_vec(IxDyn::from(&out_shape[..]), out)
619 }
620}
621
622// ---------------------------------------------------------------------------
623// Mean / Var / Std — require Float
624// ---------------------------------------------------------------------------
625
626impl<T, D> Array<T, D>
627where
628 T: MeanAcc,
629 D: Dimension,
630{
631 /// Arithmetic mean of all elements. Returns `None` for an empty array.
632 ///
633 /// The result type is the NumPy mean accumulator [`MeanAcc::Mean`]:
634 /// bool / integer inputs average in (and return) `f64`, while `f32`/`f64`/
635 /// complex keep their own dtype. This matches numpy, which casts bool /
636 /// unsigned / signed int to `float64` before averaging
637 /// (`numpy/_core/_methods.py:124-127`), so `Array::<i32, _>::mean()` is
638 /// `Some(f64)` and `Array::<bool, _>::mean()` is `Some(f64)` (e.g.
639 /// `0.666…`), matching `np.mean`.
640 pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
641 let n = self.size();
642 if n == 0 {
643 return None;
644 }
645 let sum = self
646 .iter()
647 .copied()
648 .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
649 Some(sum / <T as MeanAcc>::count(n))
650 }
651
652 /// Mean along an axis. Element type is the promoted [`MeanAcc::Mean`]
653 /// (bool / integer lanes average in `f64`; `f32`/`f64` stay themselves).
654 pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<<T as MeanAcc>::Mean, IxDyn>>
655 where
656 <T as MeanAcc>::Mean: ReduceAcc<Acc = <T as MeanAcc>::Mean>,
657 D::NdarrayDim: ndarray::RemoveAxis,
658 {
659 let ndim = self.ndim();
660 if axis.index() >= ndim {
661 return Err(crate::error::FerrayError::axis_out_of_bounds(
662 axis.index(),
663 ndim,
664 ));
665 }
666 let n = self.shape()[axis.index()];
667 if n == 0 {
668 return Err(crate::error::FerrayError::shape_mismatch(
669 "cannot compute mean along empty axis",
670 ));
671 }
672 // Widen each element into the mean accumulator, then sum along the axis
673 // and divide by the lane length.
674 let widened = self.map_to::<<T as MeanAcc>::Mean>(MeanAcc::widen_mean);
675 let sums = widened.sum_axis(axis)?;
676 let n_t = <T as MeanAcc>::count(n);
677 Ok(sums.mapv(|x| x / n_t))
678 }
679}
680
681impl<T, D> Array<T, D>
682where
683 T: Element + Float + MeanAcc<Mean = T>,
684 D: Dimension,
685{
686 /// Variance with `ddof` degrees of freedom (Bessel's correction = 1).
687 ///
688 /// Returns `None` for an empty array, or when `ddof >= n`.
689 pub fn var(&self, ddof: usize) -> Option<T> {
690 let n = self.size();
691 if n == 0 || ddof >= n {
692 return None;
693 }
694 let mean = self.mean()?;
695 let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
696 acc + (x - mean) * (x - mean)
697 });
698 Some(sum_sq / T::from(n - ddof).unwrap())
699 }
700
701 /// Standard deviation with `ddof` degrees of freedom.
702 pub fn std(&self, ddof: usize) -> Option<T> {
703 self.var(ddof).map(num_traits::Float::sqrt)
704 }
705}
706
707// ---------------------------------------------------------------------------
708// any / all — for bool arrays
709// ---------------------------------------------------------------------------
710
711impl<D> Array<bool, D>
712where
713 D: Dimension,
714{
715 /// Returns `true` if any element is `true`.
716 pub fn any(&self) -> bool {
717 self.iter().any(|&x| x)
718 }
719
720 /// Returns `true` if all elements are `true`. Vacuously `true` for empty arrays.
721 pub fn all(&self) -> bool {
722 self.iter().all(|&x| x)
723 }
724}
725
726// ---------------------------------------------------------------------------
727// ArrayView mirrors — same methods on borrowed views
728// ---------------------------------------------------------------------------
729
730impl<T, D> ArrayView<'_, T, D>
731where
732 T: Element + Copy,
733 D: Dimension,
734{
735 /// Sum of all elements. See [`Array::sum`] — returns the promoted
736 /// [`ReduceAcc::Acc`].
737 pub fn sum(&self) -> <T as ReduceAcc>::Acc
738 where
739 T: ReduceAcc,
740 {
741 let mut acc = <T as ReduceAcc>::Acc::zero();
742 for &x in self.iter() {
743 acc = acc + x.widen();
744 }
745 acc
746 }
747
748 /// Product of all elements. See [`Array::prod`] — returns the promoted
749 /// [`ReduceAcc::Acc`].
750 pub fn prod(&self) -> <T as ReduceAcc>::Acc
751 where
752 T: ReduceAcc,
753 {
754 let mut acc = <T as ReduceAcc>::Acc::one();
755 for &x in self.iter() {
756 acc = acc * x.widen();
757 }
758 acc
759 }
760}
761
762impl<T, D> ArrayView<'_, T, D>
763where
764 T: Element + Copy + PartialOrd,
765 D: Dimension,
766{
767 /// Minimum value. See [`Array::min`].
768 pub fn min(&self) -> Option<T> {
769 let mut iter = self.iter().copied();
770 let first = iter.next()?;
771 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
772 }
773
774 /// Maximum value. See [`Array::max`].
775 pub fn max(&self) -> Option<T> {
776 let mut iter = self.iter().copied();
777 let first = iter.next()?;
778 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
779 }
780
781 /// Flat index of the maximum element. See [`Array::argmax`].
782 pub fn argmax(&self) -> Option<i64> {
783 arg_reduce(self.iter().copied(), false)
784 }
785
786 /// Flat index of the minimum element. See [`Array::argmin`].
787 pub fn argmin(&self) -> Option<i64> {
788 arg_reduce(self.iter().copied(), true)
789 }
790}
791
792impl<T, D> ArrayView<'_, T, D>
793where
794 T: MeanAcc,
795 D: Dimension,
796{
797 /// Mean. See [`Array::mean`] — returns the promoted [`MeanAcc::Mean`].
798 pub fn mean(&self) -> Option<<T as MeanAcc>::Mean> {
799 let n = self.size();
800 if n == 0 {
801 return None;
802 }
803 let sum = self
804 .iter()
805 .copied()
806 .fold(<T as MeanAcc>::Mean::zero(), |acc, x| acc + x.widen_mean());
807 Some(sum / <T as MeanAcc>::count(n))
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use super::*;
814 use crate::dimension::{Ix1, Ix2};
815
816 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
817 let n = data.len();
818 Array::from_vec(Ix1::new([n]), data).unwrap()
819 }
820
821 fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
822 Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
823 }
824
825 // ----- sum / prod -----
826
827 #[test]
828 fn sum_1d() {
829 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
830 assert_eq!(a.sum(), 10.0);
831 }
832
833 #[test]
834 fn sum_empty_returns_zero() {
835 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
836 assert_eq!(a.sum(), 0.0);
837 }
838
839 #[test]
840 fn sum_axis_2d() {
841 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
842 // Sum across rows (axis 0): [1+4, 2+5, 3+6] = [5, 7, 9]
843 let s0 = a.sum_axis(Axis(0)).unwrap();
844 assert_eq!(s0.shape(), &[3]);
845 assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
846
847 // Sum across columns (axis 1): [1+2+3, 4+5+6] = [6, 15]
848 let s1 = a.sum_axis(Axis(1)).unwrap();
849 assert_eq!(s1.shape(), &[2]);
850 assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
851 }
852
853 #[test]
854 fn prod_1d() {
855 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
856 assert_eq!(a.prod(), 24.0);
857 }
858
859 #[test]
860 fn prod_empty_returns_one() {
861 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
862 assert_eq!(a.prod(), 1.0);
863 }
864
865 #[test]
866 fn prod_axis_2d() {
867 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
868 let p0 = a.prod_axis(Axis(0)).unwrap();
869 assert_eq!(
870 p0.iter().copied().collect::<Vec<_>>(),
871 vec![4.0, 10.0, 18.0]
872 );
873
874 let p1 = a.prod_axis(Axis(1)).unwrap();
875 assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
876 }
877
878 // ----- min / max -----
879
880 #[test]
881 fn min_max_1d() {
882 let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
883 assert_eq!(a.min(), Some(1.0));
884 assert_eq!(a.max(), Some(9.0));
885 }
886
887 #[test]
888 fn min_max_empty_returns_none() {
889 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
890 assert_eq!(a.min(), None);
891 assert_eq!(a.max(), None);
892 }
893
894 #[test]
895 fn min_max_int() {
896 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
897 assert_eq!(a.min(), Some(-5));
898 assert_eq!(a.max(), Some(4));
899 }
900
901 #[test]
902 fn min_max_axis_2d() {
903 let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
904 // axis 0: min/max per column
905 let mn0 = a.min_axis(Axis(0)).unwrap();
906 assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
907 let mx0 = a.max_axis(Axis(0)).unwrap();
908 assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
909
910 // axis 1: min/max per row
911 let mn1 = a.min_axis(Axis(1)).unwrap();
912 assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
913 let mx1 = a.max_axis(Axis(1)).unwrap();
914 assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
915 }
916
917 // ----- mean / var / std -----
918
919 #[test]
920 fn mean_1d() {
921 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
922 assert_eq!(a.mean(), Some(2.5));
923 }
924
925 #[test]
926 fn mean_empty_returns_none() {
927 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
928 assert_eq!(a.mean(), None);
929 }
930
931 #[test]
932 fn mean_axis_2d() {
933 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
934 let m0 = a.mean_axis(Axis(0)).unwrap();
935 assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
936 let m1 = a.mean_axis(Axis(1)).unwrap();
937 assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
938 }
939
940 #[test]
941 fn var_population() {
942 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
943 // population variance (ddof=0): ((1-3)^2+(2-3)^2+(3-3)^2+(4-3)^2+(5-3)^2)/5 = 10/5 = 2
944 assert_eq!(a.var(0), Some(2.0));
945 }
946
947 #[test]
948 fn var_sample() {
949 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
950 // sample variance (ddof=1): 10/4 = 2.5
951 assert_eq!(a.var(1), Some(2.5));
952 }
953
954 #[test]
955 fn std_basic() {
956 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
957 let s = a.std(0).unwrap();
958 assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
959 }
960
961 #[test]
962 fn var_ddof_too_large_returns_none() {
963 let a = arr1(vec![1.0, 2.0]);
964 assert_eq!(a.var(2), None);
965 assert_eq!(a.var(5), None);
966 }
967
968 // ----- any / all -----
969
970 #[test]
971 fn any_all_bool() {
972 let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
973 let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
974 let false_arr =
975 Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
976 let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
977
978 assert!(true_arr.all());
979 assert!(true_arr.any());
980
981 assert!(!mixed.all());
982 assert!(mixed.any());
983
984 assert!(!false_arr.all());
985 assert!(!false_arr.any());
986
987 // Vacuous truth for empty
988 assert!(empty.all());
989 assert!(!empty.any());
990 }
991
992 // ----- ArrayView mirrors -----
993
994 #[test]
995 fn view_sum_min_max_mean() {
996 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
997 let v = a.view();
998 assert_eq!(v.sum(), 10.0);
999 assert_eq!(v.min(), Some(1.0));
1000 assert_eq!(v.max(), Some(4.0));
1001 assert_eq!(v.mean(), Some(2.5));
1002 }
1003
1004 #[test]
1005 fn nan_propagates_in_min_max() {
1006 // NaN somewhere in the middle
1007 let a = arr1(vec![1.0, f64::NAN, 3.0]);
1008 assert!(a.min().unwrap().is_nan());
1009 assert!(a.max().unwrap().is_nan());
1010
1011 // NaN at the start
1012 let b = arr1(vec![f64::NAN, 1.0, 3.0]);
1013 assert!(b.min().unwrap().is_nan());
1014 assert!(b.max().unwrap().is_nan());
1015
1016 // NaN at the end
1017 let c = arr1(vec![1.0, 3.0, f64::NAN]);
1018 assert!(c.min().unwrap().is_nan());
1019 assert!(c.max().unwrap().is_nan());
1020 }
1021}