#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct Var(pub usize);
#[derive(Clone, Debug, PartialEq)]
pub enum TapeNode {
Const(f64),
Input(usize),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
Div(usize, usize),
Sin(usize),
Cos(usize),
Exp(usize),
Ln(usize),
Sqrt(usize),
Recip(usize),
}
#[derive(Clone, Debug, Default)]
pub struct Tape {
nodes: Vec<TapeNode>,
values: Vec<f64>,
}
impl Tape {
pub fn new() -> Self {
Self::default()
}
pub fn constant(&mut self, value: f64) -> Var {
self.push(TapeNode::Const(value), value)
}
pub fn input(&mut self, slot: usize, value: f64) -> Var {
self.push(TapeNode::Input(slot), value)
}
pub fn add(&mut self, a: Var, b: Var) -> Var {
self.push(TapeNode::Add(a.0, b.0), self.values[a.0] + self.values[b.0])
}
pub fn sub(&mut self, a: Var, b: Var) -> Var {
self.push(TapeNode::Sub(a.0, b.0), self.values[a.0] - self.values[b.0])
}
pub fn mul(&mut self, a: Var, b: Var) -> Var {
self.push(TapeNode::Mul(a.0, b.0), self.values[a.0] * self.values[b.0])
}
pub fn div(&mut self, a: Var, b: Var) -> Var {
self.push(TapeNode::Div(a.0, b.0), self.values[a.0] / self.values[b.0])
}
pub fn sin(&mut self, arg: Var) -> Var {
self.push(TapeNode::Sin(arg.0), self.values[arg.0].sin())
}
pub fn cos(&mut self, arg: Var) -> Var {
self.push(TapeNode::Cos(arg.0), self.values[arg.0].cos())
}
pub fn exp(&mut self, arg: Var) -> Var {
self.push(TapeNode::Exp(arg.0), self.values[arg.0].exp())
}
pub fn ln(&mut self, arg: Var) -> Var {
self.push(TapeNode::Ln(arg.0), self.values[arg.0].ln())
}
pub fn sqrt(&mut self, arg: Var) -> Var {
self.push(TapeNode::Sqrt(arg.0), self.values[arg.0].sqrt())
}
pub fn recip(&mut self, arg: Var) -> Var {
self.push(TapeNode::Recip(arg.0), self.values[arg.0].recip())
}
pub fn value(&self, var: Var) -> f64 {
self.values[var.0]
}
pub fn grad(&self, out: Var, n_inputs: usize) -> Vec<f64> {
let mut adjoints = vec![0.0; self.nodes.len()];
let mut input_grad = vec![0.0; n_inputs];
adjoints[out.0] = 1.0;
for index in (0..self.nodes.len()).rev() {
let seed = adjoints[index];
if seed == 0.0 {
continue;
}
match self.nodes[index] {
TapeNode::Const(_) => {}
TapeNode::Input(slot) => {
if let Some(grad) = input_grad.get_mut(slot) {
*grad += seed;
}
}
TapeNode::Add(a, b) => {
adjoints[a] += seed;
adjoints[b] += seed;
}
TapeNode::Sub(a, b) => {
adjoints[a] += seed;
adjoints[b] -= seed;
}
TapeNode::Mul(a, b) => {
adjoints[a] += seed * self.values[b];
adjoints[b] += seed * self.values[a];
}
TapeNode::Div(a, b) => {
let denom = self.values[b] * self.values[b];
adjoints[a] += seed / self.values[b];
adjoints[b] -= seed * self.values[a] / denom;
}
TapeNode::Sin(arg) => adjoints[arg] += seed * self.values[arg].cos(),
TapeNode::Cos(arg) => adjoints[arg] -= seed * self.values[arg].sin(),
TapeNode::Exp(arg) => adjoints[arg] += seed * self.values[index],
TapeNode::Ln(arg) => adjoints[arg] += seed / self.values[arg],
TapeNode::Sqrt(arg) => adjoints[arg] += seed / (2.0 * self.values[index]),
TapeNode::Recip(arg) => {
let denom = self.values[arg] * self.values[arg];
adjoints[arg] -= seed / denom;
}
}
}
input_grad
}
fn push(&mut self, node: TapeNode, value: f64) -> Var {
let index = self.nodes.len();
self.nodes.push(node);
self.values.push(value);
Var(index)
}
}