Skip to main content

ferray_ufunc/ops/
arithmetic.rs

1// ferray-ufunc: Arithmetic functions
2//
3// add, subtract, multiply, divide, true_divide, floor_divide, power,
4// remainder, mod_, fmod, divmod, absolute, fabs, sign, negative, positive,
5// reciprocal, sqrt, cbrt, square, heaviside, gcd, lcm
6//
7// Cumulative: cumsum, cumprod, nancumsum, nancumprod
8// Differences: diff, ediff1d, gradient
9// Products: cross
10// Integration: trapezoid
11//
12// Reduction: add_reduce, add_accumulate, multiply_outer
13
14use ferray_core::Array;
15use ferray_core::dimension::{Dimension, Ix1, IxDyn};
16use ferray_core::dtype::Element;
17use ferray_core::error::{FerrayError, FerrayResult};
18use num_traits::Float;
19
20use crate::helpers::{
21    binary_broadcast_op, binary_elementwise_op, binary_elementwise_op_into, try_simd_f32_binary,
22    try_simd_f64_binary, unary_float_op, unary_float_op_compute, unary_float_op_into,
23};
24use crate::kernels::simd_f32::{add_f32, div_f32, mul_f32, sub_f32};
25use crate::kernels::simd_f64::{add_f64, div_f64, mul_f64, sub_f64};
26
27// ---------------------------------------------------------------------------
28// Basic arithmetic (binary, same-shape)
29// ---------------------------------------------------------------------------
30
31/// Elementwise addition with `NumPy` broadcasting.
32///
33/// Same-shape f64 / f32 inputs go through the explicit SIMD slice
34/// kernel (`add_f64` / `add_f32` via `pulp::Arch` runtime dispatch).
35/// Other dtypes and broadcasting paths fall through to the generic
36/// auto-vectorised loop in `binary_elementwise_op`. (#88)
37pub fn add<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
38where
39    T: Element + std::ops::Add<Output = T> + Copy,
40    D: Dimension,
41{
42    if let Some(r) = try_simd_f64_binary(a, b, add_f64) {
43        return r;
44    }
45    if let Some(r) = try_simd_f32_binary(a, b, add_f32) {
46        return r;
47    }
48    binary_elementwise_op(a, b, |x, y| x + y)
49}
50
51/// In-place elementwise addition, equivalent to `NumPy`'s
52/// `np.add(a, b, out=out)`. Writes `a + b` directly into `out` without
53/// allocating. All three arrays must be contiguous (C-order) and have the
54/// same shape; broadcasting is not supported on this fast path — use
55/// [`add`] if you need it.
56///
57/// # Errors
58/// - `FerrayError::ShapeMismatch` if shapes differ.
59/// - `FerrayError::InvalidValue` if any array is non-contiguous.
60pub fn add_into<T, D>(a: &Array<T, D>, b: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
61where
62    T: Element + std::ops::Add<Output = T> + Copy,
63    D: Dimension,
64{
65    binary_elementwise_op_into(a, b, out, "add", |x, y| x + y)
66}
67
68/// Elementwise subtraction with `NumPy` broadcasting. SIMD-dispatched
69/// for same-shape f64/f32 inputs (#88).
70pub fn subtract<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
71where
72    T: Element + std::ops::Sub<Output = T> + Copy,
73    D: Dimension,
74{
75    if let Some(r) = try_simd_f64_binary(a, b, sub_f64) {
76        return r;
77    }
78    if let Some(r) = try_simd_f32_binary(a, b, sub_f32) {
79        return r;
80    }
81    binary_elementwise_op(a, b, |x, y| x - y)
82}
83
84/// In-place subtraction — the `_into` counterpart of [`subtract`].
85pub fn subtract_into<T, D>(
86    a: &Array<T, D>,
87    b: &Array<T, D>,
88    out: &mut Array<T, D>,
89) -> FerrayResult<()>
90where
91    T: Element + std::ops::Sub<Output = T> + Copy,
92    D: Dimension,
93{
94    binary_elementwise_op_into(a, b, out, "subtract", |x, y| x - y)
95}
96
97/// Elementwise multiplication with `NumPy` broadcasting. SIMD-dispatched
98/// for same-shape f64/f32 inputs (#88).
99pub fn multiply<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
100where
101    T: Element + std::ops::Mul<Output = T> + Copy,
102    D: Dimension,
103{
104    if let Some(r) = try_simd_f64_binary(a, b, mul_f64) {
105        return r;
106    }
107    if let Some(r) = try_simd_f32_binary(a, b, mul_f32) {
108        return r;
109    }
110    binary_elementwise_op(a, b, |x, y| x * y)
111}
112
113/// In-place multiplication — the `_into` counterpart of [`multiply`].
114pub fn multiply_into<T, D>(
115    a: &Array<T, D>,
116    b: &Array<T, D>,
117    out: &mut Array<T, D>,
118) -> FerrayResult<()>
119where
120    T: Element + std::ops::Mul<Output = T> + Copy,
121    D: Dimension,
122{
123    binary_elementwise_op_into(a, b, out, "multiply", |x, y| x * y)
124}
125
126/// Elementwise division with `NumPy` broadcasting. SIMD-dispatched
127/// for same-shape f64/f32 inputs (#88).
128pub fn divide<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
129where
130    T: Element + std::ops::Div<Output = T> + Copy,
131    D: Dimension,
132{
133    if let Some(r) = try_simd_f64_binary(a, b, div_f64) {
134        return r;
135    }
136    if let Some(r) = try_simd_f32_binary(a, b, div_f32) {
137        return r;
138    }
139    binary_elementwise_op(a, b, |x, y| x / y)
140}
141
142/// In-place division — the `_into` counterpart of [`divide`].
143pub fn divide_into<T, D>(
144    a: &Array<T, D>,
145    b: &Array<T, D>,
146    out: &mut Array<T, D>,
147) -> FerrayResult<()>
148where
149    T: Element + std::ops::Div<Output = T> + Copy,
150    D: Dimension,
151{
152    binary_elementwise_op_into(a, b, out, "divide", |x, y| x / y)
153}
154
155/// Alias for [`divide`] — true division (float).
156pub fn true_divide<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
157where
158    T: Element + Float,
159    D: Dimension,
160{
161    binary_elementwise_op(a, b, |x, y| x / y)
162}
163
164/// Floor division: floor(a / b).
165pub fn floor_divide<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
166where
167    T: Element + Float,
168    D: Dimension,
169{
170    binary_elementwise_op(a, b, |x, y| (x / y).floor())
171}
172
173/// Elementwise power: a^b.
174pub fn power<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
175where
176    T: Element + Float,
177    D: Dimension,
178{
179    binary_elementwise_op(a, b, num_traits::Float::powf)
180}
181
182/// Elementwise remainder (Python-style modulo).
183pub fn remainder<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
184where
185    T: Element + Float,
186    D: Dimension,
187{
188    let z = <T as Element>::zero();
189    binary_elementwise_op(a, b, |x, y| {
190        let r = x % y;
191        // Python/NumPy mod: result has same sign as divisor
192        if (r < z && y > z) || (r > z && y < z) {
193            r + y
194        } else {
195            r
196        }
197    })
198}
199
200/// Alias for [`remainder`].
201pub fn mod_<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
202where
203    T: Element + Float,
204    D: Dimension,
205{
206    remainder(a, b)
207}
208
209/// C-style fmod (remainder has same sign as dividend).
210pub fn fmod<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
211where
212    T: Element + Float,
213    D: Dimension,
214{
215    binary_elementwise_op(a, b, |x, y| x % y)
216}
217
218/// Return `(floor_divide, remainder)` as a tuple of arrays, with broadcasting.
219///
220/// Computes both results in a single pass over the (broadcast) data,
221/// avoiding the redundant division that would occur from calling
222/// `floor_divide` and `remainder` separately.
223pub fn divmod<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<(Array<T, D>, Array<T, D>)>
224where
225    T: Element + Float,
226    D: Dimension,
227{
228    use ferray_core::dimension::broadcast::{broadcast_shapes, broadcast_to};
229
230    let z = <T as Element>::zero();
231
232    // Inline the divmod kernel so we can route both fast and broadcast
233    // paths through a single closure body.
234    let kernel = |x: T, y: T| -> (T, T) {
235        let q = (x / y).floor();
236        let mut r = x - q * y;
237        if (r < z && y > z) || (r > z && y < z) {
238            r = r + y;
239        }
240        (q, r)
241    };
242
243    // Fast path: identical shapes.
244    if a.shape() == b.shape() {
245        let mut quot_data = Vec::with_capacity(a.size());
246        let mut rem_data = Vec::with_capacity(a.size());
247        for (&x, &y) in a.iter().zip(b.iter()) {
248            let (q, r) = kernel(x, y);
249            quot_data.push(q);
250            rem_data.push(r);
251        }
252        let quot = Array::from_vec(a.dim().clone(), quot_data)?;
253        let rem = Array::from_vec(a.dim().clone(), rem_data)?;
254        return Ok((quot, rem));
255    }
256
257    // Broadcasting path.
258    let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
259        FerrayError::shape_mismatch(format!(
260            "divmod: shapes {:?} and {:?} are not broadcast-compatible",
261            a.shape(),
262            b.shape()
263        ))
264    })?;
265    let a_view = broadcast_to(a, &target_shape)?;
266    let b_view = broadcast_to(b, &target_shape)?;
267    let n: usize = target_shape.iter().product();
268    let mut quot_data = Vec::with_capacity(n);
269    let mut rem_data = Vec::with_capacity(n);
270    for (&x, &y) in a_view.iter().zip(b_view.iter()) {
271        let (q, r) = kernel(x, y);
272        quot_data.push(q);
273        rem_data.push(r);
274    }
275    let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
276        FerrayError::shape_mismatch(format!(
277            "divmod: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
278        ))
279    })?;
280    let quot = Array::from_vec(result_dim.clone(), quot_data)?;
281    let rem = Array::from_vec(result_dim, rem_data)?;
282    Ok((quot, rem))
283}
284
285// ---------------------------------------------------------------------------
286// Unary arithmetic
287// ---------------------------------------------------------------------------
288
289/// Elementwise absolute value.
290///
291/// Uses hardware SIMD for contiguous f64 arrays.
292pub fn absolute<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
293where
294    T: Element + Float,
295    D: Dimension,
296{
297    if let Some(r) = crate::helpers::try_simd_f64_unary(input, crate::dispatch::simd_abs_f64) {
298        return r;
299    }
300    if let Some(r) = crate::helpers::try_simd_f32_unary(input, crate::dispatch::simd_abs_f32) {
301        return r;
302    }
303    unary_float_op(input, T::abs)
304}
305
306/// In-place elementwise absolute value — `_into` counterpart of [`absolute`].
307pub fn absolute_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
308where
309    T: Element + Float,
310    D: Dimension,
311{
312    unary_float_op_into(input, out, "absolute", T::abs)
313}
314
315/// Alias for [`absolute`] — float abs.
316pub fn fabs<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
317where
318    T: Element + Float,
319    D: Dimension,
320{
321    absolute(input)
322}
323
324/// Elementwise sign: -1 for negative, 0 for zero, +1 for positive.
325pub fn sign<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
326where
327    T: Element + Float,
328    D: Dimension,
329{
330    unary_float_op(input, |x| {
331        if x.is_nan() {
332            <T as Float>::nan()
333        } else if x > <T as Element>::zero() {
334            <T as Element>::one()
335        } else if x < <T as Element>::zero() {
336            -<T as Element>::one()
337        } else {
338            <T as Element>::zero()
339        }
340    })
341}
342
343/// Elementwise negation.
344///
345/// Uses hardware SIMD for contiguous f64 arrays.
346pub fn negative<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
347where
348    T: Element + Float,
349    D: Dimension,
350{
351    if let Some(r) = crate::helpers::try_simd_f64_unary(input, crate::dispatch::simd_neg_f64) {
352        return r;
353    }
354    if let Some(r) = crate::helpers::try_simd_f32_unary(input, crate::dispatch::simd_neg_f32) {
355        return r;
356    }
357    unary_float_op(input, |x| -x)
358}
359
360/// In-place elementwise negation — `_into` counterpart of [`negative`].
361pub fn negative_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
362where
363    T: Element + Float,
364    D: Dimension,
365{
366    unary_float_op_into(input, out, "negative", |x| -x)
367}
368
369/// Elementwise positive (identity for numeric types).
370pub fn positive<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
371where
372    T: Element + Float,
373    D: Dimension,
374{
375    unary_float_op(input, |x| x)
376}
377
378/// Elementwise reciprocal: 1/x.
379///
380/// Uses hardware SIMD for contiguous f64 arrays.
381pub fn reciprocal<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
382where
383    T: Element + Float,
384    D: Dimension,
385{
386    if let Some(r) = crate::helpers::try_simd_f64_unary(input, crate::dispatch::simd_reciprocal_f64)
387    {
388        return r;
389    }
390    if let Some(r) = crate::helpers::try_simd_f32_unary(input, crate::dispatch::simd_reciprocal_f32)
391    {
392        return r;
393    }
394    unary_float_op(input, T::recip)
395}
396
397/// Elementwise square root.
398///
399/// Uses hardware SIMD (`vsqrtpd`) for contiguous f64 arrays.
400pub fn sqrt<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
401where
402    T: Element + Float,
403    D: Dimension,
404{
405    if let Some(r) = crate::helpers::try_simd_f64_unary(input, crate::dispatch::simd_sqrt_f64) {
406        return r;
407    }
408    unary_float_op(input, T::sqrt)
409}
410
411/// In-place elementwise square root — the `_into` counterpart of [`sqrt`].
412pub fn sqrt_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
413where
414    T: Element + Float,
415    D: Dimension,
416{
417    unary_float_op_into(input, out, "sqrt", T::sqrt)
418}
419
420/// Elementwise cube root.
421pub fn cbrt<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
422where
423    T: Element + Float + crate::cr_math::CrMath,
424    D: Dimension,
425{
426    unary_float_op_compute(input, T::cr_cbrt)
427}
428
429/// Elementwise square: x^2.
430///
431/// Uses hardware SIMD for contiguous f64 arrays.
432pub fn square<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
433where
434    T: Element + Float,
435    D: Dimension,
436{
437    if let Some(r) = crate::helpers::try_simd_f64_unary(input, crate::dispatch::simd_square_f64) {
438        return r;
439    }
440    if let Some(r) = crate::helpers::try_simd_f32_unary(input, crate::dispatch::simd_square_f32) {
441        return r;
442    }
443    unary_float_op(input, |x| x * x)
444}
445
446/// In-place elementwise square — the `_into` counterpart of [`square`].
447pub fn square_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
448where
449    T: Element + Float,
450    D: Dimension,
451{
452    unary_float_op_into(input, out, "square", |x| x * x)
453}
454
455/// Heaviside step function.
456///
457/// `heaviside(x, h0)` returns 0 for x < 0, h0 for x == 0, and 1 for x > 0.
458pub fn heaviside<T, D>(x: &Array<T, D>, h0: &Array<T, D>) -> FerrayResult<Array<T, D>>
459where
460    T: Element + Float,
461    D: Dimension,
462{
463    binary_elementwise_op(x, h0, |xi, h0i| {
464        if xi.is_nan() {
465            xi
466        } else if xi < <T as Element>::zero() {
467            <T as Element>::zero()
468        } else if xi == <T as Element>::zero() {
469            h0i
470        } else {
471            <T as Element>::one()
472        }
473    })
474}
475
476/// Integer GCD (works on float representations of integers).
477pub fn gcd<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
478where
479    T: Element + Float,
480    D: Dimension,
481{
482    binary_elementwise_op(a, b, |mut x, mut y| {
483        if x.is_nan() || y.is_nan() {
484            return T::nan();
485        }
486        x = x.abs();
487        y = y.abs();
488        while y != <T as Element>::zero() {
489            let t = y;
490            y = x % y;
491            x = t;
492        }
493        x
494    })
495}
496
497/// Integer LCM (works on float representations of integers).
498pub fn lcm<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
499where
500    T: Element + Float,
501    D: Dimension,
502{
503    binary_elementwise_op(a, b, |x, y| {
504        if x.is_nan() || y.is_nan() {
505            return T::nan();
506        }
507        let ax = x.abs();
508        let ay = y.abs();
509        if ax == <T as Element>::zero() || ay == <T as Element>::zero() {
510            return <T as Element>::zero();
511        }
512        // lcm = |a*b| / gcd(a,b)
513        let mut gx = ax;
514        let mut gy = ay;
515        while gy != <T as Element>::zero() {
516            let t = gy;
517            gy = gx % gy;
518            gx = t;
519        }
520        ax / gx * ay
521    })
522}
523
524/// Integer GCD using the Euclidean algorithm.
525///
526/// Works on actual integer element types (i8, i16, i32, i64, u8, u16, u32, u64, etc.).
527/// For float-typed arrays, use [`gcd`] instead.
528pub fn gcd_int<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
529where
530    T: Element + Copy + PartialEq + std::ops::Rem<Output = T> + num_traits::Signed,
531    D: Dimension,
532{
533    binary_elementwise_op(a, b, |x, y| {
534        let mut ax = x.abs();
535        let mut ay = y.abs();
536        while ay != <T as Element>::zero() {
537            let t = ay;
538            ay = ax % ay;
539            ax = t;
540        }
541        ax
542    })
543}
544
545/// Integer LCM using the Euclidean GCD algorithm.
546///
547/// Works on actual integer element types. For float-typed arrays, use [`lcm`] instead.
548pub fn lcm_int<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
549where
550    T: Element
551        + Copy
552        + PartialEq
553        + std::ops::Rem<Output = T>
554        + std::ops::Div<Output = T>
555        + std::ops::Mul<Output = T>
556        + num_traits::Signed,
557    D: Dimension,
558{
559    binary_elementwise_op(a, b, |x, y| {
560        let ax = x.abs();
561        let ay = y.abs();
562        if ax == <T as Element>::zero() || ay == <T as Element>::zero() {
563            return <T as Element>::zero();
564        }
565        let mut gx = ax;
566        let mut gy = ay;
567        while gy != <T as Element>::zero() {
568            let t = gy;
569            gy = gx % gy;
570            gx = t;
571        }
572        ax / gx * ay
573    })
574}
575
576// ---------------------------------------------------------------------------
577// Broadcasting binary arithmetic
578// ---------------------------------------------------------------------------
579
580/// Elementwise addition with broadcasting.
581pub fn add_broadcast<T, D1, D2>(a: &Array<T, D1>, b: &Array<T, D2>) -> FerrayResult<Array<T, IxDyn>>
582where
583    T: Element + std::ops::Add<Output = T> + Copy,
584    D1: Dimension,
585    D2: Dimension,
586{
587    binary_broadcast_op(a, b, |x, y| x + y)
588}
589
590/// Elementwise subtraction with broadcasting.
591pub fn subtract_broadcast<T, D1, D2>(
592    a: &Array<T, D1>,
593    b: &Array<T, D2>,
594) -> FerrayResult<Array<T, IxDyn>>
595where
596    T: Element + std::ops::Sub<Output = T> + Copy,
597    D1: Dimension,
598    D2: Dimension,
599{
600    binary_broadcast_op(a, b, |x, y| x - y)
601}
602
603/// Elementwise multiplication with broadcasting.
604pub fn multiply_broadcast<T, D1, D2>(
605    a: &Array<T, D1>,
606    b: &Array<T, D2>,
607) -> FerrayResult<Array<T, IxDyn>>
608where
609    T: Element + std::ops::Mul<Output = T> + Copy,
610    D1: Dimension,
611    D2: Dimension,
612{
613    binary_broadcast_op(a, b, |x, y| x * y)
614}
615
616/// Elementwise division with broadcasting.
617pub fn divide_broadcast<T, D1, D2>(
618    a: &Array<T, D1>,
619    b: &Array<T, D2>,
620) -> FerrayResult<Array<T, IxDyn>>
621where
622    T: Element + std::ops::Div<Output = T> + Copy,
623    D1: Dimension,
624    D2: Dimension,
625{
626    binary_broadcast_op(a, b, |x, y| x / y)
627}
628
629// ---------------------------------------------------------------------------
630// Reductions
631// ---------------------------------------------------------------------------
632
633/// Reduce by addition along an axis (column sums, row sums, etc.).
634///
635/// Equivalent to `np.add.reduce(arr, axis=...)`. Delegates to the generic
636/// [`crate::ufunc_methods::reduce_axis`] with the `+` kernel and `0` seed.
637///
638/// AC-2: `add_reduce` computes correct column sums.
639pub fn add_reduce<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<T, IxDyn>>
640where
641    T: Element + std::ops::Add<Output = T> + Copy,
642    D: Dimension,
643{
644    crate::ufunc_methods::reduce_axis(input, axis, <T as Element>::zero(), |acc, x| acc + x)
645}
646
647/// Reduce by addition along an axis with an optional `keepdims` flag.
648///
649/// Equivalent to `np.add.reduce(arr, axis=..., keepdims=...)` /
650/// `np.sum(arr, axis=..., keepdims=...)`. When `keepdims = true` the
651/// reduced axis is preserved as a size-1 dimension so the result is
652/// broadcastable back against the original input — the classic pattern
653/// for row/column centering (`arr - arr.sum(axis=1, keepdims=True)`).
654///
655/// With `keepdims = false` this behaves exactly like [`add_reduce`].
656/// Added for #394.
657pub fn add_reduce_keepdims<T, D>(
658    input: &Array<T, D>,
659    axis: usize,
660    keepdims: bool,
661) -> FerrayResult<Array<T, IxDyn>>
662where
663    T: Element + std::ops::Add<Output = T> + Copy,
664    D: Dimension,
665{
666    crate::ufunc_methods::reduce_axis_keepdims(
667        input,
668        axis,
669        <T as Element>::zero(),
670        keepdims,
671        |acc, x| acc + x,
672    )
673}
674
675/// Reduce by addition over multiple axes simultaneously.
676///
677/// Equivalent to `np.add.reduce(arr, axis=axes, keepdims=keepdims)` /
678/// `np.sum(arr, axis=axes, keepdims=keepdims)` where `axes` is a tuple
679/// of axes to collapse. Reduces every listed axis in a single pass over
680/// the input — never materializes intermediates the way chained
681/// `add_reduce` calls would, and the order of `axes` is irrelevant.
682/// Added for #395.
683pub fn add_reduce_axes<T, D>(
684    input: &Array<T, D>,
685    axes: &[usize],
686    keepdims: bool,
687) -> FerrayResult<Array<T, IxDyn>>
688where
689    T: Element + std::ops::Add<Output = T> + Copy,
690    D: Dimension,
691{
692    crate::ufunc_methods::reduce_axes(input, axes, <T as Element>::zero(), keepdims, |acc, x| {
693        acc + x
694    })
695}
696
697/// Reduce by addition over the entire array (the `axis=None` form).
698///
699/// Equivalent to `np.add.reduce(arr, axis=None)` / `np.sum(arr)`.
700/// Returns a single scalar — use [`add_reduce_axes`] when you want a
701/// wrapped array result that supports `keepdims`.
702///
703/// Added for #395.
704pub fn add_reduce_all<T, D>(input: &Array<T, D>) -> T
705where
706    T: Element + std::ops::Add<Output = T> + Copy,
707    D: Dimension,
708{
709    crate::ufunc_methods::reduce_all(input, <T as Element>::zero(), |acc, x| acc + x)
710}
711
712// ---------------------------------------------------------------------------
713// NaN-aware reductions (#388)
714//
715// Parallel to add_reduce / multiply_reduce / max_reduce / min_reduce but
716// with NaN-skipping kernels. ferray-stats already exposes high-level
717// nansum / nanmean / etc. wrappers; these are the lower-level ufunc
718// primitives that match the cumulative nancumsum/nancumprod pattern in
719// the same module — they live here so the full reduction family
720// (whole-array + axis + axes + keepdims) is available without depending
721// on ferray-stats.
722//
723// All four functions require `T: Element + Float` so the kernel can call
724// `.is_nan()`. NaNs are dropped via per-element preprocessing into the
725// reduction identity (0 for sum, 1 for product, +inf for min, -inf for
726// max). Whole-array forms return a scalar; axis-aware forms delegate to
727// the generic reduce_axes / reduce_axis_keepdims helpers.
728// ---------------------------------------------------------------------------
729
730/// Reduce by NaN-skipping addition along an axis with optional keepdims.
731///
732/// Equivalent to `np.nansum(arr, axis=axis, keepdims=keepdims)`. NaN
733/// elements are treated as zero and contribute nothing to the sum.
734pub fn nan_add_reduce<T, D>(
735    input: &Array<T, D>,
736    axis: usize,
737    keepdims: bool,
738) -> FerrayResult<Array<T, IxDyn>>
739where
740    T: Element + Float,
741    D: Dimension,
742{
743    crate::ufunc_methods::reduce_axis_keepdims(
744        input,
745        axis,
746        <T as Element>::zero(),
747        keepdims,
748        |acc, x| acc + nan_to_zero(x),
749    )
750}
751
752/// Reduce by NaN-skipping addition over multiple axes simultaneously.
753pub fn nan_add_reduce_axes<T, D>(
754    input: &Array<T, D>,
755    axes: &[usize],
756    keepdims: bool,
757) -> FerrayResult<Array<T, IxDyn>>
758where
759    T: Element + Float,
760    D: Dimension,
761{
762    crate::ufunc_methods::reduce_axes(input, axes, <T as Element>::zero(), keepdims, |acc, x| {
763        acc + nan_to_zero(x)
764    })
765}
766
767/// Reduce by NaN-skipping addition over the entire array.
768///
769/// Equivalent to `np.nansum(arr)` / `np.nansum(arr, axis=None)`. Returns
770/// a scalar. NaN elements contribute nothing to the sum; an array of all
771/// NaNs sums to zero.
772pub fn nan_add_reduce_all<T, D>(input: &Array<T, D>) -> T
773where
774    T: Element + Float,
775    D: Dimension,
776{
777    crate::ufunc_methods::reduce_all(input, <T as Element>::zero(), |acc, x| acc + nan_to_zero(x))
778}
779
780/// Reduce by NaN-skipping multiplication along an axis with optional keepdims.
781///
782/// Equivalent to `np.nanprod(arr, axis=axis, keepdims=keepdims)`. NaN
783/// elements are treated as one and contribute nothing to the product.
784pub fn nan_multiply_reduce<T, D>(
785    input: &Array<T, D>,
786    axis: usize,
787    keepdims: bool,
788) -> FerrayResult<Array<T, IxDyn>>
789where
790    T: Element + Float,
791    D: Dimension,
792{
793    crate::ufunc_methods::reduce_axis_keepdims(
794        input,
795        axis,
796        <T as Element>::one(),
797        keepdims,
798        |acc, x| acc * nan_to_one(x),
799    )
800}
801
802/// Reduce by NaN-skipping multiplication over multiple axes.
803pub fn nan_multiply_reduce_axes<T, D>(
804    input: &Array<T, D>,
805    axes: &[usize],
806    keepdims: bool,
807) -> FerrayResult<Array<T, IxDyn>>
808where
809    T: Element + Float,
810    D: Dimension,
811{
812    crate::ufunc_methods::reduce_axes(input, axes, <T as Element>::one(), keepdims, |acc, x| {
813        acc * nan_to_one(x)
814    })
815}
816
817/// Reduce by NaN-skipping multiplication over the entire array.
818pub fn nan_multiply_reduce_all<T, D>(input: &Array<T, D>) -> T
819where
820    T: Element + Float,
821    D: Dimension,
822{
823    crate::ufunc_methods::reduce_all(input, <T as Element>::one(), |acc, x| acc * nan_to_one(x))
824}
825
826/// Reduce by NaN-skipping maximum along an axis with optional keepdims.
827///
828/// Equivalent to `np.nanmax(arr, axis=axis, keepdims=keepdims)`. NaN
829/// elements are skipped (treated as `-inf`).
830pub fn nan_max_reduce<T, D>(
831    input: &Array<T, D>,
832    axis: usize,
833    keepdims: bool,
834) -> FerrayResult<Array<T, IxDyn>>
835where
836    T: Element + Float,
837    D: Dimension,
838{
839    crate::ufunc_methods::reduce_axis_keepdims(
840        input,
841        axis,
842        <T as Float>::neg_infinity(),
843        keepdims,
844        |acc, x| {
845            if x.is_nan() {
846                acc
847            } else if x > acc {
848                x
849            } else {
850                acc
851            }
852        },
853    )
854}
855
856/// Reduce by NaN-skipping maximum over multiple axes.
857pub fn nan_max_reduce_axes<T, D>(
858    input: &Array<T, D>,
859    axes: &[usize],
860    keepdims: bool,
861) -> FerrayResult<Array<T, IxDyn>>
862where
863    T: Element + Float,
864    D: Dimension,
865{
866    crate::ufunc_methods::reduce_axes(
867        input,
868        axes,
869        <T as Float>::neg_infinity(),
870        keepdims,
871        |acc, x| {
872            if x.is_nan() {
873                acc
874            } else if x > acc {
875                x
876            } else {
877                acc
878            }
879        },
880    )
881}
882
883/// Reduce by NaN-skipping maximum over the entire array.
884///
885/// Equivalent to `np.nanmax(arr)`. Returns `-inf` for an all-NaN input
886/// rather than raising — callers that need the all-NaN error semantics
887/// should use ferray-stats' `nanmax` (which checks the result and errors
888/// out instead of returning the seed).
889pub fn nan_max_reduce_all<T, D>(input: &Array<T, D>) -> T
890where
891    T: Element + Float,
892    D: Dimension,
893{
894    crate::ufunc_methods::reduce_all(input, <T as Float>::neg_infinity(), |acc, x| {
895        if x.is_nan() {
896            acc
897        } else if x > acc {
898            x
899        } else {
900            acc
901        }
902    })
903}
904
905/// Reduce by NaN-skipping minimum along an axis with optional keepdims.
906///
907/// Equivalent to `np.nanmin(arr, axis=axis, keepdims=keepdims)`. NaN
908/// elements are skipped (treated as `+inf`).
909pub fn nan_min_reduce<T, D>(
910    input: &Array<T, D>,
911    axis: usize,
912    keepdims: bool,
913) -> FerrayResult<Array<T, IxDyn>>
914where
915    T: Element + Float,
916    D: Dimension,
917{
918    crate::ufunc_methods::reduce_axis_keepdims(
919        input,
920        axis,
921        <T as Float>::infinity(),
922        keepdims,
923        |acc, x| {
924            if x.is_nan() {
925                acc
926            } else if x < acc {
927                x
928            } else {
929                acc
930            }
931        },
932    )
933}
934
935/// Reduce by NaN-skipping minimum over multiple axes.
936pub fn nan_min_reduce_axes<T, D>(
937    input: &Array<T, D>,
938    axes: &[usize],
939    keepdims: bool,
940) -> FerrayResult<Array<T, IxDyn>>
941where
942    T: Element + Float,
943    D: Dimension,
944{
945    crate::ufunc_methods::reduce_axes(input, axes, <T as Float>::infinity(), keepdims, |acc, x| {
946        if x.is_nan() {
947            acc
948        } else if x < acc {
949            x
950        } else {
951            acc
952        }
953    })
954}
955
956/// Reduce by NaN-skipping minimum over the entire array.
957pub fn nan_min_reduce_all<T, D>(input: &Array<T, D>) -> T
958where
959    T: Element + Float,
960    D: Dimension,
961{
962    crate::ufunc_methods::reduce_all(input, <T as Float>::infinity(), |acc, x| {
963        if x.is_nan() {
964            acc
965        } else if x < acc {
966            x
967        } else {
968            acc
969        }
970    })
971}
972
973#[inline]
974fn nan_to_zero<T: Float + Element>(x: T) -> T {
975    if x.is_nan() {
976        <T as Element>::zero()
977    } else {
978        x
979    }
980}
981
982#[inline]
983fn nan_to_one<T: Float + Element>(x: T) -> T {
984    if x.is_nan() { <T as Element>::one() } else { x }
985}
986
987/// Running (cumulative) addition along an axis.
988///
989/// AC-2: `add_accumulate` produces running sums.
990pub fn add_accumulate<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<T, D>>
991where
992    T: Element + std::ops::Add<Output = T> + Copy,
993    D: Dimension,
994{
995    cumsum(input, Some(axis))
996}
997
998/// Outer product: `multiply_outer(a, b)[i, j] = a[i] * b[j]`.
999///
1000/// Equivalent to `np.multiply.outer(a, b)`. Delegates to the generic
1001/// [`crate::ufunc_methods::outer`] with the `*` kernel.
1002///
1003/// AC-3: `multiply_outer` produces correct outer product.
1004pub fn multiply_outer<T>(a: &Array<T, Ix1>, b: &Array<T, Ix1>) -> FerrayResult<Array<T, IxDyn>>
1005where
1006    T: Element + std::ops::Mul<Output = T> + Copy,
1007{
1008    crate::ufunc_methods::outer(a, b, |x, y| x * y)
1009}
1010
1011// ---------------------------------------------------------------------------
1012// Cumulative operations
1013// ---------------------------------------------------------------------------
1014
1015/// Shared cumulative kernel: build the result buffer by applying
1016/// `preprocess` to every input element, then walk it in place with
1017/// `accumulate` along `axis` (or flat if `None`). Factored out so
1018/// `cumsum`, `cumprod`, `nancumsum` and `nancumprod` all share a single
1019/// pass — previously `nancumsum`/`nancumprod` materialized a cleaned
1020/// copy and then called `cumsum`/`cumprod`, which materialized a
1021/// second buffer (#156).
1022fn cumulative_with_preprocess<T, D, Pre, Acc>(
1023    input: &Array<T, D>,
1024    axis: Option<usize>,
1025    preprocess: Pre,
1026    accumulate: Acc,
1027) -> FerrayResult<Array<T, D>>
1028where
1029    T: Element + Copy,
1030    D: Dimension,
1031    Pre: Fn(T) -> T,
1032    Acc: Fn(T, T) -> T,
1033{
1034    if let Some(ax) = axis {
1035        if ax >= input.ndim() {
1036            return Err(FerrayError::axis_out_of_bounds(ax, input.ndim()));
1037        }
1038        let shape = input.shape().to_vec();
1039        let mut result: Vec<T> = input.iter().map(|&x| preprocess(x)).collect();
1040        let mut stride = 1usize;
1041        for d in shape.iter().skip(ax + 1) {
1042            stride *= d;
1043        }
1044        let axis_len = shape[ax];
1045        let outer_size: usize = shape[..ax].iter().product();
1046        let inner_size = stride;
1047
1048        for outer in 0..outer_size {
1049            for inner in 0..inner_size {
1050                let base = outer * axis_len * inner_size + inner;
1051                for k in 1..axis_len {
1052                    let prev = base + (k - 1) * inner_size;
1053                    let curr = base + k * inner_size;
1054                    result[curr] = accumulate(result[prev], result[curr]);
1055                }
1056            }
1057        }
1058        Array::from_vec(input.dim().clone(), result)
1059    } else {
1060        let mut data: Vec<T> = input.iter().map(|&x| preprocess(x)).collect();
1061        for i in 1..data.len() {
1062            data[i] = accumulate(data[i - 1], data[i]);
1063        }
1064        Array::from_vec(input.dim().clone(), data)
1065    }
1066}
1067
1068/// Cumulative sum along an axis (or flattened if axis is None).
1069///
1070/// When `axis=None`, data is flattened and accumulated, but the result retains
1071/// the original shape (unlike `NumPy` which returns a 1-D array). This is due to
1072/// the generic return type `Array<T, D>`.
1073///
1074/// AC-11: `cumsum([1,2,3,4]) == [1,3,6,10]`.
1075pub fn cumsum<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1076where
1077    T: Element + std::ops::Add<Output = T> + Copy,
1078    D: Dimension,
1079{
1080    cumulative_with_preprocess(input, axis, |x| x, |a, b| a + b)
1081}
1082
1083/// Cumulative product along an axis (or flattened if axis is None).
1084///
1085/// When `axis=None`, data is flattened and accumulated, but the result retains
1086/// the original shape (unlike `NumPy` which returns a 1-D array). This is due to
1087/// the generic return type `Array<T, D>`.
1088pub fn cumprod<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1089where
1090    T: Element + std::ops::Mul<Output = T> + Copy,
1091    D: Dimension,
1092{
1093    cumulative_with_preprocess(input, axis, |x| x, |a, b| a * b)
1094}
1095
1096/// Cumulative sum (Array API standard name).
1097///
1098/// Alias of [`cumsum`] matching the Python Array API specification's
1099/// `cumulative_sum` name (added to `numpy` in 2.0).
1100pub fn cumulative_sum<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1101where
1102    T: Element + std::ops::Add<Output = T> + Copy,
1103    D: Dimension,
1104{
1105    cumsum(input, axis)
1106}
1107
1108/// Cumulative product (Array API standard name).
1109///
1110/// Alias of [`cumprod`] matching the Python Array API specification's
1111/// `cumulative_prod` name (added to `numpy` in 2.0).
1112pub fn cumulative_prod<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1113where
1114    T: Element + std::ops::Mul<Output = T> + Copy,
1115    D: Dimension,
1116{
1117    cumprod(input, axis)
1118}
1119
1120/// Cumulative sum ignoring NaNs.
1121pub fn nancumsum<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1122where
1123    T: Element + Float,
1124    D: Dimension,
1125{
1126    cumulative_with_preprocess(
1127        input,
1128        axis,
1129        |x| {
1130            if x.is_nan() {
1131                <T as Element>::zero()
1132            } else {
1133                x
1134            }
1135        },
1136        |a, b| a + b,
1137    )
1138}
1139
1140/// Cumulative product ignoring NaNs.
1141pub fn nancumprod<T, D>(input: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
1142where
1143    T: Element + Float,
1144    D: Dimension,
1145{
1146    cumulative_with_preprocess(
1147        input,
1148        axis,
1149        |x| {
1150            if x.is_nan() { <T as Element>::one() } else { x }
1151        },
1152        |a, b| a * b,
1153    )
1154}
1155
1156// ---------------------------------------------------------------------------
1157// Differences
1158// ---------------------------------------------------------------------------
1159
1160/// Compute the n-th discrete difference along the given axis.
1161///
1162/// AC-11: `diff([1,3,6,10], 1) == [2,3,4]`.
1163pub fn diff<T>(input: &Array<T, Ix1>, n: usize) -> FerrayResult<Array<T, Ix1>>
1164where
1165    T: Element + std::ops::Sub<Output = T> + Copy,
1166{
1167    let mut data: Vec<T> = input.iter().copied().collect();
1168    for _ in 0..n {
1169        if data.len() <= 1 {
1170            data.clear();
1171            break;
1172        }
1173        let mut new_data = Vec::with_capacity(data.len() - 1);
1174        for i in 1..data.len() {
1175            new_data.push(data[i] - data[i - 1]);
1176        }
1177        data = new_data;
1178    }
1179    Array::from_vec(Ix1::new([data.len()]), data)
1180}
1181
1182/// Differences between consecutive elements of an array, with optional
1183/// prepend/append values.
1184pub fn ediff1d<T>(
1185    input: &Array<T, Ix1>,
1186    to_end: Option<&[T]>,
1187    to_begin: Option<&[T]>,
1188) -> FerrayResult<Array<T, Ix1>>
1189where
1190    T: Element + std::ops::Sub<Output = T> + Copy,
1191{
1192    let data: Vec<T> = input.iter().copied().collect();
1193    let mut result = Vec::new();
1194
1195    if let Some(begin) = to_begin {
1196        result.extend_from_slice(begin);
1197    }
1198
1199    for i in 1..data.len() {
1200        result.push(data[i] - data[i - 1]);
1201    }
1202
1203    if let Some(end) = to_end {
1204        result.extend_from_slice(end);
1205    }
1206
1207    Array::from_vec(Ix1::new([result.len()]), result)
1208}
1209
1210/// Compute the gradient of a 1-D array using central differences.
1211///
1212/// Edge values use forward/backward differences.
1213pub fn gradient<T>(input: &Array<T, Ix1>, spacing: Option<T>) -> FerrayResult<Array<T, Ix1>>
1214where
1215    T: Element + Float,
1216{
1217    let data: Vec<T> = input.iter().copied().collect();
1218    let n = data.len();
1219    if n == 0 {
1220        return Array::from_vec(Ix1::new([0]), vec![]);
1221    }
1222    let h = spacing.unwrap_or_else(|| <T as Element>::one());
1223    let two = <T as Element>::one() + <T as Element>::one();
1224    let mut result = Vec::with_capacity(n);
1225
1226    if n == 1 {
1227        result.push(<T as Element>::zero());
1228    } else {
1229        // Forward difference for first element
1230        result.push((data[1] - data[0]) / h);
1231        // Central differences for interior
1232        for i in 1..n - 1 {
1233            result.push((data[i + 1] - data[i - 1]) / (two * h));
1234        }
1235        // Backward difference for last element
1236        result.push((data[n - 1] - data[n - 2]) / h);
1237    }
1238
1239    Array::from_vec(Ix1::new([n]), result)
1240}
1241
1242// ---------------------------------------------------------------------------
1243// Cross product
1244// ---------------------------------------------------------------------------
1245
1246/// Cross product of two 3-element 1-D arrays.
1247pub fn cross<T>(a: &Array<T, Ix1>, b: &Array<T, Ix1>) -> FerrayResult<Array<T, Ix1>>
1248where
1249    T: Element + std::ops::Mul<Output = T> + std::ops::Sub<Output = T> + Copy,
1250{
1251    if a.size() != 3 || b.size() != 3 {
1252        return Err(FerrayError::invalid_value(
1253            "cross product requires 3-element vectors",
1254        ));
1255    }
1256    let ad: Vec<T> = a.iter().copied().collect();
1257    let bd: Vec<T> = b.iter().copied().collect();
1258    let result = vec![
1259        ad[1] * bd[2] - ad[2] * bd[1],
1260        ad[2] * bd[0] - ad[0] * bd[2],
1261        ad[0] * bd[1] - ad[1] * bd[0],
1262    ];
1263    Array::from_vec(Ix1::new([3]), result)
1264}
1265
1266// ---------------------------------------------------------------------------
1267// Integration
1268// ---------------------------------------------------------------------------
1269
1270/// Integrate using the trapezoidal rule.
1271///
1272/// If `dx` is provided, it is the spacing between sample points.
1273/// If `x` is provided, it gives the sample point coordinates.
1274pub fn trapezoid<T>(y: &Array<T, Ix1>, x: Option<&Array<T, Ix1>>, dx: Option<T>) -> FerrayResult<T>
1275where
1276    T: Element + Float,
1277{
1278    let ydata: Vec<T> = y.iter().copied().collect();
1279    let n = ydata.len();
1280    if n < 2 {
1281        return Ok(<T as Element>::zero());
1282    }
1283
1284    let two = <T as Element>::one() + <T as Element>::one();
1285    let mut total = <T as Element>::zero();
1286
1287    if let Some(xarr) = x {
1288        let xdata: Vec<T> = xarr.iter().copied().collect();
1289        if xdata.len() != n {
1290            return Err(FerrayError::shape_mismatch(
1291                "x and y must have the same length for trapezoid",
1292            ));
1293        }
1294        for i in 1..n {
1295            total = total + (ydata[i] + ydata[i - 1]) / two * (xdata[i] - xdata[i - 1]);
1296        }
1297    } else {
1298        let h = dx.unwrap_or_else(|| <T as Element>::one());
1299        for i in 1..n {
1300            total = total + (ydata[i] + ydata[i - 1]) / two * h;
1301        }
1302    }
1303
1304    Ok(total)
1305}
1306
1307// ---------------------------------------------------------------------------
1308// f16 variants (f32-promoted) — generated via the shared macros (#142).
1309// ---------------------------------------------------------------------------
1310
1311use crate::helpers::{binary_f16_fn, unary_f16_fn};
1312
1313unary_f16_fn!(
1314    /// Elementwise absolute value for f16 arrays via f32 promotion.
1315    #[cfg(feature = "f16")]
1316    absolute_f16,
1317    f32::abs
1318);
1319unary_f16_fn!(
1320    /// Elementwise negation for f16 arrays via f32 promotion.
1321    #[cfg(feature = "f16")]
1322    negative_f16,
1323    |x: f32| -x
1324);
1325unary_f16_fn!(
1326    /// Elementwise square root for f16 arrays via f32 promotion.
1327    #[cfg(feature = "f16")]
1328    sqrt_f16,
1329    f32::sqrt
1330);
1331unary_f16_fn!(
1332    /// Elementwise cube root for f16 arrays via f32 promotion.
1333    #[cfg(feature = "f16")]
1334    cbrt_f16,
1335    f32::cbrt
1336);
1337unary_f16_fn!(
1338    /// Elementwise square for f16 arrays via f32 promotion.
1339    #[cfg(feature = "f16")]
1340    square_f16,
1341    |x: f32| x * x
1342);
1343unary_f16_fn!(
1344    /// Elementwise reciprocal for f16 arrays via f32 promotion.
1345    #[cfg(feature = "f16")]
1346    reciprocal_f16,
1347    f32::recip
1348);
1349unary_f16_fn!(
1350    /// Elementwise sign for f16 arrays via f32 promotion.
1351    #[cfg(feature = "f16")]
1352    sign_f16,
1353    |x: f32| {
1354        if x.is_nan() {
1355            f32::NAN
1356        } else if x > 0.0 {
1357            1.0
1358        } else if x < 0.0 {
1359            -1.0
1360        } else {
1361            0.0
1362        }
1363    }
1364);
1365binary_f16_fn!(
1366    /// Elementwise addition for f16 arrays via f32 promotion.
1367    #[cfg(feature = "f16")]
1368    add_f16,
1369    |x: f32, y: f32| x + y
1370);
1371binary_f16_fn!(
1372    /// Elementwise subtraction for f16 arrays via f32 promotion.
1373    #[cfg(feature = "f16")]
1374    subtract_f16,
1375    |x: f32, y: f32| x - y
1376);
1377binary_f16_fn!(
1378    /// Elementwise multiplication for f16 arrays via f32 promotion.
1379    #[cfg(feature = "f16")]
1380    multiply_f16,
1381    |x: f32, y: f32| x * y
1382);
1383binary_f16_fn!(
1384    /// Elementwise division for f16 arrays via f32 promotion.
1385    #[cfg(feature = "f16")]
1386    divide_f16,
1387    |x: f32, y: f32| x / y
1388);
1389binary_f16_fn!(
1390    /// Elementwise power for f16 arrays via f32 promotion.
1391    #[cfg(feature = "f16")]
1392    power_f16,
1393    f32::powf
1394);
1395binary_f16_fn!(
1396    /// Floor division for f16 arrays via f32 promotion.
1397    #[cfg(feature = "f16")]
1398    floor_divide_f16,
1399    |x: f32, y: f32| (x / y).floor()
1400);
1401binary_f16_fn!(
1402    /// Elementwise remainder for f16 arrays via f32 promotion.
1403    #[cfg(feature = "f16")]
1404    remainder_f16,
1405    |x: f32, y: f32| {
1406        let r = x % y;
1407        if (r < 0.0 && y > 0.0) || (r > 0.0 && y < 0.0) {
1408            r + y
1409        } else {
1410            r
1411        }
1412    }
1413);
1414
1415#[cfg(test)]
1416mod tests {
1417    use super::*;
1418    use ferray_core::dimension::Ix2;
1419
1420    use crate::test_util::arr1;
1421
1422    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
1423        let n = data.len();
1424        Array::from_vec(Ix1::new([n]), data).unwrap()
1425    }
1426
1427    fn arr1_f32(data: Vec<f32>) -> Array<f32, Ix1> {
1428        let n = data.len();
1429        Array::from_vec(Ix1::new([n]), data).unwrap()
1430    }
1431
1432    // ---- f32 coverage (#721) -------------------------------------------
1433    //
1434    // Existing arithmetic tests run against f64 only; the f32 SIMD
1435    // dispatch (try_simd_f32_binary / try_simd_f32_unary) is exercised
1436    // here. Arrays are sized at 32 elements so the SIMD path engages
1437    // (the threshold is >= a small multiple of the lane width).
1438
1439    #[test]
1440    fn add_f32_simd_path() {
1441        let n = 32;
1442        let a = arr1_f32((0..n).map(|i| i as f32).collect());
1443        let b = arr1_f32((0..n).map(|i| i as f32 * 2.0).collect());
1444        let r = add(&a, &b).unwrap();
1445        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1446            assert!((v - (i as f32 * 3.0)).abs() < 1e-6);
1447        }
1448    }
1449
1450    #[test]
1451    fn sub_f32_simd_path() {
1452        let n = 32;
1453        let a = arr1_f32((0..n).map(|i| i as f32 * 5.0).collect());
1454        let b = arr1_f32((0..n).map(|i| i as f32 * 2.0).collect());
1455        let r = subtract(&a, &b).unwrap();
1456        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1457            assert!((v - (i as f32 * 3.0)).abs() < 1e-6);
1458        }
1459    }
1460
1461    #[test]
1462    fn mul_f32_simd_path() {
1463        let n = 32;
1464        let a = arr1_f32((0..n).map(|i| i as f32).collect());
1465        let b = arr1_f32(vec![3.0_f32; n]);
1466        let r = multiply(&a, &b).unwrap();
1467        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1468            assert!((v - (i as f32 * 3.0)).abs() < 1e-6);
1469        }
1470    }
1471
1472    #[test]
1473    fn div_f32_simd_path() {
1474        let n = 32;
1475        let a = arr1_f32((0..n).map(|i| (i as f32 + 1.0) * 4.0).collect());
1476        let b = arr1_f32(vec![2.0_f32; n]);
1477        let r = divide(&a, &b).unwrap();
1478        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1479            assert!((v - (i as f32 + 1.0) * 2.0).abs() < 1e-6);
1480        }
1481    }
1482
1483    #[test]
1484    fn abs_f32_simd_path() {
1485        let n = 32;
1486        let a = arr1_f32(
1487            (0..n)
1488                .map(|i| if i % 2 == 0 { -(i as f32) } else { i as f32 })
1489                .collect(),
1490        );
1491        let r = absolute(&a).unwrap();
1492        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1493            assert!((v - (i as f32)).abs() < 1e-6);
1494        }
1495    }
1496
1497    #[test]
1498    fn neg_f32_simd_path() {
1499        let n = 32;
1500        let a = arr1_f32((0..n).map(|i| i as f32).collect());
1501        let r = negative(&a).unwrap();
1502        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1503            assert!((v - (-(i as f32))).abs() < 1e-6);
1504        }
1505    }
1506
1507    #[test]
1508    fn reciprocal_f32_simd_path() {
1509        let n = 32;
1510        let a = arr1_f32((1..=n).map(|i| i as f32).collect());
1511        let r = reciprocal(&a).unwrap();
1512        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1513            let want = 1.0_f32 / ((i + 1) as f32);
1514            assert!((v - want).abs() < 1e-6);
1515        }
1516    }
1517
1518    #[test]
1519    fn square_f32_simd_path() {
1520        let n = 32;
1521        let a = arr1_f32((0..n).map(|i| i as f32).collect());
1522        let r = square(&a).unwrap();
1523        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1524            let want = (i as f32) * (i as f32);
1525            assert!((v - want).abs() < 1e-4);
1526        }
1527    }
1528
1529    #[test]
1530    fn add_f32_below_simd_threshold_scalar_path() {
1531        // Tiny array — SIMD dispatch typically falls back to scalar.
1532        let a = arr1_f32(vec![1.5, 2.5, 3.5]);
1533        let b = arr1_f32(vec![0.5, 0.5, 0.5]);
1534        let r = add(&a, &b).unwrap();
1535        assert_eq!(r.as_slice().unwrap(), &[2.0, 3.0, 4.0]);
1536    }
1537
1538    #[test]
1539    fn add_f32_force_scalar_env_var() {
1540        // FERRAY_FORCE_SCALAR=1 should bypass SIMD; result must still
1541        // be correct.
1542        // SAFETY: This test is single-threaded by default per cargo
1543        // test runner; we set then unset the env var around the call.
1544        unsafe {
1545            std::env::set_var("FERRAY_FORCE_SCALAR", "1");
1546        }
1547        let a = arr1_f32((0..32).map(|i| i as f32).collect());
1548        let b = arr1_f32(vec![1.0_f32; 32]);
1549        let r = add(&a, &b).unwrap();
1550        unsafe {
1551            std::env::remove_var("FERRAY_FORCE_SCALAR");
1552        }
1553        for (i, &v) in r.as_slice().unwrap().iter().enumerate() {
1554            assert!((v - (i as f32 + 1.0)).abs() < 1e-6);
1555        }
1556    }
1557
1558    #[test]
1559    fn test_add() {
1560        let a = arr1(vec![1.0, 2.0, 3.0]);
1561        let b = arr1(vec![4.0, 5.0, 6.0]);
1562        let r = add(&a, &b).unwrap();
1563        assert_eq!(r.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
1564    }
1565
1566    #[test]
1567    fn test_subtract() {
1568        let a = arr1(vec![5.0, 7.0, 9.0]);
1569        let b = arr1(vec![1.0, 2.0, 3.0]);
1570        let r = subtract(&a, &b).unwrap();
1571        assert_eq!(r.as_slice().unwrap(), &[4.0, 5.0, 6.0]);
1572    }
1573
1574    #[test]
1575    fn test_multiply() {
1576        let a = arr1(vec![2.0, 3.0, 4.0]);
1577        let b = arr1(vec![5.0, 6.0, 7.0]);
1578        let r = multiply(&a, &b).unwrap();
1579        assert_eq!(r.as_slice().unwrap(), &[10.0, 18.0, 28.0]);
1580    }
1581
1582    #[test]
1583    fn test_divide() {
1584        let a = arr1(vec![10.0, 20.0, 30.0]);
1585        let b = arr1(vec![2.0, 4.0, 5.0]);
1586        let r = divide(&a, &b).unwrap();
1587        assert_eq!(r.as_slice().unwrap(), &[5.0, 5.0, 6.0]);
1588    }
1589
1590    #[test]
1591    fn test_floor_divide() {
1592        let a = arr1(vec![7.0, -7.0]);
1593        let b = arr1(vec![2.0, 2.0]);
1594        let r = floor_divide(&a, &b).unwrap();
1595        assert_eq!(r.as_slice().unwrap(), &[3.0, -4.0]);
1596    }
1597
1598    #[test]
1599    fn test_power() {
1600        let a = arr1(vec![2.0, 3.0]);
1601        let b = arr1(vec![3.0, 2.0]);
1602        let r = power(&a, &b).unwrap();
1603        assert_eq!(r.as_slice().unwrap(), &[8.0, 9.0]);
1604    }
1605
1606    #[test]
1607    fn test_remainder() {
1608        let a = arr1(vec![7.0, -7.0]);
1609        let b = arr1(vec![3.0, 3.0]);
1610        let r = remainder(&a, &b).unwrap();
1611        let s = r.as_slice().unwrap();
1612        assert!((s[0] - 1.0).abs() < 1e-12);
1613        assert!((s[1] - 2.0).abs() < 1e-12);
1614    }
1615
1616    #[test]
1617    fn test_fmod() {
1618        let a = arr1(vec![7.0, -7.0]);
1619        let b = arr1(vec![3.0, 3.0]);
1620        let r = fmod(&a, &b).unwrap();
1621        let s = r.as_slice().unwrap();
1622        assert!((s[0] - 1.0).abs() < 1e-12);
1623        assert!((s[1] - (-1.0)).abs() < 1e-12);
1624    }
1625
1626    #[test]
1627    fn test_absolute() {
1628        let a = arr1(vec![-1.0, 2.0, -3.0]);
1629        let r = absolute(&a).unwrap();
1630        assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1631    }
1632
1633    #[test]
1634    fn test_sign() {
1635        let a = arr1(vec![-5.0, 0.0, 3.0]);
1636        let r = sign(&a).unwrap();
1637        assert_eq!(r.as_slice().unwrap(), &[-1.0, 0.0, 1.0]);
1638    }
1639
1640    #[test]
1641    fn test_negative() {
1642        let a = arr1(vec![1.0, -2.0, 3.0]);
1643        let r = negative(&a).unwrap();
1644        assert_eq!(r.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
1645    }
1646
1647    #[test]
1648    fn test_sqrt() {
1649        let a = arr1(vec![1.0, 4.0, 9.0, 16.0]);
1650        let r = sqrt(&a).unwrap();
1651        assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
1652    }
1653
1654    #[test]
1655    fn test_cbrt() {
1656        let a = arr1(vec![8.0, 27.0]);
1657        let r = cbrt(&a).unwrap();
1658        let s = r.as_slice().unwrap();
1659        assert!((s[0] - 2.0).abs() < 1e-12);
1660        assert!((s[1] - 3.0).abs() < 1e-12);
1661    }
1662
1663    #[test]
1664    fn test_square() {
1665        let a = arr1(vec![2.0, 3.0, 4.0]);
1666        let r = square(&a).unwrap();
1667        assert_eq!(r.as_slice().unwrap(), &[4.0, 9.0, 16.0]);
1668    }
1669
1670    #[test]
1671    fn test_reciprocal() {
1672        let a = arr1(vec![2.0, 4.0, 5.0]);
1673        let r = reciprocal(&a).unwrap();
1674        assert_eq!(r.as_slice().unwrap(), &[0.5, 0.25, 0.2]);
1675    }
1676
1677    #[test]
1678    fn test_heaviside() {
1679        let x = arr1(vec![-1.0, 0.0, 1.0]);
1680        let h0 = arr1(vec![0.5, 0.5, 0.5]);
1681        let r = heaviside(&x, &h0).unwrap();
1682        assert_eq!(r.as_slice().unwrap(), &[0.0, 0.5, 1.0]);
1683    }
1684
1685    #[test]
1686    fn test_gcd() {
1687        let a = arr1(vec![12.0, 15.0]);
1688        let b = arr1(vec![8.0, 25.0]);
1689        let r = gcd(&a, &b).unwrap();
1690        assert_eq!(r.as_slice().unwrap(), &[4.0, 5.0]);
1691    }
1692
1693    #[test]
1694    fn test_lcm() {
1695        let a = arr1(vec![4.0, 6.0]);
1696        let b = arr1(vec![6.0, 8.0]);
1697        let r = lcm(&a, &b).unwrap();
1698        assert_eq!(r.as_slice().unwrap(), &[12.0, 24.0]);
1699    }
1700
1701    #[test]
1702    fn test_gcd_int() {
1703        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![12, 15, 0]).unwrap();
1704        let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![8, 25, 7]).unwrap();
1705        let r = gcd_int(&a, &b).unwrap();
1706        assert_eq!(r.as_slice().unwrap(), &[4, 5, 7]);
1707    }
1708
1709    #[test]
1710    fn test_lcm_int() {
1711        let a = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![4, 6, 0]).unwrap();
1712        let b = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![6, 8, 5]).unwrap();
1713        let r = lcm_int(&a, &b).unwrap();
1714        assert_eq!(r.as_slice().unwrap(), &[12, 24, 0]);
1715    }
1716
1717    #[test]
1718    fn test_gcd_int_negative() {
1719        let a = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![-12, 15]).unwrap();
1720        let b = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![8, -25]).unwrap();
1721        let r = gcd_int(&a, &b).unwrap();
1722        assert_eq!(r.as_slice().unwrap(), &[4, 5]);
1723    }
1724
1725    #[test]
1726    fn test_cumsum_ac11() {
1727        // AC-11: cumsum([1,2,3,4]) == [1,3,6,10]
1728        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
1729        let r = cumsum(&a, None).unwrap();
1730        assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 10.0]);
1731    }
1732
1733    #[test]
1734    fn test_cumsum_i32() {
1735        let a = arr1_i32(vec![1, 2, 3, 4]);
1736        let r = cumsum(&a, None).unwrap();
1737        assert_eq!(r.as_slice().unwrap(), &[1, 3, 6, 10]);
1738    }
1739
1740    #[test]
1741    fn test_cumprod() {
1742        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
1743        let r = cumprod(&a, None).unwrap();
1744        assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 6.0, 24.0]);
1745    }
1746
1747    #[test]
1748    fn test_cumulative_sum_alias() {
1749        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
1750        let r = cumulative_sum(&a, None).unwrap();
1751        assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 10.0]);
1752    }
1753
1754    #[test]
1755    fn test_cumulative_prod_alias() {
1756        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
1757        let r = cumulative_prod(&a, None).unwrap();
1758        assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 6.0, 24.0]);
1759    }
1760
1761    #[test]
1762    fn test_diff_ac11() {
1763        // AC-11: diff([1,3,6,10], 1) == [2,3,4]
1764        let a = arr1(vec![1.0, 3.0, 6.0, 10.0]);
1765        let r = diff(&a, 1).unwrap();
1766        assert_eq!(r.as_slice().unwrap(), &[2.0, 3.0, 4.0]);
1767    }
1768
1769    #[test]
1770    fn test_diff_n2() {
1771        let a = arr1(vec![1.0, 3.0, 6.0, 10.0]);
1772        let r = diff(&a, 2).unwrap();
1773        assert_eq!(r.as_slice().unwrap(), &[1.0, 1.0]);
1774    }
1775
1776    #[test]
1777    fn test_ediff1d() {
1778        let a = arr1(vec![1.0, 2.0, 4.0, 7.0]);
1779        let r = ediff1d(&a, None, None).unwrap();
1780        assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1781    }
1782
1783    #[test]
1784    fn test_gradient() {
1785        let a = arr1(vec![1.0, 2.0, 4.0, 7.0, 11.0]);
1786        let r = gradient(&a, None).unwrap();
1787        let s = r.as_slice().unwrap();
1788        // forward: 2-1=1, central: (4-1)/2=1.5, (7-2)/2=2.5, (11-4)/2=3.5, backward: 11-7=4
1789        assert!((s[0] - 1.0).abs() < 1e-12);
1790        assert!((s[1] - 1.5).abs() < 1e-12);
1791        assert!((s[2] - 2.5).abs() < 1e-12);
1792        assert!((s[3] - 3.5).abs() < 1e-12);
1793        assert!((s[4] - 4.0).abs() < 1e-12);
1794    }
1795
1796    #[test]
1797    fn test_cross() {
1798        let a = arr1(vec![1.0, 0.0, 0.0]);
1799        let b = arr1(vec![0.0, 1.0, 0.0]);
1800        let r = cross(&a, &b).unwrap();
1801        assert_eq!(r.as_slice().unwrap(), &[0.0, 0.0, 1.0]);
1802    }
1803
1804    #[test]
1805    fn test_trapezoid() {
1806        // Integrate y=x from 0 to 4: area = 8
1807        let y = arr1(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1808        let r = trapezoid(&y, None, Some(1.0)).unwrap();
1809        assert!((r - 8.0).abs() < 1e-12);
1810    }
1811
1812    #[test]
1813    fn test_trapezoid_with_x() {
1814        let y = arr1(vec![0.0, 1.0, 4.0]);
1815        let x = arr1(vec![0.0, 1.0, 2.0]);
1816        let r = trapezoid(&y, Some(&x), None).unwrap();
1817        // (0+1)/2*1 + (1+4)/2*1 = 0.5 + 2.5 = 3.0
1818        assert!((r - 3.0).abs() < 1e-12);
1819    }
1820
1821    #[test]
1822    fn test_add_reduce_ac2() {
1823        // AC-2: add_reduce computes correct column sums
1824        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1825            .unwrap();
1826        let r = add_reduce(&a, 0).unwrap();
1827        assert_eq!(r.shape(), &[3]);
1828        let s: Vec<f64> = r.iter().copied().collect();
1829        assert_eq!(s, vec![5.0, 7.0, 9.0]);
1830    }
1831
1832    #[test]
1833    fn test_add_accumulate_ac2() {
1834        let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
1835        let r = add_accumulate(&a, 0).unwrap();
1836        assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 10.0]);
1837    }
1838
1839    #[test]
1840    fn add_reduce_keepdims_true_preserves_row_axis() {
1841        // (2,3) + axis=1 + keepdims=true → (2,1) with row sums.
1842        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1843            .unwrap();
1844        let r = add_reduce_keepdims(&a, 1, true).unwrap();
1845        assert_eq!(r.shape(), &[2, 1]);
1846        assert_eq!(r.as_slice().unwrap(), &[6.0, 15.0]);
1847    }
1848
1849    #[test]
1850    fn add_reduce_keepdims_true_preserves_col_axis() {
1851        // (2,3) + axis=0 + keepdims=true → (1,3) with column sums.
1852        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1853            .unwrap();
1854        let r = add_reduce_keepdims(&a, 0, true).unwrap();
1855        assert_eq!(r.shape(), &[1, 3]);
1856        assert_eq!(r.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
1857    }
1858
1859    #[test]
1860    fn add_reduce_keepdims_false_matches_legacy_add_reduce() {
1861        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1862            .unwrap();
1863        let legacy = add_reduce(&a, 1).unwrap();
1864        let new_false = add_reduce_keepdims(&a, 1, false).unwrap();
1865        assert_eq!(legacy.shape(), new_false.shape());
1866        assert_eq!(legacy.as_slice().unwrap(), new_false.as_slice().unwrap());
1867    }
1868
1869    #[test]
1870    fn add_reduce_axes_two_axes_3d() {
1871        // (2, 3, 4) reducing axes (0, 2) → length-3 result.
1872        use ferray_core::dimension::Ix3;
1873        let data: Vec<f64> = (0..24).map(f64::from).collect();
1874        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
1875        let r = add_reduce_axes(&a, &[0, 2], false).unwrap();
1876        assert_eq!(r.shape(), &[3]);
1877        // For each j in 0..3: sum_{i,k} (i*12 + j*4 + k)
1878        let expected: Vec<f64> = (0..3)
1879            .map(|j| {
1880                let mut s = 0.0;
1881                for i in 0..2 {
1882                    for k in 0..4 {
1883                        s += f64::from(i * 12 + j * 4 + k);
1884                    }
1885                }
1886                s
1887            })
1888            .collect();
1889        assert_eq!(r.as_slice().unwrap(), expected.as_slice());
1890    }
1891
1892    #[test]
1893    fn add_reduce_axes_keepdims_preserves_rank() {
1894        use ferray_core::dimension::Ix3;
1895        let data: Vec<f64> = (0..24).map(f64::from).collect();
1896        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
1897        let r = add_reduce_axes(&a, &[0, 2], true).unwrap();
1898        assert_eq!(r.shape(), &[1, 3, 1]);
1899    }
1900
1901    #[test]
1902    fn add_reduce_axes_all_axes_collapses_to_scalar_array() {
1903        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1904            .unwrap();
1905        let r = add_reduce_axes(&a, &[0, 1], false).unwrap();
1906        assert_eq!(r.shape(), &[1]);
1907        assert_eq!(r.as_slice().unwrap(), &[21.0]);
1908    }
1909
1910    #[test]
1911    fn add_reduce_all_returns_scalar_sum() {
1912        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1913            .unwrap();
1914        let s = add_reduce_all(&a);
1915        assert!((s - 21.0).abs() < 1e-12);
1916    }
1917
1918    #[test]
1919    fn add_reduce_all_integer_input_works() {
1920        // Multi-axis reductions must work for integer Element types too.
1921        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
1922        let s = add_reduce_all(&a);
1923        assert_eq!(s, 21);
1924    }
1925
1926    // ---- nan-aware reductions (#388) ----
1927
1928    #[test]
1929    fn nan_add_reduce_all_skips_nans() {
1930        let a = arr1(vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]);
1931        let s = nan_add_reduce_all(&a);
1932        assert!((s - 9.0).abs() < 1e-12);
1933    }
1934
1935    #[test]
1936    fn nan_add_reduce_all_nans_only_returns_zero() {
1937        let a = arr1(vec![f64::NAN, f64::NAN]);
1938        let s = nan_add_reduce_all(&a);
1939        assert!((s - 0.0).abs() < 1e-12);
1940    }
1941
1942    #[test]
1943    fn nan_add_reduce_axis_skips_nans_per_row() {
1944        // (2, 3) with row-1 having a NaN; reduce axis=1 → row sums.
1945        let a =
1946            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, f64::NAN, 6.0])
1947                .unwrap();
1948        let r = nan_add_reduce(&a, 1, false).unwrap();
1949        assert_eq!(r.shape(), &[2]);
1950        let s = r.as_slice().unwrap();
1951        assert!((s[0] - 6.0).abs() < 1e-12);
1952        assert!((s[1] - 10.0).abs() < 1e-12);
1953    }
1954
1955    #[test]
1956    fn nan_add_reduce_axes_multi_axis_skips_nans() {
1957        use ferray_core::dimension::Ix3;
1958        // (2, 2, 2) with one NaN; reduce axes (0, 2).
1959        let data = vec![1.0, 2.0, 3.0, 4.0, f64::NAN, 6.0, 7.0, 8.0];
1960        let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 2, 2]), data).unwrap();
1961        let r = nan_add_reduce_axes(&a, &[0, 2], false).unwrap();
1962        assert_eq!(r.shape(), &[2]);
1963        // For j=0: sum(1, 2, NaN→0, 6) = 9.0
1964        // For j=1: sum(3, 4, 7, 8) = 22.0
1965        let s = r.as_slice().unwrap();
1966        assert!((s[0] - 9.0).abs() < 1e-12);
1967        assert!((s[1] - 22.0).abs() < 1e-12);
1968    }
1969
1970    #[test]
1971    fn nan_multiply_reduce_all_skips_nans() {
1972        let a = arr1(vec![2.0, f64::NAN, 3.0, f64::NAN, 4.0]);
1973        let p = nan_multiply_reduce_all(&a);
1974        assert!((p - 24.0).abs() < 1e-12);
1975    }
1976
1977    #[test]
1978    fn nan_multiply_reduce_all_nans_only_returns_one() {
1979        let a = arr1(vec![f64::NAN, f64::NAN]);
1980        let p = nan_multiply_reduce_all(&a);
1981        assert!((p - 1.0).abs() < 1e-12);
1982    }
1983
1984    #[test]
1985    fn nan_multiply_reduce_axis_per_row() {
1986        let a =
1987            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![2.0, 3.0, 4.0, 5.0, f64::NAN, 6.0])
1988                .unwrap();
1989        let r = nan_multiply_reduce(&a, 1, false).unwrap();
1990        let s = r.as_slice().unwrap();
1991        assert!((s[0] - 24.0).abs() < 1e-12); // 2*3*4
1992        assert!((s[1] - 30.0).abs() < 1e-12); // 5*1*6
1993    }
1994
1995    #[test]
1996    fn nan_max_reduce_all_skips_nans() {
1997        let a = arr1(vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0, 2.0]);
1998        let m = nan_max_reduce_all(&a);
1999        assert!((m - 5.0).abs() < 1e-12);
2000    }
2001
2002    #[test]
2003    fn nan_max_reduce_all_nans_only_returns_neg_infinity() {
2004        let a = arr1(vec![f64::NAN, f64::NAN]);
2005        let m = nan_max_reduce_all(&a);
2006        assert!(m.is_infinite() && m.is_sign_negative());
2007    }
2008
2009    #[test]
2010    fn nan_max_reduce_axis_per_row_with_nans() {
2011        let a = Array::<f64, Ix2>::from_vec(
2012            Ix2::new([2, 3]),
2013            vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0, 4.0],
2014        )
2015        .unwrap();
2016        let r = nan_max_reduce(&a, 1, false).unwrap();
2017        let s = r.as_slice().unwrap();
2018        assert!((s[0] - 3.0).abs() < 1e-12);
2019        assert!((s[1] - 5.0).abs() < 1e-12);
2020    }
2021
2022    #[test]
2023    fn nan_min_reduce_all_skips_nans() {
2024        let a = arr1(vec![5.0, f64::NAN, 3.0, f64::NAN, 1.0, 4.0]);
2025        let m = nan_min_reduce_all(&a);
2026        assert!((m - 1.0).abs() < 1e-12);
2027    }
2028
2029    #[test]
2030    fn nan_min_reduce_all_nans_only_returns_infinity() {
2031        let a = arr1(vec![f64::NAN, f64::NAN]);
2032        let m = nan_min_reduce_all(&a);
2033        assert!(m.is_infinite() && m.is_sign_positive());
2034    }
2035
2036    #[test]
2037    fn nan_min_reduce_axis_per_row_with_nans() {
2038        let a = Array::<f64, Ix2>::from_vec(
2039            Ix2::new([2, 3]),
2040            vec![5.0, f64::NAN, 3.0, f64::NAN, 5.0, 4.0],
2041        )
2042        .unwrap();
2043        let r = nan_min_reduce(&a, 1, false).unwrap();
2044        let s = r.as_slice().unwrap();
2045        assert!((s[0] - 3.0).abs() < 1e-12);
2046        assert!((s[1] - 4.0).abs() < 1e-12);
2047    }
2048
2049    #[test]
2050    fn nan_reductions_with_no_nans_match_regular_reductions() {
2051        // When the input has no NaNs the nan-aware versions must give
2052        // the exact same result as the regular reductions.
2053        let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
2054        assert!((nan_add_reduce_all(&a) - 15.0).abs() < 1e-12);
2055        assert!((nan_multiply_reduce_all(&a) - 120.0).abs() < 1e-12);
2056        assert!((nan_max_reduce_all(&a) - 5.0).abs() < 1e-12);
2057        assert!((nan_min_reduce_all(&a) - 1.0).abs() < 1e-12);
2058    }
2059
2060    #[test]
2061    fn nan_add_reduce_keepdims_preserves_axis() {
2062        let a =
2063            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, f64::NAN, 6.0])
2064                .unwrap();
2065        let r = nan_add_reduce(&a, 1, true).unwrap();
2066        assert_eq!(r.shape(), &[2, 1]);
2067    }
2068
2069    #[test]
2070    fn test_multiply_outer_ac3() {
2071        // AC-3: multiply_outer produces correct outer product
2072        let a = arr1(vec![1.0, 2.0, 3.0]);
2073        let b = arr1(vec![4.0, 5.0]);
2074        let r = multiply_outer(&a, &b).unwrap();
2075        assert_eq!(r.shape(), &[3, 2]);
2076        let s: Vec<f64> = r.iter().copied().collect();
2077        assert_eq!(s, vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]);
2078    }
2079
2080    #[test]
2081    fn test_nancumsum() {
2082        let a = arr1(vec![1.0, f64::NAN, 3.0, 4.0]);
2083        let r = nancumsum(&a, None).unwrap();
2084        let s = r.as_slice().unwrap();
2085        assert_eq!(s[0], 1.0);
2086        assert_eq!(s[1], 1.0); // NaN treated as 0
2087        assert_eq!(s[2], 4.0);
2088        assert_eq!(s[3], 8.0);
2089    }
2090
2091    #[test]
2092    fn test_nancumprod() {
2093        let a = arr1(vec![1.0, f64::NAN, 3.0, 4.0]);
2094        let r = nancumprod(&a, None).unwrap();
2095        let s = r.as_slice().unwrap();
2096        assert_eq!(s[0], 1.0);
2097        assert_eq!(s[1], 1.0); // NaN treated as 1
2098        assert_eq!(s[2], 3.0);
2099        assert_eq!(s[3], 12.0);
2100    }
2101
2102    #[test]
2103    fn test_add_broadcast() {
2104        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
2105        let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
2106        let r = add_broadcast(&a, &b).unwrap();
2107        assert_eq!(r.shape(), &[2, 3]);
2108    }
2109
2110    #[test]
2111    fn test_divmod() {
2112        let a = arr1(vec![7.0, -7.0]);
2113        let b = arr1(vec![3.0, 3.0]);
2114        let (q, r) = divmod(&a, &b).unwrap();
2115        assert_eq!(q.as_slice().unwrap(), &[2.0, -3.0]);
2116        let rs = r.as_slice().unwrap();
2117        assert!((rs[0] - 1.0).abs() < 1e-12);
2118        assert!((rs[1] - 2.0).abs() < 1e-12);
2119    }
2120
2121    #[test]
2122    fn test_positive() {
2123        let a = arr1(vec![-1.0, 2.0]);
2124        let r = positive(&a).unwrap();
2125        assert_eq!(r.as_slice().unwrap(), &[-1.0, 2.0]);
2126    }
2127
2128    #[test]
2129    fn test_true_divide() {
2130        let a = arr1(vec![10.0, 20.0]);
2131        let b = arr1(vec![3.0, 7.0]);
2132        let r = true_divide(&a, &b).unwrap();
2133        let s = r.as_slice().unwrap();
2134        assert!((s[0] - 10.0 / 3.0).abs() < 1e-12);
2135        assert!((s[1] - 20.0 / 7.0).abs() < 1e-12);
2136    }
2137
2138    // -----------------------------------------------------------------------
2139    // Broadcasting tests for arithmetic ops (issue #379)
2140    // -----------------------------------------------------------------------
2141
2142    #[test]
2143    fn test_add_broadcasts_within_same_rank() {
2144        // (3, 1) + (1, 4) -> (3, 4) — both Ix2
2145        let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
2146        let row =
2147            Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
2148        let r = add(&col, &row).unwrap();
2149        assert_eq!(r.shape(), &[3, 4]);
2150        assert_eq!(
2151            r.iter().copied().collect::<Vec<_>>(),
2152            vec![
2153                11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0,
2154            ]
2155        );
2156    }
2157
2158    #[test]
2159    fn test_subtract_broadcasts() {
2160        let a =
2161            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
2162                .unwrap();
2163        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
2164        let r = subtract(&a, &b).unwrap();
2165        assert_eq!(r.shape(), &[2, 3]);
2166        assert_eq!(
2167            r.iter().copied().collect::<Vec<_>>(),
2168            vec![9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
2169        );
2170    }
2171
2172    #[test]
2173    fn test_multiply_broadcasts() {
2174        let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
2175        let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
2176        let r = multiply(&col, &row).unwrap();
2177        assert_eq!(r.shape(), &[3, 3]);
2178        assert_eq!(
2179            r.iter().copied().collect::<Vec<_>>(),
2180            vec![10, 20, 30, 20, 40, 60, 30, 60, 90]
2181        );
2182    }
2183
2184    #[test]
2185    fn test_divide_broadcasts() {
2186        let a =
2187            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
2188                .unwrap();
2189        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 5.0, 2.0]).unwrap();
2190        let r = divide(&a, &b).unwrap();
2191        assert_eq!(r.shape(), &[2, 3]);
2192        assert_eq!(
2193            r.iter().copied().collect::<Vec<_>>(),
2194            vec![1.0, 4.0, 15.0, 4.0, 10.0, 30.0]
2195        );
2196    }
2197
2198    #[test]
2199    fn test_power_broadcasts() {
2200        let bases = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![2.0, 3.0, 4.0]).unwrap();
2201        let exps = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
2202        let r = power(&bases, &exps).unwrap();
2203        assert_eq!(r.shape(), &[3, 3]);
2204        assert_eq!(
2205            r.iter().copied().collect::<Vec<_>>(),
2206            vec![2.0, 4.0, 8.0, 3.0, 9.0, 27.0, 4.0, 16.0, 64.0]
2207        );
2208    }
2209
2210    #[test]
2211    fn test_remainder_broadcasts() {
2212        let a =
2213            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0])
2214                .unwrap();
2215        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![3.0, 4.0, 5.0]).unwrap();
2216        let r = remainder(&a, &b).unwrap();
2217        assert_eq!(r.shape(), &[2, 3]);
2218        assert_eq!(
2219            r.iter().copied().collect::<Vec<_>>(),
2220            vec![1.0, 0.0, 4.0, 1.0, 3.0, 2.0]
2221        );
2222    }
2223
2224    #[test]
2225    fn test_divmod_broadcasts() {
2226        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![7.0, 13.0]).unwrap();
2227        let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, 3.0, 4.0]).unwrap();
2228        // Need both inputs to have the same D for the typed divmod entry point.
2229        // Use the cross-rank broadcast helper instead — but divmod is typed,
2230        // so route via an explicit Ix2 reshape of b.
2231        let b2 = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![2.0, 3.0, 4.0]).unwrap();
2232        let (q, r) = divmod(&a, &b2).unwrap();
2233        assert_eq!(q.shape(), &[2, 3]);
2234        assert_eq!(r.shape(), &[2, 3]);
2235        // a broadcasts to [[7,7,7],[13,13,13]], b broadcasts to [[2,3,4],[2,3,4]]
2236        // divmod(7,2)=(3,1), divmod(7,3)=(2,1), divmod(7,4)=(1,3)
2237        // divmod(13,2)=(6,1), divmod(13,3)=(4,1), divmod(13,4)=(3,1)
2238        let q_vec: Vec<f64> = q.iter().copied().collect();
2239        let r_vec: Vec<f64> = r.iter().copied().collect();
2240        assert_eq!(q_vec, vec![3.0, 2.0, 1.0, 6.0, 4.0, 3.0]);
2241        assert_eq!(r_vec, vec![1.0, 1.0, 3.0, 1.0, 1.0, 1.0]);
2242        let _ = b; // silence unused
2243    }
2244
2245    #[test]
2246    fn test_gcd_int_broadcasts() {
2247        let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![12, 18, 24]).unwrap();
2248        let b = Array::<i32, Ix2>::from_vec(Ix2::new([1, 2]), vec![8, 9]).unwrap();
2249        let r = gcd_int(&a, &b).unwrap();
2250        assert_eq!(r.shape(), &[3, 2]);
2251        // gcd(12,8)=4, gcd(12,9)=3, gcd(18,8)=2, gcd(18,9)=9, gcd(24,8)=8, gcd(24,9)=3
2252        assert_eq!(
2253            r.iter().copied().collect::<Vec<_>>(),
2254            vec![4, 3, 2, 9, 8, 3]
2255        );
2256    }
2257
2258    #[test]
2259    fn test_lcm_int_broadcasts() {
2260        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 1]), vec![4, 6]).unwrap();
2261        let b = Array::<i32, Ix2>::from_vec(Ix2::new([1, 2]), vec![6, 8]).unwrap();
2262        let r = lcm_int(&a, &b).unwrap();
2263        assert_eq!(r.shape(), &[2, 2]);
2264        // lcm(4,6)=12, lcm(4,8)=8, lcm(6,6)=6, lcm(6,8)=24
2265        assert_eq!(r.iter().copied().collect::<Vec<_>>(), vec![12, 8, 6, 24]);
2266    }
2267
2268    #[test]
2269    fn test_add_incompatible_shapes_errors() {
2270        let a = arr1(vec![1.0, 2.0, 3.0]);
2271        let b = arr1(vec![1.0, 2.0, 3.0, 4.0]);
2272        assert!(add(&a, &b).is_err());
2273    }
2274
2275    #[cfg(feature = "f16")]
2276    mod f16_tests {
2277        use super::*;
2278
2279        fn arr1_f16(data: &[f32]) -> Array<half::f16, Ix1> {
2280            let n = data.len();
2281            let vals: Vec<half::f16> = data.iter().map(|&x| half::f16::from_f32(x)).collect();
2282            Array::from_vec(Ix1::new([n]), vals).unwrap()
2283        }
2284
2285        #[test]
2286        fn test_add_f16() {
2287            let a = arr1_f16(&[1.0, 2.0, 3.0]);
2288            let b = arr1_f16(&[4.0, 5.0, 6.0]);
2289            let r = add_f16(&a, &b).unwrap();
2290            let s = r.as_slice().unwrap();
2291            assert!((s[0].to_f32() - 5.0).abs() < 0.01);
2292            assert!((s[1].to_f32() - 7.0).abs() < 0.01);
2293            assert!((s[2].to_f32() - 9.0).abs() < 0.01);
2294        }
2295
2296        #[test]
2297        fn test_multiply_f16() {
2298            let a = arr1_f16(&[2.0, 3.0]);
2299            let b = arr1_f16(&[4.0, 5.0]);
2300            let r = multiply_f16(&a, &b).unwrap();
2301            let s = r.as_slice().unwrap();
2302            assert!((s[0].to_f32() - 8.0).abs() < 0.01);
2303            assert!((s[1].to_f32() - 15.0).abs() < 0.1);
2304        }
2305
2306        #[test]
2307        fn test_sqrt_f16() {
2308            let a = arr1_f16(&[1.0, 4.0, 9.0, 16.0]);
2309            let r = sqrt_f16(&a).unwrap();
2310            let s = r.as_slice().unwrap();
2311            assert!((s[0].to_f32() - 1.0).abs() < 0.01);
2312            assert!((s[1].to_f32() - 2.0).abs() < 0.01);
2313            assert!((s[2].to_f32() - 3.0).abs() < 0.01);
2314            assert!((s[3].to_f32() - 4.0).abs() < 0.01);
2315        }
2316
2317        #[test]
2318        fn test_absolute_f16() {
2319            let a = arr1_f16(&[-1.0, 2.0, -3.0]);
2320            let r = absolute_f16(&a).unwrap();
2321            let s = r.as_slice().unwrap();
2322            assert!((s[0].to_f32() - 1.0).abs() < 0.01);
2323            assert!((s[1].to_f32() - 2.0).abs() < 0.01);
2324            assert!((s[2].to_f32() - 3.0).abs() < 0.01);
2325        }
2326
2327        #[test]
2328        fn test_power_f16() {
2329            let a = arr1_f16(&[2.0, 3.0]);
2330            let b = arr1_f16(&[3.0, 2.0]);
2331            let r = power_f16(&a, &b).unwrap();
2332            let s = r.as_slice().unwrap();
2333            assert!((s[0].to_f32() - 8.0).abs() < 0.1);
2334            assert!((s[1].to_f32() - 9.0).abs() < 0.1);
2335        }
2336
2337        #[test]
2338        fn test_divide_f16() {
2339            let a = arr1_f16(&[10.0, 20.0]);
2340            let b = arr1_f16(&[2.0, 4.0]);
2341            let r = divide_f16(&a, &b).unwrap();
2342            let s = r.as_slice().unwrap();
2343            assert!((s[0].to_f32() - 5.0).abs() < 0.01);
2344            assert!((s[1].to_f32() - 5.0).abs() < 0.01);
2345        }
2346    }
2347
2348    // -----------------------------------------------------------------------
2349    // In-place (_into) variants (issue #378)
2350    // -----------------------------------------------------------------------
2351
2352    mod into_tests {
2353        use super::*;
2354        use ferray_core::Array;
2355        use ferray_core::dimension::Ix1;
2356
2357        fn arr(data: &[f64]) -> Array<f64, Ix1> {
2358            Array::<f64, Ix1>::from_vec(Ix1::new([data.len()]), data.to_vec()).unwrap()
2359        }
2360
2361        #[test]
2362        fn add_into_writes_result() {
2363            let a = arr(&[1.0, 2.0, 3.0]);
2364            let b = arr(&[10.0, 20.0, 30.0]);
2365            let mut out = arr(&[0.0, 0.0, 0.0]);
2366            add_into(&a, &b, &mut out).unwrap();
2367            assert_eq!(out.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
2368        }
2369
2370        #[test]
2371        fn subtract_into_writes_result() {
2372            let a = arr(&[10.0, 20.0, 30.0]);
2373            let b = arr(&[1.0, 2.0, 3.0]);
2374            let mut out = arr(&[0.0, 0.0, 0.0]);
2375            subtract_into(&a, &b, &mut out).unwrap();
2376            assert_eq!(out.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
2377        }
2378
2379        #[test]
2380        fn multiply_into_writes_result() {
2381            let a = arr(&[1.0, 2.0, 3.0]);
2382            let b = arr(&[4.0, 5.0, 6.0]);
2383            let mut out = arr(&[0.0; 3]);
2384            multiply_into(&a, &b, &mut out).unwrap();
2385            assert_eq!(out.as_slice().unwrap(), &[4.0, 10.0, 18.0]);
2386        }
2387
2388        #[test]
2389        fn divide_into_writes_result() {
2390            let a = arr(&[10.0, 20.0, 30.0]);
2391            let b = arr(&[2.0, 4.0, 6.0]);
2392            let mut out = arr(&[0.0; 3]);
2393            divide_into(&a, &b, &mut out).unwrap();
2394            assert_eq!(out.as_slice().unwrap(), &[5.0, 5.0, 5.0]);
2395        }
2396
2397        #[test]
2398        fn add_into_shape_mismatch_errors() {
2399            let a = arr(&[1.0, 2.0, 3.0]);
2400            let b = arr(&[1.0, 2.0]);
2401            let mut out = arr(&[0.0, 0.0, 0.0]);
2402            assert!(add_into(&a, &b, &mut out).is_err());
2403        }
2404
2405        #[test]
2406        fn add_into_out_shape_mismatch_errors() {
2407            let a = arr(&[1.0, 2.0, 3.0]);
2408            let b = arr(&[4.0, 5.0, 6.0]);
2409            let mut out = arr(&[0.0, 0.0]); // wrong size
2410            assert!(add_into(&a, &b, &mut out).is_err());
2411        }
2412
2413        #[test]
2414        fn sqrt_into_writes_result() {
2415            let a = arr(&[1.0, 4.0, 9.0, 16.0]);
2416            let mut out = arr(&[0.0; 4]);
2417            sqrt_into(&a, &mut out).unwrap();
2418            assert_eq!(out.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
2419        }
2420
2421        #[test]
2422        fn square_into_writes_result() {
2423            let a = arr(&[1.0, -2.0, 3.0, -4.0]);
2424            let mut out = arr(&[0.0; 4]);
2425            square_into(&a, &mut out).unwrap();
2426            assert_eq!(out.as_slice().unwrap(), &[1.0, 4.0, 9.0, 16.0]);
2427        }
2428
2429        #[test]
2430        fn absolute_into_writes_result() {
2431            let a = arr(&[-1.0, 2.0, -3.0]);
2432            let mut out = arr(&[0.0; 3]);
2433            absolute_into(&a, &mut out).unwrap();
2434            assert_eq!(out.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
2435        }
2436
2437        #[test]
2438        fn negative_into_writes_result() {
2439            let a = arr(&[1.0, -2.0, 3.0]);
2440            let mut out = arr(&[0.0; 3]);
2441            negative_into(&a, &mut out).unwrap();
2442            assert_eq!(out.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
2443        }
2444
2445        #[test]
2446        fn into_variants_are_chainable_no_alloc() {
2447            // A realistic pattern: apply a pipeline in-place over and over
2448            // without touching the allocator after initial setup.
2449            let mut state = arr(&[1.0, 2.0, 3.0, 4.0]);
2450            let ones = arr(&[1.0; 4]);
2451            let mut scratch = arr(&[0.0; 4]);
2452            for _ in 0..100 {
2453                add_into(&state, &ones, &mut scratch).unwrap();
2454                std::mem::swap(&mut state, &mut scratch);
2455            }
2456            // After 100 increments of 1: [101, 102, 103, 104]
2457            assert_eq!(state.as_slice().unwrap(), &[101.0, 102.0, 103.0, 104.0]);
2458        }
2459
2460        #[test]
2461        fn exp_into_matches_exp() {
2462            use crate::ops::explog::{exp, exp_into};
2463            let a = arr(&[0.0, 1.0, 2.0]);
2464            let expected = exp(&a).unwrap();
2465            let mut out = arr(&[0.0; 3]);
2466            exp_into(&a, &mut out).unwrap();
2467            for (&x, &y) in expected
2468                .as_slice()
2469                .unwrap()
2470                .iter()
2471                .zip(out.as_slice().unwrap().iter())
2472            {
2473                assert!((x - y).abs() < 1e-14);
2474            }
2475        }
2476
2477        #[test]
2478        fn sin_into_matches_sin() {
2479            use crate::ops::trig::{sin, sin_into};
2480            let a = arr(&[0.0, std::f64::consts::FRAC_PI_2, std::f64::consts::PI]);
2481            let expected = sin(&a).unwrap();
2482            let mut out = arr(&[0.0; 3]);
2483            sin_into(&a, &mut out).unwrap();
2484            for (&x, &y) in expected
2485                .as_slice()
2486                .unwrap()
2487                .iter()
2488                .zip(out.as_slice().unwrap().iter())
2489            {
2490                assert!((x - y).abs() < 1e-14);
2491            }
2492        }
2493
2494        #[test]
2495        fn cos_into_matches_cos() {
2496            use crate::ops::trig::{cos, cos_into};
2497            let a = arr(&[0.0, std::f64::consts::FRAC_PI_2, std::f64::consts::PI]);
2498            let expected = cos(&a).unwrap();
2499            let mut out = arr(&[0.0; 3]);
2500            cos_into(&a, &mut out).unwrap();
2501            for (&x, &y) in expected
2502                .as_slice()
2503                .unwrap()
2504                .iter()
2505                .zip(out.as_slice().unwrap().iter())
2506            {
2507                assert!((x - y).abs() < 1e-14);
2508            }
2509        }
2510
2511        #[test]
2512        fn log_into_matches_log() {
2513            use crate::ops::explog::{log, log_into};
2514            let a = arr(&[1.0, std::f64::consts::E, 10.0]);
2515            let expected = log(&a).unwrap();
2516            let mut out = arr(&[0.0; 3]);
2517            log_into(&a, &mut out).unwrap();
2518            for (&x, &y) in expected
2519                .as_slice()
2520                .unwrap()
2521                .iter()
2522                .zip(out.as_slice().unwrap().iter())
2523            {
2524                assert!((x - y).abs() < 1e-14);
2525            }
2526        }
2527    }
2528}