ferrite/autograd/grad_fn/
reduction.rs1use crate::{reduce_grad, tensor::*};
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct SumGrad {
7 input: Tensor,
8 output: Tensor,
9}
10
11impl SumGrad {
12 pub fn new(input: &Tensor, output: &Tensor) -> Self {
13 SumGrad {
14 input: input.clone(),
15 output: output.clone(),
16 }
17 }
18}
19
20impl GradientFunction for SumGrad {
21 fn backward(&self) {
22 let device = self.output.device();
23 if let Some(input_grad) = &self.input.grad() {
24 if let Some(out_grad) = &self.output.grad() {
25 let input_shape = self.input.tensor().shape();
27 let ones = Storage::ones(input_shape.clone(), Some(device), None);
28 let expanded_grad = &ones * out_grad.borrow().get(&[0]);
29 input_grad.borrow_mut().add_tensor_assign(&expanded_grad);
30 }
31 }
32 }
33
34 fn prev(&self) -> Vec<&Tensor> {
35 vec![&self.input]
36 }
37}
38
39#[derive(Debug)]
40pub struct MeanGrad {
41 input: Tensor,
42 output: Tensor,
43}
44
45impl MeanGrad {
46 pub fn new(input: &Tensor, output: &Tensor) -> Self {
47 MeanGrad {
48 input: input.clone(),
49 output: output.clone(),
50 }
51 }
52}
53
54impl GradientFunction for MeanGrad {
55 fn backward(&self) {
56 let device = self.output.device();
57
58 if let Some(input_grad) = &self.input.grad() {
59 if let Some(out_grad) = &self.output.grad() {
60 let input_shape = self.input.tensor().shape();
62 let n_elements = input_shape.iter().product::<usize>() as f32;
63 let ones = Storage::ones(input_shape.clone(), Some(device), None);
64 let expanded_grad = &ones * (out_grad.borrow().get(&[0]) / n_elements);
65 input_grad.borrow_mut().add_tensor_assign(&expanded_grad);
66 }
67 }
68 }
69
70 fn prev(&self) -> Vec<&Tensor> {
71 vec![&self.input]
72 }
73}
74
75#[derive(Debug)]
76pub struct ProductGrad {
77 input: Tensor,
78 output: Tensor,
79}
80
81impl ProductGrad {
82 pub fn new(input: &Tensor, output: &Tensor) -> Self {
83 ProductGrad {
84 input: input.clone(),
85 output: output.clone(),
86 }
87 }
88}
89
90impl GradientFunction for ProductGrad {
91 fn backward(&self) {
92 let device = self.output.device();
93
94 if let Some(input_grad) = &self.input.grad() {
95 if let Some(out_grad) = &self.output.grad() {
96 let input_data = self.input.tensor();
98 let mut grad = Storage::zeros(input_data.shape().clone(), Some(device), None);
99 let total_product = self.output.tensor().get(&[0]);
100
101 for i in 0..input_data.data().read().unwrap().len() {
103 let element = input_data.data().read().unwrap()[i];
104 if element != 0.0 {
105 grad.data_mut()[i] = total_product / element;
106 }
107 }
108
109 grad = &grad * out_grad.borrow().get(&[0]);
111 input_grad.borrow_mut().add_tensor_assign(&grad);
112 }
113 }
114 }
115
116 fn prev(&self) -> Vec<&Tensor> {
117 vec![&self.input]
118 }
119}
120