redstone_ml/ndarray/
binary_ops.rs

1use crate::broadcast::broadcast_shapes;
2use crate::broadcast::broadcast_stride;
3use crate::common::constructors::Constructors;
4use crate::{NdArray, RawDataType, StridedMemory};
5use std::ops::{Add, BitAnd, BitOr, Div, Mul, Rem, Shl, Shr, Sub};
6
7use crate::ops::binary_ops::*;
8use crate::ops::binary_op_add::BinaryOpAdd;
9use crate::ops::binary_op_div::BinaryOpDiv;
10use crate::ops::binary_op_mul::BinaryOpMul;
11use crate::ops::binary_op_sub::BinaryOpSub;
12use paste::paste;
13
14macro_rules! implement_binary_ops {
15    ($($binary_op:ident, $binary_op_trait:ident, $operator:tt, $method: ident;)* ) => { $(
16        impl<T: RawDataType + $binary_op_trait> $binary_op<NdArray<'_, T>> for NdArray<'_, T> {
17            type Output = NdArray<'static, T>;
18
19            fn $method(self, rhs: NdArray<T>) -> Self::Output { &self $operator &rhs }
20        }
21
22        impl<T: RawDataType + $binary_op_trait> $binary_op<&NdArray<'_, T>> for NdArray<'_, T> {
23            type Output = NdArray<'static, T>;
24
25            fn $method(self, rhs: &NdArray<T>) -> Self::Output { &self $operator rhs }
26        }
27        
28        impl<T: RawDataType + $binary_op_trait> $binary_op<NdArray<'_, T>> for &NdArray<'_, T> {
29            type Output = NdArray<'static, T>;
30
31            fn $method(self, rhs: NdArray<T>) -> Self::Output { self $operator &rhs }
32        }
33
34        impl<T: RawDataType + $binary_op_trait> $binary_op<&NdArray<'_, T>> for &NdArray<'_, T> {
35            type Output = NdArray<'static, T>;
36
37            fn $method(self, rhs: &NdArray<T>) -> Self::Output {
38                let shape = broadcast_shapes(self.shape(), rhs.shape());
39                let lhs_stride = broadcast_stride(self.stride(), &shape, self.shape());
40                let rhs_stride = broadcast_stride(rhs.stride(), &shape, rhs.shape());
41
42                let mut data = vec![T::default(); shape.iter().product()];
43
44                unsafe {
45                    <T as $binary_op_trait>::$method(self.ptr(), &lhs_stride,
46                                                     rhs.ptr(), &rhs_stride,
47                                                     data.as_mut_ptr(), &shape);
48
49                    NdArray::from_contiguous_owned_buffer(shape, data)
50                }
51            }
52        }
53        
54        impl<T: RawDataType + $binary_op_trait> $binary_op<T> for NdArray<'_, T> {
55            type Output = NdArray<'static, T>;
56
57            fn $method(self, rhs: T) -> Self::Output { paste! { &self $operator rhs } }
58        }
59
60        paste! {
61            impl<T: RawDataType + $binary_op_trait> $binary_op<T> for &NdArray<'_, T> {
62                type Output = NdArray<'static, T>;
63
64                fn $method(self, rhs: T) -> Self::Output { paste! {
65                    let mut data = vec![T::default(); self.size()];
66
67                    unsafe {
68                        <T as $binary_op_trait>::[<$method _scalar>](self.ptr(), self.shape(), self.stride(),
69                                                                     rhs, data.as_mut_ptr());
70
71                        NdArray::from_contiguous_owned_buffer(self.shape().to_vec(), data)
72                    }
73                } }
74            }
75        }
76    )*};
77
78    ($dtype1:ty, $dtype2:ty, $($trait_: ident, $method: ident;)* ) => {
79        implement_binary_ops!($dtype1, $($trait_, $method;)* );
80        implement_binary_ops!($dtype2, $($trait_, $method;)* );
81    };
82
83    ($dtype1:ty, $dtype2:ty, $dtype3:ty, $dtype4:ty, $($trait_: ident, $method: ident;)* ) => {
84        implement_binary_ops!($dtype1, $dtype2, $($trait_, $method;)* );
85        implement_binary_ops!($dtype3, $dtype4, $($trait_, $method;)* );
86        implement_binary_ops!($dtype5, $dtype6, $($trait_, $method;)* );
87    }
88}
89
90
91implement_binary_ops!(
92    Add, BinaryOpAdd, +, add;
93    Sub, BinaryOpSub, -, sub;
94    Mul, BinaryOpMul, *, mul;
95    Div, BinaryOpDiv, /, div;
96    Rem, BinaryOpRem, %, rem;
97    BitAnd, BinaryOpBitAnd, &, bitand;
98    BitOr, BinaryOpBitOr, |, bitor;
99    Shl, BinaryOpShl, <<, shl;
100    Shr, BinaryOpShr, >>, shr;
101);