use super::Tape;
use alloc::sync::Arc;
use alloc::vec;
use core::ops::{Add, Div, Mul, Neg, Sub};
use crate::la::DVec;
#[derive(Clone, Debug)]
pub struct Var {
pub(crate) index: usize,
pub(crate) value: f64,
pub(crate) tape: Arc<Tape>,
}
impl Var {
#[inline]
pub fn value(&self) -> f64 {
self.value
}
pub fn backward(&self) -> DVec<f64> {
let ops = self.tape.ops.borrow();
let n = ops.len();
let mut adjoints = vec![0.0_f64; n];
adjoints[self.index] = 1.0;
for i in (0..n).rev() {
let adj = adjoints[i];
if adj == 0.0 {
continue;
}
let op = &ops[i];
if op.num_inputs >= 1 {
adjoints[op.inputs[0]] += adj * op.partials[0];
}
if op.num_inputs >= 2 {
adjoints[op.inputs[1]] += adj * op.partials[1];
}
}
DVec::from_vec(adjoints)
}
pub fn sin(&self) -> Var {
self.tape
.unary(self.index, self.value.sin(), self.value.cos())
}
pub fn cos(&self) -> Var {
self.tape
.unary(self.index, self.value.cos(), -self.value.sin())
}
pub fn exp(&self) -> Var {
let e = self.value.exp();
self.tape.unary(self.index, e, e)
}
pub fn ln(&self) -> Var {
self.tape
.unary(self.index, self.value.ln(), 1.0 / self.value)
}
pub fn sqrt(&self) -> Var {
let s = self.value.sqrt();
self.tape.unary(self.index, s, 0.5 / s)
}
pub fn abs(&self) -> Var {
self.tape
.unary(self.index, self.value.abs(), self.value.signum())
}
pub fn tanh(&self) -> Var {
let t = self.value.tanh();
self.tape.unary(self.index, t, 1.0 - t * t)
}
pub fn powf(&self, p: f64) -> Var {
let val = self.value.powf(p);
self.tape
.unary(self.index, val, p * self.value.powf(p - 1.0))
}
pub fn constant(tape: &Arc<Tape>, value: f64) -> Var {
let mut ops = tape.ops.borrow_mut();
let index = ops.len();
ops.push(super::tape::Op {
inputs: [0, 0],
num_inputs: 0,
partials: [0.0, 0.0],
});
Var {
index,
value,
tape: Arc::clone(tape),
}
}
}
impl Add for &Var {
type Output = Var;
fn add(self, rhs: &Var) -> Var {
self.tape
.binary(self.index, rhs.index, self.value + rhs.value, 1.0, 1.0)
}
}
impl Sub for &Var {
type Output = Var;
fn sub(self, rhs: &Var) -> Var {
self.tape
.binary(self.index, rhs.index, self.value - rhs.value, 1.0, -1.0)
}
}
impl Mul for &Var {
type Output = Var;
fn mul(self, rhs: &Var) -> Var {
self.tape.binary(
self.index,
rhs.index,
self.value * rhs.value,
rhs.value,
self.value,
)
}
}
impl Div for &Var {
type Output = Var;
fn div(self, rhs: &Var) -> Var {
let inv = 1.0 / rhs.value;
self.tape.binary(
self.index,
rhs.index,
self.value * inv,
inv,
-self.value * inv * inv,
)
}
}
impl Neg for &Var {
type Output = Var;
fn neg(self) -> Var {
self.tape.unary(self.index, -self.value, -1.0)
}
}
impl Add for Var {
type Output = Var;
fn add(self, rhs: Var) -> Var {
(&self) + (&rhs)
}
}
impl Sub for Var {
type Output = Var;
fn sub(self, rhs: Var) -> Var {
(&self) - (&rhs)
}
}
impl Mul for Var {
type Output = Var;
fn mul(self, rhs: Var) -> Var {
(&self) * (&rhs)
}
}
impl Div for Var {
type Output = Var;
fn div(self, rhs: Var) -> Var {
(&self) / (&rhs)
}
}
impl Neg for Var {
type Output = Var;
fn neg(self) -> Var {
-(&self)
}
}
impl Mul<f64> for &Var {
type Output = Var;
fn mul(self, rhs: f64) -> Var {
self.tape.unary(self.index, self.value * rhs, rhs)
}
}
impl Mul<f64> for Var {
type Output = Var;
fn mul(self, rhs: f64) -> Var {
(&self) * rhs
}
}
impl Add<f64> for &Var {
type Output = Var;
fn add(self, rhs: f64) -> Var {
self.tape.unary(self.index, self.value + rhs, 1.0)
}
}
impl Add<f64> for Var {
type Output = Var;
fn add(self, rhs: f64) -> Var {
(&self) + rhs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gradient_of_sum() {
let tape = Tape::new();
let x = tape.var(3.0);
let y = tape.var(5.0);
let z = &x + &y;
let grads = z.backward();
assert_eq!(grads[x.index], 1.0);
assert_eq!(grads[y.index], 1.0);
}
#[test]
fn gradient_of_product() {
let tape = Tape::new();
let x = tape.var(3.0);
let y = tape.var(5.0);
let z = &x * &y;
let grads = z.backward();
assert_eq!(grads[x.index], 5.0); assert_eq!(grads[y.index], 3.0); }
#[test]
fn gradient_of_square() {
let tape = Tape::new();
let x = tape.var(3.0);
let z = &x * &x;
let grads = z.backward();
assert!((grads[x.index] - 6.0).abs() < 1e-10); }
#[test]
fn gradient_of_exp() {
let tape = Tape::new();
let x = tape.var(1.0);
let z = x.exp();
let grads = z.backward();
let e = 1.0_f64.exp();
assert!((grads[0] - e).abs() < 1e-10);
}
#[test]
fn gradient_of_sin() {
let tape = Tape::new();
let x = tape.var(0.0);
let z = x.sin();
let grads = z.backward();
assert!((grads[0] - 1.0).abs() < 1e-10); }
#[test]
fn gradient_chain_rule() {
let tape = Tape::new();
let x = tape.var(1.0);
let x2 = &x * &x;
let z = x2.sin();
let grads = z.backward();
let expected = 2.0 * 1.0_f64.cos();
assert!((grads[x.index] - expected).abs() < 1e-10);
}
#[test]
fn gradient_complex_expr() {
let tape = Tape::new();
let x = tape.var(2.0);
let y = tape.var(3.0);
let xy = &x * &y;
let sx = x.sin();
let z = xy + sx;
let grads = z.backward();
assert!((grads[x.index] - (3.0 + 2.0_f64.cos())).abs() < 1e-10);
assert!((grads[y.index] - 2.0).abs() < 1e-10);
}
}