Skip to main content

ferray_core/
ops.rs

1// ferray-core: Operator overloading for Array<T, D>
2//
3// Implements std::ops::{Add, Sub, Mul, Div, Rem, Neg} with
4// Output = FerrayResult<Array<T, D>>.
5//
6// Users write `(a + b)?` to get the result, maintaining the zero-panic
7// guarantee while enabling natural math syntax.
8//
9// Broadcasting: when both operands share the same dimension type `D`,
10// the operators broadcast along compatible axes (NumPy rules). Cross-rank
11// broadcasting (e.g. Ix1 + Ix2) is exposed via `add_broadcast` etc., which
12// return `Array<T, IxDyn>` because the result rank cannot be expressed in
13// the type system without specialization.
14//
15// See: https://github.com/dollspace-gay/ferray/issues/7
16// Broadcasting: https://github.com/dollspace-gay/ferray/issues/346
17
18use crate::array::owned::Array;
19use crate::dimension::Dimension;
20use crate::dimension::IxDyn;
21use crate::dimension::broadcast::{broadcast_shapes, broadcast_to};
22use crate::dtype::Element;
23use crate::error::{FerrayError, FerrayResult};
24
25/// Elementwise binary operation on two same-D arrays, with NumPy broadcasting.
26///
27/// - If shapes match exactly, takes the fast path (zip iter, no broadcast).
28/// - Otherwise, broadcasts both inputs to the common shape and applies `op`.
29///
30/// Both inputs share dimension type `D`, so the broadcast result also has
31/// rank `D::NDIM` (or, for `IxDyn`, the maximum of the two ranks). The result
32/// is reconstructed via [`Dimension::from_dim_slice`].
33///
34/// # Errors
35/// Returns `FerrayError::ShapeMismatch` if the shapes are not broadcast-compatible.
36fn elementwise_binary<T, D, F>(
37    a: &Array<T, D>,
38    b: &Array<T, D>,
39    op: F,
40    op_name: &str,
41) -> FerrayResult<Array<T, D>>
42where
43    T: Element + Copy,
44    D: Dimension,
45    F: Fn(T, T) -> T,
46{
47    // Fast path: identical shapes — no broadcasting needed.
48    if a.shape() == b.shape() {
49        let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
50        return Array::from_vec(a.dim().clone(), data);
51    }
52
53    // Broadcasting path.
54    let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
55        FerrayError::shape_mismatch(format!(
56            "operator {}: shapes {:?} and {:?} are not broadcast-compatible",
57            op_name,
58            a.shape(),
59            b.shape()
60        ))
61    })?;
62
63    let a_view = broadcast_to(a, &target_shape)?;
64    let b_view = broadcast_to(b, &target_shape)?;
65
66    let data: Vec<T> = a_view
67        .iter()
68        .zip(b_view.iter())
69        .map(|(&x, &y)| op(x, y))
70        .collect();
71
72    let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
73        FerrayError::shape_mismatch(format!(
74            "operator {}: cannot represent broadcast result shape {:?} as the input dimension type",
75            op_name, target_shape
76        ))
77    })?;
78
79    Array::from_vec(result_dim, data)
80}
81
82/// Cross-rank broadcasting helper: apply a binary op to two arrays with
83/// possibly different dimension types, returning a dynamic-rank result.
84///
85/// This is the primitive behind [`Array::add_broadcast`], [`Array::sub_broadcast`],
86/// etc. Always returns `Array<T, IxDyn>` because the result rank depends
87/// on input shapes at runtime.
88fn elementwise_binary_dyn<T, D1, D2, F>(
89    a: &Array<T, D1>,
90    b: &Array<T, D2>,
91    op: F,
92    op_name: &str,
93) -> FerrayResult<Array<T, IxDyn>>
94where
95    T: Element + Copy,
96    D1: Dimension,
97    D2: Dimension,
98    F: Fn(T, T) -> T,
99{
100    let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
101        FerrayError::shape_mismatch(format!(
102            "{}: shapes {:?} and {:?} are not broadcast-compatible",
103            op_name,
104            a.shape(),
105            b.shape()
106        ))
107    })?;
108
109    let a_view = broadcast_to(a, &target_shape)?;
110    let b_view = broadcast_to(b, &target_shape)?;
111
112    let data: Vec<T> = a_view
113        .iter()
114        .zip(b_view.iter())
115        .map(|(&x, &y)| op(x, y))
116        .collect();
117
118    Array::from_vec(IxDyn::from(&target_shape[..]), data)
119}
120
121/// Implement a binary operator for all ownership combinations of Array.
122///
123/// Generates impls for:
124///   &Array op &Array
125///   Array  op Array
126///   Array  op &Array
127///   &Array op Array
128macro_rules! impl_binary_op {
129    ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
130        // &Array op &Array
131        impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
132        where
133            T: Element + Copy + std::ops::$trait<Output = T>,
134            D: Dimension,
135        {
136            type Output = FerrayResult<Array<T, D>>;
137
138            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
139                elementwise_binary(self, rhs, $op_fn, $op_name)
140            }
141        }
142
143        // Array op Array
144        impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
145        where
146            T: Element + Copy + std::ops::$trait<Output = T>,
147            D: Dimension,
148        {
149            type Output = FerrayResult<Array<T, D>>;
150
151            fn $method(self, rhs: Array<T, D>) -> Self::Output {
152                elementwise_binary(&self, &rhs, $op_fn, $op_name)
153            }
154        }
155
156        // Array op &Array
157        impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
158        where
159            T: Element + Copy + std::ops::$trait<Output = T>,
160            D: Dimension,
161        {
162            type Output = FerrayResult<Array<T, D>>;
163
164            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
165                elementwise_binary(&self, rhs, $op_fn, $op_name)
166            }
167        }
168
169        // &Array op Array
170        impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
171        where
172            T: Element + Copy + std::ops::$trait<Output = T>,
173            D: Dimension,
174        {
175            type Output = FerrayResult<Array<T, D>>;
176
177            fn $method(self, rhs: Array<T, D>) -> Self::Output {
178                elementwise_binary(self, &rhs, $op_fn, $op_name)
179            }
180        }
181    };
182}
183
184impl_binary_op!(Add, add, |a, b| a + b, "+");
185impl_binary_op!(Sub, sub, |a, b| a - b, "-");
186impl_binary_op!(Mul, mul, |a, b| a * b, "*");
187impl_binary_op!(Div, div, |a, b| a / b, "/");
188impl_binary_op!(Rem, rem, |a, b| a % b, "%");
189
190// ---------------------------------------------------------------------------
191// Scalar-array operations: Array op scalar and scalar op Array
192// ---------------------------------------------------------------------------
193
194/// Implement scalar-array binary operators (Array op T and &Array op T).
195macro_rules! impl_scalar_op {
196    ($trait:ident, $method:ident, $op_fn:expr) => {
197        // &Array op scalar
198        impl<T, D> std::ops::$trait<T> for &Array<T, D>
199        where
200            T: Element + Copy + std::ops::$trait<Output = T>,
201            D: Dimension,
202        {
203            type Output = FerrayResult<Array<T, D>>;
204
205            fn $method(self, rhs: T) -> Self::Output {
206                let data: Vec<T> = self.iter().map(|&x| $op_fn(x, rhs)).collect();
207                Array::from_vec(self.dim().clone(), data)
208            }
209        }
210
211        // Array op scalar
212        impl<T, D> std::ops::$trait<T> for Array<T, D>
213        where
214            T: Element + Copy + std::ops::$trait<Output = T>,
215            D: Dimension,
216        {
217            type Output = FerrayResult<Array<T, D>>;
218
219            fn $method(self, rhs: T) -> Self::Output {
220                (&self).$method(rhs)
221            }
222        }
223    };
224}
225
226impl_scalar_op!(Add, add, |a, b| a + b);
227impl_scalar_op!(Sub, sub, |a, b| a - b);
228impl_scalar_op!(Mul, mul, |a, b| a * b);
229impl_scalar_op!(Div, div, |a, b| a / b);
230impl_scalar_op!(Rem, rem, |a, b| a % b);
231
232// Unary negation: -&Array and -Array
233impl<T, D> std::ops::Neg for &Array<T, D>
234where
235    T: Element + Copy + std::ops::Neg<Output = T>,
236    D: Dimension,
237{
238    type Output = FerrayResult<Array<T, D>>;
239
240    fn neg(self) -> Self::Output {
241        let data: Vec<T> = self.iter().map(|&x| -x).collect();
242        Array::from_vec(self.dim().clone(), data)
243    }
244}
245
246impl<T, D> std::ops::Neg for Array<T, D>
247where
248    T: Element + Copy + std::ops::Neg<Output = T>,
249    D: Dimension,
250{
251    type Output = FerrayResult<Array<T, D>>;
252
253    fn neg(self) -> Self::Output {
254        -&self
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Cross-rank broadcasting methods
260//
261// These methods accept an operand of any dimension type and return a
262// dynamic-rank result. They handle the case where the std::ops operators
263// can't apply because the type system requires both operands to share `D`.
264//
265// Example: a (Ix1, shape (3,)) + b (Ix2, shape (2,1)) -> Array<T, IxDyn> shape (2,3)
266// ---------------------------------------------------------------------------
267
268impl<T, D> Array<T, D>
269where
270    T: Element + Copy,
271    D: Dimension,
272{
273    /// Elementwise add with NumPy broadcasting across arbitrary ranks.
274    ///
275    /// Returns a dynamic-rank `Array<T, IxDyn>` so that mixed-rank inputs
276    /// (e.g. 1D + 2D) can produce a result whose rank is determined at
277    /// runtime. For same-rank inputs prefer the `+` operator.
278    ///
279    /// # Errors
280    /// Returns `FerrayError::ShapeMismatch` if shapes are not broadcast-compatible.
281    pub fn add_broadcast<D2: Dimension>(
282        &self,
283        other: &Array<T, D2>,
284    ) -> FerrayResult<Array<T, IxDyn>>
285    where
286        T: std::ops::Add<Output = T>,
287    {
288        elementwise_binary_dyn(self, other, |x, y| x + y, "add_broadcast")
289    }
290
291    /// Elementwise subtract with NumPy broadcasting across arbitrary ranks.
292    ///
293    /// See [`Array::add_broadcast`] for details.
294    pub fn sub_broadcast<D2: Dimension>(
295        &self,
296        other: &Array<T, D2>,
297    ) -> FerrayResult<Array<T, IxDyn>>
298    where
299        T: std::ops::Sub<Output = T>,
300    {
301        elementwise_binary_dyn(self, other, |x, y| x - y, "sub_broadcast")
302    }
303
304    /// Elementwise multiply with NumPy broadcasting across arbitrary ranks.
305    ///
306    /// See [`Array::add_broadcast`] for details.
307    pub fn mul_broadcast<D2: Dimension>(
308        &self,
309        other: &Array<T, D2>,
310    ) -> FerrayResult<Array<T, IxDyn>>
311    where
312        T: std::ops::Mul<Output = T>,
313    {
314        elementwise_binary_dyn(self, other, |x, y| x * y, "mul_broadcast")
315    }
316
317    /// Elementwise divide with NumPy broadcasting across arbitrary ranks.
318    ///
319    /// See [`Array::add_broadcast`] for details.
320    pub fn div_broadcast<D2: Dimension>(
321        &self,
322        other: &Array<T, D2>,
323    ) -> FerrayResult<Array<T, IxDyn>>
324    where
325        T: std::ops::Div<Output = T>,
326    {
327        elementwise_binary_dyn(self, other, |x, y| x / y, "div_broadcast")
328    }
329
330    /// Elementwise remainder with NumPy broadcasting across arbitrary ranks.
331    ///
332    /// See [`Array::add_broadcast`] for details.
333    pub fn rem_broadcast<D2: Dimension>(
334        &self,
335        other: &Array<T, D2>,
336    ) -> FerrayResult<Array<T, IxDyn>>
337    where
338        T: std::ops::Rem<Output = T>,
339    {
340        elementwise_binary_dyn(self, other, |x, y| x % y, "rem_broadcast")
341    }
342}
343
344// ---------------------------------------------------------------------------
345// In-place operators (#348)
346//
347// NumPy supports `arr += 5`, `arr *= other` — mutation without allocation.
348// We split these into two groups:
349//
350// 1. Scalar in-place: `std::ops::*Assign<T> for Array<T, D>`. Always safe
351//    (scalar ops never fail), implemented via `mapv_inplace`.
352//
353// 2. Array-array in-place: inherent fallible methods `add_inplace`,
354//    `sub_inplace`, `mul_inplace`, `div_inplace`, `rem_inplace` returning
355//    `FerrayResult<()>`. We cannot implement `std::ops::AddAssign<&Array>`
356//    because the trait signature returns `()`, leaving no channel for
357//    shape-mismatch errors — incompatible with ferray's zero-panic rule.
358//    Use the `*_inplace` methods instead of `arr += other`.
359//
360//    The RHS is broadcast to `self.shape()` (NumPy semantics for in-place
361//    arithmetic: the destination shape is fixed and the source must be
362//    broadcastable to it).
363// ---------------------------------------------------------------------------
364
365/// Implement a scalar `*Assign` operator trait using `mapv_inplace`.
366macro_rules! impl_scalar_op_assign {
367    ($trait:ident, $method:ident, $op:tt) => {
368        impl<T, D> std::ops::$trait<T> for Array<T, D>
369        where
370            T: Element + Copy + std::ops::$trait,
371            D: Dimension,
372        {
373            fn $method(&mut self, rhs: T) {
374                self.mapv_inplace(|mut x| {
375                    x $op rhs;
376                    x
377                });
378            }
379        }
380    };
381}
382
383impl_scalar_op_assign!(AddAssign, add_assign, +=);
384impl_scalar_op_assign!(SubAssign, sub_assign, -=);
385impl_scalar_op_assign!(MulAssign, mul_assign, *=);
386impl_scalar_op_assign!(DivAssign, div_assign, /=);
387impl_scalar_op_assign!(RemAssign, rem_assign, %=);
388
389/// Shared implementation for in-place array-array ops with broadcasting.
390///
391/// Fast path on identical shapes uses `zip_mut_with`. Otherwise the RHS is
392/// broadcast into `self.shape()` via [`broadcast_to`] and then zipped with
393/// `self.iter_mut()` in logical order.
394fn inplace_binary<T, D, F>(
395    lhs: &mut Array<T, D>,
396    rhs: &Array<T, D>,
397    op: F,
398    op_name: &str,
399) -> FerrayResult<()>
400where
401    T: Element + Copy,
402    D: Dimension,
403    F: Fn(T, T) -> T,
404{
405    // Fast path: identical shapes — delegate to zip_mut_with.
406    if lhs.shape() == rhs.shape() {
407        return lhs.zip_mut_with(rhs, |a, b| *a = op(*a, *b));
408    }
409
410    // Broadcasting path: rhs must broadcast into lhs.shape() (the destination
411    // shape is fixed for in-place operations). Any shape change on the LHS
412    // would require a reallocation, defeating the purpose of `*_inplace`.
413    let target_shape: Vec<usize> = lhs.shape().to_vec();
414    let rhs_view = broadcast_to(rhs, &target_shape).map_err(|_| {
415        FerrayError::shape_mismatch(format!(
416            "{}: shape {:?} cannot be broadcast into destination shape {:?}",
417            op_name,
418            rhs.shape(),
419            target_shape
420        ))
421    })?;
422
423    for (a, b) in lhs.iter_mut().zip(rhs_view.iter()) {
424        *a = op(*a, *b);
425    }
426    Ok(())
427}
428
429impl<T, D> Array<T, D>
430where
431    T: Element + Copy,
432    D: Dimension,
433{
434    /// In-place elementwise add: `self[i] += other[i]`. `other` is broadcast
435    /// into `self.shape()`; the destination shape never changes.
436    ///
437    /// Prefer this over `self = (&self + &other)?` for large arrays — it
438    /// avoids allocating a new result buffer.
439    ///
440    /// # Errors
441    /// Returns `FerrayError::ShapeMismatch` if `other.shape()` cannot be
442    /// broadcast into `self.shape()`.
443    pub fn add_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
444    where
445        T: std::ops::Add<Output = T>,
446    {
447        inplace_binary(self, other, |a, b| a + b, "add_inplace")
448    }
449
450    /// In-place elementwise subtract. See [`Array::add_inplace`].
451    pub fn sub_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
452    where
453        T: std::ops::Sub<Output = T>,
454    {
455        inplace_binary(self, other, |a, b| a - b, "sub_inplace")
456    }
457
458    /// In-place elementwise multiply. See [`Array::add_inplace`].
459    pub fn mul_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
460    where
461        T: std::ops::Mul<Output = T>,
462    {
463        inplace_binary(self, other, |a, b| a * b, "mul_inplace")
464    }
465
466    /// In-place elementwise divide. See [`Array::add_inplace`].
467    pub fn div_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
468    where
469        T: std::ops::Div<Output = T>,
470    {
471        inplace_binary(self, other, |a, b| a / b, "div_inplace")
472    }
473
474    /// In-place elementwise remainder. See [`Array::add_inplace`].
475    pub fn rem_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
476    where
477        T: std::ops::Rem<Output = T>,
478    {
479        inplace_binary(self, other, |a, b| a % b, "rem_inplace")
480    }
481}
482
483// ---------------------------------------------------------------------------
484// copyto: broadcasted elementwise assignment (#352)
485//
486// NumPy parity for `np.copyto(dst, src)`: copies values from `src` into `dst`,
487// broadcasting `src` into `dst.shape()`. The destination shape is fixed — a
488// copyto never reallocates `dst`. `casting=` is not modeled because ferray's
489// type system already enforces matching element types; use `astype()` before
490// copyto if conversion is desired.
491//
492// Kept separate from the in-place arithmetic ops because copyto is the
493// "fundamental building block" called out in the issue: it has to handle
494// cross-rank src/dst pairs (e.g. `Ix1` into `Ix2` via broadcasting), whereas
495// `*_inplace` stay within one dimension type `D` to match std::ops::*Assign.
496// ---------------------------------------------------------------------------
497
498/// Copy values from `src` into `dst`, broadcasting `src` to `dst.shape()`.
499///
500/// Equivalent to `np.copyto(dst, src)` without the `casting=` or `where=`
501/// parameters. `src` may have any dimension type — it only needs to be
502/// broadcast-compatible with `dst.shape()`. The destination shape is fixed,
503/// so `src` can never grow the destination.
504///
505/// # Errors
506/// Returns `FerrayError::ShapeMismatch` if `src.shape()` cannot be broadcast
507/// into `dst.shape()`. On error `dst` is left untouched.
508pub fn copyto<T, D1, D2>(dst: &mut Array<T, D1>, src: &Array<T, D2>) -> FerrayResult<()>
509where
510    T: Element,
511    D1: Dimension,
512    D2: Dimension,
513{
514    // Fast path: identical shapes — straight iter-zip, no broadcast machinery.
515    if dst.shape() == src.shape() {
516        for (d, s) in dst.iter_mut().zip(src.iter()) {
517            *d = s.clone();
518        }
519        return Ok(());
520    }
521
522    // Broadcasting path: src must broadcast into dst.shape(). We validate up
523    // front so that a shape error leaves dst completely untouched.
524    let target_shape: Vec<usize> = dst.shape().to_vec();
525    let src_view = broadcast_to(src, &target_shape).map_err(|_| {
526        FerrayError::shape_mismatch(format!(
527            "copyto: source shape {:?} cannot be broadcast into destination shape {:?}",
528            src.shape(),
529            target_shape
530        ))
531    })?;
532
533    for (d, s) in dst.iter_mut().zip(src_view.iter()) {
534        *d = s.clone();
535    }
536    Ok(())
537}
538
539/// Copy values from `src` into `dst` where `mask` is `true`, broadcasting
540/// both `src` and `mask` to `dst.shape()`.
541///
542/// Equivalent to `np.copyto(dst, src, where=mask)`. This is the generalized
543/// form of NumPy's `where=` ufunc parameter (#353): positions where the mask
544/// is `false` are left untouched in `dst`. All three shapes must be
545/// broadcast-compatible with `dst.shape()`; the destination shape is fixed,
546/// so neither `src` nor `mask` can grow `dst`.
547///
548/// # Errors
549/// Returns `FerrayError::ShapeMismatch` if either `src.shape()` or
550/// `mask.shape()` cannot be broadcast into `dst.shape()`. On error `dst` is
551/// left untouched (broadcast validation happens before any writes).
552pub fn copyto_where<T, D1, D2, D3>(
553    dst: &mut Array<T, D1>,
554    src: &Array<T, D2>,
555    mask: &Array<bool, D3>,
556) -> FerrayResult<()>
557where
558    T: Element,
559    D1: Dimension,
560    D2: Dimension,
561    D3: Dimension,
562{
563    // Validate + materialize broadcast views BEFORE writing, so that a shape
564    // mismatch can't leave dst in a partially-updated state.
565    let target_shape: Vec<usize> = dst.shape().to_vec();
566
567    let src_view = broadcast_to(src, &target_shape).map_err(|_| {
568        FerrayError::shape_mismatch(format!(
569            "copyto_where: source shape {:?} cannot be broadcast into destination shape {:?}",
570            src.shape(),
571            target_shape
572        ))
573    })?;
574
575    let mask_view = broadcast_to(mask, &target_shape).map_err(|_| {
576        FerrayError::shape_mismatch(format!(
577            "copyto_where: mask shape {:?} cannot be broadcast into destination shape {:?}",
578            mask.shape(),
579            target_shape
580        ))
581    })?;
582
583    for ((d, s), &m) in dst.iter_mut().zip(src_view.iter()).zip(mask_view.iter()) {
584        if m {
585            *d = s.clone();
586        }
587    }
588    Ok(())
589}
590
591impl<T, D> Array<T, D>
592where
593    T: Element,
594    D: Dimension,
595{
596    /// Copy values from `src` into `self`, broadcasting `src` to `self.shape()`.
597    ///
598    /// See [`copyto`] for the free-function form and full semantics.
599    ///
600    /// # Errors
601    /// Returns `FerrayError::ShapeMismatch` if `src` cannot be broadcast into
602    /// `self.shape()`. On error `self` is left untouched.
603    pub fn copy_from<D2: Dimension>(&mut self, src: &Array<T, D2>) -> FerrayResult<()> {
604        copyto(self, src)
605    }
606
607    /// Copy values from `src` into `self` at positions where `mask` is `true`,
608    /// broadcasting both inputs to `self.shape()`.
609    ///
610    /// See [`copyto_where`] for the free-function form and full semantics.
611    ///
612    /// # Errors
613    /// Returns `FerrayError::ShapeMismatch` if `src` or `mask` cannot be
614    /// broadcast into `self.shape()`. On error `self` is left untouched.
615    pub fn copy_from_where<D2: Dimension, D3: Dimension>(
616        &mut self,
617        src: &Array<T, D2>,
618        mask: &Array<bool, D3>,
619    ) -> FerrayResult<()> {
620        copyto_where(self, src, mask)
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use crate::dimension::Ix1;
628
629    fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
630        let n = data.len();
631        Array::from_vec(Ix1::new([n]), data).unwrap()
632    }
633
634    fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
635        let n = data.len();
636        Array::from_vec(Ix1::new([n]), data).unwrap()
637    }
638
639    #[test]
640    fn test_add_ref_ref() {
641        let a = arr(vec![1.0, 2.0, 3.0]);
642        let b = arr(vec![4.0, 5.0, 6.0]);
643        let c = (&a + &b).unwrap();
644        assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
645    }
646
647    #[test]
648    fn test_add_owned_owned() {
649        let a = arr(vec![1.0, 2.0]);
650        let b = arr(vec![3.0, 4.0]);
651        let c = (a + b).unwrap();
652        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
653    }
654
655    #[test]
656    fn test_add_mixed() {
657        let a = arr(vec![1.0, 2.0]);
658        let b = arr(vec![3.0, 4.0]);
659        let c = (a + &b).unwrap();
660        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
661
662        let d = arr(vec![10.0, 20.0]);
663        let e = (&b + d).unwrap();
664        assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
665    }
666
667    #[test]
668    fn test_sub() {
669        let a = arr(vec![5.0, 7.0]);
670        let b = arr(vec![1.0, 2.0]);
671        let c = (&a - &b).unwrap();
672        assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
673    }
674
675    #[test]
676    fn test_mul() {
677        let a = arr(vec![2.0, 3.0]);
678        let b = arr(vec![4.0, 5.0]);
679        let c = (&a * &b).unwrap();
680        assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
681    }
682
683    #[test]
684    fn test_div() {
685        let a = arr(vec![10.0, 20.0]);
686        let b = arr(vec![2.0, 5.0]);
687        let c = (&a / &b).unwrap();
688        assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
689    }
690
691    #[test]
692    fn test_rem() {
693        let a = arr_i32(vec![7, 10]);
694        let b = arr_i32(vec![3, 4]);
695        let c = (&a % &b).unwrap();
696        assert_eq!(c.as_slice().unwrap(), &[1, 2]);
697    }
698
699    #[test]
700    fn test_neg() {
701        let a = arr(vec![1.0, -2.0, 3.0]);
702        let b = (-&a).unwrap();
703        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
704    }
705
706    #[test]
707    fn test_neg_owned() {
708        let a = arr(vec![1.0, -2.0]);
709        let b = (-a).unwrap();
710        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
711    }
712
713    #[test]
714    fn test_shape_mismatch_errors() {
715        let a = arr(vec![1.0, 2.0]);
716        let b = arr(vec![1.0, 2.0, 3.0]);
717        let result = &a + &b;
718        assert!(result.is_err());
719    }
720
721    // --- Scalar-array operations ---
722
723    #[test]
724    fn test_add_scalar() {
725        let a = arr(vec![1.0, 2.0, 3.0]);
726        let c = (&a + 10.0).unwrap();
727        assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
728    }
729
730    #[test]
731    fn test_sub_scalar() {
732        let a = arr(vec![10.0, 20.0, 30.0]);
733        let c = (&a - 5.0).unwrap();
734        assert_eq!(c.as_slice().unwrap(), &[5.0, 15.0, 25.0]);
735    }
736
737    #[test]
738    fn test_mul_scalar() {
739        let a = arr(vec![1.0, 2.0, 3.0]);
740        let c = (&a * 3.0).unwrap();
741        assert_eq!(c.as_slice().unwrap(), &[3.0, 6.0, 9.0]);
742    }
743
744    #[test]
745    fn test_div_scalar() {
746        let a = arr(vec![10.0, 20.0, 30.0]);
747        let c = (&a / 10.0).unwrap();
748        assert_eq!(c.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
749    }
750
751    #[test]
752    fn test_rem_scalar() {
753        let a = arr_i32(vec![7, 10, 15]);
754        let c = (&a % 4).unwrap();
755        assert_eq!(c.as_slice().unwrap(), &[3, 2, 3]);
756    }
757
758    #[test]
759    fn test_scalar_op_owned() {
760        let a = arr(vec![1.0, 2.0, 3.0]);
761        let c = (a + 10.0).unwrap();
762        assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
763    }
764
765    #[test]
766    fn test_chained_ops() {
767        let a = arr(vec![1.0, 2.0, 3.0]);
768        let b = arr(vec![4.0, 5.0, 6.0]);
769        let c = arr(vec![10.0, 10.0, 10.0]);
770        // (a + b)? * c)?
771        let result = (&(&a + &b).unwrap() * &c).unwrap();
772        assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
773    }
774
775    // -----------------------------------------------------------------------
776    // Broadcasting tests (issue #346)
777    // -----------------------------------------------------------------------
778
779    use crate::dimension::{Ix2, Ix3, IxDyn};
780
781    #[test]
782    fn test_broadcast_2d_row_plus_column() {
783        // (3, 1) + (1, 4) -> (3, 4) — both Ix2
784        let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
785        let row =
786            Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
787        let result = (&col + &row).unwrap();
788        assert_eq!(result.shape(), &[3, 4]);
789        assert_eq!(
790            result.as_slice().unwrap(),
791            &[
792                11.0, 21.0, 31.0, 41.0, // row 1
793                12.0, 22.0, 32.0, 42.0, // row 2
794                13.0, 23.0, 33.0, 43.0, // row 3
795            ]
796        );
797    }
798
799    #[test]
800    fn test_broadcast_2d_stretch_one_axis() {
801        // (3, 4) + (1, 4) -> (3, 4)
802        let a = Array::<f64, Ix2>::from_vec(
803            Ix2::new([3, 4]),
804            vec![
805                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
806            ],
807        )
808        .unwrap();
809        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![100.0, 200.0, 300.0, 400.0])
810            .unwrap();
811        let result = (&a + &b).unwrap();
812        assert_eq!(result.shape(), &[3, 4]);
813        assert_eq!(
814            result.as_slice().unwrap(),
815            &[
816                101.0, 202.0, 303.0, 404.0, 105.0, 206.0, 307.0, 408.0, 109.0, 210.0, 311.0, 412.0,
817            ]
818        );
819    }
820
821    #[test]
822    fn test_broadcast_3d_with_2d_axis() {
823        // (2, 3, 4) - (1, 3, 4) -> (2, 3, 4) — both Ix3
824        let a =
825            Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), (1..=24).map(|i| i as f64).collect())
826                .unwrap();
827        let b =
828            Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (1..=12).map(|i| i as f64).collect())
829                .unwrap();
830        let result = (&a - &b).unwrap();
831        assert_eq!(result.shape(), &[2, 3, 4]);
832        // First 12 elements: a[0..12] - b
833        let first_half: Vec<f64> = (1..=12).map(|_| 0.0).collect();
834        assert_eq!(&result.as_slice().unwrap()[..12], &first_half[..]);
835        // Second 12 elements: a[12..24] - b == 12 each
836        let second_half: Vec<f64> = (0..12).map(|_| 12.0).collect();
837        assert_eq!(&result.as_slice().unwrap()[12..], &second_half[..]);
838    }
839
840    #[test]
841    fn test_broadcast_incompatible_shapes_error() {
842        // (3,) + (4,) — incompatible
843        let a = arr(vec![1.0, 2.0, 3.0]);
844        let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
845        let result = &a + &b;
846        assert!(result.is_err());
847
848        // (3, 4) + (3, 5) — incompatible
849        let c = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![0.0; 12]).unwrap();
850        let d = Array::<f64, Ix2>::from_vec(Ix2::new([3, 5]), vec![0.0; 15]).unwrap();
851        assert!((&c + &d).is_err());
852    }
853
854    #[test]
855    fn test_broadcast_mul_2d() {
856        // (3, 1) * (1, 3) -> (3, 3) outer-product style
857        let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
858        let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
859        let result = (&col * &row).unwrap();
860        assert_eq!(result.shape(), &[3, 3]);
861        assert_eq!(
862            result.as_slice().unwrap(),
863            &[10, 20, 30, 20, 40, 60, 30, 60, 90]
864        );
865    }
866
867    // -----------------------------------------------------------------------
868    // Cross-rank broadcasting via `add_broadcast` etc.
869    // -----------------------------------------------------------------------
870
871    #[test]
872    fn test_add_broadcast_1d_plus_2d() {
873        // (3,) + (2, 3) -> (2, 3)
874        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
875        let m =
876            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
877                .unwrap();
878        let result = v.add_broadcast(&m).unwrap();
879        assert_eq!(result.shape(), &[2, 3]);
880        assert_eq!(
881            result.as_slice().unwrap(),
882            &[11.0, 22.0, 33.0, 41.0, 52.0, 63.0]
883        );
884    }
885
886    #[test]
887    fn test_add_broadcast_1d_plus_column() {
888        // (3,) + (2, 1) -> (2, 3) — the canonical NumPy example from issue #346
889        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
890        let col = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
891        let result = v.add_broadcast(&col).unwrap();
892        assert_eq!(result.shape(), &[2, 3]);
893        assert_eq!(
894            result.as_slice().unwrap(),
895            &[11.0, 12.0, 13.0, 21.0, 22.0, 23.0]
896        );
897    }
898
899    #[test]
900    fn test_sub_broadcast_2d_minus_1d() {
901        // (2, 3) - (3,) -> (2, 3)
902        let m =
903            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
904                .unwrap();
905        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
906        let result = m.sub_broadcast(&v).unwrap();
907        assert_eq!(result.shape(), &[2, 3]);
908        assert_eq!(
909            result.as_slice().unwrap(),
910            &[9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
911        );
912    }
913
914    #[test]
915    fn test_mul_broadcast_returns_dyn() {
916        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
917        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
918        let result: Array<f64, IxDyn> = a.mul_broadcast(&b).unwrap();
919        assert_eq!(result.shape(), &[2, 3]);
920        assert_eq!(
921            result.as_slice().unwrap(),
922            &[10.0, 20.0, 30.0, 20.0, 40.0, 60.0]
923        );
924    }
925
926    #[test]
927    fn test_div_broadcast_incompatible() {
928        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
929        let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
930        assert!(a.div_broadcast(&b).is_err());
931    }
932
933    #[test]
934    fn test_rem_broadcast_2d() {
935        let a =
936            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
937        let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![3, 7, 11]).unwrap();
938        let result = a.rem_broadcast(&b).unwrap();
939        assert_eq!(result.shape(), &[2, 3]);
940        assert_eq!(
941            result.as_slice().unwrap(),
942            &[10 % 3, 20 % 7, 30 % 11, 40 % 3, 50 % 7, 60 % 11]
943        );
944    }
945
946    // ------------------------------------------------------------------
947    // #348: in-place operators
948    // ------------------------------------------------------------------
949
950    #[test]
951    fn scalar_add_assign_mutates_in_place() {
952        let mut a = arr(vec![1.0, 2.0, 3.0]);
953        a += 10.0;
954        assert_eq!(a.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
955    }
956
957    #[test]
958    fn scalar_sub_mul_div_rem_assign() {
959        let mut a = arr(vec![10.0, 20.0, 30.0]);
960        a -= 1.0;
961        assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
962        a *= 2.0;
963        assert_eq!(a.as_slice().unwrap(), &[18.0, 38.0, 58.0]);
964        a /= 2.0;
965        assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
966        let mut b = arr_i32(vec![10, 11, 12]);
967        b %= 3;
968        assert_eq!(b.as_slice().unwrap(), &[1, 2, 0]);
969    }
970
971    #[test]
972    fn scalar_assign_preserves_shape_ix2() {
973        let mut a =
974            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
975                .unwrap();
976        a += 1.0;
977        assert_eq!(a.shape(), &[2, 3]);
978        assert_eq!(a.as_slice().unwrap(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
979    }
980
981    #[test]
982    fn add_inplace_same_shape_fast_path() {
983        let mut a = arr(vec![1.0, 2.0, 3.0]);
984        let b = arr(vec![10.0, 20.0, 30.0]);
985        a.add_inplace(&b).unwrap();
986        assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
987    }
988
989    #[test]
990    fn sub_mul_div_rem_inplace_same_shape() {
991        let mut a = arr(vec![10.0, 20.0, 30.0]);
992        let b = arr(vec![1.0, 2.0, 3.0]);
993        a.sub_inplace(&b).unwrap();
994        assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
995        a.mul_inplace(&b).unwrap();
996        assert_eq!(a.as_slice().unwrap(), &[9.0, 36.0, 81.0]);
997        a.div_inplace(&b).unwrap();
998        assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
999        let mut c = arr_i32(vec![10, 20, 30]);
1000        let d = arr_i32(vec![3, 7, 11]);
1001        c.rem_inplace(&d).unwrap();
1002        assert_eq!(c.as_slice().unwrap(), &[1, 6, 8]);
1003    }
1004
1005    #[test]
1006    fn add_inplace_broadcasts_rhs_into_lhs_shape() {
1007        // (2, 3) += (1, 3) — RHS row broadcast across LHS rows.
1008        let mut a =
1009            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1010                .unwrap();
1011        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1012        a.add_inplace(&b).unwrap();
1013        assert_eq!(a.shape(), &[2, 3]);
1014        assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
1015    }
1016
1017    #[test]
1018    fn add_inplace_broadcasts_column_into_rows() {
1019        // (2, 3) += (2, 1) — RHS column broadcast across LHS columns.
1020        let mut a =
1021            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1022                .unwrap();
1023        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![100.0, 200.0]).unwrap();
1024        a.add_inplace(&b).unwrap();
1025        assert_eq!(
1026            a.as_slice().unwrap(),
1027            &[101.0, 102.0, 103.0, 204.0, 205.0, 206.0]
1028        );
1029    }
1030
1031    #[test]
1032    fn add_inplace_rejects_incompatible_rhs() {
1033        // (3,) += (4,) — not broadcast-compatible; error must be Err, not panic.
1034        let mut a = arr(vec![1.0, 2.0, 3.0]);
1035        let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
1036        assert!(a.add_inplace(&b).is_err());
1037        // LHS must be untouched on error.
1038        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1039    }
1040
1041    #[test]
1042    fn add_inplace_rejects_growing_shape() {
1043        // (1, 3) += (2, 3) — RHS bigger than LHS; the destination shape is
1044        // fixed for in-place operations, so this must error.
1045        let mut a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1046        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0; 6]).unwrap();
1047        assert!(a.add_inplace(&b).is_err());
1048        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1049    }
1050
1051    #[test]
1052    fn copyto_same_shape_fast_path() {
1053        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1054        let src = arr(vec![1.0, 2.0, 3.0]);
1055        copyto(&mut dst, &src).unwrap();
1056        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1057    }
1058
1059    #[test]
1060    fn copyto_broadcasts_row_into_matrix() {
1061        // (2, 3) <= (1, 3)
1062        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1063        let src = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1064        copyto(&mut dst, &src).unwrap();
1065        assert_eq!(
1066            dst.as_slice().unwrap(),
1067            &[10.0, 20.0, 30.0, 10.0, 20.0, 30.0]
1068        );
1069    }
1070
1071    #[test]
1072    fn copyto_broadcasts_cross_rank_src() {
1073        // (2, 3) <= (3,)  — a lower-rank src broadcasts against a higher-rank dst.
1074        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1075        let src = arr(vec![7.0, 8.0, 9.0]);
1076        copyto(&mut dst, &src).unwrap();
1077        assert_eq!(dst.as_slice().unwrap(), &[7.0, 8.0, 9.0, 7.0, 8.0, 9.0]);
1078    }
1079
1080    #[test]
1081    fn copyto_scalar_src_broadcasts_to_full_dst() {
1082        // (2, 3) <= () via a length-1 1D stand-in.
1083        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1084        let src = arr(vec![42.0]);
1085        copyto(&mut dst, &src).unwrap();
1086        assert_eq!(dst.as_slice().unwrap(), &[42.0; 6]);
1087    }
1088
1089    #[test]
1090    fn copyto_rejects_growing_dst() {
1091        // (1, 3) <= (2, 3) — src wants to grow dst; must error, dst untouched.
1092        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1093        let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![99.0; 6]).unwrap();
1094        assert!(copyto(&mut dst, &src).is_err());
1095        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1096    }
1097
1098    #[test]
1099    fn copyto_rejects_incompatible_shapes() {
1100        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1101        let src = arr(vec![1.0, 2.0, 3.0, 4.0]);
1102        assert!(copyto(&mut dst, &src).is_err());
1103        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1104    }
1105
1106    #[test]
1107    fn copyto_method_form_equivalent_to_function() {
1108        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1109        let src = arr(vec![1.0, 2.0, 3.0]);
1110        dst.copy_from(&src).unwrap();
1111        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1112    }
1113
1114    #[test]
1115    fn copyto_works_for_non_copy_element_type_i64() {
1116        // Exercises the Clone-not-Copy path on the Element trait — copyto
1117        // must not require T: Copy.
1118        let mut dst = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![0, 0, 0, 0]).unwrap();
1119        let src = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1120        copyto(&mut dst, &src).unwrap();
1121        assert_eq!(dst.as_slice().unwrap(), &[1, 2, 3, 4]);
1122    }
1123
1124    #[test]
1125    fn copyto_where_same_shape_only_writes_masked_positions() {
1126        let mut dst = arr(vec![1.0, 2.0, 3.0, 4.0]);
1127        let src = arr(vec![10.0, 20.0, 30.0, 40.0]);
1128        let mask =
1129            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1130        copyto_where(&mut dst, &src, &mask).unwrap();
1131        assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 4.0]);
1132    }
1133
1134    #[test]
1135    fn copyto_where_broadcasts_mask_across_dst() {
1136        // (2, 3) <= (2, 3), mask = (1, 3) broadcasts across rows.
1137        let mut dst =
1138            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1139                .unwrap();
1140        let src =
1141            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
1142                .unwrap();
1143        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
1144        copyto_where(&mut dst, &src, &mask).unwrap();
1145        assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 40.0, 5.0, 60.0]);
1146    }
1147
1148    #[test]
1149    fn copyto_where_broadcasts_scalar_src_with_mask() {
1150        // Conditional scalar fill: set dst[i,j] = 99.0 wherever mask is true.
1151        let mut dst =
1152            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1153                .unwrap();
1154        let src = arr(vec![99.0]);
1155        let mask = Array::<bool, Ix2>::from_vec(
1156            Ix2::new([2, 3]),
1157            vec![true, false, true, false, true, false],
1158        )
1159        .unwrap();
1160        copyto_where(&mut dst, &src, &mask).unwrap();
1161        assert_eq!(dst.as_slice().unwrap(), &[99.0, 2.0, 99.0, 4.0, 99.0, 6.0]);
1162    }
1163
1164    #[test]
1165    fn copyto_where_all_false_mask_is_noop() {
1166        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1167        let original = dst.as_slice().unwrap().to_vec();
1168        let src = arr(vec![99.0, 99.0, 99.0]);
1169        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1170        copyto_where(&mut dst, &src, &mask).unwrap();
1171        assert_eq!(dst.as_slice().unwrap(), &original[..]);
1172    }
1173
1174    #[test]
1175    fn copyto_where_all_true_mask_matches_copyto() {
1176        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1177        let src = arr(vec![1.0, 2.0, 3.0]);
1178        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1179        copyto_where(&mut dst, &src, &mask).unwrap();
1180        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1181    }
1182
1183    #[test]
1184    fn copyto_where_rejects_incompatible_src_shape() {
1185        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1186        let src = arr(vec![1.0, 2.0]);
1187        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1188        assert!(copyto_where(&mut dst, &src, &mask).is_err());
1189        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1190    }
1191
1192    #[test]
1193    fn copyto_where_rejects_incompatible_mask_shape() {
1194        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1195        let src = arr(vec![10.0, 20.0, 30.0]);
1196        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1197        assert!(copyto_where(&mut dst, &src, &mask).is_err());
1198        // Validation happens before writes — dst untouched.
1199        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1200    }
1201
1202    #[test]
1203    fn copy_from_where_method_form_equivalent() {
1204        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1205        let src = arr(vec![10.0, 20.0, 30.0]);
1206        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
1207        dst.copy_from_where(&src, &mask).unwrap();
1208        assert_eq!(dst.as_slice().unwrap(), &[1.0, 20.0, 3.0]);
1209    }
1210
1211    #[test]
1212    fn div_inplace_by_zero_yields_ieee_sentinels() {
1213        // Pin current semantics: division-by-zero in-place produces IEEE
1214        // inf/NaN at the offending positions and does NOT error.
1215        let mut a = arr(vec![1.0, 2.0, 0.0]);
1216        let b = arr(vec![2.0, 0.0, 0.0]);
1217        a.div_inplace(&b).unwrap();
1218        let s = a.as_slice().unwrap();
1219        assert_eq!(s[0], 0.5);
1220        assert!(s[1].is_infinite() && s[1].is_sign_positive());
1221        assert!(s[2].is_nan());
1222    }
1223}