use std::ops;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Dual {
pub val: f64,
pub deriv: f64,
}
impl Dual {
#[must_use]
#[inline]
pub const fn new(val: f64, deriv: f64) -> Self {
Self { val, deriv }
}
#[must_use]
#[inline]
pub const fn var(val: f64) -> Self {
Self { val, deriv: 1.0 }
}
#[must_use]
#[inline]
pub const fn constant(val: f64) -> Self {
Self { val, deriv: 0.0 }
}
#[must_use]
#[inline]
pub fn sin(self) -> Self {
Self {
val: self.val.sin(),
deriv: self.deriv * self.val.cos(),
}
}
#[must_use]
#[inline]
pub fn cos(self) -> Self {
Self {
val: self.val.cos(),
deriv: -self.deriv * self.val.sin(),
}
}
#[must_use]
#[inline]
pub fn exp(self) -> Self {
let e = self.val.exp();
Self {
val: e,
deriv: self.deriv * e,
}
}
#[must_use]
#[inline]
pub fn ln(self) -> Self {
Self {
val: self.val.ln(),
deriv: self.deriv / self.val,
}
}
#[must_use]
#[inline]
pub fn sqrt(self) -> Self {
let s = self.val.sqrt();
Self {
val: s,
deriv: self.deriv / (2.0 * s),
}
}
#[must_use]
#[inline]
pub fn powf(self, n: f64) -> Self {
Self {
val: self.val.powf(n),
deriv: self.deriv * n * self.val.powf(n - 1.0),
}
}
#[must_use]
#[inline]
pub fn abs(self) -> Self {
Self {
val: self.val.abs(),
deriv: if self.val >= 0.0 {
self.deriv
} else {
-self.deriv
},
}
}
#[must_use]
#[inline]
pub fn tan(self) -> Self {
let c = self.val.cos();
Self {
val: self.val.tan(),
deriv: self.deriv / (c * c),
}
}
}
impl ops::Add for Dual {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self {
val: self.val + rhs.val,
deriv: self.deriv + rhs.deriv,
}
}
}
impl ops::Sub for Dual {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self {
val: self.val - rhs.val,
deriv: self.deriv - rhs.deriv,
}
}
}
impl ops::Mul for Dual {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self {
val: self.val * rhs.val,
deriv: self.val * rhs.deriv + self.deriv * rhs.val,
}
}
}
impl ops::Div for Dual {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
Self {
val: self.val / rhs.val,
deriv: (self.deriv * rhs.val - self.val * rhs.deriv) / (rhs.val * rhs.val),
}
}
}
impl ops::Neg for Dual {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self {
val: -self.val,
deriv: -self.deriv,
}
}
}
impl ops::Add<f64> for Dual {
type Output = Self;
#[inline]
fn add(self, rhs: f64) -> Self {
Self {
val: self.val + rhs,
deriv: self.deriv,
}
}
}
impl ops::Sub<f64> for Dual {
type Output = Self;
#[inline]
fn sub(self, rhs: f64) -> Self {
Self {
val: self.val - rhs,
deriv: self.deriv,
}
}
}
impl ops::Mul<f64> for Dual {
type Output = Self;
#[inline]
fn mul(self, rhs: f64) -> Self {
Self {
val: self.val * rhs,
deriv: self.deriv * rhs,
}
}
}
impl ops::Div<f64> for Dual {
type Output = Self;
#[inline]
fn div(self, rhs: f64) -> Self {
Self {
val: self.val / rhs,
deriv: self.deriv / rhs,
}
}
}
impl ops::Add<Dual> for f64 {
type Output = Dual;
#[inline]
fn add(self, rhs: Dual) -> Dual {
Dual {
val: self + rhs.val,
deriv: rhs.deriv,
}
}
}
impl ops::Mul<Dual> for f64 {
type Output = Dual;
#[inline]
fn mul(self, rhs: Dual) -> Dual {
Dual {
val: self * rhs.val,
deriv: self * rhs.deriv,
}
}
}
impl From<f64> for Dual {
#[inline]
fn from(val: f64) -> Self {
Self::constant(val)
}
}
impl std::fmt::Display for Dual {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}+{}ε", self.val, self.deriv)
}
}
#[derive(Debug, Clone, Copy)]
enum TapeOp {
Const,
Add(usize, usize),
Mul(usize, usize),
Sub(usize, usize),
Div(usize, usize),
Neg(usize),
Sin(usize),
Cos(usize),
Exp(usize),
Ln(usize),
Pow(usize, f64),
}
#[derive(Debug)]
pub struct Tape {
ops: Vec<TapeOp>,
values: Vec<f64>,
}
#[derive(Debug, Clone, Copy)]
pub struct Var {
index: usize,
val: f64,
}
impl Tape {
#[must_use]
pub fn new() -> Self {
Self {
ops: Vec::new(),
values: Vec::new(),
}
}
pub fn var(&mut self, val: f64) -> Var {
let index = self.ops.len();
self.ops.push(TapeOp::Const);
self.values.push(val);
Var { index, val }
}
pub fn constant(&mut self, val: f64) -> Var {
let index = self.ops.len();
self.ops.push(TapeOp::Const);
self.values.push(val);
Var { index, val }
}
fn push(&mut self, op: TapeOp, val: f64) -> Var {
let index = self.ops.len();
self.ops.push(op);
self.values.push(val);
Var { index, val }
}
pub fn add(&mut self, a: Var, b: Var) -> Var {
self.push(TapeOp::Add(a.index, b.index), a.val + b.val)
}
pub fn sub(&mut self, a: Var, b: Var) -> Var {
self.push(TapeOp::Sub(a.index, b.index), a.val - b.val)
}
pub fn mul(&mut self, a: Var, b: Var) -> Var {
self.push(TapeOp::Mul(a.index, b.index), a.val * b.val)
}
pub fn div(&mut self, a: Var, b: Var) -> Var {
self.push(TapeOp::Div(a.index, b.index), a.val / b.val)
}
pub fn neg(&mut self, a: Var) -> Var {
self.push(TapeOp::Neg(a.index), -a.val)
}
pub fn sin(&mut self, a: Var) -> Var {
self.push(TapeOp::Sin(a.index), a.val.sin())
}
pub fn cos(&mut self, a: Var) -> Var {
self.push(TapeOp::Cos(a.index), a.val.cos())
}
pub fn exp(&mut self, a: Var) -> Var {
self.push(TapeOp::Exp(a.index), a.val.exp())
}
pub fn ln(&mut self, a: Var) -> Var {
self.push(TapeOp::Ln(a.index), a.val.ln())
}
pub fn powf(&mut self, a: Var, n: f64) -> Var {
self.push(TapeOp::Pow(a.index, n), a.val.powf(n))
}
#[must_use]
pub fn backward(&self, output: Var) -> Vec<f64> {
let n = self.ops.len();
let mut grads = vec![0.0; n];
grads[output.index] = 1.0;
for i in (0..n).rev() {
let g = grads[i];
if g == 0.0 {
continue;
}
match self.ops[i] {
TapeOp::Const => {}
TapeOp::Add(a, b) => {
grads[a] += g;
grads[b] += g;
}
TapeOp::Sub(a, b) => {
grads[a] += g;
grads[b] -= g;
}
TapeOp::Mul(a, b) => {
grads[a] += g * self.values[b];
grads[b] += g * self.values[a];
}
TapeOp::Div(a, b) => {
grads[a] += g / self.values[b];
grads[b] -= g * self.values[a] / (self.values[b] * self.values[b]);
}
TapeOp::Neg(a) => {
grads[a] -= g;
}
TapeOp::Sin(a) => {
grads[a] += g * self.values[a].cos();
}
TapeOp::Cos(a) => {
grads[a] -= g * self.values[a].sin();
}
TapeOp::Exp(a) => {
grads[a] += g * self.values[i]; }
TapeOp::Ln(a) => {
grads[a] += g / self.values[a];
}
TapeOp::Pow(a, n) => {
grads[a] += g * n * self.values[a].powf(n - 1.0);
}
}
}
grads
}
}
impl Default for Tape {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn reverse_gradient(f: impl Fn(&mut Tape, &[Var]) -> Var, x: &[f64]) -> Vec<f64> {
let mut tape = Tape::new();
let vars: Vec<Var> = x.iter().map(|&v| tape.var(v)).collect();
let output = f(&mut tape, &vars);
let grads = tape.backward(output);
vars.iter().map(|v| grads[v.index]).collect()
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
fn approx(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn dual_arithmetic() {
let x = Dual::var(3.0);
let c = Dual::constant(2.0);
let r = x * x + c * x;
assert!(approx(r.val, 15.0));
assert!(approx(r.deriv, 8.0));
}
#[test]
fn dual_division() {
let x = Dual::var(4.0);
let one = Dual::constant(1.0);
let r = one / x; assert!(approx(r.val, 0.25));
assert!(approx(r.deriv, -1.0 / 16.0));
}
#[test]
fn dual_sin_cos() {
let x = Dual::var(0.0);
let s = x.sin();
assert!(approx(s.val, 0.0));
assert!(approx(s.deriv, 1.0));
let c = x.cos();
assert!(approx(c.val, 1.0));
assert!(approx(c.deriv, 0.0)); }
#[test]
fn dual_exp_ln() {
let x = Dual::var(1.0);
let e = x.exp();
assert!(approx(e.val, std::f64::consts::E));
assert!(approx(e.deriv, std::f64::consts::E));
let l = x.ln();
assert!(approx(l.val, 0.0));
assert!(approx(l.deriv, 1.0)); }
#[test]
fn dual_sqrt() {
let x = Dual::var(4.0);
let s = x.sqrt();
assert!(approx(s.val, 2.0));
assert!(approx(s.deriv, 0.25)); }
#[test]
fn dual_powf() {
let x = Dual::var(2.0);
let p = x.powf(3.0); assert!(approx(p.val, 8.0));
assert!(approx(p.deriv, 12.0)); }
#[test]
fn dual_chain_rule() {
let x = Dual::var(1.0);
let r = (x * x).sin();
assert!(approx(r.val, 1.0_f64.sin()));
assert!(approx(r.deriv, 2.0 * 1.0_f64.cos()));
}
#[test]
fn dual_neg() {
let x = Dual::var(3.0);
let r = -x;
assert!(approx(r.val, -3.0));
assert!(approx(r.deriv, -1.0));
}
#[test]
fn dual_abs() {
let x = Dual::var(-3.0);
let r = x.abs();
assert!(approx(r.val, 3.0));
assert!(approx(r.deriv, -1.0));
}
#[test]
fn dual_tan() {
let x = Dual::var(0.0);
let r = x.tan();
assert!(approx(r.val, 0.0));
assert!(approx(r.deriv, 1.0)); }
#[test]
fn dual_display() {
let d = Dual::new(1.0, 2.0);
assert_eq!(format!("{d}"), "1+2ε");
}
#[test]
fn dual_from_f64() {
let d: Dual = 5.0.into();
assert!(approx(d.val, 5.0));
assert!(approx(d.deriv, 0.0));
}
#[test]
fn dual_scalar_ops() {
let x = Dual::var(3.0);
let r = x + 1.0;
assert!(approx(r.val, 4.0));
assert!(approx(r.deriv, 1.0));
let r2 = x * 2.0;
assert!(approx(r2.val, 6.0));
assert!(approx(r2.deriv, 2.0));
}
#[test]
fn dual_sub_scalar() {
let x = Dual::var(5.0);
let r = x - 3.0;
assert!(approx(r.val, 2.0));
assert!(approx(r.deriv, 1.0));
}
#[test]
fn dual_div_scalar() {
let x = Dual::var(6.0);
let r = x / 3.0;
assert!(approx(r.val, 2.0));
assert!(approx(r.deriv, 1.0 / 3.0));
}
#[test]
fn dual_reverse_scalar_ops() {
let x = Dual::var(3.0);
let r = 2.0 * x;
assert!(approx(r.val, 6.0));
assert!(approx(r.deriv, 2.0));
let r2 = 10.0 + x;
assert!(approx(r2.val, 13.0));
assert!(approx(r2.deriv, 1.0));
}
#[test]
fn reverse_simple_product() {
let grad = reverse_gradient(|tape, vars| tape.mul(vars[0], vars[1]), &[3.0, 5.0]);
assert!(approx(grad[0], 5.0)); assert!(approx(grad[1], 3.0)); }
#[test]
fn reverse_sum() {
let grad = reverse_gradient(|tape, vars| tape.add(vars[0], vars[1]), &[3.0, 5.0]);
assert!(approx(grad[0], 1.0));
assert!(approx(grad[1], 1.0));
}
#[test]
fn reverse_chain_rule() {
let grad = reverse_gradient(
|tape, vars| {
let x2 = tape.mul(vars[0], vars[0]);
tape.sin(x2)
},
&[1.0],
);
let expected = 2.0 * 1.0_f64.cos(); assert!(approx(grad[0], expected));
}
#[test]
fn reverse_matches_forward() {
let x_val = 2.0;
let x = Dual::var(x_val);
let fwd = x * x * x + Dual::constant(2.0) * x;
let grad = reverse_gradient(
|tape, vars| {
let x2 = tape.mul(vars[0], vars[0]);
let x3 = tape.mul(x2, vars[0]);
let two = tape.constant(2.0);
let two_x = tape.mul(two, vars[0]);
tape.add(x3, two_x)
},
&[x_val],
);
assert!(
(fwd.deriv - grad[0]).abs() < 1e-8,
"forward={} vs reverse={}",
fwd.deriv,
grad[0]
);
}
#[test]
fn reverse_exp_ln() {
let grad = reverse_gradient(
|tape, vars| {
let ln_x = tape.ln(vars[0]);
tape.exp(ln_x)
},
&[3.0],
);
assert!(approx(grad[0], 1.0));
}
#[test]
fn reverse_power() {
let grad = reverse_gradient(|tape, vars| tape.powf(vars[0], 3.0), &[2.0]);
assert!(approx(grad[0], 12.0)); }
}