use super::{Graph, Scalar, Value}; use std::cell::{RefCell};
use std::collections::HashSet;
pub trait Backward {
fn add_backward(&self, lhs: usize, rhs: usize, grad: f32);
fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32);
fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32);
fn div_backward(&self, lhs: usize, rhs: usize, grad: f32);
fn exp_backward(&self, lhs: usize, grad: f32);
fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32);
fn backward(&self, value: Value);
fn backward_from(&self, idx: usize);
}
impl Backward for Graph {
fn add_backward(&self, lhs: usize, rhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let rhs_scalar = &scalars[rhs];
lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
rhs_scalar.grad.set(rhs_scalar.grad.get() + grad);
}
fn sub_backward(&self, lhs: usize, rhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let rhs_scalar = &scalars[rhs];
lhs_scalar.grad.set(lhs_scalar.grad.get() + grad);
rhs_scalar.grad.set(rhs_scalar.grad.get() - grad);
}
fn mul_backward(&self, lhs: usize, rhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let rhs_scalar = &scalars[rhs];
lhs_scalar.grad.set(lhs_scalar.grad.get() + rhs_scalar.data * grad);
rhs_scalar.grad.set(rhs_scalar.grad.get() + lhs_scalar.data * grad);
}
fn div_backward(&self, lhs: usize, rhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let rhs_scalar = &scalars[rhs];
lhs_scalar.grad.set(lhs_scalar.grad.get() + (1./rhs_scalar.data) * grad);
rhs_scalar.grad.set(rhs_scalar.grad.get() + (-lhs_scalar.data/(rhs_scalar.data*rhs_scalar.data)) * grad);
}
fn exp_backward(&self, lhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let exp_x = f32::exp(lhs_scalar.data);
lhs_scalar.grad.set(lhs_scalar.grad.get() + grad * exp_x);
}
fn pow_backward(&self, lhs: usize, rhs: usize, grad: f32) {
let scalars = self.scalars.borrow();
let lhs_scalar = &scalars[lhs];
let rhs_scalar = &scalars[rhs];
lhs_scalar.grad.set(lhs_scalar.grad.get() + (rhs_scalar.data * (f32::powf(lhs_scalar.data, rhs_scalar.data - 1.))) * grad);
rhs_scalar.grad.set(rhs_scalar.grad.get() + (f32::powf(lhs_scalar.data, rhs_scalar.data) * f32::ln(lhs_scalar.data)) * grad);
}
fn backward(&self, value: Value) {
let id = value.idx;
self.scalars.borrow()[id].grad.set(1.0);
let mut topo = vec![];
let mut visited = HashSet::new();
fn build_topo(v: usize, scalars: &RefCell<Vec<Scalar>>, visited: &mut HashSet<usize>, topo: &mut Vec<usize>) {
if !visited.contains(&v) {
visited.insert(v);
let children = {
let scalars = scalars.borrow();
scalars[v].prev.clone()
};
for &child in &children {
build_topo(child, scalars, visited, topo);
}
topo.push(v);
}
}
build_topo(id, &self.scalars, &mut visited, &mut topo);
for &idx in topo.iter().rev(){
self.backward_from(idx);
}
}
fn backward_from(&self, idx: usize) {
let scalars = self.scalars.borrow();
let scalar = &scalars[idx];
if !scalar.requires_grad {
return;
}
let grad = scalar.grad.get();
match scalar.op.as_str() {
"+" => {
let lhs = scalar.prev[0];
let rhs = scalar.prev[1];
self.add_backward(lhs, rhs, grad);
},
"-" => {
let lhs = scalar.prev[0];
let rhs = scalar.prev[1];
self.sub_backward(lhs, rhs, grad);
},
"*" => {
let lhs = scalar.prev[0];
let rhs = scalar.prev[1];
self.mul_backward(lhs, rhs, grad);
},
"/" => {
let lhs = scalar.prev[0];
let rhs = scalar.prev[1];
self.div_backward(lhs, rhs, grad);
},
"exp" => {
let lhs = scalar.prev[0];
self.exp_backward(lhs, grad);
},
"pow" => {
let lhs = scalar.prev[0];
let rhs = scalar.prev[1];
self.pow_backward(lhs, rhs, grad);
},
_ => {}
}
}
}