gam_test_support/
fd_checker.rs1use ndarray::Array1;
17
18pub fn numerical_gradient_central_diff<F>(mut f: F, x: &Array1<f64>, eps: f64) -> Array1<f64>
25where
26 F: FnMut(&Array1<f64>) -> f64,
27{
28 let mut grad = Array1::zeros(x.len());
29 for i in 0..x.len() {
30 let mut xp = x.clone();
31 let mut xm = x.clone();
32 xp[i] += eps;
33 xm[i] -= eps;
34 grad[i] = (f(&xp) - f(&xm)) / (2.0 * eps);
35 }
36 grad
37}
38
39pub fn directional_central_diff<F>(
46 mut f: F,
47 x: &Array1<f64>,
48 direction: &Array1<f64>,
49 eps: f64,
50) -> Array1<f64>
51where
52 F: FnMut(&Array1<f64>) -> Array1<f64>,
53{
54 assert_eq!(
55 x.len(),
56 direction.len(),
57 "directional_central_diff: x and direction must have equal length"
58 );
59 let xp = x + &(direction * eps);
60 let xm = x - &(direction * eps);
61 (f(&xp) - f(&xm)) / (2.0 * eps)
62}
63
64pub fn verify_gradient_vs_fd<F>(
73 objective: F,
74 analytic_grad: &Array1<f64>,
75 x: &Array1<f64>,
76 eps: f64,
77 tol: f64,
78) -> Result<(), String>
79where
80 F: FnMut(&Array1<f64>) -> f64,
81{
82 if analytic_grad.len() != x.len() {
83 return Err(format!(
84 "verify_gradient_vs_fd: analytic gradient length {} != x length {}",
85 analytic_grad.len(),
86 x.len()
87 ));
88 }
89 let fd = numerical_gradient_central_diff(objective, x, eps);
90 for i in 0..x.len() {
91 let bound = tol * (1.0 + fd[i].abs());
92 let gap = (analytic_grad[i] - fd[i]).abs();
93 if gap > bound {
94 return Err(format!(
95 "verify_gradient_vs_fd: coordinate {i} disagrees: analytic={:.6e}, fd={:.6e}, gap={:.3e}, tol={:.3e} (bound {:.3e})",
96 analytic_grad[i], fd[i], gap, tol, bound
97 ));
98 }
99 }
100 Ok(())
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use approx::assert_abs_diff_eq;
107 use ndarray::array;
108
109 #[test]
112 fn quadratic_gradient_and_directional_match_closed_form() {
113 let a = array![[3.0, 0.5, -0.2], [0.5, 2.0, 0.4], [-0.2, 0.4, 1.5]];
114 let b = array![0.3, -1.1, 0.7];
115 let x = array![0.9, -0.4, 1.3];
116
117 let objective = |v: &Array1<f64>| 0.5 * v.dot(&a.dot(v)) + b.dot(v);
118 let analytic_grad = a.dot(&x) + &b;
119
120 let eps = 1e-6;
121 let fd = numerical_gradient_central_diff(objective, &x, eps);
122 for i in 0..x.len() {
123 assert_abs_diff_eq!(fd[i], analytic_grad[i], epsilon = 1e-6);
124 }
125
126 verify_gradient_vs_fd(objective, &analytic_grad, &x, eps, 1e-5)
127 .expect("analytic gradient matches FD of the quadratic");
128
129 let direction = array![0.6, -0.8, 0.2];
131 let grad_map = |v: &Array1<f64>| a.dot(v) + &b;
132 let hvp_fd = directional_central_diff(grad_map, &x, &direction, eps);
133 let hvp_exact = a.dot(&direction);
134 for i in 0..direction.len() {
135 assert_abs_diff_eq!(hvp_fd[i], hvp_exact[i], epsilon = 1e-6);
136 }
137 }
138
139 #[test]
142 fn verify_rejects_wrong_gradient() {
143 let x = array![1.0, 2.0];
144 let objective = |v: &Array1<f64>| v[0] * v[0] + v[1] * v[1];
145 let exact = array![2.0, 4.0];
146 verify_gradient_vs_fd(objective, &exact, &x, 1e-6, 1e-5).expect("exact gradient passes");
147
148 let wrong = array![2.0, 4.5];
149 let err = verify_gradient_vs_fd(objective, &wrong, &x, 1e-6, 1e-5)
150 .expect_err("perturbed gradient must be rejected");
151 assert!(
152 err.contains("coordinate 1"),
153 "error should name coord 1: {err}"
154 );
155 }
156}