1use crate::prelude::{Scalar, TensorExpr};
6use crate::tensor::{from_vec_with_op, TensorBase};
7use acme::ops::binary::BinaryOp;
8use core::ops;
9use num::traits::float::{Float, FloatCore};
10use num::traits::Pow;
11
12#[allow(dead_code)]
13pub(crate) fn broadcast_scalar_op<F, T>(
14 lhs: &TensorBase<T>,
15 rhs: &TensorBase<T>,
16 op: BinaryOp,
17 f: F,
18) -> TensorBase<T>
19where
20 F: Fn(T, T) -> T,
21 T: Copy + Default,
22{
23 let mut lhs = lhs.clone();
24 let mut rhs = rhs.clone();
25 if lhs.is_scalar() {
26 lhs = lhs.broadcast(rhs.shape());
27 }
28 if rhs.is_scalar() {
29 rhs = rhs.broadcast(lhs.shape());
30 }
31 let shape = lhs.shape().clone();
32 let store = lhs
33 .data()
34 .iter()
35 .zip(rhs.data().iter())
36 .map(|(a, b)| f(*a, *b))
37 .collect();
38 let op = TensorExpr::binary(lhs, rhs, op);
39 from_vec_with_op(false, op, shape, store)
40}
41
42fn check_shapes_or_scalar<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>)
43where
44 T: Clone + Default,
45{
46 let is_scalar = lhs.is_scalar() || rhs.is_scalar();
47 debug_assert!(
48 is_scalar || lhs.shape() == rhs.shape(),
49 "Shape Mismatch: {:?} != {:?}",
50 lhs.shape(),
51 rhs.shape()
52 );
53}
54
55macro_rules! check {
56 (ne: $lhs:expr, $rhs:expr) => {
57 if $lhs != $rhs {
58 panic!("Shape Mismatch: {:?} != {:?}", $lhs, $rhs);
59 }
60 };
61}
62
63impl<T> TensorBase<T>
64where
65 T: Scalar,
66{
67 pub fn apply_binary(&self, other: &Self, op: BinaryOp) -> Self {
68 check_shapes_or_scalar(self, other);
69 let shape = self.shape();
70 let store = self
71 .data()
72 .iter()
73 .zip(other.data().iter())
74 .map(|(a, b)| *a + *b)
75 .collect();
76 let op = TensorExpr::binary(self.clone(), other.clone(), op);
77 from_vec_with_op(false, op, shape, store)
78 }
79
80 pub fn apply_binaryf<F>(&self, other: &Self, op: BinaryOp, f: F) -> Self
81 where
82 F: Fn(T, T) -> T,
83 {
84 check_shapes_or_scalar(self, other);
85 let shape = self.shape();
86 let store = self
87 .data()
88 .iter()
89 .zip(other.data().iter())
90 .map(|(a, b)| f(*a, *b))
91 .collect();
92 let op = TensorExpr::binary(self.clone(), other.clone(), op);
93 from_vec_with_op(false, op, shape, store)
94 }
95}
96
97impl<T> TensorBase<T> {
98 pub fn pow(&self, exp: T) -> Self
99 where
100 T: Copy + Pow<T, Output = T>,
101 {
102 let shape = self.shape();
103 let store = self.data().iter().copied().map(|a| a.pow(exp)).collect();
104 let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
105 from_vec_with_op(false, op, shape, store)
106 }
107
108 pub fn powf(&self, exp: T) -> Self
109 where
110 T: Float,
111 {
112 let shape = self.shape();
113 let store = self.data().iter().copied().map(|a| a.powf(exp)).collect();
114 let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
115 from_vec_with_op(false, op, shape, store)
116 }
117
118 pub fn powi(&self, exp: i32) -> Self
119 where
120 T: FloatCore,
121 {
122 let shape = self.shape();
123 let store = self.data().iter().copied().map(|a| a.powi(exp)).collect();
124 let op = TensorExpr::binary_scalar(self.clone(), T::from(exp).unwrap(), BinaryOp::pow());
125 from_vec_with_op(false, op, shape, store)
126 }
127}
128
129impl<T> Pow<T> for TensorBase<T>
146where
147 T: Copy + Pow<T, Output = T>,
148{
149 type Output = Self;
150
151 fn pow(self, exp: T) -> Self::Output {
152 let shape = self.shape().clone();
153 let store = self.data().iter().map(|a| a.pow(exp)).collect();
154 let op = TensorExpr::binary_scalar(self, exp, BinaryOp::pow());
155 from_vec_with_op(false, op, shape, store)
156 }
157}
158
159impl<'a, T> Pow<T> for &'a TensorBase<T>
160where
161 T: Copy + Pow<T, Output = T>,
162{
163 type Output = TensorBase<T>;
164
165 fn pow(self, exp: T) -> Self::Output {
166 let shape = self.shape().clone();
167 let store = self.data().iter().map(|a| a.pow(exp)).collect();
168 let op = TensorExpr::binary_scalar(self.clone(), exp, BinaryOp::pow());
169 from_vec_with_op(false, op, shape, store)
170 }
171}
172
173macro_rules! impl_binary_op {
174 ($(($trait:ident, $method:ident, $op:tt)),*) => {
175 $( impl_binary_op!($trait, $method, $op); )*
176 };
177 ($trait:ident, $method:ident, $op:tt) => {
178 impl_binary_op!(scalar: $trait, $method, $op);
179 impl_binary_op!(tensor: $trait, $method, $op);
180 };
181 (scalar: $trait:ident, $method:ident, $op:tt) => {
182
183 impl<T> ops::$trait<T> for TensorBase<T>
184 where
185 T: Copy + ops::$trait<Output = T>,
186 {
187 type Output = Self;
188
189 fn $method(self, other: T) -> Self::Output {
190 let shape = self.shape().clone();
191 let store = self.data().iter().map(|a| *a $op other).collect();
192 let op = TensorExpr::binary_scalar(self, other, BinaryOp::$method());
193 from_vec_with_op(false, op, shape, store)
194 }
195 }
196
197 impl<'a, T> ops::$trait<T> for &'a TensorBase<T>
198 where
199 T: Copy + ops::$trait<Output = T>,
200 {
201 type Output = TensorBase<T>;
202
203 fn $method(self, other: T) -> Self::Output {
204 let shape = self.shape().clone();
205 let store = self.data().iter().map(|a| *a $op other).collect();
206 let op = TensorExpr::binary_scalar(self.clone(), other, BinaryOp::$method());
207 from_vec_with_op(false, op, shape, store)
208 }
209 }
210 };
211 (tensor: $trait:ident, $method:ident, $op:tt) => {
212 impl<T> ops::$trait for TensorBase<T>
213 where
214 T: Copy + ops::$trait<Output = T>,
215 {
216 type Output = Self;
217
218 fn $method(self, other: Self) -> Self::Output {
219 check!(ne: self.shape(), other.shape());
220 let shape = self.shape().clone();
221 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
222 let op = TensorExpr::binary(self, other, BinaryOp::$method());
223 from_vec_with_op(false, op, shape, store)
224 }
225 }
226
227 impl<'a, T> ops::$trait<&'a TensorBase<T>> for TensorBase<T>
228 where
229 T: Copy + ops::$trait<Output = T>,
230 {
231 type Output = TensorBase<T>;
232
233 fn $method(self, other: &'a TensorBase<T>) -> Self::Output {
234 if self.shape() != other.shape() {
235 panic!("shapes must be equal");
236 }
237 let shape = self.shape().clone();
238 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
239 let op = TensorExpr::binary(self, other.clone(), BinaryOp::$method());
240 from_vec_with_op(false, op, shape, store)
241 }
242 }
243
244 impl<'a, T> ops::$trait<TensorBase<T>> for &'a TensorBase<T>
245 where
246 T: Copy + ops::$trait<Output = T>,
247 {
248 type Output = TensorBase<T>;
249
250 fn $method(self, other: TensorBase<T>) -> Self::Output {
251 if self.shape() != other.shape() {
252 panic!("shapes must be equal");
253 }
254 let shape = self.shape().clone();
255 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
256 let op = TensorExpr::binary(self.clone(), other, BinaryOp::$method());
257 from_vec_with_op(false, op, shape, store)
258 }
259 }
260
261 impl<'a, 'b, T> ops::$trait<&'b TensorBase<T>> for &'a TensorBase<T>
262 where
263 T: Copy + ops::$trait<Output = T>,
264 {
265 type Output = TensorBase<T>;
266
267 fn $method(self, other: &'b TensorBase<T>) -> Self::Output {
268 if self.shape() != other.shape() {
269 panic!("shapes must be equal");
270 }
271 let shape = self.shape().clone();
272 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
273 let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$method());
274 from_vec_with_op(false, op, shape, store)
275 }
276 }
277 };
278
279}
280
281macro_rules! impl_assign_op {
282 ($trait:ident, $method:ident, $constructor:ident, $inner:ident, $op:tt) => {
283 impl<T> core::ops::$trait for TensorBase<T>
284 where
285 T: Copy + core::ops::$inner<T, Output = T>,
286 {
287 fn $method(&mut self, other: Self) {
288 check!(ne: self.shape(), other.shape());
289 let shape = self.shape().clone();
290 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
291 let op = TensorExpr::binary(self.clone(), other, BinaryOp::$constructor());
292
293 *self = from_vec_with_op(false, op, shape, store);
294 }
295 }
296
297 impl<'a, T> core::ops::$trait<&'a TensorBase<T>> for TensorBase<T>
298 where
299 T: Copy + core::ops::$inner<Output = T>,
300 {
301 fn $method(&mut self, other: &'a TensorBase<T>) {
302 check!(ne: self.shape(), other.shape());
303 let shape = self.shape().clone();
304 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
305 let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$constructor());
306
307 *self = from_vec_with_op(false, op, shape, store);
308 }
309 }
310 };
311
312}
313
314macro_rules! impl_binary_method {
315 ($method:ident, $f:expr) => {
316 pub fn $method(&self, other: &Self) -> Self {
317 $f(self, other)
318 }
319
320 };
321 (scalar: $variant:tt, $method:ident, $op:tt) => {
322 pub fn $method(&self, other: T) -> Self {
323 let shape = self.shape();
324 let store = self.data().iter().map(| elem | *elem $op other).collect();
325 let op = TensorExpr::binary_scalar(self.clone(), other, BinaryOp::$variant());
326 from_vec_with_op(false, op, shape, store)
327 }
328
329 };
330 (tensor: $method:ident, $op:tt) => {
331 pub fn $method(&self, other: &Self) -> Self {
332 check!(ne: self.shape(), other.shape());
333 let shape = self.shape();
334 let store = self.data().iter().zip(other.data().iter()).map(|(a, b)| *a $op *b).collect();
335 let op = TensorExpr::binary(self.clone(), other.clone(), BinaryOp::$method());
336 from_vec_with_op(false, op, shape, store)
337 }
338
339 };
340}
341
342impl_binary_op!((Add, add, +), (Div, div, /), (Mul, mul, *), (Rem, rem, %), (Sub, sub, -));
343
344impl_assign_op!(AddAssign, add_assign, add, Add, +);
345impl_assign_op!(DivAssign, div_assign, div, Div, /);
346impl_assign_op!(MulAssign, mul_assign, mul, Mul, *);
347impl_assign_op!(RemAssign, rem_assign, rem, Rem, %);
348impl_assign_op!(SubAssign, sub_assign, sub, Sub, -);
349
350impl<T> TensorBase<T>
351where
352 T: Scalar,
353{
354 impl_binary_method!(tensor: add, +);
355 impl_binary_method!(scalar: add, add_scalar, +);
356}