1use crate::*;
2use std::{ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}, rc::Rc};
3
4pub trait ArithmeticOps {
5 fn add_tensor(&self, other: &Self) -> Self;
6 fn add_tensor_assign(&mut self, other: &Self);
7
8 fn sub_tensor(&self, other: &Self) -> Self;
9 fn sub_tensor_assign(&mut self, other: &Self);
10
11 fn mul_tensor(&self, other: &Self) -> Self;
12 fn mul_tensor_assign(&mut self, other: &Self);
13
14 fn div_tensor(&self, other: &Self) -> Self;
15 fn div_tensor_assign(&mut self, other: &Self);
16
17 fn add_f32(&self, other: f32) -> Self;
18 fn add_f32_assign(&mut self, other: f32);
19
20 fn sub_f32(&self, other: f32) -> Self;
21 fn sub_f32_assign(&mut self, other: f32);
22
23 fn mul_f32(&self, other: f32) -> Self;
24 fn mul_f32_assign(&mut self, other: f32);
25
26 fn div_f32(&self, other: f32) -> Self;
27 fn div_f32_assign(&mut self, other: f32);
28
29 fn pow_f32(&self, other: f32) -> Self;
30 fn pow_f32_assign(&mut self, other: f32);
31
32 fn greater_than(&self, other: &Self, make_binary: bool) -> Self;
33 fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self;
34 fn less_than(&self, other: &Self, make_binary: bool) -> Self;
35 fn less_than_f32(&self, other: f32, make_binary: bool) -> Self;
36
37 fn sign(&self) -> Self;
38 fn abs(&self) -> Self;
39 fn abs_assign(&mut self);
40}
41
42impl ArithmeticOps for Storage {
43 fn add_tensor(&self, other: &Self) -> Self {
44 match_storage!(binary self, add_tensor, other)
45 }
46
47 fn add_tensor_assign(&mut self, other: &Self) {
48 match_storage_assign!(binary self, add_tensor_assign, other);
49 }
50
51 fn sub_tensor(&self, other: &Self) -> Self {
52 match_storage!(binary self, sub_tensor, other)
53 }
54
55 fn sub_tensor_assign(&mut self, other: &Self) {
56 match_storage_assign!(binary self, sub_tensor_assign, other);
57 }
58
59 fn mul_tensor(&self, other: &Self) -> Self {
60 match_storage!(binary self, mul_tensor, other)
61 }
62
63 fn mul_tensor_assign(&mut self, other: &Self) {
64 match_storage_assign!(binary self, mul_tensor_assign, other);
65 }
66
67 fn div_tensor(&self, other: &Self) -> Self {
68 match_storage!(binary self, div_tensor, other)
69 }
70
71 fn div_tensor_assign(&mut self, other: &Self) {
72 match_storage_assign!(binary self, div_tensor_assign, other);
73 }
74
75 fn add_f32(&self, other: f32) -> Self {
76 match_storage!(unary self, add_f32, other)
77 }
78
79 fn add_f32_assign(&mut self, other: f32) {
80 match_storage_assign!(unary self, add_f32_assign, other);
81 }
82
83 fn sub_f32(&self, other: f32) -> Self {
84 match_storage!(unary self, sub_f32, other)
85 }
86
87 fn sub_f32_assign(&mut self, other: f32) {
88 match_storage_assign!(unary self, sub_f32_assign, other);
89 }
90
91 fn mul_f32(&self, other: f32) -> Self {
92 match_storage!(unary self, mul_f32, other)
93 }
94
95 fn mul_f32_assign(&mut self, other: f32) {
96 match_storage_assign!(unary self, mul_f32_assign, other);
97 }
98
99 fn div_f32(&self, other: f32) -> Self {
100 match_storage!(unary self, div_f32, other)
101 }
102
103 fn div_f32_assign(&mut self, other: f32) {
104 match_storage_assign!(unary self, div_f32_assign, other);
105 }
106
107 fn pow_f32(&self, other: f32) -> Self {
108 match_storage!(unary self, pow_f32, other)
109 }
110
111 fn pow_f32_assign(&mut self, other: f32) {
112 match_storage_assign!(unary self, pow_f32_assign, other);
113 }
114
115 fn greater_than(&self, other: &Self, make_binary: bool) -> Self {
116 match_storage!(binary self, greater_than, other, make_binary)
117 }
118
119 fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self {
120 match_storage!(unary self, greater_than_f32, other, make_binary)
121 }
122
123 fn less_than(&self, other: &Self, make_binary: bool) -> Self {
124 match_storage!(binary self, less_than, other, make_binary)
125 }
126
127 fn less_than_f32(&self, other: f32, make_binary: bool) -> Self {
128 match_storage!(unary self, less_than_f32, other, make_binary)
129 }
130
131 fn sign(&self) -> Self {
132 match_storage!(unary self, sign)
133 }
134
135 fn abs(&self) -> Self {
136 match_storage!(unary self, abs)
137 }
138
139 fn abs_assign(&mut self) {
140 match_storage_assign!(unary self, abs_assign)
141 }
142}
143
144impl ArithmeticOps for Tensor {
145 fn add_tensor(&self, other: &Self) -> Self {
146 let tensor = self.tensor().add_tensor(other.tensor());
148
149 let requires_grad = *self.requires_grad() || *other.requires_grad();
151 let mut result = Tensor::new(tensor, self.device(), requires_grad);
152
153 if requires_grad {
155 result.set_grad_fn(Some(Rc::new(AddGrad::new(
156 self,
157 other,
158 &result
159 ))));
160 }
161
162 result
163 }
164
165 fn sub_tensor(&self, other: &Self) -> Self {
166 let tensor = self.tensor().sub_tensor(other.tensor());
168
169 let requires_grad = *self.requires_grad() || *other.requires_grad();
171 let mut result = Tensor::new(tensor, self.device(), requires_grad);
172
173 if requires_grad {
175 result.set_grad_fn(Some(Rc::new(SubGrad::new(
176 self,
177 other,
178 &result
179 ))));
180 }
181
182 result
183 }
184
185 fn mul_tensor(&self, other: &Self) -> Self {
186 let tensor = self.tensor().mul_tensor(other.tensor());
187
188 let requires_grad = *self.requires_grad() || *other.requires_grad();
189 let mut result = Tensor::new(tensor, self.device(), requires_grad);
190
191 if requires_grad {
192 result.set_grad_fn(Some(Rc::new(MulGrad::new(
193 self,
194 other,
195 &result
196 ))));
197 }
198
199 result
200 }
201
202 fn div_tensor(&self, other: &Self) -> Self {
203 let tensor = self.tensor().div_tensor(other.tensor());
204
205 let requires_grad = *self.requires_grad() || *other.requires_grad();
206 let mut result = Tensor::new(tensor, self.device(), requires_grad);
207
208 if requires_grad {
209 result.set_grad_fn(Some(Rc::new(DivGrad::new(
210 self,
211 other,
212 &result
213 ))));
214 }
215
216 result
217 }
218
219
220 fn pow_f32(&self, other: f32) -> Self {
221 let tensor = self.tensor().pow_f32(other);
222
223 let requires_grad = *self.requires_grad();
224 let mut result = Tensor::new(tensor, self.device(), requires_grad);
225
226 if requires_grad {
227 result.set_grad_fn(Some(Rc::new(PowF32Grad::new(
228 self,
229 other,
230 &result
231 ))));
232 }
233
234 result
235 }
236
237 fn add_f32(&self, other: f32) -> Self {
240 let tensor = self.tensor().add_f32(other);
241
242 let requires_grad = *self.requires_grad();
243 let mut result = Tensor::new(tensor, self.device(), requires_grad);
244
245 if requires_grad {
246 result.set_grad_fn(Some(Rc::new(AddF32Grad::new(
247 self,
248 other,
249 &result
250 ))));
251 }
252
253 result
254 }
255 fn sub_f32(&self, other: f32) -> Self {
256 let tensor = self.tensor().sub_f32(other);
257
258 let requires_grad = *self.requires_grad();
259 let mut result = Tensor::new(tensor, self.device(), requires_grad);
260
261 if requires_grad {
262 result.set_grad_fn(Some(Rc::new(SubF32Grad::new(
263 self,
264 other,
265 &result
266 ))));
267 }
268
269 result
270 }
271
272 fn mul_f32(&self, other: f32) -> Self {
273 let tensor = self.tensor().mul_f32(other);
274
275 let requires_grad = *self.requires_grad();
276 let mut result = Tensor::new(tensor, self.device(), requires_grad);
277
278 if requires_grad {
279 result.set_grad_fn(Some(Rc::new(MulF32Grad::new(
280 self,
281 other,
282 &result
283 ))));
284 }
285
286 result
287 }
288
289 fn div_f32(&self, other: f32) -> Self {
290 let tensor = self.tensor().div_f32(other);
291
292 let requires_grad = *self.requires_grad();
293 let mut result = Tensor::new(tensor, self.device(), requires_grad);
294
295 if requires_grad {
296 result.set_grad_fn(Some(Rc::new(DivF32Grad::new(
297 self,
298 other,
299 &result
300 ))));
301 }
302
303 result
304 }
305
306 fn abs(&self) -> Self {
307 let tensor = self.tensor().abs();
308
309 let requires_grad = *self.requires_grad();
310 let mut result = Tensor::new(tensor, self.device(), requires_grad);
311
312 if requires_grad {
313 result.set_grad_fn(Some(Rc::new(AbsGrad::new(
314 self,
315 &result
316 ))));
317 }
318
319 result
320 }
321
322
323 fn add_tensor_assign(&mut self, other: &Self) {
325 self.tensor_mut().add_tensor_assign(other.tensor());
326 }
327
328 fn sub_tensor_assign(&mut self, other: &Self) {
329 self.tensor_mut().sub_tensor_assign(other.tensor());
330 }
331
332 fn mul_tensor_assign(&mut self, other: &Self) {
333 self.tensor_mut().mul_tensor_assign(other.tensor());
334 }
335
336 fn div_tensor_assign(&mut self, other: &Self) {
337 self.tensor_mut().div_tensor_assign(other.tensor());
338 }
339
340 fn add_f32_assign(&mut self, other: f32) {
341 self.tensor_mut().add_f32_assign(other);
342 }
343
344 fn sub_f32_assign(&mut self, other: f32) {
345 self.tensor_mut().sub_f32_assign(other);
346 }
347
348 fn mul_f32_assign(&mut self, other: f32) {
349 self.tensor_mut().mul_f32_assign(other);
350 }
351
352 fn div_f32_assign(&mut self, other: f32) {
353 self.tensor_mut().div_f32_assign(other);
354 }
355
356 fn pow_f32_assign(&mut self, other: f32) {
357 self.tensor_mut().pow_f32_assign(other);
358 }
359
360 fn abs_assign(&mut self) {
361 self.tensor_mut().abs_assign();
362 }
363
364 fn greater_than(&self, other: &Self, make_binary: bool) -> Self {
365 let tensor = self.tensor().greater_than(other.tensor(), make_binary);
366 Tensor::new(tensor, self.device(), false)
367 }
368
369 fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self {
370 let tensor = self.tensor().greater_than_f32(other, make_binary);
371 Tensor::new(tensor, self.device(), false)
372 }
373
374 fn less_than(&self, other: &Self, make_binary: bool) -> Self {
375 let tensor = self.tensor().less_than(other.tensor(), make_binary);
376 Tensor::new(tensor, self.device(), false)
377 }
378
379 fn less_than_f32(&self, other: f32, make_binary: bool) -> Self {
380 let tensor = self.tensor().less_than_f32(other, make_binary);
381 Tensor::new(tensor, self.device(), false)
382 }
383
384 fn sign(&self) -> Self {
385 let tensor = self.tensor().sign();
386 Tensor::new(tensor, self.device(), false)
387 }
388
389
390}
391
392
393macro_rules! impl_binary_ops {
394 ($type:ty, $target:ty) => {
395 impl Add<&$target> for &$type {
396 type Output = $target;
397 fn add(self, rhs: &$target) -> Self::Output {
398 self.add_tensor(rhs)
399 }
400 }
401
402 impl AddAssign<&$target> for $type {
403 fn add_assign(&mut self, rhs: &$target) {
404 self.add_tensor_assign(rhs)
405 }
406 }
407
408 impl Sub<&$target> for &$type {
409 type Output = $target;
410 fn sub(self, rhs: &$target) -> Self::Output {
411 self.sub_tensor(rhs)
412 }
413 }
414
415 impl SubAssign<&$target> for $type {
416 fn sub_assign(&mut self, rhs: &$target) {
417 self.sub_tensor_assign(rhs)
418 }
419 }
420
421 impl Mul<&$target> for &$type {
422 type Output = $target;
423 fn mul(self, rhs: &$target) -> Self::Output {
424 self.mul_tensor(rhs)
425 }
426 }
427
428 impl MulAssign<&$target> for $type {
429 fn mul_assign(&mut self, rhs: &$target) {
430 self.mul_tensor_assign(rhs)
431 }
432 }
433
434 impl Div<&$target> for &$type {
435 type Output = $target;
436 fn div(self, rhs: &$target) -> Self::Output {
437 self.div_tensor(rhs)
438 }
439 }
440
441 impl DivAssign<&$target> for $type {
442 fn div_assign(&mut self, rhs: &$target) {
443 self.div_tensor_assign(rhs)
444 }
445 }
446 }
447}
448
449macro_rules! impl_scalar_ops {
450 ($type:ty) => {
451 impl Add<f32> for &$type {
452 type Output = $type;
453 fn add(self, rhs: f32) -> Self::Output {
454 self.add_f32(rhs)
455 }
456 }
457
458 impl AddAssign<f32> for $type {
459 fn add_assign(&mut self, rhs: f32) {
460 self.add_f32_assign(rhs);
461 }
462 }
463
464 impl Sub<f32> for &$type {
465 type Output = $type;
466 fn sub(self, rhs: f32) -> Self::Output {
467 self.sub_f32(rhs)
468 }
469 }
470
471 impl SubAssign<f32> for $type {
472 fn sub_assign(&mut self, rhs: f32) {
473 self.sub_f32_assign(rhs);
474 }
475 }
476
477 impl Mul<f32> for &$type {
478 type Output = $type;
479 fn mul(self, rhs: f32) -> Self::Output {
480 self.mul_f32(rhs)
481 }
482 }
483
484 impl MulAssign<f32> for $type {
485 fn mul_assign(&mut self, rhs: f32) {
486 self.mul_f32_assign(rhs);
487 }
488 }
489
490 impl Div<f32> for &$type {
491 type Output = $type;
492 fn div(self, rhs: f32) -> Self::Output {
493 self.div_f32(rhs)
494 }
495 }
496
497 impl DivAssign<f32> for $type {
498 fn div_assign(&mut self, rhs: f32) {
499 self.div_f32_assign(rhs);
500 }
501 }
502
503
504 }
505}
506
507impl_binary_ops!(Tensor, Tensor);
508impl_binary_ops!(Storage, Storage);
509
510impl_scalar_ops!(Tensor);
511impl_scalar_ops!(Storage);