ferrite/autograd/grad_fn/
arithmetic.rs

1use crate::{reduce_grad, tensor::*};
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct AddGrad {
7  lhs: Tensor,
8  rhs: Tensor,
9  output: Tensor,
10}
11
12impl AddGrad {
13  pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Self {
14    AddGrad {
15      lhs: lhs.clone(),
16      rhs: rhs.clone(),
17      output: output.clone(),
18    }
19  }
20}
21
22impl GradientFunction for AddGrad {
23  fn backward(&self) {
24    // Get output gradient
25    let out_grad = self.output.grad().unwrap();
26    let out_grad = out_grad.borrow();
27
28    // Propagate to lhs
29    if let Some(lhs_grad) = &self.lhs.grad() {
30      let reduced_grad = reduce_grad!(out_grad, self.lhs.tensor().shape());
31    
32      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
33    }
34    
35    // Propagate to rhs
36    if let Some(rhs_grad) = &self.rhs.grad() {
37      let reduced_grad = reduce_grad!(out_grad, self.rhs.tensor().shape());
38      
39      rhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
40    }
41  }
42
43  fn prev(&self) -> Vec<&Tensor> {
44    vec![&self.lhs, &self.rhs]
45  }
46}
47
48
49#[derive(Debug)]
50pub struct SubGrad {
51  lhs: Tensor,
52  rhs: Tensor,
53  output: Tensor,
54}
55
56impl SubGrad {
57  pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Self {
58    SubGrad {
59      lhs: lhs.clone(),
60      rhs: rhs.clone(),
61      output: output.clone(),
62    }
63  }
64}
65
66impl GradientFunction for SubGrad {
67  fn backward(&self) {
68    // Get output gradient
69    let out_grad = self.output.grad().unwrap();
70    let out_grad = out_grad.borrow();
71
72    // Propagate to lhs
73    if let Some(lhs_grad) = &self.lhs.grad() {
74      let reduced_grad = reduce_grad!(out_grad, self.lhs.tensor().shape());
75      
76      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
77    }
78    
79    // Propagate to rhs
80    if let Some(rhs_grad) = &self.rhs.grad() {
81      let grad_for_rhs = &*out_grad * -1.;
82      let reduced_grad = reduce_grad!(grad_for_rhs, self.rhs.tensor().shape());
83      
84      rhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
85    }
86  }
87
88  fn prev(&self) -> Vec<&Tensor> {
89    vec![&self.lhs, &self.rhs]
90  }
91}
92
93#[derive(Debug)]
94pub struct MulGrad {
95  lhs: Tensor,
96  rhs: Tensor,
97  output: Tensor,
98}
99
100impl MulGrad {
101  pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Self {
102    MulGrad {
103      lhs: lhs.clone(),
104      rhs: rhs.clone(),
105      output: output.clone(),
106    }
107  }
108}
109
110impl GradientFunction for MulGrad {
111  fn backward(&self) {
112    let out_grad = self.output.grad().unwrap();
113    let out_grad = out_grad.borrow();
114
115    // Propagate to lhs
116    if let Some(lhs_grad) = &self.lhs.grad() {
117      let grad_for_lhs = &*out_grad * self.rhs.tensor();
118      
119      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
120      
121      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
122    }
123    
124    // Propagate to rhs
125    if let Some(rhs_grad) = &self.rhs.grad() {
126      let grad_for_rhs = &*out_grad * self.lhs.tensor();
127      
128      let reduced_grad = reduce_grad!(grad_for_rhs, self.rhs.tensor().shape());
129
130      rhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
131    }
132  }
133
134  fn prev(&self) -> Vec<&Tensor> {
135    vec![&self.lhs, &self.rhs]
136  }
137}
138
139
140#[derive(Debug)]
141pub struct DivGrad {
142  lhs: Tensor,
143  rhs: Tensor,
144  output: Tensor,
145}
146
147impl DivGrad {
148  pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Self {
149    DivGrad {
150      lhs: lhs.clone(),
151      rhs: rhs.clone(),
152      output: output.clone(),
153    }
154  }
155}
156
157impl GradientFunction for DivGrad {
158  fn backward(&self) {
159    let out_grad = self.output.grad().unwrap();
160    let out_grad = out_grad.borrow();
161
162    // Propagate to lhs
163    if let Some(lhs_grad) = &self.lhs.grad() {
164      let grad_for_lhs = &*out_grad / self.rhs.tensor();
165      
166      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
167      
168      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
169    }
170    
171    // Propagate to rhs
172    if let Some(rhs_grad) = &self.rhs.grad() {
173      // Form grad for rhs
174      let grad_for_rhs = &(&*out_grad * self.lhs.tensor()).mul_f32(-1.) / &(self.rhs.tensor().pow_f32(2.));
175      
176      let reduced_grad = reduce_grad!(grad_for_rhs, self.rhs.tensor().shape());
177
178      rhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
179    }
180  }
181
182  fn prev(&self) -> Vec<&Tensor> {
183    vec![&self.lhs, &self.rhs]
184  }
185}
186
187#[derive(Debug)]
188pub struct PowF32Grad {
189  lhs: Tensor,
190  rhs: f32,
191  output: Tensor,
192}
193
194impl PowF32Grad {
195  pub fn new(lhs: &Tensor, rhs: f32, output: &Tensor) -> Self {
196    PowF32Grad {
197      lhs: lhs.clone(),
198      rhs: rhs,
199      output: output.clone(),
200    }
201  }
202}
203
204impl GradientFunction for PowF32Grad {
205  fn backward(&self) {
206    let out_grad = self.output.grad().unwrap();
207    let out_grad = out_grad.borrow();
208
209    // Propagate to lhs
210    if let Some(lhs_grad) = &self.lhs.grad() {
211      let grad_for_lhs = &(&*out_grad * self.rhs) * &self.lhs.tensor().pow_f32(self.rhs-1.);
212      
213      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
214      
215      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
216
217    }
218  }
219
220  fn prev(&self) -> Vec<&Tensor> {
221    vec![&self.lhs]
222  }
223}
224
225
226#[derive(Debug)]
227pub struct AddF32Grad {
228  lhs: Tensor,
229  rhs: f32,
230  output: Tensor,
231}
232
233impl AddF32Grad {
234  pub fn new(lhs: &Tensor, rhs: f32, output: &Tensor) -> Self {
235    AddF32Grad {
236      lhs: lhs.clone(),
237      rhs: rhs,
238      output: output.clone(),
239    }
240  }
241}
242
243impl GradientFunction for AddF32Grad {
244  fn backward(&self) {
245    // Get output gradient
246    let out_grad = self.output.grad().unwrap();
247    let out_grad = out_grad.borrow();
248
249    // Propagate to lhs
250    if let Some(lhs_grad) = &self.lhs.grad() {
251      let reduced_grad = reduce_grad!(out_grad, self.lhs.tensor().shape());
252    
253      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
254    }
255  }
256
257  fn prev(&self) -> Vec<&Tensor> {
258    vec![&self.lhs]
259  }
260}
261
262
263#[derive(Debug)]
264pub struct SubF32Grad {
265  lhs: Tensor,
266  rhs: f32,
267  output: Tensor,
268}
269
270impl SubF32Grad {
271  pub fn new(lhs: &Tensor, rhs: f32, output: &Tensor) -> Self {
272    SubF32Grad {
273      lhs: lhs.clone(),
274      rhs: rhs,
275      output: output.clone(),
276    }
277  }
278}
279
280impl GradientFunction for SubF32Grad {
281  fn backward(&self) {
282    // Get output gradient
283    let out_grad = self.output.grad().unwrap();
284    let out_grad = out_grad.borrow();
285
286    // Propagate to lhs
287    if let Some(lhs_grad) = &self.lhs.grad() {
288      let reduced_grad = reduce_grad!(out_grad, self.lhs.tensor().shape());
289      
290      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
291    }
292  }
293
294  fn prev(&self) -> Vec<&Tensor> {
295    vec![&self.lhs]
296  }
297}
298
299#[derive(Debug)]
300pub struct MulF32Grad {
301  lhs: Tensor,
302  rhs: f32,
303  output: Tensor,
304}
305
306impl MulF32Grad {
307  pub fn new(lhs: &Tensor, rhs: f32, output: &Tensor) -> Self {
308    MulF32Grad {
309      lhs: lhs.clone(),
310      rhs: rhs,
311      output: output.clone(),
312    }
313  }
314}
315
316impl GradientFunction for MulF32Grad {
317  fn backward(&self) {
318    let out_grad = self.output.grad().unwrap();
319    let out_grad = out_grad.borrow();
320
321    // Propagate to lhs
322    if let Some(lhs_grad) = &self.lhs.grad() {
323      let grad_for_lhs = &*out_grad * self.rhs;
324      
325      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
326      
327      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
328    }
329  
330  }
331
332  fn prev(&self) -> Vec<&Tensor> {
333    vec![&self.lhs]
334  }
335}
336
337
338#[derive(Debug)]
339pub struct DivF32Grad {
340  lhs: Tensor,
341  rhs: f32,
342  output: Tensor,
343}
344
345impl DivF32Grad {
346  pub fn new(lhs: &Tensor, rhs: f32, output: &Tensor) -> Self {
347    DivF32Grad {
348      lhs: lhs.clone(),
349      rhs: rhs,
350      output: output.clone(),
351    }
352  }
353}
354
355impl GradientFunction for DivF32Grad {
356  fn backward(&self) {
357    let out_grad = self.output.grad().unwrap();
358    let out_grad = out_grad.borrow();
359
360    // Propagate to lhs
361    if let Some(lhs_grad) = &self.lhs.grad() {
362      let grad_for_lhs = &*out_grad / self.rhs;
363      
364      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
365      
366      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
367    }
368  }
369
370  fn prev(&self) -> Vec<&Tensor> {
371    vec![&self.lhs]
372  }
373}
374
375
376#[derive(Debug)]
377pub struct AbsGrad {
378  lhs: Tensor,
379  output: Tensor,
380}
381
382impl AbsGrad {
383  pub fn new(lhs: &Tensor, output: &Tensor) -> Self {
384    AbsGrad {
385      lhs: lhs.clone(),
386      output: output.clone(),
387    }
388  }
389}
390
391impl GradientFunction for AbsGrad {
392  fn backward(&self) {
393    let out_grad = self.output.grad().unwrap();
394    let out_grad = out_grad.borrow();
395
396    // Propagate to lhs
397    if let Some(lhs_grad) = &self.lhs.grad() {
398      let grad_for_lhs = &*out_grad * &self.lhs.tensor().sign();
399      
400      let reduced_grad = reduce_grad!(grad_for_lhs, self.lhs.tensor().shape());
401      
402      lhs_grad.borrow_mut().add_tensor_assign(&reduced_grad);
403    }
404  }
405
406  fn prev(&self) -> Vec<&Tensor> {
407    vec![&self.lhs]
408  }
409}