use crate::error::AutogradError;
use scirs2_core::ndarray::{Array2, Axis};
const FD_STEP: f64 = 1e-5;
pub fn grad(
f: impl Fn(&[f64]) -> f64,
x: &[f64],
) -> Result<Vec<f64>, AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"grad: input must be non-empty".to_string(),
));
}
let mut g = vec![0.0f64; n];
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_STEP;
for i in 0..n {
xp[i] = x[i] + FD_STEP;
xm[i] = x[i] - FD_STEP;
g[i] = (f(&xp) - f(&xm)) / two_h;
xp[i] = x[i];
xm[i] = x[i];
}
Ok(g)
}
pub fn jacobian(
f: impl Fn(&[f64]) -> Vec<f64>,
x: &[f64],
) -> Result<Array2<f64>, AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"jacobian: input must be non-empty".to_string(),
));
}
let f0 = f(x);
let m = f0.len();
if m == 0 {
return Err(AutogradError::OperationError(
"jacobian: output must be non-empty".to_string(),
));
}
let mut jac = Array2::<f64>::zeros((m, n));
let mut xp = x.to_vec();
let mut xm = x.to_vec();
let two_h = 2.0 * FD_STEP;
for j in 0..n {
xp[j] = x[j] + FD_STEP;
xm[j] = x[j] - FD_STEP;
let fp = f(&xp);
let fm = f(&xm);
for i in 0..m {
jac[[i, j]] = (fp[i] - fm[i]) / two_h;
}
xp[j] = x[j];
xm[j] = x[j];
}
Ok(jac)
}
pub fn jvp(
f: impl Fn(&[f64]) -> Vec<f64>,
x: &[f64],
v: &[f64],
) -> Result<(Vec<f64>, Vec<f64>), AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"jvp: input must be non-empty".to_string(),
));
}
if v.len() != n {
return Err(AutogradError::ShapeMismatch(format!(
"jvp: tangent vector length {} does not match input length {}",
v.len(),
n
)));
}
let xp: Vec<f64> = x.iter().zip(v.iter()).map(|(&xi, &vi)| xi + FD_STEP * vi).collect();
let xm: Vec<f64> = x.iter().zip(v.iter()).map(|(&xi, &vi)| xi - FD_STEP * vi).collect();
let fp = f(&xp);
let fm = f(&xm);
let fx = f(x);
let two_h = 2.0 * FD_STEP;
let jvp_val: Vec<f64> = fp.iter().zip(fm.iter()).map(|(&fpi, &fmi)| (fpi - fmi) / two_h).collect();
Ok((fx, jvp_val))
}
pub fn vjp(
f: impl Fn(&[f64]) -> Vec<f64>,
x: &[f64],
v: &[f64],
) -> Result<(Vec<f64>, Vec<f64>), AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"vjp: input must be non-empty".to_string(),
));
}
let fx = f(x);
let m = fx.len();
if m == 0 {
return Err(AutogradError::OperationError(
"vjp: function output must be non-empty".to_string(),
));
}
if v.len() != m {
return Err(AutogradError::ShapeMismatch(format!(
"vjp: cotangent vector length {} does not match output length {}",
v.len(),
m
)));
}
let jac = jacobian(f, x)?;
let mut result = vec![0.0f64; n];
for j in 0..n {
for i in 0..m {
result[j] += v[i] * jac[[i, j]];
}
}
Ok((fx, result))
}
pub fn hessian(
f: impl Fn(&[f64]) -> f64,
x: &[f64],
) -> Result<Array2<f64>, AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"hessian: input must be non-empty".to_string(),
));
}
let mut h_mat = Array2::<f64>::zeros((n, n));
let fx = f(x);
let h2_diag = FD_STEP * FD_STEP;
let h2_off = 4.0 * FD_STEP * FD_STEP;
let mut xp = x.to_vec();
let mut xm = x.to_vec();
for i in 0..n {
xp[i] = x[i] + FD_STEP;
xm[i] = x[i] - FD_STEP;
h_mat[[i, i]] = (f(&xp) + f(&xm) - 2.0 * fx) / h2_diag;
xp[i] = x[i];
xm[i] = x[i];
}
let mut xpp = x.to_vec();
let mut xpm = x.to_vec();
let mut xmp = x.to_vec();
let mut xmm = x.to_vec();
for i in 0..n {
for j in (i + 1)..n {
xpp[i] = x[i] + FD_STEP;
xpp[j] = x[j] + FD_STEP;
xpm[i] = x[i] + FD_STEP;
xpm[j] = x[j] - FD_STEP;
xmp[i] = x[i] - FD_STEP;
xmp[j] = x[j] + FD_STEP;
xmm[i] = x[i] - FD_STEP;
xmm[j] = x[j] - FD_STEP;
let val = (f(&xpp) - f(&xpm) - f(&xmp) + f(&xmm)) / h2_off;
h_mat[[i, j]] = val;
h_mat[[j, i]] = val;
xpp[i] = x[i];
xpp[j] = x[j];
xpm[i] = x[i];
xpm[j] = x[j];
xmp[i] = x[i];
xmp[j] = x[j];
xmm[i] = x[i];
xmm[j] = x[j];
}
}
Ok(h_mat)
}
pub fn hvp(
f: impl Fn(&[f64]) -> f64,
x: &[f64],
v: &[f64],
) -> Result<Vec<f64>, AutogradError> {
let n = x.len();
if n == 0 {
return Err(AutogradError::OperationError(
"hvp: input must be non-empty".to_string(),
));
}
if v.len() != n {
return Err(AutogradError::ShapeMismatch(format!(
"hvp: vector length {} does not match input length {}",
v.len(),
n
)));
}
let xp: Vec<f64> = x.iter().zip(v.iter()).map(|(&xi, &vi)| xi + FD_STEP * vi).collect();
let xm: Vec<f64> = x.iter().zip(v.iter()).map(|(&xi, &vi)| xi - FD_STEP * vi).collect();
let gp = grad(&f, &xp)?;
let gm = grad(&f, &xm)?;
let two_h = 2.0 * FD_STEP;
let result: Vec<f64> = gp.iter().zip(gm.iter()).map(|(&gpi, &gmi)| (gpi - gmi) / two_h).collect();
Ok(result)
}
pub fn vmap(
f: impl Fn(&[f64]) -> Vec<f64>,
inputs: &Array2<f64>,
) -> Result<Array2<f64>, AutogradError> {
let batch = inputs.nrows();
if batch == 0 {
return Err(AutogradError::OperationError(
"vmap: input batch is empty".to_string(),
));
}
let row0 = inputs.row(0);
let out0 = f(row0.as_slice().unwrap_or(&row0.iter().copied().collect::<Vec<_>>()));
let m = out0.len();
if m == 0 {
return Err(AutogradError::OperationError(
"vmap: function returned empty output".to_string(),
));
}
let mut result_data = vec![0.0f64; batch * m];
result_data[..m].copy_from_slice(&out0);
for i in 1..batch {
let row = inputs.row(i);
let row_slice: Vec<f64>;
let slice_ref: &[f64] = match row.as_slice() {
Some(s) => s,
None => {
row_slice = row.iter().copied().collect();
&row_slice
}
};
let out = f(slice_ref);
if out.len() != m {
return Err(AutogradError::ShapeMismatch(format!(
"vmap: row {} produced output of length {} but expected {}",
i,
out.len(),
m
)));
}
result_data[i * m..(i + 1) * m].copy_from_slice(&out);
}
Array2::from_shape_vec((batch, m), result_data).map_err(|e| {
AutogradError::ShapeMismatch(format!("vmap: failed to create output array: {}", e))
})
}
pub fn batch_grad(
loss_fn: impl Fn(&[f64], &[f64]) -> f64,
params: &[f64],
batch: &Array2<f64>,
) -> Result<Vec<f64>, AutogradError> {
let p = params.len();
if p == 0 {
return Err(AutogradError::OperationError(
"batch_grad: params must be non-empty".to_string(),
));
}
let n_samples = batch.nrows();
if n_samples == 0 {
return Err(AutogradError::OperationError(
"batch_grad: batch must be non-empty".to_string(),
));
}
let mut acc = vec![0.0f64; p];
for row in batch.axis_iter(Axis(0)) {
let sample: Vec<f64> = row.iter().copied().collect();
let mut pp = params.to_vec();
let mut pm = params.to_vec();
let two_h = 2.0 * FD_STEP;
for k in 0..p {
pp[k] = params[k] + FD_STEP;
pm[k] = params[k] - FD_STEP;
acc[k] += (loss_fn(&pp, &sample) - loss_fn(&pm, &sample)) / two_h;
pp[k] = params[k];
pm[k] = params[k];
}
}
Ok(acc)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
const TOL: f64 = 1e-3;
#[test]
fn test_grad_x_squared_at_3() {
let g = grad(|xs| xs[0] * xs[0], &[3.0]).expect("grad x^2 at 3");
assert!((g[0] - 6.0).abs() < TOL, "expected 6.0, got {}", g[0]);
}
#[test]
fn test_grad_multivariate_quadratic() {
let g = grad(|xs| xs[0] * xs[0] + xs[1] * xs[1], &[3.0, 4.0])
.expect("grad multivariate");
assert!((g[0] - 6.0).abs() < TOL);
assert!((g[1] - 8.0).abs() < TOL);
}
#[test]
fn test_grad_empty_input_returns_error() {
let result = grad(|_xs| 0.0, &[]);
assert!(result.is_err());
}
#[test]
fn test_grad_rosenbrock() {
let x = &[1.0, 1.0];
let g = grad(
|xs| {
let a = 1.0 - xs[0];
let b = xs[1] - xs[0] * xs[0];
a * a + 100.0 * b * b
},
x,
)
.expect("grad rosenbrock");
assert!(g[0].abs() < 1e-2, "∂f/∂x at (1,1) ≈ 0, got {}", g[0]);
assert!(g[1].abs() < 1e-2, "∂f/∂y at (1,1) ≈ 0, got {}", g[1]);
}
#[test]
fn test_jacobian_vector_quadratic() {
let j = jacobian(|xs| vec![xs[0] * xs[0], xs[0] * xs[1]], &[2.0, 3.0])
.expect("jacobian");
assert!((j[[0, 0]] - 4.0).abs() < TOL);
assert!((j[[0, 1]] - 0.0).abs() < TOL);
assert!((j[[1, 0]] - 3.0).abs() < TOL);
assert!((j[[1, 1]] - 2.0).abs() < TOL);
}
#[test]
fn test_jacobian_identity() {
let j = jacobian(|xs| xs.to_vec(), &[1.0, 2.0, 3.0]).expect("jacobian identity");
assert_eq!(j.shape(), &[3, 3]);
for i in 0..3 {
for k in 0..3 {
let expected = if i == k { 1.0 } else { 0.0 };
assert!((j[[i, k]] - expected).abs() < TOL);
}
}
}
#[test]
fn test_jvp_basic() {
let (fx, jvp_val) =
jvp(|xs| vec![xs[0] * xs[0], xs[0] * xs[1]], &[2.0, 3.0], &[1.0, 0.0])
.expect("jvp");
assert!((fx[0] - 4.0).abs() < TOL);
assert!((fx[1] - 6.0).abs() < TOL);
assert!((jvp_val[0] - 4.0).abs() < TOL);
assert!((jvp_val[1] - 3.0).abs() < TOL);
}
#[test]
fn test_jvp_dimension_mismatch_error() {
let result = jvp(|xs| vec![xs[0]], &[1.0, 2.0], &[1.0]);
assert!(result.is_err());
}
#[test]
fn test_vjp_basic() {
let (fx, vjp_val) =
vjp(|xs| vec![xs[0] * xs[0], xs[0] * xs[1]], &[2.0, 3.0], &[1.0, 0.0])
.expect("vjp");
assert!((fx[0] - 4.0).abs() < TOL);
assert!((fjp_val(&vjp_val, 0) - 4.0).abs() < TOL);
assert!((fjp_val(&vjp_val, 1) - 0.0).abs() < TOL);
}
fn fjp_val(v: &[f64], i: usize) -> f64 {
v[i]
}
#[test]
fn test_vjp_dimension_mismatch_cotangent() {
let result = vjp(|xs| vec![xs[0]], &[1.0], &[1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_hessian_spherical() {
let h = hessian(|xs| xs[0] * xs[0] + xs[1] * xs[1], &[1.0, 1.0])
.expect("hessian spherical");
assert!((h[[0, 0]] - 2.0).abs() < TOL);
assert!((h[[1, 1]] - 2.0).abs() < TOL);
assert!(h[[0, 1]].abs() < TOL);
assert!(h[[1, 0]].abs() < TOL);
}
#[test]
fn test_hessian_cross_term() {
let h = hessian(|xs| xs[0] * xs[1], &[2.0, 3.0]).expect("hessian cross");
assert!(h[[0, 0]].abs() < TOL);
assert!(h[[1, 1]].abs() < TOL);
assert!((h[[0, 1]] - 1.0).abs() < TOL);
assert!((h[[1, 0]] - 1.0).abs() < TOL);
}
#[test]
fn test_hessian_empty_input_error() {
let result = hessian(|_xs| 0.0, &[]);
assert!(result.is_err());
}
#[test]
fn test_hvp_spherical() {
let result = hvp(
|xs| xs[0] * xs[0] + xs[1] * xs[1],
&[1.0, 1.0],
&[1.0, 0.0],
)
.expect("hvp spherical");
assert!((result[0] - 2.0).abs() < TOL);
assert!(result[1].abs() < TOL);
}
#[test]
fn test_hvp_second_direction() {
let result = hvp(
|xs| xs[0] * xs[0] + xs[1] * xs[1],
&[1.0, 1.0],
&[0.0, 1.0],
)
.expect("hvp second direction");
assert!(result[0].abs() < TOL);
assert!((result[1] - 2.0).abs() < TOL);
}
#[test]
fn test_hvp_dimension_mismatch_error() {
let result = hvp(|xs| xs[0] * xs[0], &[1.0], &[1.0, 0.0]);
assert!(result.is_err());
}
#[test]
fn test_vmap_scale() {
let batch = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("batch");
let result = vmap(|xs| vec![xs[0] * 2.0, xs[1] * 3.0], &batch).expect("vmap scale");
assert_eq!(result.shape(), &[3, 2]);
assert!((result[[0, 0]] - 2.0).abs() < 1e-12);
assert!((result[[0, 1]] - 6.0).abs() < 1e-12);
assert!((result[[2, 0]] - 10.0).abs() < 1e-12);
assert!((result[[2, 1]] - 18.0).abs() < 1e-12);
}
#[test]
fn test_vmap_applies_independently_to_each_row() {
let data = vec![1.0, 2.0, 3.0, 4.0]; let batch = Array2::from_shape_vec((2, 2), data).expect("batch");
let result = vmap(|xs| vec![(xs[0] + xs[1]) * (xs[0] + xs[1])], &batch)
.expect("vmap sum sq");
assert!((result[[0, 0]] - 9.0).abs() < 1e-12); assert!((result[[1, 0]] - 49.0).abs() < 1e-12); }
#[test]
fn test_vmap_empty_batch_error() {
let empty = Array2::<f64>::zeros((0, 2));
let result = vmap(|xs| vec![xs[0]], &empty);
assert!(result.is_err());
}
#[test]
fn test_batch_grad_linear_regression() {
let batch = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).expect("batch");
let g = batch_grad(
|params, sample| {
let diff = params[0] * sample[0] - 1.0;
diff * diff
},
&[0.5],
&batch,
)
.expect("batch_grad linear");
assert!(g[0].is_finite());
assert!(g[0].is_finite());
}
#[test]
fn test_batch_grad_empty_params_error() {
let batch = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).expect("batch");
let result = batch_grad(|_p, _s| 0.0, &[], &batch);
assert!(result.is_err());
}
#[test]
fn test_batch_grad_empty_batch_error() {
let empty = Array2::<f64>::zeros((0, 1));
let result = batch_grad(|_p, _s| 0.0, &[1.0], &empty);
assert!(result.is_err());
}
}