Skip to main content

ferray_stats/reductions/
mod.rs

1// ferray-stats: Core reductions — sum, prod, min, max, argmin, argmax, mean, var, std (REQ-1, REQ-2)
2//
3// `cumulative` used to exist as an empty placeholder whose only content
4// was a note that cumulative functions live in nan_aware / mod.rs; it's
5// been deleted to stop pretending there's a dedicated module there (#162).
6
7pub mod nan_aware;
8pub mod quantile;
9
10use std::any::TypeId;
11
12use ferray_core::error::{FerrayError, FerrayResult};
13use ferray_core::{Array, Dimension, Element, IxDyn};
14use num_traits::Float;
15
16use crate::parallel;
17
18/// Try SIMD-accelerated fused sum of squared differences for f64 or f32 (#173).
19/// Returns sum((x - mean)²) without allocating an intermediate Vec.
20#[inline]
21fn try_simd_sum_sq_diff<T: Element + Copy + 'static>(data: &[T], mean: T) -> Option<T> {
22    if TypeId::of::<T>() == TypeId::of::<f64>() {
23        // SAFETY: TypeId check guarantees T is f64. size_of::<T>() == size_of::<f64>().
24        let f64_slice =
25            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f64>(), data.len()) };
26        let mean_f64: f64 = unsafe { std::mem::transmute_copy(&mean) };
27        let result = parallel::simd_sum_sq_diff_f64(f64_slice, mean_f64);
28        Some(unsafe { std::mem::transmute_copy(&result) })
29    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
30        // SAFETY: TypeId check guarantees T is f32. size_of::<T>() == size_of::<f32>().
31        let f32_slice =
32            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f32>(), data.len()) };
33        let mean_f32: f32 = unsafe { std::mem::transmute_copy(&mean) };
34        let result = parallel::simd_sum_sq_diff_f32(f32_slice, mean_f32);
35        Some(unsafe { std::mem::transmute_copy(&result) })
36    } else {
37        None
38    }
39}
40
41/// Try SIMD-accelerated pairwise sum for f64 or f32 (#173).
42/// Returns the sum transmuted back to T, or None if T is not f64/f32.
43#[inline]
44fn try_simd_pairwise_sum<T: Element + Copy + 'static>(data: &[T]) -> Option<T> {
45    if TypeId::of::<T>() == TypeId::of::<f64>() {
46        // SAFETY: TypeId check guarantees T is f64. size_of::<T>() == size_of::<f64>().
47        let f64_slice =
48            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f64>(), data.len()) };
49        let result = parallel::pairwise_sum_f64(f64_slice);
50        Some(unsafe { std::mem::transmute_copy(&result) })
51    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
52        // SAFETY: TypeId check guarantees T is f32. size_of::<T>() == size_of::<f32>().
53        let f32_slice =
54            unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<f32>(), data.len()) };
55        let result = parallel::pairwise_sum_f32(f32_slice);
56        Some(unsafe { std::mem::transmute_copy(&result) })
57    } else {
58        None
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Internal axis-reduction helper
64// ---------------------------------------------------------------------------
65
66/// Compute row-major strides for a given shape.
67pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
68    let ndim = shape.len();
69    let mut strides = vec![1usize; ndim];
70    for i in (0..ndim.saturating_sub(1)).rev() {
71        strides[i] = strides[i + 1] * shape[i + 1];
72    }
73    strides
74}
75
76/// Flat index from a multi-index given row-major strides.
77pub(crate) fn flat_index(multi: &[usize], strides: &[usize]) -> usize {
78    multi.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
79}
80
81/// Increment a multi-index in row-major order. Returns false if overflowed.
82pub(crate) fn increment_multi_index(multi: &mut [usize], shape: &[usize]) -> bool {
83    for d in (0..multi.len()).rev() {
84        multi[d] += 1;
85        if multi[d] < shape[d] {
86            return true;
87        }
88        multi[d] = 0;
89    }
90    false
91}
92
93/// General axis reduction parameterized on the output element type.
94///
95/// Walks `data` (in row-major order with `shape`), gathers each lane
96/// along `axis` into a temporary `Vec<T>`, and calls `f` to collapse it
97/// to an output of type `U`. Returns the concatenated outputs in
98/// row-major order for the shape with `axis` removed.
99///
100/// Used as the shared backbone for `reduce_axis_general` (T=T path) and
101/// `reduce_axis_general_u64` (T=T, U=u64 path). The two used to have
102/// copy-pasted bodies differing only in return type (#161).
103pub(crate) fn reduce_axis_typed<T, U, F>(data: &[T], shape: &[usize], axis: usize, f: F) -> Vec<U>
104where
105    T: Copy,
106    F: Fn(&[T]) -> U,
107{
108    let ndim = shape.len();
109    let axis_len = shape[axis];
110    let strides = compute_strides(shape);
111
112    // Output shape: shape with axis removed
113    let out_shape: Vec<usize> = shape
114        .iter()
115        .enumerate()
116        .filter(|&(i, _)| i != axis)
117        .map(|(_, &s)| s)
118        .collect();
119    let out_size: usize = if out_shape.is_empty() {
120        1
121    } else {
122        out_shape.iter().product()
123    };
124
125    let mut result = Vec::with_capacity(out_size);
126    let mut out_multi = vec![0usize; out_shape.len()];
127    let mut in_multi = vec![0usize; ndim];
128    let mut lane_vec: Vec<T> = Vec::with_capacity(axis_len);
129
130    for _ in 0..out_size {
131        // Build input multi-index by inserting axis position
132        let mut out_dim = 0;
133        for (d, idx) in in_multi.iter_mut().enumerate() {
134            if d == axis {
135                *idx = 0;
136            } else {
137                *idx = out_multi[out_dim];
138                out_dim += 1;
139            }
140        }
141
142        // Gather lane values
143        lane_vec.clear();
144        for k in 0..axis_len {
145            in_multi[axis] = k;
146            let idx = flat_index(&in_multi, &strides);
147            lane_vec.push(data[idx]);
148        }
149
150        result.push(f(&lane_vec));
151
152        // Increment output multi-index
153        if !out_shape.is_empty() {
154            increment_multi_index(&mut out_multi, &out_shape);
155        }
156    }
157
158    result
159}
160
161/// Thin wrapper: `T -> T` reduction. Preserved for the many call sites
162/// that pass `Fn(&[T]) -> T` kernels.
163#[inline]
164pub(crate) fn reduce_axis_general<T, F>(data: &[T], shape: &[usize], axis: usize, f: F) -> Vec<T>
165where
166    T: Copy,
167    F: Fn(&[T]) -> T,
168{
169    reduce_axis_typed(data, shape, axis, f)
170}
171
172/// In-place variant of [`reduce_axis_typed`]: writes results directly into
173/// `dst` without allocating a fresh `Vec<U>` to hold the output.
174///
175/// Used by the `*_into` reductions (#563) so callers that pre-allocate an
176/// output buffer truly avoid every per-call allocation. The destination
177/// slice must already have exactly `out_size` elements where `out_size` is
178/// the product of `shape` with `axis` removed (or 1 if the result is
179/// 0-D); callers should validate this via [`check_out_shape`] before
180/// invoking the kernel.
181pub(crate) fn reduce_axis_typed_into<T, U, F>(
182    data: &[T],
183    shape: &[usize],
184    axis: usize,
185    dst: &mut [U],
186    f: F,
187) where
188    T: Copy,
189    F: Fn(&[T]) -> U,
190{
191    let ndim = shape.len();
192    let axis_len = shape[axis];
193    let strides = compute_strides(shape);
194
195    // Output multi-index walks the shape with `axis` removed.
196    let out_shape: Vec<usize> = shape
197        .iter()
198        .enumerate()
199        .filter(|&(i, _)| i != axis)
200        .map(|(_, &s)| s)
201        .collect();
202
203    debug_assert_eq!(
204        dst.len(),
205        if out_shape.is_empty() {
206            1
207        } else {
208            out_shape.iter().product::<usize>()
209        },
210        "reduce_axis_typed_into: dst length must match the reduction's output size"
211    );
212
213    let mut out_multi = vec![0usize; out_shape.len()];
214    let mut in_multi = vec![0usize; ndim];
215    // Reused per-lane buffer — one allocation total instead of one per
216    // lane the way `reduce_axis_typed` does it via `Vec::push`.
217    let mut lane_vec: Vec<T> = Vec::with_capacity(axis_len);
218
219    for slot in dst.iter_mut() {
220        // Build the input multi-index by inserting the axis position.
221        let mut out_dim = 0;
222        for (d, idx) in in_multi.iter_mut().enumerate() {
223            if d == axis {
224                *idx = 0;
225            } else {
226                *idx = out_multi[out_dim];
227                out_dim += 1;
228            }
229        }
230
231        // Gather lane values into the reusable buffer.
232        lane_vec.clear();
233        for k in 0..axis_len {
234            in_multi[axis] = k;
235            let idx = flat_index(&in_multi, &strides);
236            lane_vec.push(data[idx]);
237        }
238
239        *slot = f(&lane_vec);
240
241        if !out_shape.is_empty() {
242            increment_multi_index(&mut out_multi, &out_shape);
243        }
244    }
245}
246
247/// In-place T-to-T reduction wrapper around [`reduce_axis_typed_into`].
248#[inline]
249pub(crate) fn reduce_axis_general_into<T, F>(
250    data: &[T],
251    shape: &[usize],
252    axis: usize,
253    dst: &mut [T],
254    f: F,
255) where
256    T: Copy,
257    F: Fn(&[T]) -> T,
258{
259    reduce_axis_typed_into(data, shape, axis, dst, f);
260}
261
262/// Validate axis parameter and return an error if out of bounds.
263pub(crate) const fn validate_axis(axis: usize, ndim: usize) -> FerrayResult<()> {
264    if axis >= ndim {
265        Err(FerrayError::axis_out_of_bounds(axis, ndim))
266    } else {
267        Ok(())
268    }
269}
270
271/// Collect array data into a contiguous Vec in logical (row-major) order.
272pub(crate) fn collect_data<T: Element + Copy, D: Dimension>(a: &Array<T, D>) -> Vec<T> {
273    a.iter().copied().collect()
274}
275
276/// Borrow contiguous data or copy if strided. Avoids allocation for contiguous arrays.
277pub(crate) enum DataRef<'a, T> {
278    Borrowed(&'a [T]),
279    Owned(Vec<T>),
280}
281
282impl<T> std::ops::Deref for DataRef<'_, T> {
283    type Target = [T];
284    fn deref(&self) -> &[T] {
285        match self {
286            DataRef::Borrowed(s) => s,
287            DataRef::Owned(v) => v,
288        }
289    }
290}
291
292/// Get a reference to contiguous data, or copy if strided.
293pub(crate) fn borrow_data<T: Element + Copy, D: Dimension>(a: &Array<T, D>) -> DataRef<'_, T> {
294    if let Some(slice) = a.as_slice() {
295        DataRef::Borrowed(slice)
296    } else {
297        DataRef::Owned(a.iter().copied().collect())
298    }
299}
300
301/// Build an `IxDyn` result array from output shape and data.
302pub(crate) fn make_result<T: Element>(
303    out_shape: &[usize],
304    data: Vec<T>,
305) -> FerrayResult<Array<T, IxDyn>> {
306    Array::from_vec(IxDyn::new(out_shape), data)
307}
308
309/// Validate that `out` has the expected shape and is C-contiguous,
310/// returning a mutable slice into its backing buffer.
311///
312/// Shared by every `*_into` reduction (#467, #563) so the validation
313/// surface lives in exactly one place. Broadcasting is intentionally not
314/// allowed because the destination shape is fixed by the input + axis
315/// combination — accepting a mismatched destination would silently
316/// re-route to a different reduction shape.
317pub(crate) fn check_out_shape<'a, T: Element + Copy>(
318    out: &'a mut Array<T, IxDyn>,
319    expected_shape: &[usize],
320    op_name: &str,
321) -> FerrayResult<&'a mut [T]> {
322    if out.shape() != expected_shape {
323        return Err(FerrayError::shape_mismatch(format!(
324            "{op_name}_into: out shape {:?} does not match expected reduction shape {:?}",
325            out.shape(),
326            expected_shape
327        )));
328    }
329    out.as_slice_mut().ok_or_else(|| {
330        FerrayError::invalid_value(format!("{op_name}_into: out must be C-contiguous"))
331    })
332}
333
334/// Compute the output shape when reducing along an axis.
335pub(crate) fn output_shape(shape: &[usize], axis: usize) -> Vec<usize> {
336    shape
337        .iter()
338        .enumerate()
339        .filter(|&(i, _)| i != axis)
340        .map(|(_, &s)| s)
341        .collect()
342}
343
344// ---------------------------------------------------------------------------
345// Multi-axis reduction helpers (issues #457 + #458)
346// ---------------------------------------------------------------------------
347
348/// Normalize the `axes: Option<&[usize]>` argument of a multi-axis reduction:
349///
350/// - `None` or `Some(&[])` expands to all axes `[0..ndim]` (reduce everything)
351/// - Duplicate axes are an error (matches `NumPy`'s `np.sum(a, axis=(0, 0))`)
352/// - Any out-of-bounds axis is an error
353///
354/// Returns the sorted, unique axis list.
355pub(crate) fn normalize_axes(axes: Option<&[usize]>, ndim: usize) -> FerrayResult<Vec<usize>> {
356    let ax: Vec<usize> = match axes {
357        None | Some([]) => (0..ndim).collect(),
358        Some(s) => s.to_vec(),
359    };
360    for &a in &ax {
361        if a >= ndim {
362            return Err(FerrayError::axis_out_of_bounds(a, ndim));
363        }
364    }
365    let mut sorted = ax;
366    sorted.sort_unstable();
367    for w in sorted.windows(2) {
368        if w[0] == w[1] {
369            return Err(FerrayError::invalid_value(format!(
370                "duplicate axis {} in reduction axes",
371                w[0]
372            )));
373        }
374    }
375    Ok(sorted)
376}
377
378/// Compute the output shape when reducing over multiple axes.
379///
380/// `axes` must be sorted and unique (produced by [`normalize_axes`]).
381/// With `keepdims = true`, reduced axes are replaced by size 1; otherwise
382/// they are removed.
383pub(crate) fn output_shape_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Vec<usize> {
384    if keepdims {
385        shape
386            .iter()
387            .enumerate()
388            .map(|(i, &s)| if axes.contains(&i) { 1 } else { s })
389            .collect()
390    } else {
391        shape
392            .iter()
393            .enumerate()
394            .filter(|(i, _)| !axes.contains(i))
395            .map(|(_, &s)| s)
396            .collect()
397    }
398}
399
400/// Reduce over a set of axes, returning `(result_data, output_shape)`.
401///
402/// `axes` must be sorted and unique (typically produced by [`normalize_axes`]).
403/// The caller supplies a reduction function `f` that collapses a lane of
404/// reduced-axis values to a single scalar.
405///
406/// The output shape honors `keepdims`.
407pub(crate) fn reduce_axes_general<T: Copy, F: Fn(&[T]) -> T>(
408    data: &[T],
409    shape: &[usize],
410    axes: &[usize],
411    keepdims: bool,
412    f: F,
413) -> (Vec<T>, Vec<usize>) {
414    let ndim = shape.len();
415    let strides = compute_strides(shape);
416
417    // Partition axes into reduce / keep.
418    let is_reduce: Vec<bool> = (0..ndim).map(|i| axes.contains(&i)).collect();
419    let keep_axes: Vec<usize> = (0..ndim).filter(|i| !is_reduce[*i]).collect();
420    let reduce_axes: Vec<usize> = (0..ndim).filter(|i| is_reduce[*i]).collect();
421    let keep_shape: Vec<usize> = keep_axes.iter().map(|&i| shape[i]).collect();
422    let reduce_shape: Vec<usize> = reduce_axes.iter().map(|&i| shape[i]).collect();
423
424    let out_size: usize = if keep_shape.is_empty() {
425        1
426    } else {
427        keep_shape.iter().product()
428    };
429    let lane_size: usize = if reduce_shape.is_empty() {
430        1
431    } else {
432        reduce_shape.iter().product()
433    };
434
435    let out_shape = output_shape_axes(shape, axes, keepdims);
436
437    let mut result = Vec::with_capacity(out_size);
438    let mut lane: Vec<T> = Vec::with_capacity(lane_size);
439    let mut keep_multi = vec![0usize; keep_shape.len()];
440    let mut reduce_multi = vec![0usize; reduce_shape.len()];
441    let mut full_multi = vec![0usize; ndim];
442
443    for _ in 0..out_size {
444        // Fill kept-axis positions from keep_multi.
445        for (i, &ax) in keep_axes.iter().enumerate() {
446            full_multi[ax] = keep_multi[i];
447        }
448
449        // Gather values along all reduced axes.
450        lane.clear();
451        reduce_multi.fill(0);
452        for _ in 0..lane_size {
453            for (i, &ax) in reduce_axes.iter().enumerate() {
454                full_multi[ax] = reduce_multi[i];
455            }
456            lane.push(data[flat_index(&full_multi, &strides)]);
457            if !reduce_shape.is_empty() {
458                increment_multi_index(&mut reduce_multi, &reduce_shape);
459            }
460        }
461
462        result.push(f(&lane));
463
464        if !keep_shape.is_empty() {
465            increment_multi_index(&mut keep_multi, &keep_shape);
466        }
467    }
468
469    (result, out_shape)
470}
471
472// ---------------------------------------------------------------------------
473// sum
474// ---------------------------------------------------------------------------
475
476/// Sum of array elements over a given axis, or over all elements if axis is None.
477///
478/// Equivalent to `numpy.sum`.
479///
480/// **Note:** Unlike `NumPy`, which auto-promotes `int32` sums to `int64`,
481/// ferray returns the same type as the input. For large integer arrays
482/// this may overflow. Use [`sum_as_f64`] for overflow-safe integer summation.
483///
484/// # Examples
485/// ```ignore
486/// let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
487/// let s = sum(&a, None).unwrap();
488/// assert_eq!(s.iter().next(), Some(&10.0));
489/// ```
490pub fn sum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
491where
492    T: Element + std::ops::Add<Output = T> + Copy + Send + Sync,
493    D: Dimension,
494{
495    let data = borrow_data(a);
496    match axis {
497        None => {
498            let total = try_simd_pairwise_sum(&data)
499                .unwrap_or_else(|| parallel::parallel_sum(&data, <T as Element>::zero()));
500            make_result(&[], vec![total])
501        }
502        Some(ax) => {
503            validate_axis(ax, a.ndim())?;
504            let shape = a.shape();
505            let out_s = output_shape(shape, ax);
506            let result = reduce_axis_general(&data, shape, ax, |lane| {
507                try_simd_pairwise_sum(lane)
508                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
509            });
510            make_result(&out_s, result)
511        }
512    }
513}
514
515/// Sum of array elements, returning `f64` regardless of input type.
516///
517/// This works on integer arrays (i32, u64, etc.) without overflow risk.
518/// The result is always `Array<f64, IxDyn>`, matching `NumPy`'s behavior
519/// of promoting integer sums to a wider type.
520pub fn sum_as_f64<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<f64, IxDyn>>
521where
522    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
523    D: Dimension,
524{
525    match axis {
526        None => {
527            let total: f64 = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).sum();
528            make_result(&[], vec![total])
529        }
530        Some(ax) => {
531            validate_axis(ax, a.ndim())?;
532            let shape = a.shape();
533            let out_s = output_shape(shape, ax);
534            let f64_data: Vec<f64> = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).collect();
535            let result = reduce_axis_general(&f64_data, shape, ax, |lane| lane.iter().sum());
536            make_result(&out_s, result)
537        }
538    }
539}
540
541// ---------------------------------------------------------------------------
542// prod
543// ---------------------------------------------------------------------------
544
545/// Product of array elements over a given axis.
546///
547/// **Note:** Unlike `NumPy`, which auto-promotes integer products,
548/// ferray returns the same type as the input. For large integer arrays
549/// this may overflow.
550/// Equivalent to `numpy.prod`.
551pub fn prod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
552where
553    T: Element + std::ops::Mul<Output = T> + Copy + Send + Sync,
554    D: Dimension,
555{
556    let data = borrow_data(a);
557    match axis {
558        None => {
559            let total = parallel::parallel_prod(&data, <T as Element>::one());
560            make_result(&[], vec![total])
561        }
562        Some(ax) => {
563            validate_axis(ax, a.ndim())?;
564            let shape = a.shape();
565            let out_s = output_shape(shape, ax);
566            let result = reduce_axis_general(&data, shape, ax, |lane| {
567                lane.iter()
568                    .copied()
569                    .fold(<T as Element>::one(), |acc, x| acc * x)
570            });
571            make_result(&out_s, result)
572        }
573    }
574}
575
576// ---------------------------------------------------------------------------
577// min / max
578// ---------------------------------------------------------------------------
579
580/// Minimum value of array elements over a given axis.
581///
582/// Equivalent to `numpy.min` / `numpy.amin`.
583pub fn min<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
584where
585    T: Element + PartialOrd + Copy,
586    D: Dimension,
587{
588    if a.is_empty() {
589        return Err(FerrayError::invalid_value(
590            "cannot compute min of empty array",
591        ));
592    }
593    // NaN-propagating min: if either operand is NaN (comparison returns false
594    // for both orderings), propagate NaN to match NumPy behavior.
595    let nan_min = |a: T, b: T| -> T {
596        if a <= b {
597            a
598        } else if a > b {
599            b
600        } else {
601            // One of them is NaN — return whichever is unordered
602            // (if a is NaN, a <= b and a > b are both false; return a)
603            a
604        }
605    };
606    let data = borrow_data(a);
607    match axis {
608        None => {
609            let m = data.iter().copied().reduce(nan_min).unwrap();
610            make_result(&[], vec![m])
611        }
612        Some(ax) => {
613            validate_axis(ax, a.ndim())?;
614            let shape = a.shape();
615            let out_s = output_shape(shape, ax);
616            let result = reduce_axis_general(&data, shape, ax, |lane| {
617                lane.iter().copied().reduce(nan_min).unwrap()
618            });
619            make_result(&out_s, result)
620        }
621    }
622}
623
624/// Maximum value of array elements over a given axis.
625///
626/// Equivalent to `numpy.max` / `numpy.amax`.
627pub fn max<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
628where
629    T: Element + PartialOrd + Copy,
630    D: Dimension,
631{
632    if a.is_empty() {
633        return Err(FerrayError::invalid_value(
634            "cannot compute max of empty array",
635        ));
636    }
637    // NaN-propagating max: same logic as min but reversed ordering.
638    let nan_max = |a: T, b: T| -> T {
639        if a >= b {
640            a
641        } else if a < b {
642            b
643        } else {
644            a
645        }
646    };
647    let data = borrow_data(a);
648    match axis {
649        None => {
650            let m = data.iter().copied().reduce(nan_max).unwrap();
651            make_result(&[], vec![m])
652        }
653        Some(ax) => {
654            validate_axis(ax, a.ndim())?;
655            let shape = a.shape();
656            let out_s = output_shape(shape, ax);
657            let result = reduce_axis_general(&data, shape, ax, |lane| {
658                lane.iter().copied().reduce(nan_max).unwrap()
659            });
660            make_result(&out_s, result)
661        }
662    }
663}
664
665// ---------------------------------------------------------------------------
666// argmin / argmax
667// ---------------------------------------------------------------------------
668
669/// Index of the minimum value. For axis=None, returns the flat index.
670/// For axis=Some(ax), returns indices along that axis.
671///
672/// Equivalent to `numpy.argmin`.
673pub fn argmin<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
674where
675    T: Element + PartialOrd + Copy,
676    D: Dimension,
677{
678    if a.is_empty() {
679        return Err(FerrayError::invalid_value(
680            "cannot compute argmin of empty array",
681        ));
682    }
683    let data = borrow_data(a);
684    match axis {
685        None => {
686            let idx = data
687                .iter()
688                .enumerate()
689                .reduce(|(ai, av), (bi, bv)| if av <= bv { (ai, av) } else { (bi, bv) })
690                .unwrap()
691                .0 as u64;
692            make_result(&[], vec![idx])
693        }
694        Some(ax) => {
695            validate_axis(ax, a.ndim())?;
696            let shape = a.shape();
697            let out_s = output_shape(shape, ax);
698            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
699                lane.iter()
700                    .enumerate()
701                    .reduce(|(ai, av), (bi, bv)| if av <= bv { (ai, av) } else { (bi, bv) })
702                    .unwrap()
703                    .0 as u64
704            });
705            make_result(&out_s, result)
706        }
707    }
708}
709
710/// Index of the maximum value.
711///
712/// Equivalent to `numpy.argmax`.
713pub fn argmax<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
714where
715    T: Element + PartialOrd + Copy,
716    D: Dimension,
717{
718    if a.is_empty() {
719        return Err(FerrayError::invalid_value(
720            "cannot compute argmax of empty array",
721        ));
722    }
723    let data = borrow_data(a);
724    match axis {
725        None => {
726            let idx = data
727                .iter()
728                .enumerate()
729                .reduce(|(ai, av), (bi, bv)| if av >= bv { (ai, av) } else { (bi, bv) })
730                .unwrap()
731                .0 as u64;
732            make_result(&[], vec![idx])
733        }
734        Some(ax) => {
735            validate_axis(ax, a.ndim())?;
736            let shape = a.shape();
737            let out_s = output_shape(shape, ax);
738            let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
739                lane.iter()
740                    .enumerate()
741                    .reduce(|(ai, av), (bi, bv)| if av >= bv { (ai, av) } else { (bi, bv) })
742                    .unwrap()
743                    .0 as u64
744            });
745            make_result(&out_s, result)
746        }
747    }
748}
749
750/// Thin wrapper: `T -> u64` reduction (for `argmin`/`argmax` /
751/// `bincount` paths). Shares its body with `reduce_axis_general`
752/// via `reduce_axis_typed`.
753#[inline]
754pub(crate) fn reduce_axis_general_u64<T, F>(
755    data: &[T],
756    shape: &[usize],
757    axis: usize,
758    f: F,
759) -> Vec<u64>
760where
761    T: Copy,
762    F: Fn(&[T]) -> u64,
763{
764    reduce_axis_typed(data, shape, axis, f)
765}
766
767// ---------------------------------------------------------------------------
768// mean
769// ---------------------------------------------------------------------------
770
771/// Range (peak-to-peak) of array elements over a given axis.
772///
773/// Returns `max - min`. Analogous to `numpy.ptp(a, axis=...)`.
774///
775/// # Errors
776/// Returns `FerrayError::InvalidValue` if the array is empty.
777pub fn ptp<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
778where
779    T: Element + PartialOrd + Copy + std::ops::Sub<Output = T>,
780    D: Dimension,
781{
782    if a.is_empty() {
783        return Err(FerrayError::invalid_value(
784            "cannot compute ptp of empty array",
785        ));
786    }
787    let lo = min(a, axis)?;
788    let hi = max(a, axis)?;
789    let lo_data: Vec<T> = lo.iter().copied().collect();
790    let hi_data: Vec<T> = hi.iter().copied().collect();
791    let result: Vec<T> = hi_data
792        .into_iter()
793        .zip(lo_data)
794        .map(|(h, l)| h - l)
795        .collect();
796    make_result(lo.shape(), result)
797}
798
799/// Weighted average of array elements.
800///
801/// When `weights` is `None`, equivalent to [`mean`]. When `Some(w)`, computes
802/// `sum(a * w) / sum(w)` along the given axis (or over all elements when
803/// `axis=None`). The weights array must have the same shape as `a` (the
804/// 1-D-along-axis broadcasting form supported by NumPy is intentionally
805/// omitted here — call broadcast yourself first if you need it).
806///
807/// Analogous to `numpy.average`.
808///
809/// # Errors
810/// - `FerrayError::InvalidValue` if the array is empty.
811/// - `FerrayError::ShapeMismatch` if `weights.shape() != a.shape()`.
812/// - `FerrayError::InvalidValue` if the weight sum along an axis is zero.
813pub fn average<T, D>(
814    a: &Array<T, D>,
815    weights: Option<&Array<T, D>>,
816    axis: Option<usize>,
817) -> FerrayResult<Array<T, IxDyn>>
818where
819    T: Element + Float + Send + Sync,
820    D: Dimension,
821{
822    let Some(w) = weights else {
823        return mean(a, axis);
824    };
825    if a.is_empty() {
826        return Err(FerrayError::invalid_value(
827            "cannot compute average of empty array",
828        ));
829    }
830    if a.shape() != w.shape() {
831        return Err(FerrayError::shape_mismatch(format!(
832            "average: weights shape {:?} differs from array shape {:?}",
833            w.shape(),
834            a.shape(),
835        )));
836    }
837    let a_data = borrow_data(a);
838    let w_data = borrow_data(w);
839    match axis {
840        None => {
841            let mut wsum = <T as Element>::zero();
842            let mut acc = <T as Element>::zero();
843            for (&x, &wi) in a_data.iter().zip(w_data.iter()) {
844                wsum = wsum + wi;
845                acc = acc + x * wi;
846            }
847            if wsum == <T as Element>::zero() {
848                return Err(FerrayError::invalid_value("average: weights sum to zero"));
849            }
850            make_result(&[], vec![acc / wsum])
851        }
852        Some(ax) => {
853            validate_axis(ax, a.ndim())?;
854            let shape = a.shape();
855            let out_s = output_shape(shape, ax);
856            // Walk lanes for both arrays; we can reuse reduce_axis_general
857            // by zipping data + weights into a tagged buffer. Simpler path:
858            // build a side buffer of (a_lane, w_lane) per lane.
859            let outer: usize = out_s.iter().product::<usize>().max(1);
860            let lane_len = shape[ax];
861            // Use the same lane-walk as reduce_axis_general. We replicate
862            // the inner loop here so we can read both data + weights.
863            let mut result = Vec::with_capacity(outer);
864            // Compute strides for picking out the lane.
865            let mut strides = vec![1usize; shape.len()];
866            for i in (0..shape.len() - 1).rev() {
867                strides[i] = strides[i + 1] * shape[i + 1];
868            }
869            let mut idx = vec![0usize; shape.len()];
870            for _ in 0..outer {
871                let mut wsum = <T as Element>::zero();
872                let mut acc = <T as Element>::zero();
873                for j in 0..lane_len {
874                    idx[ax] = j;
875                    let mut flat = 0usize;
876                    for (d, &s) in idx.iter().zip(strides.iter()) {
877                        flat += d * s;
878                    }
879                    let x = a_data[flat];
880                    let wi = w_data[flat];
881                    wsum = wsum + wi;
882                    acc = acc + x * wi;
883                }
884                if wsum == <T as Element>::zero() {
885                    return Err(FerrayError::invalid_value(
886                        "average: weights sum to zero along axis",
887                    ));
888                }
889                result.push(acc / wsum);
890                // Advance multi-index over output dims (every dim except ax).
891                idx[ax] = 0;
892                for d in (0..shape.len()).rev() {
893                    if d == ax {
894                        continue;
895                    }
896                    idx[d] += 1;
897                    if idx[d] < shape[d] {
898                        break;
899                    }
900                    idx[d] = 0;
901                }
902            }
903            make_result(&out_s, result)
904        }
905    }
906}
907
908/// Mean of array elements over a given axis.
909///
910/// Equivalent to `numpy.mean`. The result is always floating-point.
911pub fn mean<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
912where
913    T: Element + Float + Send + Sync,
914    D: Dimension,
915{
916    if a.is_empty() {
917        return Err(FerrayError::invalid_value(
918            "cannot compute mean of empty array",
919        ));
920    }
921    let data = borrow_data(a);
922    match axis {
923        None => {
924            let n = T::from(data.len()).unwrap();
925            let total = try_simd_pairwise_sum(&data)
926                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()));
927            make_result(&[], vec![total / n])
928        }
929        Some(ax) => {
930            validate_axis(ax, a.ndim())?;
931            let shape = a.shape();
932            let out_s = output_shape(shape, ax);
933            let axis_len = shape[ax];
934            let n = T::from(axis_len).unwrap();
935            let result = reduce_axis_general(&data, shape, ax, |lane| {
936                let total = try_simd_pairwise_sum(lane)
937                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
938                total / n
939            });
940            make_result(&out_s, result)
941        }
942    }
943}
944
945/// Mean of array elements, returning `f64` regardless of input type.
946///
947/// This works on integer arrays (i32, u64, etc.) where [`mean`] would
948/// fail because integers don't implement `Float`. The result is always
949/// `Array<f64, IxDyn>`, matching `NumPy`'s behavior of promoting integer
950/// means to float64.
951///
952/// Equivalent to `numpy.mean` for integer inputs.
953pub fn mean_as_f64<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<f64, IxDyn>>
954where
955    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
956    D: Dimension,
957{
958    if a.is_empty() {
959        return Err(FerrayError::invalid_value(
960            "cannot compute mean of empty array",
961        ));
962    }
963    match axis {
964        None => {
965            let n = a.size() as f64;
966            let total: f64 = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).sum();
967            make_result(&[], vec![total / n])
968        }
969        Some(ax) => {
970            validate_axis(ax, a.ndim())?;
971            let shape = a.shape();
972            let out_s = output_shape(shape, ax);
973            let axis_len = shape[ax] as f64;
974            let f64_data: Vec<f64> = a.iter().map(|x| x.to_f64().unwrap_or(0.0)).collect();
975            let result = reduce_axis_general(&f64_data, shape, ax, |lane| {
976                let total: f64 = lane.iter().sum();
977                total / axis_len
978            });
979            make_result(&out_s, result)
980        }
981    }
982}
983
984// ---------------------------------------------------------------------------
985// var
986// ---------------------------------------------------------------------------
987
988/// Variance of array elements over a given axis.
989///
990/// `ddof` is the delta degrees of freedom (0 for population variance, 1 for sample).
991/// Equivalent to `numpy.var`.
992pub fn var<T, D>(a: &Array<T, D>, axis: Option<usize>, ddof: usize) -> FerrayResult<Array<T, IxDyn>>
993where
994    T: Element + Float + Send + Sync,
995    D: Dimension,
996{
997    if a.is_empty() {
998        return Err(FerrayError::invalid_value(
999            "cannot compute variance of empty array",
1000        ));
1001    }
1002    let data = borrow_data(a);
1003    match axis {
1004        None => {
1005            let n = data.len();
1006            if n <= ddof {
1007                return Err(FerrayError::invalid_value(
1008                    "ddof >= number of elements, variance undefined",
1009                ));
1010            }
1011            let nf = T::from(n).unwrap();
1012            let mean_val = try_simd_pairwise_sum(&data)
1013                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()))
1014                / nf;
1015            let sum_sq = try_simd_sum_sq_diff(&data, mean_val).unwrap_or_else(|| {
1016                data.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1017                    let d = x - mean_val;
1018                    acc + d * d
1019                })
1020            });
1021            let var_val = sum_sq / T::from(n - ddof).unwrap();
1022            make_result(&[], vec![var_val])
1023        }
1024        Some(ax) => {
1025            validate_axis(ax, a.ndim())?;
1026            let shape = a.shape();
1027            let out_s = output_shape(shape, ax);
1028            let axis_len = shape[ax];
1029            if axis_len <= ddof {
1030                return Err(FerrayError::invalid_value(
1031                    "ddof >= axis length, variance undefined",
1032                ));
1033            }
1034            let nf = T::from(axis_len).unwrap();
1035            let denom = T::from(axis_len - ddof).unwrap();
1036            let result = reduce_axis_general(&data, shape, ax, |lane| {
1037                let mean_val = try_simd_pairwise_sum(lane)
1038                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1039                    / nf;
1040                let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1041                    lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1042                        let d = x - mean_val;
1043                        acc + d * d
1044                    })
1045                });
1046                sum_sq / denom
1047            });
1048            make_result(&out_s, result)
1049        }
1050    }
1051}
1052
1053// ---------------------------------------------------------------------------
1054// std_
1055// ---------------------------------------------------------------------------
1056
1057/// Standard deviation of array elements over a given axis.
1058///
1059/// `ddof` is the delta degrees of freedom.
1060/// Equivalent to `numpy.std`.
1061pub fn std_<T, D>(
1062    a: &Array<T, D>,
1063    axis: Option<usize>,
1064    ddof: usize,
1065) -> FerrayResult<Array<T, IxDyn>>
1066where
1067    T: Element + Float + Send + Sync,
1068    D: Dimension,
1069{
1070    let v = var(a, axis, ddof)?;
1071    let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
1072    make_result(v.shape(), data)
1073}
1074
1075/// Variance with integer-to-float64 promotion.
1076///
1077/// Sibling of [`mean_as_f64`] — accepts any `T: ToPrimitive` (including
1078/// `i64`/`i32`/`u64`/etc.) and returns an `Array<f64, IxDyn>`. Matches
1079/// `NumPy`'s behaviour of promoting integer variance to f64 (#170).
1080///
1081/// `ddof` is the delta degrees of freedom (0 for population, 1 for sample).
1082pub fn var_as_f64<T, D>(
1083    a: &Array<T, D>,
1084    axis: Option<usize>,
1085    ddof: usize,
1086) -> FerrayResult<Array<f64, IxDyn>>
1087where
1088    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
1089    D: Dimension,
1090{
1091    if a.is_empty() {
1092        return Err(FerrayError::invalid_value(
1093            "cannot compute variance of empty array",
1094        ));
1095    }
1096    // Promote each element to f64 once, then run the standard
1097    // float-typed `var` on the promoted array.
1098    let promoted: Vec<f64> = a
1099        .iter()
1100        .map(|v| {
1101            v.to_f64()
1102                .expect("ToPrimitive failed during var_as_f64 promotion")
1103        })
1104        .collect();
1105    let promoted_arr = Array::<f64, _>::from_vec(a.dim().clone(), promoted)?;
1106    var(&promoted_arr, axis, ddof)
1107}
1108
1109/// Standard deviation with integer-to-float64 promotion (#170).
1110///
1111/// Sibling of [`mean_as_f64`] / [`var_as_f64`]; sqrt of the float-promoted
1112/// variance.
1113pub fn std_as_f64<T, D>(
1114    a: &Array<T, D>,
1115    axis: Option<usize>,
1116    ddof: usize,
1117) -> FerrayResult<Array<f64, IxDyn>>
1118where
1119    T: Element + Copy + Send + Sync + num_traits::ToPrimitive,
1120    D: Dimension,
1121{
1122    let v = var_as_f64(a, axis, ddof)?;
1123    let data: Vec<f64> = v.iter().map(|x| x.sqrt()).collect();
1124    make_result(v.shape(), data)
1125}
1126
1127// ---------------------------------------------------------------------------
1128// Multi-axis + keepdims public API (issues #457 + #458)
1129//
1130// These `*_axes` variants complement the existing single-axis reductions
1131// with two additional features:
1132//   1. `axes: Option<&[usize]>` — reduce over multiple axes at once
1133//      (None or empty slice means "reduce all axes")
1134//   2. `keepdims: bool` — preserve reduced axes as size 1 so the result
1135//      can broadcast back against the original array
1136//
1137// The existing single-axis `sum/prod/etc.` functions remain unchanged.
1138// ---------------------------------------------------------------------------
1139
1140/// Multi-axis sum with optional `keepdims`.
1141///
1142/// Equivalent to `numpy.sum(a, axis=axes, keepdims=keepdims)`. If `axes` is
1143/// `None` or an empty slice, reduces over all axes.
1144///
1145/// # Errors
1146/// Returns `FerrayError::AxisOutOfBounds` on any out-of-range axis, or
1147/// `FerrayError::InvalidValue` on duplicate axes.
1148pub fn sum_axes<T, D>(
1149    a: &Array<T, D>,
1150    axes: Option<&[usize]>,
1151    keepdims: bool,
1152) -> FerrayResult<Array<T, IxDyn>>
1153where
1154    T: Element + std::ops::Add<Output = T> + Copy + Send + Sync,
1155    D: Dimension,
1156{
1157    let ax = normalize_axes(axes, a.ndim())?;
1158    let data = borrow_data(a);
1159    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1160        try_simd_pairwise_sum(lane)
1161            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1162    });
1163    make_result(&out_shape, result)
1164}
1165
1166/// Multi-axis product with optional `keepdims`.
1167pub fn prod_axes<T, D>(
1168    a: &Array<T, D>,
1169    axes: Option<&[usize]>,
1170    keepdims: bool,
1171) -> FerrayResult<Array<T, IxDyn>>
1172where
1173    T: Element + std::ops::Mul<Output = T> + Copy + Send + Sync,
1174    D: Dimension,
1175{
1176    let ax = normalize_axes(axes, a.ndim())?;
1177    let data = borrow_data(a);
1178    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1179        lane.iter()
1180            .copied()
1181            .fold(<T as Element>::one(), |acc, x| acc * x)
1182    });
1183    make_result(&out_shape, result)
1184}
1185
1186/// Multi-axis minimum with optional `keepdims`. NaN-propagating.
1187pub fn min_axes<T, D>(
1188    a: &Array<T, D>,
1189    axes: Option<&[usize]>,
1190    keepdims: bool,
1191) -> FerrayResult<Array<T, IxDyn>>
1192where
1193    T: Element + PartialOrd + Copy,
1194    D: Dimension,
1195{
1196    if a.is_empty() {
1197        return Err(FerrayError::invalid_value(
1198            "cannot compute min of empty array",
1199        ));
1200    }
1201    let ax = normalize_axes(axes, a.ndim())?;
1202    let nan_min = |a: T, b: T| -> T {
1203        if a <= b {
1204            a
1205        } else if a > b {
1206            b
1207        } else {
1208            a // NaN propagates
1209        }
1210    };
1211    let data = borrow_data(a);
1212    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1213        lane.iter().copied().reduce(nan_min).unwrap()
1214    });
1215    make_result(&out_shape, result)
1216}
1217
1218/// Multi-axis maximum with optional `keepdims`. NaN-propagating.
1219pub fn max_axes<T, D>(
1220    a: &Array<T, D>,
1221    axes: Option<&[usize]>,
1222    keepdims: bool,
1223) -> FerrayResult<Array<T, IxDyn>>
1224where
1225    T: Element + PartialOrd + Copy,
1226    D: Dimension,
1227{
1228    if a.is_empty() {
1229        return Err(FerrayError::invalid_value(
1230            "cannot compute max of empty array",
1231        ));
1232    }
1233    let ax = normalize_axes(axes, a.ndim())?;
1234    let nan_max = |a: T, b: T| -> T {
1235        if a >= b {
1236            a
1237        } else if a < b {
1238            b
1239        } else {
1240            a // NaN propagates
1241        }
1242    };
1243    let data = borrow_data(a);
1244    let (result, out_shape) = reduce_axes_general(&data, a.shape(), &ax, keepdims, |lane| {
1245        lane.iter().copied().reduce(nan_max).unwrap()
1246    });
1247    make_result(&out_shape, result)
1248}
1249
1250/// Multi-axis mean with optional `keepdims`.
1251pub fn mean_axes<T, D>(
1252    a: &Array<T, D>,
1253    axes: Option<&[usize]>,
1254    keepdims: bool,
1255) -> FerrayResult<Array<T, IxDyn>>
1256where
1257    T: Element + Float + Send + Sync,
1258    D: Dimension,
1259{
1260    if a.is_empty() {
1261        return Err(FerrayError::invalid_value(
1262            "cannot compute mean of empty array",
1263        ));
1264    }
1265    let ax = normalize_axes(axes, a.ndim())?;
1266    // Compute n = product of reduced-axis lengths.
1267    let shape = a.shape();
1268    let lane_len: usize = ax.iter().map(|&i| shape[i]).product();
1269    let n = T::from(lane_len).unwrap();
1270    let data = borrow_data(a);
1271    let (result, out_shape) = reduce_axes_general(&data, shape, &ax, keepdims, |lane| {
1272        let total = try_simd_pairwise_sum(lane)
1273            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
1274        total / n
1275    });
1276    make_result(&out_shape, result)
1277}
1278
1279/// Multi-axis variance with optional `keepdims` and Bessel correction `ddof`.
1280pub fn var_axes<T, D>(
1281    a: &Array<T, D>,
1282    axes: Option<&[usize]>,
1283    ddof: usize,
1284    keepdims: bool,
1285) -> FerrayResult<Array<T, IxDyn>>
1286where
1287    T: Element + Float + Send + Sync,
1288    D: Dimension,
1289{
1290    if a.is_empty() {
1291        return Err(FerrayError::invalid_value(
1292            "cannot compute variance of empty array",
1293        ));
1294    }
1295    let ax = normalize_axes(axes, a.ndim())?;
1296    let shape = a.shape();
1297    let lane_len: usize = ax.iter().map(|&i| shape[i]).product();
1298    if lane_len <= ddof {
1299        return Err(FerrayError::invalid_value(
1300            "ddof >= reduced-axis length, variance undefined",
1301        ));
1302    }
1303    let nf = T::from(lane_len).unwrap();
1304    let denom = T::from(lane_len - ddof).unwrap();
1305    let data = borrow_data(a);
1306    let (result, out_shape) = reduce_axes_general(&data, shape, &ax, keepdims, |lane| {
1307        let mean_val = try_simd_pairwise_sum(lane)
1308            .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1309            / nf;
1310        let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1311            lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1312                let d = x - mean_val;
1313                acc + d * d
1314            })
1315        });
1316        sum_sq / denom
1317    });
1318    make_result(&out_shape, result)
1319}
1320
1321/// Multi-axis standard deviation with optional `keepdims`.
1322pub fn std_axes<T, D>(
1323    a: &Array<T, D>,
1324    axes: Option<&[usize]>,
1325    ddof: usize,
1326    keepdims: bool,
1327) -> FerrayResult<Array<T, IxDyn>>
1328where
1329    T: Element + Float + Send + Sync,
1330    D: Dimension,
1331{
1332    let v = var_axes(a, axes, ddof, keepdims)?;
1333    let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
1334    make_result(v.shape(), data)
1335}
1336
1337// ---------------------------------------------------------------------------
1338// `*_into` reductions: write into a caller-provided destination
1339//
1340// NumPy reductions accept an `out=` parameter that lets callers reuse a
1341// pre-allocated output buffer across repeated reductions on same-shaped
1342// data — common in streaming statistics, online ML metrics, etc. The
1343// ferray-stats reduction functions all allocate a fresh `Array<T, IxDyn>`
1344// internally, so we add `*_into` companion functions that take a
1345// `&mut Array<T, IxDyn>` destination and skip the allocation entirely.
1346//
1347// History:
1348//   #467 introduced the `*_into` API surface but the first implementation
1349//   still went through the allocating kernel and copied into `out`,
1350//   leaving one Vec materialization per call. #563 plumbs the destination
1351//   slice through the kernel itself via `reduce_axis_typed_into` so the
1352//   path is truly zero-alloc — only the per-lane scratch buffer
1353//   (allocated once per call, reused for every lane) and the input
1354//   contig-borrow allocation (only when the input is already non-contig)
1355//   remain.
1356// ---------------------------------------------------------------------------
1357
1358/// Sum reduction writing into a pre-allocated destination.
1359///
1360/// Equivalent to `np.sum(a, axis=axis, out=out)`. The destination must
1361/// be C-contiguous and have exactly the shape that `sum(a, axis)` would
1362/// produce; broadcasting is not supported.
1363///
1364/// # Errors
1365/// - `FerrayError::AxisOutOfBounds` if `axis` is out of range.
1366/// - `FerrayError::ShapeMismatch` if `out.shape()` does not match the
1367///   expected reduction shape.
1368/// - `FerrayError::InvalidValue` if `out` is not C-contiguous.
1369pub fn sum_into<T, D>(
1370    a: &Array<T, D>,
1371    axis: Option<usize>,
1372    out: &mut Array<T, IxDyn>,
1373) -> FerrayResult<()>
1374where
1375    T: Element + std::ops::Add<Output = T> + Copy + Send + Sync,
1376    D: Dimension,
1377{
1378    let data = borrow_data(a);
1379    match axis {
1380        None => {
1381            let dst = check_out_shape(out, &[], "sum")?;
1382            let total = try_simd_pairwise_sum(&data)
1383                .unwrap_or_else(|| parallel::parallel_sum(&data, <T as Element>::zero()));
1384            dst[0] = total;
1385            Ok(())
1386        }
1387        Some(ax) => {
1388            validate_axis(ax, a.ndim())?;
1389            let shape_vec = a.shape().to_vec();
1390            let out_s = output_shape(&shape_vec, ax);
1391            let dst = check_out_shape(out, &out_s, "sum")?;
1392            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1393                try_simd_pairwise_sum(lane)
1394                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1395            });
1396            Ok(())
1397        }
1398    }
1399}
1400
1401/// Product reduction writing into a pre-allocated destination.
1402///
1403/// See [`sum_into`] for the contract on `out`.
1404pub fn prod_into<T, D>(
1405    a: &Array<T, D>,
1406    axis: Option<usize>,
1407    out: &mut Array<T, IxDyn>,
1408) -> FerrayResult<()>
1409where
1410    T: Element + std::ops::Mul<Output = T> + Copy + Send + Sync,
1411    D: Dimension,
1412{
1413    let data = borrow_data(a);
1414    match axis {
1415        None => {
1416            let dst = check_out_shape(out, &[], "prod")?;
1417            dst[0] = parallel::parallel_prod(&data, <T as Element>::one());
1418            Ok(())
1419        }
1420        Some(ax) => {
1421            validate_axis(ax, a.ndim())?;
1422            let shape_vec = a.shape().to_vec();
1423            let out_s = output_shape(&shape_vec, ax);
1424            let dst = check_out_shape(out, &out_s, "prod")?;
1425            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1426                lane.iter()
1427                    .copied()
1428                    .fold(<T as Element>::one(), |acc, x| acc * x)
1429            });
1430            Ok(())
1431        }
1432    }
1433}
1434
1435/// Min reduction writing into a pre-allocated destination.
1436///
1437/// See [`sum_into`] for the contract on `out`. Empty input arrays return
1438/// `FerrayError::InvalidValue` because `min` of an empty set is undefined.
1439pub fn min_into<T, D>(
1440    a: &Array<T, D>,
1441    axis: Option<usize>,
1442    out: &mut Array<T, IxDyn>,
1443) -> FerrayResult<()>
1444where
1445    T: Element + PartialOrd + Copy + Send + Sync,
1446    D: Dimension,
1447{
1448    if a.is_empty() {
1449        return Err(FerrayError::invalid_value(
1450            "cannot compute min of empty array",
1451        ));
1452    }
1453    // Same NaN-propagating reducer as `min` so the two paths agree.
1454    let nan_min = |a: T, b: T| -> T {
1455        if a <= b {
1456            a
1457        } else if a > b {
1458            b
1459        } else {
1460            a
1461        }
1462    };
1463    let data = borrow_data(a);
1464    match axis {
1465        None => {
1466            let dst = check_out_shape(out, &[], "min")?;
1467            dst[0] = data.iter().copied().reduce(nan_min).unwrap();
1468            Ok(())
1469        }
1470        Some(ax) => {
1471            validate_axis(ax, a.ndim())?;
1472            let shape_vec = a.shape().to_vec();
1473            let out_s = output_shape(&shape_vec, ax);
1474            let dst = check_out_shape(out, &out_s, "min")?;
1475            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1476                lane.iter().copied().reduce(nan_min).unwrap()
1477            });
1478            Ok(())
1479        }
1480    }
1481}
1482
1483/// Max reduction writing into a pre-allocated destination.
1484///
1485/// See [`sum_into`] for the contract on `out`. Empty input arrays return
1486/// `FerrayError::InvalidValue`.
1487pub fn max_into<T, D>(
1488    a: &Array<T, D>,
1489    axis: Option<usize>,
1490    out: &mut Array<T, IxDyn>,
1491) -> FerrayResult<()>
1492where
1493    T: Element + PartialOrd + Copy + Send + Sync,
1494    D: Dimension,
1495{
1496    if a.is_empty() {
1497        return Err(FerrayError::invalid_value(
1498            "cannot compute max of empty array",
1499        ));
1500    }
1501    let nan_max = |a: T, b: T| -> T {
1502        if a >= b {
1503            a
1504        } else if a < b {
1505            b
1506        } else {
1507            a
1508        }
1509    };
1510    let data = borrow_data(a);
1511    match axis {
1512        None => {
1513            let dst = check_out_shape(out, &[], "max")?;
1514            dst[0] = data.iter().copied().reduce(nan_max).unwrap();
1515            Ok(())
1516        }
1517        Some(ax) => {
1518            validate_axis(ax, a.ndim())?;
1519            let shape_vec = a.shape().to_vec();
1520            let out_s = output_shape(&shape_vec, ax);
1521            let dst = check_out_shape(out, &out_s, "max")?;
1522            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1523                lane.iter().copied().reduce(nan_max).unwrap()
1524            });
1525            Ok(())
1526        }
1527    }
1528}
1529
1530/// Mean reduction writing into a pre-allocated destination.
1531///
1532/// See [`sum_into`] for the contract on `out`. Empty inputs return
1533/// `FerrayError::InvalidValue`.
1534pub fn mean_into<T, D>(
1535    a: &Array<T, D>,
1536    axis: Option<usize>,
1537    out: &mut Array<T, IxDyn>,
1538) -> FerrayResult<()>
1539where
1540    T: Element + Float + Send + Sync,
1541    D: Dimension,
1542{
1543    if a.is_empty() {
1544        return Err(FerrayError::invalid_value(
1545            "cannot compute mean of empty array",
1546        ));
1547    }
1548    let data = borrow_data(a);
1549    match axis {
1550        None => {
1551            let dst = check_out_shape(out, &[], "mean")?;
1552            let n = T::from(data.len()).unwrap();
1553            let total = try_simd_pairwise_sum(&data)
1554                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()));
1555            dst[0] = total / n;
1556            Ok(())
1557        }
1558        Some(ax) => {
1559            validate_axis(ax, a.ndim())?;
1560            let shape_vec = a.shape().to_vec();
1561            let out_s = output_shape(&shape_vec, ax);
1562            let axis_len = shape_vec[ax];
1563            let n = T::from(axis_len).unwrap();
1564            let dst = check_out_shape(out, &out_s, "mean")?;
1565            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1566                let total = try_simd_pairwise_sum(lane)
1567                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()));
1568                total / n
1569            });
1570            Ok(())
1571        }
1572    }
1573}
1574
1575/// Variance reduction writing into a pre-allocated destination.
1576///
1577/// `ddof` is the delta degrees of freedom. See [`sum_into`] for the
1578/// contract on `out`. Returns `FerrayError::InvalidValue` for empty
1579/// inputs or when `ddof >= n`.
1580pub fn var_into<T, D>(
1581    a: &Array<T, D>,
1582    axis: Option<usize>,
1583    ddof: usize,
1584    out: &mut Array<T, IxDyn>,
1585) -> FerrayResult<()>
1586where
1587    T: Element + Float + Send + Sync,
1588    D: Dimension,
1589{
1590    if a.is_empty() {
1591        return Err(FerrayError::invalid_value(
1592            "cannot compute variance of empty array",
1593        ));
1594    }
1595    let data = borrow_data(a);
1596    match axis {
1597        None => {
1598            let n = data.len();
1599            if n <= ddof {
1600                return Err(FerrayError::invalid_value(
1601                    "ddof >= number of elements, variance undefined",
1602                ));
1603            }
1604            let dst = check_out_shape(out, &[], "var")?;
1605            let nf = T::from(n).unwrap();
1606            let mean_val = try_simd_pairwise_sum(&data)
1607                .unwrap_or_else(|| parallel::pairwise_sum(&data, <T as Element>::zero()))
1608                / nf;
1609            let sum_sq = try_simd_sum_sq_diff(&data, mean_val).unwrap_or_else(|| {
1610                data.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1611                    let d = x - mean_val;
1612                    acc + d * d
1613                })
1614            });
1615            dst[0] = sum_sq / T::from(n - ddof).unwrap();
1616            Ok(())
1617        }
1618        Some(ax) => {
1619            validate_axis(ax, a.ndim())?;
1620            let shape_vec = a.shape().to_vec();
1621            let axis_len = shape_vec[ax];
1622            if axis_len <= ddof {
1623                return Err(FerrayError::invalid_value(
1624                    "ddof >= axis length, variance undefined",
1625                ));
1626            }
1627            let out_s = output_shape(&shape_vec, ax);
1628            let nf = T::from(axis_len).unwrap();
1629            let denom = T::from(axis_len - ddof).unwrap();
1630            let dst = check_out_shape(out, &out_s, "var")?;
1631            reduce_axis_general_into(&data, &shape_vec, ax, dst, |lane| {
1632                let mean_val = try_simd_pairwise_sum(lane)
1633                    .unwrap_or_else(|| parallel::pairwise_sum(lane, <T as Element>::zero()))
1634                    / nf;
1635                let sum_sq = try_simd_sum_sq_diff(lane, mean_val).unwrap_or_else(|| {
1636                    lane.iter().copied().fold(<T as Element>::zero(), |acc, x| {
1637                        let d = x - mean_val;
1638                        acc + d * d
1639                    })
1640                });
1641                sum_sq / denom
1642            });
1643            Ok(())
1644        }
1645    }
1646}
1647
1648/// Standard deviation reduction writing into a pre-allocated destination.
1649///
1650/// Computes the variance directly into `out` and then takes the
1651/// element-wise square root in place — the std value never lives in any
1652/// intermediate buffer.
1653pub fn std_into<T, D>(
1654    a: &Array<T, D>,
1655    axis: Option<usize>,
1656    ddof: usize,
1657    out: &mut Array<T, IxDyn>,
1658) -> FerrayResult<()>
1659where
1660    T: Element + Float + Send + Sync,
1661    D: Dimension,
1662{
1663    var_into(a, axis, ddof, out)?;
1664    // Variance is now in `out`; sqrt each element in place.
1665    let dst = out
1666        .as_slice_mut()
1667        .ok_or_else(|| FerrayError::invalid_value("std_into: out must be C-contiguous"))?;
1668    for slot in dst.iter_mut() {
1669        *slot = slot.sqrt();
1670    }
1671    Ok(())
1672}
1673
1674// ---------------------------------------------------------------------------
1675// `*_with` reductions: NumPy `initial=` and `where=` parameters (#459)
1676//
1677// NumPy reductions accept two extra parameters that ferray-stats was
1678// missing entirely:
1679//
1680//   - `initial`: starting value for the accumulator. Lets `np.sum([])`
1681//     return a non-zero seed, lets `np.min(arr, initial=999)` provide
1682//     a fallback for empty inputs, and lets cumulative-style code
1683//     thread an initial state through a reduction.
1684//
1685//   - `where`: same-shape boolean mask. Only positions where the mask
1686//     is `true` contribute to the reduction. This is the masked-array
1687//     equivalent without materializing a MaskedArray — `np.sum(arr,
1688//     where=arr > 0)` is the canonical "sum the positives" idiom.
1689//
1690// The `*_with` family adds these as `Option<T>` / `Option<&Array<bool,
1691// D>>` parameters. Mask broadcasting is intentionally not supported in
1692// this first cut — the mask must have exactly the same shape as the
1693// input. NumPy allows broadcast-compatible masks, but the same-shape
1694// path covers every common use case (predicates produced by comparison
1695// ufuncs against the same input always have a matching shape).
1696// ---------------------------------------------------------------------------
1697
1698/// Prepare a `where` mask for use by a `*_with` reduction kernel.
1699///
1700/// Accepts an `Option<&Array<bool, IxDyn>>` so the mask can have any
1701/// rank independently of the input — callers with typed masks should
1702/// call `.to_dyn()` before passing in.
1703///
1704/// Returns:
1705/// - `Ok(None)` if the caller passed no mask.
1706/// - `Ok(Some(vec))` where `vec` is a row-major materialization of the
1707///   mask aligned with `a`'s flat logical order. For same-shape masks
1708///   this is a direct copy; for broadcast-compatible masks the mask is
1709///   broadcast into `a.shape()` first via
1710///   [`ferray_core::dimension::broadcast::broadcast_to`] (#565) and
1711///   then materialized.
1712/// - `Err(FerrayError::ShapeMismatch)` if the mask is not
1713///   broadcast-compatible with `a.shape()`.
1714fn prepare_where_mask<T, D>(
1715    a: &Array<T, D>,
1716    mask: Option<&Array<bool, IxDyn>>,
1717    op_name: &str,
1718) -> FerrayResult<Option<Vec<bool>>>
1719where
1720    T: Element,
1721    D: Dimension,
1722{
1723    use ferray_core::dimension::broadcast::broadcast_to;
1724
1725    let Some(m) = mask else {
1726        return Ok(None);
1727    };
1728
1729    // Fast path: the mask is already the same shape as the input.
1730    if m.shape() == a.shape() {
1731        return Ok(Some(m.iter().copied().collect()));
1732    }
1733
1734    // Broadcast path: let the core machinery check compatibility and
1735    // produce a stride-tricked view at the target shape. We then
1736    // materialize it into a flat Vec<bool> aligned with `a`'s logical
1737    // row-major order.
1738    let view = broadcast_to(m, a.shape()).map_err(|_| {
1739        FerrayError::shape_mismatch(format!(
1740            "{op_name}: where mask shape {:?} is not broadcast-compatible with array shape {:?}",
1741            m.shape(),
1742            a.shape()
1743        ))
1744    })?;
1745    Ok(Some(view.iter().copied().collect()))
1746}
1747
1748/// Walk a single axis-Some output position, gathering masked-true lane
1749/// values into `lane_buf` (cleared first). Used by every `*_with` axis
1750/// path so the lane-gather logic exists in one place.
1751///
1752/// `out_multi` is the output multi-index excluding `axis`. The function
1753/// reconstructs the corresponding input flat positions and consults the
1754/// mask in lockstep.
1755fn gather_lane_with_mask<T: Copy>(
1756    data: &[T],
1757    mask: Option<&[bool]>,
1758    shape: &[usize],
1759    strides: &[usize],
1760    axis: usize,
1761    out_multi: &[usize],
1762    lane_buf: &mut Vec<T>,
1763) {
1764    let ndim = shape.len();
1765    let axis_len = shape[axis];
1766    let mut in_multi = vec![0usize; ndim];
1767    let mut out_dim = 0;
1768    for (d, idx) in in_multi.iter_mut().enumerate() {
1769        if d == axis {
1770            *idx = 0;
1771        } else {
1772            *idx = out_multi[out_dim];
1773            out_dim += 1;
1774        }
1775    }
1776    lane_buf.clear();
1777    for k in 0..axis_len {
1778        in_multi[axis] = k;
1779        let idx = flat_index(&in_multi, strides);
1780        match mask {
1781            Some(m) if !m[idx] => {}
1782            _ => lane_buf.push(data[idx]),
1783        }
1784    }
1785}
1786
1787/// Sum reduction with `initial` and `where` parameters.
1788///
1789/// Equivalent to `np.sum(a, axis=axis, initial=initial, where=where_mask)`.
1790///
1791/// `initial` defaults to `T::zero()` when `None`. `where_mask`, when
1792/// provided, must be broadcast-compatible with `a.shape()` (#565).
1793/// Callers with typed masks (e.g. `Array<bool, Ix1>`) should call
1794/// `.to_dyn()` before passing. Positions where the broadcast mask is
1795/// `false` are skipped.
1796///
1797/// # Errors
1798/// - `FerrayError::AxisOutOfBounds` if `axis` is out of range.
1799/// - `FerrayError::ShapeMismatch` if `where_mask` is not
1800///   broadcast-compatible with `a.shape()`.
1801pub fn sum_with<T, D>(
1802    a: &Array<T, D>,
1803    axis: Option<usize>,
1804    initial: Option<T>,
1805    where_mask: Option<&Array<bool, IxDyn>>,
1806) -> FerrayResult<Array<T, IxDyn>>
1807where
1808    T: Element + std::ops::Add<Output = T> + Copy,
1809    D: Dimension,
1810{
1811    let init = initial.unwrap_or_else(<T as Element>::zero);
1812    let data = borrow_data(a);
1813    let mask_vec = prepare_where_mask(a, where_mask, "sum")?;
1814    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
1815
1816    match axis {
1817        None => {
1818            let total = match mask_slice {
1819                None => data.iter().copied().fold(init, |acc, x| acc + x),
1820                Some(mask) => data
1821                    .iter()
1822                    .zip(mask.iter())
1823                    .filter(|&(_, &m)| m)
1824                    .fold(init, |acc, (&x, _)| acc + x),
1825            };
1826            make_result(&[], vec![total])
1827        }
1828        Some(ax) => {
1829            validate_axis(ax, a.ndim())?;
1830            let shape_vec = a.shape().to_vec();
1831            let out_s = output_shape(&shape_vec, ax);
1832            let strides = compute_strides(&shape_vec);
1833
1834            let out_size: usize = if out_s.is_empty() {
1835                1
1836            } else {
1837                out_s.iter().product()
1838            };
1839            let mut result: Vec<T> = Vec::with_capacity(out_size);
1840            let mut out_multi = vec![0usize; out_s.len()];
1841            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
1842
1843            for _ in 0..out_size {
1844                gather_lane_with_mask(
1845                    &data,
1846                    mask_slice,
1847                    &shape_vec,
1848                    &strides,
1849                    ax,
1850                    &out_multi,
1851                    &mut lane_buf,
1852                );
1853                let lane_sum = lane_buf.iter().copied().fold(init, |acc, x| acc + x);
1854                result.push(lane_sum);
1855                if !out_s.is_empty() {
1856                    increment_multi_index(&mut out_multi, &out_s);
1857                }
1858            }
1859            make_result(&out_s, result)
1860        }
1861    }
1862}
1863
1864/// Product reduction with `initial` and `where` parameters.
1865///
1866/// Equivalent to `np.prod(a, axis=axis, initial=initial, where=where_mask)`.
1867/// `initial` defaults to `T::one()` when `None`. `where_mask` is
1868/// broadcast-compatible with `a.shape()` (#565).
1869pub fn prod_with<T, D>(
1870    a: &Array<T, D>,
1871    axis: Option<usize>,
1872    initial: Option<T>,
1873    where_mask: Option<&Array<bool, IxDyn>>,
1874) -> FerrayResult<Array<T, IxDyn>>
1875where
1876    T: Element + std::ops::Mul<Output = T> + Copy,
1877    D: Dimension,
1878{
1879    let init = initial.unwrap_or_else(<T as Element>::one);
1880    let data = borrow_data(a);
1881    let mask_vec = prepare_where_mask(a, where_mask, "prod")?;
1882    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
1883
1884    match axis {
1885        None => {
1886            let total = match mask_slice {
1887                None => data.iter().copied().fold(init, |acc, x| acc * x),
1888                Some(mask) => data
1889                    .iter()
1890                    .zip(mask.iter())
1891                    .filter(|&(_, &m)| m)
1892                    .fold(init, |acc, (&x, _)| acc * x),
1893            };
1894            make_result(&[], vec![total])
1895        }
1896        Some(ax) => {
1897            validate_axis(ax, a.ndim())?;
1898            let shape_vec = a.shape().to_vec();
1899            let out_s = output_shape(&shape_vec, ax);
1900            let strides = compute_strides(&shape_vec);
1901
1902            let out_size: usize = if out_s.is_empty() {
1903                1
1904            } else {
1905                out_s.iter().product()
1906            };
1907            let mut result: Vec<T> = Vec::with_capacity(out_size);
1908            let mut out_multi = vec![0usize; out_s.len()];
1909            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
1910
1911            for _ in 0..out_size {
1912                gather_lane_with_mask(
1913                    &data,
1914                    mask_slice,
1915                    &shape_vec,
1916                    &strides,
1917                    ax,
1918                    &out_multi,
1919                    &mut lane_buf,
1920                );
1921                let lane_prod = lane_buf.iter().copied().fold(init, |acc, x| acc * x);
1922                result.push(lane_prod);
1923                if !out_s.is_empty() {
1924                    increment_multi_index(&mut out_multi, &out_s);
1925                }
1926            }
1927            make_result(&out_s, result)
1928        }
1929    }
1930}
1931
1932/// Min reduction with `initial` and `where` parameters.
1933///
1934/// Equivalent to `np.min(a, axis=axis, initial=initial, where=where_mask)`.
1935/// Unlike plain `min`, an empty input (or a fully-masked-out lane) is
1936/// allowed when `initial` is supplied — the result is `initial`. Without
1937/// `initial`, an empty lane is an error. `where_mask` is
1938/// broadcast-compatible with `a.shape()` (#565).
1939pub fn min_with<T, D>(
1940    a: &Array<T, D>,
1941    axis: Option<usize>,
1942    initial: Option<T>,
1943    where_mask: Option<&Array<bool, IxDyn>>,
1944) -> FerrayResult<Array<T, IxDyn>>
1945where
1946    T: Element + PartialOrd + Copy,
1947    D: Dimension,
1948{
1949    let nan_min = |a: T, b: T| -> T {
1950        if a <= b {
1951            a
1952        } else if a > b {
1953            b
1954        } else {
1955            a
1956        }
1957    };
1958    let data = borrow_data(a);
1959    let mask_vec = prepare_where_mask(a, where_mask, "min")?;
1960    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
1961
1962    let lane_min = |lane: &[T], initial: Option<T>| -> FerrayResult<T> {
1963        let mut iter = lane.iter().copied();
1964        let seed = match initial {
1965            Some(v) => Some(v),
1966            None => iter.next(),
1967        };
1968        match seed {
1969            Some(s) => Ok(iter.fold(s, nan_min)),
1970            None => Err(FerrayError::invalid_value(
1971                "min: empty lane and no initial value",
1972            )),
1973        }
1974    };
1975
1976    match axis {
1977        None => {
1978            let total = match mask_slice {
1979                None => lane_min(&data, initial)?,
1980                Some(mask) => {
1981                    let filtered: Vec<T> = data
1982                        .iter()
1983                        .copied()
1984                        .zip(mask.iter().copied())
1985                        .filter_map(|(x, m)| if m { Some(x) } else { None })
1986                        .collect();
1987                    lane_min(&filtered, initial)?
1988                }
1989            };
1990            make_result(&[], vec![total])
1991        }
1992        Some(ax) => {
1993            validate_axis(ax, a.ndim())?;
1994            let shape_vec = a.shape().to_vec();
1995            let out_s = output_shape(&shape_vec, ax);
1996            let strides = compute_strides(&shape_vec);
1997
1998            let out_size: usize = if out_s.is_empty() {
1999                1
2000            } else {
2001                out_s.iter().product()
2002            };
2003            let mut result: Vec<T> = Vec::with_capacity(out_size);
2004            let mut out_multi = vec![0usize; out_s.len()];
2005            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2006
2007            for _ in 0..out_size {
2008                gather_lane_with_mask(
2009                    &data,
2010                    mask_slice,
2011                    &shape_vec,
2012                    &strides,
2013                    ax,
2014                    &out_multi,
2015                    &mut lane_buf,
2016                );
2017                result.push(lane_min(&lane_buf, initial)?);
2018                if !out_s.is_empty() {
2019                    increment_multi_index(&mut out_multi, &out_s);
2020                }
2021            }
2022            make_result(&out_s, result)
2023        }
2024    }
2025}
2026
2027/// Max reduction with `initial` and `where` parameters.
2028///
2029/// Symmetric to [`min_with`]. `where_mask` is broadcast-compatible with
2030/// `a.shape()` (#565).
2031pub fn max_with<T, D>(
2032    a: &Array<T, D>,
2033    axis: Option<usize>,
2034    initial: Option<T>,
2035    where_mask: Option<&Array<bool, IxDyn>>,
2036) -> FerrayResult<Array<T, IxDyn>>
2037where
2038    T: Element + PartialOrd + Copy,
2039    D: Dimension,
2040{
2041    let nan_max = |a: T, b: T| -> T {
2042        if a >= b {
2043            a
2044        } else if a < b {
2045            b
2046        } else {
2047            a
2048        }
2049    };
2050    let data = borrow_data(a);
2051    let mask_vec = prepare_where_mask(a, where_mask, "max")?;
2052    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2053
2054    let lane_max = |lane: &[T], initial: Option<T>| -> FerrayResult<T> {
2055        let mut iter = lane.iter().copied();
2056        let seed = match initial {
2057            Some(v) => Some(v),
2058            None => iter.next(),
2059        };
2060        match seed {
2061            Some(s) => Ok(iter.fold(s, nan_max)),
2062            None => Err(FerrayError::invalid_value(
2063                "max: empty lane and no initial value",
2064            )),
2065        }
2066    };
2067
2068    match axis {
2069        None => {
2070            let total = match mask_slice {
2071                None => lane_max(&data, initial)?,
2072                Some(mask) => {
2073                    let filtered: Vec<T> = data
2074                        .iter()
2075                        .copied()
2076                        .zip(mask.iter().copied())
2077                        .filter_map(|(x, m)| if m { Some(x) } else { None })
2078                        .collect();
2079                    lane_max(&filtered, initial)?
2080                }
2081            };
2082            make_result(&[], vec![total])
2083        }
2084        Some(ax) => {
2085            validate_axis(ax, a.ndim())?;
2086            let shape_vec = a.shape().to_vec();
2087            let out_s = output_shape(&shape_vec, ax);
2088            let strides = compute_strides(&shape_vec);
2089
2090            let out_size: usize = if out_s.is_empty() {
2091                1
2092            } else {
2093                out_s.iter().product()
2094            };
2095            let mut result: Vec<T> = Vec::with_capacity(out_size);
2096            let mut out_multi = vec![0usize; out_s.len()];
2097            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2098
2099            for _ in 0..out_size {
2100                gather_lane_with_mask(
2101                    &data,
2102                    mask_slice,
2103                    &shape_vec,
2104                    &strides,
2105                    ax,
2106                    &out_multi,
2107                    &mut lane_buf,
2108                );
2109                result.push(lane_max(&lane_buf, initial)?);
2110                if !out_s.is_empty() {
2111                    increment_multi_index(&mut out_multi, &out_s);
2112                }
2113            }
2114            make_result(&out_s, result)
2115        }
2116    }
2117}
2118
2119/// Mean reduction with a `where` mask.
2120///
2121/// Equivalent to `np.mean(a, axis=axis, where=where_mask)`. The divisor
2122/// is the count of `true` positions in the (broadcast) mask, NOT the
2123/// lane length — fully-masked-out lanes return `T::nan()` (matching
2124/// `NumPy`'s "`RuntimeWarning`: Mean of empty slice" behavior, but without
2125/// the warning machinery). `initial` is intentionally not modeled
2126/// because the divisor for an "initial-bumped" mean is ambiguous in
2127/// `NumPy` too. `where_mask` is broadcast-compatible with `a.shape()`
2128/// (#565).
2129pub fn mean_where<T, D>(
2130    a: &Array<T, D>,
2131    axis: Option<usize>,
2132    where_mask: Option<&Array<bool, IxDyn>>,
2133) -> FerrayResult<Array<T, IxDyn>>
2134where
2135    T: Element + Float,
2136    D: Dimension,
2137{
2138    let data = borrow_data(a);
2139    let mask_vec = prepare_where_mask(a, where_mask, "mean")?;
2140    let mask_slice: Option<&[bool]> = mask_vec.as_deref();
2141
2142    match axis {
2143        None => {
2144            let (sum, count) = match mask_slice {
2145                None => {
2146                    let total = data
2147                        .iter()
2148                        .copied()
2149                        .fold(<T as Element>::zero(), |acc, x| acc + x);
2150                    (total, data.len())
2151                }
2152                Some(mask) => {
2153                    let mut s = <T as Element>::zero();
2154                    let mut c = 0usize;
2155                    for (&x, &m) in data.iter().zip(mask.iter()) {
2156                        if m {
2157                            s = s + x;
2158                            c += 1;
2159                        }
2160                    }
2161                    (s, c)
2162                }
2163            };
2164            let result = if count == 0 {
2165                <T as Float>::nan()
2166            } else {
2167                sum / T::from(count).unwrap()
2168            };
2169            make_result(&[], vec![result])
2170        }
2171        Some(ax) => {
2172            validate_axis(ax, a.ndim())?;
2173            let shape_vec = a.shape().to_vec();
2174            let out_s = output_shape(&shape_vec, ax);
2175            let strides = compute_strides(&shape_vec);
2176
2177            let out_size: usize = if out_s.is_empty() {
2178                1
2179            } else {
2180                out_s.iter().product()
2181            };
2182            let mut result: Vec<T> = Vec::with_capacity(out_size);
2183            let mut out_multi = vec![0usize; out_s.len()];
2184            let mut lane_buf: Vec<T> = Vec::with_capacity(shape_vec[ax]);
2185
2186            for _ in 0..out_size {
2187                gather_lane_with_mask(
2188                    &data,
2189                    mask_slice,
2190                    &shape_vec,
2191                    &strides,
2192                    ax,
2193                    &out_multi,
2194                    &mut lane_buf,
2195                );
2196                let lane_mean = if lane_buf.is_empty() {
2197                    <T as Float>::nan()
2198                } else {
2199                    let s = lane_buf
2200                        .iter()
2201                        .copied()
2202                        .fold(<T as Element>::zero(), |acc, x| acc + x);
2203                    s / T::from(lane_buf.len()).unwrap()
2204                };
2205                result.push(lane_mean);
2206                if !out_s.is_empty() {
2207                    increment_multi_index(&mut out_multi, &out_s);
2208                }
2209            }
2210            make_result(&out_s, result)
2211        }
2212    }
2213}
2214
2215/// Single-axis argmin with optional `keepdims`.
2216///
2217/// `NumPy`'s `argmin` only accepts a single axis (or `None`); this mirrors
2218/// that constraint. `keepdims` preserves the reduced axis as size 1.
2219pub fn argmin_keepdims<T, D>(
2220    a: &Array<T, D>,
2221    axis: Option<usize>,
2222    keepdims: bool,
2223) -> FerrayResult<Array<u64, IxDyn>>
2224where
2225    T: Element + PartialOrd + Copy,
2226    D: Dimension,
2227{
2228    if a.is_empty() {
2229        return Err(FerrayError::invalid_value(
2230            "cannot compute argmin of empty array",
2231        ));
2232    }
2233    let ndim = a.ndim();
2234    let ax_vec: Vec<usize> = match axis {
2235        None => (0..ndim).collect(),
2236        Some(ax) => {
2237            if ax >= ndim {
2238                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
2239            }
2240            vec![ax]
2241        }
2242    };
2243    let shape = a.shape();
2244    let data = borrow_data(a);
2245    let (result_f64, out_shape) =
2246        reduce_axes_general_u64(&data, shape, &ax_vec, keepdims, |lane| {
2247            lane.iter()
2248                .enumerate()
2249                .reduce(|(ai, av), (bi, bv)| if av <= bv { (ai, av) } else { (bi, bv) })
2250                .unwrap()
2251                .0 as u64
2252        });
2253    make_result(&out_shape, result_f64)
2254}
2255
2256/// Single-axis argmax with optional `keepdims`.
2257pub fn argmax_keepdims<T, D>(
2258    a: &Array<T, D>,
2259    axis: Option<usize>,
2260    keepdims: bool,
2261) -> FerrayResult<Array<u64, IxDyn>>
2262where
2263    T: Element + PartialOrd + Copy,
2264    D: Dimension,
2265{
2266    if a.is_empty() {
2267        return Err(FerrayError::invalid_value(
2268            "cannot compute argmax of empty array",
2269        ));
2270    }
2271    let ndim = a.ndim();
2272    let ax_vec: Vec<usize> = match axis {
2273        None => (0..ndim).collect(),
2274        Some(ax) => {
2275            if ax >= ndim {
2276                return Err(FerrayError::axis_out_of_bounds(ax, ndim));
2277            }
2278            vec![ax]
2279        }
2280    };
2281    let shape = a.shape();
2282    let data = borrow_data(a);
2283    let (result_u64, out_shape) =
2284        reduce_axes_general_u64(&data, shape, &ax_vec, keepdims, |lane| {
2285            lane.iter()
2286                .enumerate()
2287                .reduce(|(ai, av), (bi, bv)| if av >= bv { (ai, av) } else { (bi, bv) })
2288                .unwrap()
2289                .0 as u64
2290        });
2291    make_result(&out_shape, result_u64)
2292}
2293
2294/// Multi-axis reduction that returns `u64` values (used by argmin/argmax variants).
2295pub(crate) fn reduce_axes_general_u64<T: Copy, F: Fn(&[T]) -> u64>(
2296    data: &[T],
2297    shape: &[usize],
2298    axes: &[usize],
2299    keepdims: bool,
2300    f: F,
2301) -> (Vec<u64>, Vec<usize>) {
2302    let ndim = shape.len();
2303    let strides = compute_strides(shape);
2304
2305    let is_reduce: Vec<bool> = (0..ndim).map(|i| axes.contains(&i)).collect();
2306    let keep_axes: Vec<usize> = (0..ndim).filter(|i| !is_reduce[*i]).collect();
2307    let reduce_axes: Vec<usize> = (0..ndim).filter(|i| is_reduce[*i]).collect();
2308    let keep_shape: Vec<usize> = keep_axes.iter().map(|&i| shape[i]).collect();
2309    let reduce_shape: Vec<usize> = reduce_axes.iter().map(|&i| shape[i]).collect();
2310
2311    let out_size: usize = if keep_shape.is_empty() {
2312        1
2313    } else {
2314        keep_shape.iter().product()
2315    };
2316    let lane_size: usize = if reduce_shape.is_empty() {
2317        1
2318    } else {
2319        reduce_shape.iter().product()
2320    };
2321    let out_shape = output_shape_axes(shape, axes, keepdims);
2322
2323    let mut result = Vec::with_capacity(out_size);
2324    let mut lane: Vec<T> = Vec::with_capacity(lane_size);
2325    let mut keep_multi = vec![0usize; keep_shape.len()];
2326    let mut reduce_multi = vec![0usize; reduce_shape.len()];
2327    let mut full_multi = vec![0usize; ndim];
2328
2329    for _ in 0..out_size {
2330        for (i, &ax) in keep_axes.iter().enumerate() {
2331            full_multi[ax] = keep_multi[i];
2332        }
2333
2334        lane.clear();
2335        reduce_multi.fill(0);
2336        for _ in 0..lane_size {
2337            for (i, &ax) in reduce_axes.iter().enumerate() {
2338                full_multi[ax] = reduce_multi[i];
2339            }
2340            lane.push(data[flat_index(&full_multi, &strides)]);
2341            if !reduce_shape.is_empty() {
2342                increment_multi_index(&mut reduce_multi, &reduce_shape);
2343            }
2344        }
2345
2346        result.push(f(&lane));
2347
2348        if !keep_shape.is_empty() {
2349            increment_multi_index(&mut keep_multi, &keep_shape);
2350        }
2351    }
2352
2353    (result, out_shape)
2354}
2355
2356// ---------------------------------------------------------------------------
2357// Re-export cumulative operations from ferray-ufunc for discoverability
2358// ---------------------------------------------------------------------------
2359
2360/// Cumulative sum along an axis (or flattened if axis is None).
2361///
2362/// Re-exported from `ferray_ufunc::cumsum` for convenience.
2363pub fn cumsum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
2364where
2365    T: Element + std::ops::Add<Output = T> + Copy,
2366    D: Dimension,
2367{
2368    ferray_ufunc::cumsum(a, axis)
2369}
2370
2371/// Cumulative product along an axis (or flattened if axis is None).
2372///
2373/// Re-exported from `ferray_ufunc::cumprod` for convenience.
2374pub fn cumprod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
2375where
2376    T: Element + std::ops::Mul<Output = T> + Copy,
2377    D: Dimension,
2378{
2379    ferray_ufunc::cumprod(a, axis)
2380}
2381
2382#[cfg(test)]
2383mod tests {
2384    use super::*;
2385    use ferray_core::{Ix1, Ix2};
2386
2387    #[test]
2388    fn test_sum_1d_no_axis() {
2389        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2390        let s = sum(&a, None).unwrap();
2391        assert_eq!(s.shape(), &[]);
2392        assert_eq!(s.iter().next(), Some(&10.0));
2393    }
2394
2395    #[test]
2396    fn test_sum_2d_axis0() {
2397        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2398            .unwrap();
2399        let s = sum(&a, Some(0)).unwrap();
2400        assert_eq!(s.shape(), &[3]);
2401        let data: Vec<f64> = s.iter().copied().collect();
2402        assert_eq!(data, vec![5.0, 7.0, 9.0]);
2403    }
2404
2405    #[test]
2406    fn test_sum_2d_axis1() {
2407        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2408            .unwrap();
2409        let s = sum(&a, Some(1)).unwrap();
2410        assert_eq!(s.shape(), &[2]);
2411        let data: Vec<f64> = s.iter().copied().collect();
2412        assert_eq!(data, vec![6.0, 15.0]);
2413    }
2414
2415    #[test]
2416    fn test_prod_1d() {
2417        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2418        let p = prod(&a, None).unwrap();
2419        assert_eq!(p.iter().next(), Some(&24.0));
2420    }
2421
2422    #[test]
2423    fn test_min_max() {
2424        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
2425        let mn = min(&a, None).unwrap();
2426        let mx = max(&a, None).unwrap();
2427        assert_eq!(mn.iter().next(), Some(&1.0));
2428        assert_eq!(mx.iter().next(), Some(&4.0));
2429    }
2430
2431    #[test]
2432    fn test_argmin_argmax() {
2433        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
2434        let ami = argmin(&a, None).unwrap();
2435        let amx = argmax(&a, None).unwrap();
2436        assert_eq!(ami.iter().next(), Some(&1u64));
2437        assert_eq!(amx.iter().next(), Some(&2u64));
2438    }
2439
2440    #[test]
2441    fn test_mean() {
2442        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2443        let m = mean(&a, None).unwrap();
2444        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
2445    }
2446
2447    #[test]
2448    fn test_var_population() {
2449        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2450        let v = var(&a, None, 0).unwrap();
2451        // var = ((1-2.5)^2 + (2-2.5)^2 + (3-2.5)^2 + (4-2.5)^2) / 4 = 1.25
2452        assert!((v.iter().next().unwrap() - 1.25).abs() < 1e-12);
2453    }
2454
2455    #[test]
2456    fn test_var_sample() {
2457        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2458        let v = var(&a, None, 1).unwrap();
2459        // var = 5.0 / 3.0 = 1.6666...
2460        assert!((v.iter().next().unwrap() - 5.0 / 3.0).abs() < 1e-12);
2461    }
2462
2463    #[test]
2464    fn test_std() {
2465        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2466        let s = std_(&a, None, 1).unwrap();
2467        let expected = (5.0_f64 / 3.0).sqrt();
2468        assert!((s.iter().next().unwrap() - expected).abs() < 1e-12);
2469    }
2470
2471    #[test]
2472    fn test_sum_axis_out_of_bounds() {
2473        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
2474        assert!(sum(&a, Some(1)).is_err());
2475    }
2476
2477    #[test]
2478    fn test_cumsum_reexport() {
2479        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2480        let cs = cumsum(&a, None).unwrap();
2481        assert_eq!(cs.as_slice().unwrap(), &[1, 3, 6, 10]);
2482    }
2483
2484    #[test]
2485    fn test_cumprod_reexport() {
2486        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2487        let cp = cumprod(&a, None).unwrap();
2488        assert_eq!(cp.as_slice().unwrap(), &[1, 2, 6, 24]);
2489    }
2490
2491    #[test]
2492    fn test_min_2d_axis0() {
2493        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0])
2494            .unwrap();
2495        let m = min(&a, Some(0)).unwrap();
2496        let data: Vec<f64> = m.iter().copied().collect();
2497        assert_eq!(data, vec![1.0, 1.0, 2.0]);
2498    }
2499
2500    #[test]
2501    fn test_argmin_2d_axis1() {
2502        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0])
2503            .unwrap();
2504        let ami = argmin(&a, Some(1)).unwrap();
2505        let data: Vec<u64> = ami.iter().copied().collect();
2506        assert_eq!(data, vec![1, 0]);
2507    }
2508
2509    #[test]
2510    fn test_mean_2d_axis0() {
2511        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2512            .unwrap();
2513        let m = mean(&a, Some(0)).unwrap();
2514        let data: Vec<f64> = m.iter().copied().collect();
2515        assert_eq!(data, vec![2.5, 3.5, 4.5]);
2516    }
2517
2518    #[test]
2519    fn test_sum_integer() {
2520        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
2521        let s = sum(&a, None).unwrap();
2522        assert_eq!(s.iter().next(), Some(&15));
2523    }
2524
2525    #[test]
2526    fn test_mean_as_f64_integer() {
2527        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2528        let m = mean_as_f64(&a, None).unwrap();
2529        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
2530    }
2531
2532    #[test]
2533    fn test_mean_as_f64_integer_axis() {
2534        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
2535        let m = mean_as_f64(&a, Some(1)).unwrap();
2536        assert_eq!(m.shape(), &[2]);
2537        let data: Vec<f64> = m.iter().copied().collect();
2538        assert!((data[0] - 2.0).abs() < 1e-12);
2539        assert!((data[1] - 5.0).abs() < 1e-12);
2540    }
2541
2542    #[test]
2543    fn test_mean_as_f64_u8() {
2544        let a = Array::<u8, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
2545        let m = mean_as_f64(&a, None).unwrap();
2546        assert!((m.iter().next().unwrap() - 25.0).abs() < 1e-12);
2547    }
2548
2549    #[test]
2550    fn test_sum_as_f64_integer() {
2551        let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
2552        let s = sum_as_f64(&a, None).unwrap();
2553        assert!((s.iter().next().unwrap() - 10.0).abs() < 1e-12);
2554    }
2555
2556    #[test]
2557    fn test_sum_as_f64_large_values() {
2558        // Values that would overflow i32 if summed as i32
2559        let a =
2560            Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![i32::MAX, i32::MAX, i32::MAX]).unwrap();
2561        let s = sum_as_f64(&a, None).unwrap();
2562        let expected = 3.0 * f64::from(i32::MAX);
2563        assert!((s.iter().next().unwrap() - expected).abs() < 1.0);
2564    }
2565
2566    // -----------------------------------------------------------------------
2567    // Multi-axis + keepdims tests (issues #457 + #458)
2568    // -----------------------------------------------------------------------
2569
2570    use ferray_core::Ix3;
2571
2572    fn arr1d(data: Vec<f64>) -> Array<f64, Ix1> {
2573        let n = data.len();
2574        Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
2575    }
2576
2577    fn arr2d(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
2578        Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap()
2579    }
2580
2581    fn arr3d(s0: usize, s1: usize, s2: usize, data: Vec<f64>) -> Array<f64, Ix3> {
2582        Array::<f64, Ix3>::from_vec(Ix3::new([s0, s1, s2]), data).unwrap()
2583    }
2584
2585    #[test]
2586    fn test_normalize_axes_none_reduces_all() {
2587        let ax = normalize_axes(None, 3).unwrap();
2588        assert_eq!(ax, vec![0, 1, 2]);
2589    }
2590
2591    #[test]
2592    fn test_normalize_axes_empty_reduces_all() {
2593        let ax = normalize_axes(Some(&[]), 3).unwrap();
2594        assert_eq!(ax, vec![0, 1, 2]);
2595    }
2596
2597    #[test]
2598    fn test_normalize_axes_sorts() {
2599        let ax = normalize_axes(Some(&[2, 0]), 3).unwrap();
2600        assert_eq!(ax, vec![0, 2]);
2601    }
2602
2603    #[test]
2604    fn test_normalize_axes_rejects_duplicate() {
2605        assert!(normalize_axes(Some(&[0, 0]), 3).is_err());
2606    }
2607
2608    #[test]
2609    fn test_normalize_axes_rejects_out_of_bounds() {
2610        assert!(normalize_axes(Some(&[3]), 3).is_err());
2611    }
2612
2613    #[test]
2614    fn test_sum_axes_single_axis_matches_legacy() {
2615        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2616        let legacy = sum(&a, Some(0)).unwrap();
2617        let new = sum_axes(&a, Some(&[0]), false).unwrap();
2618        assert_eq!(legacy.shape(), new.shape());
2619        let la: Vec<f64> = legacy.iter().copied().collect();
2620        let na: Vec<f64> = new.iter().copied().collect();
2621        assert_eq!(la, na);
2622    }
2623
2624    #[test]
2625    fn test_sum_axes_keepdims_2d() {
2626        // sum((2, 3), axis=1, keepdims=True) -> shape (2, 1)
2627        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2628        let s = sum_axes(&a, Some(&[1]), true).unwrap();
2629        assert_eq!(s.shape(), &[2, 1]);
2630        let data: Vec<f64> = s.iter().copied().collect();
2631        assert_eq!(data, vec![6.0, 15.0]);
2632    }
2633
2634    #[test]
2635    fn test_sum_axes_keepdims_supports_broadcast_back() {
2636        // Canonical NumPy pattern: arr - arr.mean(axis=1, keepdims=True)
2637        // We do it with a 2x3 array.
2638        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2639        let m = mean_axes(&a, Some(&[1]), true).unwrap();
2640        assert_eq!(m.shape(), &[2, 1]);
2641        // Row 0 mean = 2.0, row 1 mean = 5.0
2642        let md: Vec<f64> = m.iter().copied().collect();
2643        assert_eq!(md, vec![2.0, 5.0]);
2644    }
2645
2646    #[test]
2647    fn test_sum_axes_multi_axis_3d() {
2648        // shape (2, 3, 4), sum over axes (0, 2) -> shape (3,)
2649        let data: Vec<f64> = (0..24).map(f64::from).collect();
2650        let a = arr3d(2, 3, 4, data);
2651        let s = sum_axes(&a, Some(&[0, 2]), false).unwrap();
2652        assert_eq!(s.shape(), &[3]);
2653        // For each j in [0,1,2], sum over i in [0,1] and k in [0,1,2,3]:
2654        //   a[i,j,k] = i*12 + j*4 + k
2655        //   sum = (0*12 + j*4) + ... + (0*12 + j*4+3) + (1*12 + j*4) + ... + (1*12 + j*4+3)
2656        //       = 2*(4*j*4 + 6) + 2*12*1
2657        // For j=0: (0+1+2+3) + (12+13+14+15) = 6 + 54 = 60
2658        // For j=1: (4+5+6+7) + (16+17+18+19) = 22 + 70 = 92
2659        // For j=2: (8+9+10+11) + (20+21+22+23) = 38 + 86 = 124
2660        let d: Vec<f64> = s.iter().copied().collect();
2661        assert_eq!(d, vec![60.0, 92.0, 124.0]);
2662    }
2663
2664    #[test]
2665    fn test_sum_axes_multi_axis_keepdims_3d() {
2666        // Same as above but keepdims -> shape (1, 3, 1)
2667        let data: Vec<f64> = (0..24).map(f64::from).collect();
2668        let a = arr3d(2, 3, 4, data);
2669        let s = sum_axes(&a, Some(&[0, 2]), true).unwrap();
2670        assert_eq!(s.shape(), &[1, 3, 1]);
2671        let d: Vec<f64> = s.iter().copied().collect();
2672        assert_eq!(d, vec![60.0, 92.0, 124.0]);
2673    }
2674
2675    #[test]
2676    fn test_sum_axes_none_reduces_all() {
2677        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2678        let s = sum_axes(&a, None, false).unwrap();
2679        assert_eq!(s.shape(), &[]);
2680        assert_eq!(s.iter().next(), Some(&21.0));
2681    }
2682
2683    #[test]
2684    fn test_sum_axes_none_keepdims() {
2685        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2686        let s = sum_axes(&a, None, true).unwrap();
2687        assert_eq!(s.shape(), &[1, 1]);
2688        assert_eq!(s.iter().next(), Some(&21.0));
2689    }
2690
2691    #[test]
2692    fn test_prod_axes_keepdims() {
2693        let a = arr2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2694        let p = prod_axes(&a, Some(&[1]), true).unwrap();
2695        assert_eq!(p.shape(), &[2, 1]);
2696        let d: Vec<f64> = p.iter().copied().collect();
2697        assert_eq!(d, vec![6.0, 120.0]);
2698    }
2699
2700    #[test]
2701    fn test_min_max_axes_multi_axis() {
2702        let data: Vec<f64> = (0..24).map(f64::from).collect();
2703        let a = arr3d(2, 3, 4, data);
2704        let mn = min_axes(&a, Some(&[0, 2]), false).unwrap();
2705        let mx = max_axes(&a, Some(&[0, 2]), false).unwrap();
2706        // min over i=0..2, k=0..4 for each j: min value on i=0 has smallest
2707        // indices. For j=0, values are 0..3 and 12..15 -> min=0, max=15
2708        assert_eq!(mn.iter().copied().collect::<Vec<_>>(), vec![0.0, 4.0, 8.0]);
2709        assert_eq!(
2710            mx.iter().copied().collect::<Vec<_>>(),
2711            vec![15.0, 19.0, 23.0]
2712        );
2713    }
2714
2715    #[test]
2716    fn test_mean_axes_multi_axis() {
2717        let data: Vec<f64> = (0..24).map(f64::from).collect();
2718        let a = arr3d(2, 3, 4, data);
2719        let m = mean_axes(&a, Some(&[0, 2]), false).unwrap();
2720        assert_eq!(m.shape(), &[3]);
2721        // For j=0: 60 / 8 = 7.5
2722        // For j=1: 92 / 8 = 11.5
2723        // For j=2: 124 / 8 = 15.5
2724        let d: Vec<f64> = m.iter().copied().collect();
2725        assert_eq!(d, vec![7.5, 11.5, 15.5]);
2726    }
2727
2728    #[test]
2729    fn test_var_axes_single_axis_keepdims() {
2730        let a = arr2d(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2731        // Population variance along axis 1 (each row), keepdims = (2, 1)
2732        let v = var_axes(&a, Some(&[1]), 0, true).unwrap();
2733        assert_eq!(v.shape(), &[2, 1]);
2734        // Row 0: mean=2.5, var=((1-2.5)^2+(2-2.5)^2+(3-2.5)^2+(4-2.5)^2)/4 = 1.25
2735        // Row 1: same distribution = 1.25
2736        let d: Vec<f64> = v.iter().copied().collect();
2737        assert!((d[0] - 1.25).abs() < 1e-12);
2738        assert!((d[1] - 1.25).abs() < 1e-12);
2739    }
2740
2741    #[test]
2742    fn test_std_axes_multi_axis() {
2743        let a = arr2d(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2744        // Population std over all: mean=4.5, var=5.25, std≈2.2913
2745        let s = std_axes(&a, None, 0, false).unwrap();
2746        assert_eq!(s.shape(), &[]);
2747        let v = *s.iter().next().unwrap();
2748        assert!((v - 5.25_f64.sqrt()).abs() < 1e-12);
2749    }
2750
2751    #[test]
2752    fn test_argmin_keepdims_single_axis() {
2753        let a = arr2d(2, 3, vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0]);
2754        let am = argmin_keepdims(&a, Some(1), true).unwrap();
2755        assert_eq!(am.shape(), &[2, 1]);
2756        let d: Vec<u64> = am.iter().copied().collect();
2757        assert_eq!(d, vec![1, 0]);
2758    }
2759
2760    #[test]
2761    fn test_argmax_keepdims_single_axis() {
2762        let a = arr2d(2, 3, vec![3.0, 1.0, 4.0, 1.0, 5.0, 2.0]);
2763        let am = argmax_keepdims(&a, Some(1), false).unwrap();
2764        assert_eq!(am.shape(), &[2]);
2765        let d: Vec<u64> = am.iter().copied().collect();
2766        assert_eq!(d, vec![2, 1]);
2767    }
2768
2769    #[test]
2770    fn test_axes_out_of_bounds_error() {
2771        let a = arr2d(2, 3, vec![0.0; 6]);
2772        assert!(sum_axes(&a, Some(&[5]), false).is_err());
2773    }
2774
2775    #[test]
2776    fn test_axes_duplicate_error() {
2777        let a = arr3d(2, 3, 4, vec![0.0; 24]);
2778        assert!(sum_axes(&a, Some(&[0, 0]), false).is_err());
2779    }
2780
2781    // ----- Infinity tests (#177) -----
2782
2783    #[test]
2784    fn sum_with_infinity() {
2785        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2786        let s = sum(&a, None).unwrap();
2787        assert_eq!(s.iter().next().copied().unwrap(), f64::INFINITY);
2788    }
2789
2790    #[test]
2791    fn sum_inf_minus_inf_is_nan() {
2792        let a = arr1d(vec![f64::INFINITY, f64::NEG_INFINITY]);
2793        let s = sum(&a, None).unwrap();
2794        assert!(s.iter().next().copied().unwrap().is_nan());
2795    }
2796
2797    #[test]
2798    fn mean_with_infinity() {
2799        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2800        let m = mean(&a, None).unwrap();
2801        assert_eq!(m.iter().next().copied().unwrap(), f64::INFINITY);
2802    }
2803
2804    #[test]
2805    fn min_with_neg_infinity() {
2806        let a = arr1d(vec![1.0, f64::NEG_INFINITY, 3.0]);
2807        let m = min(&a, None).unwrap();
2808        assert_eq!(m.iter().next().copied().unwrap(), f64::NEG_INFINITY);
2809    }
2810
2811    #[test]
2812    fn max_with_infinity() {
2813        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2814        let m = max(&a, None).unwrap();
2815        assert_eq!(m.iter().next().copied().unwrap(), f64::INFINITY);
2816    }
2817
2818    #[test]
2819    fn prod_with_zero_and_infinity_is_nan() {
2820        // 0 * INF = NaN per IEEE 754
2821        let a = arr1d(vec![0.0, f64::INFINITY, 1.0]);
2822        let p = prod(&a, None).unwrap();
2823        assert!(p.iter().next().copied().unwrap().is_nan());
2824    }
2825
2826    #[test]
2827    fn prod_with_infinity_propagates() {
2828        let a = arr1d(vec![2.0, f64::INFINITY, 3.0]);
2829        let p = prod(&a, None).unwrap();
2830        assert_eq!(p.iter().next().copied().unwrap(), f64::INFINITY);
2831    }
2832
2833    #[test]
2834    fn var_with_infinity_is_nan() {
2835        // var includes (x - mean)^2 where mean is INF; (INF - INF)^2 = NaN.
2836        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2837        let v = var(&a, None, 0).unwrap();
2838        assert!(v.iter().next().copied().unwrap().is_nan());
2839    }
2840
2841    #[test]
2842    fn std_with_infinity_is_nan() {
2843        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2844        let s = std_(&a, None, 0).unwrap();
2845        assert!(s.iter().next().copied().unwrap().is_nan());
2846    }
2847
2848    #[test]
2849    fn argmin_finds_neg_infinity() {
2850        let a = arr1d(vec![1.0, f64::NEG_INFINITY, 3.0, -100.0]);
2851        let i = crate::reductions::argmin(&a, None).unwrap();
2852        assert_eq!(i.iter().next().copied().unwrap(), 1);
2853    }
2854
2855    #[test]
2856    fn argmax_finds_infinity() {
2857        let a = arr1d(vec![1.0, f64::INFINITY, 1e300, 5.0]);
2858        let i = crate::reductions::argmax(&a, None).unwrap();
2859        assert_eq!(i.iter().next().copied().unwrap(), 1);
2860    }
2861
2862    #[test]
2863    fn cumsum_propagates_infinity() {
2864        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2865        let c = cumsum(&a, None).unwrap();
2866        let v: Vec<f64> = c.iter().copied().collect();
2867        assert_eq!(v[0], 1.0);
2868        assert!(v[1].is_infinite());
2869        assert!(v[2].is_infinite());
2870    }
2871
2872    #[test]
2873    fn cumprod_inf_then_zero_yields_nan() {
2874        let a = arr1d(vec![2.0, f64::INFINITY, 0.0]);
2875        let c = cumprod(&a, None).unwrap();
2876        let v: Vec<f64> = c.iter().copied().collect();
2877        assert_eq!(v[0], 2.0);
2878        assert!(v[1].is_infinite());
2879        assert!(v[2].is_nan()); // INF * 0 = NaN
2880    }
2881
2882    #[test]
2883    fn ptp_with_infinity_is_inf() {
2884        let a = arr1d(vec![1.0, f64::INFINITY, 3.0]);
2885        let p = ptp(&a, None).unwrap();
2886        assert!(p.iter().next().copied().unwrap().is_infinite());
2887    }
2888
2889    // ----- Single-element var/std with ddof (#178) -----
2890
2891    #[test]
2892    fn var_single_element_ddof0() {
2893        let a = arr1d(vec![5.0]);
2894        let v = var(&a, None, 0).unwrap();
2895        assert_eq!(v.iter().next().copied().unwrap(), 0.0);
2896    }
2897
2898    #[test]
2899    fn var_single_element_ddof1_errors() {
2900        // ddof=1 with N=1 → ddof >= N, ferray errors instead of
2901        // returning NaN (stricter than NumPy which returns NaN).
2902        let a = arr1d(vec![5.0]);
2903        assert!(var(&a, None, 1).is_err());
2904    }
2905
2906    #[test]
2907    fn std_single_element_ddof0() {
2908        let a = arr1d(vec![5.0]);
2909        let s = std_(&a, None, 0).unwrap();
2910        assert_eq!(s.iter().next().copied().unwrap(), 0.0);
2911    }
2912
2913    #[test]
2914    fn std_single_element_ddof1_errors() {
2915        // Same shape as var: ddof=1 with N=1 hits the same ddof >= N
2916        // guard (std_ delegates to var internally). Confirms the
2917        // error path is plumbed through.
2918        let a = arr1d(vec![5.0]);
2919        assert!(std_(&a, None, 1).is_err());
2920    }
2921
2922    #[test]
2923    fn var_two_elements_ddof1_population_to_sample() {
2924        // Sanity: ddof=0 vs ddof=1 on the same input — Bessel's correction
2925        // should give a 2x bigger variance for N=2.
2926        let a = arr1d(vec![1.0, 3.0]);
2927        let v0 = var(&a, None, 0).unwrap();
2928        let v1 = var(&a, None, 1).unwrap();
2929        let v0_val = v0.iter().next().copied().unwrap();
2930        let v1_val = v1.iter().next().copied().unwrap();
2931        assert!((v0_val - 1.0).abs() < 1e-12); // ((1-2)^2 + (3-2)^2) / 2 = 1.0
2932        assert!((v1_val - 2.0).abs() < 1e-12); // ((1-2)^2 + (3-2)^2) / 1 = 2.0
2933        // Sample variance (ddof=1) should be 2x population (ddof=0) for N=2.
2934        assert!((v1_val / v0_val - 2.0).abs() < 1e-12);
2935    }
2936
2937    #[test]
2938    fn var_single_element_2d_ddof0() {
2939        // 1×1 array — single element along every axis
2940        use ferray_core::dimension::Ix2;
2941        let a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 1]), vec![5.0]).unwrap();
2942        let v = var(&a, None, 0).unwrap();
2943        assert_eq!(v.iter().next().copied().unwrap(), 0.0);
2944    }
2945
2946    #[test]
2947    fn var_single_element_axis_ddof_too_large_errors() {
2948        // 1×3 array with ddof=1 reduced along axis 0 (length 1) should error
2949        use ferray_core::dimension::Ix2;
2950        let a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
2951        assert!(var(&a, Some(0), 1).is_err());
2952    }
2953
2954    // ---- *_into reductions (#467) ----
2955
2956    #[test]
2957    fn sum_into_axis_writes_into_destination() {
2958        // (2, 3) sum along axis=1 → shape (2,)
2959        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2960            .unwrap();
2961        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
2962        sum_into(&a, Some(1), &mut out).unwrap();
2963        assert_eq!(out.as_slice().unwrap(), &[6.0, 15.0]);
2964    }
2965
2966    #[test]
2967    fn sum_into_no_axis_writes_scalar_destination() {
2968        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
2969        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
2970        sum_into(&a, None, &mut out).unwrap();
2971        assert_eq!(out.iter().next().copied().unwrap(), 10.0);
2972    }
2973
2974    #[test]
2975    fn sum_into_rejects_wrong_shape() {
2976        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2977            .unwrap();
2978        // axis=1 reduces to shape (2,), but out has shape (3,)
2979        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0; 3]).unwrap();
2980        let err = sum_into(&a, Some(1), &mut out);
2981        assert!(err.is_err());
2982    }
2983
2984    #[test]
2985    fn sum_into_rejects_axis_out_of_bounds() {
2986        let a = arr1d(vec![1.0, 2.0, 3.0]);
2987        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1]), vec![0.0]).unwrap();
2988        assert!(sum_into(&a, Some(5), &mut out).is_err());
2989    }
2990
2991    #[test]
2992    fn prod_into_basic() {
2993        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
2994        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
2995        prod_into(&a, None, &mut out).unwrap();
2996        assert_eq!(out.iter().next().copied().unwrap(), 24.0);
2997    }
2998
2999    #[test]
3000    fn min_into_axis_basic() {
3001        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 2.0, 4.0, 3.0, 6.0])
3002            .unwrap();
3003        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3004        min_into(&a, Some(1), &mut out).unwrap();
3005        assert_eq!(out.as_slice().unwrap(), &[1.0, 3.0]);
3006    }
3007
3008    #[test]
3009    fn max_into_axis_basic() {
3010        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 2.0, 4.0, 3.0, 6.0])
3011            .unwrap();
3012        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3013        max_into(&a, Some(1), &mut out).unwrap();
3014        assert_eq!(out.as_slice().unwrap(), &[5.0, 6.0]);
3015    }
3016
3017    #[test]
3018    fn mean_into_axis_basic() {
3019        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3020            .unwrap();
3021        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3022        mean_into(&a, Some(1), &mut out).unwrap();
3023        assert_eq!(out.as_slice().unwrap(), &[2.0, 5.0]);
3024    }
3025
3026    #[test]
3027    fn var_into_basic() {
3028        // Variance with ddof=0 of [1,2,3,4] = 1.25
3029        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3030        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3031        var_into(&a, None, 0, &mut out).unwrap();
3032        assert!((out.iter().next().copied().unwrap() - 1.25).abs() < 1e-12);
3033    }
3034
3035    #[test]
3036    fn std_into_basic() {
3037        // sqrt(var) for [1,2,3,4] with ddof=0 = sqrt(1.25)
3038        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3039        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3040        std_into(&a, None, 0, &mut out).unwrap();
3041        let expected = 1.25_f64.sqrt();
3042        assert!((out.iter().next().copied().unwrap() - expected).abs() < 1e-12);
3043    }
3044
3045    #[test]
3046    fn into_reductions_can_reuse_destination_across_calls() {
3047        // The whole point of the out= API: a single allocation reused
3048        // across many reductions on same-shaped data.
3049        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0; 2]).unwrap();
3050        for k in 0..3 {
3051            let base = f64::from(k);
3052            let a = Array::<f64, Ix2>::from_vec(
3053                Ix2::new([2, 3]),
3054                vec![
3055                    base + 1.0,
3056                    base + 2.0,
3057                    base + 3.0,
3058                    base + 4.0,
3059                    base + 5.0,
3060                    base + 6.0,
3061                ],
3062            )
3063            .unwrap();
3064            sum_into(&a, Some(1), &mut out).unwrap();
3065            assert_eq!(
3066                out.as_slice().unwrap(),
3067                &[3.0f64.mul_add(base, 6.0), 3.0f64.mul_add(base, 15.0)]
3068            );
3069        }
3070    }
3071
3072    // ---- zero-alloc kernel regression tests (#563) ----
3073
3074    #[test]
3075    fn sum_into_overwrites_existing_destination_garbage() {
3076        // The new kernel writes results directly into `out` rather than
3077        // routing through write_into's copy_from_slice; make sure existing
3078        // garbage is fully overwritten and never bleeds through.
3079        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3080            .unwrap();
3081        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![999.0, -999.0]).unwrap();
3082        sum_into(&a, Some(1), &mut out).unwrap();
3083        assert_eq!(out.as_slice().unwrap(), &[6.0, 15.0]);
3084    }
3085
3086    #[test]
3087    fn reduce_axis_typed_into_matches_reduce_axis_typed() {
3088        // The in-place kernel must produce identical output to the
3089        // allocating kernel for any (shape, axis, fn).
3090        use super::{reduce_axis_typed, reduce_axis_typed_into};
3091        let data: Vec<f64> = (0..24).map(f64::from).collect();
3092        for shape in [vec![24usize], vec![4, 6], vec![2, 3, 4], vec![2, 2, 2, 3]] {
3093            for ax in 0..shape.len() {
3094                let allocated: Vec<f64> =
3095                    reduce_axis_typed(&data, &shape, ax, |lane| lane.iter().sum());
3096                let mut dst = vec![0.0; allocated.len()];
3097                reduce_axis_typed_into(&data, &shape, ax, &mut dst, |lane| lane.iter().sum());
3098                assert_eq!(dst, allocated, "shape {shape:?} axis {ax}");
3099            }
3100        }
3101    }
3102
3103    #[test]
3104    fn sum_into_3d_axis_correct() {
3105        use ferray_core::Ix3;
3106        // (2, 3, 4) reducing axis 1 → shape (2, 4)
3107        let data: Vec<f64> = (0..24).map(f64::from).collect();
3108        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
3109        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 4]), vec![0.0; 8]).unwrap();
3110        sum_into(&a, Some(1), &mut out).unwrap();
3111        // Hand check: out[i, k] = sum_{j} a[i, j, k] = sum_{j} (i*12 + j*4 + k)
3112        let expected: Vec<f64> = (0..2)
3113            .flat_map(|i| {
3114                (0..4).map(move |k| (0..3).map(|j| f64::from(i * 12 + j * 4 + k)).sum::<f64>())
3115            })
3116            .collect();
3117        assert_eq!(out.as_slice().unwrap(), expected.as_slice());
3118    }
3119
3120    #[test]
3121    fn min_into_rejects_empty_input() {
3122        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3123        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3124        assert!(min_into(&a, None, &mut out).is_err());
3125    }
3126
3127    #[test]
3128    fn max_into_rejects_empty_input() {
3129        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3130        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3131        assert!(max_into(&a, None, &mut out).is_err());
3132    }
3133
3134    #[test]
3135    fn var_into_rejects_ddof_too_large() {
3136        let a = arr1d(vec![1.0, 2.0, 3.0]);
3137        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3138        // n=3, ddof=3 → ddof >= n, must error.
3139        assert!(var_into(&a, None, 3, &mut out).is_err());
3140    }
3141
3142    #[test]
3143    fn std_into_does_not_leave_variance_in_destination_on_success() {
3144        // var of [1,2,3,4] with ddof=0 = 1.25; std = sqrt(1.25) ≈ 1.118.
3145        // The destination must hold the std value, not the variance, after
3146        // a successful call.
3147        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3148        let mut out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[]), vec![0.0]).unwrap();
3149        std_into(&a, None, 0, &mut out).unwrap();
3150        let got = out.iter().next().copied().unwrap();
3151        let expected = 1.25_f64.sqrt();
3152        assert!((got - expected).abs() < 1e-12);
3153        // Sanity: result is not the variance.
3154        assert!((got - 1.25).abs() > 1e-3);
3155    }
3156
3157    // ---- *_with reductions: initial= and where= (#459) ----
3158
3159    #[test]
3160    fn sum_with_initial_only() {
3161        // Initial = 100; sum becomes 100 + 1 + 2 + 3 + 4 = 110
3162        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3163        let r = sum_with(&a, None, Some(100.0), None).unwrap();
3164        assert_eq!(r.iter().next().copied().unwrap(), 110.0);
3165    }
3166
3167    #[test]
3168    fn sum_with_where_mask_only() {
3169        // Sum positives only.
3170        let a = arr1d(vec![1.0, -2.0, 3.0, -4.0, 5.0]);
3171        let mask =
3172            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3173                .unwrap();
3174        let r = sum_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3175        assert_eq!(r.iter().next().copied().unwrap(), 9.0);
3176    }
3177
3178    #[test]
3179    fn sum_with_initial_and_mask_combined() {
3180        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0]);
3181        let mask =
3182            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
3183        let r = sum_with(&a, None, Some(50.0), Some(&mask.to_dyn())).unwrap();
3184        // 50 + 1 + 3 = 54
3185        assert_eq!(r.iter().next().copied().unwrap(), 54.0);
3186    }
3187
3188    #[test]
3189    fn sum_with_axis_and_mask() {
3190        // (2, 3); mask zeroes out the first column; row sums become
3191        // [2+3, 5+6] = [5, 11].
3192        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3193            .unwrap();
3194        let mask = Array::<bool, Ix2>::from_vec(
3195            Ix2::new([2, 3]),
3196            vec![false, true, true, false, true, true],
3197        )
3198        .unwrap();
3199        let r = sum_with(&a, Some(1), None, Some(&mask.to_dyn())).unwrap();
3200        assert_eq!(r.shape(), &[2]);
3201        assert_eq!(r.as_slice().unwrap(), &[5.0, 11.0]);
3202    }
3203
3204    #[test]
3205    fn sum_with_axis_and_initial() {
3206        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3207            .unwrap();
3208        // axis=1 with initial=10 → row sums become [10+6, 10+15] = [16, 25]
3209        let r = sum_with(&a, Some(1), Some(10.0), None).unwrap();
3210        assert_eq!(r.as_slice().unwrap(), &[16.0, 25.0]);
3211    }
3212
3213    #[test]
3214    fn sum_with_no_initial_no_mask_matches_legacy_sum() {
3215        // Sanity: omitting both knobs should match the existing sum().
3216        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3217            .unwrap();
3218        let legacy = sum(&a, Some(1)).unwrap();
3219        let with_form = sum_with(&a, Some(1), None, None).unwrap();
3220        assert_eq!(legacy.as_slice().unwrap(), with_form.as_slice().unwrap());
3221    }
3222
3223    #[test]
3224    fn sum_with_rejects_mismatched_mask_shape() {
3225        let a = arr1d(vec![1.0, 2.0, 3.0]);
3226        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
3227        assert!(sum_with(&a, None, None, Some(&mask.to_dyn())).is_err());
3228    }
3229
3230    #[test]
3231    fn prod_with_initial_only() {
3232        let a = arr1d(vec![2.0, 3.0, 4.0]);
3233        let r = prod_with(&a, None, Some(10.0), None).unwrap();
3234        // 10 * 2 * 3 * 4 = 240
3235        assert_eq!(r.iter().next().copied().unwrap(), 240.0);
3236    }
3237
3238    #[test]
3239    fn prod_with_where_mask() {
3240        let a = arr1d(vec![2.0, 0.0, 3.0, 0.0, 4.0]);
3241        let mask =
3242            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3243                .unwrap();
3244        let r = prod_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3245        // 2 * 3 * 4 = 24 (zero positions skipped)
3246        assert_eq!(r.iter().next().copied().unwrap(), 24.0);
3247    }
3248
3249    #[test]
3250    fn min_with_initial_provides_fallback_for_empty_lane() {
3251        // Empty array + initial = the initial value.
3252        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3253        let r = min_with(&a, None, Some(99.0), None).unwrap();
3254        assert_eq!(r.iter().next().copied().unwrap(), 99.0);
3255    }
3256
3257    #[test]
3258    fn min_with_initial_caps_actual_min() {
3259        // Initial=0, data=[1,2,3] → min(0, 1, 2, 3) = 0.
3260        let a = arr1d(vec![1.0, 2.0, 3.0]);
3261        let r = min_with(&a, None, Some(0.0), None).unwrap();
3262        assert_eq!(r.iter().next().copied().unwrap(), 0.0);
3263    }
3264
3265    #[test]
3266    fn min_with_where_mask_filters_then_reduces() {
3267        let a = arr1d(vec![5.0, 1.0, 4.0, 2.0, 3.0]);
3268        // Mask out positions 1 and 3 (values 1 and 2).
3269        let mask =
3270            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3271                .unwrap();
3272        let r = min_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3273        // Filtered: [5, 4, 3] → min = 3
3274        assert_eq!(r.iter().next().copied().unwrap(), 3.0);
3275    }
3276
3277    #[test]
3278    fn min_with_empty_lane_no_initial_errors() {
3279        let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
3280        assert!(min_with(&a, None, None, None).is_err());
3281    }
3282
3283    #[test]
3284    fn min_with_fully_masked_axis_lane_no_initial_errors() {
3285        // Each row fully masked out → error per lane.
3286        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3287            .unwrap();
3288        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
3289        assert!(min_with(&a, Some(1), None, Some(&mask.to_dyn())).is_err());
3290    }
3291
3292    #[test]
3293    fn min_with_fully_masked_axis_lane_with_initial_uses_initial() {
3294        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3295            .unwrap();
3296        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
3297        let r = min_with(&a, Some(1), Some(99.0), Some(&mask.to_dyn())).unwrap();
3298        assert_eq!(r.as_slice().unwrap(), &[99.0, 99.0]);
3299    }
3300
3301    #[test]
3302    fn max_with_initial_caps_actual_max() {
3303        let a = arr1d(vec![1.0, 2.0, 3.0]);
3304        let r = max_with(&a, None, Some(99.0), None).unwrap();
3305        assert_eq!(r.iter().next().copied().unwrap(), 99.0);
3306    }
3307
3308    #[test]
3309    fn max_with_where_mask_filters_then_reduces() {
3310        let a = arr1d(vec![5.0, 10.0, 4.0, 20.0, 3.0]);
3311        // Mask out positions 1 and 3 (values 10 and 20).
3312        let mask =
3313            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3314                .unwrap();
3315        let r = max_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3316        // Filtered: [5, 4, 3] → max = 5
3317        assert_eq!(r.iter().next().copied().unwrap(), 5.0);
3318    }
3319
3320    #[test]
3321    fn mean_where_filters_and_divides_by_count() {
3322        // [1, 2, 3, 4, 5] with mask [T, F, T, F, T] → (1+3+5)/3 = 3.0
3323        let a = arr1d(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3324        let mask =
3325            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
3326                .unwrap();
3327        let r = mean_where(&a, None, Some(&mask.to_dyn())).unwrap();
3328        assert!((r.iter().next().copied().unwrap() - 3.0).abs() < 1e-12);
3329    }
3330
3331    #[test]
3332    fn mean_where_no_mask_matches_legacy_mean() {
3333        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3334            .unwrap();
3335        let legacy = mean(&a, Some(1)).unwrap();
3336        let where_form = mean_where(&a, Some(1), None).unwrap();
3337        assert_eq!(legacy.as_slice().unwrap(), where_form.as_slice().unwrap());
3338    }
3339
3340    #[test]
3341    fn mean_where_fully_masked_lane_returns_nan() {
3342        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3343            .unwrap();
3344        let mask = Array::<bool, Ix2>::from_vec(
3345            Ix2::new([2, 3]),
3346            vec![true, true, true, false, false, false],
3347        )
3348        .unwrap();
3349        let r = mean_where(&a, Some(1), Some(&mask.to_dyn())).unwrap();
3350        let s = r.as_slice().unwrap();
3351        assert!((s[0] - 2.0).abs() < 1e-12);
3352        assert!(s[1].is_nan());
3353    }
3354
3355    #[test]
3356    fn mean_where_axis_with_partial_mask() {
3357        // (2, 3); mask = [[T,T,F],[F,T,T]]
3358        // Row 0 mean: (1+2)/2 = 1.5
3359        // Row 1 mean: (5+6)/2 = 5.5
3360        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3361            .unwrap();
3362        let mask = Array::<bool, Ix2>::from_vec(
3363            Ix2::new([2, 3]),
3364            vec![true, true, false, false, true, true],
3365        )
3366        .unwrap();
3367        let r = mean_where(&a, Some(1), Some(&mask.to_dyn())).unwrap();
3368        assert_eq!(r.shape(), &[2]);
3369        let s = r.as_slice().unwrap();
3370        assert!((s[0] - 1.5).abs() < 1e-12);
3371        assert!((s[1] - 5.5).abs() < 1e-12);
3372    }
3373
3374    // ---- broadcast-compatible where masks (#565) ----
3375
3376    #[test]
3377    fn sum_with_mask_broadcasts_ix1_into_ix2() {
3378        // (2, 3) input; mask of shape (3,) — broadcasts to (2, 3) with
3379        // every row identical. sum ignores the first column, so both
3380        // row sums skip that column's contribution.
3381        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3382            .unwrap();
3383        let mask_1d = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, true]).unwrap();
3384        let r = sum_with(&a, None, None, Some(&mask_1d.to_dyn())).unwrap();
3385        // Sum with first column masked out: 2 + 3 + 5 + 6 = 16.0
3386        assert!((r.iter().next().copied().unwrap() - 16.0).abs() < 1e-12);
3387    }
3388
3389    #[test]
3390    fn sum_with_mask_broadcasts_column_vector_across_rows() {
3391        // (3, 4) input; mask of shape (3, 1) — broadcasts across the
3392        // four columns, so each row is either fully kept or fully
3393        // masked out.
3394        let a = Array::<f64, Ix2>::from_vec(
3395            Ix2::new([3, 4]),
3396            vec![
3397                1.0, 2.0, 3.0, 4.0, // row 0
3398                5.0, 6.0, 7.0, 8.0, // row 1
3399                9.0, 10.0, 11.0, 12.0, // row 2
3400            ],
3401        )
3402        .unwrap();
3403        let mask_col =
3404            Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![true, false, true]).unwrap();
3405        let r = sum_with(&a, None, None, Some(&mask_col.to_dyn())).unwrap();
3406        // Rows 0 and 2 kept: 1+2+3+4 + 9+10+11+12 = 10 + 42 = 52
3407        assert!((r.iter().next().copied().unwrap() - 52.0).abs() < 1e-12);
3408    }
3409
3410    #[test]
3411    fn sum_with_mask_broadcasts_row_vector_against_axis_reduction() {
3412        // (2, 3) reducing axis=1; mask shape (3,) broadcasts against
3413        // each row.
3414        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3415            .unwrap();
3416        let mask_1d = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3417        let r = sum_with(&a, Some(1), None, Some(&mask_1d.to_dyn())).unwrap();
3418        assert_eq!(r.shape(), &[2]);
3419        // Row 0: 1 + 3 = 4; row 1: 4 + 6 = 10
3420        assert_eq!(r.as_slice().unwrap(), &[4.0, 10.0]);
3421    }
3422
3423    #[test]
3424    fn prod_with_mask_broadcasts_ix1() {
3425        // (2, 3); mask shape (3,) keeps columns 0 and 2.
3426        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
3427            .unwrap();
3428        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3429        let r = prod_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3430        // Product of kept values: 2 * 4 * 5 * 7 = 280
3431        assert!((r.iter().next().copied().unwrap() - 280.0).abs() < 1e-12);
3432    }
3433
3434    #[test]
3435    fn min_with_mask_broadcasts_ix1() {
3436        // Column mask; find min across the kept columns.
3437        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 1.0, 4.0, 2.0, 10.0, 3.0])
3438            .unwrap();
3439        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3440        let r = min_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3441        // Kept values: 5, 4, 2, 3 → min = 2
3442        assert!((r.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
3443    }
3444
3445    #[test]
3446    fn max_with_mask_broadcasts_ix1() {
3447        let a =
3448            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 100.0, 4.0, 2.0, 200.0, 3.0])
3449                .unwrap();
3450        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3451        let r = max_with(&a, None, None, Some(&mask.to_dyn())).unwrap();
3452        // Kept values: 5, 4, 2, 3 → max = 5
3453        assert!((r.iter().next().copied().unwrap() - 5.0).abs() < 1e-12);
3454    }
3455
3456    #[test]
3457    fn mean_where_mask_broadcasts_ix1() {
3458        // (2, 3); mask (3,) keeps columns 0 and 2. Mean of 4 kept
3459        // values: (1 + 3 + 4 + 6) / 4 = 3.5
3460        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3461            .unwrap();
3462        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
3463        let r = mean_where(&a, None, Some(&mask.to_dyn())).unwrap();
3464        assert!((r.iter().next().copied().unwrap() - 3.5).abs() < 1e-12);
3465    }
3466
3467    #[test]
3468    fn with_mask_rejects_incompatible_broadcast_shape() {
3469        // Mask rank compatible but length wrong: shape (2,) against a
3470        // (2, 3) input cannot broadcast (the 2 aligns with the last
3471        // dim which is 3, not the first which is 2).
3472        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3473            .unwrap();
3474        let bad_mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
3475        assert!(sum_with(&a, None, None, Some(&bad_mask.to_dyn())).is_err());
3476    }
3477
3478    #[test]
3479    fn sum_with_mask_broadcast_scalar_like_length_1() {
3480        // A shape-(1,) mask is the scalar case — broadcasts to every
3481        // position of the input, which for `true` is an identity and
3482        // for `false` yields 0 (everything masked out).
3483        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3484            .unwrap();
3485        let true_mask = Array::<bool, Ix1>::from_vec(Ix1::new([1]), vec![true]).unwrap();
3486        let r = sum_with(&a, None, None, Some(&true_mask.to_dyn())).unwrap();
3487        // All positions kept: 1+2+3+4+5+6 = 21
3488        assert!((r.iter().next().copied().unwrap() - 21.0).abs() < 1e-12);
3489
3490        let false_mask = Array::<bool, Ix1>::from_vec(Ix1::new([1]), vec![false]).unwrap();
3491        let r2 = sum_with(&a, None, Some(100.0), Some(&false_mask.to_dyn())).unwrap();
3492        // Nothing kept, initial = 100 → result = 100
3493        assert!((r2.iter().next().copied().unwrap() - 100.0).abs() < 1e-12);
3494    }
3495
3496    #[test]
3497    fn into_reductions_match_allocating_versions_3d_axis_2() {
3498        // Cross-check every *_into against its allocating sibling on a
3499        // 3-D input, axis = last dim — guards against an off-by-one in
3500        // the in-place index walker.
3501        use ferray_core::Ix3;
3502        let data: Vec<f64> = (0..24).map(|i| f64::from(i) + 0.5).collect();
3503        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
3504
3505        let s_alloc = sum(&a, Some(2)).unwrap();
3506        let mut s_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3507        sum_into(&a, Some(2), &mut s_out).unwrap();
3508        assert_eq!(s_alloc.as_slice().unwrap(), s_out.as_slice().unwrap());
3509
3510        let p_alloc = prod(&a, Some(2)).unwrap();
3511        let mut p_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3512        prod_into(&a, Some(2), &mut p_out).unwrap();
3513        assert_eq!(p_alloc.as_slice().unwrap(), p_out.as_slice().unwrap());
3514
3515        let mn_alloc = min(&a, Some(2)).unwrap();
3516        let mut mn_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3517        min_into(&a, Some(2), &mut mn_out).unwrap();
3518        assert_eq!(mn_alloc.as_slice().unwrap(), mn_out.as_slice().unwrap());
3519
3520        let mx_alloc = max(&a, Some(2)).unwrap();
3521        let mut mx_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3522        max_into(&a, Some(2), &mut mx_out).unwrap();
3523        assert_eq!(mx_alloc.as_slice().unwrap(), mx_out.as_slice().unwrap());
3524
3525        let me_alloc = mean(&a, Some(2)).unwrap();
3526        let mut me_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3527        mean_into(&a, Some(2), &mut me_out).unwrap();
3528        assert_eq!(me_alloc.as_slice().unwrap(), me_out.as_slice().unwrap());
3529
3530        let v_alloc = var(&a, Some(2), 0).unwrap();
3531        let mut v_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3532        var_into(&a, Some(2), 0, &mut v_out).unwrap();
3533        // var output may have small numerical drift between the two-pass
3534        // and one-shot kernels — compare element-wise within tolerance.
3535        for (a, b) in v_alloc
3536            .as_slice()
3537            .unwrap()
3538            .iter()
3539            .zip(v_out.as_slice().unwrap())
3540        {
3541            assert!((a - b).abs() < 1e-10);
3542        }
3543
3544        let sd_alloc = std_(&a, Some(2), 0).unwrap();
3545        let mut sd_out = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
3546        std_into(&a, Some(2), 0, &mut sd_out).unwrap();
3547        for (a, b) in sd_alloc
3548            .as_slice()
3549            .unwrap()
3550            .iter()
3551            .zip(sd_out.as_slice().unwrap())
3552        {
3553            assert!((a - b).abs() < 1e-10);
3554        }
3555    }
3556
3557    // -- ptp --
3558
3559    #[test]
3560    fn test_ptp_1d() {
3561        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 5.0, 3.0, 9.0, 2.0]).unwrap();
3562        let r = ptp(&a, None).unwrap();
3563        assert_eq!(r.iter().copied().next().unwrap(), 8.0);
3564    }
3565
3566    #[test]
3567    fn test_ptp_2d_axis() {
3568        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 3.0, 7.0, 2.0, 9.0])
3569            .unwrap();
3570        let r = ptp(&a, Some(1)).unwrap();
3571        // row 0: max=5, min=1 → 4; row 1: max=9, min=2 → 7
3572        assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![4.0, 7.0]);
3573    }
3574
3575    #[test]
3576    fn test_ptp_empty_errs() {
3577        let a: Array<f64, Ix1> = Array::from_vec(Ix1::new([0]), vec![]).unwrap();
3578        assert!(ptp(&a, None).is_err());
3579    }
3580
3581    // -- average --
3582
3583    #[test]
3584    fn test_average_unweighted_matches_mean() {
3585        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3586        let r = average(&a, None, None).unwrap();
3587        assert!((r.iter().copied().next().unwrap() - 2.5).abs() < 1e-12);
3588    }
3589
3590    #[test]
3591    fn test_average_weighted() {
3592        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3593        let w = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 1.0, 1.0, 7.0]).unwrap();
3594        // (1+2+3+4*7) / (1+1+1+7) = 34/10 = 3.4
3595        let r = average(&a, Some(&w), None).unwrap();
3596        assert!((r.iter().copied().next().unwrap() - 3.4).abs() < 1e-12);
3597    }
3598
3599    #[test]
3600    fn test_average_weighted_axis() {
3601        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
3602            .unwrap();
3603        let w = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
3604            .unwrap();
3605        // With uniform weights, axis-1 average = row mean = [2, 5]
3606        let r = average(&a, Some(&w), Some(1)).unwrap();
3607        let data: Vec<f64> = r.iter().copied().collect();
3608        assert!((data[0] - 2.0).abs() < 1e-12);
3609        assert!((data[1] - 5.0).abs() < 1e-12);
3610    }
3611
3612    #[test]
3613    fn test_average_weights_zero_sum_errs() {
3614        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
3615        let w = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.0, 0.0, 0.0]).unwrap();
3616        assert!(average(&a, Some(&w), None).is_err());
3617    }
3618
3619    #[test]
3620    fn test_average_shape_mismatch_errs() {
3621        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
3622        let w = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0; 4]).unwrap();
3623        assert!(average(&a, Some(&w), None).is_err());
3624    }
3625
3626    // ----------------------------------------------------------------------
3627    // var_as_f64 / std_as_f64 integer promotion (#170)
3628    // ----------------------------------------------------------------------
3629
3630    #[test]
3631    fn var_as_f64_promotes_int_input() {
3632        let a = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3633        let v = var_as_f64(&a, None, 0).unwrap();
3634        // var([1,2,3,4,5]) with ddof=0 = 2.0
3635        assert!((v.iter().next().unwrap() - 2.0).abs() < 1e-12);
3636    }
3637
3638    #[test]
3639    fn var_as_f64_ddof_1() {
3640        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3641        let v = var_as_f64(&a, None, 1).unwrap();
3642        // var([1,2,3,4,5]) with ddof=1 = 2.5
3643        assert!((v.iter().next().unwrap() - 2.5).abs() < 1e-12);
3644    }
3645
3646    #[test]
3647    fn std_as_f64_promotes_int_input() {
3648        let a = Array::<u32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
3649        let s = std_as_f64(&a, None, 0).unwrap();
3650        // std([1,2,3,4,5]) with ddof=0 = sqrt(2.0)
3651        assert!((s.iter().next().unwrap() - 2.0_f64.sqrt()).abs() < 1e-12);
3652    }
3653
3654    #[test]
3655    fn var_as_f64_axis_2d() {
3656        use ferray_core::dimension::Ix2;
3657        let a = Array::<i64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
3658        // var along axis 0 of [[1,2,3],[4,5,6]] is [2.25, 2.25, 2.25]
3659        let v = var_as_f64(&a, Some(0), 0).unwrap();
3660        for x in v.iter() {
3661            assert!((x - 2.25).abs() < 1e-12);
3662        }
3663    }
3664
3665    // ----------------------------------------------------------------------
3666    // f32 sibling tests (#185) — exercises the f32 SIMD path added in #173
3667    // and the generic Float-bound reduction paths.
3668    // ----------------------------------------------------------------------
3669
3670    fn arr1d_f32(data: Vec<f32>) -> Array<f32, Ix1> {
3671        Array::<f32, Ix1>::from_vec(Ix1::new([data.len()]), data).unwrap()
3672    }
3673
3674    #[test]
3675    fn sum_f32_basic() {
3676        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3677        let s = sum(&a, None).unwrap();
3678        assert!((s.iter().next().copied().unwrap() - 15.0).abs() < 1e-6);
3679    }
3680
3681    #[test]
3682    fn sum_f32_large_for_simd() {
3683        // Big enough to actually exercise the SIMD pairwise sum kernel.
3684        let n = 4096;
3685        let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
3686        let a = arr1d_f32(data);
3687        let s = sum(&a, None).unwrap();
3688        let expected = 0.1 * (n as f32) * ((n - 1) as f32) / 2.0;
3689        let got = s.iter().copied().next().unwrap();
3690        assert!((got - expected).abs() / expected < 1e-4);
3691    }
3692
3693    #[test]
3694    fn mean_f32_basic() {
3695        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3696        let m = mean(&a, None).unwrap();
3697        assert!((m.iter().next().copied().unwrap() - 3.0).abs() < 1e-6);
3698    }
3699
3700    #[test]
3701    fn var_f32_basic() {
3702        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3703        let v = var(&a, None, 0).unwrap();
3704        assert!((v.iter().next().copied().unwrap() - 2.0).abs() < 1e-6);
3705    }
3706
3707    #[test]
3708    fn var_f32_large_for_simd() {
3709        // Exercises the simd_sum_sq_diff_f32 kernel.
3710        let n = 4096;
3711        let data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
3712        let a = arr1d_f32(data);
3713        let v = var(&a, None, 0).unwrap();
3714        let expected = 0.25 * ((n * n - 1) as f32) / 12.0;
3715        let got = v.iter().copied().next().unwrap();
3716        assert!((got - expected).abs() / expected < 1e-3);
3717    }
3718
3719    #[test]
3720    fn std_f32_basic() {
3721        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
3722        let s = std_(&a, None, 0).unwrap();
3723        assert!((s.iter().next().copied().unwrap() - 2.0_f32.sqrt()).abs() < 1e-6);
3724    }
3725
3726    #[test]
3727    fn min_max_f32_basic() {
3728        let a = arr1d_f32(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
3729        let mn = min(&a, None).unwrap();
3730        let mx = max(&a, None).unwrap();
3731        assert_eq!(mn.iter().next().copied().unwrap(), 1.0);
3732        assert_eq!(mx.iter().next().copied().unwrap(), 9.0);
3733    }
3734
3735    #[test]
3736    fn prod_f32_basic() {
3737        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0]);
3738        let p = prod(&a, None).unwrap();
3739        assert!((p.iter().next().copied().unwrap() - 24.0).abs() < 1e-6);
3740    }
3741
3742    #[test]
3743    fn argmin_argmax_f32_basic() {
3744        let a = arr1d_f32(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
3745        let amin = argmin(&a, None).unwrap();
3746        let amax = argmax(&a, None).unwrap();
3747        assert_eq!(amin.iter().next().copied().unwrap(), 1);
3748        assert_eq!(amax.iter().next().copied().unwrap(), 5);
3749    }
3750
3751    #[test]
3752    fn cumsum_f32_basic() {
3753        let a = arr1d_f32(vec![1.0, 2.0, 3.0, 4.0]);
3754        let c = cumsum(&a, None).unwrap();
3755        let v: Vec<f32> = c.iter().copied().collect();
3756        for (got, expected) in v.iter().zip(&[1.0_f32, 3.0, 6.0, 10.0]) {
3757            assert!((got - expected).abs() < 1e-6);
3758        }
3759    }
3760}