ferrite/autograd/scalar/
backward.rs1use super::{Graph, Scalar, Value}; use std::cell::{RefCell};
4use std::collections::HashSet;
5
6pub trait Backward {
7 fn add_backward(&self, lhs: usize, rhs: usize, grad: f32);
8 fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32);
9 fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32);
10 fn div_backward(&self, lhs: usize, rhs: usize, grad: f32);
11 fn exp_backward(&self, lhs: usize, grad: f32);
12 fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32);
13 fn backward(&self, value: Value);
14 fn backward_from(&self, idx: usize);
15
16}
17
18impl Backward for Graph {
19 fn add_backward(&self, lhs: usize, rhs: usize, grad: f32) {
21 let scalars = self.scalars.borrow();
22 let lhs_scalar = &scalars[lhs];
23 let rhs_scalar = &scalars[rhs];
24 lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
25 rhs_scalar.grad.set(rhs_scalar.grad.get() + grad);
26 }
27
28 fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32) {
30 let scalars = self.scalars.borrow();
31 let lhs_scalar = &scalars[lhs];
32 let rhs_scalar = &scalars[rhs];
33 lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
34 rhs_scalar.grad.set(rhs_scalar.grad.get() - grad);
35 }
36
37 fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32) {
39 let scalars = self.scalars.borrow();
40 let lhs_scalar = &scalars[lhs];
41 let rhs_scalar = &scalars[rhs];
42 lhs_scalar.grad.set(lhs_scalar.grad.get() + rhs_scalar.data * grad);
43 rhs_scalar.grad.set(rhs_scalar.grad.get() + lhs_scalar.data * grad);
44 }
45
46 fn div_backward(&self, lhs: usize, rhs: usize, grad: f32) {
48 let scalars = self.scalars.borrow();
49 let lhs_scalar = &scalars[lhs];
50 let rhs_scalar = &scalars[rhs];
51 lhs_scalar.grad.set(lhs_scalar.grad.get() + (1./rhs_scalar.data) * grad);
52 rhs_scalar.grad.set(rhs_scalar.grad.get() + (-lhs_scalar.data/(rhs_scalar.data*rhs_scalar.data)) * grad);
53 }
54
55
56 fn exp_backward(&self, lhs: usize, grad: f32) {
58 let scalars = self.scalars.borrow();
59 let lhs_scalar = &scalars[lhs];
60 let exp_x = f32::exp(lhs_scalar.data);
61 lhs_scalar.grad.set(lhs_scalar.grad.get() + grad * exp_x);
62 }
63
64 fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32) {
66 let scalars = self.scalars.borrow();
67 let lhs_scalar = &scalars[lhs];
68 let rhs_scalar = &scalars[rhs];
69 lhs_scalar.grad.set(lhs_scalar.grad.get() + (rhs_scalar.data * (f32::powf(lhs_scalar.data, rhs_scalar.data - 1.))) * grad);
70 rhs_scalar.grad.set(rhs_scalar.grad.get() + (f32::powf(lhs_scalar.data, rhs_scalar.data) * f32::ln(lhs_scalar.data)) * grad);
71 }
72
73
74 fn backward(&self, value: Value) {
75 let id = value.idx;
76 self.scalars.borrow()[id].grad.set(1.0);
77
78 let mut topo = vec![];
80 let mut visited = HashSet::new();
81
82 fn build_topo(v: usize, scalars: &RefCell<Vec<Scalar>>, visited: &mut HashSet<usize>, topo: &mut Vec<usize>) {
83 if !visited.contains(&v) {
84 visited.insert(v);
85
86 let children = {
87 let scalars = scalars.borrow();
88 scalars[v].prev.clone()
89 };
90
91 for &child in &children {
92 build_topo(child, scalars, visited, topo);
93 }
94
95 topo.push(v);
96 }
97 }
98
99 build_topo(id, &self.scalars, &mut visited, &mut topo);
100
101 for &idx in topo.iter().rev(){
102 self.backward_from(idx);
103 }
104
105 }
107
108 fn backward_from(&self, idx: usize) {
109 let scalars = self.scalars.borrow();
110 let scalar = &scalars[idx];
111
112 if !scalar.requires_grad {
113 return;
114 }
115
116 let grad = scalar.grad.get();
117
118 match scalar.op.as_str() {
119 "+" => {
120 let lhs = scalar.prev[0];
121 let rhs = scalar.prev[1];
122 self.add_backward(lhs, rhs, grad);
123 },
124 "-" => {
125 let lhs = scalar.prev[0];
126 let rhs = scalar.prev[1];
127 self.sub_backward(lhs, rhs, grad);
128 },
129 "*" => {
130 let lhs = scalar.prev[0];
131 let rhs = scalar.prev[1];
132 self.mul_backward(lhs, rhs, grad);
133 },
134 "/" => {
135 let lhs = scalar.prev[0];
136 let rhs = scalar.prev[1];
137 self.div_backward(lhs, rhs, grad);
138 },
139 "exp" => {
140 let lhs = scalar.prev[0];
141 self.exp_backward(lhs, grad);
142 },
143 "pow" => {
144 let lhs = scalar.prev[0];
145 let rhs = scalar.prev[1];
146 self.pow_backward(lhs, rhs, grad);
147 },
148
149 _ => {}
150 }
151 }
152
153}
154