#![no_std]
extern crate alloc;
mod tape;
mod var;
pub use tape::Tape;
pub use var::Var;
use alloc::vec::Vec;
use crate::la::{DMat, DVec};
pub fn grad<F>(f: F, x: &[f64]) -> DVec<f64>
where
F: Fn(&[Var]) -> Var,
{
let tape = Tape::new();
let vars: Vec<Var> = x.iter().map(|&v| tape.var(v)).collect();
let indices: Vec<usize> = vars.iter().map(|v| v.index).collect();
let result = f(&vars);
let all_grads = result.backward();
DVec::from_fn(x.len(), |i| all_grads[indices[i]])
}
pub fn jacobian_fwd<F>(f: F, x: &[f64]) -> DMat<f64>
where
F: Fn(&[crate::Dual<f64>]) -> Vec<crate::Dual<f64>>,
{
let n = x.len();
let mut columns = Vec::new();
for i in 0..n {
let inputs: Vec<crate::Dual<f64>> = x
.iter()
.enumerate()
.map(|(j, &v)| {
if i == j {
crate::Dual::var(v)
} else {
crate::Dual::constant(v)
}
})
.collect();
let outputs = f(&inputs);
columns.push(outputs.iter().map(|d| d.dual).collect::<Vec<_>>());
}
let m = columns.first().map_or(0, |c| c.len());
DMat::from_fn(m, n, |i, j| columns[j][i])
}
pub fn hessian<F>(f: F, x: &[f64]) -> DMat<f64>
where
F: Fn(&[crate::Dual<crate::Dual<f64>>]) -> crate::Dual<crate::Dual<f64>>,
{
let n = x.len();
DMat::from_fn(n, n, |i, j| {
let inputs: Vec<crate::Dual<crate::Dual<f64>>> = (0..n)
.map(|k| {
let real = crate::Dual::new(x[k], if k == j { 1.0 } else { 0.0 });
let dual = crate::Dual::new(if k == i { 1.0 } else { 0.0 }, 0.0);
crate::Dual::new(real, dual)
})
.collect();
f(&inputs).dual.dual
})
}
pub fn vjp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
where
F: Fn(&[Var]) -> Vec<Var>,
{
let tape = Tape::new();
let vars: Vec<Var> = x.iter().map(|&val| tape.var(val)).collect();
let indices: Vec<usize> = vars.iter().map(|var| var.index).collect();
let outputs = f(&vars);
let n = x.len();
let mut result = DVec::zeros(n);
for (k, out) in outputs.iter().enumerate() {
let grads = out.backward();
let vk = v[k];
for i in 0..n {
result[i] = result[i] + vk * grads[indices[i]];
}
}
result
}
pub fn jvp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
where
F: Fn(&[crate::Dual<f64>]) -> Vec<crate::Dual<f64>>,
{
let inputs: Vec<crate::Dual<f64>> = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| crate::Dual::new(xi, vi))
.collect();
let outputs = f(&inputs);
DVec::from_fn(outputs.len(), |i| outputs[i].dual)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Scalar;
#[test]
fn grad_simple() {
let g = grad(|x| &x[0] * &x[1], &[3.0, 5.0]);
assert!((g[0] - 5.0).abs() < 1e-10);
assert!((g[1] - 3.0).abs() < 1e-10);
}
#[test]
fn grad_quadratic() {
let g = grad(|x| &x[0] * &x[0], &[4.0]);
assert!((g[0] - 8.0).abs() < 1e-10);
}
#[test]
fn jacobian_fwd_linear() {
let j = jacobian_fwd(|x| alloc::vec![x[0] + x[1], x[0] - x[1]], &[1.0, 2.0]);
assert_eq!(j.get(0, 0), 1.0);
assert_eq!(j.get(0, 1), 1.0);
assert_eq!(j.get(1, 0), 1.0);
assert_eq!(j.get(1, 1), -1.0);
}
#[test]
fn hessian_quadratic() {
let h = hessian(
|x| x[0] * x[0] + x[0] * x[1] * crate::Dual::from_f64(2.0) + x[1] * x[1],
&[1.0, 1.0],
);
assert!((h.get(0, 0) - 2.0).abs() < 1e-10);
assert!((h.get(0, 1) - 2.0).abs() < 1e-10);
assert!((h.get(1, 0) - 2.0).abs() < 1e-10);
assert!((h.get(1, 1) - 2.0).abs() < 1e-10);
}
#[test]
fn jvp_simple() {
let result = jvp(
|x| alloc::vec![x[0] * x[1], x[0] + x[1]],
&[3.0, 5.0],
&[1.0, 0.0],
);
assert!((result[0] - 5.0).abs() < 1e-10);
assert!((result[1] - 1.0).abs() < 1e-10);
}
}