ferrite/autograd/grad_fn/
reduction.rs

1use crate::{reduce_grad, tensor::*};
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct SumGrad {
7  input: Tensor,
8  output: Tensor,
9}
10
11impl SumGrad {
12  pub fn new(input: &Tensor, output: &Tensor) -> Self {
13    SumGrad {
14      input: input.clone(),
15      output: output.clone(),
16    }
17  }
18}
19
20impl GradientFunction for SumGrad {
21  fn backward(&self) {
22    let device = self.output.device();
23    if let Some(input_grad) = &self.input.grad() {
24      if let Some(out_grad) = &self.output.grad() {
25        // For sum, we need to expand the gradient to match input shape
26        let input_shape = self.input.tensor().shape();
27        let ones = Storage::ones(input_shape.clone(), Some(device), None);
28        let expanded_grad = &ones * out_grad.borrow().get(&[0]);
29        input_grad.borrow_mut().add_tensor_assign(&expanded_grad);
30      }
31    }
32  }
33
34  fn prev(&self) -> Vec<&Tensor> {
35    vec![&self.input]
36  }
37}
38
39#[derive(Debug)]
40pub struct MeanGrad {
41  input: Tensor,
42  output: Tensor,
43}
44
45impl MeanGrad {
46  pub fn new(input: &Tensor, output: &Tensor) -> Self {
47    MeanGrad {
48      input: input.clone(),
49      output: output.clone(),
50    }
51  }
52}
53
54impl GradientFunction for MeanGrad {
55  fn backward(&self) {
56    let device = self.output.device();
57
58    if let Some(input_grad) = &self.input.grad() {
59      if let Some(out_grad) = &self.output.grad() {
60        // For mean, expand gradient and divide by number of elements
61        let input_shape = self.input.tensor().shape();
62        let n_elements = input_shape.iter().product::<usize>() as f32;
63        let ones = Storage::ones(input_shape.clone(), Some(device), None);
64        let expanded_grad = &ones * (out_grad.borrow().get(&[0]) / n_elements);
65        input_grad.borrow_mut().add_tensor_assign(&expanded_grad);
66      }
67    }
68  }
69
70  fn prev(&self) -> Vec<&Tensor> {
71    vec![&self.input]
72  }
73}
74
75#[derive(Debug)]
76pub struct ProductGrad {
77  input: Tensor,
78  output: Tensor,
79}
80
81impl ProductGrad {
82  pub fn new(input: &Tensor, output: &Tensor) -> Self {
83    ProductGrad {
84      input: input.clone(),
85      output: output.clone(),
86    }
87  }
88}
89
90impl GradientFunction for ProductGrad {
91  fn backward(&self) {
92    let device = self.output.device();
93
94    if let Some(input_grad) = &self.input.grad() {
95      if let Some(out_grad) = &self.output.grad() {
96        // For product, each element's gradient is the product of all other elements
97        let input_data = self.input.tensor();
98        let mut grad = Storage::zeros(input_data.shape().clone(), Some(device), None);
99        let total_product = self.output.tensor().get(&[0]);
100        
101        // For each element, divide total product by that element to get product of others
102        for i in 0..input_data.data().read().unwrap().len() {
103          let element = input_data.data().read().unwrap()[i];
104          if element != 0.0 {
105            grad.data_mut()[i] = total_product / element;
106          }
107        }
108        
109        // Multiply by output gradient
110        grad = &grad * out_grad.borrow().get(&[0]);
111        input_grad.borrow_mut().add_tensor_assign(&grad);
112      }
113    }
114  }
115
116  fn prev(&self) -> Vec<&Tensor> {
117    vec![&self.input]
118  }
119}
120