use crate::error::AutogradError;
use crate::Result;
const FD_H: f64 = 1e-7;
pub fn jvp<F>(f: &F, x: &[f64], t: &[f64]) -> Result<(Vec<f64>, Vec<f64>)>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"jvp: primal input must be non-empty".to_string(),
));
}
if t.len() != n {
return Err(AutogradError::ShapeMismatch(format!(
"jvp: tangent length {} != input length {}",
t.len(),
n
)));
}
let xp: Vec<f64> = x
.iter()
.zip(t.iter())
.map(|(&xi, &ti)| xi + FD_H * ti)
.collect();
let xm: Vec<f64> = x
.iter()
.zip(t.iter())
.map(|(&xi, &ti)| xi - FD_H * ti)
.collect();
let fp = f(&xp);
let fm = f(&xm);
let fx = f(x);
let two_h = 2.0 * FD_H;
let jvp_val: Vec<f64> = fp
.iter()
.zip(fm.iter())
.map(|(&fpi, &fmi)| (fpi - fmi) / two_h)
.collect();
Ok((fx, jvp_val))
}
pub fn vjp<F>(f: &F, x: &[f64], v: &[f64]) -> Result<(Vec<f64>, Vec<f64>)>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"vjp: primal input must be non-empty".to_string(),
));
}
let fx = f(x);
let m = fx.len();
if m == 0 {
return Err(AutogradError::invalid_argument(
"vjp: function output must be non-empty".to_string(),
));
}
if v.len() != m {
return Err(AutogradError::ShapeMismatch(format!(
"vjp: cotangent length {} != output length {}",
v.len(),
m
)));
}
let mut jac = vec![vec![0.0f64; n]; m];
let mut xp = x.to_vec();
let mut xmv = x.to_vec();
let two_h = 2.0 * FD_H;
for j in 0..n {
xp[j] += FD_H;
xmv[j] -= FD_H;
let fp = f(&xp);
let fmv = f(&xmv);
for i in 0..m {
jac[i][j] = (fp[i] - fmv[i]) / two_h;
}
xp[j] = x[j];
xmv[j] = x[j];
}
let mut result = vec![0.0f64; n];
for j in 0..n {
for i in 0..m {
result[j] += v[i] * jac[i][j];
}
}
Ok((fx, result))
}
pub fn jacfwd<F>(f: &F, x: &[f64]) -> Result<Vec<Vec<f64>>>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"jacfwd: input must be non-empty".to_string(),
));
}
let f0 = f(x);
let m = f0.len();
if m == 0 {
return Err(AutogradError::invalid_argument(
"jacfwd: function output must be non-empty".to_string(),
));
}
let mut jac = vec![vec![0.0f64; n]; m];
let mut xp = x.to_vec();
let mut xmv = x.to_vec();
let two_h = 2.0 * FD_H;
for j in 0..n {
xp[j] += FD_H;
xmv[j] -= FD_H;
let fp = f(&xp);
let fm = f(&xmv);
for i in 0..m {
jac[i][j] = (fp[i] - fm[i]) / two_h;
}
xp[j] = x[j];
xmv[j] = x[j];
}
Ok(jac)
}
pub fn jacrev<F>(f: &F, x: &[f64]) -> Result<Vec<Vec<f64>>>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"jacrev: input must be non-empty".to_string(),
));
}
let f0 = f(x);
let m = f0.len();
if m == 0 {
return Err(AutogradError::invalid_argument(
"jacrev: function output must be non-empty".to_string(),
));
}
let mut jac = vec![vec![0.0f64; n]; m];
for i in 0..m {
let mut cotangent = vec![0.0f64; m];
cotangent[i] = 1.0;
let (_fx, row) = vjp(f, x, &cotangent)?;
jac[i] = row;
}
Ok(jac)
}
pub fn hessian<F>(f: &F, x: &[f64]) -> Result<Vec<Vec<f64>>>
where
F: Fn(&[f64]) -> f64,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"hessian: input must be non-empty".to_string(),
));
}
let f0 = f(x);
const H2: f64 = 1e-4;
let h2 = H2 * H2;
let four_h2 = 4.0 * h2;
let mut hess = vec![vec![0.0f64; n]; n];
let mut xa = x.to_vec();
let mut xb = x.to_vec();
let mut xab = x.to_vec();
for i in 0..n {
xa[i] = x[i] + H2;
xb[i] = x[i] - H2;
let fpi = f(&xa);
let fmi = f(&xb);
hess[i][i] = (fpi - 2.0 * f0 + fmi) / h2;
xa[i] = x[i];
xb[i] = x[i];
for j in (i + 1)..n {
xa[i] = x[i] + H2;
xa[j] = x[j] + H2;
let fpp = f(&xa);
xa[j] = x[j] - H2;
let fpm = f(&xa);
xb[i] = x[i] - H2;
xb[j] = x[j] + H2;
let fmp = f(&xb);
xb[j] = x[j] - H2;
let fmm = f(&xb);
let val = (fpp - fpm - fmp + fmm) / four_h2;
hess[i][j] = val;
hess[j][i] = val;
xa[i] = x[i];
xa[j] = x[j];
xb[i] = x[i];
xb[j] = x[j];
xab[i] = x[i];
xab[j] = x[j];
}
}
let _ = xab;
Ok(hess)
}
pub fn linearize<F>(f: &F, x: &[f64], t: &[f64]) -> Result<(Vec<f64>, Vec<f64>)>
where
F: Fn(&[f64]) -> Vec<f64>,
{
jvp(f, x, t)
}
pub fn grad_scalar<F>(f: &F, x: &[f64]) -> Result<Vec<f64>>
where
F: Fn(&[f64]) -> f64,
{
let wrapper = |xs: &[f64]| vec![f(xs)];
let jac = jacfwd(&wrapper, x)?;
Ok(jac.into_iter().next().unwrap_or_default())
}
pub fn hvp<F>(f: &F, x: &[f64], v: &[f64]) -> Result<Vec<f64>>
where
F: Fn(&[f64]) -> f64,
{
let n = x.len();
if n == 0 {
return Err(AutogradError::invalid_argument(
"hvp: input must be non-empty".to_string(),
));
}
if v.len() != n {
return Err(AutogradError::ShapeMismatch(format!(
"hvp: vector length {} != input length {}",
v.len(),
n
)));
}
const HVP_STEP: f64 = 1e-4;
let xp: Vec<f64> = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| xi + HVP_STEP * vi)
.collect();
let xm: Vec<f64> = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| xi - HVP_STEP * vi)
.collect();
let gp = grad_scalar(f, &xp)?;
let gm = grad_scalar(f, &xm)?;
let two_h = 2.0 * HVP_STEP;
let result: Vec<f64> = gp
.iter()
.zip(gm.iter())
.map(|(&gpi, &gmi)| (gpi - gmi) / two_h)
.collect();
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn f_quad(xs: &[f64]) -> Vec<f64> {
vec![xs[0] * xs[0], xs[0] * xs[1]]
}
fn f_scalar(xs: &[f64]) -> f64 {
xs[0] * xs[0] + xs[1] * xs[1]
}
#[test]
fn test_jvp_basic() {
let (fx, jvp_val) = jvp(&f_quad, &[2.0, 3.0], &[1.0, 0.0]).expect("jvp");
assert!((fx[0] - 4.0).abs() < 1e-9); assert!((fx[1] - 6.0).abs() < 1e-9); assert!((jvp_val[0] - 4.0).abs() < 1e-4, "jvp[0] = {}", jvp_val[0]);
assert!((jvp_val[1] - 3.0).abs() < 1e-4, "jvp[1] = {}", jvp_val[1]);
}
#[test]
fn test_jvp_second_direction() {
let (_, jvp_val) = jvp(&f_quad, &[2.0, 3.0], &[0.0, 1.0]).expect("jvp");
assert!((jvp_val[0] - 0.0).abs() < 1e-4, "jvp[0] = {}", jvp_val[0]);
assert!((jvp_val[1] - 2.0).abs() < 1e-4, "jvp[1] = {}", jvp_val[1]);
}
#[test]
fn test_vjp_basic() {
let (fx, vjp_val) = vjp(&f_quad, &[2.0, 3.0], &[1.0, 0.0]).expect("vjp");
assert!((fx[0] - 4.0).abs() < 1e-9);
assert!((vjp_val[0] - 4.0).abs() < 1e-4, "vjp[0] = {}", vjp_val[0]);
assert!((vjp_val[1] - 0.0).abs() < 1e-4, "vjp[1] = {}", vjp_val[1]);
}
#[test]
fn test_vjp_second_cotangent() {
let (_, vjp_val) = vjp(&f_quad, &[2.0, 3.0], &[0.0, 1.0]).expect("vjp");
assert!((vjp_val[0] - 3.0).abs() < 1e-4, "vjp[0] = {}", vjp_val[0]);
assert!((vjp_val[1] - 2.0).abs() < 1e-4, "vjp[1] = {}", vjp_val[1]);
}
#[test]
fn test_jacfwd() {
let j = jacfwd(&f_quad, &[2.0, 3.0]).expect("jacfwd");
assert!((j[0][0] - 4.0).abs() < 1e-4, "j[0][0] = {}", j[0][0]);
assert!((j[0][1] - 0.0).abs() < 1e-4, "j[0][1] = {}", j[0][1]);
assert!((j[1][0] - 3.0).abs() < 1e-4, "j[1][0] = {}", j[1][0]);
assert!((j[1][1] - 2.0).abs() < 1e-4, "j[1][1] = {}", j[1][1]);
}
#[test]
fn test_jacrev() {
let j = jacrev(&f_quad, &[2.0, 3.0]).expect("jacrev");
assert!((j[0][0] - 4.0).abs() < 1e-4, "j[0][0] = {}", j[0][0]);
assert!((j[0][1] - 0.0).abs() < 1e-4, "j[0][1] = {}", j[0][1]);
assert!((j[1][0] - 3.0).abs() < 1e-4, "j[1][0] = {}", j[1][0]);
assert!((j[1][1] - 2.0).abs() < 1e-4, "j[1][1] = {}", j[1][1]);
}
#[test]
fn test_jacfwd_jacrev_agree() {
let jf = jacfwd(&f_quad, &[1.5, 2.5]).expect("jacfwd");
let jr = jacrev(&f_quad, &[1.5, 2.5]).expect("jacrev");
for i in 0..2 {
for j in 0..2 {
assert!(
(jf[i][j] - jr[i][j]).abs() < 1e-4,
"jf[{i}][{j}]={} jr[{i}][{j}]={}",
jf[i][j],
jr[i][j]
);
}
}
}
#[test]
fn test_hessian_diagonal() {
let h = hessian(&f_scalar, &[1.0, 1.0]).expect("hessian");
assert!((h[0][0] - 2.0).abs() < 1e-3, "H[0][0] = {}", h[0][0]);
assert!((h[1][1] - 2.0).abs() < 1e-3, "H[1][1] = {}", h[1][1]);
assert!(h[0][1].abs() < 1e-3, "H[0][1] = {}", h[0][1]);
assert!(h[1][0].abs() < 1e-3, "H[1][0] = {}", h[1][0]);
}
#[test]
fn test_hessian_mixed_partial() {
let g = |xs: &[f64]| xs[0] * xs[0] + xs[0] * xs[1];
let h = hessian(&g, &[1.0, 1.0]).expect("hessian mixed");
assert!((h[0][0] - 2.0).abs() < 1e-3, "H[0][0] = {}", h[0][0]);
assert!((h[0][1] - 1.0).abs() < 1e-3, "H[0][1] = {}", h[0][1]);
assert!((h[1][0] - 1.0).abs() < 1e-3, "H[1][0] = {}", h[1][0]);
assert!((h[1][1] - 0.0).abs() < 1e-3, "H[1][1] = {}", h[1][1]);
}
#[test]
fn test_linearize() {
let f = |xs: &[f64]| vec![xs[0].exp(), xs[0] * xs[1]];
let (primal, tangent) = linearize(&f, &[0.0, 2.0], &[1.0, 0.0]).expect("linearize");
assert!((primal[0] - 1.0).abs() < 1e-9); assert!((tangent[0] - 1.0).abs() < 1e-4); assert!((tangent[1] - 2.0).abs() < 1e-4); }
#[test]
fn test_grad_scalar() {
let g = grad_scalar(&f_scalar, &[3.0, 4.0]).expect("grad_scalar");
assert!((g[0] - 6.0).abs() < 1e-4, "g[0] = {}", g[0]);
assert!((g[1] - 8.0).abs() < 1e-4, "g[1] = {}", g[1]);
}
#[test]
fn test_hvp() {
let h = hvp(&f_scalar, &[1.0, 1.0], &[1.0, 0.0]).expect("hvp");
assert!((h[0] - 2.0).abs() < 1e-3, "hvp[0] = {}", h[0]);
assert!(h[1].abs() < 1e-3, "hvp[1] = {}", h[1]);
}
#[test]
fn test_jvp_empty_input_error() {
let f = |_xs: &[f64]| vec![1.0_f64];
assert!(jvp(&f, &[], &[]).is_err());
}
#[test]
fn test_vjp_dimension_mismatch_error() {
let f = |xs: &[f64]| vec![xs[0]];
assert!(vjp(&f, &[1.0], &[1.0, 2.0]).is_err());
}
#[test]
fn test_jacfwd_empty_input_error() {
let f = |_xs: &[f64]| vec![1.0_f64];
assert!(jacfwd(&f, &[]).is_err());
}
#[test]
fn test_hessian_empty_input_error() {
let f = |_xs: &[f64]| 0.0_f64;
assert!(hessian(&f, &[]).is_err());
}
}