ferrite/autograd/scalar/
backward.rs

1use super::{Graph, Scalar, Value};  // Import Graph from parent module
2// scalar.rs
3use std::cell::{RefCell};
4use std::collections::HashSet;
5
6pub trait Backward {
7  fn add_backward(&self, lhs: usize, rhs: usize, grad: f32);
8  fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32);
9  fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32);
10  fn div_backward(&self, lhs: usize, rhs: usize, grad: f32);
11  fn exp_backward(&self, lhs: usize, grad: f32);
12  fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32);
13  fn backward(&self, value: Value);
14  fn backward_from(&self, idx: usize);
15
16}
17
18impl Backward for Graph {
19  // d(a + b)/da = 1, d(a + b)/db = 1
20  fn add_backward(&self, lhs: usize, rhs: usize, grad: f32) {
21    let scalars = self.scalars.borrow();
22    let lhs_scalar = &scalars[lhs];
23    let rhs_scalar = &scalars[rhs];
24    lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
25    rhs_scalar.grad.set(rhs_scalar.grad.get() + grad);
26  }
27
28  // d(a + b)/da = 1, d(a + b)/db = 1
29  fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32) {
30    let scalars = self.scalars.borrow();
31    let lhs_scalar = &scalars[lhs];
32    let rhs_scalar = &scalars[rhs];
33    lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
34    rhs_scalar.grad.set(rhs_scalar.grad.get() - grad);
35  }
36
37  // d(a * b)/da = b, d(a * b)/db = a
38  fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32) {
39    let scalars = self.scalars.borrow();
40    let lhs_scalar = &scalars[lhs];
41    let rhs_scalar = &scalars[rhs];
42    lhs_scalar.grad.set(lhs_scalar.grad.get() + rhs_scalar.data * grad);
43    rhs_scalar.grad.set(rhs_scalar.grad.get() + lhs_scalar.data * grad);
44  }
45
46  // d(a/b)/da = 1/b, d(a/b)/db = (-a/b**2)
47  fn div_backward(&self, lhs: usize, rhs: usize, grad: f32) {
48    let scalars = self.scalars.borrow();
49    let lhs_scalar = &scalars[lhs];
50    let rhs_scalar = &scalars[rhs];
51    lhs_scalar.grad.set(lhs_scalar.grad.get() + (1./rhs_scalar.data) * grad);
52    rhs_scalar.grad.set(rhs_scalar.grad.get() + (-lhs_scalar.data/(rhs_scalar.data*rhs_scalar.data)) * grad);
53  }
54
55
56  // d(e^x)/dx = e^x
57  fn exp_backward(&self, lhs: usize, grad: f32) {
58    let scalars = self.scalars.borrow();
59    let lhs_scalar = &scalars[lhs];
60    let exp_x = f32::exp(lhs_scalar.data);
61    lhs_scalar.grad.set(lhs_scalar.grad.get() + grad * exp_x);
62  }
63
64  // d(x^a)/dx = ax^(a-1), d(x^a)/da = x^a * lnx
65  fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32) {
66    let scalars = self.scalars.borrow();
67    let lhs_scalar = &scalars[lhs];
68    let rhs_scalar = &scalars[rhs];
69    lhs_scalar.grad.set(lhs_scalar.grad.get() + (rhs_scalar.data * (f32::powf(lhs_scalar.data, rhs_scalar.data - 1.))) * grad);
70    rhs_scalar.grad.set(rhs_scalar.grad.get() + (f32::powf(lhs_scalar.data, rhs_scalar.data) * f32::ln(lhs_scalar.data)) * grad);
71  }
72
73
74  fn backward(&self, value: Value) {
75    let id = value.idx;
76    self.scalars.borrow()[id].grad.set(1.0);
77
78    // Topo sort
79    let mut topo = vec![];
80    let mut visited = HashSet::new();
81
82    fn build_topo(v: usize, scalars: &RefCell<Vec<Scalar>>, visited: &mut HashSet<usize>, topo: &mut Vec<usize>) {
83      if !visited.contains(&v) {
84        visited.insert(v);
85
86        let children = {
87          let scalars = scalars.borrow();
88          scalars[v].prev.clone()
89        };
90
91        for &child in &children {
92          build_topo(child, scalars, visited, topo);
93        }
94
95        topo.push(v);
96      }
97    }
98
99    build_topo(id, &self.scalars, &mut visited, &mut topo);
100    
101    for &idx in topo.iter().rev(){
102      self.backward_from(idx);
103    }
104
105   // 
106  }
107
108  fn backward_from(&self, idx: usize) {
109    let scalars = self.scalars.borrow();
110    let scalar = &scalars[idx];
111
112    if !scalar.requires_grad {
113      return;
114    }
115
116    let grad = scalar.grad.get();
117
118    match scalar.op.as_str() {
119      "+" => {
120        let lhs = scalar.prev[0];
121        let rhs = scalar.prev[1];
122        self.add_backward(lhs, rhs, grad);
123      },
124      "-" => {
125        let lhs = scalar.prev[0];
126        let rhs = scalar.prev[1];
127        self.sub_backward(lhs, rhs, grad);
128      },
129      "*" => {
130        let lhs = scalar.prev[0];
131        let rhs = scalar.prev[1];
132        self.mul_backward(lhs, rhs, grad);
133      },
134      "/" => {
135        let lhs = scalar.prev[0];
136        let rhs = scalar.prev[1];
137        self.div_backward(lhs, rhs, grad);
138      },
139      "exp" => {
140        let lhs = scalar.prev[0];
141        self.exp_backward(lhs, grad);
142      },
143      "pow" => {
144        let lhs = scalar.prev[0];
145        let rhs = scalar.prev[1];
146        self.pow_backward(lhs, rhs, grad);
147      },
148      
149      _ => {}
150    }
151  }
152
153}
154