redstone_ml/ndarray/
assign_ops.rs

1use crate::ndarray::NdArrayFlags;
2use crate::ops::binary_op_add::BinaryOpAdd;
3use crate::ops::binary_op_div::BinaryOpDiv;
4use crate::ops::binary_op_mul::BinaryOpMul;
5use crate::ops::binary_op_sub::BinaryOpSub;
6use crate::ops::binary_ops::{BinaryOpBitAnd, BinaryOpBitOr, BinaryOpRem, BinaryOpShl, BinaryOpShr};
7use crate::RawDataType;
8use crate::{NdArray, StridedMemory};
9use paste::paste;
10use std::ops::{AddAssign, BitAndAssign, BitOrAssign, DivAssign, MulAssign, RemAssign, ShlAssign, ShrAssign, SubAssign};
11
12
13macro_rules! define_binary_iop {
14    ( $binary_op_trait:ident, $iop_trait:ident, $operator:tt, $method:ident ) => {
15        paste! {
16            impl<T: RawDataType + $binary_op_trait> $iop_trait<NdArray<'_, T>> for NdArray<'_, T> {
17                fn [<$method _assign>](&mut self, rhs: NdArray<'_, T>) {
18                    *self $operator &rhs
19                }
20            }
21
22            impl<T: RawDataType + $binary_op_trait> $iop_trait<&NdArray<'_, T>> for NdArray<'_, T> {
23                fn [<$method _assign>](&mut self, rhs: &NdArray<'_, T>) {
24                    if !self.flags.contains(NdArrayFlags::Writeable) {
25                        panic!("tensor is readonly.");
26                    }
27                    
28                    if rhs.shape() == self.shape() {
29                        unsafe {
30                            <T as $binary_op_trait>::$method(self.ptr(), &self.stride(),
31                                                             rhs.ptr(), &rhs.stride(),
32                                                             self.mut_ptr(), self.shape());
33                        }
34                    }
35                    
36                    // right-hand term needs broadcasting
37                    else {
38                        let rhs = rhs.broadcast_to(&self.shape);
39                        
40                        unsafe {
41                        <T as $binary_op_trait>::$method(self.ptr(), &self.stride(),
42                                                         rhs.ptr(), &rhs.stride(),
43                                                         self.mut_ptr(), self.shape());
44                        }
45                    }
46                }
47            }
48
49            impl<T: RawDataType + $binary_op_trait> $iop_trait<T> for NdArray<'_, T> {
50                fn [<$method _assign>](&mut self, rhs: T) {
51                    if !self.flags.contains(NdArrayFlags::Writeable) {
52                        panic!("tensor is readonly.");
53                    }
54
55                    unsafe {
56                        <T as $binary_op_trait>::[<$method _scalar>](self.ptr(), self.shape(), self.stride(),
57                                                                     rhs, self.mut_ptr());
58                    }
59                }
60            }
61        }
62    };
63}
64
65define_binary_iop!(BinaryOpAdd, AddAssign, +=, add);
66define_binary_iop!(BinaryOpSub, SubAssign, -=, sub);
67define_binary_iop!(BinaryOpMul, MulAssign, *=, mul);
68define_binary_iop!(BinaryOpDiv, DivAssign, /=, div);
69define_binary_iop!(BinaryOpRem, RemAssign, %=, rem);
70define_binary_iop!(BinaryOpBitAnd, BitAndAssign, &=, bitand);
71define_binary_iop!(BinaryOpBitOr, BitOrAssign, |=, bitor);
72define_binary_iop!(BinaryOpShl, ShlAssign, <<=, shl);
73define_binary_iop!(BinaryOpShr, ShrAssign, >>=, shr);