1use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54 ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55#[doc=$doc]
57impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67 A: Clone + $trt<B, Output=A>,
68 B: Clone,
69 S: DataOwned<Elem=A> + DataMut,
70 S2: Data<Elem=B>,
71 D: Dimension + DimMax<E>,
72 E: Dimension,
73{
74 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75
76 #[track_caller]
77 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
78 {
79 self.$mth(&rhs)
80 }
81}
82
83#[doc=$doc]
85impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
95where
96 A: Clone + $trt<B, Output=A>,
97 B: Clone,
98 S: DataOwned<Elem=A> + DataMut,
99 S2: Data<Elem=B>,
100 D: Dimension + DimMax<E>,
101 E: Dimension,
102{
103 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
104
105 #[track_caller]
106 fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
107 {
108 self.$mth(&**rhs)
109 }
110}
111
112#[doc=$doc]
114impl<'a, A, B, S, D, E> $trt<&'a ArrayRef<B, E>> for ArrayBase<S, D>
124where
125 A: Clone + $trt<B, Output=A>,
126 B: Clone,
127 S: DataOwned<Elem=A> + DataMut,
128 D: Dimension + DimMax<E>,
129 E: Dimension,
130{
131 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
132
133 #[track_caller]
134 fn $mth(self, rhs: &ArrayRef<B, E>) -> Self::Output
135 {
136 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
137 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
138 out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
139 out
140 } else {
141 let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
142 if lhs_view.shape() == self.shape() {
143 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
144 out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
145 out
146 } else {
147 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
148 }
149 }
150 }
151}
152
153#[doc=$doc]
155impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
165where
166 A: Clone + $trt<B, Output=B>,
167 B: Clone,
168 S: Data<Elem=A>,
169 S2: DataOwned<Elem=B> + DataMut,
170 D: Dimension,
171 E: Dimension + DimMax<D>,
172{
173 type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
174
175 #[track_caller]
176 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
177 where
178 {
179 (&**self).$mth(rhs)
180 }
181}
182
183#[doc=$doc]
185impl<'a, A, B, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayRef<A, D>
195where
196 A: Clone + $trt<B, Output=B>,
197 B: Clone,
198 S2: DataOwned<Elem=B> + DataMut,
199 D: Dimension,
200 E: Dimension + DimMax<D>,
201{
202 type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
203
204 #[track_caller]
205 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
206 where
207 {
208 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
209 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
210 out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
211 out
212 } else {
213 let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
214 if rhs_view.shape() == rhs.shape() {
215 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
216 out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
217 out
218 } else {
219 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
220 }
221 }
222 }
223}
224
225#[doc=$doc]
227impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
235where
236 A: Clone + $trt<B, Output=A>,
237 B: Clone,
238 S: Data<Elem=A>,
239 S2: Data<Elem=B>,
240 D: Dimension + DimMax<E>,
241 E: Dimension,
242{
243 type Output = Array<A, <D as DimMax<E>>::Output>;
244
245 #[track_caller]
246 fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
247 (&**self).$mth(&**rhs)
248 }
249}
250
251#[doc=$doc]
253impl<'a, A, B, D, E> $trt<&'a ArrayRef<B, E>> for &'a ArrayRef<A, D>
261where
262 A: Clone + $trt<B, Output=A>,
263 B: Clone,
264 D: Dimension + DimMax<E>,
265 E: Dimension,
266{
267 type Output = Array<A, <D as DimMax<E>>::Output>;
268
269 #[track_caller]
270 fn $mth(self, rhs: &'a ArrayRef<B, E>) -> Self::Output {
271 let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
272 let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
273 let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
274 (lhs, rhs)
275 } else {
276 self.broadcast_with(rhs).unwrap()
277 };
278 Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
279 }
280}
281
282#[doc=$doc]
284impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
289 where A: Clone + $trt<B, Output=A>,
290 S: DataOwned<Elem=A> + DataMut,
291 D: Dimension,
292 B: ScalarOperand,
293{
294 type Output = ArrayBase<S, D>;
295 fn $mth(mut self, x: B) -> ArrayBase<S, D> {
296 self.map_inplace(move |elt| {
297 *elt = elt.clone() $operator x.clone();
298 });
299 self
300 }
301}
302
303#[doc=$doc]
305impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
308 where A: Clone + $trt<B, Output=A>,
309 S: Data<Elem=A>,
310 D: Dimension,
311 B: ScalarOperand,
312{
313 type Output = Array<A, D>;
314
315 fn $mth(self, x: B) -> Self::Output {
316 (&**self).$mth(x)
317 }
318}
319
320#[doc=$doc]
322impl<'a, A, D, B> $trt<B> for &'a ArrayRef<A, D>
325 where A: Clone + $trt<B, Output=A>,
326 D: Dimension,
327 B: ScalarOperand,
328{
329 type Output = Array<A, D>;
330
331 fn $mth(self, x: B) -> Self::Output {
332 self.map(move |elt| elt.clone() $operator x.clone())
333 }
334}
335 );
336);
337
338macro_rules! if_commutative {
340 (Commute { $a:expr } or { $b:expr }) => {
341 $a
342 };
343 (Ordered { $a:expr } or { $b:expr }) => {
344 $b
345 };
346}
347
348macro_rules! impl_scalar_lhs_op {
349 ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
352impl<S, D> $trt<ArrayBase<S, D>> for $scalar
357 where S: DataOwned<Elem=$scalar> + DataMut,
358 D: Dimension,
359{
360 type Output = ArrayBase<S, D>;
361 fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
362 if_commutative!($commutative {
363 rhs.$mth(self)
364 } or {{
365 let mut rhs = rhs;
366 rhs.map_inplace(move |elt| {
367 *elt = self $operator *elt;
368 });
369 rhs
370 }})
371 }
372}
373
374impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
378 where S: Data<Elem=$scalar>,
379 D: Dimension,
380{
381 type Output = Array<$scalar, D>;
382
383 fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
384 self.$mth(&**rhs)
385 }
386}
387
388impl<'a, D> $trt<&'a ArrayRef<$scalar, D>> for $scalar
392 where D: Dimension
393{
394 type Output = Array<$scalar, D>;
395
396 fn $mth(self, rhs: &ArrayRef<$scalar, D>) -> Self::Output {
397 if_commutative!($commutative {
398 rhs.$mth(self)
399 } or {
400 rhs.map(move |elt| self.clone() $operator elt.clone())
401 })
402 }
403}
404 );
405}
406
407mod arithmetic_ops
408{
409 use super::*;
410 use crate::imp_prelude::*;
411
412 use std::ops::*;
413
414 fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C
415 {
416 move |x, y| f(x.clone(), y.clone())
417 }
418
419 fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B)
420 {
421 move |x, y| *x = f(x.clone(), y.clone())
422 }
423
424 fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A)
425 {
426 move |x, y| *x = f(y.clone(), x.clone())
427 }
428
429 impl_binary_op!(Add, +, add, +=, "addition");
430 impl_binary_op!(Sub, -, sub, -=, "subtraction");
431 impl_binary_op!(Mul, *, mul, *=, "multiplication");
432 impl_binary_op!(Div, /, div, /=, "division");
433 impl_binary_op!(Rem, %, rem, %=, "remainder");
434 impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
435 impl_binary_op!(BitOr, |, bitor, |=, "bit or");
436 impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
437 impl_binary_op!(Shl, <<, shl, <<=, "left shift");
438 impl_binary_op!(Shr, >>, shr, >>=, "right shift");
439
440 macro_rules! all_scalar_ops {
441 ($int_scalar:ty) => (
442 impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
443 impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
444 impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
445 impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
446 impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
447 impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
448 impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
449 impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
450 impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
451 impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
452 );
453 }
454 all_scalar_ops!(i8);
455 all_scalar_ops!(u8);
456 all_scalar_ops!(i16);
457 all_scalar_ops!(u16);
458 all_scalar_ops!(i32);
459 all_scalar_ops!(u32);
460 all_scalar_ops!(i64);
461 all_scalar_ops!(u64);
462 all_scalar_ops!(isize);
463 all_scalar_ops!(usize);
464 all_scalar_ops!(i128);
465 all_scalar_ops!(u128);
466
467 impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
468 impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
469 impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
470
471 impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
472 impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
473 impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
474 impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
475 impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
476
477 impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
478 impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
479 impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
480 impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
481 impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
482
483 impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
484 impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
485 impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
486 impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
487
488 impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
489 impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
490 impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
491 impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
492
493 impl<A, S, D> Neg for ArrayBase<S, D>
494 where
495 A: Clone + Neg<Output = A>,
496 S: DataOwned<Elem = A> + DataMut,
497 D: Dimension,
498 {
499 type Output = Self;
500
501 fn neg(mut self) -> Self
503 {
504 self.map_inplace(|elt| {
505 *elt = -elt.clone();
506 });
507 self
508 }
509 }
510
511 impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
512 where
513 &'a A: 'a + Neg<Output = A>,
514 S: Data<Elem = A>,
515 D: Dimension,
516 {
517 type Output = Array<A, D>;
518
519 fn neg(self) -> Array<A, D>
522 {
523 (&**self).neg()
524 }
525 }
526
527 impl<'a, A, D> Neg for &'a ArrayRef<A, D>
528 where
529 &'a A: 'a + Neg<Output = A>,
530 D: Dimension,
531 {
532 type Output = Array<A, D>;
533
534 fn neg(self) -> Array<A, D>
537 {
538 self.map(Neg::neg)
539 }
540 }
541
542 impl<A, S, D> Not for ArrayBase<S, D>
543 where
544 A: Clone + Not<Output = A>,
545 S: DataOwned<Elem = A> + DataMut,
546 D: Dimension,
547 {
548 type Output = Self;
549
550 fn not(mut self) -> Self
552 {
553 self.map_inplace(|elt| {
554 *elt = !elt.clone();
555 });
556 self
557 }
558 }
559
560 impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
561 where
562 &'a A: 'a + Not<Output = A>,
563 S: Data<Elem = A>,
564 D: Dimension,
565 {
566 type Output = Array<A, D>;
567
568 fn not(self) -> Array<A, D>
571 {
572 (&**self).not()
573 }
574 }
575
576 impl<'a, A, D> Not for &'a ArrayRef<A, D>
577 where
578 &'a A: 'a + Not<Output = A>,
579 D: Dimension,
580 {
581 type Output = Array<A, D>;
582
583 fn not(self) -> Array<A, D>
586 {
587 self.map(Not::not)
588 }
589 }
590}
591
592mod assign_ops
593{
594 use super::*;
595 use crate::imp_prelude::*;
596
597 macro_rules! impl_assign_op {
598 ($trt:ident, $method:ident, $doc:expr) => {
599 use std::ops::$trt;
600
601 #[doc=$doc]
602 impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
606 where
607 A: Clone + $trt<A>,
608 S: DataMut<Elem = A>,
609 S2: Data<Elem = A>,
610 D: Dimension,
611 E: Dimension,
612 {
613 #[track_caller]
614 fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
615 (**self).$method(&**rhs)
616 }
617 }
618
619 #[doc=$doc]
620 impl<'a, A, D, E> $trt<&'a ArrayRef<A, E>> for ArrayRef<A, D>
624 where
625 A: Clone + $trt<A>,
626 D: Dimension,
627 E: Dimension,
628 {
629 #[track_caller]
630 fn $method(&mut self, rhs: &ArrayRef<A, E>) {
631 self.zip_mut_with(rhs, |x, y| {
632 x.$method(y.clone());
633 });
634 }
635 }
636
637 #[doc=$doc]
638 impl<A, S, D> $trt<A> for ArrayBase<S, D>
639 where
640 A: ScalarOperand + $trt<A>,
641 S: DataMut<Elem = A>,
642 D: Dimension,
643 {
644 fn $method(&mut self, rhs: A) {
645 (**self).$method(rhs)
646 }
647 }
648
649 #[doc=$doc]
650 impl<A, D> $trt<A> for ArrayRef<A, D>
651 where
652 A: ScalarOperand + $trt<A>,
653 D: Dimension,
654 {
655 fn $method(&mut self, rhs: A) {
656 self.map_inplace(move |elt| {
657 elt.$method(rhs.clone());
658 });
659 }
660 }
661 };
662 }
663
664 impl_assign_op!(
665 AddAssign,
666 add_assign,
667 "Perform `self += rhs` as elementwise addition (in place).\n"
668 );
669 impl_assign_op!(
670 SubAssign,
671 sub_assign,
672 "Perform `self -= rhs` as elementwise subtraction (in place).\n"
673 );
674 impl_assign_op!(
675 MulAssign,
676 mul_assign,
677 "Perform `self *= rhs` as elementwise multiplication (in place).\n"
678 );
679 impl_assign_op!(
680 DivAssign,
681 div_assign,
682 "Perform `self /= rhs` as elementwise division (in place).\n"
683 );
684 impl_assign_op!(
685 RemAssign,
686 rem_assign,
687 "Perform `self %= rhs` as elementwise remainder (in place).\n"
688 );
689 impl_assign_op!(
690 BitAndAssign,
691 bitand_assign,
692 "Perform `self &= rhs` as elementwise bit and (in place).\n"
693 );
694 impl_assign_op!(
695 BitOrAssign,
696 bitor_assign,
697 "Perform `self |= rhs` as elementwise bit or (in place).\n"
698 );
699 impl_assign_op!(
700 BitXorAssign,
701 bitxor_assign,
702 "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
703 );
704 impl_assign_op!(
705 ShlAssign,
706 shl_assign,
707 "Perform `self <<= rhs` as elementwise left shift (in place).\n"
708 );
709 impl_assign_op!(
710 ShrAssign,
711 shr_assign,
712 "Perform `self >>= rhs` as elementwise right shift (in place).\n"
713 );
714}