use ndarray::Array1;
pub fn numerical_gradient_central_diff<F>(mut f: F, x: &Array1<f64>, eps: f64) -> Array1<f64>
where
F: FnMut(&Array1<f64>) -> f64,
{
let mut grad = Array1::zeros(x.len());
for i in 0..x.len() {
let mut xp = x.clone();
let mut xm = x.clone();
xp[i] += eps;
xm[i] -= eps;
grad[i] = (f(&xp) - f(&xm)) / (2.0 * eps);
}
grad
}
pub fn directional_central_diff<F>(
mut f: F,
x: &Array1<f64>,
direction: &Array1<f64>,
eps: f64,
) -> Array1<f64>
where
F: FnMut(&Array1<f64>) -> Array1<f64>,
{
assert_eq!(
x.len(),
direction.len(),
"directional_central_diff: x and direction must have equal length"
);
let xp = x + &(direction * eps);
let xm = x - &(direction * eps);
(f(&xp) - f(&xm)) / (2.0 * eps)
}
pub fn verify_gradient_vs_fd<F>(
objective: F,
analytic_grad: &Array1<f64>,
x: &Array1<f64>,
eps: f64,
tol: f64,
) -> Result<(), String>
where
F: FnMut(&Array1<f64>) -> f64,
{
if analytic_grad.len() != x.len() {
return Err(format!(
"verify_gradient_vs_fd: analytic gradient length {} != x length {}",
analytic_grad.len(),
x.len()
));
}
let fd = numerical_gradient_central_diff(objective, x, eps);
for i in 0..x.len() {
let bound = tol * (1.0 + fd[i].abs());
let gap = (analytic_grad[i] - fd[i]).abs();
if gap > bound {
return Err(format!(
"verify_gradient_vs_fd: coordinate {i} disagrees: analytic={:.6e}, fd={:.6e}, gap={:.3e}, tol={:.3e} (bound {:.3e})",
analytic_grad[i], fd[i], gap, tol, bound
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn quadratic_gradient_and_directional_match_closed_form() {
let a = array![[3.0, 0.5, -0.2], [0.5, 2.0, 0.4], [-0.2, 0.4, 1.5]];
let b = array![0.3, -1.1, 0.7];
let x = array![0.9, -0.4, 1.3];
let objective = |v: &Array1<f64>| 0.5 * v.dot(&a.dot(v)) + b.dot(v);
let analytic_grad = a.dot(&x) + &b;
let eps = 1e-6;
let fd = numerical_gradient_central_diff(objective, &x, eps);
for i in 0..x.len() {
assert_abs_diff_eq!(fd[i], analytic_grad[i], epsilon = 1e-6);
}
verify_gradient_vs_fd(objective, &analytic_grad, &x, eps, 1e-5)
.expect("analytic gradient matches FD of the quadratic");
let direction = array![0.6, -0.8, 0.2];
let grad_map = |v: &Array1<f64>| a.dot(v) + &b;
let hvp_fd = directional_central_diff(grad_map, &x, &direction, eps);
let hvp_exact = a.dot(&direction);
for i in 0..direction.len() {
assert_abs_diff_eq!(hvp_fd[i], hvp_exact[i], epsilon = 1e-6);
}
}
#[test]
fn verify_rejects_wrong_gradient() {
let x = array![1.0, 2.0];
let objective = |v: &Array1<f64>| v[0] * v[0] + v[1] * v[1];
let exact = array![2.0, 4.0];
verify_gradient_vs_fd(objective, &exact, &x, 1e-6, 1e-5).expect("exact gradient passes");
let wrong = array![2.0, 4.5];
let err = verify_gradient_vs_fd(objective, &wrong, &x, 1e-6, 1e-5)
.expect_err("perturbed gradient must be rejected");
assert!(
err.contains("coordinate 1"),
"error should name coord 1: {err}"
);
}
}