1use crate::array::owned::Array;
19use crate::dimension::Dimension;
20use crate::dimension::IxDyn;
21use crate::dimension::broadcast::{broadcast_shapes, broadcast_to};
22use crate::dtype::Element;
23use crate::error::{FerrayError, FerrayResult};
24
25fn elementwise_binary<T, D, F>(
37 a: &Array<T, D>,
38 b: &Array<T, D>,
39 op: F,
40 op_name: &str,
41) -> FerrayResult<Array<T, D>>
42where
43 T: Element + Copy,
44 D: Dimension,
45 F: Fn(T, T) -> T,
46{
47 if a.shape() == b.shape() {
49 let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
50 return Array::from_vec(a.dim().clone(), data);
51 }
52
53 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
55 FerrayError::shape_mismatch(format!(
56 "operator {}: shapes {:?} and {:?} are not broadcast-compatible",
57 op_name,
58 a.shape(),
59 b.shape()
60 ))
61 })?;
62
63 let a_view = broadcast_to(a, &target_shape)?;
64 let b_view = broadcast_to(b, &target_shape)?;
65
66 let data: Vec<T> = a_view
67 .iter()
68 .zip(b_view.iter())
69 .map(|(&x, &y)| op(x, y))
70 .collect();
71
72 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
73 FerrayError::shape_mismatch(format!(
74 "operator {}: cannot represent broadcast result shape {:?} as the input dimension type",
75 op_name, target_shape
76 ))
77 })?;
78
79 Array::from_vec(result_dim, data)
80}
81
82fn elementwise_binary_dyn<T, D1, D2, F>(
89 a: &Array<T, D1>,
90 b: &Array<T, D2>,
91 op: F,
92 op_name: &str,
93) -> FerrayResult<Array<T, IxDyn>>
94where
95 T: Element + Copy,
96 D1: Dimension,
97 D2: Dimension,
98 F: Fn(T, T) -> T,
99{
100 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
101 FerrayError::shape_mismatch(format!(
102 "{}: shapes {:?} and {:?} are not broadcast-compatible",
103 op_name,
104 a.shape(),
105 b.shape()
106 ))
107 })?;
108
109 let a_view = broadcast_to(a, &target_shape)?;
110 let b_view = broadcast_to(b, &target_shape)?;
111
112 let data: Vec<T> = a_view
113 .iter()
114 .zip(b_view.iter())
115 .map(|(&x, &y)| op(x, y))
116 .collect();
117
118 Array::from_vec(IxDyn::from(&target_shape[..]), data)
119}
120
121macro_rules! impl_binary_op {
129 ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
130 impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
132 where
133 T: Element + Copy + std::ops::$trait<Output = T>,
134 D: Dimension,
135 {
136 type Output = FerrayResult<Array<T, D>>;
137
138 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
139 elementwise_binary(self, rhs, $op_fn, $op_name)
140 }
141 }
142
143 impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
145 where
146 T: Element + Copy + std::ops::$trait<Output = T>,
147 D: Dimension,
148 {
149 type Output = FerrayResult<Array<T, D>>;
150
151 fn $method(self, rhs: Array<T, D>) -> Self::Output {
152 elementwise_binary(&self, &rhs, $op_fn, $op_name)
153 }
154 }
155
156 impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
158 where
159 T: Element + Copy + std::ops::$trait<Output = T>,
160 D: Dimension,
161 {
162 type Output = FerrayResult<Array<T, D>>;
163
164 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
165 elementwise_binary(&self, rhs, $op_fn, $op_name)
166 }
167 }
168
169 impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
171 where
172 T: Element + Copy + std::ops::$trait<Output = T>,
173 D: Dimension,
174 {
175 type Output = FerrayResult<Array<T, D>>;
176
177 fn $method(self, rhs: Array<T, D>) -> Self::Output {
178 elementwise_binary(self, &rhs, $op_fn, $op_name)
179 }
180 }
181 };
182}
183
184impl_binary_op!(Add, add, |a, b| a + b, "+");
185impl_binary_op!(Sub, sub, |a, b| a - b, "-");
186impl_binary_op!(Mul, mul, |a, b| a * b, "*");
187impl_binary_op!(Div, div, |a, b| a / b, "/");
188impl_binary_op!(Rem, rem, |a, b| a % b, "%");
189
190macro_rules! impl_scalar_op {
196 ($trait:ident, $method:ident, $op_fn:expr) => {
197 impl<T, D> std::ops::$trait<T> for &Array<T, D>
199 where
200 T: Element + Copy + std::ops::$trait<Output = T>,
201 D: Dimension,
202 {
203 type Output = FerrayResult<Array<T, D>>;
204
205 fn $method(self, rhs: T) -> Self::Output {
206 let data: Vec<T> = self.iter().map(|&x| $op_fn(x, rhs)).collect();
207 Array::from_vec(self.dim().clone(), data)
208 }
209 }
210
211 impl<T, D> std::ops::$trait<T> for Array<T, D>
213 where
214 T: Element + Copy + std::ops::$trait<Output = T>,
215 D: Dimension,
216 {
217 type Output = FerrayResult<Array<T, D>>;
218
219 fn $method(self, rhs: T) -> Self::Output {
220 (&self).$method(rhs)
221 }
222 }
223 };
224}
225
226impl_scalar_op!(Add, add, |a, b| a + b);
227impl_scalar_op!(Sub, sub, |a, b| a - b);
228impl_scalar_op!(Mul, mul, |a, b| a * b);
229impl_scalar_op!(Div, div, |a, b| a / b);
230impl_scalar_op!(Rem, rem, |a, b| a % b);
231
232impl<T, D> std::ops::Neg for &Array<T, D>
234where
235 T: Element + Copy + std::ops::Neg<Output = T>,
236 D: Dimension,
237{
238 type Output = FerrayResult<Array<T, D>>;
239
240 fn neg(self) -> Self::Output {
241 let data: Vec<T> = self.iter().map(|&x| -x).collect();
242 Array::from_vec(self.dim().clone(), data)
243 }
244}
245
246impl<T, D> std::ops::Neg for Array<T, D>
247where
248 T: Element + Copy + std::ops::Neg<Output = T>,
249 D: Dimension,
250{
251 type Output = FerrayResult<Array<T, D>>;
252
253 fn neg(self) -> Self::Output {
254 -&self
255 }
256}
257
258impl<T, D> Array<T, D>
269where
270 T: Element + Copy,
271 D: Dimension,
272{
273 pub fn add_broadcast<D2: Dimension>(
282 &self,
283 other: &Array<T, D2>,
284 ) -> FerrayResult<Array<T, IxDyn>>
285 where
286 T: std::ops::Add<Output = T>,
287 {
288 elementwise_binary_dyn(self, other, |x, y| x + y, "add_broadcast")
289 }
290
291 pub fn sub_broadcast<D2: Dimension>(
295 &self,
296 other: &Array<T, D2>,
297 ) -> FerrayResult<Array<T, IxDyn>>
298 where
299 T: std::ops::Sub<Output = T>,
300 {
301 elementwise_binary_dyn(self, other, |x, y| x - y, "sub_broadcast")
302 }
303
304 pub fn mul_broadcast<D2: Dimension>(
308 &self,
309 other: &Array<T, D2>,
310 ) -> FerrayResult<Array<T, IxDyn>>
311 where
312 T: std::ops::Mul<Output = T>,
313 {
314 elementwise_binary_dyn(self, other, |x, y| x * y, "mul_broadcast")
315 }
316
317 pub fn div_broadcast<D2: Dimension>(
321 &self,
322 other: &Array<T, D2>,
323 ) -> FerrayResult<Array<T, IxDyn>>
324 where
325 T: std::ops::Div<Output = T>,
326 {
327 elementwise_binary_dyn(self, other, |x, y| x / y, "div_broadcast")
328 }
329
330 pub fn rem_broadcast<D2: Dimension>(
334 &self,
335 other: &Array<T, D2>,
336 ) -> FerrayResult<Array<T, IxDyn>>
337 where
338 T: std::ops::Rem<Output = T>,
339 {
340 elementwise_binary_dyn(self, other, |x, y| x % y, "rem_broadcast")
341 }
342}
343
344macro_rules! impl_scalar_op_assign {
367 ($trait:ident, $method:ident, $op:tt) => {
368 impl<T, D> std::ops::$trait<T> for Array<T, D>
369 where
370 T: Element + Copy + std::ops::$trait,
371 D: Dimension,
372 {
373 fn $method(&mut self, rhs: T) {
374 self.mapv_inplace(|mut x| {
375 x $op rhs;
376 x
377 });
378 }
379 }
380 };
381}
382
383impl_scalar_op_assign!(AddAssign, add_assign, +=);
384impl_scalar_op_assign!(SubAssign, sub_assign, -=);
385impl_scalar_op_assign!(MulAssign, mul_assign, *=);
386impl_scalar_op_assign!(DivAssign, div_assign, /=);
387impl_scalar_op_assign!(RemAssign, rem_assign, %=);
388
389fn inplace_binary<T, D, F>(
395 lhs: &mut Array<T, D>,
396 rhs: &Array<T, D>,
397 op: F,
398 op_name: &str,
399) -> FerrayResult<()>
400where
401 T: Element + Copy,
402 D: Dimension,
403 F: Fn(T, T) -> T,
404{
405 if lhs.shape() == rhs.shape() {
407 return lhs.zip_mut_with(rhs, |a, b| *a = op(*a, *b));
408 }
409
410 let target_shape: Vec<usize> = lhs.shape().to_vec();
414 let rhs_view = broadcast_to(rhs, &target_shape).map_err(|_| {
415 FerrayError::shape_mismatch(format!(
416 "{}: shape {:?} cannot be broadcast into destination shape {:?}",
417 op_name,
418 rhs.shape(),
419 target_shape
420 ))
421 })?;
422
423 for (a, b) in lhs.iter_mut().zip(rhs_view.iter()) {
424 *a = op(*a, *b);
425 }
426 Ok(())
427}
428
429impl<T, D> Array<T, D>
430where
431 T: Element + Copy,
432 D: Dimension,
433{
434 pub fn add_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
444 where
445 T: std::ops::Add<Output = T>,
446 {
447 inplace_binary(self, other, |a, b| a + b, "add_inplace")
448 }
449
450 pub fn sub_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
452 where
453 T: std::ops::Sub<Output = T>,
454 {
455 inplace_binary(self, other, |a, b| a - b, "sub_inplace")
456 }
457
458 pub fn mul_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
460 where
461 T: std::ops::Mul<Output = T>,
462 {
463 inplace_binary(self, other, |a, b| a * b, "mul_inplace")
464 }
465
466 pub fn div_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
468 where
469 T: std::ops::Div<Output = T>,
470 {
471 inplace_binary(self, other, |a, b| a / b, "div_inplace")
472 }
473
474 pub fn rem_inplace(&mut self, other: &Array<T, D>) -> FerrayResult<()>
476 where
477 T: std::ops::Rem<Output = T>,
478 {
479 inplace_binary(self, other, |a, b| a % b, "rem_inplace")
480 }
481}
482
483pub fn copyto<T, D1, D2>(dst: &mut Array<T, D1>, src: &Array<T, D2>) -> FerrayResult<()>
509where
510 T: Element,
511 D1: Dimension,
512 D2: Dimension,
513{
514 if dst.shape() == src.shape() {
516 for (d, s) in dst.iter_mut().zip(src.iter()) {
517 *d = s.clone();
518 }
519 return Ok(());
520 }
521
522 let target_shape: Vec<usize> = dst.shape().to_vec();
525 let src_view = broadcast_to(src, &target_shape).map_err(|_| {
526 FerrayError::shape_mismatch(format!(
527 "copyto: source shape {:?} cannot be broadcast into destination shape {:?}",
528 src.shape(),
529 target_shape
530 ))
531 })?;
532
533 for (d, s) in dst.iter_mut().zip(src_view.iter()) {
534 *d = s.clone();
535 }
536 Ok(())
537}
538
539pub fn copyto_where<T, D1, D2, D3>(
553 dst: &mut Array<T, D1>,
554 src: &Array<T, D2>,
555 mask: &Array<bool, D3>,
556) -> FerrayResult<()>
557where
558 T: Element,
559 D1: Dimension,
560 D2: Dimension,
561 D3: Dimension,
562{
563 let target_shape: Vec<usize> = dst.shape().to_vec();
566
567 let src_view = broadcast_to(src, &target_shape).map_err(|_| {
568 FerrayError::shape_mismatch(format!(
569 "copyto_where: source shape {:?} cannot be broadcast into destination shape {:?}",
570 src.shape(),
571 target_shape
572 ))
573 })?;
574
575 let mask_view = broadcast_to(mask, &target_shape).map_err(|_| {
576 FerrayError::shape_mismatch(format!(
577 "copyto_where: mask shape {:?} cannot be broadcast into destination shape {:?}",
578 mask.shape(),
579 target_shape
580 ))
581 })?;
582
583 for ((d, s), &m) in dst.iter_mut().zip(src_view.iter()).zip(mask_view.iter()) {
584 if m {
585 *d = s.clone();
586 }
587 }
588 Ok(())
589}
590
591impl<T, D> Array<T, D>
592where
593 T: Element,
594 D: Dimension,
595{
596 pub fn copy_from<D2: Dimension>(&mut self, src: &Array<T, D2>) -> FerrayResult<()> {
604 copyto(self, src)
605 }
606
607 pub fn copy_from_where<D2: Dimension, D3: Dimension>(
616 &mut self,
617 src: &Array<T, D2>,
618 mask: &Array<bool, D3>,
619 ) -> FerrayResult<()> {
620 copyto_where(self, src, mask)
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use crate::dimension::Ix1;
628
629 fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
630 let n = data.len();
631 Array::from_vec(Ix1::new([n]), data).unwrap()
632 }
633
634 fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
635 let n = data.len();
636 Array::from_vec(Ix1::new([n]), data).unwrap()
637 }
638
639 #[test]
640 fn test_add_ref_ref() {
641 let a = arr(vec![1.0, 2.0, 3.0]);
642 let b = arr(vec![4.0, 5.0, 6.0]);
643 let c = (&a + &b).unwrap();
644 assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
645 }
646
647 #[test]
648 fn test_add_owned_owned() {
649 let a = arr(vec![1.0, 2.0]);
650 let b = arr(vec![3.0, 4.0]);
651 let c = (a + b).unwrap();
652 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
653 }
654
655 #[test]
656 fn test_add_mixed() {
657 let a = arr(vec![1.0, 2.0]);
658 let b = arr(vec![3.0, 4.0]);
659 let c = (a + &b).unwrap();
660 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
661
662 let d = arr(vec![10.0, 20.0]);
663 let e = (&b + d).unwrap();
664 assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
665 }
666
667 #[test]
668 fn test_sub() {
669 let a = arr(vec![5.0, 7.0]);
670 let b = arr(vec![1.0, 2.0]);
671 let c = (&a - &b).unwrap();
672 assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
673 }
674
675 #[test]
676 fn test_mul() {
677 let a = arr(vec![2.0, 3.0]);
678 let b = arr(vec![4.0, 5.0]);
679 let c = (&a * &b).unwrap();
680 assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
681 }
682
683 #[test]
684 fn test_div() {
685 let a = arr(vec![10.0, 20.0]);
686 let b = arr(vec![2.0, 5.0]);
687 let c = (&a / &b).unwrap();
688 assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
689 }
690
691 #[test]
692 fn test_rem() {
693 let a = arr_i32(vec![7, 10]);
694 let b = arr_i32(vec![3, 4]);
695 let c = (&a % &b).unwrap();
696 assert_eq!(c.as_slice().unwrap(), &[1, 2]);
697 }
698
699 #[test]
700 fn test_neg() {
701 let a = arr(vec![1.0, -2.0, 3.0]);
702 let b = (-&a).unwrap();
703 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
704 }
705
706 #[test]
707 fn test_neg_owned() {
708 let a = arr(vec![1.0, -2.0]);
709 let b = (-a).unwrap();
710 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
711 }
712
713 #[test]
714 fn test_shape_mismatch_errors() {
715 let a = arr(vec![1.0, 2.0]);
716 let b = arr(vec![1.0, 2.0, 3.0]);
717 let result = &a + &b;
718 assert!(result.is_err());
719 }
720
721 #[test]
724 fn test_add_scalar() {
725 let a = arr(vec![1.0, 2.0, 3.0]);
726 let c = (&a + 10.0).unwrap();
727 assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
728 }
729
730 #[test]
731 fn test_sub_scalar() {
732 let a = arr(vec![10.0, 20.0, 30.0]);
733 let c = (&a - 5.0).unwrap();
734 assert_eq!(c.as_slice().unwrap(), &[5.0, 15.0, 25.0]);
735 }
736
737 #[test]
738 fn test_mul_scalar() {
739 let a = arr(vec![1.0, 2.0, 3.0]);
740 let c = (&a * 3.0).unwrap();
741 assert_eq!(c.as_slice().unwrap(), &[3.0, 6.0, 9.0]);
742 }
743
744 #[test]
745 fn test_div_scalar() {
746 let a = arr(vec![10.0, 20.0, 30.0]);
747 let c = (&a / 10.0).unwrap();
748 assert_eq!(c.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
749 }
750
751 #[test]
752 fn test_rem_scalar() {
753 let a = arr_i32(vec![7, 10, 15]);
754 let c = (&a % 4).unwrap();
755 assert_eq!(c.as_slice().unwrap(), &[3, 2, 3]);
756 }
757
758 #[test]
759 fn test_scalar_op_owned() {
760 let a = arr(vec![1.0, 2.0, 3.0]);
761 let c = (a + 10.0).unwrap();
762 assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
763 }
764
765 #[test]
766 fn test_chained_ops() {
767 let a = arr(vec![1.0, 2.0, 3.0]);
768 let b = arr(vec![4.0, 5.0, 6.0]);
769 let c = arr(vec![10.0, 10.0, 10.0]);
770 let result = (&(&a + &b).unwrap() * &c).unwrap();
772 assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
773 }
774
775 use crate::dimension::{Ix2, Ix3, IxDyn};
780
781 #[test]
782 fn test_broadcast_2d_row_plus_column() {
783 let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
785 let row =
786 Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
787 let result = (&col + &row).unwrap();
788 assert_eq!(result.shape(), &[3, 4]);
789 assert_eq!(
790 result.as_slice().unwrap(),
791 &[
792 11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0, ]
796 );
797 }
798
799 #[test]
800 fn test_broadcast_2d_stretch_one_axis() {
801 let a = Array::<f64, Ix2>::from_vec(
803 Ix2::new([3, 4]),
804 vec![
805 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
806 ],
807 )
808 .unwrap();
809 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![100.0, 200.0, 300.0, 400.0])
810 .unwrap();
811 let result = (&a + &b).unwrap();
812 assert_eq!(result.shape(), &[3, 4]);
813 assert_eq!(
814 result.as_slice().unwrap(),
815 &[
816 101.0, 202.0, 303.0, 404.0, 105.0, 206.0, 307.0, 408.0, 109.0, 210.0, 311.0, 412.0,
817 ]
818 );
819 }
820
821 #[test]
822 fn test_broadcast_3d_with_2d_axis() {
823 let a =
825 Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), (1..=24).map(|i| i as f64).collect())
826 .unwrap();
827 let b =
828 Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (1..=12).map(|i| i as f64).collect())
829 .unwrap();
830 let result = (&a - &b).unwrap();
831 assert_eq!(result.shape(), &[2, 3, 4]);
832 let first_half: Vec<f64> = (1..=12).map(|_| 0.0).collect();
834 assert_eq!(&result.as_slice().unwrap()[..12], &first_half[..]);
835 let second_half: Vec<f64> = (0..12).map(|_| 12.0).collect();
837 assert_eq!(&result.as_slice().unwrap()[12..], &second_half[..]);
838 }
839
840 #[test]
841 fn test_broadcast_incompatible_shapes_error() {
842 let a = arr(vec![1.0, 2.0, 3.0]);
844 let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
845 let result = &a + &b;
846 assert!(result.is_err());
847
848 let c = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![0.0; 12]).unwrap();
850 let d = Array::<f64, Ix2>::from_vec(Ix2::new([3, 5]), vec![0.0; 15]).unwrap();
851 assert!((&c + &d).is_err());
852 }
853
854 #[test]
855 fn test_broadcast_mul_2d() {
856 let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
858 let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
859 let result = (&col * &row).unwrap();
860 assert_eq!(result.shape(), &[3, 3]);
861 assert_eq!(
862 result.as_slice().unwrap(),
863 &[10, 20, 30, 20, 40, 60, 30, 60, 90]
864 );
865 }
866
867 #[test]
872 fn test_add_broadcast_1d_plus_2d() {
873 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
875 let m =
876 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
877 .unwrap();
878 let result = v.add_broadcast(&m).unwrap();
879 assert_eq!(result.shape(), &[2, 3]);
880 assert_eq!(
881 result.as_slice().unwrap(),
882 &[11.0, 22.0, 33.0, 41.0, 52.0, 63.0]
883 );
884 }
885
886 #[test]
887 fn test_add_broadcast_1d_plus_column() {
888 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
890 let col = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
891 let result = v.add_broadcast(&col).unwrap();
892 assert_eq!(result.shape(), &[2, 3]);
893 assert_eq!(
894 result.as_slice().unwrap(),
895 &[11.0, 12.0, 13.0, 21.0, 22.0, 23.0]
896 );
897 }
898
899 #[test]
900 fn test_sub_broadcast_2d_minus_1d() {
901 let m =
903 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
904 .unwrap();
905 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
906 let result = m.sub_broadcast(&v).unwrap();
907 assert_eq!(result.shape(), &[2, 3]);
908 assert_eq!(
909 result.as_slice().unwrap(),
910 &[9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
911 );
912 }
913
914 #[test]
915 fn test_mul_broadcast_returns_dyn() {
916 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
917 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
918 let result: Array<f64, IxDyn> = a.mul_broadcast(&b).unwrap();
919 assert_eq!(result.shape(), &[2, 3]);
920 assert_eq!(
921 result.as_slice().unwrap(),
922 &[10.0, 20.0, 30.0, 20.0, 40.0, 60.0]
923 );
924 }
925
926 #[test]
927 fn test_div_broadcast_incompatible() {
928 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
929 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
930 assert!(a.div_broadcast(&b).is_err());
931 }
932
933 #[test]
934 fn test_rem_broadcast_2d() {
935 let a =
936 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
937 let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![3, 7, 11]).unwrap();
938 let result = a.rem_broadcast(&b).unwrap();
939 assert_eq!(result.shape(), &[2, 3]);
940 assert_eq!(
941 result.as_slice().unwrap(),
942 &[10 % 3, 20 % 7, 30 % 11, 40 % 3, 50 % 7, 60 % 11]
943 );
944 }
945
946 #[test]
951 fn scalar_add_assign_mutates_in_place() {
952 let mut a = arr(vec![1.0, 2.0, 3.0]);
953 a += 10.0;
954 assert_eq!(a.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
955 }
956
957 #[test]
958 fn scalar_sub_mul_div_rem_assign() {
959 let mut a = arr(vec![10.0, 20.0, 30.0]);
960 a -= 1.0;
961 assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
962 a *= 2.0;
963 assert_eq!(a.as_slice().unwrap(), &[18.0, 38.0, 58.0]);
964 a /= 2.0;
965 assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
966 let mut b = arr_i32(vec![10, 11, 12]);
967 b %= 3;
968 assert_eq!(b.as_slice().unwrap(), &[1, 2, 0]);
969 }
970
971 #[test]
972 fn scalar_assign_preserves_shape_ix2() {
973 let mut a =
974 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
975 .unwrap();
976 a += 1.0;
977 assert_eq!(a.shape(), &[2, 3]);
978 assert_eq!(a.as_slice().unwrap(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
979 }
980
981 #[test]
982 fn add_inplace_same_shape_fast_path() {
983 let mut a = arr(vec![1.0, 2.0, 3.0]);
984 let b = arr(vec![10.0, 20.0, 30.0]);
985 a.add_inplace(&b).unwrap();
986 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
987 }
988
989 #[test]
990 fn sub_mul_div_rem_inplace_same_shape() {
991 let mut a = arr(vec![10.0, 20.0, 30.0]);
992 let b = arr(vec![1.0, 2.0, 3.0]);
993 a.sub_inplace(&b).unwrap();
994 assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
995 a.mul_inplace(&b).unwrap();
996 assert_eq!(a.as_slice().unwrap(), &[9.0, 36.0, 81.0]);
997 a.div_inplace(&b).unwrap();
998 assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
999 let mut c = arr_i32(vec![10, 20, 30]);
1000 let d = arr_i32(vec![3, 7, 11]);
1001 c.rem_inplace(&d).unwrap();
1002 assert_eq!(c.as_slice().unwrap(), &[1, 6, 8]);
1003 }
1004
1005 #[test]
1006 fn add_inplace_broadcasts_rhs_into_lhs_shape() {
1007 let mut a =
1009 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1010 .unwrap();
1011 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1012 a.add_inplace(&b).unwrap();
1013 assert_eq!(a.shape(), &[2, 3]);
1014 assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
1015 }
1016
1017 #[test]
1018 fn add_inplace_broadcasts_column_into_rows() {
1019 let mut a =
1021 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1022 .unwrap();
1023 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![100.0, 200.0]).unwrap();
1024 a.add_inplace(&b).unwrap();
1025 assert_eq!(
1026 a.as_slice().unwrap(),
1027 &[101.0, 102.0, 103.0, 204.0, 205.0, 206.0]
1028 );
1029 }
1030
1031 #[test]
1032 fn add_inplace_rejects_incompatible_rhs() {
1033 let mut a = arr(vec![1.0, 2.0, 3.0]);
1035 let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
1036 assert!(a.add_inplace(&b).is_err());
1037 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1039 }
1040
1041 #[test]
1042 fn add_inplace_rejects_growing_shape() {
1043 let mut a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1046 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0; 6]).unwrap();
1047 assert!(a.add_inplace(&b).is_err());
1048 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1049 }
1050
1051 #[test]
1052 fn copyto_same_shape_fast_path() {
1053 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1054 let src = arr(vec![1.0, 2.0, 3.0]);
1055 copyto(&mut dst, &src).unwrap();
1056 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1057 }
1058
1059 #[test]
1060 fn copyto_broadcasts_row_into_matrix() {
1061 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1063 let src = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
1064 copyto(&mut dst, &src).unwrap();
1065 assert_eq!(
1066 dst.as_slice().unwrap(),
1067 &[10.0, 20.0, 30.0, 10.0, 20.0, 30.0]
1068 );
1069 }
1070
1071 #[test]
1072 fn copyto_broadcasts_cross_rank_src() {
1073 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1075 let src = arr(vec![7.0, 8.0, 9.0]);
1076 copyto(&mut dst, &src).unwrap();
1077 assert_eq!(dst.as_slice().unwrap(), &[7.0, 8.0, 9.0, 7.0, 8.0, 9.0]);
1078 }
1079
1080 #[test]
1081 fn copyto_scalar_src_broadcasts_to_full_dst() {
1082 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
1084 let src = arr(vec![42.0]);
1085 copyto(&mut dst, &src).unwrap();
1086 assert_eq!(dst.as_slice().unwrap(), &[42.0; 6]);
1087 }
1088
1089 #[test]
1090 fn copyto_rejects_growing_dst() {
1091 let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
1093 let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![99.0; 6]).unwrap();
1094 assert!(copyto(&mut dst, &src).is_err());
1095 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1096 }
1097
1098 #[test]
1099 fn copyto_rejects_incompatible_shapes() {
1100 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1101 let src = arr(vec![1.0, 2.0, 3.0, 4.0]);
1102 assert!(copyto(&mut dst, &src).is_err());
1103 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1104 }
1105
1106 #[test]
1107 fn copyto_method_form_equivalent_to_function() {
1108 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1109 let src = arr(vec![1.0, 2.0, 3.0]);
1110 dst.copy_from(&src).unwrap();
1111 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1112 }
1113
1114 #[test]
1115 fn copyto_works_for_non_copy_element_type_i64() {
1116 let mut dst = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![0, 0, 0, 0]).unwrap();
1119 let src = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
1120 copyto(&mut dst, &src).unwrap();
1121 assert_eq!(dst.as_slice().unwrap(), &[1, 2, 3, 4]);
1122 }
1123
1124 #[test]
1125 fn copyto_where_same_shape_only_writes_masked_positions() {
1126 let mut dst = arr(vec![1.0, 2.0, 3.0, 4.0]);
1127 let src = arr(vec![10.0, 20.0, 30.0, 40.0]);
1128 let mask =
1129 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
1130 copyto_where(&mut dst, &src, &mask).unwrap();
1131 assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 4.0]);
1132 }
1133
1134 #[test]
1135 fn copyto_where_broadcasts_mask_across_dst() {
1136 let mut dst =
1138 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1139 .unwrap();
1140 let src =
1141 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
1142 .unwrap();
1143 let mask = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
1144 copyto_where(&mut dst, &src, &mask).unwrap();
1145 assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 40.0, 5.0, 60.0]);
1146 }
1147
1148 #[test]
1149 fn copyto_where_broadcasts_scalar_src_with_mask() {
1150 let mut dst =
1152 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1153 .unwrap();
1154 let src = arr(vec![99.0]);
1155 let mask = Array::<bool, Ix2>::from_vec(
1156 Ix2::new([2, 3]),
1157 vec![true, false, true, false, true, false],
1158 )
1159 .unwrap();
1160 copyto_where(&mut dst, &src, &mask).unwrap();
1161 assert_eq!(dst.as_slice().unwrap(), &[99.0, 2.0, 99.0, 4.0, 99.0, 6.0]);
1162 }
1163
1164 #[test]
1165 fn copyto_where_all_false_mask_is_noop() {
1166 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1167 let original = dst.as_slice().unwrap().to_vec();
1168 let src = arr(vec![99.0, 99.0, 99.0]);
1169 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
1170 copyto_where(&mut dst, &src, &mask).unwrap();
1171 assert_eq!(dst.as_slice().unwrap(), &original[..]);
1172 }
1173
1174 #[test]
1175 fn copyto_where_all_true_mask_matches_copyto() {
1176 let mut dst = arr(vec![0.0, 0.0, 0.0]);
1177 let src = arr(vec![1.0, 2.0, 3.0]);
1178 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1179 copyto_where(&mut dst, &src, &mask).unwrap();
1180 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1181 }
1182
1183 #[test]
1184 fn copyto_where_rejects_incompatible_src_shape() {
1185 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1186 let src = arr(vec![1.0, 2.0]);
1187 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
1188 assert!(copyto_where(&mut dst, &src, &mask).is_err());
1189 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1190 }
1191
1192 #[test]
1193 fn copyto_where_rejects_incompatible_mask_shape() {
1194 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1195 let src = arr(vec![10.0, 20.0, 30.0]);
1196 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
1197 assert!(copyto_where(&mut dst, &src, &mask).is_err());
1198 assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1200 }
1201
1202 #[test]
1203 fn copy_from_where_method_form_equivalent() {
1204 let mut dst = arr(vec![1.0, 2.0, 3.0]);
1205 let src = arr(vec![10.0, 20.0, 30.0]);
1206 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
1207 dst.copy_from_where(&src, &mask).unwrap();
1208 assert_eq!(dst.as_slice().unwrap(), &[1.0, 20.0, 3.0]);
1209 }
1210
1211 #[test]
1212 fn div_inplace_by_zero_yields_ieee_sentinels() {
1213 let mut a = arr(vec![1.0, 2.0, 0.0]);
1216 let b = arr(vec![2.0, 0.0, 0.0]);
1217 a.div_inplace(&b).unwrap();
1218 let s = a.as_slice().unwrap();
1219 assert_eq!(s[0], 0.5);
1220 assert!(s[1].is_infinite() && s[1].is_sign_positive());
1221 assert!(s[2].is_nan());
1222 }
1223}