use crate::error::AutogradError;
use crate::Result as AgResult;
const FD_H: f64 = 1e-5;
pub fn grad<F>(f: F) -> impl Fn(&[f64]) -> Vec<f64>
where
F: Fn(&[f64]) -> f64,
{
move |x: &[f64]| {
let n = x.len();
let mut g = vec![0.0f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for i in 0..n {
xp[i] = x[i] + FD_H;
xm[i] = x[i] - FD_H;
g[i] = (f(&xp) - f(&xm)) / two_h;
xp[i] = x[i];
xm[i] = x[i];
}
g
}
}
pub fn grad_checked<F>(f: F) -> impl Fn(&[f64]) -> AgResult<Vec<f64>>
where
F: Fn(&[f64]) -> f64,
{
move |x: &[f64]| {
if x.is_empty() {
return Err(AutogradError::invalid_argument(
"grad_checked: input must be non-empty".to_string(),
));
}
let n = x.len();
let mut g = vec![0.0f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for i in 0..n {
xp[i] = x[i] + FD_H;
xm[i] = x[i] - FD_H;
let fp = f(&xp);
let fm = f(&xm);
if !fp.is_finite() || !fm.is_finite() {
return Err(AutogradError::invalid_argument(format!(
"grad_checked: non-finite value at component {i}: fp={fp}, fm={fm}"
)));
}
g[i] = (fp - fm) / two_h;
xp[i] = x[i];
xm[i] = x[i];
}
Ok(g)
}
}
pub fn value_and_grad<F>(f: F) -> impl Fn(&[f64]) -> (f64, Vec<f64>)
where
F: Fn(&[f64]) -> f64,
{
move |x: &[f64]| {
let n = x.len();
let val = f(x);
let mut g = vec![0.0f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for i in 0..n {
xp[i] = x[i] + FD_H;
xm[i] = x[i] - FD_H;
g[i] = (f(&xp) - f(&xm)) / two_h;
xp[i] = x[i];
xm[i] = x[i];
}
(val, g)
}
}
pub fn jacobian<F>(f: F, n_outputs: usize) -> impl Fn(&[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[f64]) -> Vec<f64>,
{
move |x: &[f64]| {
let n = x.len();
let mut j = vec![vec![0.0f64; n]; n_outputs];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for col in 0..n {
xp[col] = x[col] + FD_H;
xm[col] = x[col] - FD_H;
let fp = f(&xp);
let fm = f(&xm);
for row in 0..n_outputs.min(fp.len()).min(fm.len()) {
j[row][col] = (fp[row] - fm[row]) / two_h;
}
xp[col] = x[col];
xm[col] = x[col];
}
j
}
}
pub fn hessian<F>(f: F) -> impl Fn(&[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[f64]) -> f64,
{
move |x: &[f64]| {
let n = x.len();
let mut h = vec![vec![0.0f64; n]; n];
let h2 = FD_H * FD_H;
let fx = f(x);
let mut xp = x.to_vec();
let mut xm = x.to_vec();
for i in 0..n {
xp[i] = x[i] + FD_H;
xm[i] = x[i] - FD_H;
h[i][i] = (f(&xp) + f(&xm) - 2.0 * fx) / h2;
xp[i] = x[i];
xm[i] = x[i];
}
for i in 0..n {
for j in (i + 1)..n {
let mut xpp = x.to_vec();
let mut xpm = x.to_vec();
let mut xmp = x.to_vec();
let mut xmm = x.to_vec();
xpp[i] += FD_H;
xpp[j] += FD_H;
xpm[i] += FD_H;
xpm[j] -= FD_H;
xmp[i] -= FD_H;
xmp[j] += FD_H;
xmm[i] -= FD_H;
xmm[j] -= FD_H;
let entry = (f(&xpp) - f(&xpm) - f(&xmp) + f(&xmm)) / (4.0 * h2);
h[i][j] = entry;
h[j][i] = entry; }
}
h
}
}
pub fn vmap<F>(f: F) -> impl Fn(&[&[f64]]) -> Vec<Vec<f64>>
where
F: Fn(&[f64]) -> Vec<f64>,
{
move |batch: &[&[f64]]| batch.iter().map(|x| f(x)).collect()
}
pub fn vmap_with_grad<F, G>(
f: F,
f_scalar: G,
) -> impl Fn(&[&[f64]]) -> (Vec<Vec<f64>>, Vec<Vec<f64>>)
where
F: Fn(&[f64]) -> Vec<f64>,
G: Fn(&[f64]) -> f64,
{
let grad_fn = grad(f_scalar);
move |batch: &[&[f64]]| {
let values: Vec<Vec<f64>> = batch.iter().map(|x| f(x)).collect();
let grads: Vec<Vec<f64>> = batch.iter().map(|x| grad_fn(x)).collect();
(values, grads)
}
}
pub fn jvp<F>(f: F, x: &[f64], tangent: &[f64]) -> (Vec<f64>, Vec<f64>)
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
let fx = f(x);
let xp: Vec<f64> = x.iter().zip(tangent.iter()).map(|(&xi, &vi)| xi + FD_H * vi).collect();
let xm: Vec<f64> = x.iter().zip(tangent.iter()).map(|(&xi, &vi)| xi - FD_H * vi).collect();
if n == 0 || tangent.len() != n {
return (fx, vec![]);
}
let fp = f(&xp);
let fm = f(&xm);
let two_h = 2.0 * FD_H;
let jv: Vec<f64> = fp.iter().zip(fm.iter()).map(|(&a, &b)| (a - b) / two_h).collect();
(fx, jv)
}
pub fn vjp<F>(f: F, x: &[f64], cotangent: &[f64]) -> (Vec<f64>, Vec<f64>)
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = x.len();
let fx = f(x);
let m = fx.len();
if n == 0 || cotangent.len() != m {
return (fx, vec![0.0f64; n]);
}
let mut j_col = vec![vec![0.0f64; m]; n]; let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for col in 0..n {
xp[col] = x[col] + FD_H;
xm[col] = x[col] - FD_H;
let fp = f(&xp);
let fm_vals = f(&xm);
for row in 0..m.min(fp.len()).min(fm_vals.len()) {
j_col[col][row] = (fp[row] - fm_vals[row]) / two_h;
}
xp[col] = x[col];
xm[col] = x[col];
}
let g: Vec<f64> = (0..n)
.map(|col| {
cotangent
.iter()
.zip(j_col[col].iter())
.map(|(&c, &jrc)| c * jrc)
.sum()
})
.collect();
(fx, g)
}
pub fn grad_of_grad<F>(f: F) -> impl Fn(&[f64]) -> Vec<f64>
where
F: Fn(&[f64]) -> f64 + Clone,
{
let g = grad(f.clone());
move |x: &[f64]| {
let gx = g(x);
let n = x.len();
let mut result = vec![0.0f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_H;
for j in 0..n {
xp[j] = x[j] + FD_H;
xm[j] = x[j] - FD_H;
let gp = g(&xp);
let gm = g(&xm);
for i in 0..n {
let h_ij = (gp[i] - gm[i]) / two_h;
result[i] += h_ij * gx[j];
}
xp[j] = x[j];
xm[j] = x[j];
}
result
}
}
pub fn linearize<F>(f: F, x: &[f64], tangent: &[f64]) -> (Vec<f64>, Vec<f64>)
where
F: Fn(&[f64]) -> Vec<f64>,
{
jvp(f, x, tangent)
}
pub fn compose_vec<F, G>(f: F, g: G) -> impl Fn(Vec<f64>) -> Vec<f64>
where
F: Fn(Vec<f64>) -> Vec<f64>,
G: Fn(Vec<f64>) -> Vec<f64>,
{
move |x| g(f(x))
}
pub fn iterate_scalar<F>(f: F, n: usize) -> impl Fn(f64) -> f64
where
F: Fn(f64) -> f64,
{
move |mut x| {
for _ in 0..n {
x = f(x);
}
x
}
}
pub fn check_grad<F, G>(f: F, grad_fn: G, x: &[f64]) -> f64
where
F: Fn(&[f64]) -> f64,
G: Fn(&[f64]) -> Vec<f64>,
{
let numerical_g = {
let g_fn = grad(f);
g_fn(x)
};
let analytical_g = grad_fn(x);
numerical_g
.iter()
.zip(analytical_g.iter())
.map(|(&n, &a)| (n - a).abs())
.fold(0.0_f64, f64::max)
}
pub fn stop_gradient(x: Vec<f64>) -> Vec<f64> {
x
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_quadratic() {
let g = grad(|xs: &[f64]| xs[0] * xs[0] + xs[1] * xs[1]);
let result = g(&[3.0, 4.0]);
assert!((result[0] - 6.0).abs() < 1e-3, "g[0] = {}", result[0]);
assert!((result[1] - 8.0).abs() < 1e-3, "g[1] = {}", result[1]);
}
#[test]
fn test_grad_cubic() {
let g = grad(|xs: &[f64]| xs[0].powi(3));
let result = g(&[2.0]);
assert!((result[0] - 12.0).abs() < 1e-2, "g[0] = {}", result[0]);
}
#[test]
fn test_grad_checked_empty_input() {
let g = grad_checked(|xs: &[f64]| xs[0] * xs[0]);
assert!(g(&[]).is_err());
}
#[test]
fn test_value_and_grad() {
let vg = value_and_grad(|xs: &[f64]| xs[0] * xs[0] + xs[1] * xs[1]);
let (val, g) = vg(&[3.0, 4.0]);
assert!((val - 25.0).abs() < 1e-12);
assert!((g[0] - 6.0).abs() < 1e-3);
assert!((g[1] - 8.0).abs() < 1e-3);
}
#[test]
fn test_jacobian_linear_map() {
let jac = jacobian(|xs: &[f64]| vec![2.0 * xs[0] + xs[1], xs[0] - xs[1]], 2);
let j = jac(&[5.0, 3.0]);
assert!((j[0][0] - 2.0).abs() < 1e-3, "J[0][0] = {}", j[0][0]);
assert!((j[0][1] - 1.0).abs() < 1e-3, "J[0][1] = {}", j[0][1]);
assert!((j[1][0] - 1.0).abs() < 1e-3, "J[1][0] = {}", j[1][0]);
assert!((j[1][1] - (-1.0)).abs() < 1e-3, "J[1][1] = {}", j[1][1]);
}
#[test]
fn test_hessian_quadratic_form() {
let hess = hessian(|xs: &[f64]| xs[0] * xs[0] + 3.0 * xs[0] * xs[1] + 2.0 * xs[1] * xs[1]);
let h = hess(&[1.0, 1.0]);
assert!((h[0][0] - 2.0).abs() < 1e-2, "H[0][0] = {}", h[0][0]);
assert!((h[0][1] - 3.0).abs() < 1e-2, "H[0][1] = {}", h[0][1]);
assert!((h[1][0] - 3.0).abs() < 1e-2, "H[1][0] = {}", h[1][0]);
assert!((h[1][1] - 4.0).abs() < 1e-2, "H[1][1] = {}", h[1][1]);
}
#[test]
fn test_hessian_symmetry() {
let f = |xs: &[f64]| xs[0].sin() * xs[1].cos() + xs[0] * xs[1] * xs[1];
let h = hessian(f)(&[1.0, 2.0]);
assert!((h[0][1] - h[1][0]).abs() < 1e-5, "Symmetry violated: {} vs {}", h[0][1], h[1][0]);
}
#[test]
fn test_vmap_batch() {
let f = |xs: &[f64]| vec![xs[0] * xs[0], xs[1] * 2.0];
let batched = vmap(f);
let a: &[f64] = &[1.0, 2.0];
let b: &[f64] = &[3.0, 4.0];
let out = batched(&[a, b]);
assert_eq!(out.len(), 2);
assert!((out[0][0] - 1.0).abs() < 1e-12);
assert!((out[1][0] - 9.0).abs() < 1e-12);
assert!((out[1][1] - 8.0).abs() < 1e-12);
}
#[test]
fn test_jvp_quadratic() {
let f = |xs: &[f64]| vec![xs[0] * xs[0], xs[0] * xs[1]];
let (fx, jv) = jvp(f, &[2.0, 3.0], &[1.0, 0.0]);
assert!((fx[0] - 4.0).abs() < 1e-12);
assert!((fx[1] - 6.0).abs() < 1e-12);
assert!((jv[0] - 4.0).abs() < 1e-4, "jvp[0] = {}", jv[0]);
assert!((jv[1] - 3.0).abs() < 1e-4, "jvp[1] = {}", jv[1]);
}
#[test]
fn test_vjp_quadratic() {
let f = |xs: &[f64]| vec![xs[0] * xs[0], xs[0] * xs[1]];
let (fx, g) = vjp(f, &[2.0, 3.0], &[1.0, 0.0]);
assert!((fx[0] - 4.0).abs() < 1e-12);
assert!((g[0] - 4.0).abs() < 1e-4, "vjp[0] = {}", g[0]);
assert!((g[1] - 0.0).abs() < 1e-4, "vjp[1] = {}", g[1]);
}
#[test]
fn test_vjp_with_nontrivial_cotangent() {
let f = |xs: &[f64]| vec![xs[0] + xs[1], xs[0] * xs[1]];
let (_, g) = vjp(f, &[2.0, 3.0], &[1.0, 1.0]);
assert!((g[0] - 4.0).abs() < 1e-4, "vjp[0] = {}", g[0]);
assert!((g[1] - 3.0).abs() < 1e-4, "vjp[1] = {}", g[1]);
}
#[test]
fn test_linearize_exp() {
let f = |xs: &[f64]| vec![xs[0].exp(), xs[0] * xs[1]];
let (primal, tangent_out) = linearize(f, &[0.0, 2.0], &[1.0, 0.0]);
assert!((primal[0] - 1.0).abs() < 1e-9);
assert!((tangent_out[0] - 1.0).abs() < 1e-4);
assert!((tangent_out[1] - 2.0).abs() < 1e-4);
}
#[test]
fn test_check_grad_correct() {
let err = check_grad(
|xs: &[f64]| xs[0] * xs[0] + xs[1] * xs[1],
|xs: &[f64]| vec![2.0 * xs[0], 2.0 * xs[1]],
&[3.0, 4.0],
);
assert!(err < 1e-3, "gradient check error = {}", err);
}
#[test]
fn test_compose_vec() {
let f = |xs: Vec<f64>| xs.iter().map(|&v| v * 2.0).collect::<Vec<_>>();
let g = |xs: Vec<f64>| xs.iter().map(|&v| v + 1.0).collect::<Vec<_>>();
let h = compose_vec(f, g);
let out = h(vec![1.0, 2.0, 3.0]);
assert!((out[0] - 3.0).abs() < 1e-12);
assert!((out[1] - 5.0).abs() < 1e-12);
assert!((out[2] - 7.0).abs() < 1e-12);
}
#[test]
fn test_iterate_scalar() {
let f = |x: f64| x * 2.0;
let f5 = iterate_scalar(f, 5);
assert!((f5(1.0) - 32.0).abs() < 1e-12);
}
#[test]
fn test_stop_gradient_passthrough() {
let x = vec![1.0, 2.0, 3.0];
let y = stop_gradient(x.clone());
assert_eq!(y, x);
}
#[test]
fn test_grad_of_grad_quadratic() {
let gg = grad_of_grad(|xs: &[f64]| xs[0] * xs[0] + xs[1] * xs[1]);
let result = gg(&[3.0, 4.0]);
assert!((result[0] - 12.0).abs() < 1e-1, "gg[0] = {}", result[0]);
assert!((result[1] - 16.0).abs() < 1e-1, "gg[1] = {}", result[1]);
}
}