1use crate::broadcast::broadcast_shapes;
2use crate::broadcast::broadcast_stride;
3use crate::common::constructors::Constructors;
4use crate::{NdArray, RawDataType, StridedMemory};
5use std::ops::{Add, BitAnd, BitOr, Div, Mul, Rem, Shl, Shr, Sub};
6
7use crate::ops::binary_ops::*;
8use crate::ops::binary_op_add::BinaryOpAdd;
9use crate::ops::binary_op_div::BinaryOpDiv;
10use crate::ops::binary_op_mul::BinaryOpMul;
11use crate::ops::binary_op_sub::BinaryOpSub;
12use paste::paste;
13
14macro_rules! implement_binary_ops {
15 ($($binary_op:ident, $binary_op_trait:ident, $operator:tt, $method: ident;)* ) => { $(
16 impl<T: RawDataType + $binary_op_trait> $binary_op<NdArray<'_, T>> for NdArray<'_, T> {
17 type Output = NdArray<'static, T>;
18
19 fn $method(self, rhs: NdArray<T>) -> Self::Output { &self $operator &rhs }
20 }
21
22 impl<T: RawDataType + $binary_op_trait> $binary_op<&NdArray<'_, T>> for NdArray<'_, T> {
23 type Output = NdArray<'static, T>;
24
25 fn $method(self, rhs: &NdArray<T>) -> Self::Output { &self $operator rhs }
26 }
27
28 impl<T: RawDataType + $binary_op_trait> $binary_op<NdArray<'_, T>> for &NdArray<'_, T> {
29 type Output = NdArray<'static, T>;
30
31 fn $method(self, rhs: NdArray<T>) -> Self::Output { self $operator &rhs }
32 }
33
34 impl<T: RawDataType + $binary_op_trait> $binary_op<&NdArray<'_, T>> for &NdArray<'_, T> {
35 type Output = NdArray<'static, T>;
36
37 fn $method(self, rhs: &NdArray<T>) -> Self::Output {
38 let shape = broadcast_shapes(self.shape(), rhs.shape());
39 let lhs_stride = broadcast_stride(self.stride(), &shape, self.shape());
40 let rhs_stride = broadcast_stride(rhs.stride(), &shape, rhs.shape());
41
42 let mut data = vec![T::default(); shape.iter().product()];
43
44 unsafe {
45 <T as $binary_op_trait>::$method(self.ptr(), &lhs_stride,
46 rhs.ptr(), &rhs_stride,
47 data.as_mut_ptr(), &shape);
48
49 NdArray::from_contiguous_owned_buffer(shape, data)
50 }
51 }
52 }
53
54 impl<T: RawDataType + $binary_op_trait> $binary_op<T> for NdArray<'_, T> {
55 type Output = NdArray<'static, T>;
56
57 fn $method(self, rhs: T) -> Self::Output { paste! { &self $operator rhs } }
58 }
59
60 paste! {
61 impl<T: RawDataType + $binary_op_trait> $binary_op<T> for &NdArray<'_, T> {
62 type Output = NdArray<'static, T>;
63
64 fn $method(self, rhs: T) -> Self::Output { paste! {
65 let mut data = vec![T::default(); self.size()];
66
67 unsafe {
68 <T as $binary_op_trait>::[<$method _scalar>](self.ptr(), self.shape(), self.stride(),
69 rhs, data.as_mut_ptr());
70
71 NdArray::from_contiguous_owned_buffer(self.shape().to_vec(), data)
72 }
73 } }
74 }
75 }
76 )*};
77
78 ($dtype1:ty, $dtype2:ty, $($trait_: ident, $method: ident;)* ) => {
79 implement_binary_ops!($dtype1, $($trait_, $method;)* );
80 implement_binary_ops!($dtype2, $($trait_, $method;)* );
81 };
82
83 ($dtype1:ty, $dtype2:ty, $dtype3:ty, $dtype4:ty, $($trait_: ident, $method: ident;)* ) => {
84 implement_binary_ops!($dtype1, $dtype2, $($trait_, $method;)* );
85 implement_binary_ops!($dtype3, $dtype4, $($trait_, $method;)* );
86 implement_binary_ops!($dtype5, $dtype6, $($trait_, $method;)* );
87 }
88}
89
90
91implement_binary_ops!(
92 Add, BinaryOpAdd, +, add;
93 Sub, BinaryOpSub, -, sub;
94 Mul, BinaryOpMul, *, mul;
95 Div, BinaryOpDiv, /, div;
96 Rem, BinaryOpRem, %, rem;
97 BitAnd, BinaryOpBitAnd, &, bitand;
98 BitOr, BinaryOpBitOr, |, bitor;
99 Shl, BinaryOpShl, <<, shl;
100 Shr, BinaryOpShr, >>, shr;
101);