Skip to main content

ferray_stats/reductions/
mod.rs

1// ferray-stats: Core reductions — sum, prod, min, max, argmin, argmax, mean, var, std (REQ-1, REQ-2)
2//
3// `cumulative` used to exist as an empty placeholder whose only content
4// was a note that cumulative functions live in nan_aware / mod.rs; it's
5// been deleted to stop pretending there's a dedicated module there (#162).
6//
7// ## REQ status (ferray-stats reductions, NumPy parity)
8//  - REQ-1 / REQ-2 (sum/prod/min/max/argmin/argmax/mean/var/std/ptp/average
9//    as free functions, each with an optional `axis`) — SHIPPED: `pub fn sum`,
10//    `prod`, `min`, `max`, `argmin`, `argmax`, `ptp`, `average`, `mean`, `var`,
11//    `std_` (this file), plus the `_axes` / `_into` / `_with` axis-tuple and
12//    out-buffer variants. Each takes `axis: Option<usize>` (`None` = whole
13//    array), matching numpy's `axis=None` collapse
14//    (numpy/_core/fromnumeric.py:2321 `sum`, :2160 `min`/`max`). Non-test
15//    consumers: the `#[pyfunction]` shims in `ferray-python/src/stats.rs`
16//    (`ferray_stats::sum`/`min`/`max`/`argmin`/`argmax`/`ptp`/`average`/`mean`/
17//    `var`/`std_`/`prod`), and in-crate `var`/`std_` reuse `mean`; `average`
18//    delegates to `mean` when no weights are given.
19//  - min/max/argmin/argmax/ptp NaN-propagation (#1058-#1062) — SHIPPED: the
20//    extremum reducers in `min`/`max`/`argmin`/`argmax`/`ptp` propagate NaN
21//    (a NaN anywhere in the reduced lane makes the result NaN, and argmin/
22//    argmax return the index of the first NaN), matching numpy's
23//    non-`nan*` extremum semantics (numpy/_core/_methods.py:43 `_amin`,
24//    :39 `_amax`). Audited green. Consumers: the same `ferray-python`
25//    `#[pyfunction]` shims above (`ferray_stats::min`/`max`/`argmin`/`argmax`/
26//    `ptp`).
27//  - integer/bool sum/prod accumulator promotion — SHIPPED: `sum`/`prod`
28//    reduce through `ferray_core::array::reductions::ReduceAcc` (widen narrow
29//    ints before accumulating), so narrow-int reductions do not overflow,
30//    matching numpy (numpy/_core/fromnumeric.py:2321-2327). `sum_as_f64` /
31//    `mean_as_f64` / `var_as_f64` / `std_as_f64` expose the float-result forms.
32//  - SIMD/parallel reduction kernels — SHIPPED (consumer side): `sum`/`mean`/
33//    `var` route through `parallel::pairwise_sum*` / `parallel::simd_sum_sq_diff*`
34//    (this file's `try_simd_sum_sq_diff` helper); see `parallel.rs` REQ-19.
35//  - cumsum / cumprod — SHIPPED here as free functions `pub fn cumsum` /
36//    `pub fn cumprod` (REQ-2a); consumed by `ferray_stats::cumsum`/`cumprod`
37//    in `ferray-python/src/stats.rs`.
38
39pub mod nan_aware;
40pub mod quantile;
41
42use std::any::TypeId;
43
44use ferray_core::array::reductions::ReduceAcc;
45use ferray_core::error::{FerrayError, FerrayResult};
46use ferray_core::{Array, Dimension, Element, IxDyn};
47use num_traits::Float;
48
49use crate::parallel;
50
51/// Try SIMD-accelerated fused sum of squared differences for f64 or f32 (#173).
52/// Returns sum((x - mean)²) without allocating an intermediate Vec.
53#[inline]
54fn try_simd_sum_sq_diff<T: Element + Copy + 'static>(data: &[T], mean: T) -> Option<T> {
55    if TypeId::of::<T>() == TypeId::of::<f64>() {
56        // SAFETY: TypeId check guarantees T is f64. size_of::<T>() == size_of::<f64>().
57        let f64_slice =
58            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f64>(), data.len()) };
59        let mean_f64: f64 = unsafe { std::mem::transmute_copy(&mean) };
60        let result = parallel::simd_sum_sq_diff_f64(f64_slice, mean_f64);
61        Some(unsafe { std::mem::transmute_copy(&result) })
62    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
63        // SAFETY: TypeId check guarantees T is f32. size_of::<T>() == size_of::<f32>().
64        let f32_slice =
65            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f32>(), data.len()) };
66        let mean_f32: f32 = unsafe { std::mem::transmute_copy(&mean) };
67        let result = parallel::simd_sum_sq_diff_f32(f32_slice, mean_f32);
68        Some(unsafe { std::mem::transmute_copy(&result) })
69    } else {
70        None
71    }
72}
73
74/// Try SIMD-accelerated pairwise sum for f64 or f32 (#173).
75/// Returns the sum transmuted back to T, or None if T is not f64/f32.
76#[inline]
77fn try_simd_pairwise_sum<T: Element + Copy + 'static>(data: &[T]) -> Option<T> {
78    if TypeId::of::<T>() == TypeId::of::<f64>() {
79        // SAFETY: TypeId check guarantees T is f64. size_of::<T>() == size_of::<f64>().
80        let f64_slice =
81            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f64>(), data.len()) };
82        let result = parallel::pairwise_sum_f64(f64_slice);
83        Some(unsafe { std::mem::transmute_copy(&result) })
84    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
85        // SAFETY: TypeId check guarantees T is f32. size_of::<T>() == size_of::<f32>().
86        let f32_slice =
87            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f32>(), data.len()) };
88        let result = parallel::pairwise_sum_f32(f32_slice);
89        Some(unsafe { std::mem::transmute_copy(&result) })
90    } else {
91        None
92    }
93}
94
95// ---------------------------------------------------------------------------
96// Internal axis-reduction helper
97// ---------------------------------------------------------------------------
98
99/// Compute row-major strides for a given shape.
100pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
101    let ndim = shape.len();
102    let mut strides = vec![1usize; ndim];
103    for i in (0..ndim.saturating_sub(1)).rev() {
104        strides[i] = strides[i + 1] * shape[i + 1];
105    }
106    strides
107}
108
109/// Flat index from a multi-index given row-major strides.
110pub(crate) fn flat_index(multi: &[usize], strides: &[usize]) -> usize {
111    multi.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
112}
113
114/// Increment a multi-index in row-major order. Returns false if overflowed.
115pub(crate) fn increment_multi_index(multi: &mut [usize], shape: &[usize]) -> bool {
116    for d in (0..multi.len()).rev() {
117        multi[d] += 1;
118        if multi[d] < shape[d] {
119            return true;
120        }
121        multi[d] = 0;
122    }
123    false
124}
125
126/// General axis reduction parameterized on the output element type.
127///
128/// Walks `data` (in row-major order with `shape`), gathers each lane
129/// along `axis` into a temporary `Vec<T>`, and calls `f` to collapse it
130/// to an output of type `U`. Returns the concatenated outputs in
131/// row-major order for the shape with `axis` removed.
132///
133/// Used as the shared backbone for `reduce_axis_general` (T=T path) and
134/// `reduce_axis_general_u64` (T=T, U=u64 path). The two used to have
135/// copy-pasted bodies differing only in return type (#161).
136pub(crate) fn reduce_axis_typed<T, U, F>(data: &[T], shape: &[usize], axis: usize, f: F) -> Vec<U>
137where
138    T: Copy,
139    F: Fn(&[T]) -> U,
140{
141    let ndim = shape.len();
142    let axis_len = shape[axis];
143    let strides = compute_strides(shape);
144
145    // Output shape: shape with axis removed
146    let out_shape: Vec<usize> = shape
147        .iter()
148        .enumerate()
149        .filter(|&(i, _)| i != axis)
150        .map(|(_, &s)| s)
151        .collect();
152    let out_size: usize = if out_shape.is_empty() {
153        1
154    } else {
155        out_shape.iter().product()
156    };
157
158    let mut result = Vec::with_capacity(out_size);
159    let mut out_multi = vec![0usize; out_shape.len()];
160    let mut in_multi = vec![0usize; ndim];
161    let mut lane_vec: Vec<T> = Vec::with_capacity(axis_len);
162
163    for _ in 0..out_size {
164        // Build input multi-index by inserting axis position
165        let mut out_dim = 0;
166        for (d, idx) in in_multi.iter_mut().enumerate() {
167            if d == axis {
168                *idx = 0;
169            } else {
170                *idx = out_multi[out_dim];
171                out_dim += 1;
172            }
173        }
174
175        // Gather lane values
176        lane_vec.clear();
177        for k in 0..axis_len {
178            in_multi[axis] = k;
179            let idx = flat_index(&in_multi, &strides);
180            lane_vec.push(data[idx]);
181        }
182
183        result.push(f(&lane_vec));
184
185        // Increment output multi-index
186        if !out_shape.is_empty() {
187            increment_multi_index(&mut out_multi, &out_shape);
188        }
189    }
190
191    result
192}
193
194/// Thin wrapper: `T -> T` reduction. Preserved for the many call sites
195/// that pass `Fn(&[T]) -> T` kernels.
196#[inline]
197pub(crate) fn reduce_axis_general<T, F>(data: &[T], shape: &[usize], axis: usize, f: F) -> Vec<T>
198where
199    T: Copy,
200    F: Fn(&[T]) -> T,
201{
202    reduce_axis_typed(data, shape, axis, f)
203}
204
205/// In-place variant of [`reduce_axis_typed`]: writes results directly into
206/// `dst` without allocating a fresh `Vec<U>` to hold the output.
207///
208/// Used by the `*_into` reductions (#563) so callers that pre-allocate an
209/// output buffer truly avoid every per-call allocation. The destination
210/// slice must already have exactly `out_size` elements where `out_size` is
211/// the product of `shape` with `axis` removed (or 1 if the result is
212/// 0-D); callers should validate this via [`check_out_shape`] before
213/// invoking the kernel.
214pub(crate) fn reduce_axis_typed_into<T, U, F>(
215    data: &[T],
216    shape: &[usize],
217    axis: usize,
218    dst: &mut [U],
219    f: F,
220) where
221    T: Copy,
222    F: Fn(&[T]) -> U,
223{
224    let ndim = shape.len();
225    let axis_len = shape[axis];
226    let strides = compute_strides(shape);
227
228    // Output multi-index walks the shape with `axis` removed.
229    let out_shape: Vec<usize> = shape
230        .iter()
231        .enumerate()
232        .filter(|&(i, _)| i != axis)
233        .map(|(_, &s)| s)
234        .collect();
235
236    debug_assert_eq!(
237        dst.len(),
238        if out_shape.is_empty() {
239            1
240        } else {
241            out_shape.iter().product::<usize>()
242        },
243        "reduce_axis_typed_into: dst length must match the reduction's output size"
244    );
245
246    let mut out_multi = vec![0usize; out_shape.len()];
247    let mut in_multi = vec![0usize; ndim];
248    // Reused per-lane buffer — one allocation total instead of one per
249    // lane the way `reduce_axis_typed` does it via `Vec::push`.
250    let mut lane_vec: Vec<T> = Vec::with_capacity(axis_len);
251
252    for slot in dst.iter_mut() {
253        // Build the input multi-index by inserting the axis position.
254        let mut out_dim = 0;
255        for (d, idx) in in_multi.iter_mut().enumerate() {
256            if d == axis {
257                *idx = 0;
258            } else {
259                *idx = out_multi[out_dim];
260                out_dim += 1;
261            }
262        }
263
264        // Gather lane values into the reusable buffer.
265        lane_vec.clear();
266        for k in 0..axis_len {
267            in_multi[axis] = k;
268            let idx = flat_index(&in_multi, &strides);
269            lane_vec.push(data[idx]);
270        }
271
272        *slot = f(&lane_vec);
273
274        if !out_shape.is_empty() {
275            increment_multi_index(&mut out_multi, &out_shape);
276        }
277    }
278}
279
280/// In-place T-to-T reduction wrapper around [`reduce_axis_typed_into`].
281#[inline]
282pub(crate) fn reduce_axis_general_into<T, F>(
283    data: &[T],
284    shape: &[usize],
285    axis: usize,
286    dst: &mut [T],
287    f: F,
288) where
289    T: Copy,
290    F: Fn(&[T]) -> T,
291{
292    reduce_axis_typed_into(data, shape, axis, dst, f);
293}
294
295/// Validate axis parameter and return an error if out of bounds.
296pub(crate) const fn validate_axis(axis: usize, ndim: usize) -> FerrayResult<()> {
297    if axis >= ndim {
298        Err(FerrayError::axis_out_of_bounds(axis, ndim))
299    } else {
300        Ok(())
301    }
302}
303
304/// Collect array data into a contiguous Vec in logical (row-major) order.
305pub(crate) fn collect_data<T: Element + Copy, D: Dimension>(a: &Array<T, D>) -> Vec<T> {
306    a.iter().copied().collect()
307}
308
309/// Borrow contiguous data or copy if strided. Avoids allocation for contiguous arrays.
310pub(crate) enum DataRef<'a, T> {
311    Borrowed(&'a [T]),
312    Owned(Vec<T>),
313}
314
315impl<T> std::ops::Deref for DataRef<'_, T> {
316    type Target = [T];
317    fn deref(&self) -> &[T] {
318        match self {
319            DataRef::Borrowed(s) => s,
320            DataRef::Owned(v) => v,
321        }
322    }
323}
324
325/// Get a reference to contiguous data, or copy if strided.
326pub(crate) fn borrow_data<T: Element + Copy, D: Dimension>(a: &Array<T, D>) -> DataRef<'_, T> {
327    if let Some(slice) = a.as_slice() {
328        DataRef::Borrowed(slice)
329    } else {
330        DataRef::Owned(a.iter().copied().collect())
331    }
332}
333
334/// Build an `IxDyn` result array from output shape and data.
335pub(crate) fn make_result<T: Element>(
336    out_shape: &[usize],
337    data: Vec<T>,
338) -> FerrayResult<Array<T, IxDyn>> {
339    Array::from_vec(IxDyn::new(out_shape), data)
340}
341
342/// Validate that `out` has the expected shape and is C-contiguous,
343/// returning a mutable slice into its backing buffer.
344///
345/// Shared by every `*_into` reduction (#467, #563) so the validation
346/// surface lives in exactly one place. Broadcasting is intentionally not
347/// allowed because the destination shape is fixed by the input + axis
348/// combination — accepting a mismatched destination would silently
349/// re-route to a different reduction shape.
350pub(crate) fn check_out_shape<'a, T: Element + Copy>(
351    out: &'a mut Array<T, IxDyn>,
352    expected_shape: &[usize],
353    op_name: &str,
354) -> FerrayResult<&'a mut [T]> {
355    if out.shape() != expected_shape {
356        return Err(FerrayError::shape_mismatch(format!(
357            "{op_name}_into: out shape {:?} does not match expected reduction shape {:?}",
358            out.shape(),
359            expected_shape
360        )));
361    }
362    out.as_slice_mut().ok_or_else(|| {
363        FerrayError::invalid_value(format!("{op_name}_into: out must be C-contiguous"))
364    })
365}
366
367/// Compute the output shape when reducing along an axis.
368pub(crate) fn output_shape(shape: &[usize], axis: usize) -> Vec<usize> {
369    shape
370        .iter()
371        .enumerate()
372        .filter(|&(i, _)| i != axis)
373        .map(|(_, &s)| s)
374        .collect()
375}
376
377// ---------------------------------------------------------------------------
378// Multi-axis reduction helpers (issues #457 + #458)
379// ---------------------------------------------------------------------------
380
381/// Normalize the `axes: Option<&[usize]>` argument of a multi-axis reduction:
382///
383/// - `None` or `Some(&[])` expands to all axes `[0..ndim]` (reduce everything)
384/// - Duplicate axes are an error (matches `NumPy`'s `np.sum(a, axis=(0, 0))`)
385/// - Any out-of-bounds axis is an error
386///
387/// Returns the sorted, unique axis list.
388pub(crate) fn normalize_axes(axes: Option<&[usize]>, ndim: usize) -> FerrayResult<Vec<usize>> {
389    let ax: Vec<usize> = match axes {
390        None | Some([]) => (0..ndim).collect(),
391        Some(s) => s.to_vec(),
392    };
393    for &a in &ax {
394        if a >= ndim {
395            return Err(FerrayError::axis_out_of_bounds(a, ndim));
396        }
397    }
398    let mut sorted = ax;
399    sorted.sort_unstable();
400    for w in sorted.windows(2) {
401        if w[0] == w[1] {
402            return Err(FerrayError::invalid_value(format!(
403                "duplicate axis {} in reduction axes",
404                w[0]
405            )));
406        }
407    }
408    Ok(sorted)
409}
410
411/// Compute the output shape when reducing over multiple axes.
412///
413/// `axes` must be sorted and unique (produced by [`normalize_axes`]).
414/// With `keepdims = true`, reduced axes are replaced by size 1; otherwise
415/// they are removed.
416pub(crate) fn output_shape_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Vec<usize> {
417    if keepdims {
418        shape
419            .iter()
420            .enumerate()
421            .map(|(i, &s)| if axes.contains(&i) { 1 } else { s })
422            .collect()
423    } else {
424        shape
425            .iter()
426            .enumerate()
427            .filter(|(i, _)| !axes.contains(i))
428            .map(|(_, &s)| s)
429            .collect()
430    }
431}
432
433/// Reduce over a set of axes, returning `(result_data, output_shape)`.
434///
435/// `axes` must be sorted and unique (typically produced by [`normalize_axes`]).
436/// The caller supplies a reduction function `f` that collapses a lane of
437/// reduced-axis values to a single scalar.
438///
439/// The output shape honors `keepdims`.
440pub(crate) fn reduce_axes_general<T: Copy, F: Fn(&[T]) -> T>(
441    data: &[T],
442    shape: &[usize],
443    axes: &[usize],
444    keepdims: bool,
445    f: F,
446) -> (Vec<T>, Vec<usize>) {
447    let ndim = shape.len();
448    let strides = compute_strides(shape);
449
450    // Partition axes into reduce / keep.
451    let is_reduce: Vec<bool> = (0..ndim).map(|i| axes.contains(&i)).collect();
452    let keep_axes: Vec<usize> = (0..ndim).filter(|i| !is_reduce[*i]).collect();
453    let reduce_axes: Vec<usize> = (0..ndim).filter(|i| is_reduce[*i]).collect();
454    let keep_shape: Vec<usize> = keep_axes.iter().map(|&i| shape[i]).collect();
455    let reduce_shape: Vec<usize> = reduce_axes.iter().map(|&i| shape[i]).collect();
456
457    let out_size: usize = if keep_shape.is_empty() {
458        1
459    } else {
460        keep_shape.iter().product()
461    };
462    let lane_size: usize = if reduce_shape.is_empty() {
463        1
464    } else {
465        reduce_shape.iter().product()
466    };
467
468    let out_shape = output_shape_axes(shape, axes, keepdims);
469
470    let mut result = Vec::with_capacity(out_size);
471    let mut lane: Vec<T> = Vec::with_capacity(lane_size);
472    let mut keep_multi = vec![0usize; keep_shape.len()];
473    let mut reduce_multi = vec![0usize; reduce_shape.len()];
474    let mut full_multi = vec![0usize; ndim];
475
476    for _ in 0..out_size {
477        // Fill kept-axis positions from keep_multi.
478        for (i, &ax) in keep_axes.iter().enumerate() {
479            full_multi[ax] = keep_multi[i];
480        }
481
482        // Gather values along all reduced axes.
483        lane.clear();
484        reduce_multi.fill(0);
485        for _ in 0..lane_size {
486            for (i, &ax) in reduce_axes.iter().enumerate() {
487                full_multi[ax] = reduce_multi[i];
488            }
489            lane.push(data[flat_index(&full_multi, &strides)]);
490            if !reduce_shape.is_empty() {
491                increment_multi_index(&mut reduce_multi, &reduce_shape);
492            }
493        }
494
495        result.push(f(&lane));
496
497        if !keep_shape.is_empty() {
498            increment_multi_index(&mut keep_multi, &keep_shape);
499        }
500    }
501
502    (result, out_shape)
503}
504
505// ---------------------------------------------------------------------------
506// sum
507// ---------------------------------------------------------------------------
508
509/// Sum of array elements over a given axis, or over all elements if axis is None.
510///
511/// Equivalent to `numpy.sum`.
512///
513/// The result element type is the NumPy reduction accumulator
514/// [`ReduceAcc::Acc`]: narrow signed ints widen to `i64`, narrow unsigned
515/// ints to `u64`, `bool` to `i64`, and `f32`/`f64`/`i64`/complex stay
516/// themselves. So a narrow-int sum can never overflow and its dtype matches
517/// `np.sum`'s promoted result, which "uses the default platform integer ...
518/// [when] `a` has an integer dtype of less precision than the default
519/// platform integer" (`numpy/_core/fromnumeric.py:2321-2327`). `f32`/`f64`
520/// callers keep the byte-identical SIMD/pairwise fast path.
521///
522/// # Examples
523/// ```ignore
524/// let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
525/// let s = sum(&a, None).unwrap();
526/// assert_eq!(s.iter().next(), Some(&10.0));
527/// ```
528pub fn sum<T, D>(
529    a: &Array<T, D>,
530    axis: Option<usize>,
531) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
532where
533    T: ReduceAcc + Send + Sync,
534    <T as ReduceAcc>::Acc: Send + Sync + 'static,
535    D: Dimension,
536{
537    // Widen every element into the accumulator dtype `T::Acc` up front, then
538    // run the existing SIMD/pairwise/parallel summation on the promoted slice.
539    // For `f32`/`f64`/`i64` (`Acc == Self`) the widen is the identity cast and
540    // `try_simd_pairwise_sum` still dispatches the SIMD path on the same
541    // values, so those results are byte-identical to before.
542    let widened: Vec<<T as ReduceAcc>::Acc> = a.iter().map(|&x| x.widen()).collect();
543    match axis {
544        None => {
545            let total = try_simd_pairwise_sum(&widened).unwrap_or_else(|| {
546                parallel::parallel_sum(&widened, <<T as ReduceAcc>::Acc as Element>::zero())
547            });
548            make_result(&[], vec![total])
549        }
550        Some(ax) => {
551            validate_axis(ax, a.ndim())?;
552            let shape = a.shape();
553            let out_s = output_shape(shape, ax);
554            let result = reduce_axis_general(&widened, shape, ax, |lane| {
555                try_simd_pairwise_sum(lane).unwrap_or_else(|| {
556                    parallel::pairwise_sum(lane, <<T as ReduceAcc>::Acc as Element>::zero())
557                })
558            });
559            make_result(&out_s, result)
560        }
561    }
562}
563
564/// Sum of array elements, returning `f64` regardless of input type.
565///
566/// This works on integer arrays (i32, u64, etc.) without overflow risk.
567/// The result is always `Array<f64, IxDyn>`, matching `NumPy`'s behavior
568/// of promoting integer sums to a wider type.
569pub fn sum_as_f64<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<f64, IxDyn>>
570where
571    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
572    D: Dimension,
573{
574    match axis {
575        None => {
576            let total: f64 = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).sum();
577            make_result(&[], vec![total])
578        }
579        Some(ax) => {
580            validate_axis(ax, a.ndim())?;
581            let shape = a.shape();
582            let out_s = output_shape(shape, ax);
583            let f64_data: Vec<f64> = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).collect();
584            let result = reduce_axis_general(&f64_data, shape, ax, |lane| lane.iter().sum());
585            make_result(&out_s, result)
586        }
587    }
588}
589
590// ---------------------------------------------------------------------------
591// prod
592// ---------------------------------------------------------------------------
593
594/// Product of array elements over a given axis.
595///
596/// The result element type is the promoted [`ReduceAcc::Acc`] (same numpy
597/// narrow-int promotion as [`sum`]; `numpy/_core/fromnumeric.py:3306-3312`),
598/// so a narrow-int product can never overflow and its dtype matches
599/// `np.prod`. `f32`/`f64`/`i64`/complex stay themselves.
600///
601/// Equivalent to `numpy.prod`.
602pub fn prod<T, D>(
603    a: &Array<T, D>,
604    axis: Option<usize>,
605) -> FerrayResult<Array<<T as ReduceAcc>::Acc, IxDyn>>
606where
607    T: ReduceAcc + Send + Sync,
608    <T as ReduceAcc>::Acc: Send + Sync,
609    D: Dimension,
610{
611    // Widen every element into the accumulator dtype before multiplying. For
612    // `f32`/`f64`/`i64` the widen is the identity cast, so those products are
613    // byte-identical to before.
614    let widened: Vec<<T as ReduceAcc>::Acc> = a.iter().map(|&x| x.widen()).collect();
615    match axis {
616        None => {
617            let total =
618                parallel::parallel_prod(&widened, <<T as ReduceAcc>::Acc as Element>::one());
619            make_result(&[], vec![total])
620        }
621        Some(ax) => {
622            validate_axis(ax, a.ndim())?;
623            let shape = a.shape();
624            let out_s = output_shape(shape, ax);
625            let result = reduce_axis_general(&widened, shape, ax, |lane| {
626                lane.iter()
627                    .copied()
628                    .fold(<<T as ReduceAcc>::Acc as Element>::one(), |acc, x| acc * x)
629            });
630            make_result(&out_s, result)
631        }
632    }
633}
634
635// ---------------------------------------------------------------------------
636// min / max
637// ---------------------------------------------------------------------------
638
639/// Generic NaN test for any `T: PartialEq`.
640///
641/// A value is NaN iff it does not equal itself (IEEE-754: `NaN != NaN`). For
642/// every integer/bool type this is always `false`, so the integer min/max/
643/// argmin/argmax paths are byte-for-byte unchanged; only float types observe
644/// the NaN-propagating behaviour numpy mandates (`np.maximum`/`np.minimum`
645/// propagate NaN — `numpy/_core/fromnumeric.py:3076`/`:3214`).
646#[inline]
647#[allow(clippy::eq_op)] // `x != x` is the intentional generic IEEE-754 NaN test.
648fn is_nan_val<T: PartialEq>(x: &T) -> bool {
649    x != x
650}
651
652/// First-NaN-aware fold step for `argmin`.
653///
654/// numpy reduces argmin via `np.minimum`, which propagates NaN: a NaN compares
655/// as the minimum, so the FIRST NaN's index wins (numpy/_core/fromnumeric.py:1362).
656/// Walking left-to-right: once the running accumulator is NaN it is kept (the
657/// first NaN's index is already recorded); a fresh NaN is adopted over any
658/// finite running value; otherwise the plain `<=` first-occurrence rule applies.
659/// For integers `is_nan_val` is always false, so this is byte-identical to the
660/// previous `if av <= bv` behaviour.
661#[inline]
662fn argmin_fold<'a, T: PartialOrd + PartialEq>(
663    acc: (usize, &'a T),
664    next: (usize, &'a T),
665) -> (usize, &'a T) {
666    let (ai, av) = acc;
667    let (bi, bv) = next;
668    if is_nan_val(av) {
669        (ai, av)
670    } else if is_nan_val(bv) {
671        (bi, bv)
672    } else if av <= bv {
673        (ai, av)
674    } else {
675        (bi, bv)
676    }
677}
678
679/// First-NaN-aware fold step for `argmax` (numpy/_core/fromnumeric.py:1262).
680///
681/// Mirror of [`argmin_fold`] with `>=` ordering: numpy reduces argmax via
682/// `np.maximum`, which propagates NaN, so the FIRST NaN's index wins. Integers
683/// never satisfy `is_nan_val`, so the integer path is byte-identical to the
684/// previous `if av >= bv` behaviour.
685#[inline]
686fn argmax_fold<'a, T: PartialOrd + PartialEq>(
687    acc: (usize, &'a T),
688    next: (usize, &'a T),
689) -> (usize, &'a T) {
690    let (ai, av) = acc;
691    let (bi, bv) = next;
692    if is_nan_val(av) {
693        (ai, av)
694    } else if is_nan_val(bv) {
695        (bi, bv)
696    } else if av >= bv {
697        (ai, av)
698    } else {
699        (bi, bv)
700    }
701}
702
703/// Minimum value of array elements over a given axis.
704///
705/// Equivalent to `numpy.min` / `numpy.amin`.
706pub fn min<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
707where
708    T: Element + PartialOrd + Copy,
709    D: Dimension,
710{
711    if a.is_empty() {
712        return Err(FerrayError::invalid_value(
713            "cannot compute min of empty array",
714        ));
715    }
716    // NaN-propagating min: if either operand is NaN (comparison returns false
717    // for both orderings), propagate NaN to match NumPy behavior.
718    let nan_min = |a: T, b: T| -> T {
719        // numpy's np.minimum propagates NaN: if EITHER operand is NaN the
720        // result is NaN (numpy/_core/fromnumeric.py:3214). For integers
721        // `is_nan_val` is always false, so this reduces to the plain order.
722        if is_nan_val(&a) {
723            a
724        } else if is_nan_val(&b) {
725            b
726        } else if a <= b {
727            a
728        } else {
729            b
730        }
731    };
732    let data = borrow_data(a);
733    match axis {
734        None => {
735            let m = data.iter().copied().reduce(nan_min).unwrap();
736            make_result(&[], vec![m])
737        }
738        Some(ax) => {
739            validate_axis(ax, a.ndim())?;
740            let shape = a.shape();
741            let out_s = output_shape(shape, ax);
742            let result = reduce_axis_general(&data, shape, ax, |lane| {
743                lane.iter().copied().reduce(nan_min).unwrap()
744            });
745            make_result(&out_s, result)
746        }
747    }
748}
749
750/// Maximum value of array elements over a given axis.
751///
752/// Equivalent to `numpy.max` / `numpy.amax`.
753pub fn max<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
754where
755    T: Element + PartialOrd + Copy,
756    D: Dimension,
757{
758    if a.is_empty() {
759        return Err(FerrayError::invalid_value(
760            "cannot compute max of empty array",
761        ));
762    }
763    // NaN-propagating max: if EITHER operand is NaN the result is NaN, to
764    // match numpy's np.maximum (numpy/_core/fromnumeric.py:3076). For
765    // integers `is_nan_val` is always false, so this reduces to plain order.
766    let nan_max = |a: T, b: T| -> T {
767        if is_nan_val(&a) {
768            a
769        } else if is_nan_val(&b) {
770            b
771        } else if a >= b {
772            a
773        } else {
774            b
775        }
776    };
777    let data = borrow_data(a);
778    match axis {
779        None => {
780            let m = data.iter().copied().reduce(nan_max).unwrap();
781            make_result(&[], vec![m])
782        }
783        Some(ax) => {
784            validate_axis(ax, a.ndim())?;
785            let shape = a.shape();
786            let out_s = output_shape(shape, ax);
787            let result = reduce_axis_general(&data, shape, ax, |lane| {
788                lane.iter().copied().reduce(nan_max).unwrap()
789            });
790            make_result(&out_s, result)
791        }
792    }
793}
794
795// ---------------------------------------------------------------------------
796// argmin / argmax
797// ---------------------------------------------------------------------------
798
799/// Index of the minimum value. For axis=None, returns the flat index.
800/// For axis=Some(ax), returns indices along that axis.
801///
802/// Equivalent to `numpy.argmin`.
803pub fn argmin<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
804where
805    T: Element + PartialOrd + Copy,
806    D: Dimension,
807{
808    if a.is_empty() {
809        return Err(FerrayError::invalid_value(
810            "cannot compute argmin of empty array",
811        ));
812    }
813    let data = borrow_data(a);
814    match axis {
815        None => {
816            let idx = data.iter().enumerate().reduce(argmin_fold).unwrap().0 as u64;
817            make_result(&[], vec![idx])
818        }
819        Some(ax) => {
820            validate_axis(ax, a.ndim())?;
821            let shape = a.shape();
822            let out_s = output_shape(shape, ax);
823            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
824                lane.iter().enumerate().reduce(argmin_fold).unwrap().0 as u64
825            });
826            make_result(&out_s, result)
827        }
828    }
829}
830
831/// Index of the maximum value.
832///
833/// Equivalent to `numpy.argmax`.
834pub fn argmax<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
835where
836    T: Element + PartialOrd + Copy,
837    D: Dimension,
838{
839    if a.is_empty() {
840        return Err(FerrayError::invalid_value(
841            "cannot compute argmax of empty array",
842        ));
843    }
844    let data = borrow_data(a);
845    match axis {
846        None => {
847            let idx = data.iter().enumerate().reduce(argmax_fold).unwrap().0 as u64;
848            make_result(&[], vec![idx])
849        }
850        Some(ax) => {
851            validate_axis(ax, a.ndim())?;
852            let shape = a.shape();
853            let out_s = output_shape(shape, ax);
854            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
855                lane.iter().enumerate().reduce(argmax_fold).unwrap().0 as u64
856            });
857            make_result(&out_s, result)
858        }
859    }
860}
861
862/// Thin wrapper: `T -> u64` reduction (for `argmin`/`argmax` /
863/// `bincount` paths). Shares its body with `reduce_axis_general`
864/// via `reduce_axis_typed`.
865#[inline]
866pub(crate) fn reduce_axis_general_u64<T, F>(
867    data: &[T],
868    shape: &[usize],
869    axis: usize,
870    f: F,
871) -> Vec<u64>
872where
873    T: Copy,
874    F: Fn(&[T]) -> u64,
875{
876    reduce_axis_typed(data, shape, axis, f)
877}
878
879// ---------------------------------------------------------------------------
880// mean
881// ---------------------------------------------------------------------------
882
883/// Range (peak-to-peak) of array elements over a given axis.
884///
885/// Returns `max - min`. Analogous to `numpy.ptp(a, axis=...)`.
886///
887/// # Errors
888/// Returns `FerrayError::InvalidValue` if the array is empty.
889pub fn ptp<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
890where
891    T: Element + PartialOrd + Copy + std::ops::Sub<Output = T>,
892    D: Dimension,
893{
894    if a.is_empty() {
895        return Err(FerrayError::invalid_value(
896            "cannot compute ptp of empty array",
897        ));
898    }
899    let lo = min(a, axis)?;
900    let hi = max(a, axis)?;
901    let lo_data: Vec<T> = lo.iter().copied().collect();
902    let hi_data: Vec<T> = hi.iter().copied().collect();
903    let result: Vec<T> = hi_data
904        .into_iter()
905        .zip(lo_data)
906        .map(|(h, l)| h - l)
907        .collect();
908    make_result(lo.shape(), result)
909}
910
911/// Weighted average of array elements.
912///
913/// When `weights` is `None`, equivalent to [`mean`]. When `Some(w)`, computes
914/// `sum(a * w) / sum(w)` along the given axis (or over all elements when
915/// `axis=None`). The weights array must have the same shape as `a` (the
916/// 1-D-along-axis broadcasting form supported by NumPy is intentionally
917/// omitted here — call broadcast yourself first if you need it).
918///
919/// Analogous to `numpy.average`.
920///
921/// # Errors
922/// - `FerrayError::InvalidValue` if the array is empty.
923/// - `FerrayError::ShapeMismatch` if `weights.shape() != a.shape()`.
924/// - `FerrayError::InvalidValue` if the weight sum along an axis is zero.
925pub fn average<T, D>(
926    a: &Array<T, D>,
927    weights: Option<&Array<T, D>>,
928    axis: Option<usize>,
929) -> FerrayResult<Array<T, IxDyn>>
930where
931    T: Element + Float + Send + Sync,
932    D: Dimension,
933{
934    let Some(w) = weights else {
935        return mean(a, axis);
936    };
937    if a.is_empty() {
938        return Err(FerrayError::invalid_value(
939            "cannot compute average of empty array",
940        ));
941    }
942    if a.shape() != w.shape() {
943        return Err(FerrayError::shape_mismatch(format!(
944            "average: weights shape {:?} differs from array shape {:?}",
945            w.shape(),
946            a.shape(),
947        )));
948    }
949    let a_data = borrow_data(a);
950    let w_data = borrow_data(w);
951    match axis {
952        None => {
953            let mut wsum = <T as Element>::zero();
954            let mut acc = <T as Element>::zero();
955            for (&x, &wi) in a_data.iter().zip(w_data.iter()) {
956                wsum = wsum + wi;
957                acc = acc + x * wi;
958            }
959            if wsum == <T as Element>::zero() {
960                return Err(FerrayError::invalid_value("average: weights sum to zero"));
961            }
962            make_result(&[], vec![acc / wsum])
963        }
964        Some(ax) => {
965            validate_axis(ax, a.ndim())?;
966            let shape = a.shape();
967            let out_s = output_shape(shape, ax);
968            // Walk lanes for both arrays; we can reuse reduce_axis_general
969            // by zipping data + weights into a tagged buffer. Simpler path:
970            // build a side buffer of (a_lane, w_lane) per lane.
971            let outer: usize = out_s.iter().product::<usize>().max(1);
972            let lane_len = shape[ax];
973            // Use the same lane-walk as reduce_axis_general. We replicate
974            // the inner loop here so we can read both data + weights.
975            let mut result = Vec::with_capacity(outer);
976            // Compute strides for picking out the lane.
977            let mut strides = vec![1usize; shape.len()];
978            for i in (0..shape.len() - 1).rev() {
979                strides[i] = strides[i + 1] * shape[i + 1];
980            }
981            let mut idx = vec![0usize; shape.len()];
982            for _ in 0..outer {
983                let mut wsum = <T as Element>::zero();
984                let mut acc = <T as Element>::zero();
985                for j in 0..lane_len {
986                    idx[ax] = j;
987                    let mut flat = 0usize;
988                    for (d, &s) in idx.iter().zip(strides.iter()) {
989                        flat += d * s;
990                    }
991                    let x = a_data[flat];
992                    let wi = w_data[flat];
993                    wsum = wsum + wi;
994                    acc = acc + x * wi;
995                }
996                if wsum == <T as Element>::zero() {
997                    return Err(FerrayError::invalid_value(
998                        "average: weights sum to zero along axis",
999                    ));
1000                }
1001                result.push(acc / wsum);
1002                // Advance multi-index over output dims (every dim except ax).
1003                idx[ax] = 0;
1004                for d in (0..shape.len()).rev() {
1005                    if d == ax {
1006                        continue;
1007                    }
1008                    idx[d] += 1;
1009                    if idx[d] < shape[d] {
1010                        break;
1011                    }
1012                    idx[d] = 0;
1013                }
1014            }
1015            make_result(&out_s, result)
1016        }
1017    }
1018}
1019
1020/// Mean of array elements over a given axis.
1021///
1022/// Equivalent to `numpy.mean`. The result is always floating-point.
1023pub fn mean<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
1024where
1025    T: Element + Float + Send + Sync,
1026    D: Dimension,
1027{
1028    if a.is_empty() {
1029        return Err(FerrayError::invalid_value(
1030            "cannot compute mean of empty array",
1031        ));
1032    }
1033    let data = borrow_data(a);
1034    match axis {
1035        None => {
1036            let n = T::from(data.len()).unwrap();
1037            let total = try_simd_pairwise_sum(&data)
1038                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()));
1039            make_result(&[], vec![total / n])
1040        }
1041        Some(ax) => {
1042            validate_axis(ax, a.ndim())?;
1043            let shape = a.shape();
1044            let out_s = output_shape(shape, ax);
1045            let axis_len = shape[ax];
1046            let n = T::from(axis_len).unwrap();
1047            let result = reduce_axis_general(&data, shape, ax, |lane| {
1048                let total = try_simd_pairwise_sum(lane)
1049                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
1050                total / n
1051            });
1052            make_result(&out_s, result)
1053        }
1054    }
1055}
1056
1057/// Mean of array elements, returning `f64` regardless of input type.
1058///
1059/// This works on integer arrays (i32, u64, etc.) where [`mean`] would
1060/// fail because integers don't implement `Float`. The result is always
1061/// `Array<f64, IxDyn>`, matching `NumPy`'s behavior of promoting integer
1062/// means to float64.
1063///
1064/// Equivalent to `numpy.mean` for integer inputs.
1065pub fn mean_as_f64<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<f64, IxDyn>>
1066where
1067    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
1068    D: Dimension,
1069{
1070    if a.is_empty() {
1071        return Err(FerrayError::invalid_value(
1072            "cannot compute mean of empty array",
1073        ));
1074    }
1075    match axis {
1076        None => {
1077            let n = a.size() as f64;
1078            let total: f64 = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).sum();
1079            make_result(&[], vec![total / n])
1080        }
1081        Some(ax) => {
1082            validate_axis(ax, a.ndim())?;
1083            let shape = a.shape();
1084            let out_s = output_shape(shape, ax);
1085            let axis_len = shape[ax] as f64;
1086            let f64_data: Vec<f64> = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).collect();
1087            let result = reduce_axis_general(&f64_data, shape, ax, |lane| {
1088                let total: f64 = lane.iter().sum();
1089                total / axis_len
1090            });
1091            make_result(&out_s, result)
1092        }
1093    }
1094}
1095
1096// ---------------------------------------------------------------------------
1097// var
1098// ---------------------------------------------------------------------------
1099
1100/// Variance of array elements over a given axis.
1101///
1102/// `ddof` is the delta degrees of freedom (0 for population variance, 1 for sample).
1103/// Equivalent to `numpy.var`.
1104pub fn var<T, D>(a: &Array<T, D>, axis: Option<usize>, ddof: usize) -> FerrayResult<Array<T, IxDyn>>
1105where
1106    T: Element + Float + Send + Sync,
1107    D: Dimension,
1108{
1109    if a.is_empty() {
1110        return Err(FerrayError::invalid_value(
1111            "cannot compute variance of empty array",
1112        ));
1113    }
1114    let data = borrow_data(a);
1115    match axis {
1116        None => {
1117            let n = data.len();
1118            let nf = T::from(n).unwrap();
1119            let mean_val = try_simd_pairwise_sum(&data)
1120                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()))
1121                / nf;
1122            let sum_sq = try_simd_sum_sq_diff(&data, mean_val).unwrap_or_else(|| {
1123                data.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1124                    let d = x - mean_val;
1125                    acc + d * d
1126                })
1127            });
1128            // numpy/_core/_methods.py:204 — `rcount = um.maximum(rcount - ddof, 0)`,
1129            // then true_divide. When ddof >= n the divisor clamps to 0 and IEEE-754
1130            // float division yields +inf (sum_sq > 0) or NaN (sum_sq == 0), NOT an
1131            // error. Saturating-sub keeps the divisor at 0 instead of underflowing.
1132            let divisor = T::from(n.saturating_sub(ddof)).unwrap_or_else(<T as Element>::zero);
1133            let var_val = sum_sq / divisor;
1134            make_result(&[], vec![var_val])
1135        }
1136        Some(ax) => {
1137            validate_axis(ax, a.ndim())?;
1138            let shape = a.shape();
1139            let out_s = output_shape(shape, ax);
1140            let axis_len = shape[ax];
1141            let nf = T::from(axis_len).unwrap();
1142            // numpy/_core/_methods.py:204 clamps `rcount = max(axis_len - ddof, 0)`
1143            // then true-divides: ddof >= axis_len yields +inf / NaN, never an error.
1144            let denom = T::from(axis_len.saturating_sub(ddof)).unwrap_or_else(<T as Element>::zero);
1145            let result = reduce_axis_general(&data, shape, ax, |lane| {
1146                let mean_val = try_simd_pairwise_sum(lane)
1147                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1148                    / nf;
1149                let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1150                    lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1151                        let d = x - mean_val;
1152                        acc + d * d
1153                    })
1154                });
1155                sum_sq / denom
1156            });
1157            make_result(&out_s, result)
1158        }
1159    }
1160}
1161
1162// ---------------------------------------------------------------------------
1163// std_
1164// ---------------------------------------------------------------------------
1165
1166/// Standard deviation of array elements over a given axis.
1167///
1168/// `ddof` is the delta degrees of freedom.
1169/// Equivalent to `numpy.std`.
1170pub fn std_<T, D>(
1171    a: &Array<T, D>,
1172    axis: Option<usize>,
1173    ddof: usize,
1174) -> FerrayResult<Array<T, IxDyn>>
1175where
1176    T: Element + Float + Send + Sync,
1177    D: Dimension,
1178{
1179    let v = var(a, axis, ddof)?;
1180    let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
1181    make_result(v.shape(), data)
1182}
1183
1184/// Variance with integer-to-float64 promotion.
1185///
1186/// Sibling of [`mean_as_f64`] — accepts any `T: ToPrimitive` (including
1187/// `i64`/`i32`/`u64`/etc.) and returns an `Array<f64, IxDyn>`. Matches
1188/// `NumPy`'s behaviour of promoting integer variance to f64 (#170).
1189///
1190/// `ddof` is the delta degrees of freedom (0 for population, 1 for sample).
1191pub fn var_as_f64<T, D>(
1192    a: &Array<T, D>,
1193    axis: Option<usize>,
1194    ddof: usize,
1195) -> FerrayResult<Array<f64, IxDyn>>
1196where
1197    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
1198    D: Dimension,
1199{
1200    if a.is_empty() {
1201        return Err(FerrayError::invalid_value(
1202            "cannot compute variance of empty array",
1203        ));
1204    }
1205    // Promote each element to f64 once, then run the standard
1206    // float-typed `var` on the promoted array.
1207    let promoted: Vec<f64> = a
1208        .iter()
1209        .map(|v| {
1210            v.to_f64()
1211                .expect("ToPrimitive failed during var_as_f64 promotion")
1212        })
1213        .collect();
1214    let promoted_arr = Array::<f64, _>::from_vec(a.dim().clone(), promoted)?;
1215    var(&promoted_arr, axis, ddof)
1216}
1217
1218/// Standard deviation with integer-to-float64 promotion (#170).
1219///
1220/// Sibling of [`mean_as_f64`] / [`var_as_f64`]; sqrt of the float-promoted
1221/// variance.
1222pub fn std_as_f64<T, D>(
1223    a: &Array<T, D>,
1224    axis: Option<usize>,
1225    ddof: usize,
1226) -> FerrayResult<Array<f64, IxDyn>>
1227where
1228    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
1229    D: Dimension,
1230{
1231    let v = var_as_f64(a, axis, ddof)?;
1232    let data: Vec<f64> = v.iter().map(|x| x.sqrt()).collect();
1233    make_result(v.shape(), data)
1234}
1235
1236// ---------------------------------------------------------------------------
1237// Multi-axis + keepdims public API (issues #457 + #458)
1238//
1239// These `*_axes` variants complement the existing single-axis reductions
1240// with two additional features:
1241//   1. `axes: Option<&[usize]>` — reduce over multiple axes at once
1242//      (None or empty slice means "reduce all axes")
1243//   2. `keepdims: bool` — preserve reduced axes as size 1 so the result
1244//      can broadcast back against the original array
1245//
1246// The existing single-axis `sum/prod/etc.` functions remain unchanged.
1247// ---------------------------------------------------------------------------
1248
1249/// Multi-axis sum with optional `keepdims`.
1250///
1251/// Equivalent to `numpy.sum(a, axis=axes, keepdims=keepdims)`. If `axes` is
1252/// `None` or an empty slice, reduces over all axes.
1253///
1254/// # Errors
1255/// Returns `FerrayError::AxisOutOfBounds` on any out-of-range axis, or
1256/// `FerrayError::InvalidValue` on duplicate axes.
1257pub fn sum_axes<T, D>(
1258    a: &Array<T, D>,
1259    axes: Option<&[usize]>,
1260    keepdims: bool,
1261) -> FerrayResult<Array<T, IxDyn>>
1262where
1263    T: Element + std::ops::Add<Output = T> + Copy + Send + Sync,
1264    D: Dimension,
1265{
1266    let ax = normalize_axes(axes, a.ndim())?;
1267    let data = borrow_data(a);
1268    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1269        try_simd_pairwise_sum(lane)
1270            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1271    });
1272    make_result(&out_shape, result)
1273}
1274
1275/// Multi-axis product with optional `keepdims`.
1276pub fn prod_axes<T, D>(
1277    a: &Array<T, D>,
1278    axes: Option<&[usize]>,
1279    keepdims: bool,
1280) -> FerrayResult<Array<T, IxDyn>>
1281where
1282    T: Element + std::ops::Mul<Output = T> + Copy + Send + Sync,
1283    D: Dimension,
1284{
1285    let ax = normalize_axes(axes, a.ndim())?;
1286    let data = borrow_data(a);
1287    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1288        lane.iter()
1289            .copied()
1290            .fold(<T as Element>::one(), |acc, x| acc * x)
1291    });
1292    make_result(&out_shape, result)
1293}
1294
1295/// Multi-axis minimum with optional `keepdims`. NaN-propagating.
1296pub fn min_axes<T, D>(
1297    a: &Array<T, D>,
1298    axes: Option<&[usize]>,
1299    keepdims: bool,
1300) -> FerrayResult<Array<T, IxDyn>>
1301where
1302    T: Element + PartialOrd + Copy,
1303    D: Dimension,
1304{
1305    if a.is_empty() {
1306        return Err(FerrayError::invalid_value(
1307            "cannot compute min of empty array",
1308        ));
1309    }
1310    let ax = normalize_axes(axes, a.ndim())?;
1311    let nan_min = |a: T, b: T| -> T {
1312        // NaN propagates whether it precedes or follows the finite operand
1313        // (mirror of the scalar `min` fix, #1058). For integers `is_nan_val`
1314        // is always false, so this reduces to the plain `a <= b` order.
1315        if is_nan_val(&a) {
1316            a
1317        } else if is_nan_val(&b) {
1318            b
1319        } else if a <= b {
1320            a
1321        } else {
1322            b
1323        }
1324    };
1325    let data = borrow_data(a);
1326    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1327        lane.iter().copied().reduce(nan_min).unwrap()
1328    });
1329    make_result(&out_shape, result)
1330}
1331
1332/// Multi-axis maximum with optional `keepdims`. NaN-propagating.
1333pub fn max_axes<T, D>(
1334    a: &Array<T, D>,
1335    axes: Option<&[usize]>,
1336    keepdims: bool,
1337) -> FerrayResult<Array<T, IxDyn>>
1338where
1339    T: Element + PartialOrd + Copy,
1340    D: Dimension,
1341{
1342    if a.is_empty() {
1343        return Err(FerrayError::invalid_value(
1344            "cannot compute max of empty array",
1345        ));
1346    }
1347    let ax = normalize_axes(axes, a.ndim())?;
1348    let nan_max = |a: T, b: T| -> T {
1349        // NaN propagates whether it precedes or follows the finite operand
1350        // (mirror of the scalar `max` fix, #1058). For integers `is_nan_val`
1351        // is always false, so this reduces to the plain `a >= b` order.
1352        if is_nan_val(&a) {
1353            a
1354        } else if is_nan_val(&b) {
1355            b
1356        } else if a >= b {
1357            a
1358        } else {
1359            b
1360        }
1361    };
1362    let data = borrow_data(a);
1363    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1364        lane.iter().copied().reduce(nan_max).unwrap()
1365    });
1366    make_result(&out_shape, result)
1367}
1368
1369/// Multi-axis mean with optional `keepdims`.
1370pub fn mean_axes<T, D>(
1371    a: &Array<T, D>,
1372    axes: Option<&[usize]>,
1373    keepdims: bool,
1374) -> FerrayResult<Array<T, IxDyn>>
1375where
1376    T: Element + Float + Send + Sync,
1377    D: Dimension,
1378{
1379    if a.is_empty() {
1380        return Err(FerrayError::invalid_value(
1381            "cannot compute mean of empty array",
1382        ));
1383    }
1384    let ax = normalize_axes(axes, a.ndim())?;
1385    // Compute n = product of reduced-axis lengths.
1386    let shape = a.shape();
1387    let lane_len: usize = ax.iter().map(|&i| shape[i]).product();
1388    let n = T::from(lane_len).unwrap();
1389    let data = borrow_data(a);
1390    let (result, out_shape) = reduce_axes_general(&data, shape, &ax, keepdims, |lane| {
1391        let total = try_simd_pairwise_sum(lane)
1392            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
1393        total / n
1394    });
1395    make_result(&out_shape, result)
1396}
1397
1398/// Multi-axis variance with optional `keepdims` and Bessel correction `ddof`.
1399pub fn var_axes<T, D>(
1400    a: &Array<T, D>,
1401    axes: Option<&[usize]>,
1402    ddof: usize,
1403    keepdims: bool,
1404) -> FerrayResult<Array<T, IxDyn>>
1405where
1406    T: Element + Float + Send + Sync,
1407    D: Dimension,
1408{
1409    if a.is_empty() {
1410        return Err(FerrayError::invalid_value(
1411            "cannot compute variance of empty array",
1412        ));
1413    }
1414    let ax = normalize_axes(axes, a.ndim())?;
1415    let shape = a.shape();
1416    let lane_len: usize = ax.iter().map(|&i| shape[i]).product();
1417    let nf = T::from(lane_len).unwrap_or_else(<T as Element>::zero);
1418    // numpy/_core/_methods.py:204 clamps `rcount = um.maximum(rcount - ddof, 0)`
1419    // then true-divides (:208): ddof >= reduced-axis length yields +inf
1420    // (sum_sq > 0) / NaN (sum_sq == 0), never an error. Mirrors the var/var_into
1421    // fixes in this file (saturating_sub + IEEE-754 division).
1422    let denom = T::from(lane_len.saturating_sub(ddof)).unwrap_or_else(<T as Element>::zero);
1423    let data = borrow_data(a);
1424    let (result, out_shape) = reduce_axes_general(&data, shape, &ax, keepdims, |lane| {
1425        let mean_val = try_simd_pairwise_sum(lane)
1426            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1427            / nf;
1428        let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1429            lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1430                let d = x - mean_val;
1431                acc + d * d
1432            })
1433        });
1434        sum_sq / denom
1435    });
1436    make_result(&out_shape, result)
1437}
1438
1439/// Multi-axis standard deviation with optional `keepdims`.
1440pub fn std_axes<T, D>(
1441    a: &Array<T, D>,
1442    axes: Option<&[usize]>,
1443    ddof: usize,
1444    keepdims: bool,
1445) -> FerrayResult<Array<T, IxDyn>>
1446where
1447    T: Element + Float + Send + Sync,
1448    D: Dimension,
1449{
1450    let v = var_axes(a, axes, ddof, keepdims)?;
1451    let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
1452    make_result(v.shape(), data)
1453}
1454
1455// ---------------------------------------------------------------------------
1456// `*_into` reductions: write into a caller-provided destination
1457//
1458// NumPy reductions accept an `out=` parameter that lets callers reuse a
1459// pre-allocated output buffer across repeated reductions on same-shaped
1460// data — common in streaming statistics, online ML metrics, etc. The
1461// ferray-stats reduction functions all allocate a fresh `Array<T, IxDyn>`
1462// internally, so we add `*_into` companion functions that take a
1463// `&mut Array<T, IxDyn>` destination and skip the allocation entirely.
1464//
1465// History:
1466//   #467 introduced the `*_into` API surface but the first implementation
1467//   still went through the allocating kernel and copied into `out`,
1468//   leaving one Vec materialization per call. #563 plumbs the destination
1469//   slice through the kernel itself via `reduce_axis_typed_into` so the
1470//   path is truly zero-alloc — only the per-lane scratch buffer
1471//   (allocated once per call, reused for every lane) and the input
1472//   contig-borrow allocation (only when the input is already non-contig)
1473//   remain.
1474// ---------------------------------------------------------------------------
1475
1476/// Sum reduction writing into a pre-allocated destination.
1477///
1478/// Equivalent to `np.sum(a, axis=axis, out=out)`. The destination must
1479/// be C-contiguous and have exactly the shape that `sum(a, axis)` would
1480/// produce; broadcasting is not supported.
1481///
1482/// # Errors
1483/// - `FerrayError::AxisOutOfBounds` if `axis` is out of range.
1484/// - `FerrayError::ShapeMismatch` if `out.shape()` does not match the
1485///   expected reduction shape.
1486/// - `FerrayError::InvalidValue` if `out` is not C-contiguous.
1487pub fn sum_into<T, D>(
1488    a: &Array<T, D>,
1489    axis: Option<usize>,
1490    out: &mut Array<T, IxDyn>,
1491) -> FerrayResult<()>
1492where
1493    T: Element + std::ops::Add<Output = T> + Copy + Send + Sync,
1494    D: Dimension,
1495{
1496    let data = borrow_data(a);
1497    match axis {
1498        None => {
1499            let dst = check_out_shape(out, &[], "sum")?;
1500            let total = try_simd_pairwise_sum(&data)
1501                .unwrap_or_else(|| parallel::parallel_sum(&data, <T as Element>::zero()));
1502            dst[0] = total;
1503            Ok(())
1504        }
1505        Some(ax) => {
1506            validate_axis(ax, a.ndim())?;
1507            let shape_vec = a.shape().to_vec();
1508            let out_s = output_shape(&shape_vec, ax);
1509            let dst = check_out_shape(out, &out_s, "sum")?;
1510            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1511                try_simd_pairwise_sum(lane)
1512                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1513            });
1514            Ok(())
1515        }
1516    }
1517}
1518
1519/// Product reduction writing into a pre-allocated destination.
1520///
1521/// See [`sum_into`] for the contract on `out`.
1522pub fn prod_into<T, D>(
1523    a: &Array<T, D>,
1524    axis: Option<usize>,
1525    out: &mut Array<T, IxDyn>,
1526) -> FerrayResult<()>
1527where
1528    T: Element + std::ops::Mul<Output = T> + Copy + Send + Sync,
1529    D: Dimension,
1530{
1531    let data = borrow_data(a);
1532    match axis {
1533        None => {
1534            let dst = check_out_shape(out, &[], "prod")?;
1535            dst[0] = parallel::parallel_prod(&data, <T as Element>::one());
1536            Ok(())
1537        }
1538        Some(ax) => {
1539            validate_axis(ax, a.ndim())?;
1540            let shape_vec = a.shape().to_vec();
1541            let out_s = output_shape(&shape_vec, ax);
1542            let dst = check_out_shape(out, &out_s, "prod")?;
1543            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1544                lane.iter()
1545                    .copied()
1546                    .fold(<T as Element>::one(), |acc, x| acc * x)
1547            });
1548            Ok(())
1549        }
1550    }
1551}
1552
1553/// Min reduction writing into a pre-allocated destination.
1554///
1555/// See [`sum_into`] for the contract on `out`. Empty input arrays return
1556/// `FerrayError::InvalidValue` because `min` of an empty set is undefined.
1557pub fn min_into<T, D>(
1558    a: &Array<T, D>,
1559    axis: Option<usize>,
1560    out: &mut Array<T, IxDyn>,
1561) -> FerrayResult<()>
1562where
1563    T: Element + PartialOrd + Copy + Send + Sync,
1564    D: Dimension,
1565{
1566    if a.is_empty() {
1567        return Err(FerrayError::invalid_value(
1568            "cannot compute min of empty array",
1569        ));
1570    }
1571    // Same NaN-propagating reducer as `min` so the two paths agree.
1572    let nan_min = |a: T, b: T| -> T {
1573        if a <= b {
1574            a
1575        } else if a > b {
1576            b
1577        } else {
1578            a
1579        }
1580    };
1581    let data = borrow_data(a);
1582    match axis {
1583        None => {
1584            let dst = check_out_shape(out, &[], "min")?;
1585            dst[0] = data.iter().copied().reduce(nan_min).unwrap();
1586            Ok(())
1587        }
1588        Some(ax) => {
1589            validate_axis(ax, a.ndim())?;
1590            let shape_vec = a.shape().to_vec();
1591            let out_s = output_shape(&shape_vec, ax);
1592            let dst = check_out_shape(out, &out_s, "min")?;
1593            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1594                lane.iter().copied().reduce(nan_min).unwrap()
1595            });
1596            Ok(())
1597        }
1598    }
1599}
1600
1601/// Max reduction writing into a pre-allocated destination.
1602///
1603/// See [`sum_into`] for the contract on `out`. Empty input arrays return
1604/// `FerrayError::InvalidValue`.
1605pub fn max_into<T, D>(
1606    a: &Array<T, D>,
1607    axis: Option<usize>,
1608    out: &mut Array<T, IxDyn>,
1609) -> FerrayResult<()>
1610where
1611    T: Element + PartialOrd + Copy + Send + Sync,
1612    D: Dimension,
1613{
1614    if a.is_empty() {
1615        return Err(FerrayError::invalid_value(
1616            "cannot compute max of empty array",
1617        ));
1618    }
1619    let nan_max = |a: T, b: T| -> T {
1620        if a >= b {
1621            a
1622        } else if a < b {
1623            b
1624        } else {
1625            a
1626        }
1627    };
1628    let data = borrow_data(a);
1629    match axis {
1630        None => {
1631            let dst = check_out_shape(out, &[], "max")?;
1632            dst[0] = data.iter().copied().reduce(nan_max).unwrap();
1633            Ok(())
1634        }
1635        Some(ax) => {
1636            validate_axis(ax, a.ndim())?;
1637            let shape_vec = a.shape().to_vec();
1638            let out_s = output_shape(&shape_vec, ax);
1639            let dst = check_out_shape(out, &out_s, "max")?;
1640            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1641                lane.iter().copied().reduce(nan_max).unwrap()
1642            });
1643            Ok(())
1644        }
1645    }
1646}
1647
1648/// Mean reduction writing into a pre-allocated destination.
1649///
1650/// See [`sum_into`] for the contract on `out`. Empty inputs return
1651/// `FerrayError::InvalidValue`.
1652pub fn mean_into<T, D>(
1653    a: &Array<T, D>,
1654    axis: Option<usize>,
1655    out: &mut Array<T, IxDyn>,
1656) -> FerrayResult<()>
1657where
1658    T: Element + Float + Send + Sync,
1659    D: Dimension,
1660{
1661    if a.is_empty() {
1662        return Err(FerrayError::invalid_value(
1663            "cannot compute mean of empty array",
1664        ));
1665    }
1666    let data = borrow_data(a);
1667    match axis {
1668        None => {
1669            let dst = check_out_shape(out, &[], "mean")?;
1670            let n = T::from(data.len()).unwrap();
1671            let total = try_simd_pairwise_sum(&data)
1672                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()));
1673            dst[0] = total / n;
1674            Ok(())
1675        }
1676        Some(ax) => {
1677            validate_axis(ax, a.ndim())?;
1678            let shape_vec = a.shape().to_vec();
1679            let out_s = output_shape(&shape_vec, ax);
1680            let axis_len = shape_vec[ax];
1681            let n = T::from(axis_len).unwrap();
1682            let dst = check_out_shape(out, &out_s, "mean")?;
1683            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1684                let total = try_simd_pairwise_sum(lane)
1685                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
1686                total / n
1687            });
1688            Ok(())
1689        }
1690    }
1691}
1692
1693/// Variance reduction writing into a pre-allocated destination.
1694///
1695/// `ddof` is the delta degrees of freedom. See [`sum_into`] for the
1696/// contract on `out`. Returns `FerrayError::InvalidValue` for empty
1697/// inputs. When `ddof >= n` the divisor clamps to 0 and the result is
1698/// `+inf` (or `NaN` when the sum-of-squares is 0), matching
1699/// `numpy.var` (`numpy/_core/_methods.py:204`); no error is raised.
1700pub fn var_into<T, D>(
1701    a: &Array<T, D>,
1702    axis: Option<usize>,
1703    ddof: usize,
1704    out: &mut Array<T, IxDyn>,
1705) -> FerrayResult<()>
1706where
1707    T: Element + Float + Send + Sync,
1708    D: Dimension,
1709{
1710    if a.is_empty() {
1711        return Err(FerrayError::invalid_value(
1712            "cannot compute variance of empty array",
1713        ));
1714    }
1715    let data = borrow_data(a);
1716    match axis {
1717        None => {
1718            let n = data.len();
1719            let dst = check_out_shape(out, &[], "var")?;
1720            let nf = T::from(n).unwrap();
1721            let mean_val = try_simd_pairwise_sum(&data)
1722                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()))
1723                / nf;
1724            let sum_sq = try_simd_sum_sq_diff(&data, mean_val).unwrap_or_else(|| {
1725                data.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1726                    let d = x - mean_val;
1727                    acc + d * d
1728                })
1729            });
1730            // numpy/_core/_methods.py:204 — `rcount = um.maximum(rcount - ddof, 0)`,
1731            // then true_divide. When ddof >= n the divisor clamps to 0 and IEEE-754
1732            // float division yields +inf (sum_sq > 0) or NaN (sum_sq == 0), NOT an
1733            // error. Saturating-sub keeps the divisor at 0 instead of underflowing.
1734            let divisor = T::from(n.saturating_sub(ddof)).unwrap_or_else(<T as Element>::zero);
1735            dst[0] = sum_sq / divisor;
1736            Ok(())
1737        }
1738        Some(ax) => {
1739            validate_axis(ax, a.ndim())?;
1740            let shape_vec = a.shape().to_vec();
1741            let axis_len = shape_vec[ax];
1742            let out_s = output_shape(&shape_vec, ax);
1743            let nf = T::from(axis_len).unwrap();
1744            // numpy/_core/_methods.py:204 clamps `rcount = max(axis_len - ddof, 0)`
1745            // then true-divides: ddof >= axis_len yields +inf / NaN, never an error.
1746            let denom = T::from(axis_len.saturating_sub(ddof)).unwrap_or_else(<T as Element>::zero);
1747            let dst = check_out_shape(out, &out_s, "var")?;
1748            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1749                let mean_val = try_simd_pairwise_sum(lane)
1750                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1751                    / nf;
1752                let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1753                    lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1754                        let d = x - mean_val;
1755                        acc + d * d
1756                    })
1757                });
1758                sum_sq / denom
1759            });
1760            Ok(())
1761        }
1762    }
1763}
1764
1765/// Standard deviation reduction writing into a pre-allocated destination.
1766///
1767/// Computes the variance directly into `out` and then takes the
1768/// element-wise square root in place — the std value never lives in any
1769/// intermediate buffer.
1770pub fn std_into<T, D>(
1771    a: &Array<T, D>,
1772    axis: Option<usize>,
1773    ddof: usize,
1774    out: &mut Array<T, IxDyn>,
1775) -> FerrayResult<()>
1776where
1777    T: Element + Float + Send + Sync,
1778    D: Dimension,
1779{
1780    var_into(a, axis, ddof, out)?;
1781    // Variance is now in `out`; sqrt each element in place.
1782    let dst = out
1783        .as_slice_mut()
1784        .ok_or_else(|| FerrayError::invalid_value("std_into: out must be C-contiguous"))?;
1785    for slot in dst.iter_mut() {
1786        *slot = slot.sqrt();
1787    }
1788    Ok(())
1789}
1790
1791// ---------------------------------------------------------------------------
1792// `*_with` reductions: NumPy `initial=` and `where=` parameters (#459)
1793//
1794// NumPy reductions accept two extra parameters that ferray-stats was
1795// missing entirely:
1796//
1797//   - `initial`: starting value for the accumulator. Lets `np.sum([])`
1798//     return a non-zero seed, lets `np.min(arr, initial=999)` provide
1799//     a fallback for empty inputs, and lets cumulative-style code
1800//     thread an initial state through a reduction.
1801//
1802//   - `where`: same-shape boolean mask. Only positions where the mask
1803//     is `true` contribute to the reduction. This is the masked-array
1804//     equivalent without materializing a MaskedArray — `np.sum(arr,
1805//     where=arr > 0)` is the canonical "sum the positives" idiom.
1806//
1807// The `*_with` family adds these as `Option<T>` / `Option<&Array<bool,
1808// D>>` parameters. Mask broadcasting is intentionally not supported in
1809// this first cut — the mask must have exactly the same shape as the
1810// input. NumPy allows broadcast-compatible masks, but the same-shape
1811// path covers every common use case (predicates produced by comparison
1812// ufuncs against the same input always have a matching shape).
1813// ---------------------------------------------------------------------------
1814
1815/// Prepare a `where` mask for use by a `*_with` reduction kernel.
1816///
1817/// Accepts an `Option<&Array<bool, IxDyn>>` so the mask can have any
1818/// rank independently of the input — callers with typed masks should
1819/// call `.to_dyn()` before passing in.
1820///
1821/// Returns:
1822/// - `Ok(None)` if the caller passed no mask.
1823/// - `Ok(Some(vec))` where `vec` is a row-major materialization of the
1824///   mask aligned with `a`'s flat logical order. For same-shape masks
1825///   this is a direct copy; for broadcast-compatible masks the mask is
1826///   broadcast into `a.shape()` first via
1827///   [`ferray_core::dimension::broadcast::broadcast_to`] (#565) and
1828///   then materialized.
1829/// - `Err(FerrayError::ShapeMismatch)` if the mask is not
1830///   broadcast-compatible with `a.shape()`.
1831fn prepare_where_mask<T, D>(
1832    a: &Array<T, D>,
1833    mask: Option<&Array<bool, IxDyn>>,
1834    op_name: &str,
1835) -> FerrayResult<Option<Vec<bool>>>
1836where
1837    T: Element,
1838    D: Dimension,
1839{
1840    use ferray_core::dimension::broadcast::broadcast_to;
1841
1842    let Some(m) = mask else {
1843        return Ok(None);
1844    };
1845
1846    // Fast path: the mask is already the same shape as the input.
1847    if m.shape() == a.shape() {
1848        return Ok(Some(m.iter().copied().collect()));
1849    }
1850
1851    // Broadcast path: let the core machinery check compatibility and
1852    // produce a stride-tricked view at the target shape. We then
1853    // materialize it into a flat Vec<bool> aligned with `a`'s logical
1854    // row-major order.
1855    let view = broadcast_to(m, a.shape()).map_err(|_| {
1856        FerrayError::shape_mismatch(format!(
1857            "{op_name}: where mask shape {:?} is not broadcast-compatible with array shape {:?}",
1858            m.shape(),
1859            a.shape()
1860        ))
1861    })?;
1862    Ok(Some(view.iter().copied().collect()))
1863}
1864
1865/// Walk a single axis-Some output position, gathering masked-true lane
1866/// values into `lane_buf` (cleared first). Used by every `*_with` axis
1867/// path so the lane-gather logic exists in one place.
1868///
1869/// `out_multi` is the output multi-index excluding `axis`. The function
1870/// reconstructs the corresponding input flat positions and consults the
1871/// mask in lockstep.
1872fn gather_lane_with_mask<T: Copy>(
1873    data: &[T],
1874    mask: Option<&[bool]>,
1875    shape: &[usize],
1876    strides: &[usize],
1877    axis: usize,
1878    out_multi: &[usize],
1879    lane_buf: &mut Vec<T>,
1880) {
1881    let ndim = shape.len();
1882    let axis_len = shape[axis];
1883    let mut in_multi = vec![0usize; ndim];
1884    let mut out_dim = 0;
1885    for (d, idx) in in_multi.iter_mut().enumerate() {
1886        if d == axis {
1887            *idx = 0;
1888        } else {
1889            *idx = out_multi[out_dim];
1890            out_dim += 1;
1891        }
1892    }
1893    lane_buf.clear();
1894    for k in 0..axis_len {
1895        in_multi[axis] = k;
1896        let idx = flat_index(&in_multi, strides);
1897        match mask {
1898            Some(m) if !m[idx] => {}
1899            _ => lane_buf.push(data[idx]),
1900        }
1901    }
1902}
1903
1904/// Sum reduction with `initial` and `where` parameters.
1905///
1906/// Equivalent to `np.sum(a, axis=axis, initial=initial, where=where_mask)`.
1907///
1908/// `initial` defaults to `T::zero()` when `None`. `where_mask`, when
1909/// provided, must be broadcast-compatible with `a.shape()` (#565).
1910/// Callers with typed masks (e.g. `Array<bool, Ix1>`) should call
1911/// `.to_dyn()` before passing. Positions where the broadcast mask is
1912/// `false` are skipped.
1913///
1914/// # Errors
1915/// - `FerrayError::AxisOutOfBounds` if `axis` is out of range.
1916/// - `FerrayError::ShapeMismatch` if `where_mask` is not
1917///   broadcast-compatible with `a.shape()`.
1918pub fn sum_with<T, D>(
1919    a: &Array<T, D>,
1920    axis: Option<usize>,
1921    initial: Option<T>,
1922    where_mask: Option<&Array<bool, IxDyn>>,
1923) -> FerrayResult<Array<T, IxDyn>>
1924where
1925    T: Element + std::ops::Add<Output = T> + Copy,
1926    D: Dimension,
1927{
1928    let init = initial.unwrap_or_else(<T as Element>::zero);
1929    let data = borrow_data(a);
1930    let mask_vec = prepare_where_mask(a, where_mask, "sum")?;
1931    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
1932
1933    match axis {
1934        None => {
1935            let total = match mask_slice {
1936                None => data.iter().copied().fold(init, |acc, x| acc + x),
1937                Some(mask) => data
1938                    .iter()
1939                    .zip(mask.iter())
1940                    .filter(|&(_, &m)| m)
1941                    .fold(init, |acc, (&x, _)| acc + x),
1942            };
1943            make_result(&[], vec![total])
1944        }
1945        Some(ax) => {
1946            validate_axis(ax, a.ndim())?;
1947            let shape_vec = a.shape().to_vec();
1948            let out_s = output_shape(&shape_vec, ax);
1949            let strides = compute_strides(&shape_vec);
1950
1951            let out_size: usize = if out_s.is_empty() {
1952                1
1953            } else {
1954                out_s.iter().product()
1955            };
1956            let mut result: Vec<T> = Vec::with_capacity(out_size);
1957            let mut out_multi = vec![0usize; out_s.len()];
1958            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
1959
1960            for _ in 0..out_size {
1961                gather_lane_with_mask(
1962                    &data,
1963                    mask_slice,
1964                    &shape_vec,
1965                    &strides,
1966                    ax,
1967                    &out_multi,
1968                    &mut lane_buf,
1969                );
1970                let lane_sum = lane_buf.iter().copied().fold(init, |acc, x| acc + x);
1971                result.push(lane_sum);
1972                if !out_s.is_empty() {
1973                    increment_multi_index(&mut out_multi, &out_s);
1974                }
1975            }
1976            make_result(&out_s, result)
1977        }
1978    }
1979}
1980
1981/// Product reduction with `initial` and `where` parameters.
1982///
1983/// Equivalent to `np.prod(a, axis=axis, initial=initial, where=where_mask)`.
1984/// `initial` defaults to `T::one()` when `None`. `where_mask` is
1985/// broadcast-compatible with `a.shape()` (#565).
1986pub fn prod_with<T, D>(
1987    a: &Array<T, D>,
1988    axis: Option<usize>,
1989    initial: Option<T>,
1990    where_mask: Option<&Array<bool, IxDyn>>,
1991) -> FerrayResult<Array<T, IxDyn>>
1992where
1993    T: Element + std::ops::Mul<Output = T> + Copy,
1994    D: Dimension,
1995{
1996    let init = initial.unwrap_or_else(<T as Element>::one);
1997    let data = borrow_data(a);
1998    let mask_vec = prepare_where_mask(a, where_mask, "prod")?;
1999    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2000
2001    match axis {
2002        None => {
2003            let total = match mask_slice {
2004                None => data.iter().copied().fold(init, |acc, x| acc * x),
2005                Some(mask) => data
2006                    .iter()
2007                    .zip(mask.iter())
2008                    .filter(|&(_, &m)| m)
2009                    .fold(init, |acc, (&x, _)| acc * x),
2010            };
2011            make_result(&[], vec![total])
2012        }
2013        Some(ax) => {
2014            validate_axis(ax, a.ndim())?;
2015            let shape_vec = a.shape().to_vec();
2016            let out_s = output_shape(&shape_vec, ax);
2017            let strides = compute_strides(&shape_vec);
2018
2019            let out_size: usize = if out_s.is_empty() {
2020                1
2021            } else {
2022                out_s.iter().product()
2023            };
2024            let mut result: Vec<T> = Vec::with_capacity(out_size);
2025            let mut out_multi = vec![0usize; out_s.len()];
2026            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2027
2028            for _ in 0..out_size {
2029                gather_lane_with_mask(
2030                    &data,
2031                    mask_slice,
2032                    &shape_vec,
2033                    &strides,
2034                    ax,
2035                    &out_multi,
2036                    &mut lane_buf,
2037                );
2038                let lane_prod = lane_buf.iter().copied().fold(init, |acc, x| acc * x);
2039                result.push(lane_prod);
2040                if !out_s.is_empty() {
2041                    increment_multi_index(&mut out_multi, &out_s);
2042                }
2043            }
2044            make_result(&out_s, result)
2045        }
2046    }
2047}
2048
2049/// Min reduction with `initial` and `where` parameters.
2050///
2051/// Equivalent to `np.min(a, axis=axis, initial=initial, where=where_mask)`.
2052/// Unlike plain `min`, an empty input (or a fully-masked-out lane) is
2053/// allowed when `initial` is supplied — the result is `initial`. Without
2054/// `initial`, an empty lane is an error. `where_mask` is
2055/// broadcast-compatible with `a.shape()` (#565).
2056pub fn min_with<T, D>(
2057    a: &Array<T, D>,
2058    axis: Option<usize>,
2059    initial: Option<T>,
2060    where_mask: Option<&Array<bool, IxDyn>>,
2061) -> FerrayResult<Array<T, IxDyn>>
2062where
2063    T: Element + PartialOrd + Copy,
2064    D: Dimension,
2065{
2066    let nan_min = |a: T, b: T| -> T {
2067        if a <= b {
2068            a
2069        } else if a > b {
2070            b
2071        } else {
2072            a
2073        }
2074    };
2075    let data = borrow_data(a);
2076    let mask_vec = prepare_where_mask(a, where_mask, "min")?;
2077    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2078
2079    let lane_min = |lane: &[T], initial: Option<T>| -> FerrayResult<T> {
2080        let mut iter = lane.iter().copied();
2081        let seed = match initial {
2082            Some(v) => Some(v),
2083            None => iter.next(),
2084        };
2085        match seed {
2086            Some(s) => Ok(iter.fold(s, nan_min)),
2087            None => Err(FerrayError::invalid_value(
2088                "min: empty lane and no initial value",
2089            )),
2090        }
2091    };
2092
2093    match axis {
2094        None => {
2095            let total = match mask_slice {
2096                None => lane_min(&data, initial)?,
2097                Some(mask) => {
2098                    let filtered: Vec<T> = data
2099                        .iter()
2100                        .copied()
2101                        .zip(mask.iter().copied())
2102                        .filter_map(|(x, m)| if m { Some(x) } else { None })
2103                        .collect();
2104                    lane_min(&filtered, initial)?
2105                }
2106            };
2107            make_result(&[], vec![total])
2108        }
2109        Some(ax) => {
2110            validate_axis(ax, a.ndim())?;
2111            let shape_vec = a.shape().to_vec();
2112            let out_s = output_shape(&shape_vec, ax);
2113            let strides = compute_strides(&shape_vec);
2114
2115            let out_size: usize = if out_s.is_empty() {
2116                1
2117            } else {
2118                out_s.iter().product()
2119            };
2120            let mut result: Vec<T> = Vec::with_capacity(out_size);
2121            let mut out_multi = vec![0usize; out_s.len()];
2122            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2123
2124            for _ in 0..out_size {
2125                gather_lane_with_mask(
2126                    &data,
2127                    mask_slice,
2128                    &shape_vec,
2129                    &strides,
2130                    ax,
2131                    &out_multi,
2132                    &mut lane_buf,
2133                );
2134                result.push(lane_min(&lane_buf, initial)?);
2135                if !out_s.is_empty() {
2136                    increment_multi_index(&mut out_multi, &out_s);
2137                }
2138            }
2139            make_result(&out_s, result)
2140        }
2141    }
2142}
2143
2144/// Max reduction with `initial` and `where` parameters.
2145///
2146/// Symmetric to [`min_with`]. `where_mask` is broadcast-compatible with
2147/// `a.shape()` (#565).
2148pub fn max_with<T, D>(
2149    a: &Array<T, D>,
2150    axis: Option<usize>,
2151    initial: Option<T>,
2152    where_mask: Option<&Array<bool, IxDyn>>,
2153) -> FerrayResult<Array<T, IxDyn>>
2154where
2155    T: Element + PartialOrd + Copy,
2156    D: Dimension,
2157{
2158    let nan_max = |a: T, b: T| -> T {
2159        if a >= b {
2160            a
2161        } else if a < b {
2162            b
2163        } else {
2164            a
2165        }
2166    };
2167    let data = borrow_data(a);
2168    let mask_vec = prepare_where_mask(a, where_mask, "max")?;
2169    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2170
2171    let lane_max = |lane: &[T], initial: Option<T>| -> FerrayResult<T> {
2172        let mut iter = lane.iter().copied();
2173        let seed = match initial {
2174            Some(v) => Some(v),
2175            None => iter.next(),
2176        };
2177        match seed {
2178            Some(s) => Ok(iter.fold(s, nan_max)),
2179            None => Err(FerrayError::invalid_value(
2180                "max: empty lane and no initial value",
2181            )),
2182        }
2183    };
2184
2185    match axis {
2186        None => {
2187            let total = match mask_slice {
2188                None => lane_max(&data, initial)?,
2189                Some(mask) => {
2190                    let filtered: Vec<T> = data
2191                        .iter()
2192                        .copied()
2193                        .zip(mask.iter().copied())
2194                        .filter_map(|(x, m)| if m { Some(x) } else { None })
2195                        .collect();
2196                    lane_max(&filtered, initial)?
2197                }
2198            };
2199            make_result(&[], vec![total])
2200        }
2201        Some(ax) => {
2202            validate_axis(ax, a.ndim())?;
2203            let shape_vec = a.shape().to_vec();
2204            let out_s = output_shape(&shape_vec, ax);
2205            let strides = compute_strides(&shape_vec);
2206
2207            let out_size: usize = if out_s.is_empty() {
2208                1
2209            } else {
2210                out_s.iter().product()
2211            };
2212            let mut result: Vec<T> = Vec::with_capacity(out_size);
2213            let mut out_multi = vec![0usize; out_s.len()];
2214            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2215
2216            for _ in 0..out_size {
2217                gather_lane_with_mask(
2218                    &data,
2219                    mask_slice,
2220                    &shape_vec,
2221                    &strides,
2222                    ax,
2223                    &out_multi,
2224                    &mut lane_buf,
2225                );
2226                result.push(lane_max(&lane_buf, initial)?);
2227                if !out_s.is_empty() {
2228                    increment_multi_index(&mut out_multi, &out_s);
2229                }
2230            }
2231            make_result(&out_s, result)
2232        }
2233    }
2234}
2235
2236/// Mean reduction with a `where` mask.
2237///
2238/// Equivalent to `np.mean(a, axis=axis, where=where_mask)`. The divisor
2239/// is the count of `true` positions in the (broadcast) mask, NOT the
2240/// lane length — fully-masked-out lanes return `T::nan()` (matching
2241/// `NumPy`'s "`RuntimeWarning`: Mean of empty slice" behavior, but without
2242/// the warning machinery). `initial` is intentionally not modeled
2243/// because the divisor for an "initial-bumped" mean is ambiguous in
2244/// `NumPy` too. `where_mask` is broadcast-compatible with `a.shape()`
2245/// (#565).
2246pub fn mean_where<T, D>(
2247    a: &Array<T, D>,
2248    axis: Option<usize>,
2249    where_mask: Option<&Array<bool, IxDyn>>,
2250) -> FerrayResult<Array<T, IxDyn>>
2251where
2252    T: Element + Float,
2253    D: Dimension,
2254{
2255    let data = borrow_data(a);
2256    let mask_vec = prepare_where_mask(a, where_mask, "mean")?;
2257    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2258
2259    match axis {
2260        None => {
2261            let (sum, count) = match mask_slice {
2262                None => {
2263                    let total = data
2264                        .iter()
2265                        .copied()
2266                        .fold(<T as Element>::zero(), |acc, x| acc + x);
2267                    (total, data.len())
2268                }
2269                Some(mask) => {
2270                    let mut s = <T as Element>::zero();
2271                    let mut c = 0usize;
2272                    for (&x, &m) in data.iter().zip(mask.iter()) {
2273                        if m {
2274                            s = s + x;
2275                            c += 1;
2276                        }
2277                    }
2278                    (s, c)
2279                }
2280            };
2281            let result = if count == 0 {
2282                <T as Float>::nan()
2283            } else {
2284                sum / T::from(count).unwrap()
2285            };
2286            make_result(&[], vec![result])
2287        }
2288        Some(ax) => {
2289            validate_axis(ax, a.ndim())?;
2290            let shape_vec = a.shape().to_vec();
2291            let out_s = output_shape(&shape_vec, ax);
2292            let strides = compute_strides(&shape_vec);
2293
2294            let out_size: usize = if out_s.is_empty() {
2295                1
2296            } else {
2297                out_s.iter().product()
2298            };
2299            let mut result: Vec<T> = Vec::with_capacity(out_size);
2300            let mut out_multi = vec![0usize; out_s.len()];
2301            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2302
2303            for _ in 0..out_size {
2304                gather_lane_with_mask(
2305                    &data,
2306                    mask_slice,
2307                    &shape_vec,
2308                    &strides,
2309                    ax,
2310                    &out_multi,
2311                    &mut lane_buf,
2312                );
2313                let lane_mean = if lane_buf.is_empty() {
2314                    <T as Float>::nan()
2315                } else {
2316                    let s = lane_buf
2317                        .iter()
2318                        .copied()
2319                        .fold(<T as Element>::zero(), |acc, x| acc + x);
2320                    s / T::from(lane_buf.len()).unwrap()
2321                };
2322                result.push(lane_mean);
2323                if !out_s.is_empty() {
2324                    increment_multi_index(&mut out_multi, &out_s);
2325                }
2326            }
2327            make_result(&out_s, result)
2328        }
2329    }
2330}
2331
2332/// Single-axis argmin with optional `keepdims`.
2333///
2334/// `NumPy`'s `argmin` only accepts a single axis (or `None`); this mirrors
2335/// that constraint. `keepdims` preserves the reduced axis as size 1.
2336pub fn argmin_keepdims<T, D>(
2337    a: &Array<T, D>,
2338    axis: Option<usize>,
2339    keepdims: bool,
2340) -> FerrayResult<Array<u64, IxDyn>>
2341where
2342    T: Element + PartialOrd + Copy,
2343    D: Dimension,
2344{
2345    if a.is_empty() {
2346        return Err(FerrayError::invalid_value(
2347            "cannot compute argmin of empty array",
2348        ));
2349    }
2350    let ndim = a.ndim();
2351    let ax_vec: Vec<usize> = match axis {
2352        None => (0..ndim).collect(),
2353        Some(ax) => {
2354            if ax >= ndim {
2355                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
2356            }
2357            vec![ax]
2358        }
2359    };
2360    let shape = a.shape();
2361    let data = borrow_data(a);
2362    let (result_f64, out_shape) =
2363        reduce_axes_general_u64(&data, shape, &ax_vec, keepdims, |lane| {
2364            lane.iter().enumerate().reduce(argmin_fold).unwrap().0 as u64
2365        });
2366    make_result(&out_shape, result_f64)
2367}
2368
2369/// Single-axis argmax with optional `keepdims`.
2370pub fn argmax_keepdims<T, D>(
2371    a: &Array<T, D>,
2372    axis: Option<usize>,
2373    keepdims: bool,
2374) -> FerrayResult<Array<u64, IxDyn>>
2375where
2376    T: Element + PartialOrd + Copy,
2377    D: Dimension,
2378{
2379    if a.is_empty() {
2380        return Err(FerrayError::invalid_value(
2381            "cannot compute argmax of empty array",
2382        ));
2383    }
2384    let ndim = a.ndim();
2385    let ax_vec: Vec<usize> = match axis {
2386        None => (0..ndim).collect(),
2387        Some(ax) => {
2388            if ax >= ndim {
2389                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
2390            }
2391            vec![ax]
2392        }
2393    };
2394    let shape = a.shape();
2395    let data = borrow_data(a);
2396    let (result_u64, out_shape) =
2397        reduce_axes_general_u64(&data, shape, &ax_vec, keepdims, |lane| {
2398            lane.iter().enumerate().reduce(argmax_fold).unwrap().0 as u64
2399        });
2400    make_result(&out_shape, result_u64)
2401}
2402
2403/// Multi-axis reduction that returns `u64` values (used by argmin/argmax variants).
2404pub(crate) fn reduce_axes_general_u64<T: Copy, F: Fn(&[T]) -> u64>(
2405    data: &[T],
2406    shape: &[usize],
2407    axes: &[usize],
2408    keepdims: bool,
2409    f: F,
2410) -> (Vec<u64>, Vec<usize>) {
2411    let ndim = shape.len();
2412    let strides = compute_strides(shape);
2413
2414    let is_reduce: Vec<bool> = (0..ndim).map(|i| axes.contains(&i)).collect();
2415    let keep_axes: Vec<usize> = (0..ndim).filter(|i| !is_reduce[*i]).collect();
2416    let reduce_axes: Vec<usize> = (0..ndim).filter(|i| is_reduce[*i]).collect();
2417    let keep_shape: Vec<usize> = keep_axes.iter().map(|&i| shape[i]).collect();
2418    let reduce_shape: Vec<usize> = reduce_axes.iter().map(|&i| shape[i]).collect();
2419
2420    let out_size: usize = if keep_shape.is_empty() {
2421        1
2422    } else {
2423        keep_shape.iter().product()
2424    };
2425    let lane_size: usize = if reduce_shape.is_empty() {
2426        1
2427    } else {
2428        reduce_shape.iter().product()
2429    };
2430    let out_shape = output_shape_axes(shape, axes, keepdims);
2431
2432    let mut result = Vec::with_capacity(out_size);
2433    let mut lane: Vec<T> = Vec::with_capacity(lane_size);
2434    let mut keep_multi = vec![0usize; keep_shape.len()];
2435    let mut reduce_multi = vec![0usize; reduce_shape.len()];
2436    let mut full_multi = vec![0usize; ndim];
2437
2438    for _ in 0..out_size {
2439        for (i, &ax) in keep_axes.iter().enumerate() {
2440            full_multi[ax] = keep_multi[i];
2441        }
2442
2443        lane.clear();
2444        reduce_multi.fill(0);
2445        for _ in 0..lane_size {
2446            for (i, &ax) in reduce_axes.iter().enumerate() {
2447                full_multi[ax] = reduce_multi[i];
2448            }
2449            lane.push(data[flat_index(&full_multi, &strides)]);
2450            if !reduce_shape.is_empty() {
2451                increment_multi_index(&mut reduce_multi, &reduce_shape);
2452            }
2453        }
2454
2455        result.push(f(&lane));
2456
2457        if !keep_shape.is_empty() {
2458            increment_multi_index(&mut keep_multi, &keep_shape);
2459        }
2460    }
2461
2462    (result, out_shape)
2463}
2464
2465// ---------------------------------------------------------------------------
2466// Re-export cumulative operations from ferray-ufunc for discoverability
2467// ---------------------------------------------------------------------------
2468
2469/// Cumulative sum along an axis (or flattened if axis is None).
2470///
2471/// Re-exported from `ferray_ufunc::cumsum` for convenience. The result element
2472/// type is the promoted [`ReduceAcc::Acc`] (narrow ints widen so the
2473/// accumulation never overflows), matching `np.cumsum`
2474/// (`numpy/_core/fromnumeric.py:2850-2855`).
2475pub fn cumsum<T, D>(
2476    a: &Array<T, D>,
2477    axis: Option<usize>,
2478) -> FerrayResult<Array<<T as ReduceAcc>::Acc, D>>
2479where
2480    T: ReduceAcc,
2481    D: Dimension,
2482{
2483    ferray_ufunc::cumsum(a, axis)
2484}
2485
2486/// Cumulative product along an axis (or flattened if axis is None).
2487///
2488/// Re-exported from `ferray_ufunc::cumprod` for convenience. The result element
2489/// type is the promoted [`ReduceAcc::Acc`], matching `np.cumprod`
2490/// (`numpy/_core/fromnumeric.py:3424-3429`).
2491pub fn cumprod<T, D>(
2492    a: &Array<T, D>,
2493    axis: Option<usize>,
2494) -> FerrayResult<Array<<T as ReduceAcc>::Acc, D>>
2495where
2496    T: ReduceAcc,
2497    D: Dimension,
2498{
2499    ferray_ufunc::cumprod(a, axis)
2500}
2501
2502#[cfg(test)]
2503mod tests {
2504    use super::*;
2505    use ferray_core::{Ix1, Ix2};
2506
2507    #[test]
2508    fn test_sum_1d_no_axis() {
2509        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2510        let s = sum(&a, None).unwrap();
2511        assert_eq!(s.shape(), &[]);
2512        assert_eq!(s.iter().next(), Some(&10.0));
2513    }
2514
2515    #[test]
2516    fn test_sum_2d_axis0() {
2517        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2518            .unwrap();
2519        let s = sum(&a, Some(0)).unwrap();
2520        assert_eq!(s.shape(), &[3]);
2521        let data: Vec<f64> = s.iter().copied().collect();
2522        assert_eq!(data, vec![5.0, 7.0, 9.0]);
2523    }
2524
2525    #[test]
2526    fn test_sum_2d_axis1() {
2527        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2528            .unwrap();
2529        let s = sum(&a, Some(1)).unwrap();
2530        assert_eq!(s.shape(), &[2]);
2531        let data: Vec<f64> = s.iter().copied().collect();
2532        assert_eq!(data, vec![6.0, 15.0]);
2533    }
2534
2535    #[test]
2536    fn test_prod_1d() {
2537        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2538        let p = prod(&a, None).unwrap();
2539        assert_eq!(p.iter().next(), Some(&24.0));
2540    }
2541
2542    #[test]
2543    fn test_min_max() {
2544        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
2545        let mn = min(&a, None).unwrap();
2546        let mx = max(&a, None).unwrap();
2547        assert_eq!(mn.iter().next(), Some(&1.0));
2548        assert_eq!(mx.iter().next(), Some(&4.0));
2549    }
2550
2551    #[test]
2552    fn test_argmin_argmax() {
2553        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
2554        let ami = argmin(&a, None).unwrap();
2555        let amx = argmax(&a, None).unwrap();
2556        assert_eq!(ami.iter().next(), Some(&1u64));
2557        assert_eq!(amx.iter().next(), Some(&2u64));
2558    }
2559
2560    #[test]
2561    fn test_mean() {
2562        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2563        let m = mean(&a, None).unwrap();
2564        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
2565    }
2566
2567    #[test]
2568    fn test_var_population() {
2569        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2570        let v = var(&a, None, 0).unwrap();
2571        // var = ((1-2.5)^2 + (2-2.5)^2 + (3-2.5)^2 + (4-2.5)^2) / 4 = 1.25
2572        assert!((v.iter().next().unwrap() - 1.25).abs() < 1e-12);
2573    }
2574
2575    #[test]
2576    fn test_var_sample() {
2577        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2578        let v = var(&a, None, 1).unwrap();
2579        // var = 5.0 / 3.0 = 1.6666...
2580        assert!((v.iter().next().unwrap() - 5.0 / 3.0).abs() < 1e-12);
2581    }
2582
2583    #[test]
2584    fn test_std() {
2585        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2586        let s = std_(&a, None, 1).unwrap();
2587        let expected = (5.0_f64 / 3.0).sqrt();
2588        assert!((s.iter().next().unwrap() - expected).abs() < 1e-12);
2589    }
2590
2591    #[test]
2592    fn test_sum_axis_out_of_bounds() {
2593        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
2594        assert!(sum(&a, Some(1)).is_err());
2595    }
2596
2597    #[test]
2598    fn test_cumsum_reexport() {
2599        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2600        let cs = cumsum(&a, None).unwrap();
2601        assert_eq!(cs.as_slice().unwrap(), &[1, 3, 6, 10]);
2602    }
2603
2604    #[test]
2605    fn test_cumprod_reexport() {
2606        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2607        let cp = cumprod(&a, None).unwrap();
2608        assert_eq!(cp.as_slice().unwrap(), &[1, 2, 6, 24]);
2609    }
2610
2611    #[test]
2612    fn test_min_2d_axis0() {
2613        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0])
2614            .unwrap();
2615        let m = min(&a, Some(0)).unwrap();
2616        let data: Vec<f64> = m.iter().copied().collect();
2617        assert_eq!(data, vec![1.0, 1.0, 2.0]);
2618    }
2619
2620    #[test]
2621    fn test_argmin_2d_axis1() {
2622        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0])
2623            .unwrap();
2624        let ami = argmin(&a, Some(1)).unwrap();
2625        let data: Vec<u64> = ami.iter().copied().collect();
2626        assert_eq!(data, vec![1, 0]);
2627    }
2628
2629    #[test]
2630    fn test_mean_2d_axis0() {
2631        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2632            .unwrap();
2633        let m = mean(&a, Some(0)).unwrap();
2634        let data: Vec<f64> = m.iter().copied().collect();
2635        assert_eq!(data, vec![2.5, 3.5, 4.5]);
2636    }
2637
2638    #[test]
2639    fn test_sum_integer() {
2640        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
2641        let s = sum(&a, None).unwrap();
2642        assert_eq!(s.iter().next(), Some(&15));
2643    }
2644
2645    #[test]
2646    fn test_mean_as_f64_integer() {
2647        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2648        let m = mean_as_f64(&a, None).unwrap();
2649        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
2650    }
2651
2652    #[test]
2653    fn test_mean_as_f64_integer_axis() {
2654        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
2655        let m = mean_as_f64(&a, Some(1)).unwrap();
2656        assert_eq!(m.shape(), &[2]);
2657        let data: Vec<f64> = m.iter().copied().collect();
2658        assert!((data[0] - 2.0).abs() < 1e-12);
2659        assert!((data[1] - 5.0).abs() < 1e-12);
2660    }
2661
2662    #[test]
2663    fn test_mean_as_f64_u8() {
2664        let a = Array::<u8, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
2665        let m = mean_as_f64(&a, None).unwrap();
2666        assert!((m.iter().next().unwrap() - 25.0).abs() < 1e-12);
2667    }
2668
2669    #[test]
2670    fn test_sum_as_f64_integer() {
2671        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2672        let s = sum_as_f64(&a, None).unwrap();
2673        assert!((s.iter().next().unwrap() - 10.0).abs() < 1e-12);
2674    }
2675
2676    #[test]
2677    fn test_sum_as_f64_large_values() {
2678        // Values that would overflow i32 if summed as i32
2679        let a =
2680            Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![i32::MAX, i32::MAX, i32::MAX]).unwrap();
2681        let s = sum_as_f64(&a, None).unwrap();
2682        let expected = 3.0 * f64::from(i32::MAX);
2683        assert!((s.iter().next().unwrap() - expected).abs() < 1.0);
2684    }
2685
2686    // -----------------------------------------------------------------------
2687    // Multi-axis + keepdims tests (issues #457 + #458)
2688    // -----------------------------------------------------------------------
2689
2690    use ferray_core::Ix3;
2691
2692    fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
2693        let n = data.len();
2694        Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
2695    }
2696
2697    fn arr2d(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
2698        Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap()
2699    }
2700
2701    fn arr3d(s0: usize, s1: usize, s2: usize, data: Vec<f64>) -> Array<f64, Ix3> {
2702        Array::<f64, Ix3>::from_vec(Ix3::new([s0, s1, s2]), data).unwrap()
2703    }
2704
2705    #[test]
2706    fn test_normalize_axes_none_reduces_all() {
2707        let ax = normalize_axes(None, 3).unwrap();
2708        assert_eq!(ax, vec![0, 1, 2]);
2709    }
2710
2711    #[test]
2712    fn test_normalize_axes_empty_reduces_all() {
2713        let ax = normalize_axes(Some(&[]), 3).unwrap();
2714        assert_eq!(ax, vec![0, 1, 2]);
2715    }
2716
2717    #[test]
2718    fn test_normalize_axes_sorts() {
2719        let ax = normalize_axes(Some(&[2, 0]), 3).unwrap();
2720        assert_eq!(ax, vec![0, 2]);
2721    }
2722
2723    #[test]
2724    fn test_normalize_axes_rejects_duplicate() {
2725        assert!(normalize_axes(Some(&[0, 0]), 3).is_err());
2726    }
2727
2728    #[test]
2729    fn test_normalize_axes_rejects_out_of_bounds() {
2730        assert!(normalize_axes(Some(&[3]), 3).is_err());
2731    }
2732
2733    #[test]
2734    fn test_sum_axes_single_axis_matches_legacy() {
2735        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2736        let legacy = sum(&a, Some(0)).unwrap();
2737        let new = sum_axes(&a, Some(&[0]), false).unwrap();
2738        assert_eq!(legacy.shape(), new.shape());
2739        let la: Vec<f64> = legacy.iter().copied().collect();
2740        let na: Vec<f64> = new.iter().copied().collect();
2741        assert_eq!(la, na);
2742    }
2743
2744    #[test]
2745    fn test_sum_axes_keepdims_2d() {
2746        // sum((2, 3), axis=1, keepdims=True) -> shape (2, 1)
2747        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2748        let s = sum_axes(&a, Some(&[1]), true).unwrap();
2749        assert_eq!(s.shape(), &[2, 1]);
2750        let data: Vec<f64> = s.iter().copied().collect();
2751        assert_eq!(data, vec![6.0, 15.0]);
2752    }
2753
2754    #[test]
2755    fn test_sum_axes_keepdims_supports_broadcast_back() {
2756        // Canonical NumPy pattern: arr - arr.mean(axis=1, keepdims=True)
2757        // We do it with a 2x3 array.
2758        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2759        let m = mean_axes(&a, Some(&[1]), true).unwrap();
2760        assert_eq!(m.shape(), &[2, 1]);
2761        // Row 0 mean = 2.0, row 1 mean = 5.0
2762        let md: Vec<f64> = m.iter().copied().collect();
2763        assert_eq!(md, vec![2.0, 5.0]);
2764    }
2765
2766    #[test]
2767    fn test_sum_axes_multi_axis_3d() {
2768        // shape (2, 3, 4), sum over axes (0, 2) -> shape (3,)
2769        let data: Vec<f64> = (0..24).map(f64::from).collect();
2770        let a = arr3d(2, 3, 4, data);
2771        let s = sum_axes(&a, Some(&[0, 2]), false).unwrap();
2772        assert_eq!(s.shape(), &[3]);
2773        // For each j in [0,1,2], sum over i in [0,1] and k in [0,1,2,3]:
2774        //   a[i,j,k] = i*12 + j*4 + k
2775        //   sum = (0*12 + j*4) + ... + (0*12 + j*4+3) + (1*12 + j*4) + ... + (1*12 + j*4+3)
2776        //       = 2*(4*j*4 + 6) + 2*12*1
2777        // For j=0: (0+1+2+3) + (12+13+14+15) = 6 + 54 = 60
2778        // For j=1: (4+5+6+7) + (16+17+18+19) = 22 + 70 = 92
2779        // For j=2: (8+9+10+11) + (20+21+22+23) = 38 + 86 = 124
2780        let d: Vec<f64> = s.iter().copied().collect();
2781        assert_eq!(d, vec![60.0, 92.0, 124.0]);
2782    }
2783
2784    #[test]
2785    fn test_sum_axes_multi_axis_keepdims_3d() {
2786        // Same as above but keepdims -> shape (1, 3, 1)
2787        let data: Vec<f64> = (0..24).map(f64::from).collect();
2788        let a = arr3d(2, 3, 4, data);
2789        let s = sum_axes(&a, Some(&[0, 2]), true).unwrap();
2790        assert_eq!(s.shape(), &[1, 3, 1]);
2791        let d: Vec<f64> = s.iter().copied().collect();
2792        assert_eq!(d, vec![60.0, 92.0, 124.0]);
2793    }
2794
2795    #[test]
2796    fn test_sum_axes_none_reduces_all() {
2797        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2798        let s = sum_axes(&a, None, false).unwrap();
2799        assert_eq!(s.shape(), &[]);
2800        assert_eq!(s.iter().next(), Some(&21.0));
2801    }
2802
2803    #[test]
2804    fn test_sum_axes_none_keepdims() {
2805        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2806        let s = sum_axes(&a, None, true).unwrap();
2807        assert_eq!(s.shape(), &[1, 1]);
2808        assert_eq!(s.iter().next(), Some(&21.0));
2809    }
2810
2811    #[test]
2812    fn test_prod_axes_keepdims() {
2813        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2814        let p = prod_axes(&a, Some(&[1]), true).unwrap();
2815        assert_eq!(p.shape(), &[2, 1]);
2816        let d: Vec<f64> = p.iter().copied().collect();
2817        assert_eq!(d, vec![6.0, 120.0]);
2818    }
2819
2820    #[test]
2821    fn test_min_max_axes_multi_axis() {
2822        let data: Vec<f64> = (0..24).map(f64::from).collect();
2823        let a = arr3d(2, 3, 4, data);
2824        let mn = min_axes(&a, Some(&[0, 2]), false).unwrap();
2825        let mx = max_axes(&a, Some(&[0, 2]), false).unwrap();
2826        // min over i=0..2, k=0..4 for each j: min value on i=0 has smallest
2827        // indices. For j=0, values are 0..3 and 12..15 -> min=0, max=15
2828        assert_eq!(mn.iter().copied().collect::<Vec<_>>(), vec![0.0, 4.0, 8.0]);
2829        assert_eq!(
2830            mx.iter().copied().collect::<Vec<_>>(),
2831            vec![15.0, 19.0, 23.0]
2832        );
2833    }
2834
2835    #[test]
2836    fn test_mean_axes_multi_axis() {
2837        let data: Vec<f64> = (0..24).map(f64::from).collect();
2838        let a = arr3d(2, 3, 4, data);
2839        let m = mean_axes(&a, Some(&[0, 2]), false).unwrap();
2840        assert_eq!(m.shape(), &[3]);
2841        // For j=0: 60 / 8 = 7.5
2842        // For j=1: 92 / 8 = 11.5
2843        // For j=2: 124 / 8 = 15.5
2844        let d: Vec<f64> = m.iter().copied().collect();
2845        assert_eq!(d, vec![7.5, 11.5, 15.5]);
2846    }
2847
2848    #[test]
2849    fn test_var_axes_single_axis_keepdims() {
2850        let a = arr2d(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2851        // Population variance along axis 1 (each row), keepdims = (2, 1)
2852        let v = var_axes(&a, Some(&[1]), 0, true).unwrap();
2853        assert_eq!(v.shape(), &[2, 1]);
2854        // Row 0: mean=2.5, var=((1-2.5)^2+(2-2.5)^2+(3-2.5)^2+(4-2.5)^2)/4 = 1.25
2855        // Row 1: same distribution = 1.25
2856        let d: Vec<f64> = v.iter().copied().collect();
2857        assert!((d[0] - 1.25).abs() < 1e-12);
2858        assert!((d[1] - 1.25).abs() < 1e-12);
2859    }
2860
2861    #[test]
2862    fn test_std_axes_multi_axis() {
2863        let a = arr2d(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2864        // Population std over all: mean=4.5, var=5.25, std≈2.2913
2865        let s = std_axes(&a, None, 0, false).unwrap();
2866        assert_eq!(s.shape(), &[]);
2867        let v = *s.iter().next().unwrap();
2868        assert!((v - 5.25_f64.sqrt()).abs() < 1e-12);
2869    }
2870
2871    #[test]
2872    fn var_axes_ddof_ge_reduced_len_is_inf_or_nan_not_error() -> FerrayResult<()> {
2873        // numpy clamps the divisor to 0 (max(rcount-ddof, 0)) then true-divides,
2874        // so ddof >= reduced-axis length yields +inf (numerator > 0) or NaN
2875        // (numerator == 0), never an error. _core/_methods.py:204,208.
2876        // Live numpy (numpy 2.4):
2877        //   a = [[1,2,3],[4,5,6]]
2878        //   np.var(a, axis=(0,1), ddof=10) == inf
2879        //   np.var(a, axis=0,     ddof=5)  == [inf inf inf]
2880        //   all-equal a2 = [[2,2,2],[2,2,2]]:
2881        //   np.var(a2, axis=(0,1), ddof=10) == nan
2882        //   np.var(a2, axis=0,     ddof=5)  == [nan nan nan]
2883        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2884
2885        // Multi-axis reduction over all 6 elements, ddof=10 > 6 -> +inf.
2886        let v_all = var_axes(&a, Some(&[0, 1]), 10, false)?;
2887        assert_eq!(v_all.shape(), &[]);
2888        let got_all: f64 = v_all.iter().copied().next().unwrap_or(0.0);
2889        assert!(got_all.is_infinite() && got_all.is_sign_positive());
2890
2891        // Reduce over axis 0 (length 2), ddof=5 > 2 -> [inf, inf, inf].
2892        let v_ax0 = var_axes(&a, Some(&[0]), 5, false)?;
2893        assert_eq!(v_ax0.shape(), &[3]);
2894        assert!(
2895            v_ax0
2896                .iter()
2897                .all(|x| x.is_infinite() && x.is_sign_positive())
2898        );
2899
2900        // std_axes = sqrt(var_axes); sqrt(+inf) = +inf.
2901        let s_ax0 = std_axes(&a, Some(&[0]), 5, false)?;
2902        assert!(
2903            s_ax0
2904                .iter()
2905                .all(|x| x.is_infinite() && x.is_sign_positive())
2906        );
2907
2908        // All-equal input: numerator (sum of squared deviations) == 0, so
2909        // 0/0 -> NaN for both the multi-axis and single-axis-in-list paths.
2910        let a2 = arr2d(2, 3, vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0]);
2911        let nan_all = var_axes(&a2, Some(&[0, 1]), 10, false)?;
2912        let got_nan: f64 = nan_all.iter().copied().next().unwrap_or(0.0);
2913        assert!(got_nan.is_nan());
2914        let nan_ax0 = var_axes(&a2, Some(&[0]), 5, false)?;
2915        assert!(nan_ax0.iter().all(|x| x.is_nan()));
2916        // std of NaN variance is NaN.
2917        let snan_ax0 = std_axes(&a2, Some(&[0]), 5, false)?;
2918        assert!(snan_ax0.iter().all(|x| x.is_nan()));
2919        Ok(())
2920    }
2921
2922    #[test]
2923    fn test_argmin_keepdims_single_axis() {
2924        let a = arr2d(2, 3, vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0]);
2925        let am = argmin_keepdims(&a, Some(1), true).unwrap();
2926        assert_eq!(am.shape(), &[2, 1]);
2927        let d: Vec<u64> = am.iter().copied().collect();
2928        assert_eq!(d, vec![1, 0]);
2929    }
2930
2931    #[test]
2932    fn test_argmax_keepdims_single_axis() {
2933        let a = arr2d(2, 3, vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0]);
2934        let am = argmax_keepdims(&a, Some(1), false).unwrap();
2935        assert_eq!(am.shape(), &[2]);
2936        let d: Vec<u64> = am.iter().copied().collect();
2937        assert_eq!(d, vec![2, 1]);
2938    }
2939
2940    #[test]
2941    fn test_axes_out_of_bounds_error() {
2942        let a = arr2d(2, 3, vec![0.0; 6]);
2943        assert!(sum_axes(&a, Some(&[5]), false).is_err());
2944    }
2945
2946    #[test]
2947    fn test_axes_duplicate_error() {
2948        let a = arr3d(2, 3, 4, vec![0.0; 24]);
2949        assert!(sum_axes(&a, Some(&[0, 0]), false).is_err());
2950    }
2951
2952    // ----- Infinity tests (#177) -----
2953
2954    #[test]
2955    fn sum_with_infinity() {
2956        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2957        let s = sum(&a, None).unwrap();
2958        assert_eq!(s.iter().next().copied().unwrap(), f64::INFINITY);
2959    }
2960
2961    #[test]
2962    fn sum_inf_minus_inf_is_nan() {
2963        let a = arr1d(vec![f64::INFINITY, f64::NEG_INFINITY]);
2964        let s = sum(&a, None).unwrap();
2965        assert!(s.iter().next().copied().unwrap().is_nan());
2966    }
2967
2968    #[test]
2969    fn mean_with_infinity() {
2970        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2971        let m = mean(&a, None).unwrap();
2972        assert_eq!(m.iter().next().copied().unwrap(), f64::INFINITY);
2973    }
2974
2975    #[test]
2976    fn min_with_neg_infinity() {
2977        let a = arr1d(vec![1.0, f64::NEG_INFINITY, 3.0]);
2978        let m = min(&a, None).unwrap();
2979        assert_eq!(m.iter().next().copied().unwrap(), f64::NEG_INFINITY);
2980    }
2981
2982    #[test]
2983    fn max_with_infinity() {
2984        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2985        let m = max(&a, None).unwrap();
2986        assert_eq!(m.iter().next().copied().unwrap(), f64::INFINITY);
2987    }
2988
2989    #[test]
2990    fn prod_with_zero_and_infinity_is_nan() {
2991        // 0 * INF = NaN per IEEE 754
2992        let a = arr1d(vec![0.0, f64::INFINITY, 1.0]);
2993        let p = prod(&a, None).unwrap();
2994        assert!(p.iter().next().copied().unwrap().is_nan());
2995    }
2996
2997    #[test]
2998    fn prod_with_infinity_propagates() {
2999        let a = arr1d(vec![2.0, f64::INFINITY, 3.0]);
3000        let p = prod(&a, None).unwrap();
3001        assert_eq!(p.iter().next().copied().unwrap(), f64::INFINITY);
3002    }
3003
3004    #[test]
3005    fn var_with_infinity_is_nan() {
3006        // var includes (x - mean)^2 where mean is INF; (INF - INF)^2 = NaN.
3007        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
3008        let v = var(&a, None, 0).unwrap();
3009        assert!(v.iter().next().copied().unwrap().is_nan());
3010    }
3011
3012    #[test]
3013    fn std_with_infinity_is_nan() {
3014        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
3015        let s = std_(&a, None, 0).unwrap();
3016        assert!(s.iter().next().copied().unwrap().is_nan());
3017    }
3018
3019    #[test]
3020    fn argmin_finds_neg_infinity() {
3021        let a = arr1d(vec![1.0, f64::NEG_INFINITY, 3.0, -100.0]);
3022        let i = crate::reductions::argmin(&a, None).unwrap();
3023        assert_eq!(i.iter().next().copied().unwrap(), 1);
3024    }
3025
3026    #[test]
3027    fn argmax_finds_infinity() {
3028        let a = arr1d(vec![1.0, f64::INFINITY, 1e300, 5.0]);
3029        let i = crate::reductions::argmax(&a, None).unwrap();
3030        assert_eq!(i.iter().next().copied().unwrap(), 1);
3031    }
3032
3033    #[test]
3034    fn cumsum_propagates_infinity() {
3035        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
3036        let c = cumsum(&a, None).unwrap();
3037        let v: Vec<f64> = c.iter().copied().collect();
3038        assert_eq!(v[0], 1.0);
3039        assert!(v[1].is_infinite());
3040        assert!(v[2].is_infinite());
3041    }
3042
3043    #[test]
3044    fn cumprod_inf_then_zero_yields_nan() {
3045        let a = arr1d(vec![2.0, f64::INFINITY, 0.0]);
3046        let c = cumprod(&a, None).unwrap();
3047        let v: Vec<f64> = c.iter().copied().collect();
3048        assert_eq!(v[0], 2.0);
3049        assert!(v[1].is_infinite());
3050        assert!(v[2].is_nan()); // INF * 0 = NaN
3051    }
3052
3053    #[test]
3054    fn ptp_with_infinity_is_inf() {
3055        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
3056        let p = ptp(&a, None).unwrap();
3057        assert!(p.iter().next().copied().unwrap().is_infinite());
3058    }
3059
3060    // ----- Single-element var/std with ddof (#178) -----
3061
3062    #[test]
3063    fn var_single_element_ddof0() {
3064        let a = arr1d(vec![5.0]);
3065        let v = var(&a, None, 0).unwrap();
3066        assert_eq!(v.iter().next().copied().unwrap(), 0.0);
3067    }
3068
3069    #[test]
3070    fn var_single_element_ddof1_nan() {
3071        // ddof=1 with N=1 → divisor max(N-ddof,0)=0 and sum_sq=0, so
3072        // 0.0/0.0 = NaN — matching numpy (np.var([5.0], ddof=1) == nan,
3073        // numpy/_core/_methods.py:204).
3074        let a = arr1d(vec![5.0]);
3075        let v = var(&a, None, 1).ok().filter(|r| !r.is_empty());
3076        assert!(matches!(&v, Some(r) if r.iter().all(|x| x.is_nan())));
3077    }
3078
3079    #[test]
3080    fn std_single_element_ddof0() {
3081        let a = arr1d(vec![5.0]);
3082        let s = std_(&a, None, 0).unwrap();
3083        assert_eq!(s.iter().next().copied().unwrap(), 0.0);
3084    }
3085
3086    #[test]
3087    fn std_single_element_ddof1_nan() {
3088        // Same shape as var: ddof=1 with N=1 → divisor 0, sum_sq 0, so
3089        // var = NaN and std = sqrt(NaN) = NaN — matching numpy
3090        // (np.std([5.0], ddof=1) == nan, numpy/_core/_methods.py:217-225).
3091        let a = arr1d(vec![5.0]);
3092        let s = std_(&a, None, 1).ok().filter(|r| !r.is_empty());
3093        assert!(matches!(&s, Some(r) if r.iter().all(|x| x.is_nan())));
3094    }
3095
3096    #[test]
3097    fn var_two_elements_ddof1_population_to_sample() {
3098        // Sanity: ddof=0 vs ddof=1 on the same input — Bessel's correction
3099        // should give a 2x bigger variance for N=2.
3100        let a = arr1d(vec![1.0, 3.0]);
3101        let v0 = var(&a, None, 0).unwrap();
3102        let v1 = var(&a, None, 1).unwrap();
3103        let v0_val = v0.iter().next().copied().unwrap();
3104        let v1_val = v1.iter().next().copied().unwrap();
3105        assert!((v0_val - 1.0).abs() < 1e-12); // ((1-2)^2 + (3-2)^2) / 2 = 1.0
3106        assert!((v1_val - 2.0).abs() < 1e-12); // ((1-2)^2 + (3-2)^2) / 1 = 2.0
3107        // Sample variance (ddof=1) should be 2x population (ddof=0) for N=2.
3108        assert!((v1_val / v0_val - 2.0).abs() < 1e-12);
3109    }
3110
3111    #[test]
3112    fn var_single_element_2d_ddof0() {
3113        // 1×1 array — single element along every axis
3114        use ferray_core::dimension::Ix2;
3115        let a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 1]), vec![5.0]).unwrap();
3116        let v = var(&a, None, 0).unwrap();
3117        assert_eq!(v.iter().next().copied().unwrap(), 0.0);
3118    }
3119
3120    #[test]
3121    fn var_single_element_axis_ddof_too_large_nan() {
3122        // 1×3 array with ddof=1 reduced along axis 0 (length 1): each lane has
3123        // one element, sum_sq=0, divisor max(1-1,0)=0 → 0/0 = NaN per lane,
3124        // matching numpy (np.var([[1,2,3]], axis=0, ddof=1) == [nan nan nan],
3125        // numpy/_core/_methods.py:204).
3126        use ferray_core::dimension::Ix2;
3127        let a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
3128        let v = var(&a, Some(0), 1).ok().filter(|r| !r.is_empty());
3129        assert!(matches!(&v, Some(r) if r.iter().all(|x| x.is_nan())));
3130    }
3131
3132    // ---- *_into reductions (#467) ----
3133
3134    #[test]
3135    fn sum_into_axis_writes_into_destination() {
3136        // (2, 3) sum along axis=1 → shape (2,)
3137        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3138            .unwrap();
3139        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3140        sum_into(&a, Some(1), &mut out).unwrap();
3141        assert_eq!(out.as_slice().unwrap(), &[6.0, 15.0]);
3142    }
3143
3144    #[test]
3145    fn sum_into_no_axis_writes_scalar_destination() {
3146        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3147        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3148        sum_into(&a, None, &mut out).unwrap();
3149        assert_eq!(out.iter().next().copied().unwrap(), 10.0);
3150    }
3151
3152    #[test]
3153    fn sum_into_rejects_wrong_shape() {
3154        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3155            .unwrap();
3156        // axis=1 reduces to shape (2,), but out has shape (3,)
3157        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0; 3]).unwrap();
3158        let err = sum_into(&a, Some(1), &mut out);
3159        assert!(err.is_err());
3160    }
3161
3162    #[test]
3163    fn sum_into_rejects_axis_out_of_bounds() {
3164        let a = arr1d(vec![1.0, 2.0, 3.0]);
3165        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1]), vec![0.0]).unwrap();
3166        assert!(sum_into(&a, Some(5), &mut out).is_err());
3167    }
3168
3169    #[test]
3170    fn prod_into_basic() {
3171        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3172        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3173        prod_into(&a, None, &mut out).unwrap();
3174        assert_eq!(out.iter().next().copied().unwrap(), 24.0);
3175    }
3176
3177    #[test]
3178    fn min_into_axis_basic() {
3179        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 2.0, 4.0, 3.0, 6.0])
3180            .unwrap();
3181        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3182        min_into(&a, Some(1), &mut out).unwrap();
3183        assert_eq!(out.as_slice().unwrap(), &[1.0, 3.0]);
3184    }
3185
3186    #[test]
3187    fn max_into_axis_basic() {
3188        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 2.0, 4.0, 3.0, 6.0])
3189            .unwrap();
3190        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3191        max_into(&a, Some(1), &mut out).unwrap();
3192        assert_eq!(out.as_slice().unwrap(), &[5.0, 6.0]);
3193    }
3194
3195    #[test]
3196    fn mean_into_axis_basic() {
3197        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3198            .unwrap();
3199        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3200        mean_into(&a, Some(1), &mut out).unwrap();
3201        assert_eq!(out.as_slice().unwrap(), &[2.0, 5.0]);
3202    }
3203
3204    #[test]
3205    fn var_into_basic() {
3206        // Variance with ddof=0 of [1,2,3,4] = 1.25
3207        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3208        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3209        var_into(&a, None, 0, &mut out).unwrap();
3210        assert!((out.iter().next().copied().unwrap() - 1.25).abs() < 1e-12);
3211    }
3212
3213    #[test]
3214    fn std_into_basic() {
3215        // sqrt(var) for [1,2,3,4] with ddof=0 = sqrt(1.25)
3216        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3217        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3218        std_into(&a, None, 0, &mut out).unwrap();
3219        let expected = 1.25_f64.sqrt();
3220        assert!((out.iter().next().copied().unwrap() - expected).abs() < 1e-12);
3221    }
3222
3223    #[test]
3224    fn into_reductions_can_reuse_destination_across_calls() {
3225        // The whole point of the out= API: a single allocation reused
3226        // across many reductions on same-shaped data.
3227        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3228        for k in 0..3 {
3229            let base = f64::from(k);
3230            let a = Array::<f64, Ix2>::from_vec(
3231                Ix2::new([2, 3]),
3232                vec![
3233                    base + 1.0,
3234                    base + 2.0,
3235                    base + 3.0,
3236                    base + 4.0,
3237                    base + 5.0,
3238                    base + 6.0,
3239                ],
3240            )
3241            .unwrap();
3242            sum_into(&a, Some(1), &mut out).unwrap();
3243            assert_eq!(
3244                out.as_slice().unwrap(),
3245                &[3.0f64.mul_add(base, 6.0), 3.0f64.mul_add(base, 15.0)]
3246            );
3247        }
3248    }
3249
3250    // ---- zero-alloc kernel regression tests (#563) ----
3251
3252    #[test]
3253    fn sum_into_overwrites_existing_destination_garbage() {
3254        // The new kernel writes results directly into `out` rather than
3255        // routing through write_into's copy_from_slice; make sure existing
3256        // garbage is fully overwritten and never bleeds through.
3257        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3258            .unwrap();
3259        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![999.0, -999.0]).unwrap();
3260        sum_into(&a, Some(1), &mut out).unwrap();
3261        assert_eq!(out.as_slice().unwrap(), &[6.0, 15.0]);
3262    }
3263
3264    #[test]
3265    fn reduce_axis_typed_into_matches_reduce_axis_typed() {
3266        // The in-place kernel must produce identical output to the
3267        // allocating kernel for any (shape, axis, fn).
3268        use super::{reduce_axis_typed, reduce_axis_typed_into};
3269        let data: Vec<f64> = (0..24).map(f64::from).collect();
3270        for shape in [vec![24usize], vec![4, 6], vec![2, 3, 4], vec![2, 2, 2, 3]] {
3271            for ax in 0..shape.len() {
3272                let allocated: Vec<f64> =
3273                    reduce_axis_typed(&data, &shape, ax, |lane| lane.iter().sum());
3274                let mut dst = vec![0.0; allocated.len()];
3275                reduce_axis_typed_into(&data, &shape, ax, &mut dst, |lane| lane.iter().sum());
3276                assert_eq!(dst, allocated, "shape {shape:?} axis {ax}");
3277            }
3278        }
3279    }
3280
3281    #[test]
3282    fn sum_into_3d_axis_correct() {
3283        use ferray_core::Ix3;
3284        // (2, 3, 4) reducing axis 1 → shape (2, 4)
3285        let data: Vec<f64> = (0..24).map(f64::from).collect();
3286        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
3287        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 4]), vec![0.0; 8]).unwrap();
3288        sum_into(&a, Some(1), &mut out).unwrap();
3289        // Hand check: out[i, k] = sum_{j} a[i, j, k] = sum_{j} (i*12 + j*4 + k)
3290        let expected: Vec<f64> = (0..2)
3291            .flat_map(|i| {
3292                (0..4).map(move |k| (0..3).map(|j| f64::from(i * 12 + j * 4 + k)).sum::<f64>())
3293            })
3294            .collect();
3295        assert_eq!(out.as_slice().unwrap(), expected.as_slice());
3296    }
3297
3298    #[test]
3299    fn min_into_rejects_empty_input() {
3300        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3301        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3302        assert!(min_into(&a, None, &mut out).is_err());
3303    }
3304
3305    #[test]
3306    fn max_into_rejects_empty_input() {
3307        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3308        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3309        assert!(max_into(&a, None, &mut out).is_err());
3310    }
3311
3312    #[test]
3313    fn var_into_ddof_too_large_is_inf_not_error() -> FerrayResult<()> {
3314        let a = arr1d(vec![1.0, 2.0, 3.0]);
3315        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0])?;
3316        // n=3, ddof=3 → divisor clamps to 0; IEEE-754 true-division yields +inf,
3317        // matching numpy (np.var([1.,2.,3.], ddof=3) == inf; _core/_methods.py:204).
3318        var_into(&a, None, 3, &mut out)?;
3319        let got: f64 = out.iter().copied().next().unwrap_or(0.0);
3320        assert!(got.is_infinite() && got.is_sign_positive());
3321        Ok(())
3322    }
3323
3324    #[test]
3325    fn std_into_does_not_leave_variance_in_destination_on_success() {
3326        // var of [1,2,3,4] with ddof=0 = 1.25; std = sqrt(1.25) ≈ 1.118.
3327        // The destination must hold the std value, not the variance, after
3328        // a successful call.
3329        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3330        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3331        std_into(&a, None, 0, &mut out).unwrap();
3332        let got = out.iter().next().copied().unwrap();
3333        let expected = 1.25_f64.sqrt();
3334        assert!((got - expected).abs() < 1e-12);
3335        // Sanity: result is not the variance.
3336        assert!((got - 1.25).abs() > 1e-3);
3337    }
3338
3339    // ---- *_with reductions: initial= and where= (#459) ----
3340
3341    #[test]
3342    fn sum_with_initial_only() {
3343        // Initial = 100; sum becomes 100 + 1 + 2 + 3 + 4 = 110
3344        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3345        let r = sum_with(&a, None, Some(100.0), None).unwrap();
3346        assert_eq!(r.iter().next().copied().unwrap(), 110.0);
3347    }
3348
3349    #[test]
3350    fn sum_with_where_mask_only() {
3351        // Sum positives only.
3352        let a = arr1d(vec![1.0, -2.0, 3.0, -4.0, 5.0]);
3353        let mask =
3354            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3355                .unwrap();
3356        let r = sum_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3357        assert_eq!(r.iter().next().copied().unwrap(), 9.0);
3358    }
3359
3360    #[test]
3361    fn sum_with_initial_and_mask_combined() {
3362        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3363        let mask =
3364            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
3365        let r = sum_with(&a, None, Some(50.0), Some(&mask.to_dyn())).unwrap();
3366        // 50 + 1 + 3 = 54
3367        assert_eq!(r.iter().next().copied().unwrap(), 54.0);
3368    }
3369
3370    #[test]
3371    fn sum_with_axis_and_mask() {
3372        // (2, 3); mask zeroes out the first column; row sums become
3373        // [2+3, 5+6] = [5, 11].
3374        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3375            .unwrap();
3376        let mask = Array::<bool, Ix2>::from_vec(
3377            Ix2::new([2, 3]),
3378            vec![false, true, true, false, true, true],
3379        )
3380        .unwrap();
3381        let r = sum_with(&a, Some(1), None, Some(&mask.to_dyn())).unwrap();
3382        assert_eq!(r.shape(), &[2]);
3383        assert_eq!(r.as_slice().unwrap(), &[5.0, 11.0]);
3384    }
3385
3386    #[test]
3387    fn sum_with_axis_and_initial() {
3388        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3389            .unwrap();
3390        // axis=1 with initial=10 → row sums become [10+6, 10+15] = [16, 25]
3391        let r = sum_with(&a, Some(1), Some(10.0), None).unwrap();
3392        assert_eq!(r.as_slice().unwrap(), &[16.0, 25.0]);
3393    }
3394
3395    #[test]
3396    fn sum_with_no_initial_no_mask_matches_legacy_sum() {
3397        // Sanity: omitting both knobs should match the existing sum().
3398        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3399            .unwrap();
3400        let legacy = sum(&a, Some(1)).unwrap();
3401        let with_form = sum_with(&a, Some(1), None, None).unwrap();
3402        assert_eq!(legacy.as_slice().unwrap(), with_form.as_slice().unwrap());
3403    }
3404
3405    #[test]
3406    fn sum_with_rejects_mismatched_mask_shape() {
3407        let a = arr1d(vec![1.0, 2.0, 3.0]);
3408        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
3409        assert!(sum_with(&a, None, None, Some(&mask.to_dyn())).is_err());
3410    }
3411
3412    #[test]
3413    fn prod_with_initial_only() {
3414        let a = arr1d(vec![2.0, 3.0, 4.0]);
3415        let r = prod_with(&a, None, Some(10.0), None).unwrap();
3416        // 10 * 2 * 3 * 4 = 240
3417        assert_eq!(r.iter().next().copied().unwrap(), 240.0);
3418    }
3419
3420    #[test]
3421    fn prod_with_where_mask() {
3422        let a = arr1d(vec![2.0, 0.0, 3.0, 0.0, 4.0]);
3423        let mask =
3424            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3425                .unwrap();
3426        let r = prod_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3427        // 2 * 3 * 4 = 24 (zero positions skipped)
3428        assert_eq!(r.iter().next().copied().unwrap(), 24.0);
3429    }
3430
3431    #[test]
3432    fn min_with_initial_provides_fallback_for_empty_lane() {
3433        // Empty array + initial = the initial value.
3434        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3435        let r = min_with(&a, None, Some(99.0), None).unwrap();
3436        assert_eq!(r.iter().next().copied().unwrap(), 99.0);
3437    }
3438
3439    #[test]
3440    fn min_with_initial_caps_actual_min() {
3441        // Initial=0, data=[1,2,3] → min(0, 1, 2, 3) = 0.
3442        let a = arr1d(vec![1.0, 2.0, 3.0]);
3443        let r = min_with(&a, None, Some(0.0), None).unwrap();
3444        assert_eq!(r.iter().next().copied().unwrap(), 0.0);
3445    }
3446
3447    #[test]
3448    fn min_with_where_mask_filters_then_reduces() {
3449        let a = arr1d(vec![5.0, 1.0, 4.0, 2.0, 3.0]);
3450        // Mask out positions 1 and 3 (values 1 and 2).
3451        let mask =
3452            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3453                .unwrap();
3454        let r = min_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3455        // Filtered: [5, 4, 3] → min = 3
3456        assert_eq!(r.iter().next().copied().unwrap(), 3.0);
3457    }
3458
3459    #[test]
3460    fn min_with_empty_lane_no_initial_errors() {
3461        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3462        assert!(min_with(&a, None, None, None).is_err());
3463    }
3464
3465    #[test]
3466    fn min_with_fully_masked_axis_lane_no_initial_errors() {
3467        // Each row fully masked out → error per lane.
3468        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3469            .unwrap();
3470        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
3471        assert!(min_with(&a, Some(1), None, Some(&mask.to_dyn())).is_err());
3472    }
3473
3474    #[test]
3475    fn min_with_fully_masked_axis_lane_with_initial_uses_initial() {
3476        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3477            .unwrap();
3478        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
3479        let r = min_with(&a, Some(1), Some(99.0), Some(&mask.to_dyn())).unwrap();
3480        assert_eq!(r.as_slice().unwrap(), &[99.0, 99.0]);
3481    }
3482
3483    #[test]
3484    fn max_with_initial_caps_actual_max() {
3485        let a = arr1d(vec![1.0, 2.0, 3.0]);
3486        let r = max_with(&a, None, Some(99.0), None).unwrap();
3487        assert_eq!(r.iter().next().copied().unwrap(), 99.0);
3488    }
3489
3490    #[test]
3491    fn max_with_where_mask_filters_then_reduces() {
3492        let a = arr1d(vec![5.0, 10.0, 4.0, 20.0, 3.0]);
3493        // Mask out positions 1 and 3 (values 10 and 20).
3494        let mask =
3495            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3496                .unwrap();
3497        let r = max_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3498        // Filtered: [5, 4, 3] → max = 5
3499        assert_eq!(r.iter().next().copied().unwrap(), 5.0);
3500    }
3501
3502    #[test]
3503    fn mean_where_filters_and_divides_by_count() {
3504        // [1, 2, 3, 4, 5] with mask [T, F, T, F, T] → (1+3+5)/3 = 3.0
3505        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3506        let mask =
3507            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3508                .unwrap();
3509        let r = mean_where(&a, None, Some(&mask.to_dyn())).unwrap();
3510        assert!((r.iter().next().copied().unwrap() - 3.0).abs() < 1e-12);
3511    }
3512
3513    #[test]
3514    fn mean_where_no_mask_matches_legacy_mean() {
3515        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3516            .unwrap();
3517        let legacy = mean(&a, Some(1)).unwrap();
3518        let where_form = mean_where(&a, Some(1), None).unwrap();
3519        assert_eq!(legacy.as_slice().unwrap(), where_form.as_slice().unwrap());
3520    }
3521
3522    #[test]
3523    fn mean_where_fully_masked_lane_returns_nan() {
3524        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3525            .unwrap();
3526        let mask = Array::<bool, Ix2>::from_vec(
3527            Ix2::new([2, 3]),
3528            vec![true, true, true, false, false, false],
3529        )
3530        .unwrap();
3531        let r = mean_where(&a, Some(1), Some(&mask.to_dyn())).unwrap();
3532        let s = r.as_slice().unwrap();
3533        assert!((s[0] - 2.0).abs() < 1e-12);
3534        assert!(s[1].is_nan());
3535    }
3536
3537    #[test]
3538    fn mean_where_axis_with_partial_mask() {
3539        // (2, 3); mask = [[T,T,F],[F,T,T]]
3540        // Row 0 mean: (1+2)/2 = 1.5
3541        // Row 1 mean: (5+6)/2 = 5.5
3542        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3543            .unwrap();
3544        let mask = Array::<bool, Ix2>::from_vec(
3545            Ix2::new([2, 3]),
3546            vec![true, true, false, false, true, true],
3547        )
3548        .unwrap();
3549        let r = mean_where(&a, Some(1), Some(&mask.to_dyn())).unwrap();
3550        assert_eq!(r.shape(), &[2]);
3551        let s = r.as_slice().unwrap();
3552        assert!((s[0] - 1.5).abs() < 1e-12);
3553        assert!((s[1] - 5.5).abs() < 1e-12);
3554    }
3555
3556    // ---- broadcast-compatible where masks (#565) ----
3557
3558    #[test]
3559    fn sum_with_mask_broadcasts_ix1_into_ix2() {
3560        // (2, 3) input; mask of shape (3,) — broadcasts to (2, 3) with
3561        // every row identical. sum ignores the first column, so both
3562        // row sums skip that column's contribution.
3563        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3564            .unwrap();
3565        let mask_1d = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, true]).unwrap();
3566        let r = sum_with(&a, None, None, Some(&mask_1d.to_dyn())).unwrap();
3567        // Sum with first column masked out: 2 + 3 + 5 + 6 = 16.0
3568        assert!((r.iter().next().copied().unwrap() - 16.0).abs() < 1e-12);
3569    }
3570
3571    #[test]
3572    fn sum_with_mask_broadcasts_column_vector_across_rows() {
3573        // (3, 4) input; mask of shape (3, 1) — broadcasts across the
3574        // four columns, so each row is either fully kept or fully
3575        // masked out.
3576        let a = Array::<f64, Ix2>::from_vec(
3577            Ix2::new([3, 4]),
3578            vec![
3579                1.0, 2.0, 3.0, 4.0, // row 0
3580                5.0, 6.0, 7.0, 8.0, // row 1
3581                9.0, 10.0, 11.0, 12.0, // row 2
3582            ],
3583        )
3584        .unwrap();
3585        let mask_col =
3586            Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![true, false, true]).unwrap();
3587        let r = sum_with(&a, None, None, Some(&mask_col.to_dyn())).unwrap();
3588        // Rows 0 and 2 kept: 1+2+3+4 + 9+10+11+12 = 10 + 42 = 52
3589        assert!((r.iter().next().copied().unwrap() - 52.0).abs() < 1e-12);
3590    }
3591
3592    #[test]
3593    fn sum_with_mask_broadcasts_row_vector_against_axis_reduction() {
3594        // (2, 3) reducing axis=1; mask shape (3,) broadcasts against
3595        // each row.
3596        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3597            .unwrap();
3598        let mask_1d = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3599        let r = sum_with(&a, Some(1), None, Some(&mask_1d.to_dyn())).unwrap();
3600        assert_eq!(r.shape(), &[2]);
3601        // Row 0: 1 + 3 = 4; row 1: 4 + 6 = 10
3602        assert_eq!(r.as_slice().unwrap(), &[4.0, 10.0]);
3603    }
3604
3605    #[test]
3606    fn prod_with_mask_broadcasts_ix1() {
3607        // (2, 3); mask shape (3,) keeps columns 0 and 2.
3608        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
3609            .unwrap();
3610        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3611        let r = prod_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3612        // Product of kept values: 2 * 4 * 5 * 7 = 280
3613        assert!((r.iter().next().copied().unwrap() - 280.0).abs() < 1e-12);
3614    }
3615
3616    #[test]
3617    fn min_with_mask_broadcasts_ix1() {
3618        // Column mask; find min across the kept columns.
3619        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 1.0, 4.0, 2.0, 10.0, 3.0])
3620            .unwrap();
3621        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3622        let r = min_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3623        // Kept values: 5, 4, 2, 3 → min = 2
3624        assert!((r.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
3625    }
3626
3627    #[test]
3628    fn max_with_mask_broadcasts_ix1() {
3629        let a =
3630            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 100.0, 4.0, 2.0, 200.0, 3.0])
3631                .unwrap();
3632        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3633        let r = max_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3634        // Kept values: 5, 4, 2, 3 → max = 5
3635        assert!((r.iter().next().copied().unwrap() - 5.0).abs() < 1e-12);
3636    }
3637
3638    #[test]
3639    fn mean_where_mask_broadcasts_ix1() {
3640        // (2, 3); mask (3,) keeps columns 0 and 2. Mean of 4 kept
3641        // values: (1 + 3 + 4 + 6) / 4 = 3.5
3642        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3643            .unwrap();
3644        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3645        let r = mean_where(&a, None, Some(&mask.to_dyn())).unwrap();
3646        assert!((r.iter().next().copied().unwrap() - 3.5).abs() < 1e-12);
3647    }
3648
3649    #[test]
3650    fn with_mask_rejects_incompatible_broadcast_shape() {
3651        // Mask rank compatible but length wrong: shape (2,) against a
3652        // (2, 3) input cannot broadcast (the 2 aligns with the last
3653        // dim which is 3, not the first which is 2).
3654        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3655            .unwrap();
3656        let bad_mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
3657        assert!(sum_with(&a, None, None, Some(&bad_mask.to_dyn())).is_err());
3658    }
3659
3660    #[test]
3661    fn sum_with_mask_broadcast_scalar_like_length_1() {
3662        // A shape-(1,) mask is the scalar case — broadcasts to every
3663        // position of the input, which for `true` is an identity and
3664        // for `false` yields 0 (everything masked out).
3665        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3666            .unwrap();
3667        let true_mask = Array::<bool, Ix1>::from_vec(Ix1::new([1]), vec![true]).unwrap();
3668        let r = sum_with(&a, None, None, Some(&true_mask.to_dyn())).unwrap();
3669        // All positions kept: 1+2+3+4+5+6 = 21
3670        assert!((r.iter().next().copied().unwrap() - 21.0).abs() < 1e-12);
3671
3672        let false_mask = Array::<bool, Ix1>::from_vec(Ix1::new([1]), vec![false]).unwrap();
3673        let r2 = sum_with(&a, None, Some(100.0), Some(&false_mask.to_dyn())).unwrap();
3674        // Nothing kept, initial = 100 → result = 100
3675        assert!((r2.iter().next().copied().unwrap() - 100.0).abs() < 1e-12);
3676    }
3677
3678    #[test]
3679    fn into_reductions_match_allocating_versions_3d_axis_2() {
3680        // Cross-check every *_into against its allocating sibling on a
3681        // 3-D input, axis = last dim — guards against an off-by-one in
3682        // the in-place index walker.
3683        use ferray_core::Ix3;
3684        let data: Vec<f64> = (0..24).map(|i| f64::from(i) + 0.5).collect();
3685        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
3686
3687        let s_alloc = sum(&a, Some(2)).unwrap();
3688        let mut s_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3689        sum_into(&a, Some(2), &mut s_out).unwrap();
3690        assert_eq!(s_alloc.as_slice().unwrap(), s_out.as_slice().unwrap());
3691
3692        let p_alloc = prod(&a, Some(2)).unwrap();
3693        let mut p_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3694        prod_into(&a, Some(2), &mut p_out).unwrap();
3695        assert_eq!(p_alloc.as_slice().unwrap(), p_out.as_slice().unwrap());
3696
3697        let mn_alloc = min(&a, Some(2)).unwrap();
3698        let mut mn_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3699        min_into(&a, Some(2), &mut mn_out).unwrap();
3700        assert_eq!(mn_alloc.as_slice().unwrap(), mn_out.as_slice().unwrap());
3701
3702        let mx_alloc = max(&a, Some(2)).unwrap();
3703        let mut mx_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3704        max_into(&a, Some(2), &mut mx_out).unwrap();
3705        assert_eq!(mx_alloc.as_slice().unwrap(), mx_out.as_slice().unwrap());
3706
3707        let me_alloc = mean(&a, Some(2)).unwrap();
3708        let mut me_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3709        mean_into(&a, Some(2), &mut me_out).unwrap();
3710        assert_eq!(me_alloc.as_slice().unwrap(), me_out.as_slice().unwrap());
3711
3712        let v_alloc = var(&a, Some(2), 0).unwrap();
3713        let mut v_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3714        var_into(&a, Some(2), 0, &mut v_out).unwrap();
3715        // var output may have small numerical drift between the two-pass
3716        // and one-shot kernels — compare element-wise within tolerance.
3717        for (a, b) in v_alloc
3718            .as_slice()
3719            .unwrap()
3720            .iter()
3721            .zip(v_out.as_slice().unwrap())
3722        {
3723            assert!((a - b).abs() < 1e-10);
3724        }
3725
3726        let sd_alloc = std_(&a, Some(2), 0).unwrap();
3727        let mut sd_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3728        std_into(&a, Some(2), 0, &mut sd_out).unwrap();
3729        for (a, b) in sd_alloc
3730            .as_slice()
3731            .unwrap()
3732            .iter()
3733            .zip(sd_out.as_slice().unwrap())
3734        {
3735            assert!((a - b).abs() < 1e-10);
3736        }
3737    }
3738
3739    // -- ptp --
3740
3741    #[test]
3742    fn test_ptp_1d() {
3743        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 5.0, 3.0, 9.0, 2.0]).unwrap();
3744        let r = ptp(&a, None).unwrap();
3745        assert_eq!(r.iter().copied().next().unwrap(), 8.0);
3746    }
3747
3748    #[test]
3749    fn test_ptp_2d_axis() {
3750        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 3.0, 7.0, 2.0, 9.0])
3751            .unwrap();
3752        let r = ptp(&a, Some(1)).unwrap();
3753        // row 0: max=5, min=1 → 4; row 1: max=9, min=2 → 7
3754        assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![4.0, 7.0]);
3755    }
3756
3757    #[test]
3758    fn test_ptp_empty_errs() {
3759        let a: Array<f64, Ix1> = Array::from_vec(Ix1::new([0]), vec![]).unwrap();
3760        assert!(ptp(&a, None).is_err());
3761    }
3762
3763    // -- average --
3764
3765    #[test]
3766    fn test_average_unweighted_matches_mean() {
3767        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3768        let r = average(&a, None, None).unwrap();
3769        assert!((r.iter().copied().next().unwrap() - 2.5).abs() < 1e-12);
3770    }
3771
3772    #[test]
3773    fn test_average_weighted() {
3774        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3775        let w = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 1.0, 1.0, 7.0]).unwrap();
3776        // (1+2+3+4*7) / (1+1+1+7) = 34/10 = 3.4
3777        let r = average(&a, Some(&w), None).unwrap();
3778        assert!((r.iter().copied().next().unwrap() - 3.4).abs() < 1e-12);
3779    }
3780
3781    #[test]
3782    fn test_average_weighted_axis() {
3783        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3784            .unwrap();
3785        let w = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
3786            .unwrap();
3787        // With uniform weights, axis-1 average = row mean = [2, 5]
3788        let r = average(&a, Some(&w), Some(1)).unwrap();
3789        let data: Vec<f64> = r.iter().copied().collect();
3790        assert!((data[0] - 2.0).abs() < 1e-12);
3791        assert!((data[1] - 5.0).abs() < 1e-12);
3792    }
3793
3794    #[test]
3795    fn test_average_weights_zero_sum_errs() {
3796        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
3797        let w = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.0, 0.0, 0.0]).unwrap();
3798        assert!(average(&a, Some(&w), None).is_err());
3799    }
3800
3801    #[test]
3802    fn test_average_shape_mismatch_errs() {
3803        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
3804        let w = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
3805        assert!(average(&a, Some(&w), None).is_err());
3806    }
3807
3808    // ----------------------------------------------------------------------
3809    // var_as_f64 / std_as_f64 integer promotion (#170)
3810    // ----------------------------------------------------------------------
3811
3812    #[test]
3813    fn var_as_f64_promotes_int_input() {
3814        let a = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3815        let v = var_as_f64(&a, None, 0).unwrap();
3816        // var([1,2,3,4,5]) with ddof=0 = 2.0
3817        assert!((v.iter().next().unwrap() - 2.0).abs() < 1e-12);
3818    }
3819
3820    #[test]
3821    fn var_as_f64_ddof_1() {
3822        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3823        let v = var_as_f64(&a, None, 1).unwrap();
3824        // var([1,2,3,4,5]) with ddof=1 = 2.5
3825        assert!((v.iter().next().unwrap() - 2.5).abs() < 1e-12);
3826    }
3827
3828    #[test]
3829    fn std_as_f64_promotes_int_input() {
3830        let a = Array::<u32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3831        let s = std_as_f64(&a, None, 0).unwrap();
3832        // std([1,2,3,4,5]) with ddof=0 = sqrt(2.0)
3833        assert!((s.iter().next().unwrap() - 2.0_f64.sqrt()).abs() < 1e-12);
3834    }
3835
3836    #[test]
3837    fn var_as_f64_axis_2d() {
3838        use ferray_core::dimension::Ix2;
3839        let a = Array::<i64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
3840        // var along axis 0 of [[1,2,3],[4,5,6]] is [2.25, 2.25, 2.25]
3841        let v = var_as_f64(&a, Some(0), 0).unwrap();
3842        for x in v.iter() {
3843            assert!((x - 2.25).abs() < 1e-12);
3844        }
3845    }
3846
3847    // ----------------------------------------------------------------------
3848    // f32 sibling tests (#185) — exercises the f32 SIMD path added in #173
3849    // and the generic Float-bound reduction paths.
3850    // ----------------------------------------------------------------------
3851
3852    fn arr1d_f32(data: Vec<f32>) -> Array<f32, Ix1> {
3853        Array::<f32, Ix1>::from_vec(Ix1::new([data.len()]), data).unwrap()
3854    }
3855
3856    #[test]
3857    fn sum_f32_basic() {
3858        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3859        let s = sum(&a, None).unwrap();
3860        assert!((s.iter().next().copied().unwrap() - 15.0).abs() < 1e-6);
3861    }
3862
3863    #[test]
3864    fn sum_f32_large_for_simd() {
3865        // Big enough to actually exercise the SIMD pairwise sum kernel.
3866        let n = 4096;
3867        let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
3868        let a = arr1d_f32(data);
3869        let s = sum(&a, None).unwrap();
3870        let expected = 0.1 * (n as f32) * ((n - 1) as f32) / 2.0;
3871        let got = s.iter().copied().next().unwrap();
3872        assert!((got - expected).abs() / expected < 1e-4);
3873    }
3874
3875    #[test]
3876    fn mean_f32_basic() {
3877        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3878        let m = mean(&a, None).unwrap();
3879        assert!((m.iter().next().copied().unwrap() - 3.0).abs() < 1e-6);
3880    }
3881
3882    #[test]
3883    fn var_f32_basic() {
3884        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3885        let v = var(&a, None, 0).unwrap();
3886        assert!((v.iter().next().copied().unwrap() - 2.0).abs() < 1e-6);
3887    }
3888
3889    #[test]
3890    fn var_f32_large_for_simd() {
3891        // Exercises the simd_sum_sq_diff_f32 kernel.
3892        let n = 4096;
3893        let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
3894        let a = arr1d_f32(data);
3895        let v = var(&a, None, 0).unwrap();
3896        let expected = 0.25 * ((n * n - 1) as f32) / 12.0;
3897        let got = v.iter().copied().next().unwrap();
3898        assert!((got - expected).abs() / expected < 1e-3);
3899    }
3900
3901    #[test]
3902    fn std_f32_basic() {
3903        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3904        let s = std_(&a, None, 0).unwrap();
3905        assert!((s.iter().next().copied().unwrap() - 2.0_f32.sqrt()).abs() < 1e-6);
3906    }
3907
3908    #[test]
3909    fn min_max_f32_basic() {
3910        let a = arr1d_f32(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
3911        let mn = min(&a, None).unwrap();
3912        let mx = max(&a, None).unwrap();
3913        assert_eq!(mn.iter().next().copied().unwrap(), 1.0);
3914        assert_eq!(mx.iter().next().copied().unwrap(), 9.0);
3915    }
3916
3917    #[test]
3918    fn prod_f32_basic() {
3919        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0]);
3920        let p = prod(&a, None).unwrap();
3921        assert!((p.iter().next().copied().unwrap() - 24.0).abs() < 1e-6);
3922    }
3923
3924    #[test]
3925    fn argmin_argmax_f32_basic() {
3926        let a = arr1d_f32(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
3927        let amin = argmin(&a, None).unwrap();
3928        let amax = argmax(&a, None).unwrap();
3929        assert_eq!(amin.iter().next().copied().unwrap(), 1);
3930        assert_eq!(amax.iter().next().copied().unwrap(), 5);
3931    }
3932
3933    #[test]
3934    fn cumsum_f32_basic() {
3935        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0]);
3936        let c = cumsum(&a, None).unwrap();
3937        let v: Vec<f32> = c.iter().copied().collect();
3938        for (got, expected) in v.iter().zip(&[1.0_f32, 3.0, 6.0, 10.0]) {
3939            assert!((got - expected).abs() < 1e-6);
3940        }
3941    }
3942}