ferrite/tensor/ops/
arithmetic.rs

1use crate::*;
2use std::{ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}, rc::Rc};
3
4pub trait ArithmeticOps {
5  fn add_tensor(&self, other: &Self) -> Self;
6  fn add_tensor_assign(&mut self, other: &Self);
7
8  fn sub_tensor(&self, other: &Self) -> Self;
9  fn sub_tensor_assign(&mut self, other: &Self);
10
11  fn mul_tensor(&self, other: &Self) -> Self;
12  fn mul_tensor_assign(&mut self, other: &Self);
13
14  fn div_tensor(&self, other: &Self) -> Self;
15  fn div_tensor_assign(&mut self, other: &Self);
16
17  fn add_f32(&self, other: f32) -> Self;
18  fn add_f32_assign(&mut self, other: f32);
19
20  fn sub_f32(&self, other: f32) -> Self;
21  fn sub_f32_assign(&mut self, other: f32);
22
23  fn mul_f32(&self, other: f32) -> Self;
24  fn mul_f32_assign(&mut self, other: f32);
25
26  fn div_f32(&self, other: f32) -> Self;
27  fn div_f32_assign(&mut self, other: f32);
28
29  fn pow_f32(&self, other: f32) -> Self;
30  fn pow_f32_assign(&mut self, other: f32);
31
32  fn greater_than(&self, other: &Self, make_binary: bool) -> Self;
33  fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self;
34  fn less_than(&self, other: &Self, make_binary: bool) -> Self;
35  fn less_than_f32(&self, other: f32, make_binary: bool) -> Self;
36
37  fn sign(&self) -> Self;
38  fn abs(&self) -> Self;
39  fn abs_assign(&mut self);
40}
41
42impl ArithmeticOps for Storage {
43  fn add_tensor(&self, other: &Self) -> Self {
44    match_storage!(binary self, add_tensor, other)
45  }
46
47  fn add_tensor_assign(&mut self, other: &Self) {
48    match_storage_assign!(binary self, add_tensor_assign, other);
49  }
50
51  fn sub_tensor(&self, other: &Self) -> Self {
52    match_storage!(binary self, sub_tensor, other)
53  }
54
55  fn sub_tensor_assign(&mut self, other: &Self) {
56    match_storage_assign!(binary self, sub_tensor_assign, other);
57  }
58
59  fn mul_tensor(&self, other: &Self) -> Self {
60    match_storage!(binary self, mul_tensor, other)
61  }
62
63  fn mul_tensor_assign(&mut self, other: &Self) {
64    match_storage_assign!(binary self, mul_tensor_assign, other);
65  }
66
67  fn div_tensor(&self, other: &Self) -> Self {
68    match_storage!(binary self, div_tensor, other)
69  }
70
71  fn div_tensor_assign(&mut self, other: &Self) {
72    match_storage_assign!(binary self, div_tensor_assign, other);
73  }
74
75  fn add_f32(&self, other: f32) -> Self {
76    match_storage!(unary self, add_f32, other)
77  }
78
79  fn add_f32_assign(&mut self, other: f32) {
80    match_storage_assign!(unary self, add_f32_assign, other);
81  }
82
83  fn sub_f32(&self, other: f32) -> Self {
84    match_storage!(unary self, sub_f32, other)
85  }
86
87  fn sub_f32_assign(&mut self, other: f32) {
88    match_storage_assign!(unary self, sub_f32_assign, other);
89  }
90
91  fn mul_f32(&self, other: f32) -> Self {
92    match_storage!(unary self, mul_f32, other)
93  }
94
95  fn mul_f32_assign(&mut self, other: f32) {
96    match_storage_assign!(unary self, mul_f32_assign, other);
97  }
98
99  fn div_f32(&self, other: f32) -> Self {
100    match_storage!(unary self, div_f32, other)
101  }
102
103  fn div_f32_assign(&mut self, other: f32) {
104    match_storage_assign!(unary self, div_f32_assign, other);
105  }
106
107  fn pow_f32(&self, other: f32) -> Self {
108    match_storage!(unary self, pow_f32, other)
109  }
110
111  fn pow_f32_assign(&mut self, other: f32) {
112    match_storage_assign!(unary self, pow_f32_assign, other);
113  }
114
115  fn greater_than(&self, other: &Self, make_binary: bool) -> Self {
116    match_storage!(binary self, greater_than, other, make_binary)
117  }
118
119  fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self {
120    match_storage!(unary self, greater_than_f32, other, make_binary)
121  }
122
123  fn less_than(&self, other: &Self, make_binary: bool) -> Self {
124    match_storage!(binary self, less_than, other, make_binary)
125  }
126
127  fn less_than_f32(&self, other: f32, make_binary: bool) -> Self {
128    match_storage!(unary self, less_than_f32, other, make_binary)
129  }
130
131  fn sign(&self) -> Self {
132    match_storage!(unary self, sign)
133  }
134
135  fn abs(&self) -> Self {
136    match_storage!(unary self, abs)
137  }
138
139  fn abs_assign(&mut self) {
140    match_storage_assign!(unary self, abs_assign)
141  }
142}
143
144impl ArithmeticOps for Tensor {
145  fn add_tensor(&self, other: &Self) -> Self {
146    // Compute the actual tensor addition
147    let tensor = self.tensor().add_tensor(other.tensor());
148    
149    // Create result tensor
150    let requires_grad = *self.requires_grad() || *other.requires_grad();
151    let mut result = Tensor::new(tensor, self.device(), requires_grad);
152    
153    // Set up gradient function if needed
154    if requires_grad {
155      result.set_grad_fn(Some(Rc::new(AddGrad::new(
156        self, 
157        other,
158        &result
159      ))));
160    }
161    
162    result
163  }
164
165  fn sub_tensor(&self, other: &Self) -> Self {
166    // Compute the actual tensor addition
167    let tensor = self.tensor().sub_tensor(other.tensor());
168    
169    // Create result tensor
170    let requires_grad = *self.requires_grad() || *other.requires_grad();
171    let mut result = Tensor::new(tensor, self.device(), requires_grad);
172    
173    // Set up gradient function if needed
174    if requires_grad {
175      result.set_grad_fn(Some(Rc::new(SubGrad::new(
176        self, 
177        other,
178        &result
179      ))));
180    }
181    
182    result
183  }
184
185  fn mul_tensor(&self, other: &Self) -> Self {
186    let tensor = self.tensor().mul_tensor(other.tensor());
187    
188    let requires_grad = *self.requires_grad() || *other.requires_grad();
189    let mut result = Tensor::new(tensor, self.device(), requires_grad);
190    
191    if requires_grad {
192      result.set_grad_fn(Some(Rc::new(MulGrad::new(
193        self,
194        other,
195        &result
196      ))));
197    }
198    
199    result
200  }
201
202  fn div_tensor(&self, other: &Self) -> Self {
203    let tensor = self.tensor().div_tensor(other.tensor());
204    
205    let requires_grad = *self.requires_grad() || *other.requires_grad();
206    let mut result = Tensor::new(tensor, self.device(), requires_grad);
207    
208    if requires_grad {
209      result.set_grad_fn(Some(Rc::new(DivGrad::new(
210        self,
211        other,
212        &result
213      ))));
214    }
215    
216    result
217  }
218
219
220  fn pow_f32(&self, other: f32) -> Self {
221    let tensor = self.tensor().pow_f32(other);
222    
223    let requires_grad = *self.requires_grad();
224    let mut result = Tensor::new(tensor, self.device(), requires_grad);
225    
226    if requires_grad {
227      result.set_grad_fn(Some(Rc::new(PowF32Grad::new(
228        self,
229        other,
230        &result
231      ))));
232    }
233    
234    result
235  }
236
237  // Additional operations that don't have gradient implementations yet
238
239  fn add_f32(&self, other: f32) -> Self {
240    let tensor = self.tensor().add_f32(other);
241    
242    let requires_grad = *self.requires_grad();
243    let mut result = Tensor::new(tensor, self.device(), requires_grad);
244    
245    if requires_grad {
246      result.set_grad_fn(Some(Rc::new(AddF32Grad::new(
247        self,
248        other,
249        &result
250      ))));
251    }
252    
253    result
254  }
255  fn sub_f32(&self, other: f32) -> Self {
256    let tensor = self.tensor().sub_f32(other);
257    
258    let requires_grad = *self.requires_grad();
259    let mut result = Tensor::new(tensor, self.device(), requires_grad);
260    
261    if requires_grad {
262      result.set_grad_fn(Some(Rc::new(SubF32Grad::new(
263        self,
264        other,
265        &result
266      ))));
267    }
268    
269    result
270  }
271
272  fn mul_f32(&self, other: f32) -> Self {
273    let tensor = self.tensor().mul_f32(other);
274    
275    let requires_grad = *self.requires_grad();
276    let mut result = Tensor::new(tensor, self.device(), requires_grad);
277    
278    if requires_grad {
279      result.set_grad_fn(Some(Rc::new(MulF32Grad::new(
280        self,
281        other,
282        &result
283      ))));
284    }
285    
286    result
287  }
288
289  fn div_f32(&self, other: f32) -> Self {
290    let tensor = self.tensor().div_f32(other);
291    
292    let requires_grad = *self.requires_grad();
293    let mut result = Tensor::new(tensor, self.device(), requires_grad);
294    
295    if requires_grad {
296      result.set_grad_fn(Some(Rc::new(DivF32Grad::new(
297        self,
298        other,
299        &result
300      ))));
301    }
302    
303    result
304  }
305
306  fn abs(&self) -> Self {
307    let tensor = self.tensor().abs();
308    
309    let requires_grad = *self.requires_grad();
310    let mut result = Tensor::new(tensor, self.device(), requires_grad);
311    
312    if requires_grad {
313      result.set_grad_fn(Some(Rc::new(AbsGrad::new(
314        self,
315        &result
316      ))));
317    }
318    
319    result
320  }
321  
322
323  // Assignment operations 
324  fn add_tensor_assign(&mut self, other: &Self) {
325    self.tensor_mut().add_tensor_assign(other.tensor());
326  }
327
328  fn sub_tensor_assign(&mut self, other: &Self) {
329    self.tensor_mut().sub_tensor_assign(other.tensor());
330  }
331
332  fn mul_tensor_assign(&mut self, other: &Self) {
333    self.tensor_mut().mul_tensor_assign(other.tensor());
334  }
335
336  fn div_tensor_assign(&mut self, other: &Self) {
337    self.tensor_mut().div_tensor_assign(other.tensor());
338  }
339
340  fn add_f32_assign(&mut self, other: f32) {
341    self.tensor_mut().add_f32_assign(other);
342  }
343
344  fn sub_f32_assign(&mut self, other: f32) {
345    self.tensor_mut().sub_f32_assign(other);
346  }
347
348  fn mul_f32_assign(&mut self, other: f32) {
349    self.tensor_mut().mul_f32_assign(other);
350  }
351
352  fn div_f32_assign(&mut self, other: f32) {
353    self.tensor_mut().div_f32_assign(other);
354  }
355
356  fn pow_f32_assign(&mut self, other: f32) {
357    self.tensor_mut().pow_f32_assign(other);
358  }
359
360  fn abs_assign(&mut self) {
361    self.tensor_mut().abs_assign();
362  }
363
364  fn greater_than(&self, other: &Self, make_binary: bool) -> Self {
365    let tensor = self.tensor().greater_than(other.tensor(), make_binary);
366    Tensor::new(tensor, self.device(), false)
367  }
368
369  fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self {
370    let tensor = self.tensor().greater_than_f32(other, make_binary);
371    Tensor::new(tensor, self.device(), false)
372  }
373
374  fn less_than(&self, other: &Self, make_binary: bool) -> Self {
375    let tensor = self.tensor().less_than(other.tensor(), make_binary);
376    Tensor::new(tensor, self.device(), false)
377  }
378
379  fn less_than_f32(&self, other: f32, make_binary: bool) -> Self {
380    let tensor = self.tensor().less_than_f32(other, make_binary);
381    Tensor::new(tensor, self.device(), false)
382  }
383
384  fn sign(&self) -> Self {
385    let tensor = self.tensor().sign();
386    Tensor::new(tensor, self.device(), false)
387  }
388
389
390}
391
392
393macro_rules! impl_binary_ops {
394  ($type:ty, $target:ty) => {
395    impl Add<&$target> for &$type {
396      type Output = $target;
397      fn add(self, rhs: &$target) -> Self::Output {
398        self.add_tensor(rhs)
399      }
400    }
401
402    impl AddAssign<&$target> for $type {
403      fn add_assign(&mut self, rhs: &$target) {
404        self.add_tensor_assign(rhs)
405      }
406    }
407    
408    impl Sub<&$target> for &$type {
409      type Output = $target;
410      fn sub(self, rhs: &$target) -> Self::Output {
411        self.sub_tensor(rhs)
412      }
413    }
414
415    impl SubAssign<&$target> for $type {
416      fn sub_assign(&mut self, rhs: &$target) {
417        self.sub_tensor_assign(rhs)
418      }
419    }
420
421    impl Mul<&$target> for &$type {
422      type Output = $target;
423      fn mul(self, rhs: &$target) -> Self::Output {
424        self.mul_tensor(rhs)
425      }
426    }
427
428    impl MulAssign<&$target> for $type {
429      fn mul_assign(&mut self, rhs: &$target) {
430        self.mul_tensor_assign(rhs)
431      }
432    }
433
434    impl Div<&$target> for &$type {
435      type Output = $target;
436      fn div(self, rhs: &$target) -> Self::Output {
437        self.div_tensor(rhs)
438      }
439    }
440
441    impl DivAssign<&$target> for $type {
442      fn div_assign(&mut self, rhs: &$target) {
443        self.div_tensor_assign(rhs)
444      }
445    }
446  }
447}
448
449macro_rules! impl_scalar_ops {
450  ($type:ty) => {
451    impl Add<f32> for &$type {
452      type Output = $type;
453      fn add(self, rhs: f32) -> Self::Output {
454        self.add_f32(rhs)
455      }
456    }
457    
458    impl AddAssign<f32> for $type {
459      fn add_assign(&mut self, rhs: f32) {
460        self.add_f32_assign(rhs);
461      }
462    }
463
464    impl Sub<f32> for &$type {
465      type Output = $type;
466      fn sub(self, rhs: f32) -> Self::Output {
467        self.sub_f32(rhs)
468      }
469    }
470    
471    impl SubAssign<f32> for $type {
472      fn sub_assign(&mut self, rhs: f32) {
473        self.sub_f32_assign(rhs);
474      }
475    }
476
477    impl Mul<f32> for &$type {
478      type Output = $type;
479      fn mul(self, rhs: f32) -> Self::Output {
480        self.mul_f32(rhs)
481      }
482    }
483    
484    impl MulAssign<f32> for $type {
485      fn mul_assign(&mut self, rhs: f32) {
486        self.mul_f32_assign(rhs);
487      }
488    }
489
490    impl Div<f32> for &$type {
491      type Output = $type;
492      fn div(self, rhs: f32) -> Self::Output {
493        self.div_f32(rhs)
494      }
495    }
496    
497    impl DivAssign<f32> for $type {
498      fn div_assign(&mut self, rhs: f32) {
499        self.div_f32_assign(rhs);
500      }
501    }
502      
503    
504  }
505}
506
507impl_binary_ops!(Tensor, Tensor);
508impl_binary_ops!(Storage, Storage);
509
510impl_scalar_ops!(Tensor);
511impl_scalar_ops!(Storage);