use crate::error::AutogradError;
use crate::Result;
use std::cell::RefCell;
#[derive(Debug, Clone)]
pub(crate) struct Node {
pub(crate) value: f64,
pub(crate) parents: Vec<(usize, f64)>,
}
pub struct Tape {
nodes: RefCell<Vec<Node>>,
}
impl Tape {
pub fn new() -> Self {
Self { nodes: RefCell::new(Vec::new()) }
}
pub fn with_capacity(n: usize) -> Self {
Self { nodes: RefCell::new(Vec::with_capacity(n)) }
}
pub fn push_input(&self, value: f64) -> TapeVar {
let mut nodes = self.nodes.borrow_mut();
let idx = nodes.len();
nodes.push(Node { value, parents: Vec::new() });
TapeVar { idx, value }
}
pub fn push_constant(&self, value: f64) -> TapeVar {
self.push_input(value)
}
pub fn push_unary(
&self,
a: TapeVar,
forward: impl Fn(f64) -> f64,
backward: impl Fn(f64, f64) -> f64,
) -> TapeVar {
let out_val = forward(a.value);
let local_grad = backward(a.value, out_val);
let mut nodes = self.nodes.borrow_mut();
let idx = nodes.len();
nodes.push(Node {
value: out_val,
parents: vec![(a.idx, local_grad)],
});
TapeVar { idx, value: out_val }
}
pub fn push_binary(
&self,
a: TapeVar,
b: TapeVar,
forward: impl Fn(f64, f64) -> f64,
backward: impl Fn(f64, f64) -> (f64, f64),
) -> TapeVar {
let out_val = forward(a.value, b.value);
let (ga, gb) = backward(a.value, b.value);
let mut nodes = self.nodes.borrow_mut();
let idx = nodes.len();
nodes.push(Node {
value: out_val,
parents: vec![(a.idx, ga), (b.idx, gb)],
});
TapeVar { idx, value: out_val }
}
pub fn push_custom(
&self,
value: f64,
parents: Vec<(TapeVar, f64)>,
) -> TapeVar {
let mut nodes = self.nodes.borrow_mut();
let idx = nodes.len();
nodes.push(Node {
value,
parents: parents.into_iter().map(|(v, w)| (v.idx, w)).collect(),
});
TapeVar { idx, value }
}
pub fn len(&self) -> usize {
self.nodes.borrow().len()
}
pub fn is_empty(&self) -> bool {
self.nodes.borrow().is_empty()
}
pub fn values(&self) -> Vec<f64> {
self.nodes.borrow().iter().map(|n| n.value).collect()
}
pub fn var(&self, value: f64) -> TapeVar {
self.push_input(value)
}
pub fn add(&self, a: TapeVar, b: TapeVar) -> TapeVar {
self.push_binary(a, b, |va, vb| va + vb, |_, _| (1.0, 1.0))
}
pub fn sub(&self, a: TapeVar, b: TapeVar) -> TapeVar {
self.push_binary(a, b, |va, vb| va - vb, |_, _| (1.0, -1.0))
}
pub fn mul(&self, a: TapeVar, b: TapeVar) -> TapeVar {
self.push_binary(a, b, |va, vb| va * vb, |va, vb| (vb, va))
}
pub fn div(&self, a: TapeVar, b: TapeVar) -> TapeVar {
self.push_binary(
a, b,
|va, vb| va / vb,
|va, vb| (1.0 / vb, -va / (vb * vb)),
)
}
pub fn neg(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| -va, |_, _| -1.0)
}
pub fn exp(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.exp(), |_, out| out)
}
pub fn ln(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.ln(), |va, _| 1.0 / va)
}
pub fn sqrt(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.sqrt(), |_, out| 0.5 / out)
}
pub fn sin(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.sin(), |va, _| va.cos())
}
pub fn cos(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.cos(), |va, _| -va.sin())
}
pub fn tanh(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.tanh(), |_, out| 1.0 - out * out)
}
pub fn powi(&self, a: TapeVar, n: i32) -> TapeVar {
self.push_unary(a, |va| va.powi(n), |va, _| f64::from(n) * va.powi(n - 1))
}
pub fn scale(&self, a: TapeVar, scalar: f64) -> TapeVar {
self.push_unary(a, |va| va * scalar, |_, _| scalar)
}
pub fn sigmoid(&self, a: TapeVar) -> TapeVar {
self.push_unary(
a,
|va| {
let e = (-va).exp();
1.0 / (1.0 + e)
},
|_, out| out * (1.0 - out),
)
}
pub fn relu(&self, a: TapeVar) -> TapeVar {
self.push_unary(a, |va| va.max(0.0), |va, _| if va > 0.0 { 1.0 } else { 0.0 })
}
pub fn sum(&self, vars: &[TapeVar]) -> Result<TapeVar> {
if vars.is_empty() {
return Err(AutogradError::invalid_argument(
"tape::sum: empty input slice".to_string(),
));
}
let mut acc = vars[0];
for &v in &vars[1..] {
acc = self.add(acc, v);
}
Ok(acc)
}
pub fn dot(&self, a: &[TapeVar], b: &[TapeVar]) -> Result<TapeVar> {
if a.len() != b.len() {
return Err(AutogradError::invalid_argument(format!(
"tape::dot: length mismatch {} vs {}",
a.len(),
b.len()
)));
}
if a.is_empty() {
return Err(AutogradError::invalid_argument(
"tape::dot: empty inputs".to_string(),
));
}
let products: Vec<TapeVar> = a.iter().zip(b.iter()).map(|(&ai, &bi)| self.mul(ai, bi)).collect();
self.sum(&products)
}
}
impl Default for Tape {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct TapeVar {
pub(crate) idx: usize,
pub value: f64,
}
impl TapeVar {
#[inline]
pub fn index(&self) -> usize {
self.idx
}
pub fn add(self, other: TapeVar, tape: &Tape) -> TapeVar {
tape.add(self, other)
}
pub fn mul(self, other: TapeVar, tape: &Tape) -> TapeVar {
tape.mul(self, other)
}
}
pub fn backward(tape: &Tape, output: TapeVar) -> Vec<f64> {
let nodes = tape.nodes.borrow();
let n = nodes.len();
let mut adjoints = vec![0.0f64; n];
if n == 0 {
return adjoints;
}
adjoints[output.idx] = 1.0;
for i in (0..n).rev() {
let adjoint_i = adjoints[i];
if adjoint_i == 0.0 {
continue; }
for &(parent_idx, weight) in &nodes[i].parents {
adjoints[parent_idx] += adjoint_i * weight;
}
}
adjoints
}
pub fn backward_with_seed(tape: &Tape, output: TapeVar, seed: f64) -> Vec<f64> {
let nodes = tape.nodes.borrow();
let n = nodes.len();
let mut adjoints = vec![0.0f64; n];
if n == 0 {
return adjoints;
}
adjoints[output.idx] = seed;
for i in (0..n).rev() {
let adjoint_i = adjoints[i];
if adjoint_i == 0.0 {
continue;
}
for &(parent_idx, weight) in &nodes[i].parents {
adjoints[parent_idx] += adjoint_i * weight;
}
}
adjoints
}
pub struct GradientTape {
tape: Tape,
}
impl GradientTape {
pub fn new() -> Self {
Self { tape: Tape::new() }
}
pub fn record_inputs(&self, inputs: &[f64]) -> Vec<TapeVar> {
inputs.iter().map(|&v| self.tape.push_input(v)).collect()
}
pub fn compute<F>(&self, f: F, inputs: &[TapeVar]) -> TapeVar
where
F: FnOnce(&Tape, &[TapeVar]) -> TapeVar,
{
f(&self.tape, inputs)
}
pub fn gradient(&self, output: TapeVar, wrt: &[TapeVar]) -> Vec<f64> {
let adjoints = backward(&self.tape, output);
wrt.iter().map(|v| adjoints[v.idx]).collect()
}
pub fn tape(&self) -> &Tape {
&self.tape
}
}
impl Default for GradientTape {
fn default() -> Self {
Self::new()
}
}
pub fn tape_grad<F>(f: F, x: &[f64]) -> Result<Vec<f64>>
where
F: FnOnce(&Tape, &[TapeVar]) -> TapeVar,
{
if x.is_empty() {
return Err(AutogradError::invalid_argument(
"tape_grad: input must be non-empty".to_string(),
));
}
let tape = Tape::with_capacity(x.len() * 4);
let inputs: Vec<TapeVar> = x.iter().map(|&v| tape.push_input(v)).collect();
let output = f(&tape, &inputs);
let adjoints = backward(&tape, output);
Ok(inputs.iter().map(|v| adjoints[v.idx]).collect())
}
pub fn tape_jacobian<F>(f: F, x: &[f64]) -> Result<Vec<Vec<f64>>>
where
F: Fn(&Tape, &[TapeVar]) -> Vec<TapeVar>,
{
if x.is_empty() {
return Err(AutogradError::invalid_argument(
"tape_jacobian: input must be non-empty".to_string(),
));
}
let tape = Tape::with_capacity(x.len() * 8);
let inputs: Vec<TapeVar> = x.iter().map(|&v| tape.push_input(v)).collect();
let outputs = f(&tape, &inputs);
if outputs.is_empty() {
return Err(AutogradError::invalid_argument(
"tape_jacobian: function returned no outputs".to_string(),
));
}
let m = outputs.len();
let n = x.len();
let mut jacobian = vec![vec![0.0f64; n]; m];
for (row, &out_var) in outputs.iter().enumerate() {
let adjoints = backward(&tape, out_var);
for (col, inp_var) in inputs.iter().enumerate() {
jacobian[row][col] = adjoints[inp_var.idx];
}
}
Ok(jacobian)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tape_basic_gradient() {
let g = tape_grad(
|tape, xs| {
let x2 = tape.powi(xs[0], 2);
let y2 = tape.powi(xs[1], 2);
tape.add(x2, y2)
},
&[3.0, 4.0],
)
.expect("tape_grad");
assert!((g[0] - 6.0).abs() < 1e-12, "∂f/∂x = {}", g[0]);
assert!((g[1] - 8.0).abs() < 1e-12, "∂f/∂y = {}", g[1]);
}
#[test]
fn test_tape_product_gradient() {
let g = tape_grad(
|tape, xs| tape.mul(xs[0], xs[1]),
&[3.0, 4.0],
)
.expect("tape_grad mul");
assert!((g[0] - 4.0).abs() < 1e-12);
assert!((g[1] - 3.0).abs() < 1e-12);
}
#[test]
fn test_tape_chain_rule() {
let g = tape_grad(
|tape, xs| {
let x2 = tape.powi(xs[0], 2);
tape.exp(x2)
},
&[2.0],
)
.expect("chain rule");
let expected = 4.0 * (4.0_f64).exp(); assert!((g[0] - expected).abs() < 1e-6, "f'(2) = {} expected {}", g[0], expected);
}
#[test]
fn test_tape_jacobian_vector_function() {
let j = tape_jacobian(
|tape, xs| {
let x2 = tape.powi(xs[0], 2);
let f0 = tape.mul(x2, xs[1]);
let f1 = tape.add(xs[0], xs[1]);
vec![f0, f1]
},
&[2.0, 3.0],
)
.expect("tape_jacobian");
assert!((j[0][0] - 12.0).abs() < 1e-12, "J[0][0] = {}", j[0][0]);
assert!((j[0][1] - 4.0).abs() < 1e-12, "J[0][1] = {}", j[0][1]);
assert!((j[1][0] - 1.0).abs() < 1e-12, "J[1][0] = {}", j[1][0]);
assert!((j[1][1] - 1.0).abs() < 1e-12, "J[1][1] = {}", j[1][1]);
}
#[test]
fn test_tape_sigmoid() {
let g = tape_grad(|tape, xs| tape.sigmoid(xs[0]), &[0.0]).expect("sigmoid grad");
assert!((g[0] - 0.25).abs() < 1e-12, "sigmoid'(0) = {}", g[0]);
}
#[test]
fn test_gradient_tape_high_level() {
let gt = GradientTape::new();
let inputs = gt.record_inputs(&[2.0, 3.0]);
let out = gt.compute(
|tape, inp| {
let x2 = tape.powi(inp[0], 2);
let y2 = tape.powi(inp[1], 2);
tape.add(x2, y2)
},
&inputs,
);
let grads = gt.gradient(out, &inputs);
assert!((grads[0] - 4.0).abs() < 1e-12, "dx = {}", grads[0]);
assert!((grads[1] - 6.0).abs() < 1e-12, "dy = {}", grads[1]);
}
#[test]
fn test_backward_with_seed() {
let tape = Tape::new();
let x = tape.var(2.0);
let y = tape.var(3.0);
let z = tape.mul(x, y); let adjs = backward_with_seed(&tape, z, 2.0); assert!((adjs[x.idx] - 6.0).abs() < 1e-12);
assert!((adjs[y.idx] - 4.0).abs() < 1e-12);
}
#[test]
fn test_tape_empty_error() {
assert!(tape_grad(|tape, xs| xs[0], &[]).is_err());
assert!(tape_jacobian(|_tape, _xs| vec![], &[1.0]).is_err());
}
#[test]
fn test_tape_dot_product() {
let tape = Tape::new();
let a: Vec<TapeVar> = [1.0, 2.0, 3.0].iter().map(|&v| tape.var(v)).collect();
let b: Vec<TapeVar> = [4.0, 5.0, 6.0].iter().map(|&v| tape.var(v)).collect();
let out = tape.dot(&a, &b).expect("dot product");
let adjs = backward(&tape, out);
assert!((out.value - 32.0).abs() < 1e-12);
assert!((adjs[a[0].idx] - 4.0).abs() < 1e-12);
assert!((adjs[a[2].idx] - 6.0).abs() < 1e-12);
assert!((adjs[b[1].idx] - 2.0).abs() < 1e-12);
}
#[test]
fn test_tape_relu_gradient() {
let g_pos = tape_grad(|tape, xs| tape.relu(xs[0]), &[2.0]).expect("relu+");
let g_neg = tape_grad(|tape, xs| tape.relu(xs[0]), &[-1.0]).expect("relu-");
assert!((g_pos[0] - 1.0).abs() < 1e-12);
assert!(g_neg[0].abs() < 1e-12);
}
}