gam_problem/laplace_sampler_contract.rs
1//! Laplace-correction / mode-posterior sampler contract (trait-inversion #1521).
2//!
3//! gam-solve's REML inner loop (`#784` block-local sampled marginal correction)
4//! and the custom-family never-fail covariance path call into the
5//! gam-inference-tier NUTS / importance-sampling engine (`inference::hmc_io`,
6//! ~8k lines) — an UP-edge that keeps gam-solve in the inference SCC.
7//!
8//! The COMPUTATION (NUTS, importance sampling, the directional-cubic eigen
9//! diagnostic) is irreducibly above gam-solve and STAYS UP in `hmc_io`. Only
10//! the neutral surface is contract-downed here, mirroring the `rho_posterior`
11//! data-down (#1521):
12//!
13//! * the plain-DATA result carriers gam-solve reads
14//! ([`BlockSampledMarginal`], [`BlockSampledMoments`], [`GaussianModePosterior`],
15//! [`LaplaceTrustworthiness`]);
16//! * the caller-supplied [`BlockExcessTarget`] evaluator gam-solve IMPLEMENTS
17//! (its `Gam784BlockTarget`), so the trait must live below both;
18//! * the two SAMPLER TRAITS ([`LaplaceMarginalSampler`],
19//! [`GaussianModePosteriorSampler`]) gam-solve calls THROUGH; the monolith /
20//! gam-inference implements them over `hmc_io` and injects the impl via the
21//! process-level registry below.
22//!
23//! The pure threshold math ([`laplace_skewness_threshold`],
24//! [`laplace_trustworthiness_from_skewness`]) has no sampler dependency, so it is
25//! moved down outright (gam-solve calls it directly).
26//!
27//! When no impl is registered (e.g. a build that never links the sampler tier)
28//! the sampler getters return `None` and gam-solve degrades to its existing
29//! decline paths — the `#784` correction returns zero (already a frequent
30//! decline outcome) and the never-fail covariance path keeps the
31//! optimizer-conditional covariance (already the `Err(reason)` fallback). The
32//! contract therefore introduces no behavioral cliff and no stub.
33
34use std::sync::OnceLock;
35
36use gam_linalg::matrix::DesignMatrix;
37use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
38
39// ───────────────────────── data carriers (contract-down) ─────────────────────
40
41/// Adaptive, block-local Laplace-trustworthiness verdict (issue #784): which
42/// curvature directions are too non-Gaussian for the plain Laplace summary.
43///
44/// Field-for-field the monolith `hmc_io` type; that module re-exports this so
45/// its construction sites name it unchanged.
46#[derive(Clone, Debug)]
47pub struct LaplaceTrustworthiness {
48 /// Per-direction standardized skewness `γ_r`.
49 pub directional_skewness: Array1<f64>,
50 /// Indices of the directions whose skewness exceeds the auto-derived
51 /// validity threshold (the curvature-heavy, non-Gaussian block).
52 pub untrustworthy_directions: Vec<usize>,
53 /// The auto-derived per-direction skewness threshold `τ(n)` actually used.
54 pub threshold: f64,
55 /// `max_r |γ_r|` across all directions (the global non-Gaussianity scale).
56 pub max_abs_skewness: f64,
57}
58
59impl LaplaceTrustworthiness {
60 /// Whether any curvature direction is too non-Gaussian for the plain
61 /// Laplace summary, i.e. whether the higher-order correction / directional
62 /// sampling fallback should engage at all.
63 pub fn fallback_required(&self) -> bool {
64 !self.untrustworthy_directions.is_empty()
65 }
66}
67
68/// Self-normalized importance-weighted moments of the per-draw gradient channels
69/// — the sampler-side half of the #784 exact-gradient seam. All expectations are
70/// under `p ∝ q·e^{−ΔF}` over the SAME fixed-seed draws that produced the value,
71/// so the spliced value and its assembled gradient can never desync (#901).
72#[derive(Clone, Debug)]
73pub struct BlockSampledMoments {
74 /// `E_p[t]`, length `m`.
75 pub e_t: Array1<f64>,
76 /// `E_p[t tᵀ]`, shape `m × m`.
77 pub e_tt: Array2<f64>,
78 /// `E_p[ngs(η̂+s)]`, length n — the displaced per-row score moment.
79 pub e_neg_score: Array1<f64>,
80 /// Column `r` = `E_p[t_r · ngs(η̂+s)]`, shape `n × m`.
81 pub e_t_neg_score: Array2<f64>,
82}
83
84/// Block-local sampled marginal correction (issue #784).
85///
86/// `value` is `Δ_b` (added to the block marginal log-likelihood, subtracted from
87/// the REML/LAML cost); `rho_gradient` is the explicit penalty-score channel (a)
88/// of the gradient exactness contract; `moments` carries the channels (b)–(d) the
89/// gam-solve assembly contracts against fields it already owns.
90#[derive(Clone, Debug)]
91pub struct BlockSampledMarginal {
92 /// `Δ_b`: additive correction to the block marginal log-likelihood.
93 pub value: f64,
94 /// `∂Δ_b/∂ρ`, length `rho_dim()` — explicit channel (a) ONLY.
95 pub rho_gradient: Array1<f64>,
96 /// Importance-sampling effective sample size (draws), for trust gating.
97 pub importance_ess: f64,
98 /// Number of draws used.
99 pub n_draws: usize,
100 /// Gradient-channel moments for the exact (b)–(d) assembly; `None` only when
101 /// the block is empty (`m == 0`, where the correction is zero).
102 pub moments: Option<BlockSampledMoments>,
103}
104
105/// Honest posterior summary from sampling the proper Gaussian posterior
106/// `N(mode, H⁻¹)` — the terminal never-fail rung of the custom-family
107/// covariance escalation. Field-for-field the monolith `hmc_io` type.
108#[derive(Clone, Debug)]
109pub struct GaussianModePosterior {
110 /// Coefficient draws in original (un-whitened) space: `(n_draws, dim)`.
111 pub samples: Array2<f64>,
112 /// Posterior mean (≈ the seeded mode for a Gaussian target).
113 pub posterior_mean: Array1<f64>,
114 /// Per-coordinate posterior standard deviation (honest SEs).
115 pub posterior_std: Array1<f64>,
116 /// Split-chain R̂ mixing diagnostic.
117 pub rhat: f64,
118 /// Effective sample size.
119 pub ess: f64,
120}
121
122// ───────────────────────── pure threshold math (moved down) ──────────────────
123
124/// Auto-derive the per-direction skewness threshold `τ(n)` separating
125/// Laplace-trustworthy directions from those that need the higher-order
126/// correction / sampling fallback. Derived purely from the effective sample
127/// size, no tunable flag: `(5/24)γ_r² > 1/n_eff ⇔ |γ_r| > sqrt((24/5)/n_eff)`.
128pub fn laplace_skewness_threshold(n_eff: f64) -> f64 {
129 if !(n_eff > 0.0) {
130 return f64::INFINITY;
131 }
132 ((24.0 / 5.0) / n_eff).sqrt()
133}
134
135/// Adaptive, block-local Laplace-trustworthiness verdict (issue #784): flag the
136/// directions whose standardized skewness exceeds [`laplace_skewness_threshold`].
137/// No linear algebra of its own — consumes the directional cubic diagnostic.
138pub fn laplace_trustworthiness_from_skewness(
139 directional_skewness: &Array1<f64>,
140 n_eff: f64,
141) -> LaplaceTrustworthiness {
142 let threshold = laplace_skewness_threshold(n_eff);
143 let mut untrustworthy_directions = Vec::new();
144 let mut max_abs_skewness = 0.0_f64;
145 for (r, &gamma) in directional_skewness.iter().enumerate() {
146 let abs_gamma = if gamma.is_finite() { gamma.abs() } else { 0.0 };
147 max_abs_skewness = max_abs_skewness.max(abs_gamma);
148 if abs_gamma > threshold {
149 untrustworthy_directions.push(r);
150 }
151 }
152 LaplaceTrustworthiness {
153 directional_skewness: directional_skewness.clone(),
154 untrustworthy_directions,
155 threshold,
156 max_abs_skewness,
157 }
158}
159
160// ───────────────────────── caller-supplied excess evaluator ──────────────────
161
162/// Caller-supplied evaluator for the non-Gaussian remainder `ΔF(t)` of the local
163/// log-posterior, restricted to the curvature-heavy block subspace (issue #784).
164///
165/// Implemented by gam-solve's `Gam784BlockTarget`; consumed by
166/// [`LaplaceMarginalSampler::block_sampled_marginal_correction`]. Lives in this
167/// neutral crate so both the implementor (gam-solve) and the sampler impl (the
168/// gam-inference monolith) name the same trait without an SCC edge.
169pub trait BlockExcessTarget {
170 /// Dimension `m` of the block subspace (number of untrustworthy directions
171 /// being sampled).
172 fn block_dim(&self) -> usize;
173 /// Number of outer ρ coordinates the gradient is reported against.
174 fn rho_dim(&self) -> usize;
175 /// Block curvatures `λ_r` (the H-eigenvalues of the sampled directions),
176 /// length `block_dim()`.
177 fn block_curvatures(&self) -> &Array1<f64>;
178 /// Non-Gaussian remainder `ΔF(t)` at whitened block displacement `t`
179 /// (length `block_dim()`).
180 fn excess(&self, t: &Array1<f64>) -> f64;
181 /// ρ-gradient `∂ΔF/∂ρ_k` at the same `t`, length `rho_dim()` — the explicit
182 /// penalty-score channel (a).
183 fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64>;
184 /// Per-row displaced score `∂(D(η̂+s(t))/2φ)/∂η` evaluated at `η̂ + s(t)`
185 /// (length = number of observation rows): the only per-draw ingredient of
186 /// the exact-gradient channels (b)–(d) the assembly side cannot reconstruct.
187 fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64>;
188 /// The same per-row score channel at the undisplaced mode `η̂`.
189 fn base_neg_score(&self) -> Array1<f64>;
190
191 /// Fused `(excess(t), displaced_neg_score(t))`. The returned score is `None`
192 /// exactly when the excess is non-finite (an infeasible draw the sampler
193 /// discards before reading the score). The default preserves the two-call
194 /// behavior; implementors override to share the displacement + jet.
195 fn excess_with_displaced_neg_score(&self, t: &Array1<f64>) -> (f64, Option<Array1<f64>>) {
196 let excess = self.excess(t);
197 if excess.is_finite() {
198 (excess, Some(self.displaced_neg_score(t)))
199 } else {
200 (excess, None)
201 }
202 }
203
204 /// Batched [`Self::excess_with_displaced_neg_score`] over many whitened draws
205 /// (one draw per COLUMN, shape `block_dim() × n_draws`). Batching may only
206 /// change HOW the shared linear algebra is computed (one BLAS-3 product over
207 /// all columns), never WHAT is computed. The default preserves the per-column
208 /// behavior exactly; the GLM implementor overrides it.
209 fn excess_with_displaced_neg_score_batch(
210 &self,
211 draws: &Array2<f64>,
212 ) -> Vec<(f64, Option<Array1<f64>>)> {
213 let n_draws = draws.ncols();
214 let mut out = Vec::with_capacity(n_draws);
215 let mut t = Array1::<f64>::zeros(draws.nrows());
216 for s in 0..n_draws {
217 t.assign(&draws.column(s));
218 out.push(self.excess_with_displaced_neg_score(&t));
219 }
220 out
221 }
222}
223
224// ───────────────────────── injected sampler traits ───────────────────────────
225
226/// The gam-inference-tier sampler for the #784 block-local Laplace correction.
227///
228/// Implemented UP in the monolith over `hmc_io`
229/// (`laplace_directional_cubic_diagnostic` + `block_sampled_marginal_correction`)
230/// and injected DOWN via [`set_laplace_marginal_sampler`]. gam-solve calls
231/// through [`laplace_marginal_sampler`].
232pub trait LaplaceMarginalSampler: Send + Sync {
233 /// Per-direction standardized cubic skewness `γ_r` of the local posterior:
234 /// returns `(max_r |γ_r|, γ)`. Pure eigen-diagnostic (no sampling), but kept
235 /// behind the trait because it lives in the sampler module up-tier.
236 fn directional_cubic_diagnostic(
237 &self,
238 hessian: &Array2<f64>,
239 design: &DesignMatrix,
240 c_weights: &Array1<f64>,
241 refine_supremum: bool,
242 ) -> Result<(f64, Array1<f64>), String>;
243
244 /// Estimate `Δ_b` and its ρ-gradient by importance sampling against the local
245 /// Laplace Gaussian, contracting the caller-supplied [`BlockExcessTarget`].
246 fn block_sampled_marginal_correction(
247 &self,
248 target: &dyn BlockExcessTarget,
249 ) -> Result<BlockSampledMarginal, String>;
250}
251
252/// The gam-inference-tier sampler for the never-fail Gaussian mode posterior
253/// (custom-family covariance escalation). Implemented UP over
254/// `hmc_io::sample_gaussian_mode_posterior` (which auto-derives its
255/// `NutsConfig::for_dimension(mode.len())` internally — that NUTS config never
256/// crosses the contract) and injected DOWN via
257/// [`set_gaussian_mode_posterior_sampler`].
258pub trait GaussianModePosteriorSampler: Send + Sync {
259 /// Sample `N(mode, precision⁻¹)`. `Err` only for a structurally impossible
260 /// request (dimension mismatch, non-PSD precision after symmetrization) —
261 /// never for "did not converge".
262 fn sample_gaussian_mode_posterior(
263 &self,
264 mode: ArrayView1<f64>,
265 precision: ArrayView2<f64>,
266 ) -> Result<GaussianModePosterior, String>;
267}
268
269// ───────────────────────── process-level injection registry ──────────────────
270
271static LAPLACE_MARGINAL_SAMPLER: OnceLock<Box<dyn LaplaceMarginalSampler>> = OnceLock::new();
272static GAUSSIAN_MODE_POSTERIOR_SAMPLER: OnceLock<Box<dyn GaussianModePosteriorSampler>> =
273 OnceLock::new();
274
275/// Register the monolith's `hmc_io`-backed #784 Laplace-correction sampler.
276/// Called once at process init by the gam-inference tier. First writer wins;
277/// a later call is ignored (returns `Err` with the boxed value) so a re-init can
278/// never swap a live sampler mid-run.
279pub fn set_laplace_marginal_sampler(
280 sampler: Box<dyn LaplaceMarginalSampler>,
281) -> Result<(), Box<dyn LaplaceMarginalSampler>> {
282 LAPLACE_MARGINAL_SAMPLER.set(sampler)
283}
284
285/// The registered #784 Laplace-correction sampler, or `None` when the sampler
286/// tier is not linked / not yet initialized (gam-solve then declines the
287/// correction, returning the zero contribution — a safe no-op).
288pub fn laplace_marginal_sampler() -> Option<&'static dyn LaplaceMarginalSampler> {
289 LAPLACE_MARGINAL_SAMPLER.get().map(|b| b.as_ref())
290}
291
292/// Register the monolith's `hmc_io`-backed never-fail Gaussian-mode-posterior
293/// sampler. First writer wins (see [`set_laplace_marginal_sampler`]).
294pub fn set_gaussian_mode_posterior_sampler(
295 sampler: Box<dyn GaussianModePosteriorSampler>,
296) -> Result<(), Box<dyn GaussianModePosteriorSampler>> {
297 GAUSSIAN_MODE_POSTERIOR_SAMPLER.set(sampler)
298}
299
300/// The registered never-fail Gaussian-mode-posterior sampler, or `None` when the
301/// sampler tier is not linked (the custom-family path then retains the
302/// optimizer-conditional covariance — its existing fallback).
303pub fn gaussian_mode_posterior_sampler() -> Option<&'static dyn GaussianModePosteriorSampler> {
304 GAUSSIAN_MODE_POSTERIOR_SAMPLER.get().map(|b| b.as_ref())
305}