Skip to main content

gam_test_support/
fd_checker.rs

1//! Central-difference finite-difference checking harness for tests.
2//!
3//! Test modules across the crate repeatedly hand-roll the same central-difference
4//! gradient check: clone the parameter vector, bump one coordinate by `±eps`,
5//! evaluate a scalar objective, form `(f₊ − f₋) / (2·eps)`, and compare against an
6//! analytic gradient component. This module captures the two mechanical shapes —
7//! a coordinate-wise scalar-objective gradient and a directional derivative of a
8//! vector-valued map — behind named helpers so each call site routes through one
9//! audited implementation instead of an open-coded loop.
10//!
11//! These helpers are *only* for tests. They are not part of any production solver
12//! path; the production outer-gradient FD audit lives in
13//! [`crate::solver::rho_optimizer::fd_audit`] and is a different (criterion-level,
14//! diagnostic-logging) facility.
15
16use ndarray::Array1;
17
18/// Central finite-difference gradient of a scalar objective at `x`.
19///
20/// For each coordinate `i`, returns `(f(x + eps·eᵢ) − f(x − eps·eᵢ)) / (2·eps)`.
21/// `f` is evaluated `2·len(x)` times. The input slice is never mutated (each
22/// evaluation operates on a fresh clone), so `f` may borrow `x`'s surroundings
23/// freely.
24pub fn numerical_gradient_central_diff<F>(mut f: F, x: &Array1<f64>, eps: f64) -> Array1<f64>
25where
26    F: FnMut(&Array1<f64>) -> f64,
27{
28    let mut grad = Array1::zeros(x.len());
29    for i in 0..x.len() {
30        let mut xp = x.clone();
31        let mut xm = x.clone();
32        xp[i] += eps;
33        xm[i] -= eps;
34        grad[i] = (f(&xp) - f(&xm)) / (2.0 * eps);
35    }
36    grad
37}
38
39/// Directional central finite-difference of a vector-valued map `f` at `x` along
40/// `direction`: `(f(x + eps·d) − f(x − eps·d)) / (2·eps)`.
41///
42/// This is the shape used to validate a Hessian-vector product or a directional
43/// score derivative against an analytic operator action: pass the gradient/score
44/// map as `f` and the probe vector as `direction`.
45pub fn directional_central_diff<F>(
46    mut f: F,
47    x: &Array1<f64>,
48    direction: &Array1<f64>,
49    eps: f64,
50) -> Array1<f64>
51where
52    F: FnMut(&Array1<f64>) -> Array1<f64>,
53{
54    assert_eq!(
55        x.len(),
56        direction.len(),
57        "directional_central_diff: x and direction must have equal length"
58    );
59    let xp = x + &(direction * eps);
60    let xm = x - &(direction * eps);
61    (f(&xp) - f(&xm)) / (2.0 * eps)
62}
63
64/// Verify an analytic gradient against the central finite-difference of the
65/// objective, coordinate by coordinate.
66///
67/// Each component must agree to `tol·(1 + |fd|)` — a mixed absolute/relative
68/// bound that stays meaningful both where the gradient is `O(1)` and where it is
69/// near zero. Returns `Err` naming the first failing coordinate (with both
70/// values and the realized gap) so the test panic message localizes the
71/// disagreement; returns `Ok(())` when every coordinate agrees.
72pub fn verify_gradient_vs_fd<F>(
73    objective: F,
74    analytic_grad: &Array1<f64>,
75    x: &Array1<f64>,
76    eps: f64,
77    tol: f64,
78) -> Result<(), String>
79where
80    F: FnMut(&Array1<f64>) -> f64,
81{
82    if analytic_grad.len() != x.len() {
83        return Err(format!(
84            "verify_gradient_vs_fd: analytic gradient length {} != x length {}",
85            analytic_grad.len(),
86            x.len()
87        ));
88    }
89    let fd = numerical_gradient_central_diff(objective, x, eps);
90    for i in 0..x.len() {
91        let bound = tol * (1.0 + fd[i].abs());
92        let gap = (analytic_grad[i] - fd[i]).abs();
93        if gap > bound {
94            return Err(format!(
95                "verify_gradient_vs_fd: coordinate {i} disagrees: analytic={:.6e}, fd={:.6e}, gap={:.3e}, tol={:.3e} (bound {:.3e})",
96                analytic_grad[i], fd[i], gap, tol, bound
97            ));
98        }
99    }
100    Ok(())
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use approx::assert_abs_diff_eq;
107    use ndarray::array;
108
109    /// `f(x) = ½·xᵀA x + bᵀx` with symmetric `A`, whose exact gradient is
110    /// `A x + b`. Exercises all three helpers against the closed form.
111    #[test]
112    fn quadratic_gradient_and_directional_match_closed_form() {
113        let a = array![[3.0, 0.5, -0.2], [0.5, 2.0, 0.4], [-0.2, 0.4, 1.5]];
114        let b = array![0.3, -1.1, 0.7];
115        let x = array![0.9, -0.4, 1.3];
116
117        let objective = |v: &Array1<f64>| 0.5 * v.dot(&a.dot(v)) + b.dot(v);
118        let analytic_grad = a.dot(&x) + &b;
119
120        let eps = 1e-6;
121        let fd = numerical_gradient_central_diff(objective, &x, eps);
122        for i in 0..x.len() {
123            assert_abs_diff_eq!(fd[i], analytic_grad[i], epsilon = 1e-6);
124        }
125
126        verify_gradient_vs_fd(objective, &analytic_grad, &x, eps, 1e-5)
127            .expect("analytic gradient matches FD of the quadratic");
128
129        // Directional FD of the gradient map recovers the Hessian action A·d.
130        let direction = array![0.6, -0.8, 0.2];
131        let grad_map = |v: &Array1<f64>| a.dot(v) + &b;
132        let hvp_fd = directional_central_diff(grad_map, &x, &direction, eps);
133        let hvp_exact = a.dot(&direction);
134        for i in 0..direction.len() {
135            assert_abs_diff_eq!(hvp_fd[i], hvp_exact[i], epsilon = 1e-6);
136        }
137    }
138
139    /// A wrong analytic gradient must be rejected with the offending coordinate
140    /// named.
141    #[test]
142    fn verify_rejects_wrong_gradient() {
143        let x = array![1.0, 2.0];
144        let objective = |v: &Array1<f64>| v[0] * v[0] + v[1] * v[1];
145        let exact = array![2.0, 4.0];
146        verify_gradient_vs_fd(objective, &exact, &x, 1e-6, 1e-5).expect("exact gradient passes");
147
148        let wrong = array![2.0, 4.5];
149        let err = verify_gradient_vs_fd(objective, &wrong, &x, 1e-6, 1e-5)
150            .expect_err("perturbed gradient must be rejected");
151        assert!(
152            err.contains("coordinate 1"),
153            "error should name coord 1: {err}"
154        );
155    }
156}