acme_tensor/impls/ops/
binary.rs

1/*
2    Appellation: arith <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::prelude::{Scalar, TensorExpr};
6use crate::tensor::{from_vec_with_op, TensorBase};
7use acme::ops::binary::BinaryOp;
8use core::ops;
9use num::traits::float::{Float, FloatCore};
10use num::traits::Pow;
11
12#[allow(dead_code)]
13pub(crate) fn broadcast_scalar_op<F, T>(
14    lhs: &TensorBase<T>,
15    rhs: &TensorBase<T>,
16    op: BinaryOp,
17    f: F,
18) -> TensorBase<T>
19where
20    F: Fn(T, T) -> T,
21    T: Copy + Default,
22{
23    let mut lhs = lhs.clone();
24    let mut rhs = rhs.clone();
25    if lhs.is_scalar() {
26        lhs = lhs.broadcast(rhs.shape());
27    }
28    if rhs.is_scalar() {
29        rhs = rhs.broadcast(lhs.shape());
30    }
31    let shape = lhs.shape().clone();
32    let store = lhs
33        .data()
34        .iter()
35        .zip(rhs.data().iter())
36        .map(|(a, b)| f(*a, *b))
37        .collect();
38    let op = TensorExpr::binary(lhs, rhs, op);
39    from_vec_with_op(false, op, shape, store)
40}
41
42fn check_shapes_or_scalar<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>)
43where
44    T: Clone + Default,
45{
46    let is_scalar = lhs.is_scalar() || rhs.is_scalar();
47    debug_assert!(
48        is_scalar || lhs.shape() == rhs.shape(),
49        "Shape Mismatch: {:?} != {:?}",
50        lhs.shape(),
51        rhs.shape()
52    );
53}
54
55macro_rules! check {
56    (ne: $lhs:expr, $rhs:expr) => {
57        if $lhs != $rhs {
58            panic!("Shape Mismatch: {:?} != {:?}", $lhs, $rhs);
59        }
60    };
61}
62
63impl<T> TensorBase<T>
64where
65    T: Scalar,
66{
67    pub fn apply_binary(&self, other: &Self, op: BinaryOp) -> Self {
68        check_shapes_or_scalar(self, other);
69        let shape = self.shape();
70        let store = self
71            .data()
72            .iter()
73            .zip(other.data().iter())
74            .map(|(a, b)| *a + *b)
75            .collect();
76        let op = TensorExpr::binary(self.clone(), other.clone(), op);
77        from_vec_with_op(false, op, shape, store)
78    }
79
80    pub fn apply_binaryf<F>(&self, other: &Self, op: BinaryOp, f: F) -> Self
81    where
82        F: Fn(T, T) -> T,
83    {
84        check_shapes_or_scalar(self, other);
85        let shape = self.shape();
86        let store = self
87            .data()
88            .iter()
89            .zip(other.data().iter())
90            .map(|(a, b)| f(*a, *b))
91            .collect();
92        let op = TensorExpr::binary(self.clone(), other.clone(), op);
93        from_vec_with_op(false, op, shape, store)
94    }
95}
96
97impl<T> TensorBase<T> {
98    pub fn pow(&self, exp: T) -> Self
99    where
100        T: Copy + Pow<T, Output = T>,
101    {
102        let shape = self.shape();
103        let store = self.data().iter().copied().map(|a| a.pow(exp)).collect();
104        let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
105        from_vec_with_op(false, op, shape, store)
106    }
107
108    pub fn powf(&self, exp: T) -> Self
109    where
110        T: Float,
111    {
112        let shape = self.shape();
113        let store = self.data().iter().copied().map(|a| a.powf(exp)).collect();
114        let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
115        from_vec_with_op(false, op, shape, store)
116    }
117
118    pub fn powi(&self, exp: i32) -> Self
119    where
120        T: FloatCore,
121    {
122        let shape = self.shape();
123        let store = self.data().iter().copied().map(|a| a.powi(exp)).collect();
124        let op = TensorExpr::binary_scalar(self.clone(), T::from(exp).unwrap(), BinaryOp::pow());
125        from_vec_with_op(false, op, shape, store)
126    }
127}
128
129// impl<T> TensorBase<T> where T: ComplexFloat<Real = T> + Scalar<Complex = Complex<T>, Real = T> {
130
131//     pub fn powc(&self, exp: <T as Scalar>::Complex) -> TensorBase<<T as Scalar>::Complex> {
132//         let shape = self.shape();
133//         let store = self.data().iter().copied().map(|a| Scalar::powc(a, exp)).collect();
134//         let op = TensorExpr::binary_scalar_c(self.clone(), exp, BinaryOp::Pow);
135//         TensorBase {
136//             id: TensorId::new(),
137//             data: store,
138//             kind: TensorKind::default(),
139//             layout: Layout::contiguous(shape),
140//             op: BackpropOp::new(op)
141//         }
142//     }
143// }
144
145impl<T> Pow<T> for TensorBase<T>
146where
147    T: Copy + Pow<T, Output = T>,
148{
149    type Output = Self;
150
151    fn pow(self, exp: T) -> Self::Output {
152        let shape = self.shape().clone();
153        let store = self.data().iter().map(|a| a.pow(exp)).collect();
154        let op = TensorExpr::binary_scalar(self, exp, BinaryOp::pow());
155        from_vec_with_op(false, op, shape, store)
156    }
157}
158
159impl<'a, T> Pow<T> for &'a TensorBase<T>
160where
161    T: Copy + Pow<T, Output = T>,
162{
163    type Output = TensorBase<T>;
164
165    fn pow(self, exp: T) -> Self::Output {
166        let shape = self.shape().clone();
167        let store = self.data().iter().map(|a| a.pow(exp)).collect();
168        let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
169        from_vec_with_op(false, op, shape, store)
170    }
171}
172
173macro_rules! impl_binary_op {
174    ($(($trait:ident, $method:ident, $op:tt)),*) => {
175        $( impl_binary_op!($trait, $method, $op); )*
176    };
177    ($trait:ident, $method:ident, $op:tt) => {
178        impl_binary_op!(scalar: $trait, $method, $op);
179        impl_binary_op!(tensor: $trait, $method, $op);
180    };
181    (scalar: $trait:ident, $method:ident, $op:tt) => {
182
183        impl<T> ops::$trait<T> for TensorBase<T>
184        where
185            T: Copy + ops::$trait<Output = T>,
186        {
187            type Output = Self;
188
189            fn $method(self, other: T) -> Self::Output {
190                let shape = self.shape().clone();
191                let store = self.data().iter().map(|a| *a $op other).collect();
192                let op = TensorExpr::binary_scalar(self, other, BinaryOp::$method());
193                from_vec_with_op(false, op, shape, store)
194            }
195        }
196
197        impl<'a, T> ops::$trait<T> for &'a TensorBase<T>
198        where
199            T: Copy + ops::$trait<Output = T>,
200        {
201            type Output = TensorBase<T>;
202
203            fn $method(self, other: T) -> Self::Output {
204                let shape = self.shape().clone();
205                let store = self.data().iter().map(|a| *a $op other).collect();
206                let op = TensorExpr::binary_scalar(self.clone(), other, BinaryOp::$method());
207                from_vec_with_op(false, op, shape, store)
208            }
209        }
210    };
211    (tensor: $trait:ident, $method:ident, $op:tt) => {
212        impl<T> ops::$trait for TensorBase<T>
213        where
214            T: Copy + ops::$trait<Output = T>,
215        {
216            type Output = Self;
217
218            fn $method(self, other: Self) -> Self::Output {
219                check!(ne: self.shape(), other.shape());
220                let shape = self.shape().clone();
221                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
222                let op = TensorExpr::binary(self, other, BinaryOp::$method());
223                from_vec_with_op(false, op, shape, store)
224            }
225        }
226
227        impl<'a, T> ops::$trait<&'a TensorBase<T>> for TensorBase<T>
228        where
229            T: Copy + ops::$trait<Output = T>,
230        {
231            type Output = TensorBase<T>;
232
233            fn $method(self, other: &'a TensorBase<T>) -> Self::Output {
234                if self.shape() != other.shape() {
235                    panic!("shapes must be equal");
236                }
237                let shape = self.shape().clone();
238                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
239                let op = TensorExpr::binary(self, other.clone(), BinaryOp::$method());
240                from_vec_with_op(false, op, shape, store)
241            }
242        }
243
244        impl<'a, T> ops::$trait<TensorBase<T>> for &'a TensorBase<T>
245        where
246            T: Copy + ops::$trait<Output = T>,
247        {
248            type Output = TensorBase<T>;
249
250            fn $method(self, other: TensorBase<T>) -> Self::Output {
251                if self.shape() != other.shape() {
252                    panic!("shapes must be equal");
253                }
254                let shape = self.shape().clone();
255                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
256                let op = TensorExpr::binary(self.clone(), other, BinaryOp::$method());
257                from_vec_with_op(false, op, shape, store)
258            }
259        }
260
261        impl<'a, 'b, T> ops::$trait<&'b TensorBase<T>> for &'a TensorBase<T>
262        where
263            T: Copy + ops::$trait<Output = T>,
264        {
265            type Output = TensorBase<T>;
266
267            fn $method(self, other: &'b TensorBase<T>) -> Self::Output {
268                if self.shape() != other.shape() {
269                    panic!("shapes must be equal");
270                }
271                let shape = self.shape().clone();
272                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
273                let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$method());
274                from_vec_with_op(false, op, shape, store)
275            }
276        }
277    };
278
279}
280
281macro_rules! impl_assign_op {
282    ($trait:ident, $method:ident, $constructor:ident, $inner:ident, $op:tt) => {
283        impl<T> core::ops::$trait for TensorBase<T>
284        where
285            T: Copy + core::ops::$inner<T, Output = T>,
286        {
287            fn $method(&mut self, other: Self) {
288                check!(ne: self.shape(), other.shape());
289                let shape = self.shape().clone();
290                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
291                let op = TensorExpr::binary(self.clone(), other, BinaryOp::$constructor());
292
293                *self = from_vec_with_op(false, op, shape, store);
294            }
295        }
296
297        impl<'a, T> core::ops::$trait<&'a TensorBase<T>> for TensorBase<T>
298        where
299            T: Copy + core::ops::$inner<Output = T>,
300        {
301            fn $method(&mut self, other: &'a TensorBase<T>) {
302                check!(ne: self.shape(), other.shape());
303                let shape = self.shape().clone();
304                let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
305                let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$constructor());
306
307                *self = from_vec_with_op(false, op, shape, store);
308            }
309        }
310    };
311
312}
313
314macro_rules! impl_binary_method {
315    ($method:ident, $f:expr) => {
316        pub fn $method(&self, other: &Self) -> Self {
317            $f(self, other)
318        }
319
320    };
321    (scalar: $variant:tt, $method:ident, $op:tt) => {
322        pub fn $method(&self, other: T) -> Self {
323            let shape = self.shape();
324            let store = self.data().iter().map(| elem | *elem $op other).collect();
325            let op = TensorExpr::binary_scalar(self.clone(), other, BinaryOp::$variant());
326            from_vec_with_op(false, op, shape, store)
327        }
328
329    };
330    (tensor: $method:ident, $op:tt) => {
331        pub fn $method(&self, other: &Self) -> Self {
332            check!(ne: self.shape(), other.shape());
333            let shape = self.shape();
334            let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
335            let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$method());
336            from_vec_with_op(false, op, shape, store)
337        }
338
339    };
340}
341
342impl_binary_op!((Add, add, +), (Div, div, /), (Mul, mul, *), (Rem, rem, %), (Sub, sub, -));
343
344impl_assign_op!(AddAssign, add_assign, add, Add, +);
345impl_assign_op!(DivAssign, div_assign, div, Div, /);
346impl_assign_op!(MulAssign, mul_assign, mul, Mul, *);
347impl_assign_op!(RemAssign, rem_assign, rem, Rem, %);
348impl_assign_op!(SubAssign, sub_assign, sub, Sub, -);
349
350impl<T> TensorBase<T>
351where
352    T: Scalar,
353{
354    impl_binary_method!(tensor: add, +);
355    impl_binary_method!(scalar: add, add_scalar, +);
356}