1use crate::array::owned::Array;
19use crate::dimension::Dimension;
20use crate::dimension::IxDyn;
21use crate::dimension::broadcast::{broadcast_shapes, broadcast_to};
22use crate::dtype::Element;
23use crate::error::{FerrayError, FerrayResult};
24
25fn elementwise_binary<T, D, F>(
37 a: &Array<T, D>,
38 b: &Array<T, D>,
39 op: F,
40 op_name: &str,
41) -> FerrayResult<Array<T, D>>
42where
43 T: Element + Copy,
44 D: Dimension,
45 F: Fn(T, T) -> T,
46{
47 if a.shape() == b.shape() {
49 let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
50 return Array::from_vec(a.dim().clone(), data);
51 }
52
53 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
55 FerrayError::shape_mismatch(format!(
56 "operator {}: shapes {:?} and {:?} are not broadcast-compatible",
57 op_name,
58 a.shape(),
59 b.shape()
60 ))
61 })?;
62
63 let a_view = broadcast_to(a, &target_shape)?;
64 let b_view = broadcast_to(b, &target_shape)?;
65
66 let data: Vec<T> = a_view
67 .iter()
68 .zip(b_view.iter())
69 .map(|(&x, &y)| op(x, y))
70 .collect();
71
72 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
73 FerrayError::shape_mismatch(format!(
74 "operator {op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
75 ))
76 })?;
77
78 Array::from_vec(result_dim, data)
79}
80
81fn elementwise_binary_dyn<T, D1, D2, F>(
88 a: &Array<T, D1>,
89 b: &Array<T, D2>,
90 op: F,
91 op_name: &str,
92) -> FerrayResult<Array<T, IxDyn>>
93where
94 T: Element + Copy,
95 D1: Dimension,
96 D2: Dimension,
97 F: Fn(T, T) -> T,
98{
99 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
100 FerrayError::shape_mismatch(format!(
101 "{}: shapes {:?} and {:?} are not broadcast-compatible",
102 op_name,
103 a.shape(),
104 b.shape()
105 ))
106 })?;
107
108 let a_view = broadcast_to(a, &target_shape)?;
109 let b_view = broadcast_to(b, &target_shape)?;
110
111 let data: Vec<T> = a_view
112 .iter()
113 .zip(b_view.iter())
114 .map(|(&x, &y)| op(x, y))
115 .collect();
116
117 Array::from_vec(IxDyn::from(&target_shape[..]), data)
118}
119
120macro_rules! impl_binary_op {
128 ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
129 impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
131 where
132 T: Element + Copy + std::ops::$trait<Output = T>,
133 D: Dimension,
134 {
135 type Output = FerrayResult<Array<T, D>>;
136
137 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
138 elementwise_binary(self, rhs, $op_fn, $op_name)
139 }
140 }
141
142 impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
144 where
145 T: Element + Copy + std::ops::$trait<Output = T>,
146 D: Dimension,
147 {
148 type Output = FerrayResult<Array<T, D>>;
149
150 fn $method(self, rhs: Array<T, D>) -> Self::Output {
151 elementwise_binary(&self, &rhs, $op_fn, $op_name)
152 }
153 }
154
155 impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
157 where
158 T: Element + Copy + std::ops::$trait<Output = T>,
159 D: Dimension,
160 {
161 type Output = FerrayResult<Array<T, D>>;
162
163 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
164 elementwise_binary(&self, rhs, $op_fn, $op_name)
165 }
166 }
167
168 impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
170 where
171 T: Element + Copy + std::ops::$trait<Output = T>,
172 D: Dimension,
173 {
174 type Output = FerrayResult<Array<T, D>>;
175
176 fn $method(self, rhs: Array<T, D>) -> Self::Output {
177 elementwise_binary(self, &rhs, $op_fn, $op_name)
178 }
179 }
180 };
181}
182
183impl_binary_op!(Add, add, |a, b| a + b, "+");
184impl_binary_op!(Sub, sub, |a, b| a - b, "-");
185impl_binary_op!(Mul, mul, |a, b| a * b, "*");
186impl_binary_op!(Div, div, |a, b| a / b, "/");
187impl_binary_op!(Rem, rem, |a, b| a % b, "%");
188
189macro_rules! impl_scalar_op {
195 ($trait:ident, $method:ident, $op_fn:expr) => {
196 impl<T, D> std::ops::$trait<T> for &Array<T, D>
198 where
199 T: Element + Copy + std::ops::$trait<Output = T>,
200 D: Dimension,
201 {
202 type Output = FerrayResult<Array<T, D>>;
203
204 fn $method(self, rhs: T) -> Self::Output {
205 let data: Vec<T> = self.iter().map(|&x| $op_fn(x, rhs)).collect();
206 Array::from_vec(self.dim().clone(), data)
207 }
208 }
209
210 impl<T, D> std::ops::$trait<T> for Array<T, D>
212 where
213 T: Element + Copy + std::ops::$trait<Output = T>,
214 D: Dimension,
215 {
216 type Output = FerrayResult<Array<T, D>>;
217
218 fn $method(self, rhs: T) -> Self::Output {
219 (&self).$method(rhs)
220 }
221 }
222 };
223}
224
225impl_scalar_op!(Add, add, |a, b| a + b);
226impl_scalar_op!(Sub, sub, |a, b| a - b);
227impl_scalar_op!(Mul, mul, |a, b| a * b);
228impl_scalar_op!(Div, div, |a, b| a / b);
229impl_scalar_op!(Rem, rem, |a, b| a % b);
230
231impl<T, D> std::ops::Neg for &Array<T, D>
233where
234 T: Element + Copy + std::ops::Neg<Output = T>,
235 D: Dimension,
236{
237 type Output = FerrayResult<Array<T, D>>;
238
239 fn neg(self) -> Self::Output {
240 let data: Vec<T> = self.iter().map(|&x| -x).collect();
241 Array::from_vec(self.dim().clone(), data)
242 }
243}
244
245impl<T, D> std::ops::Neg for Array<T, D>
246where
247 T: Element + Copy + std::ops::Neg<Output = T>,
248 D: Dimension,
249{
250 type Output = FerrayResult<Self>;
251
252 fn neg(self) -> Self::Output {
253 -&self
254 }
255}
256
257impl<T, D> Array<T, D>
268where
269 T: Element + Copy,
270 D: Dimension,
271{
272 pub fn add_broadcast<D2: Dimension>(
281 &self,
282 other: &Array<T, D2>,
283 ) -> FerrayResult<Array<T, IxDyn>>
284 where
285 T: std::ops::Add<Output = T>,
286 {
287 elementwise_binary_dyn(self, other, |x, y| x + y, "add_broadcast")
288 }
289
290 pub fn sub_broadcast<D2: Dimension>(
294 &self,
295 other: &Array<T, D2>,
296 ) -> FerrayResult<Array<T, IxDyn>>
297 where
298 T: std::ops::Sub<Output = T>,
299 {
300 elementwise_binary_dyn(self, other, |x, y| x - y, "sub_broadcast")
301 }
302
303 pub fn mul_broadcast<D2: Dimension>(
307 &self,
308 other: &Array<T, D2>,
309 ) -> FerrayResult<Array<T, IxDyn>>
310 where
311 T: std::ops::Mul<Output = T>,
312 {
313 elementwise_binary_dyn(self, other, |x, y| x * y, "mul_broadcast")
314 }
315
316 pub fn div_broadcast<D2: Dimension>(
320 &self,
321 other: &Array<T, D2>,
322 ) -> FerrayResult<Array<T, IxDyn>>
323 where
324 T: std::ops::Div<Output = T>,
325 {
326 elementwise_binary_dyn(self, other, |x, y| x / y, "div_broadcast")
327 }
328
329 pub fn rem_broadcast<D2: Dimension>(
333 &self,
334 other: &Array<T, D2>,
335 ) -> FerrayResult<Array<T, IxDyn>>
336 where
337 T: std::ops::Rem<Output = T>,
338 {
339 elementwise_binary_dyn(self, other, |x, y| x % y, "rem_broadcast")
340 }
341}
342
343macro_rules! impl_scalar_op_assign {
366 ($trait:ident, $method:ident, $op:tt) => {
367 impl<T, D> std::ops::$trait<T> for Array<T, D>
368 where
369 T: Element + Copy + std::ops::$trait,
370 D: Dimension,
371 {
372 fn $method(&mut self, rhs: T) {
373 self.mapv_inplace(|mut x| {
374 x $op rhs;
375 x
376 });
377 }
378 }
379 };
380}
381
382impl_scalar_op_assign!(AddAssign, add_assign, +=);
383impl_scalar_op_assign!(SubAssign, sub_assign, -=);
384impl_scalar_op_assign!(MulAssign, mul_assign, *=);
385impl_scalar_op_assign!(DivAssign, div_assign, /=);
386impl_scalar_op_assign!(RemAssign, rem_assign, %=);
387
388fn inplace_binary<T, D, F>(
394 lhs: &mut Array<T, D>,
395 rhs: &Array<T, D>,
396 op: F,
397 op_name: &str,
398) -> FerrayResult<()>
399where
400 T: Element + Copy,
401 D: Dimension,
402 F: Fn(T, T) -> T,
403{
404 if lhs.shape() == rhs.shape() {
406 return lhs.zip_mut_with(rhs, |a, b| *a = op(*a, *b));
407 }
408
409 let target_shape: Vec<usize> = lhs.shape().to_vec();
413 let rhs_view = broadcast_to(rhs, &target_shape).map_err(|_| {
414 FerrayError::shape_mismatch(format!(
415 "{}: shape {:?} cannot be broadcast into destination shape {:?}",
416 op_name,
417 rhs.shape(),
418 target_shape
419 ))
420 })?;
421
422 for (a, b) in lhs.iter_mut().zip(rhs_view.iter()) {
423 *a = op(*a, *b);
424 }
425 Ok(())
426}
427
428impl<T, D> Array<T, D>
429where
430 T: Element + Copy,
431 D: Dimension,
432{
433 pub fn add_inplace(&mut self, other: &Self) -> FerrayResult<()>
443 where
444 T: std::ops::Add<Output = T>,
445 {
446 inplace_binary(self, other, |a, b| a + b, "add_inplace")
447 }
448
449 pub fn sub_inplace(&mut self, other: &Self) -> FerrayResult<()>
451 where
452 T: std::ops::Sub<Output = T>,
453 {
454 inplace_binary(self, other, |a, b| a - b, "sub_inplace")
455 }
456
457 pub fn mul_inplace(&mut self, other: &Self) -> FerrayResult<()>
459 where
460 T: std::ops::Mul<Output = T>,
461 {
462 inplace_binary(self, other, |a, b| a * b, "mul_inplace")
463 }
464
465 pub fn div_inplace(&mut self, other: &Self) -> FerrayResult<()>
467 where
468 T: std::ops::Div<Output = T>,
469 {
470 inplace_binary(self, other, |a, b| a / b, "div_inplace")
471 }
472
473 pub fn rem_inplace(&mut self, other: &Self) -> FerrayResult<()>
475 where
476 T: std::ops::Rem<Output = T>,
477 {
478 inplace_binary(self, other, |a, b| a % b, "rem_inplace")
479 }
480}
481
482pub fn copyto<T, D1, D2>(dst: &mut Array<T, D1>, src: &Array<T, D2>) -> FerrayResult<()>
508where
509 T: Element,
510 D1: Dimension,
511 D2: Dimension,
512{
513 if dst.shape() == src.shape() {
515 for (d, s) in dst.iter_mut().zip(src.iter()) {
516 *d = s.clone();
517 }
518 return Ok(());
519 }
520
521 let target_shape: Vec<usize> = dst.shape().to_vec();
524 let src_view = broadcast_to(src, &target_shape).map_err(|_| {
525 FerrayError::shape_mismatch(format!(
526 "copyto: source shape {:?} cannot be broadcast into destination shape {:?}",
527 src.shape(),
528 target_shape
529 ))
530 })?;
531
532 for (d, s) in dst.iter_mut().zip(src_view.iter()) {
533 *d = s.clone();
534 }
535 Ok(())
536}
537
538pub fn copyto_where<T, D1, D2, D3>(
552 dst: &mut Array<T, D1>,
553 src: &Array<T, D2>,
554 mask: &Array<bool, D3>,
555) -> FerrayResult<()>
556where
557 T: Element,
558 D1: Dimension,
559 D2: Dimension,
560 D3: Dimension,
561{
562 let target_shape: Vec<usize> = dst.shape().to_vec();
565
566 let src_view = broadcast_to(src, &target_shape).map_err(|_| {
567 FerrayError::shape_mismatch(format!(
568 "copyto_where: source shape {:?} cannot be broadcast into destination shape {:?}",
569 src.shape(),
570 target_shape
571 ))
572 })?;
573
574 let mask_view = broadcast_to(mask, &target_shape).map_err(|_| {
575 FerrayError::shape_mismatch(format!(
576 "copyto_where: mask shape {:?} cannot be broadcast into destination shape {:?}",
577 mask.shape(),
578 target_shape
579 ))
580 })?;
581
582 for ((d, s), &m) in dst.iter_mut().zip(src_view.iter()).zip(mask_view.iter()) {
583 if m {
584 *d = s.clone();
585 }
586 }
587 Ok(())
588}
589
590impl<T, D> Array<T, D>
591where
592 T: Element,
593 D: Dimension,
594{
595 pub fn copy_from<D2: Dimension>(&mut self, src: &Array<T, D2>) -> FerrayResult<()> {
603 copyto(self, src)
604 }
605
606 pub fn copy_from_where<D2: Dimension, D3: Dimension>(
615 &mut self,
616 src: &Array<T, D2>,
617 mask: &Array<bool, D3>,
618 ) -> FerrayResult<()> {
619 copyto_where(self, src, mask)
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use crate::dimension::Ix1;
627
628 fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
629 let n = data.len();
630 Array::from_vec(Ix1::new([n]), data).unwrap()
631 }
632
633 fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
634 let n = data.len();
635 Array::from_vec(Ix1::new([n]), data).unwrap()
636 }
637
638 #[test]
639 fn test_add_ref_ref() {
640 let a = arr(vec![1.0, 2.0, 3.0]);
641 let b = arr(vec![4.0, 5.0, 6.0]);
642 let c = (&a + &b).unwrap();
643 assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
644 }
645
646 #[test]
647 fn test_add_owned_owned() {
648 let a = arr(vec![1.0, 2.0]);
649 let b = arr(vec![3.0, 4.0]);
650 let c = (a + b).unwrap();
651 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
652 }
653
654 #[test]
655 fn test_add_mixed() {
656 let a = arr(vec![1.0, 2.0]);
657 let b = arr(vec![3.0, 4.0]);
658 let c = (a + &b).unwrap();
659 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
660
661 let d = arr(vec![10.0, 20.0]);
662 let e = (&b + d).unwrap();
663 assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
664 }
665
666 #[test]
667 fn test_sub() {
668 let a = arr(vec![5.0, 7.0]);
669 let b = arr(vec![1.0, 2.0]);
670 let c = (&a - &b).unwrap();
671 assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
672 }
673
674 #[test]
675 fn test_mul() {
676 let a = arr(vec![2.0, 3.0]);
677 let b = arr(vec![4.0, 5.0]);
678 let c = (&a * &b).unwrap();
679 assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
680 }
681
682 #[test]
683 fn test_div() {
684 let a = arr(vec![10.0, 20.0]);
685 let b = arr(vec![2.0, 5.0]);
686 let c = (&a / &b).unwrap();
687 assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
688 }
689
690 #[test]
691 fn test_rem() {
692 let a = arr_i32(vec![7, 10]);
693 let b = arr_i32(vec![3, 4]);
694 let c = (&a % &b).unwrap();
695 assert_eq!(c.as_slice().unwrap(), &[1, 2]);
696 }
697
698 #[test]
699 fn test_neg() {
700 let a = arr(vec![1.0, -2.0, 3.0]);
701 let b = (-&a).unwrap();
702 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
703 }
704
705 #[test]
706 fn test_neg_owned() {
707 let a = arr(vec![1.0, -2.0]);
708 let b = (-a).unwrap();
709 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
710 }
711
712 #[test]
713 fn test_shape_mismatch_errors() {
714 let a = arr(vec![1.0, 2.0]);
715 let b = arr(vec![1.0, 2.0, 3.0]);
716 let result = &a + &b;
717 assert!(result.is_err());
718 }
719
720 #[test]
723 fn test_add_scalar() {
724 let a = arr(vec![1.0, 2.0, 3.0]);
725 let c = (&a + 10.0).unwrap();
726 assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
727 }
728
729 #[test]
730 fn test_sub_scalar() {
731 let a = arr(vec![10.0, 20.0, 30.0]);
732 let c = (&a - 5.0).unwrap();
733 assert_eq!(c.as_slice().unwrap(), &[5.0, 15.0, 25.0]);
734 }
735
736 #[test]
737 fn test_mul_scalar() {
738 let a = arr(vec![1.0, 2.0, 3.0]);
739 let c = (&a * 3.0).unwrap();
740 assert_eq!(c.as_slice().unwrap(), &[3.0, 6.0, 9.0]);
741 }
742
743 #[test]
744 fn test_div_scalar() {
745 let a = arr(vec![10.0, 20.0, 30.0]);
746 let c = (&a / 10.0).unwrap();
747 assert_eq!(c.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
748 }
749
750 #[test]
751 fn test_rem_scalar() {
752 let a = arr_i32(vec![7, 10, 15]);
753 let c = (&a % 4).unwrap();
754 assert_eq!(c.as_slice().unwrap(), &[3, 2, 3]);
755 }
756
757 #[test]
758 fn test_scalar_op_owned() {
759 let a = arr(vec![1.0, 2.0, 3.0]);
760 let c = (a + 10.0).unwrap();
761 assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
762 }
763
764 #[test]
765 fn test_chained_ops() {
766 let a = arr(vec![1.0, 2.0, 3.0]);
767 let b = arr(vec![4.0, 5.0, 6.0]);
768 let c = arr(vec![10.0, 10.0, 10.0]);
769 let result = (&(&a + &b).unwrap() * &c).unwrap();
771 assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
772 }
773
774 use crate::dimension::{Ix2, Ix3, IxDyn};
779
780 #[test]
781 fn test_broadcast_2d_row_plus_column() {
782 let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
784 let row =
785 Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
786 let result = (&col + &row).unwrap();
787 assert_eq!(result.shape(), &[3, 4]);
788 assert_eq!(
789 result.as_slice().unwrap(),
790 &[
791 11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0, ]
795 );
796 }
797
798 #[test]
799 fn test_broadcast_2d_stretch_one_axis() {
800 let a = Array::<f64, Ix2>::from_vec(
802 Ix2::new([3, 4]),
803 vec![
804 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
805 ],
806 )
807 .unwrap();
808 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![100.0, 200.0, 300.0, 400.0])
809 .unwrap();
810 let result = (&a + &b).unwrap();
811 assert_eq!(result.shape(), &[3, 4]);
812 assert_eq!(
813 result.as_slice().unwrap(),
814 &[
815 101.0, 202.0, 303.0, 404.0, 105.0, 206.0, 307.0, 408.0, 109.0, 210.0, 311.0, 412.0,
816 ]
817 );
818 }
819
820 #[test]
821 fn test_broadcast_3d_with_2d_axis() {
822 let a =
824 Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), (1..=24).map(|i| i as f64).collect())
825 .unwrap();
826 let b =
827 Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (1..=12).map(|i| i as f64).collect())
828 .unwrap();
829 let result = (&a - &b).unwrap();
830 assert_eq!(result.shape(), &[2, 3, 4]);
831 let first_half: Vec<f64> = (1..=12).map(|_| 0.0).collect();
833 assert_eq!(&result.as_slice().unwrap()[..12], &first_half[..]);
834 let second_half: Vec<f64> = (0..12).map(|_| 12.0).collect();
836 assert_eq!(&result.as_slice().unwrap()[12..], &second_half[..]);
837 }
838
839 #[test]
840 fn test_broadcast_incompatible_shapes_error() {
841 let a = arr(vec![1.0, 2.0, 3.0]);
843 let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
844 let result = &a + &b;
845 assert!(result.is_err());
846
847 let c = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![0.0; 12]).unwrap();
849 let d = Array::<f64, Ix2>::from_vec(Ix2::new([3, 5]), vec![0.0; 15]).unwrap();
850 assert!((&c + &d).is_err());
851 }
852
853 #[test]
854 fn test_broadcast_mul_2d() {
855 let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
857 let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
858 let result = (&col * &row).unwrap();
859 assert_eq!(result.shape(), &[3, 3]);
860 assert_eq!(
861 result.as_slice().unwrap(),
862 &[10, 20, 30, 20, 40, 60, 30, 60, 90]
863 );
864 }
865
866 #[test]
871 fn test_add_broadcast_1d_plus_2d() {
872 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
874 let m =
875 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
876 .unwrap();
877 let result = v.add_broadcast(&m).unwrap();
878 assert_eq!(result.shape(), &[2, 3]);
879 assert_eq!(
880 result.as_slice().unwrap(),
881 &[11.0, 22.0, 33.0, 41.0, 52.0, 63.0]
882 );
883 }
884
885 #[test]
886 fn test_add_broadcast_1d_plus_column() {
887 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
889 let col = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
890 let result = v.add_broadcast(&col).unwrap();
891 assert_eq!(result.shape(), &[2, 3]);
892 assert_eq!(
893 result.as_slice().unwrap(),
894 &[11.0, 12.0, 13.0, 21.0, 22.0, 23.0]
895 );
896 }
897
898 #[test]
899 fn test_sub_broadcast_2d_minus_1d() {
900 let m =
902 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
903 .unwrap();
904 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
905 let result = m.sub_broadcast(&v).unwrap();
906 assert_eq!(result.shape(), &[2, 3]);
907 assert_eq!(
908 result.as_slice().unwrap(),
909 &[9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
910 );
911 }
912
913 #[test]
914 fn test_mul_broadcast_returns_dyn() {
915 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
916 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
917 let result: Array<f64, IxDyn> = a.mul_broadcast(&b).unwrap();
918 assert_eq!(result.shape(), &[2, 3]);
919 assert_eq!(
920 result.as_slice().unwrap(),
921 &[10.0, 20.0, 30.0, 20.0, 40.0, 60.0]
922 );
923 }
924
925 #[test]
926 fn test_div_broadcast_incompatible() {
927 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
928 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
929 assert!(a.div_broadcast(&b).is_err());
930 }
931
932 #[test]
933 fn test_rem_broadcast_2d() {
934 let a =
935 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
936 let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![3, 7, 11]).unwrap();
937 let result = a.rem_broadcast(&b).unwrap();
938 assert_eq!(result.shape(), &[2, 3]);
939 assert_eq!(
940 result.as_slice().unwrap(),
941 &[10 % 3, 20 % 7, 30 % 11, 40 % 3, 50 % 7, 60 % 11]
942 );
943 }
944
945 #[test]
950 fn scalar_add_assign_mutates_in_place() {
951 let mut a = arr(vec![1.0, 2.0, 3.0]);
952 a += 10.0;
953 assert_eq!(a.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
954 }
955
956 #[test]
957 fn scalar_sub_mul_div_rem_assign() {
958 let mut a = arr(vec![10.0, 20.0, 30.0]);
959 a -= 1.0;
960 assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
961 a *= 2.0;
962 assert_eq!(a.as_slice().unwrap(), &[18.0, 38.0, 58.0]);
963 a /= 2.0;
964 assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
965 let mut b = arr_i32(vec![10, 11, 12]);
966 b %= 3;
967 assert_eq!(b.as_slice().unwrap(), &[1, 2, 0]);
968 }
969
970 #[test]
971 fn scalar_assign_preserves_shape_ix2() {
972 let mut a =
973 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
974 .unwrap();
975 a += 1.0;
976 assert_eq!(a.shape(), &[2, 3]);
977 assert_eq!(a.as_slice().unwrap(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
978 }
979
980 #[test]
981 fn add_inplace_same_shape_fast_path() {
982 let mut a = arr(vec![1.0, 2.0, 3.0]);
983 let b = arr(vec![10.0, 20.0, 30.0]);
984 a.add_inplace(&b).unwrap();
985 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
986 }
987
988 #[test]
989 fn sub_mul_div_rem_inplace_same_shape() {
990 let mut a = arr(vec![10.0, 20.0, 30.0]);
991 let b = arr(vec![1.0, 2.0, 3.0]);
992 a.sub_inplace(&b).unwrap();
993 assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
994 a.mul_inplace(&b).unwrap();
995 assert_eq!(a.as_slice().unwrap(), &[9.0, 36.0, 81.0]);
996 a.div_inplace(&b).unwrap();
997 assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
998 let mut c = arr_i32(vec![10, 20, 30]);
999 let d = arr_i32(vec![3, 7, 11]);
1000 c.rem_inplace(&d).unwrap();
1001 assert_eq!(c.as_slice().unwrap(), &[1, 6, 8]);
1002 }
1003
1004 #[test]
1005 fn add_inplace_broadcasts_rhs_into_lhs_shape() {
1006 let mut a =
1008 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1009 .unwrap();
1010 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1011 a.add_inplace(&b).unwrap();
1012 assert_eq!(a.shape(), &[2, 3]);
1013 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
1014 }
1015
1016 #[test]
1017 fn add_inplace_broadcasts_column_into_rows() {
1018 let mut a =
1020 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1021 .unwrap();
1022 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![100.0, 200.0]).unwrap();
1023 a.add_inplace(&b).unwrap();
1024 assert_eq!(
1025 a.as_slice().unwrap(),
1026 &[101.0, 102.0, 103.0, 204.0, 205.0, 206.0]
1027 );
1028 }
1029
1030 #[test]
1031 fn add_inplace_rejects_incompatible_rhs() {
1032 let mut a = arr(vec![1.0, 2.0, 3.0]);
1034 let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
1035 assert!(a.add_inplace(&b).is_err());
1036 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1038 }
1039
1040 #[test]
1041 fn add_inplace_rejects_growing_shape() {
1042 let mut a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1045 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0; 6]).unwrap();
1046 assert!(a.add_inplace(&b).is_err());
1047 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1048 }
1049
1050 #[test]
1051 fn copyto_same_shape_fast_path() {
1052 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1053 let src = arr(vec![1.0, 2.0, 3.0]);
1054 copyto(&mut dst, &src).unwrap();
1055 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1056 }
1057
1058 #[test]
1059 fn copyto_broadcasts_row_into_matrix() {
1060 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1062 let src = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1063 copyto(&mut dst, &src).unwrap();
1064 assert_eq!(
1065 dst.as_slice().unwrap(),
1066 &[10.0, 20.0, 30.0, 10.0, 20.0, 30.0]
1067 );
1068 }
1069
1070 #[test]
1071 fn copyto_broadcasts_cross_rank_src() {
1072 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1074 let src = arr(vec![7.0, 8.0, 9.0]);
1075 copyto(&mut dst, &src).unwrap();
1076 assert_eq!(dst.as_slice().unwrap(), &[7.0, 8.0, 9.0, 7.0, 8.0, 9.0]);
1077 }
1078
1079 #[test]
1080 fn copyto_scalar_src_broadcasts_to_full_dst() {
1081 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1083 let src = arr(vec![42.0]);
1084 copyto(&mut dst, &src).unwrap();
1085 assert_eq!(dst.as_slice().unwrap(), &[42.0; 6]);
1086 }
1087
1088 #[test]
1089 fn copyto_rejects_growing_dst() {
1090 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1092 let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![99.0; 6]).unwrap();
1093 assert!(copyto(&mut dst, &src).is_err());
1094 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1095 }
1096
1097 #[test]
1098 fn copyto_rejects_incompatible_shapes() {
1099 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1100 let src = arr(vec![1.0, 2.0, 3.0, 4.0]);
1101 assert!(copyto(&mut dst, &src).is_err());
1102 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1103 }
1104
1105 #[test]
1106 fn copyto_method_form_equivalent_to_function() {
1107 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1108 let src = arr(vec![1.0, 2.0, 3.0]);
1109 dst.copy_from(&src).unwrap();
1110 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1111 }
1112
1113 #[test]
1114 fn copyto_works_for_non_copy_element_type_i64() {
1115 let mut dst = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![0, 0, 0, 0]).unwrap();
1118 let src = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1119 copyto(&mut dst, &src).unwrap();
1120 assert_eq!(dst.as_slice().unwrap(), &[1, 2, 3, 4]);
1121 }
1122
1123 #[test]
1124 fn copyto_where_same_shape_only_writes_masked_positions() {
1125 let mut dst = arr(vec![1.0, 2.0, 3.0, 4.0]);
1126 let src = arr(vec![10.0, 20.0, 30.0, 40.0]);
1127 let mask =
1128 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1129 copyto_where(&mut dst, &src, &mask).unwrap();
1130 assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 4.0]);
1131 }
1132
1133 #[test]
1134 fn copyto_where_broadcasts_mask_across_dst() {
1135 let mut dst =
1137 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1138 .unwrap();
1139 let src =
1140 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
1141 .unwrap();
1142 let mask = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
1143 copyto_where(&mut dst, &src, &mask).unwrap();
1144 assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 40.0, 5.0, 60.0]);
1145 }
1146
1147 #[test]
1148 fn copyto_where_broadcasts_scalar_src_with_mask() {
1149 let mut dst =
1151 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1152 .unwrap();
1153 let src = arr(vec![99.0]);
1154 let mask = Array::<bool, Ix2>::from_vec(
1155 Ix2::new([2, 3]),
1156 vec![true, false, true, false, true, false],
1157 )
1158 .unwrap();
1159 copyto_where(&mut dst, &src, &mask).unwrap();
1160 assert_eq!(dst.as_slice().unwrap(), &[99.0, 2.0, 99.0, 4.0, 99.0, 6.0]);
1161 }
1162
1163 #[test]
1164 fn copyto_where_all_false_mask_is_noop() {
1165 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1166 let original = dst.as_slice().unwrap().to_vec();
1167 let src = arr(vec![99.0, 99.0, 99.0]);
1168 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1169 copyto_where(&mut dst, &src, &mask).unwrap();
1170 assert_eq!(dst.as_slice().unwrap(), &original[..]);
1171 }
1172
1173 #[test]
1174 fn copyto_where_all_true_mask_matches_copyto() {
1175 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1176 let src = arr(vec![1.0, 2.0, 3.0]);
1177 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1178 copyto_where(&mut dst, &src, &mask).unwrap();
1179 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1180 }
1181
1182 #[test]
1183 fn copyto_where_rejects_incompatible_src_shape() {
1184 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1185 let src = arr(vec![1.0, 2.0]);
1186 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1187 assert!(copyto_where(&mut dst, &src, &mask).is_err());
1188 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1189 }
1190
1191 #[test]
1192 fn copyto_where_rejects_incompatible_mask_shape() {
1193 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1194 let src = arr(vec![10.0, 20.0, 30.0]);
1195 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1196 assert!(copyto_where(&mut dst, &src, &mask).is_err());
1197 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1199 }
1200
1201 #[test]
1202 fn copy_from_where_method_form_equivalent() {
1203 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1204 let src = arr(vec![10.0, 20.0, 30.0]);
1205 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
1206 dst.copy_from_where(&src, &mask).unwrap();
1207 assert_eq!(dst.as_slice().unwrap(), &[1.0, 20.0, 3.0]);
1208 }
1209
1210 #[test]
1211 fn div_inplace_by_zero_yields_ieee_sentinels() {
1212 let mut a = arr(vec![1.0, 2.0, 0.0]);
1215 let b = arr(vec![2.0, 0.0, 0.0]);
1216 a.div_inplace(&b).unwrap();
1217 let s = a.as_slice().unwrap();
1218 assert_eq!(s[0], 0.5);
1219 assert!(s[1].is_infinite() && s[1].is_sign_positive());
1220 assert!(s[2].is_nan());
1221 }
1222}