Skip to main content

gam_problem/
laplace_sampler_contract.rs

1//! Laplace-correction / mode-posterior sampler contract (trait-inversion #1521).
2//!
3//! gam-solve's REML inner loop (`#784` block-local sampled marginal correction)
4//! and the custom-family never-fail covariance path call into the
5//! gam-inference-tier NUTS / importance-sampling engine (`inference::hmc_io`,
6//! ~8k lines) — an UP-edge that keeps gam-solve in the inference SCC.
7//!
8//! The COMPUTATION (NUTS, importance sampling, the directional-cubic eigen
9//! diagnostic) is irreducibly above gam-solve and STAYS UP in `hmc_io`. Only
10//! the neutral surface is contract-downed here, mirroring the `rho_posterior`
11//! data-down (#1521):
12//!
13//! * the plain-DATA result carriers gam-solve reads
14//!   ([`BlockSampledMarginal`], [`BlockSampledMoments`], [`GaussianModePosterior`],
15//!   [`LaplaceTrustworthiness`]);
16//! * the caller-supplied [`BlockExcessTarget`] evaluator gam-solve IMPLEMENTS
17//!   (its `Gam784BlockTarget`), so the trait must live below both;
18//! * the two SAMPLER TRAITS ([`LaplaceMarginalSampler`],
19//!   [`GaussianModePosteriorSampler`]) gam-solve calls THROUGH; the monolith /
20//!   gam-inference implements them over `hmc_io` and injects the impl via the
21//!   process-level registry below.
22//!
23//! The pure threshold math ([`laplace_skewness_threshold`],
24//! [`laplace_trustworthiness_from_skewness`]) has no sampler dependency, so it is
25//! moved down outright (gam-solve calls it directly).
26//!
27//! When no impl is registered (e.g. a build that never links the sampler tier)
28//! the sampler getters return `None` and gam-solve degrades to its existing
29//! decline paths — the `#784` correction returns zero (already a frequent
30//! decline outcome) and the never-fail covariance path keeps the
31//! optimizer-conditional covariance (already the `Err(reason)` fallback). The
32//! contract therefore introduces no behavioral cliff and no stub.
33
34use std::sync::OnceLock;
35
36use gam_linalg::matrix::DesignMatrix;
37use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
38
39// ───────────────────────── data carriers (contract-down) ─────────────────────
40
41/// Adaptive, block-local Laplace-trustworthiness verdict (issue #784): which
42/// curvature directions are too non-Gaussian for the plain Laplace summary.
43///
44/// Field-for-field the monolith `hmc_io` type; that module re-exports this so
45/// its construction sites name it unchanged.
46#[derive(Clone, Debug)]
47pub struct LaplaceTrustworthiness {
48    /// Per-direction standardized skewness `γ_r`.
49    pub directional_skewness: Array1<f64>,
50    /// Indices of the directions whose skewness exceeds the auto-derived
51    /// validity threshold (the curvature-heavy, non-Gaussian block).
52    pub untrustworthy_directions: Vec<usize>,
53    /// The auto-derived per-direction skewness threshold `τ(n)` actually used.
54    pub threshold: f64,
55    /// `max_r |γ_r|` across all directions (the global non-Gaussianity scale).
56    pub max_abs_skewness: f64,
57}
58
59impl LaplaceTrustworthiness {
60    /// Whether any curvature direction is too non-Gaussian for the plain
61    /// Laplace summary, i.e. whether the higher-order correction / directional
62    /// sampling fallback should engage at all.
63    pub fn fallback_required(&self) -> bool {
64        !self.untrustworthy_directions.is_empty()
65    }
66}
67
68/// Self-normalized importance-weighted moments of the per-draw gradient channels
69/// — the sampler-side half of the #784 exact-gradient seam. All expectations are
70/// under `p ∝ q·e^{−ΔF}` over the SAME fixed-seed draws that produced the value,
71/// so the spliced value and its assembled gradient can never desync (#901).
72#[derive(Clone, Debug)]
73pub struct BlockSampledMoments {
74    /// `E_p[t]`, length `m`.
75    pub e_t: Array1<f64>,
76    /// `E_p[t tᵀ]`, shape `m × m`.
77    pub e_tt: Array2<f64>,
78    /// `E_p[ngs(η̂+s)]`, length n — the displaced per-row score moment.
79    pub e_neg_score: Array1<f64>,
80    /// Column `r` = `E_p[t_r · ngs(η̂+s)]`, shape `n × m`.
81    pub e_t_neg_score: Array2<f64>,
82}
83
84/// Block-local sampled marginal correction (issue #784).
85///
86/// `value` is `Δ_b` (added to the block marginal log-likelihood, subtracted from
87/// the REML/LAML cost); `rho_gradient` is the explicit penalty-score channel (a)
88/// of the gradient exactness contract; `moments` carries the channels (b)–(d) the
89/// gam-solve assembly contracts against fields it already owns.
90#[derive(Clone, Debug)]
91pub struct BlockSampledMarginal {
92    /// `Δ_b`: additive correction to the block marginal log-likelihood.
93    pub value: f64,
94    /// `∂Δ_b/∂ρ`, length `rho_dim()` — explicit channel (a) ONLY.
95    pub rho_gradient: Array1<f64>,
96    /// Importance-sampling effective sample size (draws), for trust gating.
97    pub importance_ess: f64,
98    /// Number of draws used.
99    pub n_draws: usize,
100    /// Gradient-channel moments for the exact (b)–(d) assembly; `None` only when
101    /// the block is empty (`m == 0`, where the correction is zero).
102    pub moments: Option<BlockSampledMoments>,
103}
104
105/// Honest posterior summary from sampling the proper Gaussian posterior
106/// `N(mode, H⁻¹)` — the terminal never-fail rung of the custom-family
107/// covariance escalation. Field-for-field the monolith `hmc_io` type.
108#[derive(Clone, Debug)]
109pub struct GaussianModePosterior {
110    /// Coefficient draws in original (un-whitened) space: `(n_draws, dim)`.
111    pub samples: Array2<f64>,
112    /// Posterior mean (≈ the seeded mode for a Gaussian target).
113    pub posterior_mean: Array1<f64>,
114    /// Per-coordinate posterior standard deviation (honest SEs).
115    pub posterior_std: Array1<f64>,
116    /// Split-chain R̂ mixing diagnostic.
117    pub rhat: f64,
118    /// Effective sample size.
119    pub ess: f64,
120}
121
122// ───────────────────────── pure threshold math (moved down) ──────────────────
123
124/// Auto-derive the per-direction skewness threshold `τ(n)` separating
125/// Laplace-trustworthy directions from those that need the higher-order
126/// correction / sampling fallback. Derived purely from the effective sample
127/// size, no tunable flag: `(5/24)γ_r² > 1/n_eff ⇔ |γ_r| > sqrt((24/5)/n_eff)`.
128pub fn laplace_skewness_threshold(n_eff: f64) -> f64 {
129    if !(n_eff > 0.0) {
130        return f64::INFINITY;
131    }
132    ((24.0 / 5.0) / n_eff).sqrt()
133}
134
135/// Adaptive, block-local Laplace-trustworthiness verdict (issue #784): flag the
136/// directions whose standardized skewness exceeds [`laplace_skewness_threshold`].
137/// No linear algebra of its own — consumes the directional cubic diagnostic.
138pub fn laplace_trustworthiness_from_skewness(
139    directional_skewness: &Array1<f64>,
140    n_eff: f64,
141) -> LaplaceTrustworthiness {
142    let threshold = laplace_skewness_threshold(n_eff);
143    let mut untrustworthy_directions = Vec::new();
144    let mut max_abs_skewness = 0.0_f64;
145    for (r, &gamma) in directional_skewness.iter().enumerate() {
146        let abs_gamma = if gamma.is_finite() { gamma.abs() } else { 0.0 };
147        max_abs_skewness = max_abs_skewness.max(abs_gamma);
148        if abs_gamma > threshold {
149            untrustworthy_directions.push(r);
150        }
151    }
152    LaplaceTrustworthiness {
153        directional_skewness: directional_skewness.clone(),
154        untrustworthy_directions,
155        threshold,
156        max_abs_skewness,
157    }
158}
159
160// ───────────────────────── caller-supplied excess evaluator ──────────────────
161
162/// Caller-supplied evaluator for the non-Gaussian remainder `ΔF(t)` of the local
163/// log-posterior, restricted to the curvature-heavy block subspace (issue #784).
164///
165/// Implemented by gam-solve's `Gam784BlockTarget`; consumed by
166/// [`LaplaceMarginalSampler::block_sampled_marginal_correction`]. Lives in this
167/// neutral crate so both the implementor (gam-solve) and the sampler impl (the
168/// gam-inference monolith) name the same trait without an SCC edge.
169pub trait BlockExcessTarget {
170    /// Dimension `m` of the block subspace (number of untrustworthy directions
171    /// being sampled).
172    fn block_dim(&self) -> usize;
173    /// Number of outer ρ coordinates the gradient is reported against.
174    fn rho_dim(&self) -> usize;
175    /// Block curvatures `λ_r` (the H-eigenvalues of the sampled directions),
176    /// length `block_dim()`.
177    fn block_curvatures(&self) -> &Array1<f64>;
178    /// Non-Gaussian remainder `ΔF(t)` at whitened block displacement `t`
179    /// (length `block_dim()`).
180    fn excess(&self, t: &Array1<f64>) -> f64;
181    /// ρ-gradient `∂ΔF/∂ρ_k` at the same `t`, length `rho_dim()` — the explicit
182    /// penalty-score channel (a).
183    fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64>;
184    /// Per-row displaced score `∂(D(η̂+s(t))/2φ)/∂η` evaluated at `η̂ + s(t)`
185    /// (length = number of observation rows): the only per-draw ingredient of
186    /// the exact-gradient channels (b)–(d) the assembly side cannot reconstruct.
187    fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64>;
188    /// The same per-row score channel at the undisplaced mode `η̂`.
189    fn base_neg_score(&self) -> Array1<f64>;
190
191    /// Fused `(excess(t), displaced_neg_score(t))`. The returned score is `None`
192    /// exactly when the excess is non-finite (an infeasible draw the sampler
193    /// discards before reading the score). The default preserves the two-call
194    /// behavior; implementors override to share the displacement + jet.
195    fn excess_with_displaced_neg_score(&self, t: &Array1<f64>) -> (f64, Option<Array1<f64>>) {
196        let excess = self.excess(t);
197        if excess.is_finite() {
198            (excess, Some(self.displaced_neg_score(t)))
199        } else {
200            (excess, None)
201        }
202    }
203
204    /// Batched [`Self::excess_with_displaced_neg_score`] over many whitened draws
205    /// (one draw per COLUMN, shape `block_dim() × n_draws`). Batching may only
206    /// change HOW the shared linear algebra is computed (one BLAS-3 product over
207    /// all columns), never WHAT is computed. The default preserves the per-column
208    /// behavior exactly; the GLM implementor overrides it.
209    fn excess_with_displaced_neg_score_batch(
210        &self,
211        draws: &Array2<f64>,
212    ) -> Vec<(f64, Option<Array1<f64>>)> {
213        let n_draws = draws.ncols();
214        let mut out = Vec::with_capacity(n_draws);
215        let mut t = Array1::<f64>::zeros(draws.nrows());
216        for s in 0..n_draws {
217            t.assign(&draws.column(s));
218            out.push(self.excess_with_displaced_neg_score(&t));
219        }
220        out
221    }
222}
223
224// ───────────────────────── injected sampler traits ───────────────────────────
225
226/// The gam-inference-tier sampler for the #784 block-local Laplace correction.
227///
228/// Implemented UP in the monolith over `hmc_io`
229/// (`laplace_directional_cubic_diagnostic` + `block_sampled_marginal_correction`)
230/// and injected DOWN via [`set_laplace_marginal_sampler`]. gam-solve calls
231/// through [`laplace_marginal_sampler`].
232pub trait LaplaceMarginalSampler: Send + Sync {
233    /// Per-direction standardized cubic skewness `γ_r` of the local posterior:
234    /// returns `(max_r |γ_r|, γ)`. Pure eigen-diagnostic (no sampling), but kept
235    /// behind the trait because it lives in the sampler module up-tier.
236    fn directional_cubic_diagnostic(
237        &self,
238        hessian: &Array2<f64>,
239        design: &DesignMatrix,
240        c_weights: &Array1<f64>,
241        refine_supremum: bool,
242    ) -> Result<(f64, Array1<f64>), String>;
243
244    /// Estimate `Δ_b` and its ρ-gradient by importance sampling against the local
245    /// Laplace Gaussian, contracting the caller-supplied [`BlockExcessTarget`].
246    fn block_sampled_marginal_correction(
247        &self,
248        target: &dyn BlockExcessTarget,
249    ) -> Result<BlockSampledMarginal, String>;
250}
251
252/// The gam-inference-tier sampler for the never-fail Gaussian mode posterior
253/// (custom-family covariance escalation). Implemented UP over
254/// `hmc_io::sample_gaussian_mode_posterior` (which auto-derives its
255/// `NutsConfig::for_dimension(mode.len())` internally — that NUTS config never
256/// crosses the contract) and injected DOWN via
257/// [`set_gaussian_mode_posterior_sampler`].
258pub trait GaussianModePosteriorSampler: Send + Sync {
259    /// Sample `N(mode, precision⁻¹)`. `Err` only for a structurally impossible
260    /// request (dimension mismatch, non-PSD precision after symmetrization) —
261    /// never for "did not converge".
262    fn sample_gaussian_mode_posterior(
263        &self,
264        mode: ArrayView1<f64>,
265        precision: ArrayView2<f64>,
266    ) -> Result<GaussianModePosterior, String>;
267}
268
269// ───────────────────────── process-level injection registry ──────────────────
270
271static LAPLACE_MARGINAL_SAMPLER: OnceLock<Box<dyn LaplaceMarginalSampler>> = OnceLock::new();
272static GAUSSIAN_MODE_POSTERIOR_SAMPLER: OnceLock<Box<dyn GaussianModePosteriorSampler>> =
273    OnceLock::new();
274
275/// Register the monolith's `hmc_io`-backed #784 Laplace-correction sampler.
276/// Called once at process init by the gam-inference tier. First writer wins;
277/// a later call is ignored (returns `Err` with the boxed value) so a re-init can
278/// never swap a live sampler mid-run.
279pub fn set_laplace_marginal_sampler(
280    sampler: Box<dyn LaplaceMarginalSampler>,
281) -> Result<(), Box<dyn LaplaceMarginalSampler>> {
282    LAPLACE_MARGINAL_SAMPLER.set(sampler)
283}
284
285/// The registered #784 Laplace-correction sampler, or `None` when the sampler
286/// tier is not linked / not yet initialized (gam-solve then declines the
287/// correction, returning the zero contribution — a safe no-op).
288pub fn laplace_marginal_sampler() -> Option<&'static dyn LaplaceMarginalSampler> {
289    LAPLACE_MARGINAL_SAMPLER.get().map(|b| b.as_ref())
290}
291
292/// Register the monolith's `hmc_io`-backed never-fail Gaussian-mode-posterior
293/// sampler. First writer wins (see [`set_laplace_marginal_sampler`]).
294pub fn set_gaussian_mode_posterior_sampler(
295    sampler: Box<dyn GaussianModePosteriorSampler>,
296) -> Result<(), Box<dyn GaussianModePosteriorSampler>> {
297    GAUSSIAN_MODE_POSTERIOR_SAMPLER.set(sampler)
298}
299
300/// The registered never-fail Gaussian-mode-posterior sampler, or `None` when the
301/// sampler tier is not linked (the custom-family path then retains the
302/// optimizer-conditional covariance — its existing fallback).
303pub fn gaussian_mode_posterior_sampler() -> Option<&'static dyn GaussianModePosteriorSampler> {
304    GAUSSIAN_MODE_POSTERIOR_SAMPLER.get().map(|b| b.as_ref())
305}