Skip to main content

gam_sae/
certificates.rs

1//! Per-fit first-order optimality certificate for the SAE LAML criterion
2//! (issue #934).
3//!
4//! # What this is
5//!
6//! The recurring structural bug genus in this engine is *objective↔gradient
7//! desync*: the criterion value `V(ρ)` and its analytic derivative `∇V(ρ)`
8//! are computed by separate code paths that drift apart (#752, #748, #808,
9//! #901, …). Every one of those bugs was eventually diagnosed by a human
10//! running one finite-difference comparison by hand at the returned optimum:
11//! the analytic gradient claimed convergence while a finite difference of the
12//! *actual objective* disagreed. The engine never ran that check on itself.
13//!
14//! This module makes the engine run it on every real fit. Once, at the
15//! converged outer optimum `ρ̂`, outside all hot loops:
16//!
17//! 1. Draw one deterministic direction `v` on the ρ-sphere from the problem
18//!    fingerprint (no `Date`/random-source nondeterminism).
19//! 2. Central-difference the criterion **value path** at `ρ̂ ± h v` with a
20//!    Richardson second step `2h` to estimate the FD's own error bar.
21//! 3. Compare against the analytic directional derivative `∇V(ρ̂)·v`.
22//! 4. Record a [`CriterionCertificate`] on the fit payload.
23//!
24//! # Why finite differences are legal here
25//!
26//! The exact-REML-only policy bans approximate quantities from *producing*
27//! the fit. This FD probe does not produce anything the fit consumes — it
28//! **audits** the production analytic gradient against the production value
29//! path, at a single point, after convergence. FD is the audit instrument,
30//! not the estimator. That boundary is what makes it the runtime enforcement
31//! layer for the criterion-atom architecture (#931): atoms make desync
32//! structurally hard to write; this certificate makes any residue observable
33//! in production, where the real data shapes that trigger #901-class desyncs
34//! actually occur. It is the same relationship theta-correction atoms have to
35//! their invariants — the invariant is enforced structurally and audited at
36//! runtime.
37//!
38//! # Cost
39//!
40//! Two value-path evaluations for the central difference plus two for the
41//! Richardson step: four criterion evaluations at the single final point,
42//! seconds even at biobank scale. The value path is evaluated **without**
43//! warm-state shortcuts that would alias it to the gradient path — that
44//! aliasing is exactly what must be audited — so the probe is taken on a
45//! clone of the pristine baseline term whose caches start cold and naturally
46//! miss the gradient path's converged state.
47
48use ndarray::ArrayView1;
49
50/// The result of the first-order self-audit at the converged outer optimum.
51///
52/// Disagreement does not, by itself, fail the fit — it *names the broken term
53/// loudly* in the result and the report, converting the next desync from a
54/// multi-week biobank-stall investigation into a one-line diagnosis at the
55/// moment of introduction. Consumers decide the policy verdict from
56/// [`Self::agreement_rel`] against their own tolerance.
57#[derive(Debug, Clone, Copy)]
58pub struct CriterionCertificate {
59    /// `‖∇V(ρ̂)‖₂`, the analytic gradient norm reported as converged.
60    pub grad_norm: f64,
61    /// Central-difference directional derivative of the **value path**:
62    /// `[V(ρ̂ + h v) − V(ρ̂ − h v)] / (2h)`.
63    // FD-OK: FD-audit certificate oracle field verifying the analytic directional derivative
64    pub fd_directional: f64, // fd-ok: FD-audit certificate, not in math path
65    // END-FD-OK
66    /// Analytic directional derivative `∇V(ρ̂)·v` from the production gradient
67    /// path, on the same unit direction `v`.
68    pub analytic_directional: f64,
69    /// Richardson error-bar estimate of the finite difference itself:
70    /// `|D(h) − D(2h)| / 3`, the leading `O(h²)` truncation term of the central
71    /// difference. The FD/analytic gap is only meaningful relative to this — a
72    /// gap below the error bar is consistent with exact agreement.
73    // FD-OK: Richardson error-bar of the FD-audit oracle (reporting only)
74    pub fd_error_bar: f64, // fd-ok: FD-audit certificate, not in math path
75    // END-FD-OK
76    /// The probe step `h` actually used (scaled to the coordinate magnitude).
77    pub step: f64,
78    /// Whether the criterion's curvature at `ρ̂` is usable: the value-path
79    /// evaluations all returned finite, an undamped factorization succeeded.
80    /// `false` flags a railed/degenerate optimum (the #748 indefinite-`H+Sλ`
81    /// signature) even when the directional probe itself agrees.
82    pub well_posed: bool,
83}
84
85impl CriterionCertificate {
86    /// Signed FD−analytic disagreement, normalized to the larger of the two
87    /// directional magnitudes (so a flat direction with both near zero does
88    /// not manufacture a spurious relative blow-up) and floored by the FD
89    /// error bar (so we never claim a disagreement the probe itself cannot
90    /// resolve).
91    #[must_use]
92    pub fn agreement_rel(&self) -> f64 {
93        // FD-OK: comparing FD-audit oracle against the analytic directional derivative
94        let scale = self
95            .analytic_directional
96            .abs()
97            .max(self.fd_directional.abs()) // fd-ok: FD-audit certificate, not in math path
98            .max(self.fd_error_bar) // fd-ok: FD-audit certificate, not in math path
99            .max(1e-12);
100        (self.fd_directional - self.analytic_directional).abs() / scale // fd-ok: FD-audit certificate, not in math path
101        // END-FD-OK
102    }
103
104    /// The certificate's verdict against a relative tolerance: `true` means the
105    /// analytic gradient and the value-path FD agree at the optimum to within
106    /// `rel_tol` AND the curvature is well-posed. `false` is the loud
107    /// desync/rail flag — not necessarily a failed fit, but a named one.
108    #[must_use]
109    pub fn passes(&self, rel_tol: f64) -> bool {
110        self.well_posed && self.agreement_rel() <= rel_tol
111    }
112}
113
114/// A deterministic unit direction on the ρ-sphere derived purely from the
115/// problem fingerprint `(ρ̂ values, dimension)`.
116///
117/// No `Date`, no thread-RNG, no global entropy: the same fit fingerprint
118/// always yields the same probe direction, so the certificate is reproducible
119/// and CI-stable. A SplitMix64 hash of each coordinate index mixed with the
120/// fingerprint seed produces a fixed pseudo-random direction; it is then
121/// normalized. If every coordinate hashes to zero (impossible for a nonempty
122/// vector with this mixer) the function falls back to the first axis.
123#[must_use]
124pub fn deterministic_probe_direction(rho_hat: ArrayView1<'_, f64>) -> Vec<f64> {
125    let n = rho_hat.len();
126    if n == 0 {
127        return Vec::new();
128    }
129    // Fold the optimum coordinates into a 64-bit seed (finite-safe: NaN/Inf
130    // bit patterns are tolerated, they only perturb the seed deterministically).
131    let mut seed: u64 = 0x9E37_79B9_7F4A_7C15;
132    for (idx, &value) in rho_hat.iter().enumerate() {
133        seed =
134            splitmix64(seed ^ value.to_bits() ^ (idx as u64).wrapping_mul(0x2545_F491_4F6C_DD1D));
135    }
136    let mut dir = vec![0.0_f64; n];
137    let mut s = seed;
138    let mut norm_sq = 0.0_f64;
139    for slot in dir.iter_mut() {
140        s = splitmix64(s);
141        // Map the hashed bits to a symmetric (−1, 1) coordinate.
142        let unit = (s >> 11) as f64 / ((1u64 << 53) as f64); // [0, 1)
143        let coord = 2.0 * unit - 1.0;
144        *slot = coord;
145        norm_sq += coord * coord;
146    }
147    let norm = norm_sq.sqrt();
148    if norm > 0.0 {
149        for slot in dir.iter_mut() {
150            *slot /= norm;
151        }
152    } else {
153        dir[0] = 1.0;
154    }
155    dir
156}
157
158/// SplitMix64 — a tiny, well-distributed integer mixer used only to make the
159/// probe direction a deterministic function of the fit fingerprint. Thin
160/// wrapper over the canonical implementation in
161/// [`gam_linalg::utils::splitmix64_hash`].
162fn splitmix64(state: u64) -> u64 {
163    gam_linalg::utils::splitmix64_hash(state)
164}
165
166/// The probe step `h` scaled to the magnitude of the optimum: log-ρ
167/// coordinates are `O(1)`–`O(10)`, so a relative step keeps the central
168/// difference inside the smooth quadratic region while staying well above
169/// double-precision round-off of the criterion (~`1e-8` relative).
170#[must_use]
171pub fn probe_step(rho_hat: ArrayView1<'_, f64>) -> f64 {
172    const BASE: f64 = 1e-4;
173    let scale = rho_hat.iter().fold(1.0_f64, |m, &x| m.max(x.abs()));
174    BASE * scale
175}
176
177/// Per-coordinate probe step for a single outer-ρ axis.
178///
179/// [`probe_step`] collapses the whole vector to one global step
180/// (`1e-4 · max_i|ρ_i|`), which under-resolves a small coordinate whenever some
181/// other axis is large. When differencing one axis at a time, scale the step to
182/// that coordinate's own magnitude, with a unit floor so a near-zero coordinate
183/// still gets a usable step.
184pub fn probe_step_for(rho_i: f64) -> f64 {
185    const BASE: f64 = 1e-4;
186    BASE * rho_i.abs().max(1.0)
187}
188
189/// Samples of the criterion **value path** taken at the four probe points
190/// around `ρ̂` along the unit direction `v`, plus the analytic directional
191/// derivative — everything the certificate needs, with no dependence on the
192/// SAE term type so this is unit-testable in isolation.
193#[derive(Debug, Clone, Copy)]
194pub struct DirectionalSamples {
195    /// `V(ρ̂ + h v)`.
196    pub plus_h: f64,
197    /// `V(ρ̂ − h v)`.
198    pub minus_h: f64,
199    /// `V(ρ̂ + 2h v)`.
200    pub plus_2h: f64,
201    /// `V(ρ̂ − 2h v)`.
202    pub minus_2h: f64,
203    /// The step `h`.
204    pub step: f64,
205    /// `‖∇V(ρ̂)‖₂`.
206    pub grad_norm: f64,
207    /// `∇V(ρ̂)·v`.
208    pub analytic_directional: f64,
209    /// Whether every value-path evaluation returned a finite criterion at a
210    /// well-conditioned (undamped-factorable) inner optimum.
211    pub well_posed: bool,
212}
213
214/// Assemble the [`CriterionCertificate`] from directional value samples.
215///
216/// The central difference at step `h` is `D(h) = (V₊ − V₋) / 2h`; at step
217/// `2h` it is `D(2h) = (V₊₊ − V₋₋) / 4h`. For a smooth criterion both
218/// approximate `∇V·v` with leading error `(h²/6)·V‴·v` and `(4h²/6)·V‴·v`, so
219/// `|D(h) − D(2h)| = h²/2·|V‴·v| + O(h⁴)` and the Richardson-extrapolated FD
220/// error bar of `D(h)` is `|D(h) − D(2h)| / 3` (the standard central-difference
221/// Richardson remainder). The reported `fd_directional` is `D(h)`.
222#[must_use]
223pub fn certificate_from_samples(s: &DirectionalSamples) -> CriterionCertificate {
224    // FD-OK: Richardson FD oracle constructed to audit the analytic directional derivative
225    let d_h = (s.plus_h - s.minus_h) / (2.0 * s.step);
226    let d_2h = (s.plus_2h - s.minus_2h) / (4.0 * s.step);
227    let fd_error_bar = (d_h - d_2h).abs() / 3.0; // fd-ok: FD-audit certificate, not in math path
228    CriterionCertificate {
229        grad_norm: s.grad_norm,
230        fd_directional: d_h, // fd-ok: FD-audit certificate, not in math path
231        analytic_directional: s.analytic_directional,
232        fd_error_bar, // fd-ok: FD-audit certificate, not in math path
233        // END-FD-OK
234        step: s.step,
235        well_posed: s.well_posed
236            && s.plus_h.is_finite()
237            && s.minus_h.is_finite()
238            && s.plus_2h.is_finite()
239            && s.minus_2h.is_finite(),
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use ndarray::Array1;
247
248    /// A quadratic `V(ρ) = ½ ρᵀ A ρ + bᵀρ` has exact analytic directional
249    /// derivative and a central difference that recovers it to machine
250    /// precision (third derivative is zero), so the certificate must pass at a
251    /// vanishing relative gap and a vanishing error bar.
252    #[test]
253    fn quadratic_certificate_agrees_exactly() {
254        // V(ρ) = ½(2ρ₀² + 3ρ₁²) + (ρ₀ − 2ρ₁); ∇V = (2ρ₀+1, 3ρ₁−2).
255        let v = |r: &[f64]| 0.5 * (2.0 * r[0] * r[0] + 3.0 * r[1] * r[1]) + (r[0] - 2.0 * r[1]);
256        let rho = Array1::from(vec![0.7_f64, -1.3]);
257        let grad = [2.0 * rho[0] + 1.0, 3.0 * rho[1] - 2.0];
258        let dir = deterministic_probe_direction(rho.view());
259        let h = probe_step(rho.view());
260        let at = |sign: f64, mult: f64| {
261            let p: Vec<f64> = (0..2).map(|i| rho[i] + sign * mult * h * dir[i]).collect();
262            v(&p)
263        };
264        let grad_norm = (grad[0] * grad[0] + grad[1] * grad[1]).sqrt();
265        let analytic_directional = grad[0] * dir[0] + grad[1] * dir[1];
266        let samples = DirectionalSamples {
267            plus_h: at(1.0, 1.0),
268            minus_h: at(-1.0, 1.0),
269            plus_2h: at(1.0, 2.0),
270            minus_2h: at(-1.0, 2.0),
271            step: h,
272            grad_norm,
273            analytic_directional,
274            well_posed: true,
275        };
276        let cert = certificate_from_samples(&samples);
277        assert!(
278            cert.agreement_rel() < 1e-6,
279            "quadratic FD must match analytic: rel {}, fd {}, analytic {}",
280            cert.agreement_rel(),
281            cert.fd_directional, // fd-ok: FD-audit certificate, not in math path
282            cert.analytic_directional
283        );
284        assert!(
285            cert.fd_error_bar < 1e-6, // fd-ok: FD-audit certificate, not in math path
286            "quadratic has zero third derivative, error bar must be tiny: {}",
287            cert.fd_error_bar // fd-ok: FD-audit certificate, not in math path
288        );
289        assert!(cert.passes(1e-4), "well-posed quadratic must certify");
290    }
291
292    /// A planted desync — the analytic directional derivative is deliberately
293    /// off by 30% from the true (value-path) slope — must be caught loudly: the
294    /// relative agreement blows past any sane tolerance even though the value
295    /// path itself is perfectly smooth.
296    #[test]
297    fn planted_desync_is_caught() {
298        let v = |r: &[f64]| r[0].sin() + 0.5 * r[1] * r[1];
299        let rho = Array1::from(vec![0.4_f64, 0.9]);
300        let true_grad = [rho[0].cos(), rho[1]];
301        // Desynced analytic gradient: 30% too large in coord 0.
302        let bad_grad = [1.3 * true_grad[0], true_grad[1]];
303        let dir = deterministic_probe_direction(rho.view());
304        let h = probe_step(rho.view());
305        let at = |sign: f64, mult: f64| {
306            let p: Vec<f64> = (0..2).map(|i| rho[i] + sign * mult * h * dir[i]).collect();
307            v(&p)
308        };
309        let grad_norm = (bad_grad[0] * bad_grad[0] + bad_grad[1] * bad_grad[1]).sqrt();
310        let analytic_directional = bad_grad[0] * dir[0] + bad_grad[1] * dir[1];
311        let samples = DirectionalSamples {
312            plus_h: at(1.0, 1.0),
313            minus_h: at(-1.0, 1.0),
314            plus_2h: at(1.0, 2.0),
315            minus_2h: at(-1.0, 2.0),
316            step: h,
317            grad_norm,
318            analytic_directional,
319            well_posed: true,
320        };
321        let cert = certificate_from_samples(&samples);
322        assert!(
323            !cert.passes(1e-3),
324            "30% desync must fail the certificate: rel {}, fd {}, analytic {}",
325            cert.agreement_rel(),
326            cert.fd_directional, // fd-ok: FD-audit certificate, not in math path
327            cert.analytic_directional
328        );
329    }
330
331    /// The probe direction is deterministic in the fingerprint and is a unit
332    /// vector.
333    #[test]
334    fn probe_direction_is_deterministic_unit() {
335        let rho = Array1::from(vec![1.0_f64, -2.0, 0.5, 3.3]);
336        let a = deterministic_probe_direction(rho.view());
337        let b = deterministic_probe_direction(rho.view());
338        assert_eq!(a, b, "same fingerprint must give same direction");
339        let norm: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
340        assert!(
341            (norm - 1.0).abs() < 1e-12,
342            "direction must be unit, got {norm}"
343        );
344        // A different optimum gives a different direction.
345        let rho2 = Array1::from(vec![1.0_f64, -2.0, 0.5, 3.4]);
346        let c = deterministic_probe_direction(rho2.view());
347        assert_ne!(a, c, "different fingerprint must give different direction");
348    }
349
350    /// A non-finite value-path sample marks the certificate not-well-posed even
351    /// when the recorded directional numbers happen to agree.
352    #[test]
353    fn nonfinite_sample_marks_not_well_posed() {
354        let samples = DirectionalSamples {
355            plus_h: f64::NAN,
356            minus_h: 1.0,
357            plus_2h: 2.0,
358            minus_2h: 0.0,
359            step: 1e-4,
360            grad_norm: 1.0,
361            analytic_directional: 0.0,
362            well_posed: true,
363        };
364        let cert = certificate_from_samples(&samples);
365        assert!(
366            !cert.well_posed,
367            "NaN value sample must flag not-well-posed"
368        );
369        assert!(!cert.passes(1.0), "not-well-posed never certifies");
370    }
371
372    /// A non-finite Richardson sample (`plus_2h` or `minus_2h`) also marks
373    /// the certificate not-well-posed — the error bar is meaningless if those
374    /// evaluations failed, even when the ±h samples are finite.
375    #[test]
376    fn nonfinite_richardson_sample_marks_not_well_posed() {
377        let samples = DirectionalSamples {
378            plus_h: 1.1,
379            minus_h: 0.9,
380            plus_2h: f64::NAN,
381            minus_2h: 0.8,
382            step: 1e-4,
383            grad_norm: 1.0,
384            analytic_directional: 1.0,
385            well_posed: true,
386        };
387        let cert = certificate_from_samples(&samples);
388        assert!(
389            !cert.well_posed,
390            "NaN Richardson sample must flag not-well-posed"
391        );
392        assert!(!cert.passes(1.0), "not-well-posed never certifies");
393    }
394}