Skip to main content

gam_problem/
rho_posterior.rs

1//! `ρ`-posterior certificate / escalation DATA types (contract-down #1521).
2//!
3//! These are the plain-data carriers that a fit result STORES
4//! (`UnifiedFitResult::rho_posterior_{certificate,escalation}`) and that the
5//! gam-solve REML evaluator returns. The COMPUTATION that produces them — the
6//! PSIS certificate, the Tier-1 Gauss-Hermite quadrature, and the Tier-2 NUTS
7//! escalation (which pulls the gam-inference `hmc_io` sampler) — stays UP in the
8//! monolith `inference::rho_posterior`, which re-exports these types so its
9//! construction sites name them unchanged. Contract-downed here (the neutral
10//! criterion-contract crate) so gam-solve can store/return them without a
11//! back-edge into gam-inference.
12
13use ndarray::{Array1, Array2};
14use std::sync::OnceLock;
15
16/// Reliability tier read off the Pareto tail-shape `k̂` of the `ρ`-importance
17/// weights.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum RhoCertificate {
20    /// `k̂ < 0.5`: the Laplace proposal is excellent — the plug-in (REML
21    /// conditional) intervals plus the first-order `V_ρ` correction are
22    /// certified adequate; `ρ`-uncertainty does not need a heavier treatment.
23    PlugInCertified,
24    /// `0.5 ≤ k̂ ≤ 0.7`: the proposal is usable but the self-normalized
25    /// importance weights should be used to correct moments.
26    ImportanceCorrect,
27    /// `k̂ > 0.7`: the Laplace proposal poorly captures `π(ρ|y)`; escalate to
28    /// quadrature (small `K`) or NUTS over `ρ`.
29    Escalate,
30}
31
32impl RhoCertificate {
33    pub fn from_k_hat(k_hat: f64) -> Self {
34        if !k_hat.is_finite() || k_hat > 0.7 {
35            RhoCertificate::Escalate
36        } else if k_hat < 0.5 {
37            RhoCertificate::PlugInCertified
38        } else {
39            RhoCertificate::ImportanceCorrect
40        }
41    }
42}
43
44/// The Tier-0 `ρ`-uncertainty certificate for a fit.
45#[derive(Debug, Clone)]
46pub struct RhoPosteriorCertificate {
47    /// Pareto tail-shape of the importance weights — the reliability diagnostic.
48    pub k_hat: f64,
49    /// The reliability tier derived from `k_hat`.
50    pub certificate: RhoCertificate,
51    /// Number of proposal draws `M`.
52    pub n_samples: usize,
53    /// Self-normalized importance weights (length `M`), Pareto-smoothed. These
54    /// turn the `M` conditional Gaussians into a free self-normalized mixture
55    /// when the tier is `ImportanceCorrect`.
56    pub weights: Array1<f64>,
57    /// Kish effective sample size `(Σw)² / Σw²` — how many of the `M` draws are
58    /// "really" contributing after importance weighting.
59    pub effective_sample_size: f64,
60}
61
62/// One node of the criterion-closure Tier-1 mixture (#938): a `ρ` location, its
63/// normalized posterior mass, and the exact profiled criterion value there.
64#[derive(Debug, Clone)]
65pub struct RhoMixtureNode {
66    /// Smoothing parameters at this node.
67    pub rho: Array1<f64>,
68    /// Normalized node probability `w_m ∝ exp(−criterion(ρ_m) + criterion(ρ̂)) ×
69    /// GH weight × exp(½‖z_m‖²)`.
70    pub weight: f64,
71    /// Normalized log node probability.
72    pub log_weight: f64,
73    /// Exact profiled criterion value at the node (`+∞` for infeasible nodes,
74    /// which carry zero weight).
75    pub cost: f64,
76}
77
78/// Tier-1 deliverable (#938): `π(ρ|y)` as a discrete mixture of conditional
79/// Gaussians, with the posterior moment summary of `ρ` itself.
80///
81/// The conditional Gaussian at each node is exactly what the engine already
82/// produces at fixed `ρ`; this struct owns the node locations and weights, and
83/// `mixture_coefficient_covariance` (monolith `inference::rho_posterior`)
84/// assembles the mixture-corrected coefficient covariance from per-node
85/// conditionals supplied by the caller.
86#[derive(Debug, Clone)]
87pub struct RhoPosteriorMixture {
88    /// Quadrature nodes with normalized weights (weights sum to 1).
89    pub nodes: Vec<RhoMixtureNode>,
90    /// Posterior mean of `ρ`: `Σ_m w_m ρ_m`.
91    pub mean: Array1<f64>,
92    /// Posterior covariance of `ρ`: `Σ_m w_m (ρ_m−ρ̄)(ρ_m−ρ̄)ᵀ`.
93    pub covariance: Array2<f64>,
94    /// Kish ESS of the node weights `(Σw)²/Σw²` — how non-Gaussian the exact
95    /// posterior is relative to the Laplace proposal (max = node count).
96    pub effective_sample_size: f64,
97}
98
99/// Tier-2 deliverable (#938): `π(ρ|y)` draws from NUTS with the exact profiled
100/// gradient, whitened by the exact outer Hessian at `ρ̂`.
101#[derive(Debug, Clone)]
102pub struct RhoPosteriorSamples {
103    /// Draws in ρ space: `(n_draws, K)`.
104    pub samples: Array2<f64>,
105    /// Posterior mean of `ρ`.
106    pub mean: Array1<f64>,
107    /// Posterior covariance of `ρ` (sample covariance of the draws).
108    pub covariance: Array2<f64>,
109    /// Split-chain R̂ mixing diagnostic.
110    pub rhat: f64,
111    /// Effective sample size.
112    pub ess: f64,
113    /// Whether the chains mixed (R̂ < 1.1).
114    pub converged: bool,
115}
116
117/// The auto-selected escalation outcome when the Tier-0 certificate reads
118/// [`RhoCertificate::Escalate`] (#938): Tier 1 (deterministic quadrature) for
119/// `K ≤ 4`, Tier 2 (NUTS over `ρ`) for `K ≤ 16`, and an HONEST report that
120/// escalation is unavailable beyond that — never a silently-degraded answer.
121#[derive(Debug, Clone)]
122pub enum RhoPosteriorEscalation {
123    /// Tier 1: deterministic Gauss-Hermite mixture (`K ≤ 4`).
124    Quadrature(RhoPosteriorMixture),
125    /// Tier 2: NUTS draws with the exact profiled gradient (`5 ≤ K ≤ 16`).
126    Nuts(RhoPosteriorSamples),
127    /// Escalation could not run (dimension beyond the NUTS cap, or the chosen
128    /// tier failed); intervals remain plug-in + first-order corrected, and the
129    /// fit reports WHY.
130    Unavailable { n_params: usize, reason: String },
131}
132
133// ───────────────────────── injected escalator trait (#1521) ──────────────────
134
135/// The gam-inference-tier producer of the Tier-0 `ρ`-certificate and the
136/// auto-selected Tier-1/Tier-2 escalation (trait-inversion #1521).
137///
138/// The COMPUTATION — the PSIS certificate, the Gauss-Hermite quadrature, and
139/// the Tier-2 NUTS over `ρ` — pulls the gam-inference `hmc_io` sampler, so it
140/// STAYS UP in the monolith `inference::rho_posterior`. That module implements
141/// this trait over its real `rho_posterior_certificate` / `escalate_rho_posterior`
142/// functions and injects the impl DOWN via [`set_rho_posterior_escalator`];
143/// gam-solve's REML evaluator calls THROUGH [`rho_posterior_escalator`]. Only
144/// neutral types (ndarray + the contract-downed `ρ`-posterior carriers) and
145/// caller-supplied criterion closures cross this surface — no gam-inference type
146/// is threaded, so the trait can live in this neutral crate.
147///
148/// When no impl is registered (a build that never links the sampler tier) the
149/// getter returns `None` and gam-solve declines the certificate/escalation
150/// entirely (`(None, None)`), leaving the plug-in + first-order intervals — its
151/// existing decline outcome, no behavioral cliff and no stub.
152pub trait RhoPosteriorEscalator: Send + Sync {
153    /// Tier-0 PSIS `ρ`-certificate. `criterion` evaluates the outer criterion
154    /// `−log π(ρ|y)` at a trial `ρ` (`None` for infeasible `ρ`). Returns `None`
155    /// when the certificate cannot be formed (see the monolith implementation).
156    fn rho_posterior_certificate(
157        &self,
158        rho_hat: &Array1<f64>,
159        outer_hessian: &Array2<f64>,
160        criterion: &dyn Fn(&Array1<f64>) -> Option<f64>,
161        n_samples: Option<usize>,
162    ) -> Option<RhoPosteriorCertificate>;
163
164    /// Auto-selected escalation (Tier-1 quadrature / Tier-2 NUTS / honest
165    /// `Unavailable`). `criterion` returns the exact profiled criterion value,
166    /// `criterion_and_grad` the value plus the exact LAML `ρ`-gradient; both are
167    /// `None` for infeasible `ρ`.
168    fn escalate_rho_posterior(
169        &self,
170        rho_hat: &Array1<f64>,
171        outer_hessian: &Array2<f64>,
172        criterion: &mut dyn FnMut(&Array1<f64>) -> Option<f64>,
173        criterion_and_grad: &mut (dyn FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send),
174    ) -> RhoPosteriorEscalation;
175}
176
177static RHO_POSTERIOR_ESCALATOR: OnceLock<Box<dyn RhoPosteriorEscalator>> = OnceLock::new();
178
179/// Register the monolith's `hmc_io`-backed `ρ`-posterior certificate/escalation
180/// producer. Called once at process init by the gam-inference tier. First writer
181/// wins; a later call is ignored (returns `Err` with the boxed value) so a
182/// re-init can never swap a live producer mid-run.
183pub fn set_rho_posterior_escalator(
184    escalator: Box<dyn RhoPosteriorEscalator>,
185) -> Result<(), Box<dyn RhoPosteriorEscalator>> {
186    RHO_POSTERIOR_ESCALATOR.set(escalator)
187}
188
189/// The registered `ρ`-posterior certificate/escalation producer, or `None` when
190/// the sampler tier is not linked / not yet initialized (gam-solve then declines
191/// the certificate and escalation — a safe no-op leaving plug-in intervals).
192pub fn rho_posterior_escalator() -> Option<&'static dyn RhoPosteriorEscalator> {
193    RHO_POSTERIOR_ESCALATOR.get().map(|b| b.as_ref())
194}