use crate::tape::{Tape, TapeRef};
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::rc::Rc;
#[derive(Clone, Debug)]
pub struct Var {
pub(crate) index: usize,
pub value: f64,
pub(crate) tape: TapeRef,
}
impl Var {
pub fn cst(&self, value: f64) -> Var {
let (index, value) = {
let mut t = self.tape.borrow_mut();
let idx = t.nodes.len();
t.nodes.push(crate::tape::Node {
value,
parent1: None,
parent2: None,
});
(idx, value)
};
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn sin(&self) -> Var {
let val = self.value.sin();
let deriv = self.value.cos(); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn cos(&self) -> Var {
let val = self.value.cos();
let deriv = -self.value.sin(); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn tan(&self) -> Var {
let val = self.value.tan();
let c = self.value.cos();
let deriv = 1.0 / (c * c); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn exp(&self) -> Var {
let val = self.value.exp();
let deriv = val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn ln(&self) -> Var {
let val = self.value.ln();
let deriv = 1.0 / self.value; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn sqrt(&self) -> Var {
let val = self.value.sqrt();
let deriv = 0.5 / val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn abs(&self) -> Var {
let val = self.value.abs();
let deriv = if self.value >= 0.0 { 1.0 } else { -1.0 };
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn tanh(&self) -> Var {
let val = self.value.tanh();
let deriv = 1.0 - val * val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn sinh(&self) -> Var {
let val = self.value.sinh();
let deriv = self.value.cosh();
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn cosh(&self) -> Var {
let val = self.value.cosh();
let deriv = self.value.sinh();
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn pow(&self, n: &Var) -> Var {
let val = self.value.powf(n.value);
let d_self = n.value * self.value.powf(n.value - 1.0);
let d_n = val * self.value.ln();
let (index, value) = Tape::push_binary(&self.tape, val, self.index, d_self, n.index, d_n);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn powf(&self, n: f64) -> Var {
let val = self.value.powf(n);
let deriv = n * self.value.powf(n - 1.0);
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn powi(&self, n: i32) -> Var {
self.powf(n as f64)
}
pub fn asin(&self) -> Var {
let val = self.value.asin();
let deriv = 1.0 / (1.0 - self.value * self.value).sqrt();
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn acos(&self) -> Var {
let val = self.value.acos();
let deriv = -1.0 / (1.0 - self.value * self.value).sqrt();
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
pub fn atan(&self) -> Var {
let val = self.value.atan();
let deriv = 1.0 / (1.0 + self.value * self.value);
let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
impl Add for Var {
type Output = Var;
fn add(self, rhs: Var) -> Var {
let val = self.value + rhs.value;
let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Add for &Var {
type Output = Var;
fn add(self, rhs: &Var) -> Var {
let val = self.value + rhs.value;
let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
impl Add<f64> for Var {
type Output = Var;
fn add(self, rhs: f64) -> Var {
let val = self.value + rhs;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Add<Var> for f64 {
type Output = Var;
fn add(self, rhs: Var) -> Var {
let val = self + rhs.value;
let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, 1.0);
Var {
index,
value,
tape: rhs.tape,
}
}
}
impl Sub for Var {
type Output = Var;
fn sub(self, rhs: Var) -> Var {
let val = self.value - rhs.value;
let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Sub for &Var {
type Output = Var;
fn sub(self, rhs: &Var) -> Var {
let val = self.value - rhs.value;
let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
impl Sub<f64> for Var {
type Output = Var;
fn sub(self, rhs: f64) -> Var {
let val = self.value - rhs;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Sub<Var> for f64 {
type Output = Var;
fn sub(self, rhs: Var) -> Var {
let val = self - rhs.value;
let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, -1.0);
Var {
index,
value,
tape: rhs.tape,
}
}
}
impl Mul for Var {
type Output = Var;
fn mul(self, rhs: Var) -> Var {
let val = self.value * rhs.value;
let (index, value) = Tape::push_binary(
&self.tape, val, self.index, rhs.value, rhs.index, self.value, );
Var {
index,
value,
tape: self.tape,
}
}
}
impl Mul for &Var {
type Output = Var;
fn mul(self, rhs: &Var) -> Var {
let val = self.value * rhs.value;
let (index, value) = Tape::push_binary(
&self.tape, val, self.index, rhs.value, rhs.index, self.value,
);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
impl Mul<f64> for Var {
type Output = Var;
fn mul(self, rhs: f64) -> Var {
let val = self.value * rhs;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, rhs);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Mul<Var> for f64 {
type Output = Var;
fn mul(self, rhs: Var) -> Var {
let val = self * rhs.value;
let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, self);
Var {
index,
value,
tape: rhs.tape,
}
}
}
impl Mul<&Var> for Var {
type Output = Var;
fn mul(self, rhs: &Var) -> Var {
let val = self.value * rhs.value;
let (index, value) = Tape::push_binary(
&self.tape, val, self.index, rhs.value, rhs.index, self.value,
);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Mul<Var> for &Var {
type Output = Var;
fn mul(self, rhs: Var) -> Var {
let val = self.value * rhs.value;
let (index, value) = Tape::push_binary(
&self.tape, val, self.index, rhs.value, rhs.index, self.value,
);
Var {
index,
value,
tape: rhs.tape,
}
}
}
impl Div for Var {
type Output = Var;
fn div(self, rhs: Var) -> Var {
let val = self.value / rhs.value;
let (index, value) = Tape::push_binary(
&self.tape,
val,
self.index,
1.0 / rhs.value, rhs.index,
-self.value / (rhs.value * rhs.value), );
Var {
index,
value,
tape: self.tape,
}
}
}
impl Div for &Var {
type Output = Var;
fn div(self, rhs: &Var) -> Var {
let val = self.value / rhs.value;
let (index, value) = Tape::push_binary(
&self.tape,
val,
self.index,
1.0 / rhs.value,
rhs.index,
-self.value / (rhs.value * rhs.value),
);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
impl Div<f64> for Var {
type Output = Var;
fn div(self, rhs: f64) -> Var {
let val = self.value / rhs;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0 / rhs);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Neg for Var {
type Output = Var;
fn neg(self) -> Var {
let val = -self.value;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
Var {
index,
value,
tape: self.tape,
}
}
}
impl Neg for &Var {
type Output = Var;
fn neg(self) -> Var {
let val = -self.value;
let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
Var {
index,
value,
tape: Rc::clone(&self.tape),
}
}
}
pub fn grad<F>(f: F, x: &[f64]) -> Vec<f64>
where
F: Fn(&[Var]) -> Var,
{
let tape = Tape::new();
let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
let output = f(&vars);
Tape::gradient(&tape, &output)
}
pub fn jacobian_reverse<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[Var]) -> Vec<Var>,
{
let tape = Tape::new();
let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
let outputs = f(&vars);
Tape::jacobian(&tape, &outputs)
}
pub fn hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[Var]) -> Var,
{
let n = x.len();
let eps = 1e-7;
let mut h = vec![vec![0.0; n]; n];
let g0 = grad(&f, x);
for j in 0..n {
let mut x_pert = x.to_vec();
x_pert[j] += eps;
let g_pert = grad(&f, &x_pert);
for i in 0..n {
h[i][j] = (g_pert[i] - g0[i]) / eps;
}
}
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_arithmetic() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let y = Tape::var(&tape, 3.0);
let z = x.clone() + y.clone();
assert!((z.value - 5.0).abs() < 1e-14);
let g = Tape::gradient(&tape, &z);
assert!((g[0] - 1.0).abs() < 1e-14);
assert!((g[1] - 1.0).abs() < 1e-14);
}
#[test]
fn test_multiplication() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let y = Tape::var(&tape, 3.0);
let z = x * y;
let g = Tape::gradient(&tape, &z);
assert!((g[0] - 3.0).abs() < 1e-14); assert!((g[1] - 2.0).abs() < 1e-14); }
#[test]
fn test_chain_rule() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let z = x.clone() * x.clone() * x;
let g = Tape::gradient(&tape, &z);
assert!((g[0] - 12.0).abs() < 1e-12);
}
#[test]
fn test_subtraction() {
let tape = Tape::new();
let x = Tape::var(&tape, 5.0);
let y = Tape::var(&tape, 3.0);
let z = x - y;
assert!((z.value - 2.0).abs() < 1e-14);
let g = Tape::gradient(&tape, &z);
assert!((g[0] - 1.0).abs() < 1e-14);
assert!((g[1] - (-1.0)).abs() < 1e-14);
}
#[test]
fn test_division() {
let tape = Tape::new();
let x = Tape::var(&tape, 6.0);
let y = Tape::var(&tape, 3.0);
let z = x / y; let g = Tape::gradient(&tape, &z);
assert!((g[0] - 1.0 / 3.0).abs() < 1e-14); assert!((g[1] - (-6.0 / 9.0)).abs() < 1e-14); }
#[test]
fn test_negation() {
let tape = Tape::new();
let x = Tape::var(&tape, 3.0);
let z = -x;
assert!((z.value - (-3.0)).abs() < 1e-14);
let g = Tape::gradient(&tape, &z);
assert!((g[0] - (-1.0)).abs() < 1e-14);
}
#[test]
fn test_sin_cos() {
let tape = Tape::new();
let x = Tape::var(&tape, 1.0);
let s = x.sin();
let g = Tape::gradient(&tape, &s);
assert!((g[0] - 1.0_f64.cos()).abs() < 1e-14);
let tape2 = Tape::new();
let x2 = Tape::var(&tape2, 1.0);
let c = x2.cos();
let g2 = Tape::gradient(&tape2, &c);
assert!((g2[0] - (-1.0_f64.sin())).abs() < 1e-14);
}
#[test]
fn test_exp_ln() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let e = x.exp();
let g = Tape::gradient(&tape, &e);
assert!((g[0] - 2.0_f64.exp()).abs() < 1e-12);
let tape2 = Tape::new();
let x2 = Tape::var(&tape2, 3.0);
let l = x2.ln();
let g2 = Tape::gradient(&tape2, &l);
assert!((g2[0] - 1.0 / 3.0).abs() < 1e-14);
}
#[test]
fn test_sqrt() {
let tape = Tape::new();
let x = Tape::var(&tape, 4.0);
let s = x.sqrt();
assert!((s.value - 2.0).abs() < 1e-14);
let g = Tape::gradient(&tape, &s);
assert!((g[0] - 0.25).abs() < 1e-14); }
#[test]
fn test_tanh() {
let tape = Tape::new();
let x = Tape::var(&tape, 1.0);
let t = x.tanh();
let g = Tape::gradient(&tape, &t);
let expected = 1.0 - 1.0_f64.tanh().powi(2);
assert!((g[0] - expected).abs() < 1e-14);
}
#[test]
fn test_powf() {
let tape = Tape::new();
let x = Tape::var(&tape, 3.0);
let p = x.powf(2.0); assert!((p.value - 9.0).abs() < 1e-14);
let g = Tape::gradient(&tape, &p);
assert!((g[0] - 6.0).abs() < 1e-12); }
#[test]
fn test_pow_var() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let y = Tape::var(&tape, 3.0);
let p = x.pow(&y); assert!((p.value - 8.0).abs() < 1e-12);
let g = Tape::gradient(&tape, &p);
assert!((g[0] - 12.0).abs() < 1e-10);
assert!((g[1] - 8.0 * 2.0_f64.ln()).abs() < 1e-10);
}
#[test]
fn test_scalar_ops() {
let tape = Tape::new();
let x = Tape::var(&tape, 3.0);
let z1 = x.clone() + 2.0;
assert!((z1.value - 5.0).abs() < 1e-14);
let z2 = 2.0 + x.clone();
assert!((z2.value - 5.0).abs() < 1e-14);
let z3 = x.clone() * 3.0;
assert!((z3.value - 9.0).abs() < 1e-14);
let z4 = 3.0 * x.clone();
assert!((z4.value - 9.0).abs() < 1e-14);
let z5 = x.clone() - 1.0;
assert!((z5.value - 2.0).abs() < 1e-14);
let z6 = 10.0 - x.clone();
assert!((z6.value - 7.0).abs() < 1e-14);
let z7 = x / 2.0;
assert!((z7.value - 1.5).abs() < 1e-14);
}
#[test]
fn test_reference_ops() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let y = Tape::var(&tape, 3.0);
let z = &x + &y;
assert!((z.value - 5.0).abs() < 1e-14);
let z2 = &x * &y;
assert!((z2.value - 6.0).abs() < 1e-14);
let z3 = &x - &y;
assert!((z3.value - (-1.0)).abs() < 1e-14);
let z4 = &x / &y;
assert!((z4.value - 2.0 / 3.0).abs() < 1e-14);
}
#[test]
fn test_grad_rosenbrock() {
let g = grad(
|x| {
let a = x[0].cst(1.0) - x[0].clone(); let b = x[1].clone() - x[0].clone() * x[0].clone(); a.clone() * a + x[0].cst(100.0) * b.clone() * b
},
&[1.0, 1.0],
);
assert!(g[0].abs() < 1e-10);
assert!(g[1].abs() < 1e-10);
}
#[test]
fn test_grad_rosenbrock_nonzero() {
let g = grad(
|x| {
let a = x[0].cst(1.0) - x[0].clone();
let b = x[1].clone() - x[0].clone() * x[0].clone();
a.clone() * a + x[0].cst(100.0) * b.clone() * b
},
&[0.0, 0.0],
);
assert!((g[0] - (-2.0)).abs() < 1e-10);
assert!(g[1].abs() < 1e-10);
}
#[test]
fn test_jacobian_reverse_fn() {
let jac = jacobian_reverse(|x| vec![&x[0] + &x[1], &x[0] * &x[1]], &[2.0, 3.0]);
assert_eq!(jac.len(), 2);
assert!((jac[0][0] - 1.0).abs() < 1e-14);
assert!((jac[0][1] - 1.0).abs() < 1e-14);
assert!((jac[1][0] - 3.0).abs() < 1e-14);
assert!((jac[1][1] - 2.0).abs() < 1e-14);
}
#[test]
fn test_jacobian_rotation() {
let theta: f64 = 0.5;
let jac = jacobian_reverse(
|v| {
let x = &v[0];
let y = &v[1];
let ct = x.cst(theta.cos());
let st = x.cst(theta.sin());
vec![&(x * &ct) - &(y * &st), &(x * &st) + &(y * &ct)]
},
&[1.0, 0.0],
);
assert!((jac[0][0] - theta.cos()).abs() < 1e-12);
assert!((jac[0][1] - (-theta.sin())).abs() < 1e-12);
assert!((jac[1][0] - theta.sin()).abs() < 1e-12);
assert!((jac[1][1] - theta.cos()).abs() < 1e-12);
}
#[test]
fn test_hessian_quadratic() {
let h = hessian(
|x| &x[0] * &x[0] + x[0].cst(2.0) * &x[0] * &x[1] + x[0].cst(3.0) * &x[1] * &x[1],
&[1.0, 1.0],
);
assert!((h[0][0] - 2.0).abs() < 1e-5);
assert!((h[0][1] - 2.0).abs() < 1e-5);
assert!((h[1][0] - 2.0).abs() < 1e-5);
assert!((h[1][1] - 6.0).abs() < 1e-5);
}
#[test]
fn test_hessian_rosenbrock() {
let h = hessian(
|x| {
let a = x[0].cst(1.0) - x[0].clone();
let b = x[1].clone() - x[0].clone() * x[0].clone();
a.clone() * a + x[0].cst(100.0) * b.clone() * b
},
&[1.0, 1.0],
);
assert!((h[0][0] - 802.0).abs() < 1e-3);
assert!((h[0][1] - (-400.0)).abs() < 1e-3);
assert!((h[1][0] - (-400.0)).abs() < 1e-3);
assert!((h[1][1] - 200.0).abs() < 1e-3);
}
#[test]
fn test_inverse_trig() {
let tape = Tape::new();
let x = Tape::var(&tape, 0.5);
let a = x.asin();
let g = Tape::gradient(&tape, &a);
assert!((g[0] - 1.0 / (1.0 - 0.25_f64).sqrt()).abs() < 1e-12);
let tape2 = Tape::new();
let x2 = Tape::var(&tape2, 1.0);
let a2 = x2.atan();
let g2 = Tape::gradient(&tape2, &a2);
assert!((g2[0] - 0.5).abs() < 1e-14); }
#[test]
fn test_grad_matches_fd() {
let f = |x: &[Var]| x[0].sin() * x[1].exp() + x[2].clone() * x[2].clone();
let x0 = [1.0, 2.0, 3.0];
let g = grad(f, &x0);
let eps = 1e-7;
let f_val = |x: &[f64]| x[0].sin() * x[1].exp() + x[2] * x[2];
let f0 = f_val(&x0);
for i in 0..3 {
let mut xp = x0;
xp[i] += eps;
let fd = (f_val(&xp) - f0) / eps;
assert!(
(g[i] - fd).abs() < 1e-5,
"component {} mismatch: {} vs {}",
i,
g[i],
fd
);
}
}
#[test]
fn test_constant() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let c = x.cst(5.0);
let z = x * c; let g = Tape::gradient(&tape, &z);
assert!((g[0] - 5.0).abs() < 1e-14);
}
#[test]
fn test_complex_composition() {
let g = grad(
|x| {
let x2 = &x[0] * &x[0];
let ex = x[0].exp();
let inner = x2 + ex;
inner.sin()
},
&[1.0],
);
let x = 1.0_f64;
let inner = x * x + x.exp();
let expected = inner.cos() * (2.0 * x + x.exp());
assert!((g[0] - expected).abs() < 1e-10);
}
}