redstone_ml/ndarray/
assign_ops.rs1use crate::ndarray::NdArrayFlags;
2use crate::ops::binary_op_add::BinaryOpAdd;
3use crate::ops::binary_op_div::BinaryOpDiv;
4use crate::ops::binary_op_mul::BinaryOpMul;
5use crate::ops::binary_op_sub::BinaryOpSub;
6use crate::ops::binary_ops::{BinaryOpBitAnd, BinaryOpBitOr, BinaryOpRem, BinaryOpShl, BinaryOpShr};
7use crate::RawDataType;
8use crate::{NdArray, StridedMemory};
9use paste::paste;
10use std::ops::{AddAssign, BitAndAssign, BitOrAssign, DivAssign, MulAssign, RemAssign, ShlAssign, ShrAssign, SubAssign};
11
12
13macro_rules! define_binary_iop {
14 ( $binary_op_trait:ident, $iop_trait:ident, $operator:tt, $method:ident ) => {
15 paste! {
16 impl<T: RawDataType + $binary_op_trait> $iop_trait<NdArray<'_, T>> for NdArray<'_, T> {
17 fn [<$method _assign>](&mut self, rhs: NdArray<'_, T>) {
18 *self $operator &rhs
19 }
20 }
21
22 impl<T: RawDataType + $binary_op_trait> $iop_trait<&NdArray<'_, T>> for NdArray<'_, T> {
23 fn [<$method _assign>](&mut self, rhs: &NdArray<'_, T>) {
24 if !self.flags.contains(NdArrayFlags::Writeable) {
25 panic!("tensor is readonly.");
26 }
27
28 if rhs.shape() == self.shape() {
29 unsafe {
30 <T as $binary_op_trait>::$method(self.ptr(), &self.stride(),
31 rhs.ptr(), &rhs.stride(),
32 self.mut_ptr(), self.shape());
33 }
34 }
35
36 else {
38 let rhs = rhs.broadcast_to(&self.shape);
39
40 unsafe {
41 <T as $binary_op_trait>::$method(self.ptr(), &self.stride(),
42 rhs.ptr(), &rhs.stride(),
43 self.mut_ptr(), self.shape());
44 }
45 }
46 }
47 }
48
49 impl<T: RawDataType + $binary_op_trait> $iop_trait<T> for NdArray<'_, T> {
50 fn [<$method _assign>](&mut self, rhs: T) {
51 if !self.flags.contains(NdArrayFlags::Writeable) {
52 panic!("tensor is readonly.");
53 }
54
55 unsafe {
56 <T as $binary_op_trait>::[<$method _scalar>](self.ptr(), self.shape(), self.stride(),
57 rhs, self.mut_ptr());
58 }
59 }
60 }
61 }
62 };
63}
64
65define_binary_iop!(BinaryOpAdd, AddAssign, +=, add);
66define_binary_iop!(BinaryOpSub, SubAssign, -=, sub);
67define_binary_iop!(BinaryOpMul, MulAssign, *=, mul);
68define_binary_iop!(BinaryOpDiv, DivAssign, /=, div);
69define_binary_iop!(BinaryOpRem, RemAssign, %=, rem);
70define_binary_iop!(BinaryOpBitAnd, BitAndAssign, &=, bitand);
71define_binary_iop!(BinaryOpBitOr, BitOrAssign, |=, bitor);
72define_binary_iop!(BinaryOpShl, ShlAssign, <<=, shl);
73define_binary_iop!(BinaryOpShr, ShrAssign, >>=, shr);