Skip to main content

ferray_core/
ops.rs

1// ferray-core: Operator overloading for Array<T, D>
2//
3// Implements std::ops::{Add, Sub, Mul, Div, Rem, Neg} with
4// Output = FerrayResult<Array<T, D>>.
5//
6// Users write `(a + b)?` to get the result, maintaining the zero-panic
7// guarantee while enabling natural math syntax.
8//
9// Broadcasting: when both operands share the same dimension type `D`,
10// the operators broadcast along compatible axes (NumPy rules). Cross-rank
11// broadcasting (e.g. Ix1 + Ix2) is exposed via `add_broadcast` etc., which
12// return `Array<T, IxDyn>` because the result rank cannot be expressed in
13// the type system without specialization.
14//
15// See: https://github.com/dollspace-gay/ferray/issues/7
16// Broadcasting: https://github.com/dollspace-gay/ferray/issues/346
17
18use crate::array::owned::Array;
19use crate::dimension::Dimension;
20use crate::dimension::IxDyn;
21use crate::dimension::broadcast::{broadcast_shapes, broadcast_to};
22use crate::dtype::Element;
23use crate::error::{FerrayError, FerrayResult};
24
25/// Elementwise binary operation on two same-D arrays, with `NumPy` broadcasting.
26///
27/// - If shapes match exactly, takes the fast path (zip iter, no broadcast).
28/// - Otherwise, broadcasts both inputs to the common shape and applies `op`.
29///
30/// Both inputs share dimension type `D`, so the broadcast result also has
31/// rank `D::NDIM` (or, for `IxDyn`, the maximum of the two ranks). The result
32/// is reconstructed via [`Dimension::from_dim_slice`].
33///
34/// # Errors
35/// Returns `FerrayError::ShapeMismatch` if the shapes are not broadcast-compatible.
36fn elementwise_binary<T, D, F>(
37    a: &Array<T, D>,
38    b: &Array<T, D>,
39    op: F,
40    op_name: &str,
41) -> FerrayResult<Array<T, D>>
42where
43    T: Element + Copy,
44    D: Dimension,
45    F: Fn(T, T) -> T,
46{
47    // Fast path: identical shapes — no broadcasting needed.
48    if a.shape() == b.shape() {
49        let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
50        return Array::from_vec(a.dim().clone(), data);
51    }
52
53    // Broadcasting path.
54    let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
55        FerrayError::shape_mismatch(format!(
56            "operator {}: shapes {:?} and {:?} are not broadcast-compatible",
57            op_name,
58            a.shape(),
59            b.shape()
60        ))
61    })?;
62
63    let a_view = broadcast_to(a, &target_shape)?;
64    let b_view = broadcast_to(b, &target_shape)?;
65
66    let data: Vec<T> = a_view
67        .iter()
68        .zip(b_view.iter())
69        .map(|(&x, &y)| op(x, y))
70        .collect();
71
72    let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
73        FerrayError::shape_mismatch(format!(
74            "operator {op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
75        ))
76    })?;
77
78    Array::from_vec(result_dim, data)
79}
80
81/// Cross-rank broadcasting helper: apply a binary op to two arrays with
82/// possibly different dimension types, returning a dynamic-rank result.
83///
84/// This is the primitive behind [`Array::add_broadcast`], [`Array::sub_broadcast`],
85/// etc. Always returns `Array<T, IxDyn>` because the result rank depends
86/// on input shapes at runtime.
87fn elementwise_binary_dyn<T, D1, D2, F>(
88    a: &Array<T, D1>,
89    b: &Array<T, D2>,
90    op: F,
91    op_name: &str,
92) -> FerrayResult<Array<T, IxDyn>>
93where
94    T: Element + Copy,
95    D1: Dimension,
96    D2: Dimension,
97    F: Fn(T, T) -> T,
98{
99    let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
100        FerrayError::shape_mismatch(format!(
101            "{}: shapes {:?} and {:?} are not broadcast-compatible",
102            op_name,
103            a.shape(),
104            b.shape()
105        ))
106    })?;
107
108    let a_view = broadcast_to(a, &target_shape)?;
109    let b_view = broadcast_to(b, &target_shape)?;
110
111    let data: Vec<T> = a_view
112        .iter()
113        .zip(b_view.iter())
114        .map(|(&x, &y)| op(x, y))
115        .collect();
116
117    Array::from_vec(IxDyn::from(&target_shape[..]), data)
118}
119
120/// Implement a binary operator for all ownership combinations of Array.
121///
122/// Generates impls for:
123///   &Array op &Array
124///   Array  op Array
125///   Array  op &Array
126///   &Array op Array
127macro_rules! impl_binary_op {
128    ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
129        // &Array op &Array
130        impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
131        where
132            T: Element + Copy + std::ops::$trait<Output = T>,
133            D: Dimension,
134        {
135            type Output = FerrayResult<Array<T, D>>;
136
137            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
138                elementwise_binary(self, rhs, $op_fn, $op_name)
139            }
140        }
141
142        // Array op Array
143        impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
144        where
145            T: Element + Copy + std::ops::$trait<Output = T>,
146            D: Dimension,
147        {
148            type Output = FerrayResult<Array<T, D>>;
149
150            fn $method(self, rhs: Array<T, D>) -> Self::Output {
151                elementwise_binary(&self, &rhs, $op_fn, $op_name)
152            }
153        }
154
155        // Array op &Array
156        impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
157        where
158            T: Element + Copy + std::ops::$trait<Output = T>,
159            D: Dimension,
160        {
161            type Output = FerrayResult<Array<T, D>>;
162
163            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
164                elementwise_binary(&self, rhs, $op_fn, $op_name)
165            }
166        }
167
168        // &Array op Array
169        impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
170        where
171            T: Element + Copy + std::ops::$trait<Output = T>,
172            D: Dimension,
173        {
174            type Output = FerrayResult<Array<T, D>>;
175
176            fn $method(self, rhs: Array<T, D>) -> Self::Output {
177                elementwise_binary(self, &rhs, $op_fn, $op_name)
178            }
179        }
180    };
181}
182
183impl_binary_op!(Add, add, |a, b| a + b, "+");
184impl_binary_op!(Sub, sub, |a, b| a - b, "-");
185impl_binary_op!(Mul, mul, |a, b| a * b, "*");
186impl_binary_op!(Div, div, |a, b| a / b, "/");
187impl_binary_op!(Rem, rem, |a, b| a % b, "%");
188
189// ---------------------------------------------------------------------------
190// Scalar-array operations: Array op scalar and scalar op Array
191// ---------------------------------------------------------------------------
192
193/// Implement scalar-array binary operators (Array op T and &Array op T).
194macro_rules! impl_scalar_op {
195    ($trait:ident, $method:ident, $op_fn:expr) => {
196        // &Array op scalar
197        impl<T, D> std::ops::$trait<T> for &Array<T, D>
198        where
199            T: Element + Copy + std::ops::$trait<Output = T>,
200            D: Dimension,
201        {
202            type Output = FerrayResult<Array<T, D>>;
203
204            fn $method(self, rhs: T) -> Self::Output {
205                let data: Vec<T> = self.iter().map(|&x| $op_fn(x, rhs)).collect();
206                Array::from_vec(self.dim().clone(), data)
207            }
208        }
209
210        // Array op scalar
211        impl<T, D> std::ops::$trait<T> for Array<T, D>
212        where
213            T: Element + Copy + std::ops::$trait<Output = T>,
214            D: Dimension,
215        {
216            type Output = FerrayResult<Array<T, D>>;
217
218            fn $method(self, rhs: T) -> Self::Output {
219                (&self).$method(rhs)
220            }
221        }
222    };
223}
224
225impl_scalar_op!(Add, add, |a, b| a + b);
226impl_scalar_op!(Sub, sub, |a, b| a - b);
227impl_scalar_op!(Mul, mul, |a, b| a * b);
228impl_scalar_op!(Div, div, |a, b| a / b);
229impl_scalar_op!(Rem, rem, |a, b| a % b);
230
231// Unary negation: -&Array and -Array
232impl<T, D> std::ops::Neg for &Array<T, D>
233where
234    T: Element + Copy + std::ops::Neg<Output = T>,
235    D: Dimension,
236{
237    type Output = FerrayResult<Array<T, D>>;
238
239    fn neg(self) -> Self::Output {
240        let data: Vec<T> = self.iter().map(|&x| -x).collect();
241        Array::from_vec(self.dim().clone(), data)
242    }
243}
244
245impl<T, D> std::ops::Neg for Array<T, D>
246where
247    T: Element + Copy + std::ops::Neg<Output = T>,
248    D: Dimension,
249{
250    type Output = FerrayResult<Self>;
251
252    fn neg(self) -> Self::Output {
253        -&self
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Cross-rank broadcasting methods
259//
260// These methods accept an operand of any dimension type and return a
261// dynamic-rank result. They handle the case where the std::ops operators
262// can't apply because the type system requires both operands to share `D`.
263//
264// Example: a (Ix1, shape (3,)) + b (Ix2, shape (2,1)) -> Array<T, IxDyn> shape (2,3)
265// ---------------------------------------------------------------------------
266
267impl<T, D> Array<T, D>
268where
269    T: Element + Copy,
270    D: Dimension,
271{
272    /// Elementwise add with `NumPy` broadcasting across arbitrary ranks.
273    ///
274    /// Returns a dynamic-rank `Array<T, IxDyn>` so that mixed-rank inputs
275    /// (e.g. 1D + 2D) can produce a result whose rank is determined at
276    /// runtime. For same-rank inputs prefer the `+` operator.
277    ///
278    /// # Errors
279    /// Returns `FerrayError::ShapeMismatch` if shapes are not broadcast-compatible.
280    pub fn add_broadcast<D2: Dimension>(
281        &self,
282        other: &Array<T, D2>,
283    ) -> FerrayResult<Array<T, IxDyn>>
284    where
285        T: std::ops::Add<Output = T>,
286    {
287        elementwise_binary_dyn(self, other, |x, y| x + y, "add_broadcast")
288    }
289
290    /// Elementwise subtract with `NumPy` broadcasting across arbitrary ranks.
291    ///
292    /// See [`Array::add_broadcast`] for details.
293    pub fn sub_broadcast<D2: Dimension>(
294        &self,
295        other: &Array<T, D2>,
296    ) -> FerrayResult<Array<T, IxDyn>>
297    where
298        T: std::ops::Sub<Output = T>,
299    {
300        elementwise_binary_dyn(self, other, |x, y| x - y, "sub_broadcast")
301    }
302
303    /// Elementwise multiply with `NumPy` broadcasting across arbitrary ranks.
304    ///
305    /// See [`Array::add_broadcast`] for details.
306    pub fn mul_broadcast<D2: Dimension>(
307        &self,
308        other: &Array<T, D2>,
309    ) -> FerrayResult<Array<T, IxDyn>>
310    where
311        T: std::ops::Mul<Output = T>,
312    {
313        elementwise_binary_dyn(self, other, |x, y| x * y, "mul_broadcast")
314    }
315
316    /// Elementwise divide with `NumPy` broadcasting across arbitrary ranks.
317    ///
318    /// See [`Array::add_broadcast`] for details.
319    pub fn div_broadcast<D2: Dimension>(
320        &self,
321        other: &Array<T, D2>,
322    ) -> FerrayResult<Array<T, IxDyn>>
323    where
324        T: std::ops::Div<Output = T>,
325    {
326        elementwise_binary_dyn(self, other, |x, y| x / y, "div_broadcast")
327    }
328
329    /// Elementwise remainder with `NumPy` broadcasting across arbitrary ranks.
330    ///
331    /// See [`Array::add_broadcast`] for details.
332    pub fn rem_broadcast<D2: Dimension>(
333        &self,
334        other: &Array<T, D2>,
335    ) -> FerrayResult<Array<T, IxDyn>>
336    where
337        T: std::ops::Rem<Output = T>,
338    {
339        elementwise_binary_dyn(self, other, |x, y| x % y, "rem_broadcast")
340    }
341}
342
343// ---------------------------------------------------------------------------
344// In-place operators (#348)
345//
346// NumPy supports `arr += 5`, `arr *= other` — mutation without allocation.
347// We split these into two groups:
348//
349// 1. Scalar in-place: `std::ops::*Assign<T> for Array<T, D>`. Always safe
350//    (scalar ops never fail), implemented via `mapv_inplace`.
351//
352// 2. Array-array in-place: inherent fallible methods `add_inplace`,
353//    `sub_inplace`, `mul_inplace`, `div_inplace`, `rem_inplace` returning
354//    `FerrayResult<()>`. We cannot implement `std::ops::AddAssign<&Array>`
355//    because the trait signature returns `()`, leaving no channel for
356//    shape-mismatch errors — incompatible with ferray's zero-panic rule.
357//    Use the `*_inplace` methods instead of `arr += other`.
358//
359//    The RHS is broadcast to `self.shape()` (NumPy semantics for in-place
360//    arithmetic: the destination shape is fixed and the source must be
361//    broadcastable to it).
362// ---------------------------------------------------------------------------
363
364/// Implement a scalar `*Assign` operator trait using `mapv_inplace`.
365macro_rules! impl_scalar_op_assign {
366    ($trait:ident, $method:ident, $op:tt) => {
367        impl<T, D> std::ops::$trait<T> for Array<T, D>
368        where
369            T: Element + Copy + std::ops::$trait,
370            D: Dimension,
371        {
372            fn $method(&mut self, rhs: T) {
373                self.mapv_inplace(|mut x| {
374                    x $op rhs;
375                    x
376                });
377            }
378        }
379    };
380}
381
382impl_scalar_op_assign!(AddAssign, add_assign, +=);
383impl_scalar_op_assign!(SubAssign, sub_assign, -=);
384impl_scalar_op_assign!(MulAssign, mul_assign, *=);
385impl_scalar_op_assign!(DivAssign, div_assign, /=);
386impl_scalar_op_assign!(RemAssign, rem_assign, %=);
387
388/// Shared implementation for in-place array-array ops with broadcasting.
389///
390/// Fast path on identical shapes uses `zip_mut_with`. Otherwise the RHS is
391/// broadcast into `self.shape()` via [`broadcast_to`] and then zipped with
392/// `self.iter_mut()` in logical order.
393fn inplace_binary<T, D, F>(
394    lhs: &mut Array<T, D>,
395    rhs: &Array<T, D>,
396    op: F,
397    op_name: &str,
398) -> FerrayResult<()>
399where
400    T: Element + Copy,
401    D: Dimension,
402    F: Fn(T, T) -> T,
403{
404    // Fast path: identical shapes — delegate to zip_mut_with.
405    if lhs.shape() == rhs.shape() {
406        return lhs.zip_mut_with(rhs, |a, b| *a = op(*a, *b));
407    }
408
409    // Broadcasting path: rhs must broadcast into lhs.shape() (the destination
410    // shape is fixed for in-place operations). Any shape change on the LHS
411    // would require a reallocation, defeating the purpose of `*_inplace`.
412    let target_shape: Vec<usize> = lhs.shape().to_vec();
413    let rhs_view = broadcast_to(rhs, &target_shape).map_err(|_| {
414        FerrayError::shape_mismatch(format!(
415            "{}: shape {:?} cannot be broadcast into destination shape {:?}",
416            op_name,
417            rhs.shape(),
418            target_shape
419        ))
420    })?;
421
422    for (a, b) in lhs.iter_mut().zip(rhs_view.iter()) {
423        *a = op(*a, *b);
424    }
425    Ok(())
426}
427
428impl<T, D> Array<T, D>
429where
430    T: Element + Copy,
431    D: Dimension,
432{
433    /// In-place elementwise add: `self[i] += other[i]`. `other` is broadcast
434    /// into `self.shape()`; the destination shape never changes.
435    ///
436    /// Prefer this over `self = (&self + &other)?` for large arrays — it
437    /// avoids allocating a new result buffer.
438    ///
439    /// # Errors
440    /// Returns `FerrayError::ShapeMismatch` if `other.shape()` cannot be
441    /// broadcast into `self.shape()`.
442    pub fn add_inplace(&mut self, other: &Self) -> FerrayResult<()>
443    where
444        T: std::ops::Add<Output = T>,
445    {
446        inplace_binary(self, other, |a, b| a + b, "add_inplace")
447    }
448
449    /// In-place elementwise subtract. See [`Array::add_inplace`].
450    pub fn sub_inplace(&mut self, other: &Self) -> FerrayResult<()>
451    where
452        T: std::ops::Sub<Output = T>,
453    {
454        inplace_binary(self, other, |a, b| a - b, "sub_inplace")
455    }
456
457    /// In-place elementwise multiply. See [`Array::add_inplace`].
458    pub fn mul_inplace(&mut self, other: &Self) -> FerrayResult<()>
459    where
460        T: std::ops::Mul<Output = T>,
461    {
462        inplace_binary(self, other, |a, b| a * b, "mul_inplace")
463    }
464
465    /// In-place elementwise divide. See [`Array::add_inplace`].
466    pub fn div_inplace(&mut self, other: &Self) -> FerrayResult<()>
467    where
468        T: std::ops::Div<Output = T>,
469    {
470        inplace_binary(self, other, |a, b| a / b, "div_inplace")
471    }
472
473    /// In-place elementwise remainder. See [`Array::add_inplace`].
474    pub fn rem_inplace(&mut self, other: &Self) -> FerrayResult<()>
475    where
476        T: std::ops::Rem<Output = T>,
477    {
478        inplace_binary(self, other, |a, b| a % b, "rem_inplace")
479    }
480}
481
482// ---------------------------------------------------------------------------
483// copyto: broadcasted elementwise assignment (#352)
484//
485// NumPy parity for `np.copyto(dst, src)`: copies values from `src` into `dst`,
486// broadcasting `src` into `dst.shape()`. The destination shape is fixed — a
487// copyto never reallocates `dst`. `casting=` is not modeled because ferray's
488// type system already enforces matching element types; use `astype()` before
489// copyto if conversion is desired.
490//
491// Kept separate from the in-place arithmetic ops because copyto is the
492// "fundamental building block" called out in the issue: it has to handle
493// cross-rank src/dst pairs (e.g. `Ix1` into `Ix2` via broadcasting), whereas
494// `*_inplace` stay within one dimension type `D` to match std::ops::*Assign.
495// ---------------------------------------------------------------------------
496
497/// Copy values from `src` into `dst`, broadcasting `src` to `dst.shape()`.
498///
499/// Equivalent to `np.copyto(dst, src)` without the `casting=` or `where=`
500/// parameters. `src` may have any dimension type — it only needs to be
501/// broadcast-compatible with `dst.shape()`. The destination shape is fixed,
502/// so `src` can never grow the destination.
503///
504/// # Errors
505/// Returns `FerrayError::ShapeMismatch` if `src.shape()` cannot be broadcast
506/// into `dst.shape()`. On error `dst` is left untouched.
507pub fn copyto<T, D1, D2>(dst: &mut Array<T, D1>, src: &Array<T, D2>) -> FerrayResult<()>
508where
509    T: Element,
510    D1: Dimension,
511    D2: Dimension,
512{
513    // Fast path: identical shapes — straight iter-zip, no broadcast machinery.
514    if dst.shape() == src.shape() {
515        for (d, s) in dst.iter_mut().zip(src.iter()) {
516            *d = s.clone();
517        }
518        return Ok(());
519    }
520
521    // Broadcasting path: src must broadcast into dst.shape(). We validate up
522    // front so that a shape error leaves dst completely untouched.
523    let target_shape: Vec<usize> = dst.shape().to_vec();
524    let src_view = broadcast_to(src, &target_shape).map_err(|_| {
525        FerrayError::shape_mismatch(format!(
526            "copyto: source shape {:?} cannot be broadcast into destination shape {:?}",
527            src.shape(),
528            target_shape
529        ))
530    })?;
531
532    for (d, s) in dst.iter_mut().zip(src_view.iter()) {
533        *d = s.clone();
534    }
535    Ok(())
536}
537
538/// Copy values from `src` into `dst` where `mask` is `true`, broadcasting
539/// both `src` and `mask` to `dst.shape()`.
540///
541/// Equivalent to `np.copyto(dst, src, where=mask)`. This is the generalized
542/// form of `NumPy`'s `where=` ufunc parameter (#353): positions where the mask
543/// is `false` are left untouched in `dst`. All three shapes must be
544/// broadcast-compatible with `dst.shape()`; the destination shape is fixed,
545/// so neither `src` nor `mask` can grow `dst`.
546///
547/// # Errors
548/// Returns `FerrayError::ShapeMismatch` if either `src.shape()` or
549/// `mask.shape()` cannot be broadcast into `dst.shape()`. On error `dst` is
550/// left untouched (broadcast validation happens before any writes).
551pub fn copyto_where<T, D1, D2, D3>(
552    dst: &mut Array<T, D1>,
553    src: &Array<T, D2>,
554    mask: &Array<bool, D3>,
555) -> FerrayResult<()>
556where
557    T: Element,
558    D1: Dimension,
559    D2: Dimension,
560    D3: Dimension,
561{
562    // Validate + materialize broadcast views BEFORE writing, so that a shape
563    // mismatch can't leave dst in a partially-updated state.
564    let target_shape: Vec<usize> = dst.shape().to_vec();
565
566    let src_view = broadcast_to(src, &target_shape).map_err(|_| {
567        FerrayError::shape_mismatch(format!(
568            "copyto_where: source shape {:?} cannot be broadcast into destination shape {:?}",
569            src.shape(),
570            target_shape
571        ))
572    })?;
573
574    let mask_view = broadcast_to(mask, &target_shape).map_err(|_| {
575        FerrayError::shape_mismatch(format!(
576            "copyto_where: mask shape {:?} cannot be broadcast into destination shape {:?}",
577            mask.shape(),
578            target_shape
579        ))
580    })?;
581
582    for ((d, s), &m) in dst.iter_mut().zip(src_view.iter()).zip(mask_view.iter()) {
583        if m {
584            *d = s.clone();
585        }
586    }
587    Ok(())
588}
589
590impl<T, D> Array<T, D>
591where
592    T: Element,
593    D: Dimension,
594{
595    /// Copy values from `src` into `self`, broadcasting `src` to `self.shape()`.
596    ///
597    /// See [`copyto`] for the free-function form and full semantics.
598    ///
599    /// # Errors
600    /// Returns `FerrayError::ShapeMismatch` if `src` cannot be broadcast into
601    /// `self.shape()`. On error `self` is left untouched.
602    pub fn copy_from<D2: Dimension>(&mut self, src: &Array<T, D2>) -> FerrayResult<()> {
603        copyto(self, src)
604    }
605
606    /// Copy values from `src` into `self` at positions where `mask` is `true`,
607    /// broadcasting both inputs to `self.shape()`.
608    ///
609    /// See [`copyto_where`] for the free-function form and full semantics.
610    ///
611    /// # Errors
612    /// Returns `FerrayError::ShapeMismatch` if `src` or `mask` cannot be
613    /// broadcast into `self.shape()`. On error `self` is left untouched.
614    pub fn copy_from_where<D2: Dimension, D3: Dimension>(
615        &mut self,
616        src: &Array<T, D2>,
617        mask: &Array<bool, D3>,
618    ) -> FerrayResult<()> {
619        copyto_where(self, src, mask)
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use crate::dimension::Ix1;
627
628    fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
629        let n = data.len();
630        Array::from_vec(Ix1::new([n]), data).unwrap()
631    }
632
633    fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
634        let n = data.len();
635        Array::from_vec(Ix1::new([n]), data).unwrap()
636    }
637
638    #[test]
639    fn test_add_ref_ref() {
640        let a = arr(vec![1.0, 2.0, 3.0]);
641        let b = arr(vec![4.0, 5.0, 6.0]);
642        let c = (&a + &b).unwrap();
643        assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
644    }
645
646    #[test]
647    fn test_add_owned_owned() {
648        let a = arr(vec![1.0, 2.0]);
649        let b = arr(vec![3.0, 4.0]);
650        let c = (a + b).unwrap();
651        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
652    }
653
654    #[test]
655    fn test_add_mixed() {
656        let a = arr(vec![1.0, 2.0]);
657        let b = arr(vec![3.0, 4.0]);
658        let c = (a + &b).unwrap();
659        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
660
661        let d = arr(vec![10.0, 20.0]);
662        let e = (&b + d).unwrap();
663        assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
664    }
665
666    #[test]
667    fn test_sub() {
668        let a = arr(vec![5.0, 7.0]);
669        let b = arr(vec![1.0, 2.0]);
670        let c = (&a - &b).unwrap();
671        assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
672    }
673
674    #[test]
675    fn test_mul() {
676        let a = arr(vec![2.0, 3.0]);
677        let b = arr(vec![4.0, 5.0]);
678        let c = (&a * &b).unwrap();
679        assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
680    }
681
682    #[test]
683    fn test_div() {
684        let a = arr(vec![10.0, 20.0]);
685        let b = arr(vec![2.0, 5.0]);
686        let c = (&a / &b).unwrap();
687        assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
688    }
689
690    #[test]
691    fn test_rem() {
692        let a = arr_i32(vec![7, 10]);
693        let b = arr_i32(vec![3, 4]);
694        let c = (&a % &b).unwrap();
695        assert_eq!(c.as_slice().unwrap(), &[1, 2]);
696    }
697
698    #[test]
699    fn test_neg() {
700        let a = arr(vec![1.0, -2.0, 3.0]);
701        let b = (-&a).unwrap();
702        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
703    }
704
705    #[test]
706    fn test_neg_owned() {
707        let a = arr(vec![1.0, -2.0]);
708        let b = (-a).unwrap();
709        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
710    }
711
712    #[test]
713    fn test_shape_mismatch_errors() {
714        let a = arr(vec![1.0, 2.0]);
715        let b = arr(vec![1.0, 2.0, 3.0]);
716        let result = &a + &b;
717        assert!(result.is_err());
718    }
719
720    // --- Scalar-array operations ---
721
722    #[test]
723    fn test_add_scalar() {
724        let a = arr(vec![1.0, 2.0, 3.0]);
725        let c = (&a + 10.0).unwrap();
726        assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
727    }
728
729    #[test]
730    fn test_sub_scalar() {
731        let a = arr(vec![10.0, 20.0, 30.0]);
732        let c = (&a - 5.0).unwrap();
733        assert_eq!(c.as_slice().unwrap(), &[5.0, 15.0, 25.0]);
734    }
735
736    #[test]
737    fn test_mul_scalar() {
738        let a = arr(vec![1.0, 2.0, 3.0]);
739        let c = (&a * 3.0).unwrap();
740        assert_eq!(c.as_slice().unwrap(), &[3.0, 6.0, 9.0]);
741    }
742
743    #[test]
744    fn test_div_scalar() {
745        let a = arr(vec![10.0, 20.0, 30.0]);
746        let c = (&a / 10.0).unwrap();
747        assert_eq!(c.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
748    }
749
750    #[test]
751    fn test_rem_scalar() {
752        let a = arr_i32(vec![7, 10, 15]);
753        let c = (&a % 4).unwrap();
754        assert_eq!(c.as_slice().unwrap(), &[3, 2, 3]);
755    }
756
757    #[test]
758    fn test_scalar_op_owned() {
759        let a = arr(vec![1.0, 2.0, 3.0]);
760        let c = (a + 10.0).unwrap();
761        assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
762    }
763
764    #[test]
765    fn test_chained_ops() {
766        let a = arr(vec![1.0, 2.0, 3.0]);
767        let b = arr(vec![4.0, 5.0, 6.0]);
768        let c = arr(vec![10.0, 10.0, 10.0]);
769        // (a + b)? * c)?
770        let result = (&(&a + &b).unwrap() * &c).unwrap();
771        assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
772    }
773
774    // -----------------------------------------------------------------------
775    // Broadcasting tests (issue #346)
776    // -----------------------------------------------------------------------
777
778    use crate::dimension::{Ix2, Ix3, IxDyn};
779
780    #[test]
781    fn test_broadcast_2d_row_plus_column() {
782        // (3, 1) + (1, 4) -> (3, 4) — both Ix2
783        let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
784        let row =
785            Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
786        let result = (&col + &row).unwrap();
787        assert_eq!(result.shape(), &[3, 4]);
788        assert_eq!(
789            result.as_slice().unwrap(),
790            &[
791                11.0, 21.0, 31.0, 41.0, // row 1
792                12.0, 22.0, 32.0, 42.0, // row 2
793                13.0, 23.0, 33.0, 43.0, // row 3
794            ]
795        );
796    }
797
798    #[test]
799    fn test_broadcast_2d_stretch_one_axis() {
800        // (3, 4) + (1, 4) -> (3, 4)
801        let a = Array::<f64, Ix2>::from_vec(
802            Ix2::new([3, 4]),
803            vec![
804                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
805            ],
806        )
807        .unwrap();
808        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![100.0, 200.0, 300.0, 400.0])
809            .unwrap();
810        let result = (&a + &b).unwrap();
811        assert_eq!(result.shape(), &[3, 4]);
812        assert_eq!(
813            result.as_slice().unwrap(),
814            &[
815                101.0, 202.0, 303.0, 404.0, 105.0, 206.0, 307.0, 408.0, 109.0, 210.0, 311.0, 412.0,
816            ]
817        );
818    }
819
820    #[test]
821    fn test_broadcast_3d_with_2d_axis() {
822        // (2, 3, 4) - (1, 3, 4) -> (2, 3, 4) — both Ix3
823        let a =
824            Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), (1..=24).map(|i| i as f64).collect())
825                .unwrap();
826        let b =
827            Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (1..=12).map(|i| i as f64).collect())
828                .unwrap();
829        let result = (&a - &b).unwrap();
830        assert_eq!(result.shape(), &[2, 3, 4]);
831        // First 12 elements: a[0..12] - b
832        let first_half: Vec<f64> = (1..=12).map(|_| 0.0).collect();
833        assert_eq!(&result.as_slice().unwrap()[..12], &first_half[..]);
834        // Second 12 elements: a[12..24] - b == 12 each
835        let second_half: Vec<f64> = (0..12).map(|_| 12.0).collect();
836        assert_eq!(&result.as_slice().unwrap()[12..], &second_half[..]);
837    }
838
839    #[test]
840    fn test_broadcast_incompatible_shapes_error() {
841        // (3,) + (4,) — incompatible
842        let a = arr(vec![1.0, 2.0, 3.0]);
843        let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
844        let result = &a + &b;
845        assert!(result.is_err());
846
847        // (3, 4) + (3, 5) — incompatible
848        let c = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![0.0; 12]).unwrap();
849        let d = Array::<f64, Ix2>::from_vec(Ix2::new([3, 5]), vec![0.0; 15]).unwrap();
850        assert!((&c + &d).is_err());
851    }
852
853    #[test]
854    fn test_broadcast_mul_2d() {
855        // (3, 1) * (1, 3) -> (3, 3) outer-product style
856        let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
857        let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
858        let result = (&col * &row).unwrap();
859        assert_eq!(result.shape(), &[3, 3]);
860        assert_eq!(
861            result.as_slice().unwrap(),
862            &[10, 20, 30, 20, 40, 60, 30, 60, 90]
863        );
864    }
865
866    // -----------------------------------------------------------------------
867    // Cross-rank broadcasting via `add_broadcast` etc.
868    // -----------------------------------------------------------------------
869
870    #[test]
871    fn test_add_broadcast_1d_plus_2d() {
872        // (3,) + (2, 3) -> (2, 3)
873        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
874        let m =
875            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
876                .unwrap();
877        let result = v.add_broadcast(&m).unwrap();
878        assert_eq!(result.shape(), &[2, 3]);
879        assert_eq!(
880            result.as_slice().unwrap(),
881            &[11.0, 22.0, 33.0, 41.0, 52.0, 63.0]
882        );
883    }
884
885    #[test]
886    fn test_add_broadcast_1d_plus_column() {
887        // (3,) + (2, 1) -> (2, 3) — the canonical NumPy example from issue #346
888        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
889        let col = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
890        let result = v.add_broadcast(&col).unwrap();
891        assert_eq!(result.shape(), &[2, 3]);
892        assert_eq!(
893            result.as_slice().unwrap(),
894            &[11.0, 12.0, 13.0, 21.0, 22.0, 23.0]
895        );
896    }
897
898    #[test]
899    fn test_sub_broadcast_2d_minus_1d() {
900        // (2, 3) - (3,) -> (2, 3)
901        let m =
902            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
903                .unwrap();
904        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
905        let result = m.sub_broadcast(&v).unwrap();
906        assert_eq!(result.shape(), &[2, 3]);
907        assert_eq!(
908            result.as_slice().unwrap(),
909            &[9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
910        );
911    }
912
913    #[test]
914    fn test_mul_broadcast_returns_dyn() {
915        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
916        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
917        let result: Array<f64, IxDyn> = a.mul_broadcast(&b).unwrap();
918        assert_eq!(result.shape(), &[2, 3]);
919        assert_eq!(
920            result.as_slice().unwrap(),
921            &[10.0, 20.0, 30.0, 20.0, 40.0, 60.0]
922        );
923    }
924
925    #[test]
926    fn test_div_broadcast_incompatible() {
927        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
928        let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
929        assert!(a.div_broadcast(&b).is_err());
930    }
931
932    #[test]
933    fn test_rem_broadcast_2d() {
934        let a =
935            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
936        let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![3, 7, 11]).unwrap();
937        let result = a.rem_broadcast(&b).unwrap();
938        assert_eq!(result.shape(), &[2, 3]);
939        assert_eq!(
940            result.as_slice().unwrap(),
941            &[10 % 3, 20 % 7, 30 % 11, 40 % 3, 50 % 7, 60 % 11]
942        );
943    }
944
945    // ------------------------------------------------------------------
946    // #348: in-place operators
947    // ------------------------------------------------------------------
948
949    #[test]
950    fn scalar_add_assign_mutates_in_place() {
951        let mut a = arr(vec![1.0, 2.0, 3.0]);
952        a += 10.0;
953        assert_eq!(a.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
954    }
955
956    #[test]
957    fn scalar_sub_mul_div_rem_assign() {
958        let mut a = arr(vec![10.0, 20.0, 30.0]);
959        a -= 1.0;
960        assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
961        a *= 2.0;
962        assert_eq!(a.as_slice().unwrap(), &[18.0, 38.0, 58.0]);
963        a /= 2.0;
964        assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
965        let mut b = arr_i32(vec![10, 11, 12]);
966        b %= 3;
967        assert_eq!(b.as_slice().unwrap(), &[1, 2, 0]);
968    }
969
970    #[test]
971    fn scalar_assign_preserves_shape_ix2() {
972        let mut a =
973            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
974                .unwrap();
975        a += 1.0;
976        assert_eq!(a.shape(), &[2, 3]);
977        assert_eq!(a.as_slice().unwrap(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
978    }
979
980    #[test]
981    fn add_inplace_same_shape_fast_path() {
982        let mut a = arr(vec![1.0, 2.0, 3.0]);
983        let b = arr(vec![10.0, 20.0, 30.0]);
984        a.add_inplace(&b).unwrap();
985        assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
986    }
987
988    #[test]
989    fn sub_mul_div_rem_inplace_same_shape() {
990        let mut a = arr(vec![10.0, 20.0, 30.0]);
991        let b = arr(vec![1.0, 2.0, 3.0]);
992        a.sub_inplace(&b).unwrap();
993        assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
994        a.mul_inplace(&b).unwrap();
995        assert_eq!(a.as_slice().unwrap(), &[9.0, 36.0, 81.0]);
996        a.div_inplace(&b).unwrap();
997        assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
998        let mut c = arr_i32(vec![10, 20, 30]);
999        let d = arr_i32(vec![3, 7, 11]);
1000        c.rem_inplace(&d).unwrap();
1001        assert_eq!(c.as_slice().unwrap(), &[1, 6, 8]);
1002    }
1003
1004    #[test]
1005    fn add_inplace_broadcasts_rhs_into_lhs_shape() {
1006        // (2, 3) += (1, 3) — RHS row broadcast across LHS rows.
1007        let mut a =
1008            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1009                .unwrap();
1010        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1011        a.add_inplace(&b).unwrap();
1012        assert_eq!(a.shape(), &[2, 3]);
1013        assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
1014    }
1015
1016    #[test]
1017    fn add_inplace_broadcasts_column_into_rows() {
1018        // (2, 3) += (2, 1) — RHS column broadcast across LHS columns.
1019        let mut a =
1020            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1021                .unwrap();
1022        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![100.0, 200.0]).unwrap();
1023        a.add_inplace(&b).unwrap();
1024        assert_eq!(
1025            a.as_slice().unwrap(),
1026            &[101.0, 102.0, 103.0, 204.0, 205.0, 206.0]
1027        );
1028    }
1029
1030    #[test]
1031    fn add_inplace_rejects_incompatible_rhs() {
1032        // (3,) += (4,) — not broadcast-compatible; error must be Err, not panic.
1033        let mut a = arr(vec![1.0, 2.0, 3.0]);
1034        let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
1035        assert!(a.add_inplace(&b).is_err());
1036        // LHS must be untouched on error.
1037        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1038    }
1039
1040    #[test]
1041    fn add_inplace_rejects_growing_shape() {
1042        // (1, 3) += (2, 3) — RHS bigger than LHS; the destination shape is
1043        // fixed for in-place operations, so this must error.
1044        let mut a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1045        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0; 6]).unwrap();
1046        assert!(a.add_inplace(&b).is_err());
1047        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1048    }
1049
1050    #[test]
1051    fn copyto_same_shape_fast_path() {
1052        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1053        let src = arr(vec![1.0, 2.0, 3.0]);
1054        copyto(&mut dst, &src).unwrap();
1055        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1056    }
1057
1058    #[test]
1059    fn copyto_broadcasts_row_into_matrix() {
1060        // (2, 3) <= (1, 3)
1061        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1062        let src = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1063        copyto(&mut dst, &src).unwrap();
1064        assert_eq!(
1065            dst.as_slice().unwrap(),
1066            &[10.0, 20.0, 30.0, 10.0, 20.0, 30.0]
1067        );
1068    }
1069
1070    #[test]
1071    fn copyto_broadcasts_cross_rank_src() {
1072        // (2, 3) <= (3,)  — a lower-rank src broadcasts against a higher-rank dst.
1073        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1074        let src = arr(vec![7.0, 8.0, 9.0]);
1075        copyto(&mut dst, &src).unwrap();
1076        assert_eq!(dst.as_slice().unwrap(), &[7.0, 8.0, 9.0, 7.0, 8.0, 9.0]);
1077    }
1078
1079    #[test]
1080    fn copyto_scalar_src_broadcasts_to_full_dst() {
1081        // (2, 3) <= () via a length-1 1D stand-in.
1082        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1083        let src = arr(vec![42.0]);
1084        copyto(&mut dst, &src).unwrap();
1085        assert_eq!(dst.as_slice().unwrap(), &[42.0; 6]);
1086    }
1087
1088    #[test]
1089    fn copyto_rejects_growing_dst() {
1090        // (1, 3) <= (2, 3) — src wants to grow dst; must error, dst untouched.
1091        let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1092        let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![99.0; 6]).unwrap();
1093        assert!(copyto(&mut dst, &src).is_err());
1094        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1095    }
1096
1097    #[test]
1098    fn copyto_rejects_incompatible_shapes() {
1099        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1100        let src = arr(vec![1.0, 2.0, 3.0, 4.0]);
1101        assert!(copyto(&mut dst, &src).is_err());
1102        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1103    }
1104
1105    #[test]
1106    fn copyto_method_form_equivalent_to_function() {
1107        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1108        let src = arr(vec![1.0, 2.0, 3.0]);
1109        dst.copy_from(&src).unwrap();
1110        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1111    }
1112
1113    #[test]
1114    fn copyto_works_for_non_copy_element_type_i64() {
1115        // Exercises the Clone-not-Copy path on the Element trait — copyto
1116        // must not require T: Copy.
1117        let mut dst = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![0, 0, 0, 0]).unwrap();
1118        let src = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1119        copyto(&mut dst, &src).unwrap();
1120        assert_eq!(dst.as_slice().unwrap(), &[1, 2, 3, 4]);
1121    }
1122
1123    #[test]
1124    fn copyto_where_same_shape_only_writes_masked_positions() {
1125        let mut dst = arr(vec![1.0, 2.0, 3.0, 4.0]);
1126        let src = arr(vec![10.0, 20.0, 30.0, 40.0]);
1127        let mask =
1128            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1129        copyto_where(&mut dst, &src, &mask).unwrap();
1130        assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 4.0]);
1131    }
1132
1133    #[test]
1134    fn copyto_where_broadcasts_mask_across_dst() {
1135        // (2, 3) <= (2, 3), mask = (1, 3) broadcasts across rows.
1136        let mut dst =
1137            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1138                .unwrap();
1139        let src =
1140            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
1141                .unwrap();
1142        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
1143        copyto_where(&mut dst, &src, &mask).unwrap();
1144        assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 40.0, 5.0, 60.0]);
1145    }
1146
1147    #[test]
1148    fn copyto_where_broadcasts_scalar_src_with_mask() {
1149        // Conditional scalar fill: set dst[i,j] = 99.0 wherever mask is true.
1150        let mut dst =
1151            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1152                .unwrap();
1153        let src = arr(vec![99.0]);
1154        let mask = Array::<bool, Ix2>::from_vec(
1155            Ix2::new([2, 3]),
1156            vec![true, false, true, false, true, false],
1157        )
1158        .unwrap();
1159        copyto_where(&mut dst, &src, &mask).unwrap();
1160        assert_eq!(dst.as_slice().unwrap(), &[99.0, 2.0, 99.0, 4.0, 99.0, 6.0]);
1161    }
1162
1163    #[test]
1164    fn copyto_where_all_false_mask_is_noop() {
1165        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1166        let original = dst.as_slice().unwrap().to_vec();
1167        let src = arr(vec![99.0, 99.0, 99.0]);
1168        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1169        copyto_where(&mut dst, &src, &mask).unwrap();
1170        assert_eq!(dst.as_slice().unwrap(), &original[..]);
1171    }
1172
1173    #[test]
1174    fn copyto_where_all_true_mask_matches_copyto() {
1175        let mut dst = arr(vec![0.0, 0.0, 0.0]);
1176        let src = arr(vec![1.0, 2.0, 3.0]);
1177        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1178        copyto_where(&mut dst, &src, &mask).unwrap();
1179        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1180    }
1181
1182    #[test]
1183    fn copyto_where_rejects_incompatible_src_shape() {
1184        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1185        let src = arr(vec![1.0, 2.0]);
1186        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1187        assert!(copyto_where(&mut dst, &src, &mask).is_err());
1188        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1189    }
1190
1191    #[test]
1192    fn copyto_where_rejects_incompatible_mask_shape() {
1193        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1194        let src = arr(vec![10.0, 20.0, 30.0]);
1195        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1196        assert!(copyto_where(&mut dst, &src, &mask).is_err());
1197        // Validation happens before writes — dst untouched.
1198        assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1199    }
1200
1201    #[test]
1202    fn copy_from_where_method_form_equivalent() {
1203        let mut dst = arr(vec![1.0, 2.0, 3.0]);
1204        let src = arr(vec![10.0, 20.0, 30.0]);
1205        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
1206        dst.copy_from_where(&src, &mask).unwrap();
1207        assert_eq!(dst.as_slice().unwrap(), &[1.0, 20.0, 3.0]);
1208    }
1209
1210    #[test]
1211    fn div_inplace_by_zero_yields_ieee_sentinels() {
1212        // Pin current semantics: division-by-zero in-place produces IEEE
1213        // inf/NaN at the offending positions and does NOT error.
1214        let mut a = arr(vec![1.0, 2.0, 0.0]);
1215        let b = arr(vec![2.0, 0.0, 0.0]);
1216        a.div_inplace(&b).unwrap();
1217        let s = a.as_slice().unwrap();
1218        assert_eq!(s[0], 0.5);
1219        assert!(s[1].is_infinite() && s[1].is_sign_positive());
1220        assert!(s[2].is_nan());
1221    }
1222}