Skip to main content

gam_inference/
hmc_io.rs

1//! NUTS Sampler using general-mcmc
2//!
3//! This module provides NUTS (No-U-Turn Sampler) for honest uncertainty
4//! quantification after PIRLS convergence.
5//!
6//! # Design
7//!
8//! Since general-mcmc's NUTS uses an identity mass matrix, we whiten the
9//! parameter space using the Cholesky decomposition of the inverse Hessian:
10//!
11//! - Transform: β = μ + L @ z  (where L L^T = H^{-1})
12//! - The whitened space has unit covariance, so NUTS mixes efficiently
13//! - Samples are un-transformed back to the original space
14//!
15//! # Analytical Gradients
16//!
17//! We override `unnorm_logp_and_grad` to compute gradients analytically using
18//! ndarray, avoiding burn's autodiff overhead. The gradient computation mirrors
19//! the true log-posterior gradient (not the PIRLS working gradient).
20//!
21//! # Memory Efficiency
22//!
23//! Large data (design matrix, response, etc.) is wrapped in `Arc` to allow
24//! sharing across chains without duplication when general-mcmc clones the target.
25
26use crate::gpu_polya_gamma::{PgSeed, PolyaGammaBatchInput};
27use faer::Side;
28use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh, fast_ata_into, fast_atv, fast_av_into};
29use gam_linalg::matrix::DesignMatrix;
30use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
31use gam_models::wiggle::monotone_wiggle_basis_with_derivative_order;
32use gam_problem::types::{
33    InverseLink, LikelihoodSpec, ResponseFamily, RhoPrior, StandardLink, is_valid_tweedie_power,
34};
35use gam_solve::estimate::reml::FirthDenseOperator;
36use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
37use gam_solve::estimate::{
38    EstimationError, UnifiedFitResult, validate_explicit_dense_hessian_for_whitening,
39};
40use gam_solve::mixture_link::{
41    InverseLinkKernel, LinkParamPartials, inverse_link_jet_for_inverse_link, softmax_last_fixedzero,
42};
43use gam_terms::construction::CanonicalPenalty;
44use general_mcmc::generic_hmc::HamiltonianTarget;
45pub use general_mcmc::generic_nuts::NUTSMassMatrixConfig;
46use general_mcmc::generic_nuts::{GenericNUTS, MassMatrixAdaptation};
47use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
48use rand::{RngExt, SeedableRng, rngs::StdRng};
49use serde::{Deserialize, Serialize};
50use std::cell::RefCell;
51use std::fmt;
52use std::sync::{Arc, Mutex};
53
54/// Binomial families whose inverse link has a Fisher-weight jet
55/// (`fisher_weight_jet5`) support the Jeffreys/Firth term. This is the
56/// link-general set shared with the REML/PIRLS Firth operator; the canonical
57/// logit case is unchanged.
58#[inline]
59fn likelihood_spec_supports_firth(spec: &LikelihoodSpec) -> bool {
60    spec.supports_firth()
61}
62
63/// Inverse link to evaluate the Fisher working weight with for the Jeffreys
64/// term. Returns `None` for unsupported specs.
65#[inline]
66fn likelihood_spec_jeffreys_link(spec: &LikelihoodSpec) -> Option<InverseLink> {
67    if likelihood_spec_supports_firth(spec) {
68        Some(spec.link.clone())
69    } else {
70        None
71    }
72}
73
74/// Typed error variants for the HMC / NUTS sampling module.
75///
76/// External-facing helpers in this module continue to return
77/// `Result<_, String>`; this enum is materialized internally and converted
78/// at the public boundary via `.map_err(String::from)` so that the error
79/// text remains byte-identical to the previous `format!` output.
80#[derive(Debug, Clone)]
81pub enum HmcError {
82    /// Sampler state (penalty / Hessian / mode / posterior values) contains
83    /// NaN or Inf where finiteness is required.
84    NonFiniteState { reason: String },
85    /// Configuration value (e.g. `target_accept`, unit-weight requirement)
86    /// is out of range or otherwise invalid.
87    InvalidConfig { reason: String },
88    /// Dimensions of the supplied matrices / vectors are inconsistent.
89    DimensionMismatch { reason: String },
90    /// Firth/Jeffreys correction was requested for a family that does not
91    /// support it.
92    FirthUnsupported { reason: String },
93    /// Inverse-link state does not match the requested likelihood family in
94    /// the joint (β, ρ) sampler.
95    LinkMismatch { reason: String },
96    /// Likelihood family is not implemented in the current sampling path.
97    UnsupportedFamily { reason: String },
98    /// Sampling produced no usable output (empty kept set, non-finite
99    /// summary statistic, etc.).
100    SamplingFailed { reason: String },
101}
102
103impl fmt::Display for HmcError {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self {
106            HmcError::NonFiniteState { reason }
107            | HmcError::InvalidConfig { reason }
108            | HmcError::DimensionMismatch { reason }
109            | HmcError::FirthUnsupported { reason }
110            | HmcError::LinkMismatch { reason }
111            | HmcError::UnsupportedFamily { reason }
112            | HmcError::SamplingFailed { reason } => f.write_str(reason),
113        }
114    }
115}
116
117impl From<HmcError> for String {
118    fn from(err: HmcError) -> String {
119        err.to_string()
120    }
121}
122
123/// Upper bound on the autocorrelation lag summed in the effective-sample-size
124/// estimate. The Geyer initial-positive-sequence sum normally self-truncates
125/// long before this, but a hard cap bounds the `O(n·lag)` work for very long
126/// chains where the autocorrelation tail is numerical noise.
127const MAX_AUTOCORRELATION_LAG: usize = 1000;
128
129/// Floor on the lag-0 autocovariance (chain variance) used as the denominator in
130/// the autocorrelation ratios, guarding against division by zero for a chain
131/// that is numerically constant.
132const AUTOCOVARIANCE_FLOOR: f64 = 1e-16;
133
134/// Compute split-chain R-hat and ESS using the Gelman-Rubin diagnostic.
135///
136/// This is the standard split-chain formulation (no rank normalization).
137/// Returns (max_rhat, min_ess) across dimensions.
138fn compute_split_rhat_and_ess(samples: &Array3<f64>) -> (f64, f64) {
139    let n_chains = samples.shape()[0];
140    let n_samples = samples.shape()[1];
141    let dim = samples.shape()[2];
142
143    if n_chains < 2 || n_samples < 4 {
144        return (1.0, n_chains as f64 * n_samples as f64 * 0.5);
145    }
146
147    // Split each chain in half to detect non-stationarity
148    let half = n_samples / 2;
149    let n_split_chains = n_chains * 2;
150    let n_split_samples = half;
151
152    let mut max_rhat = 0.0f64;
153    let mut min_ess = f64::INFINITY;
154
155    #[inline]
156    fn splitvalue(
157        samples: &Array3<f64>,
158        n_chains: usize,
159        half: usize,
160        dim: usize,
161        sc: usize,
162        t: usize,
163    ) -> f64 {
164        let chain = sc % n_chains;
165        if sc < n_chains {
166            samples[[chain, t, dim]]
167        } else {
168            samples[[chain, half + t, dim]]
169        }
170    }
171
172    fn ess_from_split_dimension(
173        samples: &Array3<f64>,
174        n_chains: usize,
175        half: usize,
176        dim: usize,
177    ) -> f64 {
178        let m = n_chains * 2;
179        let n = half;
180        if m == 0 || n < 4 {
181            return (m * n).max(1) as f64;
182        }
183
184        let mut means = vec![0.0_f64; m];
185        let mut gamma0 = vec![0.0_f64; m];
186        for sc in 0..m {
187            let mut sum = 0.0;
188            for t in 0..n {
189                sum += splitvalue(samples, n_chains, half, dim, sc, t);
190            }
191            let mean = sum / n as f64;
192            means[sc] = mean;
193            let mut g0 = 0.0;
194            for t in 0..n {
195                let d = splitvalue(samples, n_chains, half, dim, sc, t) - mean;
196                g0 += d * d;
197            }
198            gamma0[sc] = (g0 / n as f64).max(AUTOCOVARIANCE_FLOOR);
199        }
200
201        let max_lag = (n - 1).min(MAX_AUTOCORRELATION_LAG);
202        let mut tau = 1.0_f64;
203        let mut lag = 1usize;
204        while lag < max_lag {
205            let mut pair = 0.0_f64;
206            for l in [lag, lag + 1] {
207                if l > max_lag {
208                    continue;
209                }
210                let mut rho_l = 0.0;
211                for sc in 0..m {
212                    let mu = means[sc];
213                    let mut cov = 0.0;
214                    let denom = (n - l) as f64;
215                    for t in 0..(n - l) {
216                        let x0 = splitvalue(samples, n_chains, half, dim, sc, t);
217                        let x1 = splitvalue(samples, n_chains, half, dim, sc, t + l);
218                        cov += (x0 - mu) * (x1 - mu);
219                    }
220                    cov /= denom;
221                    rho_l += cov / gamma0[sc];
222                }
223                rho_l /= m as f64;
224                pair += rho_l;
225            }
226            if !pair.is_finite() || pair <= 0.0 {
227                break;
228            }
229            tau += 2.0 * pair;
230            lag += 2;
231        }
232        if !tau.is_finite() || tau <= 0.0 {
233            return 1.0;
234        }
235        let total = (m * n) as f64;
236        (total / tau).clamp(1.0, total)
237    }
238
239    let mut chain_means = vec![0.0_f64; n_split_chains];
240    let mut chainvars = vec![0.0_f64; n_split_chains];
241    for d in 0..dim {
242        for chain in 0..n_chains {
243            // First half
244            let mut sum1 = 0.0;
245            for i in 0..half {
246                sum1 += samples[[chain, i, d]];
247            }
248            let mean1 = sum1 / half as f64;
249            let mut var1 = 0.0;
250            for i in 0..half {
251                let diff = samples[[chain, i, d]] - mean1;
252                var1 += diff * diff;
253            }
254            var1 /= (half - 1).max(1) as f64;
255            let first_idx = chain;
256            chain_means[first_idx] = mean1;
257            chainvars[first_idx] = var1;
258
259            // Second half
260            let mut sum2 = 0.0;
261            for i in half..(2 * half) {
262                sum2 += samples[[chain, i, d]];
263            }
264            let mean2 = sum2 / half as f64;
265            let mut var2 = 0.0;
266            for i in half..(2 * half) {
267                let diff = samples[[chain, i, d]] - mean2;
268                var2 += diff * diff;
269            }
270            var2 /= (half - 1).max(1) as f64;
271            let second_idx = n_chains + chain;
272            chain_means[second_idx] = mean2;
273            chainvars[second_idx] = var2;
274        }
275
276        // Within-chain variance W
277        let w: f64 = chainvars.iter().copied().sum::<f64>() / n_split_chains as f64;
278
279        // Between-chain variance B
280        let overall_mean: f64 = chain_means.iter().copied().sum::<f64>() / n_split_chains as f64;
281        let b: f64 = chain_means
282            .iter()
283            .map(|m| (m - overall_mean).powi(2))
284            .sum::<f64>()
285            * n_split_samples as f64
286            / (n_split_chains - 1) as f64;
287
288        // Estimated variance
289        let var_hat = (n_split_samples as f64 - 1.0) / n_split_samples as f64 * w
290            + b / n_split_samples as f64;
291
292        // R-hat
293        let rhat_d = if w > 1e-10 { (var_hat / w).sqrt() } else { 1.0 };
294        max_rhat = max_rhat.max(rhat_d);
295
296        // Real ESS via split-chain autocorrelation with Geyer IPS truncation.
297        let ess_d = ess_from_split_dimension(samples, n_chains, half, d);
298        min_ess = min_ess.min(ess_d);
299    }
300
301    (max_rhat, min_ess.max(1.0))
302}
303
304/// Solve L^T * X = I where L is lower triangular.
305///
306/// Returns X = L^{-T} (the inverse transpose of L).
307///
308/// This is the correct way to compute the whitening transform matrix:
309/// Given H = L L^T (Cholesky), we need W where W W^T = H^{-1}
310/// Since H^{-1} = L^{-T} L^{-1}, we have W = L^{-T}.
311///
312/// Implementation strategy (math-equivalent to back-substitution on L^T):
313/// We compute L^{-1} column-wise via forward substitution on L, then the
314/// result is `L^{-1}` transposed. Forward-substituting column `c` of L^{-1}
315/// uses `L`'s rows (which are contiguous in row-major `Array2`), giving
316/// stride-1 inner loops instead of the strided `l[[j, i]]` (column-major
317/// access pattern) and double-indexed writes of the original. We also
318/// exploit the triangular structure of `L^{-1}` (entries above the diagonal
319/// are zero), skipping ~half of the inner work compared to the previous
320/// version which traversed `i = (0..dim).rev()` for every column.
321///
322/// Total cost: ~dim^3 / 6 multiply-adds (down from dim^3 / 2), with all
323/// inner loops on contiguous slices.
324fn solve_upper_triangular_transpose(l: &Array2<f64>, dim: usize) -> Array2<f64> {
325    let mut result = Array2::<f64>::zeros((dim, dim));
326    if dim == 0 {
327        return result;
328    }
329
330    // Pull contiguous row slice access from L (row-major standard layout).
331    // Falls back to a one-time owned copy if `l` is not standard-layout
332    // (e.g. a transposed view); both branches feed the same inner loop.
333    let l_owned;
334    let l_rows: &[f64] = if let Some(s) = l.as_slice() {
335        s
336    } else {
337        l_owned = l.to_owned();
338        l_owned
339            .as_slice()
340            .expect("owned standard-layout Array2 has contiguous storage")
341    };
342
343    // Scratch column for L^{-1}[:, col]; reused across columns.
344    let mut y = vec![0.0_f64; dim];
345
346    for col in 0..dim {
347        // Forward-substitute L * y = e_col. y[i] = 0 for i < col.
348        // Diagonal term:
349        let d_col = l_rows[col * dim + col];
350        let inv_d_col = if d_col.abs() > 1e-15 {
351            1.0 / d_col
352        } else {
353            0.0
354        };
355        y[col] = inv_d_col;
356
357        // Below-diagonal entries: y[i] = -(sum_{j=col..i} L[i,j] * y[j]) / L[i,i].
358        // Each inner loop is a stride-1 dot product on row `i` of L (contiguous).
359        for i in (col + 1)..dim {
360            let row_off = i * dim;
361            let l_row = &l_rows[row_off + col..row_off + i];
362            let y_seg = &y[col..i];
363            // Both operands are contiguous slices of equal length; the loop
364            // is a straight-line stride-1 reduction the optimizer can
365            // auto-vectorize.
366            let mut sum = 0.0_f64;
367            for k in 0..l_row.len() {
368                sum += l_row[k] * y_seg[k];
369            }
370            let d = l_rows[row_off + i];
371            y[i] = if d.abs() > 1e-15 { -sum / d } else { 0.0 };
372        }
373
374        // Write the column into result transposed: result[col, i] = y[i] for i >= col.
375        // result[i, col] is left at zero for i < col (upper-triangular L^{-T}).
376        // That matches `result[col, i]` filling row `col` from column `col` rightward.
377        let res_row_start = col * dim + col;
378        let res_row = &mut result.as_slice_mut().expect("owned Array2 contiguous")
379            [res_row_start..res_row_start + (dim - col)];
380        for (k, slot) in res_row.iter_mut().enumerate() {
381            *slot = y[col + k];
382        }
383
384        // Clear scratch positions we wrote, so the next column starts clean above.
385        for slot in &mut y[col..dim] {
386            *slot = 0.0;
387        }
388    }
389
390    result
391}
392
393struct WhiteningTransform {
394    chol: Array2<f64>,
395    chol_t: Array2<f64>,
396}
397
398fn hessian_whitening_transform(
399    hessian: ArrayView2<f64>,
400    dim: usize,
401    cov_scale: f64,
402    cholesky_error_prefix: &str,
403) -> Result<WhiteningTransform, String> {
404    let hessian_owned = hessian.to_owned();
405    let chol_factor = hessian_owned
406        .cholesky(Side::Lower)
407        .map_err(|e| format!("{cholesky_error_prefix}: {:?}", e))?;
408    let l_h = chol_factor.lower_triangular();
409    let mut chol = solve_upper_triangular_transpose(&l_h, dim);
410    let sqrt_cov_scale = cov_scale.max(0.0).sqrt();
411    if (sqrt_cov_scale - 1.0).abs() > 0.0 {
412        chol.mapv_inplace(|v| v * sqrt_cov_scale);
413    }
414    let chol_t = chol.t().to_owned();
415    Ok(WhiteningTransform { chol, chol_t })
416}
417
418/// Shared data for NUTS posterior (wrapped in Arc to prevent cloning).
419///
420/// This struct holds read-only data that is shared across all chains.
421/// Using Arc prevents memory explosion when general-mcmc clones the target.
422#[derive(Clone)]
423struct SharedData {
424    /// Design matrix X [n_samples, dim]
425    x: Arc<Array2<f64>>,
426    /// Response vector y [n_samples]
427    y: Arc<Array1<f64>>,
428    /// Observation/case weights [n_samples]
429    weights: Arc<Array1<f64>>,
430    /// MAP estimate (mode) μ [dim]
431    mode: Arc<Array1<f64>>,
432    /// Fixed additive offset on the linear predictor: η = Xβ + offset
433    /// [n_samples]. `None` when the model was fit without an offset (the common
434    /// case), avoiding a per-step O(n) add of zeros. The offset shifts η only —
435    /// it is constant in β, so ∂η/∂β = X is unchanged and no gradient,
436    /// Hessian, or penalty term is affected. Dropping it (the historical
437    /// behaviour) silently sampled the wrong posterior for any `--offset-column`
438    /// fit (#882).
439    offset: Option<Arc<Array1<f64>>>,
440    /// Auxiliary log-link family parameter: Gamma shape, Tweedie power, or NB theta.
441    gamma_shape: f64,
442    /// Dispersion parameter φ (Gaussian: σ²; Gamma: 1/shape; `Known(1.0)` for
443    /// fixed-scale families). Consumed **only** by the likelihood adapters that
444    /// carry the dispersion in the data term itself: the profiled-Gaussian
445    /// log-likelihood and its gradient multiply through by `1/φ`, and the
446    /// Tweedie quasi-likelihood folds `1/φ` into its weight. It does NOT drive
447    /// the whitening or penalty scaling — those use the `cov_scale` invariant
448    /// (`NutsFamily::coefficient_covariance_scale`), which is `1.0` for Gamma
449    /// even though `φ ≠ 1`, because Gamma's dispersion already lives inside the
450    /// working weight (the `shape` factor in `gamma_log_logp_and_grad`). See
451    /// `inference::dispersion_cov` for the ownership invariants.
452    dispersion: gam_solve::model_types::Dispersion,
453    /// Number of samples
454    n_samples: usize,
455    /// Number of coefficients
456    dim: usize,
457}
458
459thread_local! {
460    static NUTS_RESIDUAL_SCRATCH: RefCell<Array1<f64>> = RefCell::new(Array1::zeros(0));
461}
462
463/// Whitened log-posterior target with analytical gradients.
464///
465/// Uses Arc for shared data to prevent memory explosion when cloned for chains.
466/// Uses faer for numerically stable Cholesky decomposition.
467/// Family mode for NUTS log-likelihood computation.
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub enum NutsFamily {
470    Gaussian,
471    BinomialLogit,
472    BinomialProbit,
473    BinomialCLogLog,
474    PoissonLog,
475    TweedieLog,
476    NegativeBinomialLog,
477    GammaLog,
478}
479
480impl NutsFamily {
481    #[inline]
482    fn likelihood_spec(self) -> LikelihoodSpec {
483        match self {
484            Self::Gaussian => LikelihoodSpec {
485                response: ResponseFamily::Gaussian,
486                link: InverseLink::Standard(StandardLink::Identity),
487            },
488            Self::BinomialLogit => LikelihoodSpec {
489                response: ResponseFamily::Binomial,
490                link: InverseLink::Standard(StandardLink::Logit),
491            },
492            Self::BinomialProbit => LikelihoodSpec {
493                response: ResponseFamily::Binomial,
494                link: InverseLink::Standard(StandardLink::Probit),
495            },
496            Self::BinomialCLogLog => LikelihoodSpec {
497                response: ResponseFamily::Binomial,
498                link: InverseLink::Standard(StandardLink::CLogLog),
499            },
500            Self::PoissonLog => LikelihoodSpec {
501                response: ResponseFamily::Poisson,
502                link: InverseLink::Standard(StandardLink::Log),
503            },
504            Self::TweedieLog => LikelihoodSpec {
505                response: ResponseFamily::Tweedie { p: 1.5 },
506                link: InverseLink::Standard(StandardLink::Log),
507            },
508            Self::NegativeBinomialLog => LikelihoodSpec {
509                response: ResponseFamily::NegativeBinomial {
510                    theta: 1.0,
511                    theta_fixed: false,
512                },
513                link: InverseLink::Standard(StandardLink::Log),
514            },
515            Self::GammaLog => LikelihoodSpec {
516                response: ResponseFamily::Gamma,
517                link: InverseLink::Standard(StandardLink::Log),
518            },
519        }
520    }
521
522    /// Coefficient-covariance scale for the whitened NUTS target — the
523    /// NUTS-family counterpart of
524    /// [`gam_problem::types::GlmLikelihoodSpec::coefficient_covariance_scale`] (#679).
525    ///
526    /// The sampler must reproduce the posterior `N(mode, Vb)` with
527    /// `Vb = scale · H⁻¹`, where `H = XᵀWX + S_λ` is the stored penalized
528    /// Hessian (penalty `S_λ` added **unscaled**). The returned `scale` is:
529    ///
530    /// * `profiled_gaussian_phi` (= σ̂²) for the **profiled Gaussian** identity
531    ///   model, whose working weight is scale-free (`W = priorweights`), so the
532    ///   stored `H` omits the dispersion and `Vb = σ̂²·H⁻¹`. The NUTS Gaussian
533    ///   log-likelihood is structurally this profiled form: scale-free
534    ///   residuals multiplied by `1/φ` (see `gaussian_logp_and_grad_into`).
535    /// * `1.0` for every **weight-carries-dispersion** family (Gamma, Tweedie,
536    ///   Negative-Binomial, and the fixed-scale Poisson/Binomial). Their
537    ///   working weight already folds in the reciprocal dispersion / full
538    ///   Fisher information — for Gamma-log the shape `ν = 1/φ` is baked into
539    ///   the likelihood score `∂ℓ/∂η = ν·(y/μ − 1)` — so the stored `H` is
540    ///   already the true penalized Hessian and `Vb = H⁻¹`. Multiplying by the
541    ///   dispersion again double-counts it and shrinks every posterior SD by
542    ///   `√dispersion` — exactly the Gamma-log defect addressed in #680.
543    ///
544    /// This single scalar governs BOTH the whitening preconditioner
545    /// (`L Lᵀ = scale·H⁻¹`, so `L` is scaled by `√scale`) and the target's
546    /// penalty weight (`penalty_scale = 1/scale`), keeping the sampled
547    /// posterior, its whitening metric, and the Wald `Vb` of #679 mutually
548    /// consistent. Crucially it does NOT key off the statistical dispersion
549    /// `φ`: Gamma carries `φ = 1/shape ≠ 1` yet still has `scale = 1`, because
550    /// that `φ` already lives inside `W`.
551    #[inline]
552    fn coefficient_covariance_scale(self, profiled_gaussian_phi: f64) -> f64 {
553        match self {
554            NutsFamily::Gaussian => profiled_gaussian_phi,
555            _ => 1.0,
556        }
557    }
558}
559
560/// Whitened-coordinate target for the No-U-Turn HMC sampler.
561///
562/// The posterior over β is reparameterized via `β = L z` where `L Lᵀ = H⁻¹`
563/// (Cholesky factor of the inverse posterior Hessian at the MAP), so that
564/// in `z`-coordinates the local curvature is approximately the identity.
565/// The struct holds the shared design, the whitening factor `L` and its
566/// transpose (for gradient chain-rule pull-back `∇_z = Lᵀ ∇_β`), the
567/// family-specific log-likelihood adapter, and a precomputed
568/// `M = Lᵀ S L` so the smoothing penalty `−½ βᵀSβ` becomes the cheap
569/// quadratic `−½ zᵀMz` inside the leapfrog hot loop.  Optionally adds
570/// the identifiable-subspace Firth/Jeffreys term to keep posterior modes
571/// away from infinity under separation.
572pub struct NutsPosterior {
573    /// Shared read-only data (Arc prevents duplication)
574    data: SharedData,
575    /// Transform: L where L L^T = H^{-1} (computed from Hessian)
576    /// This is the inverse-transpose of the Cholesky of H.
577    chol: Array2<f64>,
578    /// L^T for gradient chain rule: ∇z = L^T @ ∇_β
579    chol_t: Array2<f64>,
580    /// Family for log-likelihood computation
581    nuts_family: NutsFamily,
582    /// Whether to add the identifiable-subspace Jeffreys/Firth term to the
583    /// target
584    firth_enabled: bool,
585    /// Precomputed whitened-penalty operator `M = L^T S L` (dim×dim, symmetric
586    /// positive-semidefinite). The penalty term in z-coordinates is
587    ///   −0.5 βᵀSβ = −[c0 + (Lᵀ S μ)ᵀ z + 0.5 zᵀ M z],
588    /// so its z-gradient is just `−(L^T S μ + M z)` — no per-step `S·β` matvec
589    /// or `L^T·∇_β penalty` map is needed.
590    penalty_z_quad: Array2<f64>,
591    /// Precomputed `Lᵀ S μ` (length dim) — z-space gradient contribution from
592    /// the linear-in-z portion of the penalty.
593    penalty_z_lin: Array1<f64>,
594    /// Precomputed `0.5 μᵀ S μ` (scalar) — constant term of the penalty.
595    penalty_z_const: f64,
596    /// Coefficient-covariance scale `cov_scale` (#679/#680 invariant): the
597    /// `Vb = cov_scale·H⁻¹` multiplier. `σ̂²` for profiled Gaussian, `1.0` for
598    /// every weight-carries-dispersion family. Drives both the whitening
599    /// (`L Lᵀ = cov_scale·H⁻¹`) and the target penalty weight
600    /// (`penalty_scale = 1/cov_scale`).
601    cov_scale: f64,
602}
603
604impl NutsPosterior {
605    /// Creates a new posterior target from ndarray data.
606    ///
607    /// # Arguments
608    /// * `x` - Design matrix [n_samples, dim]
609    /// * `y` - Response vector [n_samples]
610    /// * `weights` - Observation/case weights [n_samples]
611    /// * `penalty_matrix` - Combined penalty S [dim, dim]
612    /// * `mode` - MAP estimate μ [dim]
613    /// * `hessian` - Hessian H [dim, dim] (NOT the inverse!)
614    /// * `nuts_family` - Family for log-likelihood computation
615    ///
616    /// # Numerical Stability
617    /// Accepts the Hessian directly and computes L = (chol(H))^{-T} via
618    /// triangular solves, which is more stable than explicitly inverting H.
619    pub fn new(
620        x: ArrayView2<f64>,
621        y: ArrayView1<f64>,
622        weights: ArrayView1<f64>,
623        penalty_matrix: ArrayView2<f64>,
624        mode: ArrayView1<f64>,
625        hessian: ArrayView2<f64>,
626        nuts_family: NutsFamily,
627        gamma_shape: f64,
628        dispersion: gam_solve::model_types::Dispersion,
629        firth_enabled: bool,
630    ) -> Result<Self, String> {
631        let n_samples = x.nrows();
632        let dim = x.ncols();
633
634        // Validate inputs are finite
635        if !penalty_matrix.iter().all(|x| x.is_finite()) {
636            return Err(HmcError::NonFiniteState {
637                reason: "Penalty matrix contains NaN or Inf values".to_string(),
638            }
639            .into());
640        }
641        if !hessian.iter().all(|x| x.is_finite()) {
642            return Err(HmcError::NonFiniteState {
643                reason: "Hessian matrix contains NaN or Inf values".to_string(),
644            }
645            .into());
646        }
647        if !mode.iter().all(|x| x.is_finite()) {
648            return Err(HmcError::NonFiniteState {
649                reason: "Mode vector contains NaN or Inf values".to_string(),
650            }
651            .into());
652        }
653
654        validate_firth_support(nuts_family, firth_enabled).map_err(String::from)?;
655        if nuts_family.likelihood_spec().is_binomial() {
656            validate_binary_responses("binomial NUTS", &y, &weights).map_err(String::from)?;
657        }
658        if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
659            validate_count_responses("negative-binomial NUTS", &y, &weights)
660                .map_err(String::from)?;
661        }
662
663        // Whitening metric: `L Lᵀ` must equal the posterior covariance the
664        // sampler reproduces, `Vb = cov_scale · H⁻¹` (#679/#680 invariant), so
665        // scale `L` by `√cov_scale`. Only the profiled-Gaussian model carries a
666        // non-unit scale (σ̂² = `dispersion.phi()`); every weight-carries-
667        // dispersion family (Gamma/Tweedie/NB) already folds its dispersion into
668        // the stored `H`, so `cov_scale == 1` and this is a no-op. This replaces
669        // a previous `sqrt_phi()` multiply that wrongly scaled Gamma (and any
670        // φ-bearing family) by `√φ`, mis-preconditioning against `φ·H⁻¹`.
671        let cov_scale = nuts_family.coefficient_covariance_scale(dispersion.phi());
672        let whitening = hessian_whitening_transform(
673            hessian,
674            dim,
675            cov_scale,
676            "Hessian Cholesky decomposition failed",
677        )?;
678        let chol = whitening.chol;
679        let chol_t = whitening.chol_t;
680
681        // Precompute the whitened penalty operator and constants so that the
682        // penalty contribution to logp/grad becomes a single symv against z.
683        // Math identity (β = μ + L z, L L^T = H^{-1}):
684        //   0.5 β^T S β = 0.5 μ^T S μ + (L^T S μ)^T z + 0.5 z^T (L^T S L) z
685        // and ∇_z [0.5 β^T S β] = L^T S μ + (L^T S L) z.
686        // This replaces three matvecs per leapfrog step (S·β, L·z used only
687        // for that purpose, and L^T·∇_β penalty) with one dim×dim symv.
688        let penalty_owned = penalty_matrix.to_owned();
689        let mode_owned = mode.to_owned();
690        let s_mu = penalty_owned.dot(&mode_owned);
691        let penalty_z_const = 0.5 * mode_owned.dot(&s_mu);
692        let penalty_z_lin = chol_t.dot(&s_mu);
693        // M = L^T S L = chol_t · (S · chol). Computed in two GEMMs at
694        // construction time only.
695        let s_chol = penalty_owned.dot(&chol);
696        let penalty_z_quad = chol_t.dot(&s_chol);
697
698        let data = SharedData {
699            x: Arc::new(x.to_owned()),
700            y: Arc::new(y.to_owned()),
701            weights: Arc::new(weights.to_owned()),
702            mode: Arc::new(mode_owned),
703            offset: None,
704            gamma_shape,
705            dispersion,
706            n_samples,
707            dim,
708        };
709
710        Ok(Self {
711            data,
712            chol,
713            chol_t,
714            nuts_family,
715            firth_enabled,
716            penalty_z_quad,
717            penalty_z_lin,
718            penalty_z_const,
719            cov_scale,
720        })
721    }
722
723    /// Attach a fixed additive offset to the linear predictor: η = Xβ + offset.
724    ///
725    /// The offset is constant in β, so the whitening geometry (`chol`), penalty
726    /// operators, and stored Hessian are all unchanged — only η (and hence the
727    /// per-observation working residual / mean) shifts. The fitted `mode` and
728    /// `hessian` handed to [`Self::new`] already correspond to the offset-trained
729    /// fit, so this only needs to restore the offset to the likelihood
730    /// evaluation. Returns an error if the offset length disagrees with the data
731    /// or carries non-finite entries.
732    fn with_offset(mut self, offset: ArrayView1<f64>) -> Result<Self, String> {
733        if offset.len() != self.data.n_samples {
734            return Err(HmcError::DimensionMismatch {
735                reason: format!(
736                    "NUTS offset length {} does not match {} observations",
737                    offset.len(),
738                    self.data.n_samples
739                ),
740            }
741            .into());
742        }
743        if !offset.iter().all(|v| v.is_finite()) {
744            return Err(HmcError::NonFiniteState {
745                reason: "NUTS offset contains NaN or Inf values".to_string(),
746            }
747            .into());
748        }
749        self.data.offset = Some(Arc::new(offset.to_owned()));
750        Ok(self)
751    }
752
753    fn compute_logp_and_grad_nd_into(
754        &self,
755        z: &Array1<f64>,
756        residual: &mut Array1<f64>,
757        grad: &mut Array1<f64>,
758    ) -> f64 {
759        // === Step 1: Transform z (whitened) -> β (original) ===
760        // β = μ + L @ z
761        let beta = self.data.mode.as_ref() + &self.chol.dot(z);
762
763        // === Step 2: Compute η = X @ β (+ offset) ===
764        let mut eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
765        if let Some(offset) = self.data.offset.as_ref() {
766            eta += offset.as_ref();
767        }
768
769        // === Step 3: Compute log-likelihood and gradient ===
770        let (ll, mut grad_ll_beta) = self.family_logp_and_grad_into(&eta, residual);
771
772        let mut firth_logdet = 0.0;
773        if self.firth_enabled {
774            match firth_jeffreys_logp_and_grad(self.nuts_family, &self.data, &eta) {
775                Ok((value, grad_beta_firth)) => {
776                    firth_logdet = value;
777                    grad_ll_beta += &grad_beta_firth;
778                }
779                Err(err) => {
780                    log::warn!(
781                        "[NUTS/Firth] Jeffreys target became invalid at the current state: {}",
782                        err
783                    );
784                    grad.fill(0.0);
785                    return f64::NEG_INFINITY;
786                }
787            }
788        }
789
790        // === Step 4: Penalty in z-coordinates (precomputed; see `new`) ===
791        //   −0.5 βᵀ S β  =  −[c0 + lᵀ z + 0.5 zᵀ M z]
792        //   ∇_z (−0.5 βᵀ S β) = −(l + M z)
793        // where l = L^T S μ, M = L^T S L, c0 = 0.5 μᵀ S μ.
794        // This single dim×dim symmetric matvec replaces both the per-step
795        // S·β multiply and the L^T·∇_β penalty chain-rule multiply, and lets
796        // the penalty value, β-gradient and chain rule fuse into one pass.
797        //
798        // Penalty weight in the un-whitened β-target
799        // `log p(β) = loglik(β) − penalty_scale · ½ βᵀSβ`. The invariant is
800        // `Vb = cov_scale · H⁻¹` with `H = XᵀWX + S` (penalty added unscaled),
801        // so the target curvature must equal `Vb⁻¹ = H/cov_scale`. The
802        // likelihood already supplies `−∇²ℓ = (data Fisher info)/cov_scale`
803        // (explicitly `/σ²` for profiled Gaussian, implicitly via the working
804        // weight / the `shape ≡ 1/φ` baked into `gamma_log_logp_and_grad` for
805        // the dispersion-carrying families), so the penalty must match it:
806        //   penalty_scale = 1/cov_scale.
807        // That is `1/σ²` for profiled Gaussian and exactly `1.0` for
808        // Gamma/Tweedie/NB/Poisson/Binomial. The previous code used
809        // `dispersion.inv_phi()` for GammaLog (= shape = 1/φ ≠ 1), which
810        // double-counted the dispersion in the sampled posterior (#680); the
811        // statistical dispersion `φ` is NOT `1/cov_scale` for Gamma because it
812        // already lives inside `W`. Mirrors `LinkWigglePosterior`.
813        let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
814        let mz = self.penalty_z_quad.dot(z);
815        let lin_term = self.penalty_z_lin.dot(z);
816        let quad_term = 0.5 * z.dot(&mz);
817        let penalty = penalty_scale * (self.penalty_z_const + lin_term + quad_term);
818
819        // === Step 5: z-space gradient ===
820        // ∇z log p = L^T ∇_β ℓ  −  penalty_scale · (l + M z)
821        fast_av_into(&self.chol_t, &grad_ll_beta, grad);
822        // gradz -= penalty_scale · (penalty_z_lin + M z); fused parallel update.
823        let lin_view = self.penalty_z_lin.view();
824        ndarray::Zip::from(grad)
825            .and(&lin_view)
826            .and(&mz)
827            .par_for_each(|g, &l, &m| {
828                *g -= penalty_scale * (l + m);
829            });
830
831        ll + firth_logdet - penalty
832    }
833
834    fn family_logp_and_grad_into(
835        &self,
836        eta: &Array1<f64>,
837        residual: &mut Array1<f64>,
838    ) -> (f64, Array1<f64>) {
839        nuts_family_logp_and_grad_into(self.nuts_family, &self.data, eta, residual)
840    }
841
842    /// Get the Cholesky factor L for un-whitening samples
843    pub fn chol(&self) -> &Array2<f64> {
844        &self.chol
845    }
846
847    /// Get the mode
848    pub fn mode(&self) -> &Array1<f64> {
849        &self.data.mode
850    }
851
852    /// Get dimension
853    pub fn dim(&self) -> usize {
854        self.data.dim
855    }
856}
857
858const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_7;
859
860#[inline]
861fn standard_normal_log_pdf(x: f64) -> f64 {
862    -0.5 * x * x - HALF_LOG_2PI
863}
864
865/// Stable log Φ(x) for the standard normal CDF.
866#[inline]
867fn log_ndtr(x: f64) -> f64 {
868    let arg = -x * std::f64::consts::FRAC_1_SQRT_2;
869    let erfc_val = statrs::function::erf::erfc(arg);
870    if erfc_val > 0.0 {
871        erfc_val.ln() - std::f64::consts::LN_2
872    } else {
873        -0.5 * x * x - (-x).ln() - HALF_LOG_2PI
874    }
875}
876
877#[inline]
878fn validate_firth_support(family: NutsFamily, firth_enabled: bool) -> Result<(), HmcError> {
879    let spec = family.likelihood_spec();
880    if firth_enabled && !likelihood_spec_supports_firth(&spec) {
881        return Err(HmcError::FirthUnsupported {
882            reason: format!(
883                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
884                spec.pretty_name()
885            ),
886        });
887    }
888    Ok::<(), _>(())
889}
890
891#[inline]
892fn validate_firth_likelihood_support(
893    likelihood: &LikelihoodSpec,
894    firth_enabled: bool,
895) -> Result<(), HmcError> {
896    if firth_enabled && !likelihood_spec_supports_firth(likelihood) {
897        return Err(HmcError::FirthUnsupported {
898            reason: format!(
899                "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
900                likelihood.pretty_name()
901            ),
902        });
903    }
904    Ok::<(), _>(())
905}
906
907#[inline]
908fn valid_count_response(y: f64) -> bool {
909    y.is_finite() && y >= 0.0 && (y - y.round()).abs() <= 1e-9
910}
911
912fn validate_count_responses(
913    family: &str,
914    y: &ArrayView1<'_, f64>,
915    weights: &ArrayView1<'_, f64>,
916) -> Result<(), HmcError> {
917    for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
918        if wi > 0.0 && !valid_count_response(yi) {
919            return Err(HmcError::InvalidConfig {
920                reason: format!(
921                    "{family} response must be a finite non-negative integer at positive-weight row {i}; got {yi}"
922                ),
923            });
924        }
925    }
926    Ok(())
927}
928
929fn validate_binary_responses(
930    family: &str,
931    y: &ArrayView1<'_, f64>,
932    weights: &ArrayView1<'_, f64>,
933) -> Result<(), HmcError> {
934    for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
935        if wi > 0.0 && !(yi == 0.0 || yi == 1.0) {
936            return Err(HmcError::InvalidConfig {
937                reason: format!(
938                    "{family} response must be exactly 0 or 1 at positive-weight row {i}; got {yi}"
939                ),
940            });
941        }
942    }
943    Ok(())
944}
945
946/// Compute the identifiable-subspace Jeffreys/Firth contribution and its
947/// β-gradient.
948///
949/// HMC uses the same `FirthDenseOperator` as the REML exact-gradient path.
950/// The operator owns the reduced identifiable Fisher factorization, the
951/// Jeffreys log-determinant, and the analytic β-gradient.
952fn firth_jeffreys_logp_and_grad(
953    family: NutsFamily,
954    data: &SharedData,
955    eta: &Array1<f64>,
956) -> Result<(f64, Array1<f64>), HmcError> {
957    if eta.len() != data.n_samples {
958        return Err(HmcError::DimensionMismatch {
959            reason: format!(
960                "Firth Jeffreys term eta length {} != number of samples {}",
961                eta.len(),
962                data.n_samples
963            ),
964        });
965    }
966    if data.dim == 0 || data.n_samples == 0 {
967        return Ok((0.0, Array1::zeros(data.dim)));
968    }
969    validate_firth_support(family, true)?;
970    if data.weights.iter().all(|w| *w == 0.0) {
971        return Ok((0.0, Array1::zeros(data.dim)));
972    }
973
974    let jeffreys_link =
975        likelihood_spec_jeffreys_link(&family.likelihood_spec()).ok_or_else(|| {
976            HmcError::FirthUnsupported {
977                reason: format!(
978                    "Firth Jeffreys term has no Fisher-weight jet for {}",
979                    family.likelihood_spec().pretty_name()
980                ),
981            }
982        })?;
983    let op = if data.weights.iter().all(|&w| w == 1.0) {
984        FirthDenseOperator::build_for_link(&jeffreys_link, data.x.as_ref(), eta)
985    } else {
986        FirthDenseOperator::build_with_observation_weights_for_link(
987            &jeffreys_link,
988            data.x.as_ref(),
989            eta,
990            data.weights.view(),
991        )
992    }
993    .map_err(|e| HmcError::SamplingFailed {
994        reason: format!("Firth Jeffreys operator failed: {e}"),
995    })?;
996    Ok(op.jeffreys_logdet_and_beta_gradient())
997}
998
999// ============================================================================
1000// Shared family log-likelihood helpers
1001// ============================================================================
1002//
1003// Freestanding functions for computing ℓ(y|β) and ∇_β ℓ for each supported
1004// family. Used by both `NutsPosterior` (fixed-ρ β-only sampling) and
1005// `JointBetaRhoPosterior` (joint β+ρ sampling).
1006
1007fn nuts_family_logp_and_grad_into(
1008    family: NutsFamily,
1009    data: &SharedData,
1010    eta: &Array1<f64>,
1011    residual: &mut Array1<f64>,
1012) -> (f64, Array1<f64>) {
1013    match family {
1014        NutsFamily::BinomialLogit => logit_logp_and_grad_into(data, eta, residual),
1015        NutsFamily::BinomialProbit => probit_logp_and_grad_into(data, eta, residual),
1016        NutsFamily::BinomialCLogLog => cloglog_logp_and_grad_into(data, eta, residual),
1017        NutsFamily::Gaussian => gaussian_logp_and_grad_into(data, eta, residual),
1018        NutsFamily::PoissonLog => poisson_log_logp_and_grad(data, eta),
1019        // Family mapping: TweedieLog stores variance power p in data.gamma_shape.
1020        // Its dispersion phi stays in data.dispersion, matching REML scale ownership.
1021        NutsFamily::TweedieLog => tweedie_log_quasilogp_and_grad(data, eta, data.gamma_shape),
1022        NutsFamily::NegativeBinomialLog => {
1023            // Family mapping: NegativeBinomialLog stores theta in data.gamma_shape.
1024            // NB has unit REML scale; theta is never sourced from fixed_phi.
1025            negative_binomial_log_logp_and_grad(data, eta, data.gamma_shape)
1026        }
1027        NutsFamily::GammaLog => gamma_log_logp_and_grad(data, eta),
1028    }
1029}
1030
1031#[derive(Clone, Debug)]
1032struct BinomialLinkTerms {
1033    log_mu: f64,
1034    log1m_mu: f64,
1035    dlog_mu_deta: f64,
1036    dlog1m_mu_deta: f64,
1037    dmu_dlink: Vec<f64>,
1038}
1039
1040#[inline]
1041fn log_terms_from_mu_and_dmu(
1042    mu: f64,
1043    dmu_deta: f64,
1044    dmu_dlink: Vec<f64>,
1045) -> Result<BinomialLinkTerms, String> {
1046    if !(mu.is_finite() && (0.0..=1.0).contains(&mu) && dmu_deta.is_finite()) {
1047        return Err(format!(
1048            "binomial inverse link returned invalid mu/deta derivative: mu={mu}, dmu_deta={dmu_deta}"
1049        ));
1050    }
1051    let log_mu = if mu == 0.0 {
1052        f64::NEG_INFINITY
1053    } else {
1054        mu.ln()
1055    };
1056    let one_minus_mu = 1.0 - mu;
1057    let log1m_mu = if one_minus_mu == 0.0 {
1058        f64::NEG_INFINITY
1059    } else {
1060        one_minus_mu.ln()
1061    };
1062    let dlog_mu_deta = if mu == 0.0 {
1063        f64::INFINITY.copysign(dmu_deta)
1064    } else {
1065        dmu_deta / mu
1066    };
1067    let dlog1m_mu_deta = if one_minus_mu == 0.0 {
1068        f64::NEG_INFINITY.copysign(dmu_deta)
1069    } else {
1070        -dmu_deta / one_minus_mu
1071    };
1072    Ok(BinomialLinkTerms {
1073        log_mu,
1074        log1m_mu,
1075        dlog_mu_deta,
1076        dlog1m_mu_deta,
1077        dmu_dlink,
1078    })
1079}
1080
1081#[inline]
1082fn binomial_link_terms(
1083    inverse_link: &InverseLink,
1084    eta: f64,
1085    n_link_params: usize,
1086) -> Result<BinomialLinkTerms, String> {
1087    let jet =
1088        inverse_link_jet_for_inverse_link(inverse_link, eta).map_err(|err| err.to_string())?;
1089    let mut dmu_dlink = vec![0.0; n_link_params];
1090    if n_link_params > 0 {
1091        match inverse_link
1092            .param_partials(eta)
1093            .map_err(|err| err.to_string())?
1094        {
1095            Some(LinkParamPartials::Sas(partials)) => {
1096                if n_link_params != 2 {
1097                    return Err(format!(
1098                        "SAS/Beta-Logistic link parameter dimension mismatch: expected 2, got {n_link_params}"
1099                    ));
1100                }
1101                dmu_dlink[0] = partials.djet_depsilon.mu;
1102                dmu_dlink[1] = partials.djet_dlog_delta.mu;
1103            }
1104            Some(LinkParamPartials::Mixture(partials)) => {
1105                if partials.djet_drho.len() != n_link_params {
1106                    return Err(format!(
1107                        "mixture link parameter dimension mismatch: expected {}, got {n_link_params}",
1108                        partials.djet_drho.len()
1109                    ));
1110                }
1111                for (slot, partial) in dmu_dlink.iter_mut().zip(partials.djet_drho.iter()) {
1112                    *slot = partial.mu;
1113                }
1114            }
1115            None => {
1116                return Err(format!(
1117                    "joint HMC expected {n_link_params} adaptive link parameters, but the inverse link exposes none"
1118                ));
1119            }
1120        }
1121    }
1122    log_terms_from_mu_and_dmu(jet.mu, jet.d1, dmu_dlink)
1123}
1124
1125fn joint_binomial_logp_grad_and_link_grad(
1126    inverse_link: &InverseLink,
1127    data: &SharedData,
1128    eta: &Array1<f64>,
1129    n_link_params: usize,
1130) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1131    let n = data.n_samples;
1132    // Per-row: compute stable log-tail terms and derivatives without endpoint
1133    // clamping. Positive-weight responses were validated as Bernoulli before
1134    // target construction, so each row selects exactly one log branch.
1135    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1136    let per_row: Result<Vec<(f64, f64, Vec<f64>)>, String> = (0..n)
1137        .into_par_iter()
1138        .map(|i| {
1139            let y_i = data.y[i];
1140            let w_i = data.weights[i];
1141            if w_i <= 0.0 {
1142                return Ok((0.0, 0.0, vec![0.0; n_link_params]));
1143            }
1144            let terms = binomial_link_terms(inverse_link, eta[i], n_link_params)?;
1145            if y_i == 1.0 {
1146                let inv_mu = terms.log_mu.exp().recip();
1147                let log_mu = terms.log_mu;
1148                let dlog_mu_deta = terms.dlog_mu_deta;
1149                let grad_link = terms
1150                    .dmu_dlink
1151                    .into_iter()
1152                    .map(|dmu| w_i * dmu * inv_mu)
1153                    .collect();
1154                Ok((w_i * log_mu, w_i * dlog_mu_deta, grad_link))
1155            } else if y_i == 0.0 {
1156                let inv_one_minus_mu = terms.log1m_mu.exp().recip();
1157                let log1m_mu = terms.log1m_mu;
1158                let dlog1m_mu_deta = terms.dlog1m_mu_deta;
1159                let grad_link = terms
1160                    .dmu_dlink
1161                    .into_iter()
1162                    .map(|dmu| -w_i * dmu * inv_one_minus_mu)
1163                    .collect();
1164                Ok((w_i * log1m_mu, w_i * dlog1m_mu_deta, grad_link))
1165            } else {
1166                Err(format!(
1167                    "binomial joint HMC response must be exactly 0 or 1 after validation; got {y_i}"
1168                ))
1169            }
1170        })
1171        .collect();
1172    let per_row = per_row?;
1173    let mut residual = Array1::<f64>::zeros(n);
1174    let mut grad_link = Array1::<f64>::zeros(n_link_params);
1175    let mut ll = 0.0;
1176    for (i, (ll_i, residual_i, grad_link_i)) in per_row.into_iter().enumerate() {
1177        ll += ll_i;
1178        residual[i] = residual_i;
1179        for (slot, value) in grad_link.iter_mut().zip(grad_link_i.iter()) {
1180            *slot += *value;
1181        }
1182    }
1183
1184    Ok((ll, fast_atv(&data.x, &residual), grad_link))
1185}
1186
1187fn joint_binomial_logp_and_grad(
1188    likelihood: &LikelihoodSpec,
1189    data: &SharedData,
1190    eta: &Array1<f64>,
1191) -> Result<(f64, Array1<f64>), String> {
1192    if !matches!(likelihood.response, ResponseFamily::Binomial) {
1193        return Err(HmcError::UnsupportedFamily {
1194            reason: format!(
1195                "{} is not a binomial joint-HMC family",
1196                likelihood.pretty_name()
1197            ),
1198        }
1199        .into());
1200    }
1201    match &likelihood.link {
1202        InverseLink::Standard(StandardLink::Logit) => Ok(logit_logp_and_grad(data, eta)),
1203        InverseLink::Standard(StandardLink::Probit) => Ok(probit_logp_and_grad(data, eta)),
1204        InverseLink::Standard(StandardLink::CLogLog) => Ok(cloglog_logp_and_grad(data, eta)),
1205        InverseLink::LatentCLogLog(_)
1206        | InverseLink::Sas(_)
1207        | InverseLink::BetaLogistic(_)
1208        | InverseLink::Mixture(_) => {
1209            let (ll, grad_beta, _) =
1210                joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, 0)?;
1211            Ok((ll, grad_beta))
1212        }
1213        InverseLink::Standard(_) => Err(HmcError::UnsupportedFamily {
1214            reason: format!(
1215                "{} is not a binomial joint-HMC family",
1216                likelihood.pretty_name()
1217            ),
1218        }
1219        .into()),
1220    }
1221}
1222
1223fn joint_family_logp_grad_and_link_grad(
1224    likelihood: &LikelihoodSpec,
1225    data: &SharedData,
1226    eta: &Array1<f64>,
1227    n_link_params: usize,
1228) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1229    match (&likelihood.response, &likelihood.link) {
1230        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
1231            let (ll, grad) = logit_logp_and_grad(data, eta);
1232            Ok((ll, grad, Array1::zeros(n_link_params)))
1233        }
1234        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
1235            let (ll, grad) = probit_logp_and_grad(data, eta);
1236            Ok((ll, grad, Array1::zeros(n_link_params)))
1237        }
1238        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
1239            let (ll, grad) = cloglog_logp_and_grad(data, eta);
1240            Ok((ll, grad, Array1::zeros(n_link_params)))
1241        }
1242        (
1243            ResponseFamily::Binomial,
1244            InverseLink::LatentCLogLog(_)
1245            | InverseLink::Sas(_)
1246            | InverseLink::BetaLogistic(_)
1247            | InverseLink::Mixture(_),
1248        ) => joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, n_link_params),
1249        _ => {
1250            let (ll, grad) = joint_family_logp_and_grad(likelihood, data, eta)?;
1251            Ok((ll, grad, Array1::zeros(n_link_params)))
1252        }
1253    }
1254}
1255
1256fn joint_family_logp_and_grad(
1257    likelihood: &LikelihoodSpec,
1258    data: &SharedData,
1259    eta: &Array1<f64>,
1260) -> Result<(f64, Array1<f64>), String> {
1261    match &likelihood.response {
1262        ResponseFamily::Binomial => joint_binomial_logp_and_grad(likelihood, data, eta),
1263        ResponseFamily::Gaussian => Ok(gaussian_logp_and_grad(data, eta)),
1264        ResponseFamily::Poisson => Ok(poisson_log_logp_and_grad(data, eta)),
1265        ResponseFamily::Tweedie { p } => {
1266            // Family mapping: Tweedie payload p is the variance power.
1267            // Its dispersion phi stays in data.dispersion, matching REML.
1268            let p = *p;
1269            if !is_valid_tweedie_power(p) {
1270                return Err(HmcError::InvalidConfig {
1271                    reason: format!(
1272                        "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
1273                    ),
1274                }
1275                .into());
1276            }
1277            Ok(tweedie_log_quasilogp_and_grad(data, eta, p))
1278        }
1279        ResponseFamily::NegativeBinomial { theta, .. } => {
1280            // Family mapping: NegativeBinomial payload theta is overdispersion.
1281            // NB keeps unit REML scale and never reads fixed_phi for theta.
1282            Ok(negative_binomial_log_logp_and_grad(data, eta, *theta))
1283        }
1284        ResponseFamily::Beta { .. } => Err(HmcError::UnsupportedFamily {
1285            reason: "Joint HMC fallback is not implemented for BetaLogit".to_string(),
1286        }
1287        .into()),
1288        ResponseFamily::Gamma => Ok(gamma_log_logp_and_grad(data, eta)),
1289        ResponseFamily::RoystonParmar => Err(HmcError::UnsupportedFamily {
1290            reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
1291        }
1292        .into()),
1293    }
1294}
1295
1296/// Logistic regression log-likelihood and gradient.
1297///
1298/// log p(y|η) = y·η − log(1 + exp(η)), gradient = X'(w ⊙ (y − μ))
1299fn logit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1300    let mut residual = Array1::<f64>::zeros(data.n_samples);
1301    logit_logp_and_grad_into(data, eta, &mut residual)
1302}
1303
1304fn logit_logp_and_grad_into(
1305    data: &SharedData,
1306    eta: &Array1<f64>,
1307    residual: &mut Array1<f64>,
1308) -> (f64, Array1<f64>) {
1309    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1310    let n = data.n_samples;
1311    assert_eq!(residual.len(), n);
1312    // Per-row independent: write residual entry into a pre-allocated buffer and
1313    // reduce the ll contribution in parallel — avoids materialising a
1314    // Vec<(f64, f64)> and the serial scatter that follows.
1315    let ll: f64 = residual
1316        .as_slice_mut()
1317        .unwrap()
1318        .par_iter_mut()
1319        .enumerate()
1320        .map(|(i, slot)| {
1321            let eta_i = eta[i];
1322            let y_i = data.y[i];
1323            let w_i = data.weights[i];
1324            let mu = gam_linalg::utils::stable_logistic(eta_i);
1325            *slot = w_i * (y_i - mu);
1326            w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i))
1327        })
1328        .sum();
1329
1330    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1331    (ll, grad_ll)
1332}
1333
1334/// Probit regression log-likelihood and gradient.
1335///
1336/// log p(y|η) = Σ [y·log Φ(η) + (1-y)·log(1-Φ(η))],
1337/// gradient_i = w_i · [y_i · φ(η_i)/Φ(η_i) − (1-y_i) · φ(η_i)/(1−Φ(η_i))]
1338///
1339/// Uses erfc-based log Φ for numerical stability.
1340fn probit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1341    let mut residual = Array1::<f64>::zeros(data.n_samples);
1342    probit_logp_and_grad_into(data, eta, &mut residual)
1343}
1344
1345fn probit_logp_and_grad_into(
1346    data: &SharedData,
1347    eta: &Array1<f64>,
1348    residual: &mut Array1<f64>,
1349) -> (f64, Array1<f64>) {
1350    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1351    let n = data.n_samples;
1352    assert_eq!(residual.len(), n);
1353    let ll: f64 = residual
1354        .as_slice_mut()
1355        .unwrap()
1356        .par_iter_mut()
1357        .enumerate()
1358        .map(|(i, slot)| {
1359            let eta_i = eta[i];
1360            let y_i = data.y[i];
1361            let w_i = data.weights[i];
1362            let log_phi_pos = log_ndtr(eta_i);
1363            let log_phi_neg = log_ndtr(-eta_i);
1364            let log_phi_val = standard_normal_log_pdf(eta_i);
1365            let ratio_pos = (log_phi_val - log_phi_pos).exp();
1366            let ratio_neg = (log_phi_val - log_phi_neg).exp();
1367            let grad_i = y_i * ratio_pos - (1.0 - y_i) * ratio_neg;
1368            *slot = w_i * grad_i;
1369            w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg)
1370        })
1371        .sum();
1372
1373    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1374    (ll, grad_ll)
1375}
1376
1377/// Complementary log-log regression log-likelihood and gradient.
1378///
1379/// CLogLog link: μ = 1 − exp(−exp(η))
1380/// log p(y|η) = Σ [y·log(1−exp(−exp(η))) + (1−y)·(−exp(η))]
1381/// gradient_i = w_i · [y_i · exp(η_i)·exp(−exp(η_i)) / (1−exp(−exp(η_i))) − (1−y_i)·exp(η_i)]
1382#[inline]
1383fn cloglog_bernoulli_logp_and_residual(eta: f64, y: f64) -> Result<(f64, f64), EstimationError> {
1384    if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
1385        gam_problem::bail_invalid_estim!(
1386            "cloglog eta must be finite and within [-700, 700]; got {eta}"
1387        );
1388    }
1389    let exp_eta = eta.exp();
1390    // log_mu = log(1 - exp(-exp_eta)); exp_eta > 0 on the guarded domain, so this is
1391    // exactly the canonical cancellation-free log1mexp (single source of truth).
1392    let log_mu = crate::probability::log1mexp_positive(exp_eta);
1393    let log_one_minus_mu = -exp_eta;
1394    let grad_log_mu = (eta - exp_eta - log_mu).exp();
1395    let ll_i = y * log_mu + (1.0 - y) * log_one_minus_mu;
1396    let residual_i = y * grad_log_mu - (1.0 - y) * exp_eta;
1397    Ok((ll_i, residual_i))
1398}
1399
1400fn cloglog_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1401    let mut residual = Array1::<f64>::zeros(data.n_samples);
1402    cloglog_logp_and_grad_into(data, eta, &mut residual)
1403}
1404
1405fn cloglog_logp_and_grad_into(
1406    data: &SharedData,
1407    eta: &Array1<f64>,
1408    residual: &mut Array1<f64>,
1409) -> (f64, Array1<f64>) {
1410    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1411    let n = data.n_samples;
1412    assert_eq!(residual.len(), n);
1413    if eta
1414        .iter()
1415        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1416    {
1417        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1418    }
1419    let ll: f64 = residual
1420        .as_slice_mut()
1421        .unwrap()
1422        .par_iter_mut()
1423        .enumerate()
1424        .map(|(i, slot)| {
1425            let y_i = data.y[i];
1426            let w_i = data.weights[i];
1427            let (ll_i, residual_i) =
1428                cloglog_bernoulli_logp_and_residual(eta[i], y_i).expect("validated cloglog eta");
1429            *slot = w_i * residual_i;
1430            w_i * ll_i
1431        })
1432        .sum();
1433
1434    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1435    (ll, grad_ll)
1436}
1437
1438/// Gaussian log-likelihood and gradient.
1439///
1440/// log p(y|η) = −½ (w/φ)·(y − η)²,  gradient = (1/φ)·X'(w ⊙ (y − η))
1441///
1442/// Both the log-likelihood and its β-gradient are scaled by `1/φ` so that
1443/// the working likelihood matches the φ-scaled posterior covariance the
1444/// HMC whitening transform targets. With `φ == 1` (the only value
1445/// passed by the pre-refactor call sites) this collapses to the original
1446/// `−½ w·(y − η)²` expression; with an estimated dispersion (the
1447/// `Dispersion::Estimated(σ²)` branch) it removes the silent unit-σ
1448/// approximation the Gaussian NUTS log-density used previously.
1449fn gaussian_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1450    let mut weighted_residual = Array1::<f64>::zeros(data.n_samples);
1451    gaussian_logp_and_grad_into(data, eta, &mut weighted_residual)
1452}
1453
1454fn gaussian_logp_and_grad_into(
1455    data: &SharedData,
1456    eta: &Array1<f64>,
1457    weighted_residual: &mut Array1<f64>,
1458) -> (f64, Array1<f64>) {
1459    use gam_problem::dispersion_cov::DispersionExt as _;
1460    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1461    let n = data.n_samples;
1462    let inv_phi = data.dispersion.inv_phi();
1463    assert_eq!(weighted_residual.len(), n);
1464    // Per-row: residual = y - η, weighted_residual = (w/φ)·residual,
1465    // ll contribution = -0.5·(w/φ)·residual². All independent across rows.
1466    let ll: f64 = weighted_residual
1467        .as_slice_mut()
1468        .unwrap()
1469        .par_iter_mut()
1470        .enumerate()
1471        .map(|(i, slot)| {
1472            let residual = data.y[i] - eta[i];
1473            let w_i = data.weights[i];
1474            let scaled = w_i * inv_phi;
1475            *slot = scaled * residual;
1476            -0.5 * scaled * residual * residual
1477        })
1478        .sum();
1479
1480    let grad_ll = fast_atv(data.x.as_ref(), &*weighted_residual);
1481    (ll, grad_ll)
1482}
1483
1484/// Poisson(log) log-likelihood and gradient.
1485///
1486/// log p(y|η) = y·η − exp(η), gradient = X'(w ⊙ (y − μ))
1487fn poisson_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1488    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1489    let n = data.n_samples;
1490    if eta
1491        .iter()
1492        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1493    {
1494        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1495    }
1496    let mut residual = Array1::<f64>::zeros(n);
1497    let ll: f64 = residual
1498        .as_slice_mut()
1499        .unwrap()
1500        .par_iter_mut()
1501        .enumerate()
1502        .map(|(i, slot)| {
1503            let eta_i = eta[i];
1504            let mu_i = eta_i.exp();
1505            let y_i = data.y[i];
1506            let w_i = data.weights[i];
1507            *slot = w_i * (y_i - mu_i);
1508            w_i * (y_i * eta_i - mu_i)
1509        })
1510        .sum();
1511
1512    let grad_ll = fast_atv(&data.x, &residual);
1513    (ll, grad_ll)
1514}
1515
1516fn tweedie_log_quasilogp_and_grad(
1517    data: &SharedData,
1518    eta: &Array1<f64>,
1519    p: f64,
1520) -> (f64, Array1<f64>) {
1521    use gam_problem::dispersion_cov::DispersionExt as _;
1522    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1523    let n = data.n_samples;
1524    // Family mapping: Tweedie p is the variant payload; phi is data.dispersion.
1525    // Invalid payloads invalidate the target instead of falling back to p=1.5.
1526    if !is_valid_tweedie_power(p) {
1527        return (f64::NAN, Array1::from_elem(data.dim, f64::NAN));
1528    }
1529    if eta
1530        .iter()
1531        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1532    {
1533        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1534    }
1535    let inv_phi = data.dispersion.inv_phi();
1536    let mut residual = Array1::<f64>::zeros(n);
1537    let ll: f64 = residual
1538        .as_slice_mut()
1539        .unwrap()
1540        .par_iter_mut()
1541        .enumerate()
1542        .map(|(i, slot)| {
1543            let eta_i = eta[i];
1544            let mu_i = eta_i.exp().max(1e-300);
1545            let y_i = data.y[i];
1546            let w_i = data.weights[i] * inv_phi;
1547            *slot = w_i * (y_i - mu_i) * mu_i.powf(1.0 - p);
1548            let qll = y_i * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p);
1549            w_i * qll
1550        })
1551        .sum();
1552
1553    let grad_ll = fast_atv(&data.x, &residual);
1554    (ll, grad_ll)
1555}
1556
1557fn negative_binomial_log_logp_and_grad(
1558    data: &SharedData,
1559    eta: &Array1<f64>,
1560    theta: f64,
1561) -> (f64, Array1<f64>) {
1562    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1563    let n = data.n_samples;
1564    if !(theta.is_finite() && theta > 0.0)
1565        || eta
1566            .iter()
1567            .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1568        || data
1569            .y
1570            .iter()
1571            .zip(data.weights.iter())
1572            .any(|(&y_i, &w_i)| w_i > 0.0 && !valid_count_response(y_i))
1573    {
1574        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1575    }
1576    let mut residual = Array1::<f64>::zeros(n);
1577    let ll: f64 = residual
1578        .as_slice_mut()
1579        .unwrap()
1580        .par_iter_mut()
1581        .enumerate()
1582        .map(|(i, slot)| {
1583            let eta_i = eta[i];
1584            let mu_i = eta_i.exp().max(1e-12);
1585            let y_i = data.y[i];
1586            let w_i = data.weights[i];
1587            if w_i <= 0.0 {
1588                *slot = 0.0;
1589                return 0.0;
1590            }
1591            let log_mu_term = if y_i > 0.0 { y_i * mu_i.ln() } else { 0.0 };
1592            *slot = w_i * theta * (y_i - mu_i) / (theta + mu_i);
1593            w_i * (statrs::function::gamma::ln_gamma(y_i + theta)
1594                - statrs::function::gamma::ln_gamma(theta)
1595                - statrs::function::gamma::ln_gamma(y_i + 1.0)
1596                + theta * (theta.ln() - (theta + mu_i).ln())
1597                + log_mu_term
1598                - y_i * (theta + mu_i).ln())
1599        })
1600        .sum();
1601
1602    let grad_ll = fast_atv(&data.x, &residual);
1603    (ll, grad_ll)
1604}
1605
1606fn gamma_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1607    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1608    let n = data.n_samples;
1609    if eta
1610        .iter()
1611        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1612    {
1613        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1614    }
1615    let shape = data.gamma_shape.max(1e-10);
1616    // Hoist shape-only constants out of the per-sample loop: ln Γ(shape) and
1617    // shape · ln(shape) are independent of i, so previously each sample paid
1618    // an extra `ln_gamma` and `ln` plus a multiply. n is typically large-scale-
1619    // scale, so this collapses Θ(n) gamma-function evaluations to one.
1620    let shape_ln_shape = shape * shape.ln();
1621    let log_gamma_shape = statrs::function::gamma::ln_gamma(shape);
1622    let shape_minus_one = shape - 1.0;
1623    let mut residual = Array1::<f64>::zeros(n);
1624    let ll: f64 = residual
1625        .as_slice_mut()
1626        .unwrap()
1627        .par_iter_mut()
1628        .enumerate()
1629        .map(|(i, slot)| {
1630            let eta_i = eta[i];
1631            let mu_i = eta_i.exp();
1632            let y_i = data.y[i];
1633            let w_i = data.weights[i];
1634            let ll_i = w_i
1635                * (shape_ln_shape - log_gamma_shape - shape * eta_i
1636                    + shape_minus_one * y_i.max(1e-12).ln()
1637                    - shape * y_i / mu_i);
1638            *slot = w_i * shape * (y_i / mu_i - 1.0);
1639            ll_i
1640        })
1641        .sum();
1642
1643    let grad_ll = fast_atv(&data.x, &residual);
1644    (ll, grad_ll)
1645}
1646
1647#[cfg(test)]
1648mod tests {
1649    use super::{
1650        FamilyNutsInputs, GlmFlatInputs, JointBetaRhoInputs, JointBetaRhoPosterior,
1651        LinkWigglePosterior, LinkWiggleSplineArtifacts, NutsConfig, NutsFamily, NutsPosterior,
1652        SharedData, cloglog_bernoulli_logp_and_residual, firth_jeffreys_logp_and_grad,
1653        joint_family_logp_and_grad, laplace_directional_cubic_diagnostic,
1654        laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
1655        run_joint_beta_rho_sampling, run_logit_polya_gamma_gibbs,
1656        run_nuts_sampling_flattened_family,
1657    };
1658    use gam_linalg::matrix::DesignMatrix;
1659    use gam_models::survival::{PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec};
1660    use gam_problem::types::{
1661        InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LogLikelihoodNormalization,
1662        ResponseFamily, RhoPrior, StandardLink,
1663    };
1664    use gam_solve::estimate::{
1665        BlockRole, FitGeometry, FitInference, FittedBlock, FittedLinkState, UnifiedFitResult,
1666        UnifiedFitResultParts,
1667    };
1668    use gam_terms::construction::CanonicalPenalty;
1669    use general_mcmc::generic_hmc::HamiltonianTarget;
1670    use ndarray::{Array1, Array2, array};
1671    use std::sync::Arc;
1672
1673    impl NutsPosterior {
1674        /// Test-only allocation wrapper around `compute_logp_and_grad_nd_into`.
1675        pub(super) fn compute_logp_and_grad_nd(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1676            let mut residual = Array1::<f64>::zeros(self.data.n_samples);
1677            let mut grad = Array1::<f64>::zeros(z.len());
1678            let logp = self.compute_logp_and_grad_nd_into(z, &mut residual, &mut grad);
1679            (logp, grad)
1680        }
1681    }
1682
1683    impl LinkWigglePosterior {
1684        /// Test-only allocation wrapper around `compute_logp_and_grad_into`.
1685        pub(super) fn compute_logp_and_grad(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1686            let dim = self.p_base + self.p_link;
1687            let mut grad = Array1::<f64>::zeros(dim);
1688            let logp = self.compute_logp_and_grad_into(z, &mut grad);
1689            (logp, grad)
1690        }
1691    }
1692
1693    impl JointBetaRhoPosterior {
1694        /// Test-only allocation wrapper around `compute_joint_logp_and_grad_into`.
1695        pub(super) fn compute_joint_logp_and_grad(
1696            &self,
1697            params: &Array1<f64>,
1698        ) -> (f64, Array1<f64>) {
1699            let total_dim = self.n_beta + self.n_rho + self.n_link_params;
1700            let mut grad = Array1::<f64>::zeros(total_dim);
1701            let logp = self.compute_joint_logp_and_grad_into(params, &mut grad);
1702            (logp, grad)
1703        }
1704    }
1705
1706    fn hmc_test_fit(
1707        blocks: Vec<FittedBlock>,
1708        inference: Option<FitInference>,
1709        geometry: Option<FitGeometry>,
1710    ) -> UnifiedFitResult {
1711        let lambdas = Array1::zeros(0);
1712        UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1713            blocks,
1714            log_lambdas: lambdas.clone(),
1715            lambdas,
1716            likelihood_family: Some(LikelihoodSpec::new(
1717                ResponseFamily::Gaussian,
1718                InverseLink::Standard(StandardLink::Identity),
1719            )),
1720            likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1721            log_likelihood_normalization: LogLikelihoodNormalization::Full,
1722            log_likelihood: -1.0,
1723            deviance: 2.0,
1724            reml_score: 0.0,
1725            stable_penalty_term: 0.0,
1726            penalized_objective: 0.0,
1727            used_device: false,
1728            outer_iterations: 1,
1729            outer_converged: true,
1730            outer_gradient_norm: None,
1731            standard_deviation: 1.0,
1732            covariance_conditional: None,
1733            covariance_corrected: None,
1734            inference,
1735            fitted_link: FittedLinkState::Standard(None),
1736            geometry,
1737            block_states: Vec::new(),
1738            pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1739            max_abs_eta: 0.0,
1740            constraint_kkt: None,
1741            artifacts: Default::default(),
1742            inner_cycles: 0,
1743        })
1744        .expect("valid HMC handoff test fit")
1745    }
1746
1747    #[test]
1748    fn hmc_whitening_consumes_standard_fit_inference_hessian() {
1749        let hessian = array![[2.0, 0.1], [0.1, 1.6]];
1750        let fit = hmc_test_fit(
1751            vec![FittedBlock {
1752                beta: array![0.05, -0.1],
1753                role: BlockRole::Mean,
1754                edf: 2.0,
1755                lambdas: Array1::zeros(0),
1756            }],
1757            Some(FitInference {
1758                edf_by_block: vec![],
1759                penalty_block_trace: vec![],
1760                edf_total: 2.0,
1761                smoothing_correction: None,
1762                penalized_hessian: hessian.clone().into(),
1763                working_weights: array![1.0, 1.0, 1.0],
1764                working_response: array![0.0, 0.1, -0.2],
1765                reparam_qs: None,
1766                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
1767                beta_covariance: None,
1768                beta_standard_errors: None,
1769                beta_covariance_corrected: None,
1770                beta_standard_errors_corrected: None,
1771                beta_covariance_frequentist: None,
1772                coefficient_influence: None,
1773                weighted_gram: None,
1774                bias_correction_beta: None,
1775            }),
1776            None,
1777        );
1778
1779        let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "standard fit")
1780            .expect("standard fit exports explicit Hessian");
1781        assert_eq!(explicit, &hessian);
1782
1783        let x = array![[1.0, 0.0], [1.0, 0.5], [1.0, -0.5]];
1784        let y = array![0.0, 0.2, -0.1];
1785        let weights = Array1::ones(3);
1786        let penalty = Array2::eye(2);
1787        NutsPosterior::new(
1788            x.view(),
1789            y.view(),
1790            weights.view(),
1791            penalty.view(),
1792            fit.beta.view(),
1793            explicit.view(),
1794            NutsFamily::Gaussian,
1795            1.0,
1796            gam_solve::estimate::Dispersion::Known(1.0),
1797            false,
1798        )
1799        .expect("HMC target whitens with upstream Hessian");
1800    }
1801
1802    #[test]
1803    fn hmc_whitening_consumes_blockwise_geometry_hessian() {
1804        let hessian = array![[3.0, 0.2], [0.2, 2.0]];
1805        let fit = hmc_test_fit(
1806            vec![
1807                FittedBlock {
1808                    beta: array![0.1],
1809                    role: BlockRole::Location,
1810                    edf: 1.0,
1811                    lambdas: Array1::zeros(0),
1812                },
1813                FittedBlock {
1814                    beta: array![-0.2],
1815                    role: BlockRole::Scale,
1816                    edf: 1.0,
1817                    lambdas: Array1::zeros(0),
1818                },
1819            ],
1820            None,
1821            Some(FitGeometry {
1822                penalized_hessian: hessian.clone().into(),
1823                working_weights: array![1.0, 0.8],
1824                working_response: array![0.0, 0.1],
1825            }),
1826        );
1827
1828        let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "blockwise fit")
1829            .expect("blockwise fit exports materialized Hessian");
1830        assert_eq!(explicit, &hessian);
1831    }
1832
1833    #[test]
1834    fn hmc_whitening_rejects_covariance_only_fit_without_synthesizing_hessian() {
1835        let fit = UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1836            blocks: vec![FittedBlock {
1837                beta: array![0.0],
1838                role: BlockRole::Mean,
1839                edf: 1.0,
1840                lambdas: Array1::zeros(0),
1841            }],
1842            log_lambdas: Array1::zeros(0),
1843            lambdas: Array1::zeros(0),
1844            likelihood_family: Some(LikelihoodSpec::new(
1845                ResponseFamily::Gaussian,
1846                InverseLink::Standard(StandardLink::Identity),
1847            )),
1848            likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1849            log_likelihood_normalization: LogLikelihoodNormalization::Full,
1850            log_likelihood: -1.0,
1851            deviance: 2.0,
1852            reml_score: 0.0,
1853            stable_penalty_term: 0.0,
1854            penalized_objective: 0.0,
1855            used_device: false,
1856            outer_iterations: 1,
1857            outer_converged: true,
1858            outer_gradient_norm: None,
1859            standard_deviation: 1.0,
1860            covariance_conditional: Some(array![[0.5]]),
1861            covariance_corrected: None,
1862            inference: None,
1863            fitted_link: FittedLinkState::Standard(None),
1864            geometry: None,
1865            block_states: Vec::new(),
1866            pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1867            max_abs_eta: 0.0,
1868            constraint_kkt: None,
1869            artifacts: Default::default(),
1870            inner_cycles: 0,
1871        })
1872        .expect("covariance-only fit can exist for prediction");
1873
1874        let err = super::explicit_fit_hessian_for_whitening(&fit, 1, "covariance-only fit")
1875            .expect_err("HMC must not invert covariance as a Hessian fallback");
1876        assert!(
1877            err.contains("missing an explicit penalized Hessian"),
1878            "unexpected error: {err}"
1879        );
1880    }
1881
1882    #[test]
1883    fn log1pexp_is_finite_for_extreme_eta() {
1884        assert!(gam_linalg::utils::stable_softplus(1000.0).is_finite());
1885        assert!(gam_linalg::utils::stable_softplus(-1000.0).is_finite());
1886        assert!((gam_linalg::utils::stable_softplus(-1000.0) - 0.0).abs() < 1e-12);
1887    }
1888
1889    #[test]
1890    fn sigmoid_stable_behaves_at_extremes() {
1891        let hi = gam_linalg::utils::stable_logistic(1000.0);
1892        let lo = gam_linalg::utils::stable_logistic(-1000.0);
1893        assert!((1.0 - 1e-12..=1.0).contains(&hi));
1894        assert!((0.0..=1e-12).contains(&lo));
1895    }
1896
1897    #[test]
1898    fn cloglog_log_mu_uses_complementary_loglog_inverse_link() {
1899        let eta = -1.0_f64;
1900        let (ll_y1, residual_y1) =
1901            cloglog_bernoulli_logp_and_residual(eta, 1.0).expect("valid eta");
1902        let expected = (1.0 - (-eta.exp()).exp()).ln();
1903        let wrong_log_one_minus_exp_eta = (1.0 - eta.exp()).ln();
1904
1905        assert!((ll_y1 - expected).abs() < 1e-14);
1906        assert!((ll_y1 - wrong_log_one_minus_exp_eta).abs() > 0.5);
1907
1908        let eps = 1e-6;
1909        let (lp, _) = cloglog_bernoulli_logp_and_residual(eta + eps, 1.0).expect("valid eta");
1910        let (lm, _) = cloglog_bernoulli_logp_and_residual(eta - eps, 1.0).expect("valid eta");
1911        let fd = (lp - lm) / (2.0 * eps);
1912        assert!(
1913            (residual_y1 - fd).abs() < 1e-9,
1914            "cloglog residual is not the derivative of log μ: analytic={residual_y1}, fd={fd}"
1915        );
1916    }
1917
1918    #[test]
1919    fn link_wiggle_posterior_whitening_uses_supplied_explicit_joint_hessian() {
1920        let x = array![[1.0], [1.0], [1.0]];
1921        let y = array![0.0, 1.0, 1.0];
1922        let weights = Array1::ones(3);
1923        let penalty_base = Array2::zeros((1, 1));
1924        let penalty_link = Array2::zeros((1, 1));
1925        let mode_beta = array![0.2];
1926        let mode_theta = array![0.05];
1927        let hessian = array![[4.0, 1.0], [1.0, 3.0]];
1928        let spline = LinkWiggleSplineArtifacts {
1929            knot_range: (-1.0, 1.0),
1930            knot_vector: Array1::from_vec(vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
1931            degree: 2,
1932        };
1933
1934        let posterior = LinkWigglePosterior::new(
1935            x.view(),
1936            y.view(),
1937            weights.view(),
1938            penalty_base.view(),
1939            penalty_link.view(),
1940            mode_beta.view(),
1941            mode_theta.view(),
1942            hessian.view(),
1943            spline,
1944            NutsFamily::BinomialLogit,
1945            1.0,
1946        )
1947        .expect("link-wiggle posterior should accept explicit SPD joint Hessian");
1948
1949        let reconstructed_cov = posterior.chol().dot(&posterior.chol().t());
1950        let eye_from_hessian = hessian.dot(&reconstructed_cov);
1951        for r in 0..2 {
1952            for c in 0..2 {
1953                let expected = if r == c { 1.0 } else { 0.0 };
1954                assert!(
1955                    (eye_from_hessian[[r, c]] - expected).abs() < 1e-10,
1956                    "whitening did not use the supplied explicit joint Hessian at ({r},{c}): got {} expected {}",
1957                    eye_from_hessian[[r, c]],
1958                    expected
1959                );
1960            }
1961        }
1962    }
1963
1964    #[test]
1965    fn link_wiggle_cloglog_gradient_matches_its_log_likelihood() {
1966        let x = array![[1.0], [1.0], [1.0], [1.0]];
1967        let y = array![1.0, 0.0, 1.0, 0.0];
1968        let weights = array![1.0, 1.2, 0.8, 1.4];
1969        let penalty_base = Array2::zeros((1, 1));
1970        let penalty_link = Array2::zeros((1, 1));
1971        let mode_beta = array![-0.8];
1972        let mode_theta = array![0.04];
1973        let hessian = Array2::eye(2);
1974        let spline = LinkWiggleSplineArtifacts {
1975            knot_range: (-1.5, 0.5),
1976            knot_vector: Array1::from_vec(vec![-1.5, -1.5, -1.5, 0.5, 0.5, 0.5]),
1977            degree: 2,
1978        };
1979
1980        let posterior = LinkWigglePosterior::new(
1981            x.view(),
1982            y.view(),
1983            weights.view(),
1984            penalty_base.view(),
1985            penalty_link.view(),
1986            mode_beta.view(),
1987            mode_theta.view(),
1988            hessian.view(),
1989            spline,
1990            NutsFamily::BinomialCLogLog,
1991            1.0,
1992        )
1993        .expect("cloglog link-wiggle posterior");
1994
1995        let z = array![0.2, -0.03];
1996        let (_, grad) = posterior.compute_logp_and_grad(&z);
1997        let eps = 1e-6;
1998        for j in 0..z.len() {
1999            let mut z_plus = z.clone();
2000            let mut z_minus = z.clone();
2001            z_plus[j] += eps;
2002            z_minus[j] -= eps;
2003            let (lp, _) = posterior.compute_logp_and_grad(&z_plus);
2004            let (lm, _) = posterior.compute_logp_and_grad(&z_minus);
2005            let fd = (lp - lm) / (2.0 * eps);
2006            assert!(
2007                (grad[j] - fd).abs() < 1e-6,
2008                "link-wiggle cloglog gradient mismatch at {j}: analytic={}, fd={}",
2009                grad[j],
2010                fd
2011            );
2012        }
2013    }
2014
2015    #[test]
2016    fn nuts_logitgradient_matches_finite_difference() {
2017        let x = array![[1.0, -0.5], [0.2, 0.7], [-1.0, 0.3], [0.5, -1.2]];
2018        let y = array![1.0, 0.0, 1.0, 0.0];
2019        let w = array![1.0, 1.5, 0.8, 1.2];
2020        let penalty = array![[0.4, 0.0], [0.0, 0.6]];
2021        let mode = array![0.1, -0.2];
2022        let hessian = array![[2.0, 0.2], [0.2, 1.7]]; // SPD
2023
2024        let posterior = NutsPosterior::new(
2025            x.view(),
2026            y.view(),
2027            w.view(),
2028            penalty.view(),
2029            mode.view(),
2030            hessian.view(),
2031            NutsFamily::BinomialLogit,
2032            1.0,
2033            gam_solve::estimate::Dispersion::Known(1.0),
2034            true,
2035        )
2036        .expect("posterior");
2037
2038        let z = array![0.15, -0.35];
2039        let (_, grad) = posterior.compute_logp_and_grad_nd(&z);
2040
2041        let eps = 1e-6;
2042        for j in 0..z.len() {
2043            let mut z_plus = z.clone();
2044            let mut z_minus = z.clone();
2045            z_plus[j] += eps;
2046            z_minus[j] -= eps;
2047            let (lp, _) = posterior.compute_logp_and_grad_nd(&z_plus);
2048            let (lm, _) = posterior.compute_logp_and_grad_nd(&z_minus);
2049            let fd = (lp - lm) / (2.0 * eps);
2050            assert_eq!(
2051                grad[j].signum(),
2052                fd.signum(),
2053                "gradient sign mismatch at {}: analytic={}, fd={}",
2054                j,
2055                grad[j],
2056                fd
2057            );
2058            assert!(
2059                (grad[j] - fd).abs() < 1e-5,
2060                "gradient mismatch at {}: analytic={}, fd={}",
2061                j,
2062                grad[j],
2063                fd
2064            );
2065        }
2066    }
2067
2068    #[test]
2069    fn gamma_log_logp_and_grad_uses_fitted_shape() {
2070        let x = array![[1.0_f64], [1.0_f64]];
2071        let y = array![1.5_f64, 2.5_f64];
2072        let weights = array![1.0_f64, 2.0_f64];
2073        let eta = array![0.2_f64, 0.4_f64];
2074        let shape = 3.5_f64;
2075        let data = SharedData {
2076            x: Arc::new(x.clone()),
2077            y: Arc::new(y.clone()),
2078            weights: Arc::new(weights.clone()),
2079            mode: Arc::new(Array1::zeros(1)),
2080            offset: None,
2081            gamma_shape: shape,
2082            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2083            n_samples: x.nrows(),
2084            dim: x.ncols(),
2085        };
2086
2087        let (ll, grad) = super::gamma_log_logp_and_grad(&data, &eta);
2088
2089        let mut expected_ll = 0.0;
2090        let mut expected_score = 0.0;
2091        for i in 0..eta.len() {
2092            let mu = eta[i].exp();
2093            expected_ll += weights[i]
2094                * (shape * shape.ln() - statrs::function::gamma::ln_gamma(shape) - shape * eta[i]
2095                    + (shape - 1.0) * y[i].ln()
2096                    - shape * y[i] / mu);
2097            expected_score += weights[i] * shape * (y[i] / mu - 1.0);
2098        }
2099
2100        assert!((ll - expected_ll).abs() < 1e-12);
2101        assert_eq!(grad.len(), 1);
2102        assert!((grad[0] - expected_score).abs() < 1e-12);
2103    }
2104
2105    /// Gamma observed information at the mode, `Xᵀ diag(w·ν·y/μ) X`, where the
2106    /// per-point curvature `w·ν·y/μ` is exactly `−∂/∂η` of the analytic score
2107    /// slot `w·ν·(y/μ − 1)` used by `gamma_log_logp_and_grad`.
2108    fn gamma_log_observed_information(
2109        x: &Array2<f64>,
2110        mode: &Array1<f64>,
2111        y: &Array1<f64>,
2112        weights: &Array1<f64>,
2113        shape: f64,
2114    ) -> Array2<f64> {
2115        let p = x.ncols();
2116        let eta = x.dot(mode);
2117        let mut h = Array2::<f64>::zeros((p, p));
2118        for i in 0..x.nrows() {
2119            let mu = eta[i].exp();
2120            let wt = weights[i] * shape * y[i] / mu;
2121            for a in 0..p {
2122                for b in 0..p {
2123                    h[[a, b]] += wt * x[[i, a]] * x[[i, b]];
2124                }
2125            }
2126        }
2127        h
2128    }
2129
2130    /// Regression for #680: the whitened GammaLog NUTS target must reproduce
2131    /// the #679 coefficient-covariance contract `Vb = H⁻¹` (scale `1.0`), NOT
2132    /// the dispersion-double-counted `(1/ν)(XᵀΛX + S)⁻¹`.
2133    ///
2134    /// We set the stored Hessian to the *true* penalized curvature of the
2135    /// target at the mode, `H = Xᵀ diag(w·ν·y/μ) X + S` (Gamma observed
2136    /// information + the penalty added **unscaled** — exactly the #679 `H`).
2137    /// The whitened target's curvature in z at the mode is `Lᵀ Hβ L`. The fix
2138    /// makes `L Lᵀ = H⁻¹` and `Hβ = H`, so this is the identity. The pre-fix
2139    /// code scaled the penalty by `ν` and the whitening by `√φ`, turning the
2140    /// z-curvature into `φ·(I + (ν−1)·L_H⁻¹ S L_H⁻ᵀ) ≠ I` (for ν=4 the
2141    /// diagonal collapses toward ~0.25, never 1).
2142    #[test]
2143    fn gamma_log_nuts_target_curvature_matches_unscaled_hessian_issue_680() {
2144        let x = array![[1.0, -0.7], [1.0, 0.3], [1.0, 1.1], [1.0, -0.2], [1.0, 0.8],];
2145        let mode = array![0.4_f64, -0.6_f64];
2146        let y = array![1.2_f64, 0.7, 2.3, 0.9, 1.6];
2147        let weights = array![1.0_f64, 1.5, 0.8, 1.2, 1.0];
2148        // ν = 1/φ = 4 ⇒ φ = 0.25: a large, easily-detectable double-count.
2149        let shape = 4.0_f64;
2150        let p = x.ncols();
2151
2152        let h_data = gamma_log_observed_information(&x, &mode, &y, &weights, shape);
2153        // A genuine PD smoothing penalty so the ×ν double-count is detectable.
2154        let s = array![[0.5_f64, 0.1], [0.1, 0.9]];
2155        let hessian = &h_data + &s;
2156
2157        let target = NutsPosterior::new(
2158            x.view(),
2159            y.view(),
2160            weights.view(),
2161            s.view(),
2162            mode.view(),
2163            hessian.view(),
2164            NutsFamily::GammaLog,
2165            shape,
2166            gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2167            false,
2168        )
2169        .expect("GammaLog NUTS target builds");
2170
2171        // z-space precision at the mode (z = 0) via central differences of the
2172        // analytic gradient: `−∂(∇_z logp)/∂z = Lᵀ Hβ L`. Correct value: I.
2173        let eps = 1e-6;
2174        let z0 = Array1::<f64>::zeros(p);
2175        let mut hz = Array2::<f64>::zeros((p, p));
2176        for j in 0..p {
2177            let mut zp = z0.clone();
2178            let mut zm = z0.clone();
2179            zp[j] += eps;
2180            zm[j] -= eps;
2181            let (_, gp) = target.compute_logp_and_grad_nd(&zp);
2182            let (_, gm) = target.compute_logp_and_grad_nd(&zm);
2183            for a in 0..p {
2184                hz[[a, j]] = -(gp[a] - gm[a]) / (2.0 * eps);
2185            }
2186        }
2187
2188        for a in 0..p {
2189            for b in 0..p {
2190                let expected = if a == b { 1.0 } else { 0.0 };
2191                assert!(
2192                    (hz[[a, b]] - expected).abs() < 1e-4,
2193                    "z-curvature[{a},{b}] = {} (expected {expected}); a non-identity \
2194                     value means the GammaLog target re-introduced the #680 dispersion \
2195                     double-count (penalty ×ν and/or whitening ×√φ)",
2196                    hz[[a, b]]
2197                );
2198            }
2199        }
2200        // Trace = p (identity) rejects the φ-scaled `φ·tr(...)` signature.
2201        let trace: f64 = (0..p).map(|i| hz[[i, i]]).sum();
2202        assert!(
2203            (trace - p as f64).abs() < 1e-3,
2204            "z-curvature trace {trace} ≠ {p}: dispersion double-count signature"
2205        );
2206    }
2207
2208    /// Regression for #680 (whitening half, isolated): for a weight-carries-
2209    /// dispersion family the whitening must satisfy `L Lᵀ = H⁻¹` — i.e.
2210    /// `cov_scale = 1` — so the sampler whitens against the same `H⁻¹` it
2211    /// targets. The pre-fix Gamma path scaled `L` by `√φ`, giving
2212    /// `L Lᵀ = φ·H⁻¹` and `chol·cholᵀ·H = φ·I ≠ I`.
2213    #[test]
2214    fn gamma_log_nuts_whitening_targets_unscaled_inverse_hessian_issue_680() {
2215        let x = array![[1.0, -0.4], [1.0, 0.6], [1.0, 0.1], [1.0, 1.3]];
2216        let mode = array![0.2_f64, 0.3_f64];
2217        let y = array![0.8_f64, 1.7, 1.1, 2.2];
2218        let weights = array![1.0_f64, 1.0, 1.5, 0.7];
2219        let shape = 6.25_f64; // φ = 0.16
2220        let p = x.ncols();
2221        let s = array![[0.3_f64, 0.0], [0.0, 0.7]];
2222        let hessian = &gamma_log_observed_information(&x, &mode, &y, &weights, shape) + &s;
2223
2224        let target = NutsPosterior::new(
2225            x.view(),
2226            y.view(),
2227            weights.view(),
2228            s.view(),
2229            mode.view(),
2230            hessian.view(),
2231            NutsFamily::GammaLog,
2232            shape,
2233            gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2234            false,
2235        )
2236        .expect("GammaLog NUTS target builds");
2237
2238        // chol = L with L Lᵀ = H⁻¹  ⇒  (L Lᵀ) H = I.
2239        let l = target.chol();
2240        let llt = l.dot(&l.t());
2241        let prod = llt.dot(&hessian);
2242        for a in 0..p {
2243            for b in 0..p {
2244                let expected = if a == b { 1.0 } else { 0.0 };
2245                assert!(
2246                    (prod[[a, b]] - expected).abs() < 1e-8,
2247                    "L Lᵀ H[{a},{b}] = {} (expected {expected}); a φ·I result means \
2248                     the Gamma whitening still scales by √φ (#680)",
2249                    prod[[a, b]]
2250                );
2251            }
2252        }
2253    }
2254
2255    #[test]
2256    fn firth_jeffreys_logit_is_finite_for_rank_deficient_design() {
2257        let x = array![
2258            [1.0, -0.5, 1.0],
2259            [1.0, 0.3, 1.0],
2260            [1.0, 0.8, 1.0],
2261            [1.0, -1.2, 1.0],
2262        ];
2263        let y = array![1.0, 0.0, 1.0, 0.0];
2264        let weights = array![1.0, 2.0, 0.5, 1.5];
2265        let eta = array![0.2, -0.1, 0.4, -0.3];
2266
2267        let data = SharedData {
2268            x: Arc::new(x.clone()),
2269            y: Arc::new(y),
2270            weights: Arc::new(weights.clone()),
2271            mode: Arc::new(Array1::zeros(x.ncols())),
2272            offset: None,
2273            gamma_shape: 1.0,
2274            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2275            n_samples: x.nrows(),
2276            dim: x.ncols(),
2277        };
2278
2279        let (value, grad) =
2280            firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &data, &eta).expect("firth");
2281
2282        assert!(value.is_finite());
2283        assert_eq!(grad.len(), x.ncols());
2284        assert!(grad.iter().all(|v| v.is_finite()));
2285    }
2286
2287    #[test]
2288    fn logit_pg_gibbs_returns_finite_samples() {
2289        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2290        let y = array![1.0, 0.0, 1.0, 0.0];
2291        let w = array![1.0, 1.0, 1.0, 1.0];
2292        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2293        let mode = array![0.0, 0.0];
2294        let cfg = NutsConfig {
2295            n_samples: 30,
2296            nwarmup: 30,
2297            n_chains: 2,
2298            target_accept: 0.8,
2299            seed: 123,
2300        };
2301        let out = run_logit_polya_gamma_gibbs(
2302            x.view(),
2303            y.view(),
2304            w.view(),
2305            penalty.view(),
2306            mode.view(),
2307            &cfg,
2308        )
2309        .expect("pg gibbs should run");
2310        assert_eq!(out.samples.ncols(), 2);
2311        assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2312        assert!(out.samples.iter().all(|v| v.is_finite()));
2313        assert!(out.posterior_mean.iter().all(|v| v.is_finite()));
2314        assert!(out.posterior_std.iter().all(|v| v.is_finite()));
2315    }
2316
2317    #[test]
2318    fn family_pg_dispatch_rejects_non_bernoulli_response() {
2319        let x = array![[1.0], [1.0]];
2320        let y = array![2.0, 0.0];
2321        let w = array![1.0, 1.0];
2322        let penalty = array![[0.1]];
2323        let mode = array![0.0];
2324        let non_spd_hessian = array![[0.0]];
2325        let cfg = NutsConfig {
2326            n_samples: 1,
2327            nwarmup: 1,
2328            n_chains: 1,
2329            target_accept: 0.8,
2330            seed: 321,
2331        };
2332
2333        let result = run_nuts_sampling_flattened_family(
2334            LikelihoodSpec::binomial_logit(),
2335            FamilyNutsInputs::Glm(GlmFlatInputs {
2336                x: x.view(),
2337                y: y.view(),
2338                weights: w.view(),
2339                penalty_matrix: penalty.view(),
2340                mode: mode.view(),
2341                hessian: non_spd_hessian.view(),
2342                gamma_shape: None,
2343                dispersion: gam_solve::model_types::Dispersion::Known(1.0),
2344                firth_bias_reduction: false,
2345                offset: None,
2346            }),
2347            &cfg,
2348        );
2349
2350        let err = result.err().expect("PG dispatch should reject count rows");
2351        assert!(
2352            err.contains("response must be exactly 0 or 1"),
2353            "unexpected error: {err}"
2354        );
2355    }
2356
2357    #[test]
2358    fn family_dispatch_uses_pg_gibbs_for_standard_logit() {
2359        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2360        let y = array![1.0, 0.0, 1.0, 0.0];
2361        let w = array![1.0, 1.0, 1.0, 1.0];
2362        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2363        let mode = array![0.0, 0.0];
2364        let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2365        let cfg = NutsConfig {
2366            n_samples: 20,
2367            nwarmup: 20,
2368            n_chains: 2,
2369            target_accept: 0.8,
2370            seed: 456,
2371        };
2372        let out = run_nuts_sampling_flattened_family(
2373            LikelihoodSpec {
2374                response: ResponseFamily::Binomial,
2375                link: InverseLink::Standard(StandardLink::Logit),
2376            },
2377            FamilyNutsInputs::Glm(GlmFlatInputs {
2378                x: x.view(),
2379                y: y.view(),
2380                weights: w.view(),
2381                penalty_matrix: penalty.view(),
2382                mode: mode.view(),
2383                hessian: non_spdhessian.view(),
2384                gamma_shape: None,
2385                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2386                firth_bias_reduction: false,
2387                offset: None,
2388            }),
2389            &cfg,
2390        )
2391        .expect("dispatch should use PG Gibbs and not require Hessian factorization");
2392        assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2393        assert!(out.samples.iter().all(|v| v.is_finite()));
2394    }
2395
2396    #[test]
2397    fn family_dispatch_routes_probit_to_nuts_path() {
2398        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2399        let y = array![1.0, 0.0, 1.0, 0.0];
2400        let w = array![1.0, 1.0, 1.0, 1.0];
2401        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2402        let mode = array![0.0, 0.0];
2403        let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2404        let cfg = NutsConfig {
2405            n_samples: 20,
2406            nwarmup: 20,
2407            n_chains: 2,
2408            target_accept: 0.8,
2409            seed: 654,
2410        };
2411
2412        let err = match run_nuts_sampling_flattened_family(
2413            LikelihoodSpec {
2414                response: ResponseFamily::Binomial,
2415                link: InverseLink::Standard(StandardLink::Probit),
2416            },
2417            FamilyNutsInputs::Glm(GlmFlatInputs {
2418                x: x.view(),
2419                y: y.view(),
2420                weights: w.view(),
2421                penalty_matrix: penalty.view(),
2422                mode: mode.view(),
2423                hessian: non_spdhessian.view(),
2424                gamma_shape: None,
2425                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2426                firth_bias_reduction: false,
2427                offset: None,
2428            }),
2429            &cfg,
2430        ) {
2431            Ok(_) => panic!("non-SPD Hessian should fail after probit routes to the NUTS path"),
2432            Err(err) => err,
2433        };
2434
2435        assert!(
2436            err.contains("Hessian Cholesky decomposition failed"),
2437            "unexpected error: {err}"
2438        );
2439    }
2440
2441    #[test]
2442    fn family_dispatch_rejects_nonbinomial_firth_family() {
2443        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2444        let y = array![1.0, 2.0, 0.0, 3.0];
2445        let w = array![1.0, 1.0, 1.0, 1.0];
2446        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2447        let mode = array![0.0, 0.0];
2448        let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2449        let cfg = NutsConfig {
2450            n_samples: 20,
2451            nwarmup: 20,
2452            n_chains: 2,
2453            target_accept: 0.8,
2454            seed: 111,
2455        };
2456
2457        let err = match run_nuts_sampling_flattened_family(
2458            LikelihoodSpec {
2459                response: ResponseFamily::Poisson,
2460                link: InverseLink::Standard(StandardLink::Log),
2461            },
2462            FamilyNutsInputs::Glm(GlmFlatInputs {
2463                x: x.view(),
2464                y: y.view(),
2465                weights: w.view(),
2466                penalty_matrix: penalty.view(),
2467                mode: mode.view(),
2468                hessian: hessian.view(),
2469                gamma_shape: None,
2470                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2471                firth_bias_reduction: true,
2472                offset: None,
2473            }),
2474            &cfg,
2475        ) {
2476            Ok(_) => panic!("Poisson Firth should be rejected explicitly"),
2477            Err(err) => err,
2478        };
2479
2480        assert!(
2481            err.contains(
2482                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet"
2483            ),
2484            "unexpected error: {err}"
2485        );
2486    }
2487
2488    #[test]
2489    fn run_nuts_sampling_rejects_invalid_target_accept() {
2490        let x = array![[1.0], [1.0], [1.0]];
2491        let y = array![0.5, -0.5, 1.0];
2492        let weights = array![1.0, 1.0, 1.0];
2493        let penalty = array![[0.25]];
2494        let mode = array![0.0];
2495        let hessian = array![[1.25]];
2496        let cfg = NutsConfig {
2497            n_samples: 10,
2498            nwarmup: 10,
2499            n_chains: 1,
2500            target_accept: 1.0,
2501            seed: 222,
2502        };
2503
2504        let err = super::run_nuts_sampling(
2505            x.view(),
2506            y.view(),
2507            weights.view(),
2508            penalty.view(),
2509            mode.view(),
2510            hessian.view(),
2511            NutsFamily::Gaussian,
2512            1.0,
2513            gam_solve::estimate::Dispersion::Known(1.0),
2514            false,
2515            None,
2516            &cfg,
2517        )
2518        .expect_err("invalid target_accept should be rejected before sampling");
2519
2520        assert!(
2521            err.contains("target_accept must be finite and lie in (0, 1)"),
2522            "unexpected error: {err}"
2523        );
2524    }
2525
2526    #[test]
2527    fn run_nuts_sampling_rejects_zero_or_too_few_samples() {
2528        // Issue #399: `samples=0` (and `samples` in {1, 2, 3}) reached the
2529        // engine and panicked across the FFI boundary in `general-mcmc`'s
2530        // `.expect(...)` (empty stack / "split R-hat and ESS require at least 2
2531        // split chains and 2 draws per split chain"). The up-front guard must
2532        // reject anything below the split-R-hat-defined minimum of 4 draws with
2533        // a clean typed error *before* the sampler is constructed.
2534        let x = array![[1.0], [1.0], [1.0]];
2535        let y = array![0.5, -0.5, 1.0];
2536        let weights = array![1.0, 1.0, 1.0];
2537        let penalty = array![[0.25]];
2538        let mode = array![0.0];
2539        let hessian = array![[1.25]];
2540
2541        for bad_samples in [0usize, 1, 2, 3] {
2542            let cfg = NutsConfig {
2543                n_samples: bad_samples,
2544                nwarmup: 10,
2545                n_chains: 2,
2546                target_accept: 0.8,
2547                seed: 222,
2548            };
2549
2550            let err = super::run_nuts_sampling(
2551                x.view(),
2552                y.view(),
2553                weights.view(),
2554                penalty.view(),
2555                mode.view(),
2556                hessian.view(),
2557                NutsFamily::Gaussian,
2558                1.0,
2559                gam_solve::estimate::Dispersion::Known(1.0),
2560                false,
2561                None,
2562                &cfg,
2563            )
2564            .expect_err("too-few samples must be rejected before sampling");
2565
2566            assert!(
2567                err.contains("n_samples must be >= 4"),
2568                "n_samples={bad_samples} gave unexpected error: {err}"
2569            );
2570        }
2571    }
2572
2573    #[test]
2574    fn polya_gamma_gibbs_rejects_degenerate_counts_but_accepts_single_chain() {
2575        // Issue #399 (missed path): the canonical unit-weight Bernoulli-logit
2576        // GAM auto-selects the hand-rolled Pólya-Gamma Gibbs sampler, NOT the
2577        // general-mcmc NUTS engine. Pre-fix that path never validated
2578        // n_samples/n_chains, so `chains=0` / `samples=0` silently returned a
2579        // degenerate empty `(0, p)` posterior instead of the typed error the
2580        // NUTS path raised — a divergent contract on one public API. Assert PG
2581        // now rejects the degenerate counts up front, and (mirroring NUTS)
2582        // still accepts a single chain.
2583        let x = array![[1.0], [1.0], [1.0], [1.0]];
2584        let y = array![1.0, 0.0, 1.0, 0.0];
2585        let weights = array![1.0, 1.0, 1.0, 1.0];
2586        let penalty = array![[0.25]];
2587        let mode = array![0.0];
2588
2589        let zero_chain_cfg = NutsConfig {
2590            n_samples: 20,
2591            nwarmup: 10,
2592            n_chains: 0,
2593            target_accept: 0.8,
2594            seed: 7,
2595        };
2596        let err = super::run_logit_polya_gamma_gibbs(
2597            x.view(),
2598            y.view(),
2599            weights.view(),
2600            penalty.view(),
2601            mode.view(),
2602            &zero_chain_cfg,
2603        )
2604        .expect_err("PG Gibbs must reject zero chains up front, not return an empty posterior");
2605        assert!(
2606            err.contains("n_chains must be >= 1"),
2607            "PG n_chains=0 gave unexpected error: {err}"
2608        );
2609
2610        let zero_sample_cfg = NutsConfig {
2611            n_samples: 0,
2612            nwarmup: 10,
2613            n_chains: 2,
2614            target_accept: 0.8,
2615            seed: 7,
2616        };
2617        let err = super::run_logit_polya_gamma_gibbs(
2618            x.view(),
2619            y.view(),
2620            weights.view(),
2621            penalty.view(),
2622            mode.view(),
2623            &zero_sample_cfg,
2624        )
2625        .expect_err("PG Gibbs must reject zero samples up front, not return an empty posterior");
2626        assert!(
2627            err.contains("n_samples must be >= 4"),
2628            "PG n_samples=0 gave unexpected error: {err}"
2629        );
2630
2631        let single_chain_cfg = NutsConfig {
2632            n_samples: 20,
2633            nwarmup: 10,
2634            n_chains: 1,
2635            target_accept: 0.8,
2636            seed: 7,
2637        };
2638        let result = super::run_logit_polya_gamma_gibbs(
2639            x.view(),
2640            y.view(),
2641            weights.view(),
2642            penalty.view(),
2643            mode.view(),
2644            &single_chain_cfg,
2645        )
2646        .expect("PG Gibbs must accept a single chain and return draws");
2647        assert_eq!(
2648            result.samples.nrows(),
2649            20,
2650            "single-chain PG run should return all 20 requested draws"
2651        );
2652    }
2653
2654    #[test]
2655    fn run_nuts_sampling_rejects_zero_chains_but_accepts_single_chain() {
2656        // Issue #399: only `chains=0` is degenerate — it produces an empty
2657        // initial-position vector and panics in `ndarray::stack`, so it must be
2658        // rejected up front with a typed error.
2659        //
2660        // A *single* chain, by contrast, is a supported, tested configuration
2661        // (`tests/test_sample_seed_is_reproducible.py`,
2662        // `tests/test_posterior_save_no_extension_roundtrip.py`,
2663        // `tests/test_penalty_sampling_survival_diagnostics_regressions.py` all
2664        // sample with `chains=1`): the engine splits each chain in half, so one
2665        // chain still yields the two split-chains the R-hat path needs, and
2666        // `compute_split_rhat_and_ess` early-returns gracefully for
2667        // `n_chains < 2`. The original #399 fix wrongly raised the floor to 2
2668        // and regressed those tests; this asserts `chains=1` *returns draws*.
2669        let x = array![[1.0], [1.0], [1.0]];
2670        let y = array![0.5, -0.5, 1.0];
2671        let weights = array![1.0, 1.0, 1.0];
2672        let penalty = array![[0.25]];
2673        let mode = array![0.0];
2674        let hessian = array![[1.25]];
2675
2676        let zero_chain_cfg = NutsConfig {
2677            n_samples: 50,
2678            nwarmup: 10,
2679            n_chains: 0,
2680            target_accept: 0.8,
2681            seed: 222,
2682        };
2683        let err = super::run_nuts_sampling(
2684            x.view(),
2685            y.view(),
2686            weights.view(),
2687            penalty.view(),
2688            mode.view(),
2689            hessian.view(),
2690            NutsFamily::Gaussian,
2691            1.0,
2692            gam_solve::estimate::Dispersion::Known(1.0),
2693            false,
2694            None,
2695            &zero_chain_cfg,
2696        )
2697        .expect_err("zero chains must be rejected before sampling");
2698        assert!(
2699            err.contains("n_chains must be >= 1"),
2700            "n_chains=0 gave unexpected error: {err}"
2701        );
2702
2703        let single_chain_cfg = NutsConfig {
2704            n_samples: 50,
2705            nwarmup: 10,
2706            n_chains: 1,
2707            target_accept: 0.8,
2708            seed: 222,
2709        };
2710        let result = super::run_nuts_sampling(
2711            x.view(),
2712            y.view(),
2713            weights.view(),
2714            penalty.view(),
2715            mode.view(),
2716            hessian.view(),
2717            NutsFamily::Gaussian,
2718            1.0,
2719            gam_solve::estimate::Dispersion::Known(1.0),
2720            false,
2721            None,
2722            &single_chain_cfg,
2723        )
2724        .expect("a single chain is a supported configuration and must return draws");
2725        assert_eq!(
2726            result.samples.nrows(),
2727            50,
2728            "single-chain run should return all 50 requested draws"
2729        );
2730    }
2731
2732    #[test]
2733    fn joint_hmc_boundary_rejects_nonbinomial_firth_family() {
2734        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2735        let y = array![1.0, 2.0, 0.0, 3.0];
2736        let w = array![1.0, 1.0, 1.0, 1.0];
2737        let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2738        let penalty_root = array![[0.4, 0.0], [0.0, 0.6]];
2739        let mode = array![0.0, 0.0];
2740        let rho_mode = array![0.0];
2741        let cfg = NutsConfig {
2742            n_samples: 20,
2743            nwarmup: 20,
2744            n_chains: 2,
2745            target_accept: 0.8,
2746            seed: 111,
2747        };
2748
2749        let inputs = JointBetaRhoInputs {
2750            x: x.view(),
2751            y: y.view(),
2752            weights: w.view(),
2753            likelihood: LikelihoodSpec {
2754                response: ResponseFamily::Poisson,
2755                link: InverseLink::Standard(StandardLink::Log),
2756            },
2757            gamma_shape: None,
2758            mode: mode.view(),
2759            hessian: hessian.view(),
2760            penalty_roots: vec![CanonicalPenalty::from_dense_root(
2761                penalty_root.clone(),
2762                penalty_root.ncols(),
2763            )],
2764            rho_mode: rho_mode.view(),
2765            rho_prior: RhoPrior::default(),
2766            firth_bias_reduction: true,
2767            trigger_skewness: 0.75,
2768        };
2769
2770        let err = match run_joint_beta_rho_sampling(&inputs, &cfg) {
2771            Ok(_) => panic!("Poisson joint HMC Firth should be rejected explicitly"),
2772            Err(err) => err,
2773        };
2774
2775        assert!(
2776            err.contains(
2777                "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet"
2778            ),
2779            "unexpected error: {err}"
2780        );
2781    }
2782
2783    #[test]
2784    fn joint_hmc_uses_combined_penalty_logdet_for_overlapping_penalties() {
2785        let x = array![[0.0, 0.0]];
2786        let y = array![0.0];
2787        let w = array![0.0];
2788        let mode = array![0.0, 0.0];
2789        let hessian = array![[1.0, 0.0], [0.0, 1.0]];
2790        let rho_mode = array![0.0, 0.0];
2791        let penalty_1 = array![[1.0, 0.0], [0.0, 1.0]];
2792        let penalty_2 = array![[2.0_f64.sqrt(), 0.0], [0.0, 1.0]];
2793        let target = JointBetaRhoPosterior::new(
2794            x.view(),
2795            y.view(),
2796            w.view(),
2797            mode.view(),
2798            hessian.view(),
2799            vec![
2800                CanonicalPenalty::from_dense_root(penalty_1, 2),
2801                CanonicalPenalty::from_dense_root(penalty_2, 2),
2802            ],
2803            rho_mode.view(),
2804            LikelihoodSpec {
2805                response: ResponseFamily::Gaussian,
2806                link: InverseLink::Standard(StandardLink::Identity),
2807            },
2808            None,
2809            RhoPrior::Flat,
2810            false,
2811        )
2812        .expect("joint target");
2813
2814        let params = array![0.0, 0.0, 0.0, 0.0];
2815        let (_, grad) = target.compute_joint_logp_and_grad(&params);
2816        assert!(
2817            (grad[2] - 5.0 / 12.0).abs() < 1.0e-10,
2818            "expected overlapping-penalty gradient 5/12, got {}",
2819            grad[2]
2820        );
2821        assert!(
2822            (grad[3] - 7.0 / 12.0).abs() < 1.0e-10,
2823            "expected overlapping-penalty gradient 7/12, got {}",
2824            grad[3]
2825        );
2826    }
2827
2828    #[test]
2829    fn joint_hmc_target_does_not_depend_on_rho_mode_when_prior_is_fixed() {
2830        let x = array![[0.0]];
2831        let y = array![0.0];
2832        let w = array![0.0];
2833        let mode = array![0.0];
2834        let hessian = array![[1.0]];
2835        let penalty = CanonicalPenalty::from_dense_root(array![[1.0]], 1);
2836        let prior = RhoPrior::Normal {
2837            mean: 0.25,
2838            sd: 1.7,
2839        };
2840
2841        let target_a = JointBetaRhoPosterior::new(
2842            x.view(),
2843            y.view(),
2844            w.view(),
2845            mode.view(),
2846            hessian.view(),
2847            vec![penalty.clone()],
2848            array![0.0].view(),
2849            LikelihoodSpec {
2850                response: ResponseFamily::Gaussian,
2851                link: InverseLink::Standard(StandardLink::Identity),
2852            },
2853            None,
2854            prior.clone(),
2855            false,
2856        )
2857        .expect("target a");
2858        let target_b = JointBetaRhoPosterior::new(
2859            x.view(),
2860            y.view(),
2861            w.view(),
2862            mode.view(),
2863            hessian.view(),
2864            vec![penalty],
2865            array![2.5].view(),
2866            LikelihoodSpec {
2867                response: ResponseFamily::Gaussian,
2868                link: InverseLink::Standard(StandardLink::Identity),
2869            },
2870            None,
2871            prior,
2872            false,
2873        )
2874        .expect("target b");
2875
2876        let params = array![0.0, -0.4];
2877        let (lp_a, grad_a) = target_a.compute_joint_logp_and_grad(&params);
2878        let (lp_b, grad_b) = target_b.compute_joint_logp_and_grad(&params);
2879        assert!((lp_a - lp_b).abs() < 1.0e-12);
2880        for i in 0..grad_a.len() {
2881            assert!(
2882                (grad_a[i] - grad_b[i]).abs() < 1.0e-12,
2883                "rho_mode leaked into target gradient at {}: {} vs {}",
2884                i,
2885                grad_a[i],
2886                grad_b[i]
2887            );
2888        }
2889    }
2890
2891    #[test]
2892    fn joint_hmc_binomial_sas_uses_runtime_link_state() {
2893        let x = array![[1.0], [1.0]];
2894        let y = array![1.0, 0.0];
2895        let weights = array![1.0, 1.0];
2896        let eta = array![0.3, -0.2];
2897        let sas_state =
2898            gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
2899                initial_epsilon: 0.4,
2900                initial_log_delta: -0.2,
2901            })
2902            .expect("sas state");
2903        let data = SharedData {
2904            x: Arc::new(x),
2905            y: Arc::new(y),
2906            weights: Arc::new(weights),
2907            mode: Arc::new(Array1::zeros(1)),
2908            offset: None,
2909            gamma_shape: 1.0,
2910            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2911            n_samples: 2,
2912            dim: 1,
2913        };
2914
2915        let (ll_sas, _) = joint_family_logp_and_grad(
2916            &LikelihoodSpec {
2917                response: ResponseFamily::Binomial,
2918                link: InverseLink::Sas(sas_state),
2919            },
2920            &data,
2921            &eta,
2922        )
2923        .expect("sas joint logp");
2924        let (ll_logit, _) = joint_family_logp_and_grad(
2925            &LikelihoodSpec {
2926                response: ResponseFamily::Binomial,
2927                link: InverseLink::Standard(StandardLink::Logit),
2928            },
2929            &data,
2930            &eta,
2931        )
2932        .expect("logit joint logp");
2933
2934        assert!(
2935            (ll_sas - ll_logit).abs() > 1.0e-6,
2936            "adaptive SAS link should not collapse to the logit likelihood"
2937        );
2938    }
2939
2940    #[test]
2941    fn directional_cubic_diagnostic_is_rotation_invariant_for_hessian_eigenvectors() {
2942        let x = array![[1.0, 0.5], [-0.3, 1.4], [0.8, -1.1]];
2943        let c = array![0.7, -0.5, 0.2];
2944        let h = array![[4.0, 0.0], [0.0, 1.0]];
2945        let theta = std::f64::consts::FRAC_PI_4;
2946        let q = array![[theta.cos(), -theta.sin()], [theta.sin(), theta.cos()],];
2947        let x_rot = x.dot(&q);
2948        let h_rot = q.t().dot(&h).dot(&q);
2949
2950        let (base_max, base_vals) = laplace_directional_cubic_diagnostic(
2951            &h,
2952            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
2953            &c,
2954            true,
2955        )
2956        .expect("base diagnostic");
2957        let (rot_max, rot_vals) = laplace_directional_cubic_diagnostic(
2958            &h_rot,
2959            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x_rot)),
2960            &c,
2961            true,
2962        )
2963        .expect("rotated diagnostic");
2964
2965        let mut base_abs: Vec<f64> = base_vals.iter().map(|v| v.abs()).collect();
2966        let mut rot_abs: Vec<f64> = rot_vals.iter().map(|v| v.abs()).collect();
2967        base_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2968        rot_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2969
2970        assert!((base_max - rot_max).abs() < 1.0e-10);
2971        for i in 0..base_abs.len() {
2972            assert!(
2973                (base_abs[i] - rot_abs[i]).abs() < 1.0e-10,
2974                "directional diagnostic changed under rotation at {}: {} vs {}",
2975                i,
2976                base_abs[i],
2977                rot_abs[i]
2978            );
2979        }
2980    }
2981
2982    /// Verify that joint HMC and REML compute identical penalty logdet
2983    /// derivatives for the same penalty system. This catches any divergence
2984    /// between the two code paths.
2985    #[test]
2986    fn joint_hmc_penalty_logdet_agrees_with_reml_path() {
2987        use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
2988
2989        // Two overlapping 3x3 penalties with non-trivial lambdas.
2990        let root_1 = array![[1.0, 0.5, 0.0], [0.0, 0.8, 0.3]];
2991        let root_2 = array![[0.0, 0.7, 0.0], [0.0, 0.0, 1.2]];
2992        let cp1 = CanonicalPenalty::from_dense_root(root_1, 3);
2993        let cp2 = CanonicalPenalty::from_dense_root(root_2, 3);
2994        let lambdas = [2.5_f64, 0.8];
2995        let penalties = [cp1.clone(), cp2.clone()];
2996
2997        // REML path: PenaltyPseudologdet directly.
2998        let pld =
2999            PenaltyPseudologdet::from_penalties(&penalties, &lambdas, 0.0, 3).expect("reml pld");
3000        let reml_value = pld.value();
3001        let (reml_d1, reml_d2) = pld.rho_derivatives_from_penalties(&penalties, &lambdas);
3002
3003        // Joint HMC path: build a JointBetaRhoPosterior and extract the
3004        // penalty logdet contribution. We isolate it by using zero data
3005        // (so likelihood = 0, penalty quadratic = 0) and Flat rho prior.
3006        let x = Array2::<f64>::zeros((1, 3));
3007        let y = array![0.0];
3008        let w = array![0.0];
3009        let mode = Array1::<f64>::zeros(3);
3010        let hessian = Array2::<f64>::eye(3);
3011        let rho = Array1::from_vec(lambdas.iter().map(|l| l.ln()).collect());
3012        let target = JointBetaRhoPosterior::new(
3013            x.view(),
3014            y.view(),
3015            w.view(),
3016            mode.view(),
3017            hessian.view(),
3018            vec![cp1, cp2],
3019            rho.view(),
3020            LikelihoodSpec {
3021                response: ResponseFamily::Gaussian,
3022                link: InverseLink::Standard(StandardLink::Identity),
3023            },
3024            None,
3025            RhoPrior::Flat,
3026            false,
3027        )
3028        .expect("joint target");
3029
3030        // Evaluate at beta=0, rho=ln(lambdas).
3031        let mut params = Array1::<f64>::zeros(3 + 2);
3032        params[3] = rho[0];
3033        params[4] = rho[1];
3034        let (logp, grad) = target.compute_joint_logp_and_grad(&params);
3035
3036        // logp should be 0.5 * reml_value (likelihood=0, prior=0, quadratic=0).
3037        assert!(
3038            (logp - 0.5 * reml_value).abs() < 1.0e-8,
3039            "joint HMC logdet value {} vs REML 0.5*{} = {}",
3040            logp,
3041            reml_value,
3042            0.5 * reml_value,
3043        );
3044
3045        // grad[3..5] should be 0.5 * reml_d1.
3046        for k in 0..2 {
3047            assert!(
3048                (grad[3 + k] - 0.5 * reml_d1[k]).abs() < 1.0e-8,
3049                "joint HMC logdet gradient[{}] = {} vs REML 0.5*{} = {}",
3050                k,
3051                grad[3 + k],
3052                reml_d1[k],
3053                0.5 * reml_d1[k],
3054            );
3055        }
3056
3057        // Sanity: second derivatives are available from REML but not directly
3058        // from a single HMC gradient call; just verify they're symmetric.
3059        assert!(
3060            (reml_d2[[0, 1]] - reml_d2[[1, 0]]).abs() < 1.0e-12,
3061            "REML penalty logdet Hessian not symmetric"
3062        );
3063    }
3064
3065    /// Verify the family-gating invariant: every LikelihoodSpec that
3066    /// joint_family_logp_and_grad accepts produces a result (not an error
3067    /// about missing implementation). Every family it rejects returns an
3068    /// explicit error. No family is silently remapped to a different one.
3069    #[test]
3070    fn joint_hmc_family_gating_never_remaps() {
3071        let data = SharedData {
3072            x: Arc::new(array![[1.0], [1.0]]),
3073            y: Arc::new(array![1.0, 0.0]),
3074            weights: Arc::new(array![1.0, 1.0]),
3075            mode: Arc::new(Array1::zeros(1)),
3076            offset: None,
3077            gamma_shape: 1.0,
3078            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
3079            n_samples: 2,
3080            dim: 1,
3081        };
3082        let eta = array![0.1, -0.1];
3083
3084        // These families must succeed with their own inverse link.
3085        let accepted = [
3086            LikelihoodSpec {
3087                response: ResponseFamily::Binomial,
3088                link: InverseLink::Standard(StandardLink::Logit),
3089            },
3090            LikelihoodSpec {
3091                response: ResponseFamily::Binomial,
3092                link: InverseLink::Standard(StandardLink::Probit),
3093            },
3094            LikelihoodSpec {
3095                response: ResponseFamily::Binomial,
3096                link: InverseLink::Standard(StandardLink::CLogLog),
3097            },
3098            LikelihoodSpec {
3099                response: ResponseFamily::Gaussian,
3100                link: InverseLink::Standard(StandardLink::Identity),
3101            },
3102            LikelihoodSpec {
3103                response: ResponseFamily::Poisson,
3104                link: InverseLink::Standard(StandardLink::Log),
3105            },
3106            LikelihoodSpec {
3107                response: ResponseFamily::Gamma,
3108                link: InverseLink::Standard(StandardLink::Log),
3109            },
3110        ];
3111        for spec in &accepted {
3112            let result = joint_family_logp_and_grad(spec, &data, &eta);
3113            assert!(
3114                result.is_ok(),
3115                "spec {:?} should be accepted but got error: {:?}",
3116                spec,
3117                result.err(),
3118            );
3119        }
3120
3121        // SAS/BetaLogistic/Mixture must succeed with their real link state,
3122        // NOT be remapped to logit.
3123        let sas_state =
3124            gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3125                initial_epsilon: 0.0,
3126                initial_log_delta: 0.0,
3127            })
3128            .expect("sas state");
3129        let adaptive = [
3130            LikelihoodSpec {
3131                response: ResponseFamily::Binomial,
3132                link: InverseLink::Sas(sas_state),
3133            },
3134            LikelihoodSpec {
3135                response: ResponseFamily::Binomial,
3136                link: InverseLink::BetaLogistic(
3137                    gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3138                        initial_epsilon: 0.0,
3139                        initial_log_delta: 0.0,
3140                    })
3141                    .expect("bl state"),
3142                ),
3143            },
3144        ];
3145        for spec in &adaptive {
3146            let result = joint_family_logp_and_grad(spec, &data, &eta);
3147            assert!(
3148                result.is_ok(),
3149                "adaptive spec {:?} should be accepted with its real link",
3150                spec,
3151            );
3152        }
3153
3154        // RoystonParmar must be explicitly rejected (not silently remapped).
3155        let rp_result = joint_family_logp_and_grad(
3156            &LikelihoodSpec {
3157                response: ResponseFamily::RoystonParmar,
3158                link: InverseLink::Standard(StandardLink::Logit),
3159            },
3160            &data,
3161            &eta,
3162        );
3163        assert!(
3164            rp_result.is_err(),
3165            "RoystonParmar should be rejected, not silently accepted"
3166        );
3167    }
3168
3169    /// The power-iteration refinement should find non-Gaussianity at least
3170    /// as large as the eigenvector-only pass (it's a supremum search).
3171    #[test]
3172    fn directional_cubic_power_iteration_finds_larger_or_equal_skewness() {
3173        // Construct a design where the maximum |gamma| occurs off-axis.
3174        // A single row with asymmetric structure makes the cubic form
3175        // peak between eigenvectors.
3176        let x = array![
3177            [2.0, 1.0],
3178            [-1.0, 2.0],
3179            [0.5, -0.5],
3180            [1.5, 0.3],
3181            [-0.8, 1.7],
3182        ];
3183        let c = array![1.0, -0.5, 0.3, -0.7, 0.4];
3184        let h = array![[3.0, 1.0], [1.0, 2.0]];
3185
3186        let (max_val, eigenvector_vals) = laplace_directional_cubic_diagnostic(
3187            &h,
3188            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
3189            &c,
3190            true,
3191        )
3192        .expect("diagnostic");
3193
3194        // max_val should be >= max of eigenvector-only values.
3195        let eig_max = eigenvector_vals
3196            .iter()
3197            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
3198        assert!(
3199            max_val >= eig_max - 1.0e-12,
3200            "power iteration result {} should be >= eigenvector max {}",
3201            max_val,
3202            eig_max,
3203        );
3204    }
3205
3206    #[test]
3207    fn laplace_trustworthiness_is_block_local_and_threshold_shrinks_with_n() {
3208        // Two directions: one nearly Gaussian (tiny skewness), one strongly
3209        // skewed. The adaptive verdict must flag ONLY the skewed direction —
3210        // this is the block-local behavior #784 requires (keep cheap Laplace
3211        // where the Gaussian summary holds, correct only the curvature-heavy
3212        // block).
3213        let skew = array![0.01, 0.9];
3214
3215        // At a modest effective sample size the skewed direction dominates the
3216        // Laplace floor and must be flagged; the near-Gaussian one must not.
3217        let verdict = laplace_trustworthiness_from_skewness(&skew, 100.0);
3218        assert_eq!(
3219            verdict.untrustworthy_directions,
3220            vec![1],
3221            "only the strongly-skewed direction should be flagged (block-local)",
3222        );
3223        assert!(verdict.fallback_required());
3224        assert!((verdict.max_abs_skewness - 0.9).abs() < 1e-12);
3225
3226        // The threshold must SHRINK as n grows (Laplace gets stricter): a
3227        // direction tolerated at small n becomes untrustworthy at large n,
3228        // because the Gaussian floor it must beat is O(1/n).
3229        let t_small = laplace_skewness_threshold(25.0);
3230        let t_large = laplace_skewness_threshold(10_000.0);
3231        assert!(
3232            t_large < t_small,
3233            "validity threshold must tighten with sample size: {t_large} !< {t_small}",
3234        );
3235
3236        // Degenerate / empty curvature support => everything trustworthy
3237        // (nothing for the Gaussian summary to be wrong about).
3238        let none = laplace_trustworthiness_from_skewness(&skew, 0.0);
3239        assert!(!none.fallback_required());
3240        assert!(none.threshold.is_infinite());
3241    }
3242
3243    /// Synthetic block-excess oracle: an anharmonicity `ΔF(t) = a·Σ_k t_k⁴`
3244    /// whose per-direction strength carries unit ρ-sensitivity, so
3245    /// `∂ΔF/∂ρ_k = a·t_k⁴`. `a = 0` is a pure Gaussian block (exactly zero
3246    /// excess and zero ρ-gradient — the consistency anchor); `a > 0` is the
3247    /// quartic correction oracle the importance sampler is checked against.
3248    struct AnharmonicBlock {
3249        lambdas: Array1<f64>,
3250        a: f64,
3251    }
3252    impl super::BlockExcessTarget for AnharmonicBlock {
3253        fn block_dim(&self) -> usize {
3254            self.lambdas.len()
3255        }
3256        fn rho_dim(&self) -> usize {
3257            self.lambdas.len()
3258        }
3259        fn block_curvatures(&self) -> &Array1<f64> {
3260            &self.lambdas
3261        }
3262        fn excess(&self, t: &Array1<f64>) -> f64 {
3263            self.a * t.iter().map(|&x| x.powi(4)).sum::<f64>()
3264        }
3265        fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3266            t.mapv(|x| self.a * x.powi(4))
3267        }
3268        fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3269            // The synthetic oracle has no observation rows: its ΔF carries no
3270            // deviance channel, so the per-row score moment is empty and the
3271            // (b)–(d) channel assembly contracts against nothing.
3272            assert_eq!(t.len(), self.block_dim(), "displacement dim mismatch");
3273            Array1::zeros(0)
3274        }
3275        fn base_neg_score(&self) -> Array1<f64> {
3276            Array1::zeros(0)
3277        }
3278    }
3279
3280    #[test]
3281    fn block_sampled_marginal_is_zero_for_gaussian_block() {
3282        // A purely Gaussian block has ΔF ≡ 0, so the sampled correction (the
3283        // log-ratio of true to Laplace block free energy) must be exactly 0,
3284        // with a zero ρ-gradient. This is the consistency anchor: where the
3285        // Gaussian summary holds, the fallback is a no-op.
3286        let target = AnharmonicBlock {
3287            lambdas: array![2.0, 0.5],
3288            a: 0.0,
3289        };
3290        let out = super::block_sampled_marginal_correction(&target).expect("correction");
3291        assert!(
3292            out.value.abs() < 1e-12,
3293            "Gaussian block value {}",
3294            out.value
3295        );
3296        assert!(out.rho_gradient.iter().all(|&g| g.abs() < 1e-12));
3297        assert!(out.n_draws > 0);
3298    }
3299
3300    #[test]
3301    fn block_sampled_marginal_recovers_analytic_quartic_correction() {
3302        // 1-D block with a quartic excess ΔF(t) = a t⁴ (a small positive
3303        // anharmonicity). Then exp(Δ_b) = E_{t~N(0,1/λ)}[exp(−a t⁴)], a known
3304        // 1-D integral the IS estimator must recover. We check the sampled Δ_b
3305        // matches a high-accuracy deterministic quadrature of the same
3306        // expectation, and that Δ_b < 0 (an added quartic penalty makes the
3307        // true block mass *smaller* than the Gaussian's).
3308        let lambda = 3.0_f64;
3309        let a = 0.05_f64;
3310        let target = AnharmonicBlock {
3311            lambdas: array![lambda],
3312            a,
3313        };
3314        let out = super::block_sampled_marginal_correction(&target).expect("correction");
3315
3316        // Deterministic reference: Δ_b = log E_{t~N(0,1/λ)}[exp(−a t⁴)] via a
3317        // fine trapezoid rule over the Gaussian density.
3318        let sigma = (1.0 / lambda).sqrt();
3319        let steps = 20_001;
3320        let lo = -8.0 * sigma;
3321        let hi = 8.0 * sigma;
3322        let h = (hi - lo) / (steps as f64 - 1.0);
3323        let mut integral = 0.0_f64;
3324        for i in 0..steps {
3325            let tt = lo + h * i as f64;
3326            let gauss = (-(tt * tt) / (2.0 * sigma * sigma)).exp()
3327                / (sigma * (2.0 * std::f64::consts::PI).sqrt());
3328            let w = if i == 0 || i == steps - 1 { 0.5 } else { 1.0 };
3329            integral += w * gauss * (-a * tt.powi(4)).exp() * h;
3330        }
3331        let reference = integral.ln();
3332        assert!(
3333            (out.value - reference).abs() < 5e-3,
3334            "sampled Δ_b {} vs reference {}",
3335            out.value,
3336            reference,
3337        );
3338        assert!(out.value < 0.0, "quartic penalty must shrink block mass");
3339    }
3340
3341    /// A block target whose excess and per-row score are driven by real design
3342    /// matvecs `s = X·(V_b·t)` — the SAME structure as the production
3343    /// `Gam784BlockTarget` — so it can compute those matvecs either serially
3344    /// (one `fast_av` per draw) or batched (one GEMM over all draws), toggled by
3345    /// `batched`. The two must yield a bit-for-bit (to FP-reassociation
3346    /// tolerance) identical correction: that is exactly the #1082 batching
3347    /// contract — GEMM changes HOW the matvec is computed, never WHAT.
3348    struct MatvecBlock {
3349        lambdas: Array1<f64>,
3350        x: Array2<f64>,
3351        v_b: Array2<f64>,
3352        y: Array1<f64>,
3353        batched: bool,
3354    }
3355    impl MatvecBlock {
3356        fn s_of(&self, t: &Array1<f64>) -> Array1<f64> {
3357            let delta = self.v_b.dot(t);
3358            gam_linalg::faer_ndarray::fast_av(&self.x, &delta)
3359        }
3360        // A smooth, finite, family-like excess + per-row score built from `s`.
3361        fn excess_and_ngs(&self, s: &Array1<f64>) -> (f64, Array1<f64>) {
3362            let mut excess = 0.0;
3363            let mut ngs = Array1::<f64>::zeros(s.len());
3364            for i in 0..s.len() {
3365                let mu = (self.y[i] + s[i]).tanh();
3366                excess += 0.5 * s[i] * s[i] - 0.1 * mu;
3367                ngs[i] = mu - self.y[i];
3368            }
3369            (excess, ngs)
3370        }
3371    }
3372    impl super::BlockExcessTarget for MatvecBlock {
3373        fn block_dim(&self) -> usize {
3374            self.lambdas.len()
3375        }
3376        fn rho_dim(&self) -> usize {
3377            self.lambdas.len()
3378        }
3379        fn block_curvatures(&self) -> &Array1<f64> {
3380            &self.lambdas
3381        }
3382        fn excess(&self, t: &Array1<f64>) -> f64 {
3383            self.excess_and_ngs(&self.s_of(t)).0
3384        }
3385        fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3386            t.mapv(|x| 0.01 * x)
3387        }
3388        fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3389            self.excess_and_ngs(&self.s_of(t)).1
3390        }
3391        fn base_neg_score(&self) -> Array1<f64> {
3392            self.excess_and_ngs(&self.s_of(&Array1::zeros(self.block_dim())))
3393                .1
3394        }
3395        fn excess_with_displaced_neg_score_batch(
3396            &self,
3397            draws: &Array2<f64>,
3398        ) -> Vec<(f64, Option<Array1<f64>>)> {
3399            if !self.batched {
3400                // Serial reference: per-column, exactly the default path.
3401                let mut out = Vec::with_capacity(draws.ncols());
3402                let mut t = Array1::<f64>::zeros(draws.nrows());
3403                for s in 0..draws.ncols() {
3404                    t.assign(&draws.column(s));
3405                    out.push(self.excess_with_displaced_neg_score(&t));
3406                }
3407                return out;
3408            }
3409            // Batched: Δ = V_b·T then S = X·Δ as two GEMMs, then per-column.
3410            let delta_all = gam_linalg::faer_ndarray::fast_ab(&self.v_b, draws);
3411            let s_all = gam_linalg::faer_ndarray::fast_ab(&self.x, &delta_all);
3412            (0..draws.ncols())
3413                .map(|c| {
3414                    let (e, ngs) = self.excess_and_ngs(&s_all.column(c).to_owned());
3415                    if e.is_finite() {
3416                        (e, Some(ngs))
3417                    } else {
3418                        (e, None)
3419                    }
3420                })
3421                .collect()
3422        }
3423    }
3424
3425    #[test]
3426    fn block_sampled_marginal_batched_matches_serial_matvec() {
3427        // Real design / block-frame matvecs, large enough that the GEMM path is
3428        // actually taken (n, p ≥ faer threshold). The batched override must give
3429        // the same correction value, ρ-gradient, and moments as the serial path.
3430        let n = 80usize;
3431        let p = 40usize;
3432        let m = 3usize;
3433        let mut x = Array2::<f64>::zeros((n, p));
3434        for i in 0..n {
3435            for j in 0..p {
3436                x[(i, j)] = ((i * 7 + j * 13) % 11) as f64 * 0.05 - 0.25;
3437            }
3438        }
3439        let mut v_b = Array2::<f64>::zeros((p, m));
3440        for i in 0..p {
3441            for r in 0..m {
3442                v_b[(i, r)] = ((i * 3 + r * 5) % 7) as f64 * 0.1 - 0.3;
3443            }
3444        }
3445        let y: Array1<f64> = (0..n).map(|i| ((i % 5) as f64) * 0.2).collect();
3446        let lambdas = array![2.0, 1.0, 0.5];
3447
3448        let serial = super::block_sampled_marginal_correction(&MatvecBlock {
3449            lambdas: lambdas.clone(),
3450            x: x.clone(),
3451            v_b: v_b.clone(),
3452            y: y.clone(),
3453            batched: false,
3454        })
3455        .expect("serial");
3456        let batched = super::block_sampled_marginal_correction(&MatvecBlock {
3457            lambdas,
3458            x,
3459            v_b,
3460            y,
3461            batched: true,
3462        })
3463        .expect("batched");
3464
3465        assert_eq!(serial.n_draws, batched.n_draws);
3466        assert!(
3467            (serial.value - batched.value).abs() <= 1e-10 * (1.0 + serial.value.abs()),
3468            "value serial {} vs batched {}",
3469            serial.value,
3470            batched.value
3471        );
3472        for k in 0..serial.rho_gradient.len() {
3473            assert!(
3474                (serial.rho_gradient[k] - batched.rho_gradient[k]).abs()
3475                    <= 1e-10 * (1.0 + serial.rho_gradient[k].abs()),
3476                "rho_gradient[{k}] serial {} vs batched {}",
3477                serial.rho_gradient[k],
3478                batched.rho_gradient[k]
3479            );
3480        }
3481        let ms = serial.moments.expect("serial moments");
3482        let mb = batched.moments.expect("batched moments");
3483        for (a, b) in ms.e_t.iter().zip(mb.e_t.iter()) {
3484            assert!((a - b).abs() <= 1e-10 * (1.0 + a.abs()), "e_t {a} vs {b}");
3485        }
3486        for (a, b) in ms.e_neg_score.iter().zip(mb.e_neg_score.iter()) {
3487            assert!(
3488                (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3489                "e_neg_score {a} vs {b}"
3490            );
3491        }
3492        for (a, b) in ms.e_t_neg_score.iter().zip(mb.e_t_neg_score.iter()) {
3493            assert!(
3494                (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3495                "e_t_neg_score {a} vs {b}"
3496            );
3497        }
3498    }
3499
3500    #[test]
3501    fn logit_pg_rao_blackwell_returns_finite_terms() {
3502        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3503        let y = array![1.0, 0.0, 1.0, 0.0];
3504        let w = array![1.0, 1.0, 1.0, 1.0];
3505        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3506        let mode = array![0.0, 0.0];
3507        let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3508        let cfg = NutsConfig {
3509            n_samples: 30,
3510            nwarmup: 30,
3511            n_chains: 2,
3512            target_accept: 0.8,
3513            seed: 789,
3514        };
3515
3516        let rb = super::estimate_logit_pg_rao_blackwell_terms(
3517            x.view(),
3518            y.view(),
3519            w.view(),
3520            penalty.view(),
3521            mode.view(),
3522            &roots,
3523            &cfg,
3524        )
3525        .expect("rao-blackwell PG should run");
3526
3527        assert_eq!(rb.len(), 1);
3528        assert!(rb[0].is_finite());
3529        assert!(rb[0] >= 0.0);
3530    }
3531
3532    #[test]
3533    fn logit_pg_rao_blackwell_rejects_non_bernoulli_response() {
3534        let x = array![[1.0], [1.0]];
3535        let y = array![0.25, 1.0];
3536        let w = array![1.0, 1.0];
3537        let penalty = array![[0.1]];
3538        let mode = array![0.0];
3539        let roots = vec![array![[0.1_f64.sqrt()]]];
3540        let cfg = NutsConfig {
3541            n_samples: 1,
3542            nwarmup: 1,
3543            n_chains: 1,
3544            target_accept: 0.8,
3545            seed: 654,
3546        };
3547
3548        let result = super::estimate_logit_pg_rao_blackwell_terms(
3549            x.view(),
3550            y.view(),
3551            w.view(),
3552            penalty.view(),
3553            mode.view(),
3554            &roots,
3555            &cfg,
3556        );
3557
3558        let err = result
3559            .err()
3560            .expect("PG Rao-Blackwell should reject proportion rows");
3561        assert!(
3562            err.contains("response must be exactly 0 or 1"),
3563            "unexpected error: {err}"
3564        );
3565    }
3566
3567    #[test]
3568    fn logit_pg_rao_blackwell_matches_beta_quadratic_moment_sanity() {
3569        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3570        let y = array![1.0, 0.0, 1.0, 0.0];
3571        let w = array![1.0, 1.0, 1.0, 1.0];
3572        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3573        let mode = array![0.0, 0.0];
3574        let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3575        let cfg = NutsConfig {
3576            n_samples: 120,
3577            nwarmup: 80,
3578            n_chains: 2,
3579            target_accept: 0.8,
3580            seed: 901,
3581        };
3582
3583        let gibbs = run_logit_polya_gamma_gibbs(
3584            x.view(),
3585            y.view(),
3586            w.view(),
3587            penalty.view(),
3588            mode.view(),
3589            &cfg,
3590        )
3591        .expect("pg gibbs should run");
3592        let mc_quad = gibbs
3593            .samples
3594            .rows()
3595            .into_iter()
3596            .map(|beta| {
3597                let sb = penalty.dot(&beta.to_owned());
3598                beta.dot(&sb)
3599            })
3600            .sum::<f64>()
3601            / (gibbs.samples.nrows() as f64);
3602
3603        let rb = super::estimate_logit_pg_rao_blackwell_terms(
3604            x.view(),
3605            y.view(),
3606            w.view(),
3607            penalty.view(),
3608            mode.view(),
3609            &roots,
3610            &cfg,
3611        )
3612        .expect("rao-blackwell PG should run");
3613
3614        let diff = (rb[0] - mc_quad).abs();
3615        assert!(
3616            diff < 0.35,
3617            "Rao-Blackwell vs beta-moment mismatch too large: rb={}, mc={}, diff={}",
3618            rb[0],
3619            mc_quad,
3620            diff
3621        );
3622    }
3623
3624    #[test]
3625    fn survival_hmc_structural_monotonic_returns_finitevalues() {
3626        let age_entry = array![1.0];
3627        let age_exit = array![2.0];
3628        let event_target = array![1u8];
3629        let event_competing = array![0u8];
3630        let sampleweight = array![1.0];
3631        let x_entry = array![[1.0, 0.2]];
3632        let x_exit = array![[1.0, 0.6]];
3633        let x_derivative = array![[0.0, 1.0]];
3634        let penalties = PenaltyBlocks::new(Vec::new());
3635        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3636        let mode = array![0.0, 0.0];
3637        let hessian = Array2::<f64>::eye(2);
3638
3639        let posterior = super::survival_hmc::SurvivalPosterior::new(
3640            age_entry.view(),
3641            age_exit.view(),
3642            event_target.view(),
3643            event_competing.view(),
3644            sampleweight.view(),
3645            x_entry.view(),
3646            x_exit.view(),
3647            x_derivative.view(),
3648            None,
3649            None,
3650            None,
3651            penalties,
3652            monotonicity,
3653            SurvivalSpec::Net,
3654            true,
3655            2,
3656            mode.view(),
3657            hessian.view(),
3658        )
3659        .expect("construct survival posterior");
3660
3661        let position = array![0.0, 0.0];
3662        let mut grad = Array1::<f64>::zeros(2);
3663        let logp = HamiltonianTarget::logp_and_grad(&posterior, &position, &mut grad);
3664        assert!(logp.is_finite());
3665        assert!(grad.iter().all(|v| v.is_finite()));
3666    }
3667
3668    #[test]
3669    fn survival_hmc_structural_monotonic_differs_from_linear_geometry() {
3670        let age_entry = array![1.0];
3671        let age_exit = array![2.0];
3672        let event_target = array![1u8];
3673        let event_competing = array![0u8];
3674        let sampleweight = array![1.0];
3675        let x_entry = array![[0.2, 0.1]];
3676        let x_exit = array![[0.6, 0.3]];
3677        let x_derivative = array![[1.0, 0.0]];
3678        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3679        let mode = array![0.0, 0.0];
3680        let hessian = Array2::<f64>::eye(2);
3681        let z = array![std::f64::consts::LN_2, 0.0];
3682
3683        let posterior_linear = super::survival_hmc::SurvivalPosterior::new(
3684            age_entry.view(),
3685            age_exit.view(),
3686            event_target.view(),
3687            event_competing.view(),
3688            sampleweight.view(),
3689            x_entry.view(),
3690            x_exit.view(),
3691            x_derivative.view(),
3692            None,
3693            None,
3694            None,
3695            PenaltyBlocks::new(Vec::new()),
3696            monotonicity,
3697            SurvivalSpec::Net,
3698            false,
3699            0,
3700            mode.view(),
3701            hessian.view(),
3702        )
3703        .expect("construct linear posterior");
3704        let mut grad_linear = Array1::<f64>::zeros(2);
3705        HamiltonianTarget::logp_and_grad(&posterior_linear, &z, &mut grad_linear);
3706
3707        let posterior_struct = super::survival_hmc::SurvivalPosterior::new(
3708            age_entry.view(),
3709            age_exit.view(),
3710            event_target.view(),
3711            event_competing.view(),
3712            sampleweight.view(),
3713            x_entry.view(),
3714            x_exit.view(),
3715            x_derivative.view(),
3716            None,
3717            None,
3718            None,
3719            PenaltyBlocks::new(Vec::new()),
3720            monotonicity,
3721            SurvivalSpec::Net,
3722            true,
3723            2,
3724            mode.view(),
3725            hessian.view(),
3726        )
3727        .expect("construct structural posterior");
3728        let mut grad_struct = Array1::<f64>::zeros(2);
3729        HamiltonianTarget::logp_and_grad(&posterior_struct, &z, &mut grad_struct);
3730
3731        assert!(
3732            (grad_struct[0] - grad_linear[0]).abs() > 1e-6,
3733            "expected structural and linear fallback gradients to differ"
3734        );
3735        assert!(grad_struct[0].is_finite());
3736        assert!(grad_linear[0].is_finite());
3737    }
3738
3739    #[test]
3740    fn survival_hmc_fallback_barrier_rejects_offsets_below_monotonicity_threshold() {
3741        let age_entry = array![1.0];
3742        let age_exit = array![2.0];
3743        let event_target = array![1u8];
3744        let event_competing = array![0u8];
3745        let sampleweight = array![1.0];
3746        let x_entry = array![[1.0, 0.0]];
3747        let x_exit = array![[1.0, 0.0]];
3748        // Zero derivative design so derivative_offset_exit drives d_eta/dt.
3749        let x_derivative = array![[0.0, 0.0]];
3750        let penalties = PenaltyBlocks::new(Vec::new());
3751        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3752        let mode = array![0.0, 0.0];
3753        let hessian = Array2::<f64>::eye(2);
3754        let z = array![0.0, 0.0];
3755
3756        let posterior_no_offset = super::survival_hmc::SurvivalPosterior::new(
3757            age_entry.view(),
3758            age_exit.view(),
3759            event_target.view(),
3760            event_competing.view(),
3761            sampleweight.view(),
3762            x_entry.view(),
3763            x_exit.view(),
3764            x_derivative.view(),
3765            None,
3766            None,
3767            Some(array![0.0].view()),
3768            penalties.clone(),
3769            monotonicity,
3770            SurvivalSpec::Net,
3771            false,
3772            0,
3773            mode.view(),
3774            hessian.view(),
3775        )
3776        .expect("construct posterior without derivative offset");
3777        let mut grad_no_offset = Array1::<f64>::zeros(2);
3778        let logp_no_offset =
3779            HamiltonianTarget::logp_and_grad(&posterior_no_offset, &z, &mut grad_no_offset);
3780
3781        let posteriorwith_offset = super::survival_hmc::SurvivalPosterior::new(
3782            age_entry.view(),
3783            age_exit.view(),
3784            event_target.view(),
3785            event_competing.view(),
3786            sampleweight.view(),
3787            x_entry.view(),
3788            x_exit.view(),
3789            x_derivative.view(),
3790            None,
3791            None,
3792            Some(array![2.0].view()),
3793            penalties,
3794            monotonicity,
3795            SurvivalSpec::Net,
3796            false,
3797            0,
3798            mode.view(),
3799            hessian.view(),
3800        )
3801        .expect("construct posterior with derivative offset");
3802        let mut gradwith_offset = Array1::<f64>::zeros(2);
3803        let logpwith_offset =
3804            HamiltonianTarget::logp_and_grad(&posteriorwith_offset, &z, &mut gradwith_offset);
3805
3806        assert!(!logp_no_offset.is_finite());
3807        assert!(!logpwith_offset.is_finite());
3808        assert!(grad_no_offset.iter().all(|v| *v == 0.0));
3809        assert!(gradwith_offset.iter().all(|v| *v == 0.0));
3810    }
3811
3812    #[test]
3813    fn survival_hmc_fallback_barrier_becomes_finite_once_offset_clears_guard() {
3814        let age_entry = array![1.0];
3815        let age_exit = array![2.0];
3816        let event_target = array![1u8];
3817        let event_competing = array![0u8];
3818        let sampleweight = array![1.0];
3819        let x_entry = array![[1.0, 0.0]];
3820        let x_exit = array![[1.0, 0.0]];
3821        let x_derivative = array![[0.0, 0.0]];
3822        let penalties = PenaltyBlocks::new(Vec::new());
3823        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3824        let mode = array![0.0, 0.0];
3825        let hessian = Array2::<f64>::eye(2);
3826        let z = array![0.0, 0.0];
3827
3828        let posterior_below_guard = super::survival_hmc::SurvivalPosterior::new(
3829            age_entry.view(),
3830            age_exit.view(),
3831            event_target.view(),
3832            event_competing.view(),
3833            sampleweight.view(),
3834            x_entry.view(),
3835            x_exit.view(),
3836            x_derivative.view(),
3837            None,
3838            None,
3839            Some(array![2.0].view()),
3840            penalties.clone(),
3841            monotonicity,
3842            SurvivalSpec::Net,
3843            false,
3844            0,
3845            mode.view(),
3846            hessian.view(),
3847        )
3848        .expect("construct posterior below derivative guard");
3849        let mut grad_below_guard = Array1::<f64>::zeros(2);
3850        let logp_below_guard =
3851            HamiltonianTarget::logp_and_grad(&posterior_below_guard, &z, &mut grad_below_guard);
3852
3853        let posterior_above_guard = super::survival_hmc::SurvivalPosterior::new(
3854            age_entry.view(),
3855            age_exit.view(),
3856            event_target.view(),
3857            event_competing.view(),
3858            sampleweight.view(),
3859            x_entry.view(),
3860            x_exit.view(),
3861            x_derivative.view(),
3862            None,
3863            None,
3864            Some(array![3.1].view()),
3865            penalties,
3866            monotonicity,
3867            SurvivalSpec::Net,
3868            false,
3869            0,
3870            mode.view(),
3871            hessian.view(),
3872        )
3873        .expect("construct posterior above derivative guard");
3874        let mut grad_above_guard = Array1::<f64>::zeros(2);
3875        let logp_above_guard =
3876            HamiltonianTarget::logp_and_grad(&posterior_above_guard, &z, &mut grad_above_guard);
3877
3878        assert!(!logp_below_guard.is_finite());
3879        assert!(logp_above_guard.is_finite());
3880        assert!(grad_below_guard.iter().all(|v| *v == 0.0));
3881        assert!(grad_above_guard.iter().all(|v| v.is_finite()));
3882    }
3883
3884    #[test]
3885    fn survival_hmc_structural_monotonic_handles_sparse_multirow_geometry() {
3886        let age_entry = array![1.0, 1.2];
3887        let age_exit = array![2.0, 2.4];
3888        let event_target = array![1u8, 1u8];
3889        let event_competing = array![0u8, 0u8];
3890        let sampleweight = array![1.0, 1.0];
3891        let x_entry = array![[0.1, 0.0, 0.2], [0.2, 0.1, 0.2]];
3892        let x_exit = array![[0.4, 0.2, 0.3], [0.6, 0.1, 0.3]];
3893        // First row constrains only column 0, second row constrains columns 0 and 1.
3894        let x_derivative = array![[1.0, 0.0, 0.0], [0.5, 1.0, 0.0]];
3895        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3896        let mode = array![4.0, 2.0, 0.0];
3897        let hessian = Array2::<f64>::eye(3);
3898        let z = array![0.05, -0.1, 0.15];
3899
3900        let posterior = super::survival_hmc::SurvivalPosterior::new(
3901            age_entry.view(),
3902            age_exit.view(),
3903            event_target.view(),
3904            event_competing.view(),
3905            sampleweight.view(),
3906            x_entry.view(),
3907            x_exit.view(),
3908            x_derivative.view(),
3909            None,
3910            None,
3911            None,
3912            PenaltyBlocks::new(Vec::new()),
3913            monotonicity,
3914            SurvivalSpec::Net,
3915            true,
3916            2,
3917            mode.view(),
3918            hessian.view(),
3919        )
3920        .expect("construct structural posterior");
3921
3922        let mut grad = Array1::<f64>::zeros(3);
3923        let logp = HamiltonianTarget::logp_and_grad(&posterior, &z, &mut grad);
3924        assert!(logp.is_finite());
3925        assert!(grad.iter().all(|v| v.is_finite()));
3926    }
3927}
3928
3929/// Implement HamiltonianTarget for NUTS with analytical gradients.
3930impl HamiltonianTarget<Array1<f64>> for NutsPosterior {
3931    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
3932        NUTS_RESIDUAL_SCRATCH.with(|scratch| {
3933            let mut residual = scratch.borrow_mut();
3934            if residual.len() != self.data.n_samples {
3935                *residual = Array1::<f64>::zeros(self.data.n_samples);
3936            }
3937            self.compute_logp_and_grad_nd_into(position, &mut residual, grad)
3938        })
3939    }
3940}
3941
3942/// Configuration for NUTS sampling.
3943#[derive(Clone, Debug, Serialize, Deserialize)]
3944pub struct NutsConfig {
3945    /// Number of samples to collect (after warmup)
3946    pub n_samples: usize,
3947    /// Number of warmup samples to discard
3948    pub nwarmup: usize,
3949    /// Number of parallel chains
3950    pub n_chains: usize,
3951    /// Target acceptance probability (0.6-0.9 recommended)
3952    pub target_accept: f64,
3953    /// Seed for deterministic chain initialization
3954    #[serde(default = "default_nuts_seed")]
3955    pub seed: u64,
3956}
3957
3958fn default_nuts_seed() -> u64 {
3959    42
3960}
3961
3962fn validate_nuts_target_accept(target_accept: f64) -> Result<(), HmcError> {
3963    if target_accept.is_finite() && target_accept > 0.0 && target_accept < 1.0 {
3964        Ok(())
3965    } else {
3966        Err(HmcError::InvalidConfig {
3967            reason: format!(
3968                "NUTS target_accept must be finite and lie in (0, 1), got {target_accept}"
3969            ),
3970        })
3971    }
3972}
3973
3974/// Minimum number of post-warmup draws per chain that keeps the split-R-hat /
3975/// ESS machinery well-defined. Each chain is split in half for the
3976/// Gelman-Rubin diagnostic (`compute_split_rhat_and_ess` and the engine's own
3977/// run-stats path), so both halves need at least two draws, i.e. four draws
3978/// total. Below this the engine `.expect(...)` calls (empty-stack / "split
3979/// R-hat and ESS require at least 2 split chains and 2 draws per split chain")
3980/// panic across the FFI boundary instead of returning a typed error.
3981const MIN_NUTS_SAMPLES: usize = 4;
3982
3983/// Minimum number of parallel chains. With zero chains the engine receives an
3984/// empty initial-position vector and panics in `ndarray::stack` (and the
3985/// Laplace fallback would produce an empty `(0, p)` posterior). A *single*
3986/// chain is well-defined and is a supported, tested configuration: the engine
3987/// splits each chain in half for the diagnostic, so one chain still yields the
3988/// two split-chains the R-hat path needs, and `compute_split_rhat_and_ess`
3989/// gracefully early-returns for `n_chains < 2`. We therefore only reject the
3990/// genuinely-degenerate `n_chains == 0`.
3991const MIN_NUTS_CHAINS: usize = 1;
3992
3993/// Validate the draw / chain counts of a NUTS configuration up front, mirroring
3994/// `validate_nuts_target_accept`, so that out-of-range values surface as a typed
3995/// `HmcError::InvalidConfig` *before* the sampling engine is constructed rather
3996/// than as a panic caught at the FFI boundary.
3997fn validate_nuts_draws(config: &NutsConfig) -> Result<(), HmcError> {
3998    if config.n_chains < MIN_NUTS_CHAINS {
3999        return Err(HmcError::InvalidConfig {
4000            reason: format!(
4001                "NUTS n_chains must be >= {MIN_NUTS_CHAINS}; with zero chains the \
4002                 sampler has no initial positions to run, got {}",
4003                config.n_chains
4004            ),
4005        });
4006    }
4007    if config.n_samples < MIN_NUTS_SAMPLES {
4008        return Err(HmcError::InvalidConfig {
4009            reason: format!(
4010                "NUTS n_samples must be >= {MIN_NUTS_SAMPLES} so split-R-hat / ESS \
4011                 diagnostics are defined, got {}",
4012                config.n_samples
4013            ),
4014        });
4015    }
4016    Ok(())
4017}
4018
4019/// Full up-front validation of a NUTS configuration shared by every sampling
4020/// entry point (dense NUTS, link-wiggle, joint (β, ρ), survival, the
4021/// auto-selected Pólya-Gamma Gibbs path, and the Laplace-Gaussian fallback).
4022pub(crate) fn validate_nuts_config(config: &NutsConfig) -> Result<(), HmcError> {
4023    validate_nuts_target_accept(config.target_accept)?;
4024    validate_nuts_draws(config)?;
4025    Ok(())
4026}
4027
4028#[inline]
4029fn splitmix64(x: u64) -> u64 {
4030    gam_linalg::utils::splitmix64_hash(x)
4031}
4032
4033#[inline]
4034fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
4035    splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
4036}
4037
4038#[inline]
4039fn nuts_transition_seed(seed: u64, stream: u64) -> u64 {
4040    splitmix64(seed ^ stream ^ 0xA24B_AED4_963E_E407)
4041}
4042
4043#[inline]
4044fn gibbs_pg_seed(seed: u64, chain: usize, stream: u64, iter: usize) -> u64 {
4045    chain_stream_seed(
4046        seed,
4047        chain,
4048        stream ^ ((iter as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
4049    )
4050}
4051
4052fn draw_logit_pg1_omega(
4053    shapes: ArrayView1<'_, u32>,
4054    tilts: ArrayView1<'_, f64>,
4055    seed: u64,
4056    out: &mut Array1<f64>,
4057) -> Result<(), String> {
4058    if out.len() != tilts.len() {
4059        return Err(HmcError::DimensionMismatch {
4060            reason: "draw_logit_pg1_omega: output length mismatch".to_string(),
4061        }
4062        .into());
4063    }
4064    let draws = crate::gpu_polya_gamma::draw_batch(PolyaGammaBatchInput {
4065        shapes,
4066        tilts,
4067        seed: PgSeed(seed),
4068    })?;
4069    out.assign(&draws);
4070    out.mapv_inplace(|v| v.max(1.0e-12));
4071    Ok(())
4072}
4073
4074/// Parameter dimension above which the posterior is treated as "high-dimensional"
4075/// for the purpose of the more conservative sampler heuristics below: a higher
4076/// target-acceptance floor (smaller leapfrog steps) and stronger mass-matrix
4077/// regularization. The boundary matches the `dense_max_dim` cap at which the
4078/// engine stops attempting dense mass-matrix adaptation.
4079const HIGH_DIM_THRESHOLD: usize = 50;
4080
4081/// Target-acceptance floor enforced for high-dimensional posteriors
4082/// (`dim > HIGH_DIM_THRESHOLD`). NUTS efficiency degrades faster with too-large
4083/// steps in high dimensions, so we refuse to honor a requested accept below this.
4084const HIGH_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.92;
4085/// Target-acceptance floor for low-dimensional posteriors.
4086const LOW_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.90;
4087/// Upper bound on the effective target acceptance. Pushing target accept toward
4088/// 1 collapses the step size and stalls mixing, so we cap the requested value.
4089const MAX_TARGET_ACCEPT: f64 = 0.95;
4090
4091/// Minimum warmup length below which mass-matrix adaptation is disabled: the
4092/// windowed (Stan-style) adaptation schedule needs enough warmup iterations to
4093/// populate its initial / terminal buffers, otherwise the estimated metric is
4094/// noise. With fewer warmup steps the sampler runs on the identity metric.
4095const MIN_WARMUP_FOR_MASS_ADAPT: usize = 80;
4096
4097/// Largest parameter dimension for which the engine attempts *dense* mass-matrix
4098/// adaptation; above this it falls back to a diagonal metric (an `O(p²)` dense
4099/// metric is neither affordable nor reliably estimable from limited warmup).
4100const DENSE_MASS_MATRIX_MAX_DIM: usize = 75;
4101
4102/// Mass-matrix ridge (added to the diagonal of the estimated metric) for the
4103/// general (mean-family) sampler. The high-dimensional value is larger because
4104/// the warmup metric estimate is noisier relative to its scale as `p` grows.
4105const MASS_REGULARIZE_HIGH_DIM: f64 = 0.14;
4106const MASS_REGULARIZE_LOW_DIM: f64 = 0.10;
4107/// Mass-matrix ridge for survival posteriors, which are frequently skewed by
4108/// censoring / rare events and so warrant a heavier ridge than the mean family.
4109const SURVIVAL_MASS_REGULARIZE_HIGH_DIM: f64 = 0.18;
4110const SURVIVAL_MASS_REGULARIZE_LOW_DIM: f64 = 0.12;
4111
4112/// Jitter added during mass-matrix inversion to keep the metric strictly
4113/// positive-definite against round-off in the warmup covariance estimate.
4114const MASS_MATRIX_JITTER: f64 = 1e-5;
4115
4116#[inline]
4117fn robust_target_accept(requested: f64, dim: usize) -> f64 {
4118    let floor = if dim > HIGH_DIM_THRESHOLD {
4119        HIGH_DIM_TARGET_ACCEPT_FLOOR
4120    } else {
4121        LOW_DIM_TARGET_ACCEPT_FLOOR
4122    };
4123    requested.max(floor).min(MAX_TARGET_ACCEPT)
4124}
4125
4126fn jittered_initial_positions(
4127    config: &NutsConfig,
4128    dim: usize,
4129    scale: f64,
4130    stream: u64,
4131) -> Vec<Array1<f64>> {
4132    (0..config.n_chains)
4133        .map(|chain| {
4134            let mut rng = StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, stream));
4135            Array1::from_shape_fn(dim, |_| sample_standard_normal(&mut rng) * scale)
4136        })
4137        .collect()
4138}
4139
4140fn robust_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4141    if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4142        return NUTSMassMatrixConfig::disabled();
4143    }
4144    let start_buffer = (nwarmup / 8).clamp(35, 180);
4145    let end_buffer = (nwarmup / 5).clamp(50, 250);
4146    let initial_window = (nwarmup / 20).clamp(10, 60);
4147    NUTSMassMatrixConfig {
4148        adaptation: MassMatrixAdaptation::Diagonal,
4149        start_buffer,
4150        end_buffer,
4151        initial_window,
4152        regularize: if dim > HIGH_DIM_THRESHOLD {
4153            MASS_REGULARIZE_HIGH_DIM
4154        } else {
4155            MASS_REGULARIZE_LOW_DIM
4156        },
4157        jitter: MASS_MATRIX_JITTER,
4158        dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4159    }
4160}
4161
4162fn robust_survival_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4163    if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4164        return NUTSMassMatrixConfig::disabled();
4165    }
4166    // Survival posteriors with censoring/rare events are often skewed; this
4167    // configuration uses diagonal adaptation.
4168    let start_buffer = (nwarmup / 7).clamp(40, 200);
4169    let end_buffer = (nwarmup / 4).clamp(60, 280);
4170    let initial_window = (nwarmup / 20).clamp(10, 60);
4171    NUTSMassMatrixConfig {
4172        adaptation: MassMatrixAdaptation::Diagonal,
4173        start_buffer,
4174        end_buffer,
4175        initial_window,
4176        regularize: if dim > HIGH_DIM_THRESHOLD {
4177            SURVIVAL_MASS_REGULARIZE_HIGH_DIM
4178        } else {
4179            SURVIVAL_MASS_REGULARIZE_LOW_DIM
4180        },
4181        jitter: MASS_MATRIX_JITTER,
4182        dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4183    }
4184}
4185
4186impl Default for NutsConfig {
4187    fn default() -> Self {
4188        Self {
4189            n_samples: 1000,
4190            nwarmup: 500,
4191            n_chains: 4,
4192            target_accept: 0.9,
4193            seed: 42,
4194        }
4195    }
4196}
4197
4198impl NutsConfig {
4199    /// Create a config with sample counts tuned for the model dimension.
4200    ///
4201    /// Higher dimensions need more samples because:
4202    /// - ESS decreases with dimension (autocorrelation grows)
4203    /// - Split R-hat needs enough samples per chain to be meaningful
4204    ///
4205    /// Rule of thumb: target 100 effective samples per parameter.
4206    pub fn for_dimension(n_params: usize) -> Self {
4207        // ESS ≈ n_samples / (1 + 2τ) where τ ≈ sqrt(dim) for well-tuned NUTS
4208        let effective_autocorr = (n_params as f64).sqrt().max(1.0);
4209
4210        // Target: at least 100 effective samples per parameter
4211        let target_ess = 100 * n_params;
4212
4213        // Samples needed = ESS * (1 + 2τ), with 1.5x safety factor
4214        let raw_samples = (target_ess as f64 * (1.0 + 2.0 * effective_autocorr) * 1.5) as usize;
4215
4216        // Clamp to reasonable range [500, 10000]
4217        let n_samples = raw_samples.clamp(500, 10_000);
4218
4219        // Warmup ≈ samples (standard practice for adaptation)
4220        let nwarmup = n_samples;
4221
4222        // More chains for higher dims (better R-hat estimation)
4223        let n_chains = if n_params > 50 { 4 } else { 2 };
4224
4225        Self {
4226            n_samples,
4227            nwarmup,
4228            n_chains,
4229            target_accept: 0.9,
4230            seed: 42,
4231        }
4232    }
4233}
4234
4235/// Result of NUTS sampling.
4236#[derive(Clone, Debug)]
4237pub struct NutsResult {
4238    /// Coefficient samples in ORIGINAL space: shape (n_total_samples, n_coeffs)
4239    pub samples: Array2<f64>,
4240    /// Posterior mean
4241    pub posterior_mean: Array1<f64>,
4242    /// Posterior standard deviation
4243    pub posterior_std: Array1<f64>,
4244    /// R-hat convergence diagnostic
4245    pub rhat: f64,
4246    /// Effective sample size
4247    pub ess: f64,
4248    /// Whether sampling converged (R-hat < 1.1)
4249    pub converged: bool,
4250}
4251
4252#[derive(Clone, Copy)]
4253struct NutsConvergenceThresholds {
4254    max_rhat: f64,
4255    min_ess: Option<f64>,
4256}
4257
4258impl NutsConvergenceThresholds {
4259    #[inline]
4260    fn converged(self, rhat: f64, ess: f64) -> bool {
4261        let rhat_ok = rhat < self.max_rhat;
4262        match self.min_ess {
4263            Some(min_ess) => rhat_ok && ess > min_ess,
4264            None => rhat_ok,
4265        }
4266    }
4267}
4268
4269fn run_whitened_nuts_samples<Target>(
4270    target: Target,
4271    initial_positions: Vec<Array1<f64>>,
4272    config: &NutsConfig,
4273    dim: usize,
4274    mass_cfg: NUTSMassMatrixConfig,
4275    transition_seed_stream: u64,
4276    sampling_error_label: &str,
4277) -> Result<(Array3<f64>, String), String>
4278where
4279    Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4280{
4281    let mut sampler = GenericNUTS::new_with_mass_matrix(
4282        target,
4283        initial_positions,
4284        robust_target_accept(config.target_accept, dim),
4285        mass_cfg,
4286    )
4287    .set_seed(nuts_transition_seed(config.seed, transition_seed_stream));
4288
4289    let (samples_array, run_stats) = sampler
4290        .run_progress(config.n_samples, config.nwarmup)
4291        .map_err(|e| format!("{sampling_error_label}: {e}"))?;
4292    Ok((samples_array, run_stats.to_string()))
4293}
4294
4295fn unwhiten_samples(
4296    samples_array: &Array3<f64>,
4297    mode: &Array1<f64>,
4298    chol: &Array2<f64>,
4299    dim: usize,
4300    z_start: usize,
4301) -> Array2<f64> {
4302    let shape = samples_array.shape();
4303    let n_chains = shape[0];
4304    let n_samples_out = shape[1];
4305    let total_samples = n_chains * n_samples_out;
4306
4307    let mut samples = Array2::<f64>::zeros((total_samples, dim));
4308    let mut z_buffer = Array1::<f64>::zeros(dim);
4309    for chain in 0..n_chains {
4310        for sample_i in 0..n_samples_out {
4311            let zview = samples_array.slice(ndarray::s![chain, sample_i, z_start..z_start + dim]);
4312            z_buffer.assign(&zview);
4313            let beta = mode + &chol.dot(&z_buffer);
4314            let sample_idx = chain * n_samples_out + sample_i;
4315            samples.row_mut(sample_idx).assign(&beta);
4316        }
4317    }
4318
4319    samples
4320}
4321
4322fn summarize_unwhitened_nuts_samples(
4323    samples: Array2<f64>,
4324    samples_array: &Array3<f64>,
4325    empty_mean: Array1<f64>,
4326    convergence: NutsConvergenceThresholds,
4327) -> NutsResult {
4328    let posterior_mean = samples.mean_axis(Axis(0)).unwrap_or(empty_mean);
4329    let posterior_std = samples.std_axis(Axis(0), 0.0);
4330    let (rhat, ess) = compute_split_rhat_and_ess(samples_array);
4331    let converged = convergence.converged(rhat, ess);
4332
4333    NutsResult {
4334        samples,
4335        posterior_mean,
4336        posterior_std,
4337        rhat,
4338        ess,
4339        converged,
4340    }
4341}
4342
4343fn run_whitened_nuts_result<Target>(
4344    target: Target,
4345    mode: &Array1<f64>,
4346    chol: &Array2<f64>,
4347    initial_positions: Vec<Array1<f64>>,
4348    config: &NutsConfig,
4349    dim: usize,
4350    mass_cfg: NUTSMassMatrixConfig,
4351    transition_seed_stream: u64,
4352    sampling_error_label: &str,
4353    empty_mean: Array1<f64>,
4354    convergence: NutsConvergenceThresholds,
4355) -> Result<(NutsResult, String), String>
4356where
4357    Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4358{
4359    let (samples_array, run_stats) = run_whitened_nuts_samples(
4360        target,
4361        initial_positions,
4362        config,
4363        dim,
4364        mass_cfg,
4365        transition_seed_stream,
4366        sampling_error_label,
4367    )?;
4368    let samples = unwhiten_samples(&samples_array, mode, chol, dim, 0);
4369    let result =
4370        summarize_unwhitened_nuts_samples(samples, &samples_array, empty_mean, convergence);
4371    Ok((result, run_stats))
4372}
4373
4374impl NutsResult {
4375    /// Computes the posterior mean of a function applied to coefficients.
4376    /// Returns 0.0 if samples is empty to avoid divide-by-zero.
4377    pub fn posterior_mean_of<F>(&self, f: F) -> f64
4378    where
4379        F: Fn(ArrayView1<f64>) -> f64 + Sync,
4380    {
4381        let n = self.samples.nrows();
4382        if n == 0 {
4383            return 0.0;
4384        }
4385        // Posterior mean of a sample-function: parallel reduction over rows.
4386        // `f: Fn(ArrayView1) -> f64` is shared-access so safe across threads.
4387        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4388        let sum: f64 = (0..n).into_par_iter().map(|i| f(self.samples.row(i))).sum();
4389        sum / n as f64
4390    }
4391
4392    /// Computes percentiles of a function applied to coefficients.
4393    pub fn posterior_interval_of<F>(&self, f: F, lower_pct: f64, upper_pct: f64) -> (f64, f64)
4394    where
4395        F: Fn(ArrayView1<f64>) -> f64,
4396    {
4397        let n = self.samples.nrows();
4398        if n == 0 {
4399            return (0.0, 0.0);
4400        }
4401        let mut values: Vec<f64> = (0..n).map(|i| f(self.samples.row(i))).collect();
4402        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4403
4404        let lower_idx = ((lower_pct / 100.0) * n as f64).floor() as usize;
4405        let upper_idx = ((upper_pct / 100.0) * n as f64).ceil() as usize;
4406
4407        (
4408            values[lower_idx.min(n.saturating_sub(1))],
4409            values[upper_idx.min(n.saturating_sub(1))],
4410        )
4411    }
4412}
4413
4414#[inline]
4415fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
4416    let u1 = rng.random::<f64>().max(1e-16);
4417    let u2 = rng.random::<f64>();
4418    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
4419}
4420
4421/// Runs a Pólya-Gamma Gibbs sampler for Bernoulli-logit models.
4422///
4423/// This sampler is gradient-free: each iteration alternates
4424/// 1) ω_i | β, y ~ PG(1, x_i^T β), and
4425/// 2) β | ω, y ~ N(Q^{-1} b, Q^{-1}), with Q = S + X^T diag(ω) X, b = X^T(y - 1/2).
4426///
4427/// For weighted data, this implementation is defined for weights ≈ 1.0 because it
4428/// samples PG(1,·) latent variables.
4429pub fn run_logit_polya_gamma_gibbs(
4430    x: ArrayView2<f64>,
4431    y: ArrayView1<f64>,
4432    weights: ArrayView1<f64>,
4433    penalty_matrix: ArrayView2<f64>,
4434    mode: ArrayView1<f64>,
4435    config: &NutsConfig,
4436) -> Result<NutsResult, String> {
4437    let n = x.nrows();
4438    let p = x.ncols();
4439    if y.len() != n || weights.len() != n {
4440        return Err(HmcError::DimensionMismatch {
4441            reason: "run_logit_polya_gamma_gibbs: input length mismatch".to_string(),
4442        }
4443        .into());
4444    }
4445    if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4446        return Err(HmcError::DimensionMismatch {
4447            reason: "run_logit_polya_gamma_gibbs: coefficient/penalty dimension mismatch"
4448                .to_string(),
4449        }
4450        .into());
4451    }
4452    if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4453        return Err(HmcError::InvalidConfig {
4454            reason: "run_logit_polya_gamma_gibbs requires unit weights (PG(1,·)); use NUTS for non-unit weights".to_string(),
4455        }
4456        .into());
4457    }
4458    validate_binary_responses("run_logit_polya_gamma_gibbs", &y, &weights).map_err(String::from)?;
4459    // Issue #399: the auto-selected PG-Gibbs path is reached for the canonical
4460    // unit-weight Bernoulli-logit GAM. Without this guard, `n_chains == 0` /
4461    // `n_samples == 0` would not panic but silently return a degenerate empty
4462    // `(0, p)` posterior, diverging from the typed error the NUTS path raises
4463    // for the same inputs. Route it through the shared validator so every
4464    // `Model.sample` surface rejects degenerate draw/chain counts identically.
4465    validate_nuts_config(config).map_err(String::from)?;
4466
4467    let n_iter = config.nwarmup + config.n_samples;
4468
4469    // b = X^T (y - 1/2), constant across iterations.
4470    let kappa = y.mapv(|v| v - 0.5);
4471    let rhs_b = fast_atv(&x, &kappa);
4472
4473    let mut samples_array = Array3::<f64>::zeros((config.n_chains, config.n_samples, p));
4474    let mut eta = Array1::<f64>::zeros(n);
4475    let mut omega = Array1::<f64>::ones(n);
4476    let pg_shapes = Array1::<u32>::from_elem(n, 1);
4477    let mut xw = x.to_owned();
4478    let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4479    let penalty = penalty_matrix.to_owned();
4480    let mut q = Array2::<f64>::zeros((p, p));
4481    let mut mean = Array1::<f64>::zeros(p);
4482    let mut z = Array1::<f64>::zeros(p);
4483    let mut noise = Array1::<f64>::zeros(p);
4484
4485    for chain in 0..config.n_chains {
4486        let mut init_rng =
4487            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xB3C4_5A1F_8E9D_7632));
4488        let mut draw_rng =
4489            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x17A9_26D5_4C1B_E083));
4490        let mut beta = mode.to_owned();
4491        // Small jitter so chains are not perfectly coupled.
4492        for j in 0..p {
4493            beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4494        }
4495
4496        for iter in 0..n_iter {
4497            eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4498            draw_logit_pg1_omega(
4499                pg_shapes.view(),
4500                eta.view(),
4501                gibbs_pg_seed(config.seed, chain, 0x4D94_DF4E_5D72_81AB, iter),
4502                &mut omega,
4503            )?;
4504
4505            // Build Xweighted = diag(sqrt(ω)) X and compute X^T Ω X via faer GEMM.
4506            // Per-row scaling is fully independent across rows.
4507            ndarray::Zip::indexed(xw.rows_mut())
4508                .and(x.rows())
4509                .and(&omega)
4510                .par_for_each(|_idx, mut xw_row, x_row, omega_i| {
4511                    let s = omega_i.sqrt();
4512                    for j in 0..p {
4513                        xw_row[j] = x_row[j] * s;
4514                    }
4515                });
4516            fast_ata_into(&xw, &mut xt_omega_x);
4517
4518            q.assign(&penalty);
4519            q += &xt_omega_x;
4520
4521            // β | ω,y ~ N(Q^{-1} b, Q^{-1})
4522            let factor = q
4523                .cholesky(Side::Lower)
4524                .map_err(|e| format!("PG Gibbs failed to factor Q: {:?}", e))?;
4525            mean.assign(&factor.solvevec(&rhs_b));
4526
4527            for j in 0..p {
4528                z[j] = sample_standard_normal(&mut draw_rng);
4529            }
4530            let l = factor.lower_triangular();
4531            back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4532            beta.assign(&(&mean + &noise));
4533
4534            if iter >= config.nwarmup {
4535                let keep_idx = iter - config.nwarmup;
4536                samples_array
4537                    .slice_mut(ndarray::s![chain, keep_idx, ..])
4538                    .assign(&beta);
4539            }
4540        }
4541    }
4542
4543    let total_samples = config.n_chains * config.n_samples;
4544    let mut samples = Array2::<f64>::zeros((total_samples, p));
4545    for chain in 0..config.n_chains {
4546        for s in 0..config.n_samples {
4547            let idx = chain * config.n_samples + s;
4548            samples
4549                .row_mut(idx)
4550                .assign(&samples_array.slice(ndarray::s![chain, s, ..]));
4551        }
4552    }
4553
4554    let posterior_mean = samples
4555        .mean_axis(Axis(0))
4556        .unwrap_or_else(|| Array1::zeros(p));
4557    let posterior_std = samples.std_axis(Axis(0), 0.0);
4558    let (rhat, ess) = if config.n_chains >= 2 && config.n_samples >= 4 {
4559        compute_split_rhat_and_ess(&samples_array)
4560    } else {
4561        (1.0, (total_samples as f64) * 0.5)
4562    };
4563    let converged = rhat < 1.1 && ess > 100.0;
4564
4565    Ok(NutsResult {
4566        samples,
4567        posterior_mean,
4568        posterior_std,
4569        rhat,
4570        ess,
4571        converged,
4572    })
4573}
4574
4575/// Estimate E_{ω|y,ρ}[ tr(S_k Q^{-1}) + μᵀ S_k μ ] with PG Gibbs + Rao-Blackwellization.
4576///
4577/// For each retained Gibbs state ω:
4578///   Q = S + Xᵀ diag(ω) X,  μ = Q^{-1} Xᵀ(y-1/2),
4579/// and with S_k = R_kᵀ R_k:
4580///   tr(S_k Q^{-1}) + μᵀ S_k μ
4581/// = tr(R_k Q^{-1} R_kᵀ) + ||R_k μ||².
4582///
4583/// Returns one expectation per penalty block k, averaged over retained draws.
4584pub fn estimate_logit_pg_rao_blackwell_terms(
4585    x: ArrayView2<f64>,
4586    y: ArrayView1<f64>,
4587    weights: ArrayView1<f64>,
4588    penalty_matrix: ArrayView2<f64>,
4589    mode: ArrayView1<f64>,
4590    penalty_roots: &[Array2<f64>],
4591    config: &NutsConfig,
4592) -> Result<Array1<f64>, String> {
4593    let n = x.nrows();
4594    let p = x.ncols();
4595    if y.len() != n || weights.len() != n {
4596        return Err(HmcError::DimensionMismatch {
4597            reason: "estimate_logit_pg_rao_blackwell_terms: input length mismatch".to_string(),
4598        }
4599        .into());
4600    }
4601    if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4602        return Err(HmcError::DimensionMismatch {
4603            reason: "estimate_logit_pg_rao_blackwell_terms: coefficient/penalty dimension mismatch"
4604                .to_string(),
4605        }
4606        .into());
4607    }
4608    if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4609        return Err(HmcError::InvalidConfig {
4610            reason: "estimate_logit_pg_rao_blackwell_terms requires unit weights (PG(1,·))"
4611                .to_string(),
4612        }
4613        .into());
4614    }
4615    validate_binary_responses("estimate_logit_pg_rao_blackwell_terms", &y, &weights)
4616        .map_err(String::from)?;
4617    if penalty_roots.iter().any(|r| r.ncols() != p) {
4618        return Err(HmcError::DimensionMismatch {
4619            reason: "estimate_logit_pg_rao_blackwell_terms: root width mismatch".to_string(),
4620        }
4621        .into());
4622    }
4623    // Precompute transposed root blocks once:
4624    //   R_k^T is the RHS used for batched solves Q X = R_k^T.
4625    let penalty_roots_t: Vec<Array2<f64>> =
4626        penalty_roots.iter().map(|r| r.t().to_owned()).collect();
4627
4628    let n_iter = config.nwarmup + config.n_samples;
4629
4630    // Logistic PG identity uses kappa_i = y_i - 1/2 so that
4631    // b = X^T kappa in the Gaussian conditional for beta|omega.
4632    let kappa = y.mapv(|v| v - 0.5);
4633    let rhs_b = fast_atv(&x, &kappa);
4634
4635    let penalty = penalty_matrix.to_owned();
4636    let mut eta = Array1::<f64>::zeros(n);
4637    let mut omega = Array1::<f64>::ones(n);
4638    let pg_shapes = Array1::<u32>::from_elem(n, 1);
4639    let mut xw = x.to_owned();
4640    let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4641    let mut q = Array2::<f64>::zeros((p, p));
4642    let mut mean = Array1::<f64>::zeros(p);
4643    let mut rb_sum = Array1::<f64>::zeros(penalty_roots.len());
4644    let mut z = Array1::<f64>::zeros(p);
4645    let mut noise = Array1::<f64>::zeros(p);
4646
4647    let mut kept = 0usize;
4648    for chain in 0..config.n_chains {
4649        let mut init_rng =
4650            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x28F0_7B65_1A4D_C93E));
4651        let mut draw_rng =
4652            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xC642_6E35_B5A9_1D80));
4653        let mut beta = mode.to_owned();
4654        for j in 0..p {
4655            beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4656        }
4657
4658        for iter in 0..n_iter {
4659            eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4660            draw_logit_pg1_omega(
4661                pg_shapes.view(),
4662                eta.view(),
4663                gibbs_pg_seed(config.seed, chain, 0x83F1_56C9_A7E0_2D4B, iter),
4664                &mut omega,
4665            )?;
4666
4667            ndarray::Zip::from(xw.rows_mut())
4668                .and(x.rows())
4669                .and(&omega)
4670                .par_for_each(|mut xw_row, x_row, &omega_i| {
4671                    let s = omega_i.sqrt();
4672                    for j in 0..p {
4673                        xw_row[j] = x_row[j] * s;
4674                    }
4675                });
4676            fast_ata_into(&xw, &mut xt_omega_x);
4677
4678            // Conditional precision:
4679            //   Q = S + X^T diag(omega) X.
4680            q.assign(&penalty);
4681            q += &xt_omega_x;
4682
4683            let factor = q
4684                .cholesky(Side::Lower)
4685                .map_err(|e| format!("PG Rao-Blackwell failed to factor Q: {:?}", e))?;
4686            // Conditional mean:
4687            //   mu = Q^{-1} b,  b = X^T(y - 1/2).
4688            mean.assign(&factor.solvevec(&rhs_b));
4689
4690            // Draw beta for the next Gibbs state.
4691            for j in 0..p {
4692                z[j] = sample_standard_normal(&mut draw_rng);
4693            }
4694            let l = factor.lower_triangular();
4695            back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4696            beta.assign(&(&mean + &noise));
4697
4698            if iter < config.nwarmup {
4699                continue;
4700            }
4701            kept += 1;
4702
4703            for (k, r_k) in penalty_roots.iter().enumerate() {
4704                if r_k.nrows() == 0 {
4705                    continue;
4706                }
4707
4708                // mu^T S_k mu via root form S_k = R_k^T R_k.
4709                let rmu = r_k.dot(&mean);
4710                let mu_quad = rmu.dot(&rmu);
4711
4712                // Batched trace solve:
4713                //   V_k = Q^{-1} R_k^T  (single multi-RHS solve)
4714                // then tr(R_k Q^{-1} R_k^T) = <R_k, V_k^T>_F.
4715                let solved_mat = factor.solve_mat(&penalty_roots_t[k]); // (p, r_k)
4716                let solved_t = solved_mat.t();
4717                let mut trace_term = 0.0_f64;
4718                for (&a, &b) in r_k.iter().zip(solved_t.iter()) {
4719                    trace_term += a * b;
4720                }
4721
4722                rb_sum[k] += trace_term + mu_quad;
4723            }
4724        }
4725    }
4726
4727    if kept == 0 {
4728        return Err(HmcError::SamplingFailed {
4729            reason: "estimate_logit_pg_rao_blackwell_terms: no retained samples".to_string(),
4730        }
4731        .into());
4732    }
4733    let out = rb_sum.mapv(|v| v / (kept as f64));
4734    if !out.iter().all(|v| v.is_finite()) {
4735        return Err(HmcError::NonFiniteState {
4736            reason: "estimate_logit_pg_rao_blackwell_terms: non-finite expectation".to_string(),
4737        }
4738        .into());
4739    }
4740    Ok(out)
4741}
4742
4743/// Runs NUTS sampling using general-mcmc with whitened parameter space.
4744///
4745/// # Arguments
4746/// * `x` - Design matrix [n_samples, dim]
4747/// * `y` - Response vector [n_samples]
4748/// * `weights` - Observation/case weights [n_samples]
4749/// * `penalty_matrix` - Combined penalty S [dim, dim]
4750/// * `mode` - MAP estimate μ [dim]
4751/// * `hessian` - Penalized Hessian H [dim, dim] (NOT the inverse!)
4752/// * `nuts_family` - Family for log-likelihood computation
4753/// * `firth_bias_reduction` - Whether Firth bias reduction was used in training
4754/// * `config` - NUTS configuration
4755pub(crate) fn run_nuts_sampling(
4756    x: ArrayView2<f64>,
4757    y: ArrayView1<f64>,
4758    weights: ArrayView1<f64>,
4759    penalty_matrix: ArrayView2<f64>,
4760    mode: ArrayView1<f64>,
4761    hessian: ArrayView2<f64>,
4762    nuts_family: NutsFamily,
4763    gamma_shape: f64,
4764    dispersion: gam_solve::model_types::Dispersion,
4765    firth_bias_reduction: bool,
4766    offset: Option<ArrayView1<f64>>,
4767    config: &NutsConfig,
4768) -> Result<NutsResult, String> {
4769    validate_firth_support(nuts_family, firth_bias_reduction).map_err(String::from)?;
4770    validate_nuts_config(config).map_err(String::from)?;
4771    if nuts_family == NutsFamily::TweedieLog && !is_valid_tweedie_power(gamma_shape) {
4772        return Err(format!(
4773            "Tweedie variance power must be finite and strictly between 1 and 2; got {gamma_shape}"
4774        ));
4775    }
4776    let dim = mode.len();
4777
4778    // Create posterior target with analytical gradients. When Firth is enabled,
4779    // this target includes the identifiable-subspace Jeffreys term.
4780    let target = NutsPosterior::new(
4781        x,
4782        y,
4783        weights,
4784        penalty_matrix,
4785        mode,
4786        hessian,
4787        nuts_family,
4788        gamma_shape,
4789        dispersion,
4790        firth_bias_reduction,
4791    )?;
4792    let target = match offset {
4793        Some(offset) => target.with_offset(offset)?,
4794        None => target,
4795    };
4796
4797    // Get Cholesky factor for un-whitening samples later
4798    let chol = target.chol().clone();
4799    let mode_arr = target.mode().clone();
4800
4801    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x0F65_83B2_BC71_4D9E);
4802    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4803    let (result, run_stats) = run_whitened_nuts_result(
4804        target,
4805        &mode_arr,
4806        &chol,
4807        initial_positions,
4808        config,
4809        dim,
4810        mass_cfg,
4811        0xF1D3_C2B5_A697_804E,
4812        "NUTS sampling failed",
4813        Array1::zeros(dim),
4814        NutsConvergenceThresholds {
4815            max_rhat: 1.1,
4816            min_ess: Some(100.0),
4817        },
4818    )?;
4819    log::info!("NUTS sampling complete: {}", run_stats);
4820
4821    Ok(result)
4822}
4823
4824/// Terminal never-fail Gaussian-posterior sampling target.
4825///
4826/// This is the bottom rung of the solver's geometry-driven escalation ladder.
4827/// When the outer smoothing optimizer cannot certify convergence on a custom
4828/// (BMS / general) family — typically because Strong-Wolfe stalls on an
4829/// indefinite or non-smooth LAML objective — the driver no longer dead-ends
4830/// with an `Err`. Instead it lands here: the *same* penalized objective's
4831/// curvature (its penalized joint Hessian `H = −∇²log L + Σ_k λ_k S_k`,
4832/// augmented with the proper (unconditional) Jeffreys/PC term)
4833/// is used as the precision of a proper Gaussian posterior `N(β̂, H⁻¹)` about
4834/// the best mode `β̂` the inner solve reached. Sampling a multivariate normal
4835/// cannot fail: in the worst case (a poorly conditioned `H`) the intervals come
4836/// out honestly wider, which is the intended "magic for all users" behavior —
4837/// a finite point with calibrated SEs instead of a hard error.
4838///
4839/// The target is expressed in the whitened space `z` (`β = β̂ + L z`,
4840/// `L Lᵀ = H⁻¹`), where the posterior is the standard normal `N(0, I)`. Its
4841/// log-density and gradient are then exactly `logp(z) = −½ zᵀz`,
4842/// `∇ = −z` — a smooth, globally coercive target with no failure mode. The
4843/// `chol` factor un-whitens draws back to coefficient space, identically to
4844/// the `NutsPosterior` whitening contract above.
4845struct GaussianModeTarget;
4846
4847impl HamiltonianTarget<Array1<f64>> for GaussianModeTarget {
4848    #[inline]
4849    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4850        // Standard-normal target in whitened coordinates: logp = -0.5 zᵀz,
4851        // ∇ = -z. The whitening `L` (built from the penalized Hessian) carries
4852        // all of the posterior geometry, so the sampler itself only ever sees a
4853        // unit-covariance Gaussian — which is why this rung cannot stall.
4854        let mut quad = 0.0;
4855        for (g, &zi) in grad.iter_mut().zip(position.iter()) {
4856            *g = -zi;
4857            quad += zi * zi;
4858        }
4859        -0.5 * quad
4860    }
4861}
4862
4863/// Sample the proper Gaussian posterior `N(mode, H⁻¹)` defined by a mode and a
4864/// (penalized, Jeffreys-augmented) SPD precision `hessian`.
4865///
4866/// This is the terminal, never-fail rung of the outer-optimizer escalation:
4867/// it consumes the same penalized-objective curvature the inner machinery
4868/// already computed and returns an honest posterior summary. It returns `Err`
4869/// only for a *structurally* impossible request (dimension mismatch, a Hessian
4870/// that is not even positive-definite after symmetrization, a degenerate
4871/// config) — never for "did not converge", which is precisely the dead-end this
4872/// path exists to remove.
4873///
4874/// `hessian` must be the SPD penalized joint Hessian at `mode` (e.g. from
4875/// `compute_joint_geometry`). It is symmetrized defensively and Cholesky-
4876/// factored to build the whitening `L` with `L Lᵀ = H⁻¹`.
4877pub fn sample_gaussian_mode_posterior(
4878    mode: ArrayView1<f64>,
4879    hessian: ArrayView2<f64>,
4880    config: &NutsConfig,
4881) -> Result<GaussianModePosterior, String> {
4882    validate_nuts_config(config).map_err(String::from)?;
4883    let dim = mode.len();
4884    if hessian.nrows() != dim || hessian.ncols() != dim {
4885        return Err(format!(
4886            "Gaussian-posterior fallback: hessian shape {:?} does not match mode dim {dim}",
4887            hessian.dim()
4888        ));
4889    }
4890    if dim == 0 {
4891        return Err("Gaussian-posterior fallback: zero-dimensional posterior".to_string());
4892    }
4893
4894    // Symmetrize defensively (the assembled joint Hessian may carry
4895    // floating-point asymmetry from directional-callback construction) and add
4896    // a tiny jitter on the diagonal so a Hessian that is SPD-up-to-roundoff at a
4897    // boundary optimum still factors. The jitter only ever *widens* the
4898    // posterior, consistent with the honest-interval guarantee.
4899    let mut h = hessian.to_owned();
4900    for i in 0..dim {
4901        for j in (i + 1)..dim {
4902            let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
4903            h[[i, j]] = avg;
4904            h[[j, i]] = avg;
4905        }
4906    }
4907    let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
4908    let jitter = (diag_scale * 1e-10).max(1e-12);
4909    for i in 0..dim {
4910        h[[i, i]] += jitter;
4911    }
4912
4913    let mode_owned = mode.to_owned();
4914    let whitening = hessian_whitening_transform(
4915        h.view(),
4916        dim,
4917        1.0,
4918        "Gaussian-posterior fallback Cholesky failed",
4919    )?;
4920    let chol = whitening.chol;
4921    let target = GaussianModeTarget;
4922    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x51A6_2C73_90E4_1DBF);
4923    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4924    let (result, run_stats) = run_whitened_nuts_result(
4925        target,
4926        &mode_owned,
4927        &chol,
4928        initial_positions,
4929        config,
4930        dim,
4931        mass_cfg,
4932        0x7C19_5A3E_82D6_44B1,
4933        "Gaussian-posterior fallback NUTS sampling failed",
4934        mode_owned.clone(),
4935        NutsConvergenceThresholds {
4936            max_rhat: 1.1,
4937            min_ess: None,
4938        },
4939    )?;
4940    log::info!(
4941        "never-fail Gaussian-posterior fallback: sampling complete dim={dim} {}",
4942        run_stats
4943    );
4944
4945    Ok(GaussianModePosterior {
4946        samples: result.samples,
4947        posterior_mean: result.posterior_mean,
4948        posterior_std: result.posterior_std,
4949        rhat: result.rhat,
4950        ess: result.ess,
4951    })
4952}
4953
4954/// Penalty subtracted from the log-density when the `ρ`-criterion closure
4955/// reports an infeasible / non-finite point during Tier-2 `ρ`-posterior NUTS
4956/// (#938). The fallback density is the whitened standard normal shifted down by
4957/// this constant, so the sampler sees a smooth, coercive pull back toward the
4958/// feasible region around `ρ̂` instead of a `-inf` cliff.
4959const RHO_NUTS_INFEASIBLE_LOGP_PENALTY: f64 = 1.0e8;
4960
4961/// Tier-2 of the exact marginal-smoothing inference stack (#938): the whitened
4962/// `ρ`-criterion Hamiltonian target.
4963///
4964/// This reuses the module's β-level whitening design ONE LEVEL UP: the target
4965/// log-density is `logp(ρ) = −(criterion(ρ) − criterion(ρ̂))` — i.e.
4966/// `π(ρ|y) ∝ exp(−LAML(ρ))`, the exact profiled criterion the outer optimizer
4967/// minimizes — expressed in the whitened coordinates `ρ = ρ̂ + L z` with
4968/// `L Lᵀ = H_ρ⁻¹` built from the exact outer Hessian at `ρ̂`. The gradient is
4969/// the caller's exact profiled `ρ`-gradient pushed through the chain rule:
4970/// `∇_z logp = −Lᵀ ∇_ρ criterion`.
4971///
4972/// The criterion closure is `FnMut` (each evaluation is one warm inner profile
4973/// solve with interior caches), so it is serialized behind a `Mutex`; chains
4974/// take turns evaluating, which also keeps the inner warm-start trajectory
4975/// coherent.
4976struct WhitenedRhoCriterionTarget<F> {
4977    /// `ρ ↦ (criterion(ρ), ∇_ρ criterion(ρ))`; `None` marks an infeasible point.
4978    criterion_and_grad: Mutex<F>,
4979    /// `ρ̂`, the converged smoothing parameters (the whitening center).
4980    mode: Array1<f64>,
4981    /// `L` with `L Lᵀ = H_ρ⁻¹`: maps whitened `z` to `ρ = ρ̂ + L z`.
4982    chol: Array2<f64>,
4983    /// `Lᵀ`, for the gradient chain rule.
4984    chol_t: Array2<f64>,
4985    /// `criterion(ρ̂)`, subtracted for numerical stability (cancels in MCMC).
4986    cost_hat: f64,
4987}
4988
4989impl<F> HamiltonianTarget<Array1<f64>> for WhitenedRhoCriterionTarget<F>
4990where
4991    F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
4992{
4993    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4994        let rho = &self.mode + &self.chol.dot(position);
4995        let eval = {
4996            let mut criterion = self
4997                .criterion_and_grad
4998                .lock()
4999                .expect("rho-criterion mutex poisoned");
5000            (*criterion)(&rho)
5001        };
5002        match eval {
5003            Some((cost, g))
5004                if cost.is_finite()
5005                    && g.len() == position.len()
5006                    && g.iter().all(|v| v.is_finite()) =>
5007            {
5008                let grad_z = self.chol_t.dot(&g);
5009                for (gi, &v) in grad.iter_mut().zip(grad_z.iter()) {
5010                    *gi = -v;
5011                }
5012                -(cost - self.cost_hat)
5013            }
5014            _ => {
5015                // Infeasible criterion: smooth coercive fallback toward ρ̂.
5016                let mut quad = 0.0;
5017                for (gi, &zi) in grad.iter_mut().zip(position.iter()) {
5018                    *gi = -zi;
5019                    quad += zi * zi;
5020                }
5021                -0.5 * quad - RHO_NUTS_INFEASIBLE_LOGP_PENALTY
5022            }
5023        }
5024    }
5025}
5026
5027/// Run NUTS over the smoothing parameters `ρ` with the exact profiled criterion
5028/// and gradient (#938 Tier 2).
5029///
5030/// * `rho_hat` — converged `ρ̂` (the whitening center and chain seed).
5031/// * `outer_hessian` — exact outer Hessian `H_ρ` at `ρ̂` (symmetrized and
5032///   jittered defensively, then Cholesky-factored for the whitening).
5033/// * `criterion_and_grad` — `ρ ↦ (LAML(ρ), ∇_ρ LAML(ρ))`, both exact; `None`
5034///   for infeasible `ρ`. Each call is one warm inner profile solve.
5035/// * `config` — sampler configuration; determinism comes from `config.seed`
5036///   through the same splitmix64 chain/transition streams as every other NUTS
5037///   entry point (no clock, no global RNG).
5038///
5039/// Returns draws in the ORIGINAL `ρ` space (un-whitened), with split-R̂/ESS
5040/// diagnostics.
5041pub fn run_rho_criterion_nuts<F>(
5042    rho_hat: ArrayView1<f64>,
5043    outer_hessian: ArrayView2<f64>,
5044    mut criterion_and_grad: F,
5045    config: &NutsConfig,
5046) -> Result<NutsResult, String>
5047where
5048    F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
5049{
5050    validate_nuts_config(config).map_err(String::from)?;
5051    let dim = rho_hat.len();
5052    if dim == 0 {
5053        return Err("rho-posterior NUTS: zero-dimensional rho".to_string());
5054    }
5055    if outer_hessian.nrows() != dim || outer_hessian.ncols() != dim {
5056        return Err(format!(
5057            "rho-posterior NUTS: outer Hessian shape {:?} does not match rho dim {dim}",
5058            outer_hessian.dim()
5059        ));
5060    }
5061
5062    // Symmetrize + jitter the exact outer Hessian so a boundary optimum that is
5063    // SPD-up-to-roundoff still factors; jitter only widens the proposal metric.
5064    let mut h = outer_hessian.to_owned();
5065    for i in 0..dim {
5066        for j in (i + 1)..dim {
5067            let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
5068            h[[i, j]] = avg;
5069            h[[j, i]] = avg;
5070        }
5071    }
5072    let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
5073    let jitter = (diag_scale * 1e-10).max(1e-12);
5074    for i in 0..dim {
5075        h[[i, i]] += jitter;
5076    }
5077
5078    let mode = rho_hat.to_owned();
5079    let whitening = hessian_whitening_transform(
5080        h.view(),
5081        dim,
5082        1.0,
5083        "rho-posterior NUTS: outer-Hessian Cholesky failed",
5084    )?;
5085
5086    let cost_hat = match criterion_and_grad(&mode) {
5087        Some((cost, _)) if cost.is_finite() => cost,
5088        _ => {
5089            return Err(
5090                "rho-posterior NUTS: criterion is infeasible at rho_hat itself".to_string(),
5091            );
5092        }
5093    };
5094
5095    let chol = whitening.chol;
5096    let target = WhitenedRhoCriterionTarget {
5097        criterion_and_grad: Mutex::new(criterion_and_grad),
5098        mode: mode.clone(),
5099        chol: chol.clone(),
5100        chol_t: whitening.chol_t,
5101        cost_hat,
5102    };
5103    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x3D8A_91C4_E27B_5F60);
5104    // The rho target is already whitened by the exact outer Hessian at rho_hat,
5105    // so the local mass matrix in z-space is identity. Re-adapting a diagonal or
5106    // dense metric during warmup would spend expensive profile solves estimating
5107    // curvature we have already supplied analytically.
5108    let mass_cfg = NUTSMassMatrixConfig::disabled();
5109    let (result, run_stats) = run_whitened_nuts_result(
5110        target,
5111        &mode,
5112        &chol,
5113        initial_positions,
5114        config,
5115        dim,
5116        mass_cfg,
5117        0x6B42_E9A1_05D7_C83F,
5118        "rho-posterior NUTS sampling failed",
5119        mode.clone(),
5120        NutsConvergenceThresholds {
5121            max_rhat: 1.1,
5122            min_ess: None,
5123        },
5124    )?;
5125    log::info!("rho-posterior NUTS (#938 tier 2): sampling complete dim={dim} {run_stats}");
5126    Ok(result)
5127}
5128
5129/// Flattened numeric inputs for GLM-family NUTS sampling.
5130pub struct GlmFlatInputs<'a> {
5131    pub x: ArrayView2<'a, f64>,
5132    pub y: ArrayView1<'a, f64>,
5133    pub weights: ArrayView1<'a, f64>,
5134    pub penalty_matrix: ArrayView2<'a, f64>,
5135    pub mode: ArrayView1<'a, f64>,
5136    pub hessian: ArrayView2<'a, f64>,
5137    pub gamma_shape: Option<f64>,
5138    /// Dispersion parameter φ used to scale the likelihood and the
5139    /// whitening Cholesky. For fixed-scale families (Binomial, Poisson)
5140    /// this is `Dispersion::Known(1.0)` and has no numerical effect;
5141    /// for Gaussian / Gamma it carries the estimated `phi` so that the
5142    /// sampler targets the φ-scaled posterior covariance `Vb = φ·H⁻¹`.
5143    /// See `inference::dispersion_cov` for the ownership invariants.
5144    pub dispersion: gam_solve::model_types::Dispersion,
5145    pub firth_bias_reduction: bool,
5146    /// Fixed additive offset on the linear predictor (η = Xβ + offset), or
5147    /// `None` for an offset-free fit. Carried so posterior sampling targets the
5148    /// same η the model was fit and predicts on; omitting it sampled the wrong
5149    /// posterior for any `--offset-column` model (#882).
5150    pub offset: Option<ArrayView1<'a, f64>>,
5151}
5152
5153/// Flat survival inputs for engine-facing HMC APIs.
5154pub struct SurvivalFlatInputs<'a> {
5155    pub age_entry: ArrayView1<'a, f64>,
5156    pub age_exit: ArrayView1<'a, f64>,
5157    pub event_target: ArrayView1<'a, u8>,
5158    pub event_competing: ArrayView1<'a, u8>,
5159    pub weights: ArrayView1<'a, f64>,
5160    pub x_entry: ArrayView2<'a, f64>,
5161    pub x_exit: ArrayView2<'a, f64>,
5162    pub x_derivative: ArrayView2<'a, f64>,
5163    pub eta_offset_entry: Option<ArrayView1<'a, f64>>,
5164    pub eta_offset_exit: Option<ArrayView1<'a, f64>>,
5165    pub derivative_offset_exit: Option<ArrayView1<'a, f64>>,
5166}
5167
5168/// Flattened numeric inputs for Royston-Parmar NUTS sampling.
5169pub struct SurvivalNutsInputs<'a> {
5170    pub flat: SurvivalFlatInputs<'a>,
5171    pub penalties: gam_models::survival::PenaltyBlocks,
5172    pub monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
5173    pub spec: gam_models::survival::SurvivalSpec,
5174    pub structurally_monotonic: bool,
5175    pub structural_time_columns: usize,
5176    pub mode: ArrayView1<'a, f64>,
5177    pub hessian: ArrayView2<'a, f64>,
5178}
5179
5180/// Family-dispatched flattened NUTS inputs.
5181pub enum FamilyNutsInputs<'a> {
5182    Glm(GlmFlatInputs<'a>),
5183    Survival(Box<SurvivalNutsInputs<'a>>),
5184}
5185
5186/// Return the explicit fitted penalized Hessian used for HMC/NUTS whitening.
5187///
5188/// This is the only supported upstream-to-HMC curvature handoff: callers must
5189/// pass a dense Hessian (or an already materialized exact operator stored as a
5190/// dense Hessian) exported by the fitter. We deliberately do not synthesize a
5191/// numerical Hessian and do not invert `beta_covariance` as a compatibility
5192/// fallback, because either path can silently whiten against curvature that the
5193/// upstream fit never certified.
5194pub fn explicit_fit_hessian_for_whitening<'a>(
5195    fit: &'a UnifiedFitResult,
5196    expected_dim: usize,
5197    label: &str,
5198) -> Result<&'a Array2<f64>, String> {
5199    let hessian = fit.penalized_hessian().ok_or_else(|| {
5200        format!(
5201            "{label}: fit result is missing an explicit penalized Hessian for HMC/NUTS whitening"
5202        )
5203    })?;
5204    validate_explicit_dense_hessian_for_whitening(
5205        &format!("{label} penalized Hessian"),
5206        hessian,
5207        expected_dim,
5208    )
5209    .map_err(|err| err.to_string())?;
5210    Ok(hessian)
5211}
5212
5213/// Family-agnostic flattened NUTS entrypoint across all supported likelihood families.
5214pub fn run_nuts_sampling_flattened_family(
5215    likelihood: LikelihoodSpec,
5216    inputs: FamilyNutsInputs<'_>,
5217    config: &NutsConfig,
5218) -> Result<NutsResult, String> {
5219    if let FamilyNutsInputs::Glm(glm) = &inputs
5220        && glm.firth_bias_reduction
5221        && !likelihood_spec_supports_firth(&likelihood)
5222    {
5223        return Err(HmcError::FirthUnsupported {
5224            reason: format!(
5225                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
5226                likelihood.pretty_name()
5227            ),
5228        }
5229        .into());
5230    }
5231
5232    match (likelihood.response.clone(), likelihood.link.clone(), inputs) {
5233        (
5234            ResponseFamily::Gaussian,
5235            InverseLink::Standard(StandardLink::Identity),
5236            FamilyNutsInputs::Glm(glm),
5237        ) => run_nuts_sampling(
5238            glm.x,
5239            glm.y,
5240            glm.weights,
5241            glm.penalty_matrix,
5242            glm.mode,
5243            glm.hessian,
5244            NutsFamily::Gaussian,
5245            1.0,
5246            glm.dispersion,
5247            glm.firth_bias_reduction,
5248            glm.offset,
5249            config,
5250        ),
5251        (
5252            ResponseFamily::Binomial,
5253            InverseLink::Standard(StandardLink::Logit),
5254            FamilyNutsInputs::Glm(glm),
5255        ) => {
5256            // Auto-select PG Gibbs when assumptions hold; otherwise fall back to NUTS.
5257            // This gives gradient-free posterior draws for standard Bernoulli logit GAMs.
5258            // The Pólya-Gamma augmentation here assumes η = Xβ (no offset); an
5259            // offset model routes to NUTS, which carries the offset through
5260            // `glm.offset` (#882). PG-with-offset is a valid but separate scheme
5261            // we deliberately do not duplicate.
5262            if !glm.firth_bias_reduction
5263                && glm.offset.is_none()
5264                && glm.weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10)
5265            {
5266                run_logit_polya_gamma_gibbs(
5267                    glm.x,
5268                    glm.y,
5269                    glm.weights,
5270                    glm.penalty_matrix,
5271                    glm.mode,
5272                    config,
5273                )
5274            } else {
5275                run_nuts_sampling(
5276                    glm.x,
5277                    glm.y,
5278                    glm.weights,
5279                    glm.penalty_matrix,
5280                    glm.mode,
5281                    glm.hessian,
5282                    NutsFamily::BinomialLogit,
5283                    1.0,
5284                    glm.dispersion,
5285                    glm.firth_bias_reduction,
5286                    glm.offset,
5287                    config,
5288                )
5289            }
5290        }
5291        (
5292            ResponseFamily::Binomial,
5293            InverseLink::Standard(StandardLink::Probit),
5294            FamilyNutsInputs::Glm(glm),
5295        ) => run_nuts_sampling(
5296            glm.x,
5297            glm.y,
5298            glm.weights,
5299            glm.penalty_matrix,
5300            glm.mode,
5301            glm.hessian,
5302            NutsFamily::BinomialProbit,
5303            1.0,
5304            glm.dispersion,
5305            glm.firth_bias_reduction,
5306            glm.offset,
5307            config,
5308        ),
5309        (
5310            ResponseFamily::Binomial,
5311            InverseLink::Standard(StandardLink::CLogLog),
5312            FamilyNutsInputs::Glm(glm),
5313        ) => run_nuts_sampling(
5314            glm.x,
5315            glm.y,
5316            glm.weights,
5317            glm.penalty_matrix,
5318            glm.mode,
5319            glm.hessian,
5320            NutsFamily::BinomialCLogLog,
5321            1.0,
5322            glm.dispersion,
5323            glm.firth_bias_reduction,
5324            glm.offset,
5325            config,
5326        ),
5327        (
5328            ResponseFamily::Binomial,
5329            InverseLink::LatentCLogLog(_),
5330            FamilyNutsInputs::Glm(glm),
5331        ) => run_nuts_sampling(
5332            glm.x,
5333            glm.y,
5334            glm.weights,
5335            glm.penalty_matrix,
5336            glm.mode,
5337            glm.hessian,
5338            NutsFamily::BinomialCLogLog,
5339            1.0,
5340            glm.dispersion,
5341            glm.firth_bias_reduction,
5342            glm.offset,
5343            config,
5344        ),
5345        (ResponseFamily::Binomial, InverseLink::Mixture(_), FamilyNutsInputs::Glm(_)) => Err(
5346            "BinomialMixture NUTS is not implemented yet; use fit_gam/predict_gam for blended inverse-link models"
5347                .to_string(),
5348        ),
5349        (ResponseFamily::Binomial, InverseLink::Sas(_), FamilyNutsInputs::Glm(_)) => Err(
5350            "BinomialSas NUTS is not implemented yet; use fit_gam/predict_gam for SAS-link models"
5351                .to_string(),
5352        ),
5353        (ResponseFamily::Binomial, InverseLink::BetaLogistic(_), FamilyNutsInputs::Glm(_)) => Err(
5354            "BinomialBetaLogistic NUTS is not implemented yet; use fit_gam/predict_gam for beta-logistic-link models"
5355                .to_string(),
5356        ),
5357        (ResponseFamily::Binomial, InverseLink::Standard(_), FamilyNutsInputs::Glm(_)) => Err(
5358            "NUTS sampling is not implemented for this binomial inverse link".to_string(),
5359        ),
5360        (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Survival(survival)) => {
5361            survival_hmc::run_survival_nuts_sampling(
5362                survival.flat.age_entry,
5363                survival.flat.age_exit,
5364                survival.flat.event_target,
5365                survival.flat.event_competing,
5366                survival.flat.weights,
5367                survival.flat.x_entry,
5368                survival.flat.x_exit,
5369                survival.flat.x_derivative,
5370                survival.flat.eta_offset_entry,
5371                survival.flat.eta_offset_exit,
5372                survival.flat.derivative_offset_exit,
5373                survival.penalties,
5374                survival.monotonicity,
5375                survival.spec,
5376                survival.structurally_monotonic,
5377                survival.structural_time_columns,
5378                survival.mode,
5379                survival.hessian,
5380                config,
5381            )
5382        }
5383        (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Glm(_)) => Err(
5384            "RoystonParmar family requires FamilyNutsInputs::Survival flattened inputs".to_string(),
5385        ),
5386        (_, _, FamilyNutsInputs::Survival(_)) => Err(
5387            "Survival flattened inputs are only valid for the Royston-Parmar response family"
5388                .to_string(),
5389        ),
5390        (ResponseFamily::Poisson, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5391            glm.x,
5392            glm.y,
5393            glm.weights,
5394            glm.penalty_matrix,
5395            glm.mode,
5396            glm.hessian,
5397            NutsFamily::PoissonLog,
5398            1.0,
5399            glm.dispersion,
5400            glm.firth_bias_reduction,
5401            glm.offset,
5402            config,
5403        ),
5404        (ResponseFamily::Tweedie { p }, _, FamilyNutsInputs::Glm(glm)) => {
5405            // Family mapping: Tweedie payload p is passed through the family-parameter slot.
5406            // The Tweedie dispersion phi remains in glm.dispersion, matching REML.
5407            if !is_valid_tweedie_power(p) {
5408                return Err(format!(
5409                    "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
5410                ));
5411            }
5412            run_nuts_sampling(
5413                glm.x,
5414                glm.y,
5415                glm.weights,
5416                glm.penalty_matrix,
5417                glm.mode,
5418                glm.hessian,
5419                NutsFamily::TweedieLog,
5420                p,
5421                glm.dispersion,
5422                glm.firth_bias_reduction,
5423                glm.offset,
5424                config,
5425            )
5426        }
5427        (ResponseFamily::NegativeBinomial { theta, .. }, _, FamilyNutsInputs::Glm(glm)) => {
5428            // Family mapping: NegativeBinomial payload theta is passed through the family slot.
5429            // NB dispersion scale is unit; theta is not derived from fixed_phi.
5430            run_nuts_sampling(
5431                glm.x,
5432                glm.y,
5433                glm.weights,
5434                glm.penalty_matrix,
5435                glm.mode,
5436                glm.hessian,
5437                NutsFamily::NegativeBinomialLog,
5438                theta,
5439                glm.dispersion,
5440                glm.firth_bias_reduction,
5441                glm.offset,
5442                config,
5443            )
5444        }
5445        (ResponseFamily::Beta { .. }, _, FamilyNutsInputs::Glm(_)) => Err(
5446            "NUTS sampling is not implemented for beta-regression logit".to_string(),
5447        ),
5448        (ResponseFamily::Gamma, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5449            glm.x,
5450            glm.y,
5451            glm.weights,
5452            glm.penalty_matrix,
5453            glm.mode,
5454            glm.hessian,
5455            NutsFamily::GammaLog,
5456            glm.gamma_shape.unwrap_or(1.0),
5457            glm.dispersion,
5458            glm.firth_bias_reduction,
5459            glm.offset,
5460            config,
5461        ),
5462        (ResponseFamily::Gaussian, _, FamilyNutsInputs::Glm(_)) => Err(
5463            "NUTS sampling is only implemented for Gaussian with identity link".to_string(),
5464        ),
5465    }
5466}
5467
5468// ============================================================================
5469// Joint (β, θ) Link-Wiggle HMC
5470// ============================================================================
5471//
5472// NUTS sampling over the joint parameter space [β_eta; β_wiggle] for models
5473// with a structurally monotone I-spline link wiggle. The wiggle introduces a
5474// nonlinear coupling:
5475//
5476//   η(β_eta, β_wiggle) = q₀(β_eta) + B(q₀(β_eta)) · β_wiggle
5477//
5478// where B is the shared monotone wiggle basis evaluated at the base linear
5479// predictor q₀ = X · β_eta. The gradient of log p(y|β_eta, β_wiggle) w.r.t.
5480// β_eta picks up a chain-rule factor g'(q₀) = 1 + B'(q₀) · β_wiggle / range_width
5481// from the dependence of B on q₀.
5482//
5483// Whitening uses the Cholesky of the joint Hessian at the mode, exactly as for
5484// the standard NutsPosterior. C^1 linear extension outside the training knot
5485// range prevents basis evaluation discontinuities.
5486
5487/// Fixed spline artifacts for link-wiggle posterior sampling.
5488#[derive(Clone)]
5489pub struct LinkWiggleSplineArtifacts {
5490    /// Knot range (min, max) from training (in standardized [0,1] space of q₀)
5491    pub knot_range: (f64, f64),
5492    /// Full knot vector for the shared monotone I-spline basis
5493    pub knot_vector: Array1<f64>,
5494    /// I-spline degree
5495    pub degree: usize,
5496}
5497
5498/// Whitened log-posterior target for joint (β_eta, β_wiggle) with analytical gradients.
5499#[derive(Clone)]
5500pub struct LinkWigglePosterior {
5501    /// Main design matrix X (n × p_main)
5502    x: Arc<Array2<f64>>,
5503    y: Arc<Array1<f64>>,
5504    weights: Arc<Array1<f64>>,
5505    /// Penalty for main coefficients (p_main × p_main)
5506    penalty_base: Arc<Array2<f64>>,
5507    /// Penalty for wiggle coefficients (p_wiggle × p_wiggle)
5508    penalty_link: Arc<Array2<f64>>,
5509    mode_beta: Arc<Array1<f64>>,
5510    mode_theta: Arc<Array1<f64>>,
5511    spline: LinkWiggleSplineArtifacts,
5512    /// L where LL^T = H^{-1} (joint Hessian)
5513    chol: Array2<f64>,
5514    /// L^T for gradient chain rule
5515    chol_t: Array2<f64>,
5516    p_base: usize,
5517    p_link: usize,
5518    n_samples: usize,
5519    nuts_family: NutsFamily,
5520    /// Family-specific noise parameter: Gaussian sigma or Gamma shape.
5521    scale: f64,
5522    /// Coefficient-covariance scale `cov_scale` (#679/#680 invariant): the
5523    /// `Vb = cov_scale·H⁻¹` multiplier driving both the whitening
5524    /// (`L Lᵀ = cov_scale·H⁻¹`) and the target penalty weight
5525    /// (`penalty_scale = 1/cov_scale`). `σ²` for profiled Gaussian, `1.0` for
5526    /// every weight-carries-dispersion family (Gamma/Tweedie/NB).
5527    cov_scale: f64,
5528}
5529
5530impl LinkWigglePosterior {
5531    /// Standardize q₀ values to [0,1] range using training knot bounds.
5532    #[inline]
5533    fn standardized_z(&self, u: &Array1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
5534        let (min_u, max_u) = self.spline.knot_range;
5535        let rw = (max_u - min_u).max(1e-6);
5536        let z_raw: Array1<f64> = u.mapv(|v| (v - min_u) / rw);
5537        let z_c: Array1<f64> = z_raw.mapv(|z| z.clamp(0.0, 1.0));
5538        (z_raw, z_c, rw)
5539    }
5540
5541    /// Creates a new link-wiggle posterior target.
5542    pub fn new(
5543        x: ArrayView2<f64>,
5544        y: ArrayView1<f64>,
5545        weights: ArrayView1<f64>,
5546        penalty_base: ArrayView2<f64>,
5547        penalty_link: ArrayView2<f64>,
5548        mode_beta: ArrayView1<f64>,
5549        mode_theta: ArrayView1<f64>,
5550        hessian: ArrayView2<f64>,
5551        spline: LinkWiggleSplineArtifacts,
5552        nuts_family: NutsFamily,
5553        scale: f64,
5554    ) -> Result<Self, String> {
5555        let n_samples = x.nrows();
5556        let p_base = x.ncols();
5557        let p_link = mode_theta.len();
5558        let dim = p_base + p_link;
5559        if hessian.nrows() != dim || hessian.ncols() != dim {
5560            return Err(HmcError::DimensionMismatch {
5561                reason: format!(
5562                    "LinkWigglePosterior: Hessian dim mismatch: {}x{} vs expected {}x{}",
5563                    hessian.nrows(),
5564                    hessian.ncols(),
5565                    dim,
5566                    dim,
5567                ),
5568            }
5569            .into());
5570        }
5571        if nuts_family.likelihood_spec().is_binomial() {
5572            validate_binary_responses("binomial link-wiggle NUTS", &y, &weights)
5573                .map_err(String::from)?;
5574        }
5575        if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
5576            validate_count_responses("negative-binomial link-wiggle NUTS", &y, &weights)
5577                .map_err(String::from)?;
5578        }
5579        // Whitening metric `L Lᵀ = cov_scale · H⁻¹` (#679/#680 invariant), so
5580        // scale `L` by `√cov_scale`. For the link-wiggle joint target `scale`
5581        // is σ (Gaussian), so the profiled-Gaussian covariance scale is
5582        // `cov_scale = σ²`. Every other family folds its dispersion into the
5583        // working weight / the `shape`/`theta` already inside its
5584        // log-likelihood, so `cov_scale = 1` and this is a no-op. The previous
5585        // Gamma branch scaled `L` by `1/√shape = √φ`, mis-preconditioning the
5586        // sampler against `φ·H⁻¹` instead of the correct `H⁻¹` (#680).
5587        let cov_scale = match nuts_family {
5588            NutsFamily::Gaussian => scale * scale,
5589            _ => 1.0,
5590        };
5591        let whitening = hessian_whitening_transform(
5592            hessian,
5593            dim,
5594            cov_scale,
5595            "LinkWigglePosterior Cholesky failed",
5596        )?;
5597        let chol = whitening.chol;
5598        let chol_t = whitening.chol_t;
5599        Ok(Self {
5600            x: Arc::new(x.to_owned()),
5601            y: Arc::new(y.to_owned()),
5602            weights: Arc::new(weights.to_owned()),
5603            penalty_base: Arc::new(penalty_base.to_owned()),
5604            penalty_link: Arc::new(penalty_link.to_owned()),
5605            mode_beta: Arc::new(mode_beta.to_owned()),
5606            mode_theta: Arc::new(mode_theta.to_owned()),
5607            spline,
5608            chol,
5609            chol_t,
5610            p_base,
5611            p_link,
5612            n_samples,
5613            nuts_family,
5614            scale,
5615            cov_scale,
5616        })
5617    }
5618
5619    /// Evaluate the wiggle basis and compute η = q₀ + B(q₀)·θ with C^1 linear extension.
5620    fn evaluate_link(&self, u: &Array1<f64>, theta: &Array1<f64>) -> (Array2<f64>, Array1<f64>) {
5621        let n = u.len();
5622        if theta.is_empty() {
5623            return (Array2::zeros((n, 0)), u.clone());
5624        }
5625
5626        let (z_raw, z_c, _) = self.standardized_z(u);
5627        let Ok(mut basis) = monotone_wiggle_basis_with_derivative_order(
5628            z_c.view(),
5629            &self.spline.knot_vector,
5630            self.spline.degree,
5631            0,
5632        ) else {
5633            return (Array2::zeros((n, theta.len())), u.clone());
5634        };
5635        if basis.ncols() != theta.len() {
5636            return (Array2::zeros((n, theta.len())), u.clone());
5637        }
5638
5639        // C^1 linear extension outside [0, 1]:
5640        // B_ext(z_raw) = B(z_c) + (z_raw - z_c) * B'(z_c)
5641        let mut needs_ext = false;
5642        for i in 0..n {
5643            if (z_raw[i] - z_c[i]).abs() > 1e-12 {
5644                needs_ext = true;
5645                break;
5646            }
5647        }
5648        if needs_ext
5649            && let Ok(b_prime) = monotone_wiggle_basis_with_derivative_order(
5650                z_c.view(),
5651                &self.spline.knot_vector,
5652                self.spline.degree,
5653                1,
5654            )
5655        {
5656            for i in 0..n {
5657                let dz = z_raw[i] - z_c[i];
5658                if dz.abs() <= 1e-12 {
5659                    continue;
5660                }
5661                for j in 0..basis.ncols().min(b_prime.ncols()) {
5662                    basis[[i, j]] += dz * b_prime[[i, j]];
5663                }
5664            }
5665        }
5666        (
5667            basis.clone(),
5668            u + &gam_linalg::faer_ndarray::fast_av(&basis, theta),
5669        )
5670    }
5671
5672    /// Compute dη/dq₀ = 1 + B'(q₀)·θ / range_width (chain-rule factor for β_eta gradient).
5673    fn compute_g_prime(&self, u: &Array1<f64>, theta: &Array1<f64>) -> Array1<f64> {
5674        let n = u.len();
5675        let mut g = Array1::<f64>::ones(n);
5676        let (_, z_c, rw) = self.standardized_z(u);
5677        if theta.is_empty() {
5678            return g;
5679        }
5680
5681        let Ok(b_prime_constrained) = monotone_wiggle_basis_with_derivative_order(
5682            z_c.view(),
5683            &self.spline.knot_vector,
5684            self.spline.degree,
5685            1,
5686        ) else {
5687            return g;
5688        };
5689        if b_prime_constrained.ncols() != theta.len() {
5690            return g;
5691        }
5692        let dwiggle_dz = gam_linalg::faer_ndarray::fast_av(&b_prime_constrained, theta);
5693        ndarray::Zip::from(&mut g)
5694            .and(&dwiggle_dz)
5695            .par_for_each(|gi, &dw| *gi = 1.0 + dw / rw);
5696        g
5697    }
5698
5699    fn compute_logp_and_grad_into(&self, z: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5700        let dim = self.p_base + self.p_link;
5701
5702        // Un-whiten: q = mode + L·z
5703        let mut mode = Array1::<f64>::zeros(dim);
5704        mode.slice_mut(ndarray::s![0..self.p_base])
5705            .assign(&self.mode_beta);
5706        mode.slice_mut(ndarray::s![self.p_base..])
5707            .assign(&self.mode_theta);
5708        let q = &mode + &self.chol.dot(z);
5709        let beta = q.slice(ndarray::s![0..self.p_base]).to_owned();
5710        let theta = q.slice(ndarray::s![self.p_base..]).to_owned();
5711
5712        // Compute η = q₀ + B(q₀)·θ where q₀ = X·β
5713        let u = gam_linalg::faer_ndarray::fast_av(self.x.as_ref(), &beta);
5714        let (bwiggle, eta) = self.evaluate_link(&u, &theta);
5715
5716        // Log-likelihood and residuals via family dispatch
5717        let ll;
5718        let mut residual = Array1::<f64>::zeros(self.n_samples);
5719        match self.nuts_family {
5720            NutsFamily::Gaussian => {
5721                let inv_scale_sq = 1.0 / (self.scale * self.scale).max(1e-10);
5722                let mut ll_acc = 0.0;
5723                for i in 0..self.n_samples {
5724                    let r = self.y[i] - eta[i];
5725                    let w = self.weights[i];
5726                    ll_acc -= 0.5 * w * r * r * inv_scale_sq;
5727                    residual[i] = w * r * inv_scale_sq;
5728                }
5729                ll = ll_acc;
5730            }
5731            NutsFamily::BinomialLogit => {
5732                let mut ll_acc = 0.0;
5733                for i in 0..self.n_samples {
5734                    let eta_i = eta[i];
5735                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5736                    ll_acc += w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i));
5737                    let mu = gam_linalg::utils::stable_logistic(eta_i);
5738                    residual[i] = w_i * (y_i - mu);
5739                }
5740                ll = ll_acc;
5741            }
5742            NutsFamily::BinomialProbit => {
5743                let mut ll_acc = 0.0;
5744                for i in 0..self.n_samples {
5745                    let eta_i = eta[i];
5746                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5747                    let log_phi_pos = log_ndtr(eta_i);
5748                    let log_phi_neg = log_ndtr(-eta_i);
5749                    ll_acc += w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg);
5750                    let log_phi = standard_normal_log_pdf(eta_i);
5751                    let ratio_pos = (log_phi - log_phi_pos).exp();
5752                    let ratio_neg = (log_phi - log_phi_neg).exp();
5753                    residual[i] = w_i * (y_i * ratio_pos - (1.0 - y_i) * ratio_neg);
5754                }
5755                ll = ll_acc;
5756            }
5757            NutsFamily::BinomialCLogLog => {
5758                let mut ll_acc = 0.0;
5759                for i in 0..self.n_samples {
5760                    let eta_i = eta[i];
5761                    if !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)) {
5762                        grad.fill(0.0);
5763                        return f64::NEG_INFINITY;
5764                    }
5765                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5766                    let (ll_i, residual_i) = match cloglog_bernoulli_logp_and_residual(eta_i, y_i) {
5767                        Ok(values) => values,
5768                        Err(_) => {
5769                            grad.fill(0.0);
5770                            return f64::NEG_INFINITY;
5771                        }
5772                    };
5773                    ll_acc += w_i * ll_i;
5774                    residual[i] = w_i * residual_i;
5775                }
5776                ll = ll_acc;
5777            }
5778            NutsFamily::PoissonLog => {
5779                let mut ll_acc = 0.0;
5780                for i in 0..self.n_samples {
5781                    let eta_i = eta[i];
5782                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5783                        grad.fill(0.0);
5784                        return f64::NEG_INFINITY;
5785                    }
5786                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5787                    let mu = eta_i.exp();
5788                    ll_acc += w_i * (y_i * eta_i - mu);
5789                    residual[i] = w_i * (y_i - mu);
5790                }
5791                ll = ll_acc;
5792            }
5793            NutsFamily::TweedieLog => {
5794                let mut ll_acc = 0.0;
5795                // Family mapping: Tweedie scale carries payload p; phi is not stored here.
5796                // Invalid p makes the link-wiggle target invalid instead of defaulting.
5797                if !is_valid_tweedie_power(self.scale) {
5798                    grad.fill(0.0);
5799                    return f64::NEG_INFINITY;
5800                }
5801                let p = self.scale;
5802                for i in 0..self.n_samples {
5803                    let eta_i = eta[i];
5804                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5805                        grad.fill(0.0);
5806                        return f64::NEG_INFINITY;
5807                    }
5808                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5809                    let mu = eta_i.exp().max(1e-300);
5810                    ll_acc +=
5811                        w_i * (y_i * mu.powf(1.0 - p) / (1.0 - p) - mu.powf(2.0 - p) / (2.0 - p));
5812                    residual[i] = w_i * (y_i - mu) * mu.powf(1.0 - p);
5813                }
5814                ll = ll_acc;
5815            }
5816            NutsFamily::NegativeBinomialLog => {
5817                let mut ll_acc = 0.0;
5818                // Family mapping: NegativeBinomial scale carries payload theta.
5819                // Invalid theta makes the link-wiggle target invalid instead of clamping.
5820                if !(self.scale.is_finite() && self.scale > 0.0) {
5821                    grad.fill(0.0);
5822                    return f64::NEG_INFINITY;
5823                }
5824                let theta = self.scale;
5825                for i in 0..self.n_samples {
5826                    let eta_i = eta[i];
5827                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5828                        grad.fill(0.0);
5829                        return f64::NEG_INFINITY;
5830                    }
5831                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5832                    if w_i <= 0.0 {
5833                        residual[i] = 0.0;
5834                        continue;
5835                    }
5836                    let mu = eta_i.exp().max(1e-12);
5837                    let log_mu_term = if y_i > 0.0 { y_i * mu.ln() } else { 0.0 };
5838                    ll_acc += w_i
5839                        * (statrs::function::gamma::ln_gamma(y_i + theta)
5840                            - statrs::function::gamma::ln_gamma(theta)
5841                            - statrs::function::gamma::ln_gamma(y_i + 1.0)
5842                            + theta * (theta.ln() - (theta + mu).ln())
5843                            + log_mu_term
5844                            - y_i * (theta + mu).ln());
5845                    residual[i] = w_i * theta * (y_i - mu) / (theta + mu);
5846                }
5847                ll = ll_acc;
5848            }
5849            NutsFamily::GammaLog => {
5850                let mut ll_acc = 0.0;
5851                let shape = self.scale.max(1e-10);
5852                for i in 0..self.n_samples {
5853                    let eta_i = eta[i];
5854                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5855                        grad.fill(0.0);
5856                        return f64::NEG_INFINITY;
5857                    }
5858                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5859                    let mu = eta_i.exp();
5860                    ll_acc += w_i * shape * (-y_i / mu - eta_i);
5861                    residual[i] = w_i * shape * (y_i / mu - 1.0);
5862                }
5863                ll = ll_acc;
5864            }
5865        }
5866
5867        // Penalty weight = 1/cov_scale (#679/#680 invariant), matching the
5868        // factor the likelihood already carries so the prior and likelihood
5869        // live on the same scale and the MAP-anchored target curvature equals
5870        // `Vb⁻¹ = H/cov_scale`. The Gaussian block above multiplies through by
5871        // `1/σ²`, so `penalty_scale = 1/σ²`. The Gamma block carries an explicit
5872        // `shape = 1/φ` factor in its score (`w_i·shape·(y/μ − 1)`) — that is
5873        // the *data* Fisher information, already folded into the working
5874        // weight, so the penalty must stay UNSCALED (`cov_scale = 1`,
5875        // `penalty_scale = 1`). The previous code used `penalty_scale = shape`
5876        // for Gamma, double-counting the dispersion in the sampled posterior
5877        // and shrinking every posterior SD by `√φ` (#680). Tweedie/NB/Poisson/
5878        // Binomial are unit-scale and unchanged.
5879        let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
5880
5881        // Gradient w.r.t. θ (wiggle): ∂ℓ/∂θ = B(q₀)^T · residual − S_link · θ
5882        let s_link_theta = self.penalty_link.dot(&theta);
5883        let grad_theta = &fast_atv(&bwiggle, &residual) - &(&s_link_theta * penalty_scale);
5884
5885        // Gradient w.r.t. β_eta: ∂ℓ/∂β = X^T · (residual ⊙ g'(q₀)) − S_base · β
5886        // where g'(q₀) = dη/dq₀ is the chain-rule factor
5887        let g_prime = self.compute_g_prime(&u, &theta);
5888        let r_scaled: Array1<f64> = residual
5889            .iter()
5890            .zip(g_prime.iter())
5891            .map(|(&r, &g)| r * g)
5892            .collect();
5893        let s_base_beta = self.penalty_base.dot(&beta);
5894        let grad_beta = &fast_atv(&self.x, &r_scaled) - &(&s_base_beta * penalty_scale);
5895
5896        // Penalty (also φ-scaled for Gaussian; see `penalty_scale` above).
5897        let penalty =
5898            penalty_scale * (0.5 * beta.dot(&s_base_beta) + 0.5 * theta.dot(&s_link_theta));
5899
5900        // Assemble joint gradient and transform to whitened space
5901        let mut grad_q = Array1::<f64>::zeros(dim);
5902        grad_q
5903            .slice_mut(ndarray::s![0..self.p_base])
5904            .assign(&grad_beta);
5905        grad_q
5906            .slice_mut(ndarray::s![self.p_base..])
5907            .assign(&grad_theta);
5908        fast_av_into(&self.chol_t, &grad_q, grad);
5909        ll - penalty
5910    }
5911
5912    /// Get the Cholesky factor L for un-whitening samples.
5913    pub fn chol(&self) -> &Array2<f64> {
5914        &self.chol
5915    }
5916
5917    /// Get the mode [β_eta; β_wiggle].
5918    pub fn mode_joint(&self) -> Array1<f64> {
5919        let dim = self.p_base + self.p_link;
5920        let mut mode = Array1::<f64>::zeros(dim);
5921        mode.slice_mut(ndarray::s![0..self.p_base])
5922            .assign(&self.mode_beta);
5923        mode.slice_mut(ndarray::s![self.p_base..])
5924            .assign(&self.mode_theta);
5925        mode
5926    }
5927}
5928
5929impl HamiltonianTarget<Array1<f64>> for LinkWigglePosterior {
5930    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5931        self.compute_logp_and_grad_into(position, grad)
5932    }
5933}
5934
5935/// Runs NUTS sampling for joint (β_eta, β_wiggle) in a link-wiggle model.
5936pub fn run_link_wiggle_nuts_sampling(
5937    x: ArrayView2<f64>,
5938    y: ArrayView1<f64>,
5939    weights: ArrayView1<f64>,
5940    penalty_base: ArrayView2<f64>,
5941    penalty_link: ArrayView2<f64>,
5942    mode_beta: ArrayView1<f64>,
5943    mode_theta: ArrayView1<f64>,
5944    hessian: ArrayView2<f64>,
5945    spline: LinkWiggleSplineArtifacts,
5946    nuts_family: NutsFamily,
5947    scale: f64,
5948    config: &NutsConfig,
5949) -> Result<NutsResult, String> {
5950    validate_nuts_config(config).map_err(String::from)?;
5951    let dim = mode_beta.len() + mode_theta.len();
5952    let target = LinkWigglePosterior::new(
5953        x,
5954        y,
5955        weights,
5956        penalty_base,
5957        penalty_link,
5958        mode_beta,
5959        mode_theta,
5960        hessian,
5961        spline,
5962        nuts_family,
5963        scale,
5964    )?;
5965    let chol = target.chol().clone();
5966    let mode_arr = target.mode_joint();
5967
5968    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x8C48_0F65_3A2B_D917);
5969
5970    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
5971    let (result, run_stats) = run_whitened_nuts_result(
5972        target,
5973        &mode_arr,
5974        &chol,
5975        initial_positions,
5976        config,
5977        dim,
5978        mass_cfg,
5979        0x2E31_A4B6_C908_F57D,
5980        "Link-wiggle NUTS sampling failed",
5981        Array1::zeros(dim),
5982        NutsConvergenceThresholds {
5983            max_rhat: 1.1,
5984            min_ess: Some(100.0),
5985        },
5986    )?;
5987    log::info!("Link-wiggle NUTS sampling complete: {}", run_stats);
5988
5989    Ok(result)
5990}
5991
5992// ============================================================================
5993// Joint (β, ρ) HMC for Skewed Posteriors
5994// ============================================================================
5995//
5996// When the Laplace approximation to the marginal likelihood is unreliable
5997// (high posterior skewness), we bypass LAML entirely and sample from the
5998// joint posterior p(β, ρ | y) ∝ p(y|β) p(β|ρ) p(ρ).
5999//
6000// The joint log-posterior is:
6001//   log p(β, ρ | y) = ℓ(y|β) + Φ(β) [if Firth]
6002//                    - 0.5 β'S(ρ)β + 0.5 log|S(ρ)|_+ + log p(ρ) + const
6003//
6004// Gradients:
6005//   ∇_β: ∇_β ℓ + ∇_β Φ(β) [if Firth] - S(ρ) β
6006//   ∂/∂ρ_k: -0.5 λ_k β'S_k β + 0.5 tr(S_+⁻¹ A_k) + ∂log p(ρ)/∂ρ_k
6007//
6008// This completely avoids the Laplace approximation. When Firth bias reduction
6009// is active, the sampled target also includes the Jeffreys term Φ(β) in
6010// addition to the smoothing-parameter prior.
6011
6012/// Directional cubic non-Gaussianity diagnostic for the Laplace approximation.
6013///
6014/// For each positive-curvature Hessian eigenpair `(lambda_r, v_r)`, this computes
6015///
6016///   gamma_r = T[v_r, v_r, v_r] / lambda_r^(3/2)
6017///            = Σ_i c_i (x_i^T v_r)^3 / lambda_r^(3/2),
6018///
6019/// and reports `max_r |gamma_r|`. This is invariant to arbitrary coordinate
6020/// relabeling and uses the full directional cubic contraction rather than only
6021/// diagonal tensor entries.
6022/// `refine_supremum` controls Phase 2, the cubic power-iteration that sharpens
6023/// the returned scalar `max_abs` toward the true supremum of `|γ(u)|` over the
6024/// H-unit sphere (which can exceed the per-eigenvector maximum). That scalar is
6025/// the ONLY thing Phase 2 affects — the per-direction `directional` vector,
6026/// which drives [`laplace_trustworthiness_from_skewness`]'s direction selection
6027/// AND its own internally-recomputed `max_abs_skewness`, comes entirely from
6028/// Phase 1. The #784 block-local REML correction
6029/// (`block_local_sampled_correction`) consumes `directional` and uses `max_abs`
6030/// only for a `> 0` finiteness guard that Phase 1 already satisfies, so it
6031/// passes `false` and skips Phase 2's multi-probe O(probes·iters·np) refinement
6032/// on every inner evaluation. Diagnostic callers that report the true supremum
6033/// pass `true`.
6034pub fn laplace_directional_cubic_diagnostic(
6035    hessian: &Array2<f64>,
6036    design: &DesignMatrix,
6037    c_weights: &Array1<f64>,
6038    refine_supremum: bool,
6039) -> Result<(f64, Array1<f64>), String> {
6040    let p = hessian.nrows();
6041    if p == 0 || hessian.ncols() != p {
6042        return Ok((0.0, Array1::zeros(0)));
6043    }
6044
6045    let sym_h = (hessian + &hessian.t()) * 0.5;
6046    let (evals, evecs) = sym_h
6047        .eigh(Side::Lower)
6048        .map_err(|e| format!("directional cubic diagnostic eigendecomposition failed: {e}"))?;
6049    let max_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
6050    let tol = (max_eval * 1.0e-12).max(1.0e-14);
6051    let mut directional = Array1::<f64>::zeros(p);
6052    let mut max_abs = 0.0_f64;
6053
6054    // Build the whitening transform L^{-1} where H = L L^T, so that
6055    // the standardized cubic along whitened direction u is:
6056    //   gamma(u) = T[L^{-T}u, L^{-T}u, L^{-T}u]  for ||u||=1
6057    // Eigenvector directions v_r satisfy u_r = lambda_r^{1/2} v_r (after
6058    // appropriate normalization), so gamma_r = T[v_r,v_r,v_r] / lambda_r^{3/2}.
6059
6060    // Phase 1: evaluate gamma_r for all positive-curvature eigenvectors.
6061    for r in 0..p {
6062        let lambda = evals[r];
6063        if lambda <= tol {
6064            continue;
6065        }
6066        let v = evecs.column(r);
6067        let gamma = directional_cubic_contraction(design, c_weights, &v) / lambda.powf(1.5);
6068        directional[r] = if gamma.is_finite() { gamma } else { 0.0 };
6069        max_abs = max_abs.max(directional[r].abs());
6070    }
6071
6072    // Phase 2: power-iteration refinement in whitened space.
6073    //
6074    // The supremum of |gamma(u)| over ||u||_H=1 can exceed the max over
6075    // eigenvectors. We approximate it with a few rounds of cubic power
6076    // iteration: given current direction v, the gradient of T[v,v,v] w.r.t.
6077    // v on the H-unit sphere is 3 T[·,v,v] projected onto the tangent space.
6078    // Since T[·,v,v] = X^T diag(c_i (x_i^T v)^2) which is a matrix-vector
6079    // product, each iteration is O(np).
6080    //
6081    // We seed from the eigenvector with largest |gamma_r| and also from a
6082    // few random probe directions.
6083    if refine_supremum && p >= 2 {
6084        // Build H^{-1/2} columns for whitening: H^{-1/2} = V diag(1/sqrt(lam)) V^T
6085        // We need it to map whitened u -> original v = H^{-1/2} u, and
6086        // H^{1/2} to project back: H^{1/2} v = V diag(sqrt(lam)) V^T v.
6087        let positive_mask: Vec<bool> = evals.iter().map(|&ev| ev > tol).collect();
6088        let n_pos = positive_mask.iter().filter(|&&m| m).count();
6089        if n_pos >= 2 {
6090            let max_abs_from_probes = cubic_power_iteration_refinement(
6091                design,
6092                c_weights,
6093                &evals,
6094                &evecs,
6095                &positive_mask,
6096                n_pos,
6097            );
6098            if max_abs_from_probes > max_abs {
6099                max_abs = max_abs_from_probes;
6100            }
6101        }
6102    }
6103
6104    Ok((max_abs, directional))
6105}
6106
6107/// Compute T[v,v,v] = Σ_i c_i (x_i^T v)^3 for a given direction v.
6108fn directional_cubic_contraction(
6109    design: &DesignMatrix,
6110    c_weights: &Array1<f64>,
6111    v: &ArrayView1<f64>,
6112) -> f64 {
6113    match design.as_sparse() {
6114        Some(x_sparse) => {
6115            let (symbolic, values) = x_sparse.as_ref().parts();
6116            let col_ptr = symbolic.col_ptr();
6117            let row_idx = symbolic.row_idx();
6118            let mut row_scores = vec![0.0_f64; x_sparse.nrows()];
6119            for col in 0..x_sparse.ncols() {
6120                let coeff = v[col];
6121                for ptr in col_ptr[col]..col_ptr[col + 1] {
6122                    row_scores[row_idx[ptr]] += values[ptr] * coeff;
6123                }
6124            }
6125            let mut cubic = 0.0_f64;
6126            for i in 0..row_scores.len().min(c_weights.len()) {
6127                cubic += c_weights[i] * row_scores[i].powi(3);
6128            }
6129            cubic
6130        }
6131        None => {
6132            let x_dense = design.to_dense_cow();
6133            let x_dense = x_dense.as_ref();
6134            let mut cubic = 0.0_f64;
6135            for i in 0..x_dense.nrows().min(c_weights.len()) {
6136                let proj = x_dense.row(i).dot(v);
6137                cubic += c_weights[i] * proj.powi(3);
6138            }
6139            cubic
6140        }
6141    }
6142}
6143
6144/// Compute the gradient of T[v,v,v] w.r.t. v:  3 X^T diag(c_i (x_i^T v)^2) 1.
6145/// More precisely: ∂/∂v T[v,v,v] = 3 Σ_i c_i (x_i^T v)^2 x_i.
6146fn directional_cubic_gradient(
6147    design: &DesignMatrix,
6148    c_weights: &Array1<f64>,
6149    v: &Array1<f64>,
6150) -> Array1<f64> {
6151    let p = v.len();
6152    match design.as_sparse() {
6153        Some(x_sparse) => {
6154            let (symbolic, values) = x_sparse.as_ref().parts();
6155            let col_ptr = symbolic.col_ptr();
6156            let row_idx = symbolic.row_idx();
6157            let n = x_sparse.nrows();
6158            let mut row_scores = vec![0.0_f64; n];
6159            for col in 0..x_sparse.ncols() {
6160                let coeff = v[col];
6161                for ptr in col_ptr[col]..col_ptr[col + 1] {
6162                    row_scores[row_idx[ptr]] += values[ptr] * coeff;
6163                }
6164            }
6165            // quadratic weights: 3 c_i (x_i^T v)^2
6166            let mut quad_weights = vec![0.0_f64; n];
6167            for i in 0..n.min(c_weights.len()) {
6168                quad_weights[i] = 3.0 * c_weights[i] * row_scores[i] * row_scores[i];
6169            }
6170            // X^T quad_weights
6171            let mut grad = Array1::<f64>::zeros(p);
6172            for col in 0..x_sparse.ncols() {
6173                let mut acc = 0.0_f64;
6174                for ptr in col_ptr[col]..col_ptr[col + 1] {
6175                    acc += values[ptr] * quad_weights[row_idx[ptr]];
6176                }
6177                grad[col] = acc;
6178            }
6179            grad
6180        }
6181        None => {
6182            let x_dense = design.to_dense_cow();
6183            let x_dense = x_dense.as_ref();
6184            let n = x_dense.nrows();
6185            let mut grad = Array1::<f64>::zeros(p);
6186            for i in 0..n.min(c_weights.len()) {
6187                let proj = x_dense.row(i).dot(v);
6188                let w = 3.0 * c_weights[i] * proj * proj;
6189                // scaled_add works with any ArrayBase reference.
6190                let row = x_dense.row(i);
6191                for j in 0..p {
6192                    grad[j] += w * row[j];
6193                }
6194            }
6195            grad
6196        }
6197    }
6198}
6199
6200/// Power-iteration refinement for the supremum of |gamma(u)| over ||u||_H = 1.
6201///
6202/// Seeds from the best eigenvector direction plus deterministic probe
6203/// directions constructed from pairs of eigenvectors. Runs a few Riemannian
6204/// gradient ascent steps on the whitened unit sphere.
6205fn cubic_power_iteration_refinement(
6206    design: &DesignMatrix,
6207    c_weights: &Array1<f64>,
6208    evals: &Array1<f64>,
6209    evecs: &Array2<f64>,
6210    positive_mask: &[bool],
6211    n_pos: usize,
6212) -> f64 {
6213    let p = evals.len();
6214    let max_probes = 8;
6215    let max_iters = 5;
6216
6217    // Helper: convert whitened u -> original v = Σ_r (u_r / sqrt(lam_r)) * evec_r
6218    // (only over positive eigenspace).
6219    let to_original = |u: &Array1<f64>| -> Array1<f64> {
6220        let mut v = Array1::<f64>::zeros(p);
6221        let mut idx = 0;
6222        for r in 0..p {
6223            if positive_mask[r] {
6224                let scale = u[idx] / evals[r].sqrt();
6225                let col = evecs.column(r);
6226                for j in 0..p {
6227                    v[j] += scale * col[j];
6228                }
6229                idx += 1;
6230            }
6231        }
6232        v
6233    };
6234
6235    // Helper: project original-space vector to whitened: u_j = sqrt(lam_r) (evec_r^T g)
6236    let to_whitened = |g: &Array1<f64>| -> Array1<f64> {
6237        let mut u = Array1::<f64>::zeros(n_pos);
6238        let mut idx = 0;
6239        for r in 0..p {
6240            if positive_mask[r] {
6241                u[idx] = evals[r].sqrt() * evecs.column(r).dot(g);
6242                idx += 1;
6243            }
6244        }
6245        u
6246    };
6247
6248    // Evaluate |gamma(u)| for whitened direction u.
6249    let eval_gamma = |u: &Array1<f64>| -> f64 {
6250        let norm = u.dot(u).sqrt();
6251        if norm < 1e-30 {
6252            return 0.0;
6253        }
6254        let u_normed: Array1<f64> = u / norm;
6255        let v = to_original(&u_normed);
6256        // gamma = T[v,v,v] since v already has ||v||_H = 1
6257        let cubic = directional_cubic_contraction(design, c_weights, &v.view());
6258        if cubic.is_finite() { cubic.abs() } else { 0.0 }
6259    };
6260
6261    // One step of Riemannian gradient ascent on the whitened sphere for |T[v,v,v]|.
6262    let refine_step = |u: &Array1<f64>| -> Array1<f64> {
6263        let norm = u.dot(u).sqrt();
6264        if norm < 1e-30 {
6265            return u.clone();
6266        }
6267        let u_normed: Array1<f64> = u / norm;
6268        let v = to_original(&u_normed);
6269        // Gradient of T[v,v,v] w.r.t. v in original space
6270        let grad_v = directional_cubic_gradient(design, c_weights, &v);
6271        // Map to whitened space
6272        let mut grad_u = to_whitened(&grad_v);
6273        // Project onto tangent plane of sphere: grad - (grad . u) u
6274        let dot = grad_u.dot(&u_normed);
6275        grad_u.scaled_add(-dot, &u_normed);
6276        // Sign: we want to maximize |T|, so follow sign(T) * grad
6277        let cubic_val = directional_cubic_contraction(design, c_weights, &v.view());
6278        let sign = if cubic_val >= 0.0 { 1.0 } else { -1.0 };
6279        let step_size = 0.3;
6280        let mut u_new = &u_normed + &(&grad_u * (sign * step_size));
6281        let new_norm = u_new.dot(&u_new).sqrt();
6282        if new_norm > 1e-30 {
6283            u_new /= new_norm;
6284        }
6285        u_new
6286    };
6287
6288    let mut best = 0.0_f64;
6289
6290    // Build seed directions:
6291    // (a) The eigenvector with largest |gamma_r| (already computed by caller,
6292    //     but we re-derive the whitened form here).
6293    // (b) Deterministic probe directions from pairs of top eigenvectors:
6294    //     (e_i + e_j) / sqrt(2) and (e_i - e_j) / sqrt(2) in whitened space.
6295    let mut seeds: Vec<Array1<f64>> = Vec::with_capacity(max_probes);
6296
6297    // Seed (a): each eigenvector is a standard basis vector in whitened space.
6298    // Find the one with largest |gamma|.
6299    let mut best_eig_idx = 0;
6300    let mut best_eig_gamma = 0.0_f64;
6301    for j in 0..n_pos {
6302        let mut u = Array1::<f64>::zeros(n_pos);
6303        u[j] = 1.0;
6304        let g = eval_gamma(&u);
6305        if g > best_eig_gamma {
6306            best_eig_gamma = g;
6307            best_eig_idx = j;
6308        }
6309    }
6310    best = best.max(best_eig_gamma);
6311    let mut u_best = Array1::<f64>::zeros(n_pos);
6312    u_best[best_eig_idx] = 1.0;
6313    seeds.push(u_best);
6314
6315    // Seed (b): pairwise combinations of the top few eigenvectors.
6316    let n_top = n_pos.min(4);
6317    for i in 0..n_top {
6318        for j in (i + 1)..n_top {
6319            if seeds.len() >= max_probes {
6320                break;
6321            }
6322            let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
6323            let mut u_plus = Array1::<f64>::zeros(n_pos);
6324            u_plus[i] = inv_sqrt2;
6325            u_plus[j] = inv_sqrt2;
6326            seeds.push(u_plus);
6327            if seeds.len() < max_probes {
6328                let mut u_minus = Array1::<f64>::zeros(n_pos);
6329                u_minus[i] = inv_sqrt2;
6330                u_minus[j] = -inv_sqrt2;
6331                seeds.push(u_minus);
6332            }
6333        }
6334    }
6335
6336    // Run power iteration from each seed.
6337    for seed in &seeds {
6338        let mut u = seed.clone();
6339        for _ in 0..max_iters {
6340            u = refine_step(&u);
6341        }
6342        let g = eval_gamma(&u);
6343        best = best.max(g);
6344    }
6345
6346    best
6347}
6348
6349// ───────────────── #1521 laplace-sampler contract re-exports ─────────────────
6350//
6351// The neutral DATA carriers + the caller-supplied [`BlockExcessTarget`]
6352// evaluator + the pure threshold math were contract-downed to the neutral
6353// `gam-problem` crate (#1521) so gam-solve (whose `Gam784BlockTarget`
6354// IMPLEMENTS `BlockExcessTarget`) and this gam-inference-tier sampler share one
6355// set of types without an SCC edge. The COMPUTATION (NUTS, importance sampling,
6356// the directional-cubic eigen diagnostic) stays UP in this module and
6357// constructs these types under their original names via this re-export.
6358pub use gam_problem::laplace_sampler_contract::{
6359    BlockExcessTarget, BlockSampledMarginal, BlockSampledMoments, GaussianModePosterior,
6360    LaplaceTrustworthiness, laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
6361};
6362
6363/// Monolith (gam-inference-tier) implementor of the contract-downed
6364/// [`LaplaceMarginalSampler`](gam_problem::laplace_sampler_contract::LaplaceMarginalSampler):
6365/// wraps the `hmc_io` directional-cubic eigen diagnostic and the
6366/// importance-sampled #784 block correction. Registered at process init via
6367/// `gam_problem::laplace_sampler_contract::set_laplace_marginal_sampler`.
6368pub struct HmcIoLaplaceMarginalSampler;
6369
6370impl gam_problem::laplace_sampler_contract::LaplaceMarginalSampler for HmcIoLaplaceMarginalSampler {
6371    fn directional_cubic_diagnostic(
6372        &self,
6373        hessian: &Array2<f64>,
6374        design: &DesignMatrix,
6375        c_weights: &Array1<f64>,
6376        refine_supremum: bool,
6377    ) -> Result<(f64, Array1<f64>), String> {
6378        laplace_directional_cubic_diagnostic(hessian, design, c_weights, refine_supremum)
6379    }
6380
6381    fn block_sampled_marginal_correction(
6382        &self,
6383        target: &dyn BlockExcessTarget,
6384    ) -> Result<BlockSampledMarginal, String> {
6385        block_sampled_marginal_correction(target)
6386    }
6387}
6388
6389/// Monolith (gam-inference-tier) implementor of the contract-downed
6390/// [`GaussianModePosteriorSampler`](gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler):
6391/// the never-fail Gaussian mode-posterior rung. Builds the NUTS config from the
6392/// problem dimension internally (so `NutsConfig` never crosses the contract)
6393/// and wraps `hmc_io::sample_gaussian_mode_posterior`. Registered at process
6394/// init via `gam_problem::laplace_sampler_contract::set_gaussian_mode_posterior_sampler`.
6395pub struct HmcIoGaussianModePosteriorSampler;
6396
6397impl gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler
6398    for HmcIoGaussianModePosteriorSampler
6399{
6400    fn sample_gaussian_mode_posterior(
6401        &self,
6402        mode: ArrayView1<f64>,
6403        precision: ArrayView2<f64>,
6404    ) -> Result<GaussianModePosterior, String> {
6405        let config = NutsConfig::for_dimension(mode.len());
6406        sample_gaussian_mode_posterior(mode, precision, &config)
6407    }
6408}
6409
6410/// Auto-derive the number of importance draws for the block-local sampled
6411/// marginalization from the block dimension.  MAGIC: more directions need more
6412/// draws to control the importance-weight variance, but the block is small by
6413/// construction (only the curvature-heavy directions), so this stays cheap.
6414/// No CLI flag.
6415fn block_sampling_draws(block_dim: usize) -> usize {
6416    // Base budget plus a per-direction allowance; capped so a pathological
6417    // block can never make a single inner evaluation explode.
6418    const BASE: usize = 256;
6419    const PER_DIM: usize = 256;
6420    const CAP: usize = 4096;
6421    (BASE + PER_DIM * block_dim).min(CAP)
6422}
6423
6424/// Estimate the block-local sampled marginal correction `Δ_b` and its
6425/// ρ-gradient by importance sampling against the local Laplace Gaussian
6426/// (issue #784).
6427///
6428/// # Math
6429///
6430/// Draw `t_s ~ q = N(0, diag(1/λ_r))` (the local Laplace Gaussian in the block
6431/// subspace; whitened draws `z_s ~ N(0, I)` give `t_{s,r} = z_{s,r}/√λ_r`).
6432/// With the non-Gaussian remainder `ΔF` defined on [`BlockExcessTarget`],
6433///
6434///   exp(Δ_b) = E_q[ exp(−ΔF(t)) ]  ⇒  Δ_b = log mean_s exp(−ΔF(t_s)),
6435///
6436/// computed via a numerically-stable log-mean-exp.  The ρ-gradient follows
6437/// from differentiating `Δ_b = log E_q[e^{−ΔF}]` (the `q`-Gaussian normalizer
6438/// `½Σ log(2π/λ_r)` cancels against `A_Lap`, leaving only the `ΔF` channel):
6439///
6440///   ∂Δ_b/∂ρ_k = E_p[ −∂ΔF/∂ρ_k ],   p ∝ q·e^{−ΔF},
6441///
6442/// i.e. the self-normalized importance-weighted average of `−∂ΔF/∂ρ_k` over the
6443/// same draws.  Because value and gradient come from one set of draws and one
6444/// target, they are mutually consistent — the contract the outer REML needs.
6445///
6446/// Determinism: draws come from a fixed-seed RNG so the inner evaluation is a
6447/// pure function of `(β̂, H, ρ)` and the outer optimizer sees a smooth,
6448/// reproducible objective rather than Monte-Carlo jitter across evaluations.
6449pub fn block_sampled_marginal_correction<T: BlockExcessTarget + ?Sized>(
6450    target: &T,
6451) -> Result<BlockSampledMarginal, String> {
6452    use rand::SeedableRng;
6453    use rand::rngs::StdRng;
6454
6455    let m = target.block_dim();
6456    let k = target.rho_dim();
6457    if m == 0 {
6458        return Ok(BlockSampledMarginal {
6459            value: 0.0,
6460            rho_gradient: Array1::zeros(k),
6461            importance_ess: 0.0,
6462            n_draws: 0,
6463            moments: None,
6464        });
6465    }
6466    let lambdas = target.block_curvatures();
6467    if lambdas.len() != m {
6468        return Err(format!(
6469            "block_sampled_marginal_correction: block_curvatures len {} != block_dim {m}",
6470            lambdas.len()
6471        ));
6472    }
6473    let inv_sqrt_lambda: Array1<f64> = lambdas.mapv(|l| {
6474        if l > 0.0 {
6475            1.0 / l.sqrt()
6476        } else {
6477            // A non-positive block curvature means the mode is not a strict
6478            // minimum in this direction; the Laplace Gaussian is undefined
6479            // there. Reject rather than fabricate a correction.
6480            f64::NAN
6481        }
6482    });
6483    if inv_sqrt_lambda.iter().any(|v| !v.is_finite()) {
6484        return Err(
6485            "block_sampled_marginal_correction: non-positive block curvature (mode is not a \
6486             strict local minimum in a sampled direction)"
6487                .to_string(),
6488        );
6489    }
6490
6491    let n_draws = block_sampling_draws(m);
6492    // ρ-invariant fixed seed → deterministic AND smooth-in-ρ objective.
6493    //
6494    // The doc comment above promises "the outer optimizer sees a smooth,
6495    // reproducible objective rather than Monte-Carlo jitter across
6496    // evaluations." That smoothness holds only if the importance draws
6497    // `z_s` themselves do NOT depend on ρ — ρ may enter the estimator
6498    // only through the per-sample importance weights `exp(−ΔF(t_s))` and
6499    // the rescaling `t_s = z_s / √λ_r`, both of which are continuous in
6500    // ρ for fixed `z_s`. A seed mixed from `λ_r = exp(ρ_k)` (or any
6501    // other ρ-dependent quantity such as the H-eigenvalues) permutes
6502    // `z_s` for every ρ probe, so the FD `(F(ρ+h) − F(ρ−h))/2h`
6503    // identity fails by O(MC_stdev/h) — exactly the order-10²–10³ FD
6504    // blow-up observed in the iso-κ Duchon binomial FD probes — and
6505    // every outer trust-region step lands on a different random face of
6506    // the objective. Mix only the (ρ-invariant) block / outer dimensions
6507    // so different problems still get independent streams.
6508    let mut seed_bits: u64 = 0x9E37_79B9_7F4A_7C15;
6509    seed_bits ^= (m as u64).rotate_left(17);
6510    seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6511    seed_bits ^= (k as u64).rotate_left(31);
6512    seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6513    let mut rng = StdRng::seed_from_u64(seed_bits);
6514
6515    // Streaming, numerically-stable accumulation of the log-mean-exp value,
6516    // the explicit gradient channel `E_p[−∂ΔF/∂ρ]`, AND the gradient-channel
6517    // moments `E_p[t]`, `E_p[t tᵀ]`, `E_p[ngs]`, `E_p[t ⊗ ngs]` needed by the
6518    // exact (b)–(d) channel assembly (gradient exactness contract above).
6519    // Weights are kept relative to a running maximum log-weight: whenever a
6520    // new maximum arrives, every accumulator is rescaled by
6521    // `exp(max_old − max_new) ≤ 1`, so each per-draw relative weight is ≤ 1
6522    // and the sums never overflow. Infeasible / divergent draws contribute
6523    // zero weight rather than poisoning the estimate.
6524    let n_obs = target.base_neg_score().len();
6525    let mut max_lw = f64::NEG_INFINITY;
6526    let mut sum_w = 0.0_f64;
6527    let mut sum_w2 = 0.0_f64;
6528    let mut grad_acc = Array1::<f64>::zeros(k);
6529    let mut e_t_acc = Array1::<f64>::zeros(m);
6530    let mut e_tt_acc = Array2::<f64>::zeros((m, m));
6531    let mut e_ngs_acc = Array1::<f64>::zeros(n_obs);
6532    let mut e_t_ngs_acc = Array2::<f64>::zeros((n_obs, m));
6533
6534    // Pre-generate ALL whitened draws into the columns of `draws` (m × n_draws)
6535    // in the EXACT same RNG order as the serial loop (draw 0: r=0..m, draw 1:
6536    // r=0..m, …). The per-draw design matvec `s = X_t·(V_b·t_s)` is then batched
6537    // into two BLAS-3 products over all columns at once (the #1082 hot path),
6538    // instead of n_draws separate BLAS-2 matvecs — the draws, seed, budget, and
6539    // importance weights are byte-for-byte unchanged; only the matvecs are
6540    // reassociated into a GEMM.
6541    let mut draws = Array2::<f64>::zeros((m, n_draws));
6542    for s in 0..n_draws {
6543        let mut col = draws.column_mut(s);
6544        for r in 0..m {
6545            let z = sample_standard_normal(&mut rng);
6546            col[r] = z * inv_sqrt_lambda[r];
6547        }
6548    }
6549    let batched = target.excess_with_displaced_neg_score_batch(&draws);
6550
6551    let mut t = Array1::<f64>::zeros(m);
6552    for (sidx, (excess, displaced_ngs)) in batched.into_iter().enumerate() {
6553        t.assign(&draws.column(sidx));
6554        if !excess.is_finite() {
6555            continue;
6556        }
6557        let Some(ngs) = displaced_ngs else {
6558            // A finite excess always carries a score; absence means infeasible.
6559            continue;
6560        };
6561        let lw = -excess;
6562        if lw > max_lw {
6563            // exp(−∞ − lw) = 0 zeroes the (empty) accumulators on the first
6564            // feasible draw, so no special-casing is needed.
6565            let rescale = (max_lw - lw).exp();
6566            sum_w *= rescale;
6567            sum_w2 *= rescale * rescale;
6568            grad_acc *= rescale;
6569            e_t_acc *= rescale;
6570            e_tt_acc *= rescale;
6571            e_ngs_acc *= rescale;
6572            e_t_ngs_acc *= rescale;
6573            max_lw = lw;
6574        }
6575        let w = (lw - max_lw).exp();
6576        sum_w += w;
6577        sum_w2 += w * w;
6578        // Explicit channel: −∂ΔF/∂ρ.
6579        grad_acc.scaled_add(-w, &target.excess_rho_gradient(&t));
6580        // Moment channels (score already computed in the fused call above).
6581        if ngs.len() != n_obs {
6582            return Err(format!(
6583                "block_sampled_marginal_correction: displaced_neg_score len {} != {n_obs}",
6584                ngs.len()
6585            ));
6586        }
6587        e_t_acc.scaled_add(w, &t);
6588        e_ngs_acc.scaled_add(w, &ngs);
6589        for r in 0..m {
6590            let wt_r = w * t[r];
6591            for q in 0..m {
6592                e_tt_acc[(q, r)] += wt_r * t[q];
6593            }
6594            e_t_ngs_acc.column_mut(r).scaled_add(wt_r, &ngs);
6595        }
6596    }
6597    if !max_lw.is_finite() {
6598        return Err(
6599            "block_sampled_marginal_correction: all importance draws were infeasible".to_string(),
6600        );
6601    }
6602    let value = max_lw + (sum_w / n_draws as f64).ln();
6603    // Self-normalized importance-weighted gradient E_p[−∂ΔF/∂ρ] and moments.
6604    let (rho_gradient, moments) = if sum_w > 0.0 {
6605        (
6606            grad_acc / sum_w,
6607            Some(BlockSampledMoments {
6608                e_t: e_t_acc / sum_w,
6609                e_tt: e_tt_acc / sum_w,
6610                e_neg_score: e_ngs_acc / sum_w,
6611                e_t_neg_score: e_t_ngs_acc / sum_w,
6612            }),
6613        )
6614    } else {
6615        (Array1::zeros(k), None)
6616    };
6617    // Kish effective sample size of the importance weights.
6618    let importance_ess = if sum_w2 > 0.0 {
6619        (sum_w * sum_w) / sum_w2
6620    } else {
6621        0.0
6622    };
6623
6624    if !value.is_finite() || rho_gradient.iter().any(|v| !v.is_finite()) {
6625        return Err(
6626            "block_sampled_marginal_correction: produced a non-finite correction or gradient"
6627                .to_string(),
6628        );
6629    }
6630    if let Some(mo) = moments.as_ref()
6631        && (mo.e_t.iter().any(|v| !v.is_finite())
6632            || mo.e_tt.iter().any(|v| !v.is_finite())
6633            || mo.e_neg_score.iter().any(|v| !v.is_finite())
6634            || mo.e_t_neg_score.iter().any(|v| !v.is_finite()))
6635    {
6636        return Err(
6637            "block_sampled_marginal_correction: produced non-finite gradient-channel moments"
6638                .to_string(),
6639        );
6640    }
6641
6642    Ok(BlockSampledMarginal {
6643        value,
6644        rho_gradient,
6645        importance_ess,
6646        n_draws,
6647        moments,
6648    })
6649}
6650
6651/// Result of joint (β, ρ) sampling.
6652#[derive(Clone, Debug)]
6653pub struct JointBetaRhoResult {
6654    /// Coefficient samples: shape (n_total_samples, n_beta)
6655    pub beta_samples: Array2<f64>,
6656    /// Log-smoothing parameter samples: shape (n_total_samples, n_rho)
6657    pub rho_samples: Array2<f64>,
6658    /// Posterior mean of β
6659    pub beta_mean: Array1<f64>,
6660    /// Adaptive inverse-link parameter samples: shape (n_total_samples, n_link_params)
6661    pub link_param_samples: Array2<f64>,
6662    /// Posterior mean of adaptive inverse-link parameters
6663    pub link_param_mean: Array1<f64>,
6664    /// Posterior mean of ρ
6665    pub rho_mean: Array1<f64>,
6666    /// R-hat diagnostic
6667    pub rhat: f64,
6668    /// Effective sample size
6669    pub ess: f64,
6670    /// Whether sampling converged
6671    pub converged: bool,
6672    /// Max skewness that triggered this sampling
6673    pub trigger_skewness: f64,
6674}
6675
6676/// Joint (β, ρ) posterior target for NUTS.
6677///
6678/// Samples from p(β, ρ | y) ∝ p(y|β) p(β|ρ) p(ρ) directly,
6679/// completely bypassing the Laplace approximation.
6680///
6681/// The parameter vector is [z_β; ρ] where z_β = L⁻¹(β - μ) is the
6682/// whitened β, ρ is the raw log-smoothing parameters, and adaptive inverse-link
6683/// parameters follow when the binomial link has fitted shape/mixing parameters.
6684struct JointBetaRhoPosterior {
6685    data: SharedData,
6686    /// L where LL' = H⁻¹ (whitening for β block)
6687    chol: Array2<f64>,
6688    /// L' for chain rule
6689    chol_t: Array2<f64>,
6690    /// Joint likelihood specification (response + parameterized link).
6691    likelihood: LikelihoodSpec,
6692    /// Dimension of β
6693    n_beta: usize,
6694    /// Dimension of ρ
6695    n_rho: usize,
6696    /// Dimension of adaptive inverse-link parameters
6697    n_link_params: usize,
6698    /// LAML-converged adaptive inverse-link parameters (used only to initialize chains)
6699    link_param_mode: Array1<f64>,
6700    /// Canonical penalties in the transformed basis.
6701    penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6702    /// Fixed prior on rho used by the sampled target.
6703    rho_prior: RhoPrior,
6704    /// LAML-converged ρ (used only to initialize chains)
6705    rho_mode: Array1<f64>,
6706    /// Whether to add the identifiable-subspace Jeffreys/Firth term to the
6707    /// target
6708    firth_enabled: bool,
6709    /// One-deep cache for the structural penalty pseudo-logdet and its
6710    /// ρ-gradient. NUTS tree-doubling and U-turn checks repeatedly evaluate
6711    /// the joint log-posterior at the same `rho` bytes, so a single-slot
6712    /// cache keyed on the exact f64 bit pattern of `rho` avoids redundant
6713    /// SVD/eigendecompositions inside `PenaltyPseudologdet::from_penalties`.
6714    /// `Mutex` (not `RefCell`) because chains share the target via
6715    /// `Arc<Target>` and run in parallel via rayon.
6716    penalty_logdet_cache: Mutex<Option<(u64, f64, Array1<f64>)>>,
6717}
6718
6719impl JointBetaRhoPosterior {
6720    fn new(
6721        x: ArrayView2<f64>,
6722        y: ArrayView1<f64>,
6723        weights: ArrayView1<f64>,
6724        mode: ArrayView1<f64>,
6725        hessian: ArrayView2<f64>,
6726        penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6727        rho_mode: ArrayView1<f64>,
6728        likelihood: LikelihoodSpec,
6729        gamma_shape: Option<f64>,
6730        rho_prior: RhoPrior,
6731        firth_enabled: bool,
6732    ) -> Result<Self, String> {
6733        let n_samples = x.nrows();
6734        let n_beta = x.ncols();
6735        let n_rho = penalty_canonical.len();
6736
6737        if rho_mode.len() != n_rho {
6738            return Err(HmcError::DimensionMismatch {
6739                reason: format!(
6740                    "rho_mode length {} != penalty count {}",
6741                    rho_mode.len(),
6742                    n_rho
6743                ),
6744            }
6745            .into());
6746        }
6747
6748        match (&likelihood.response, &likelihood.link) {
6749            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {}
6750            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {}
6751            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {}
6752            (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {}
6753            (ResponseFamily::Binomial, InverseLink::Sas(_)) => {}
6754            (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {}
6755            (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {}
6756            (ResponseFamily::Binomial, InverseLink::Standard(other)) => {
6757                return Err(HmcError::LinkMismatch {
6758                    reason: format!(
6759                        "Joint HMC binomial response requires a binomial-compatible inverse link; got {:?}",
6760                        other
6761                    ),
6762                }
6763                .into());
6764            }
6765            (ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {}
6766            (ResponseFamily::Gaussian, _) => {
6767                return Err(HmcError::LinkMismatch {
6768                    reason: "Joint HMC Gaussian requires an identity inverse link".to_string(),
6769                }
6770                .into());
6771            }
6772            (
6773                ResponseFamily::Poisson
6774                | ResponseFamily::Tweedie { .. }
6775                | ResponseFamily::NegativeBinomial { .. }
6776                | ResponseFamily::Gamma,
6777                InverseLink::Standard(StandardLink::Log),
6778            ) => {}
6779            (
6780                ResponseFamily::Poisson
6781                | ResponseFamily::Tweedie { .. }
6782                | ResponseFamily::NegativeBinomial { .. }
6783                | ResponseFamily::Gamma,
6784                _,
6785            ) => {
6786                return Err(HmcError::LinkMismatch {
6787                    reason: "Joint HMC log-link family requires a log inverse link".to_string(),
6788                }
6789                .into());
6790            }
6791            (ResponseFamily::Beta { .. }, InverseLink::Standard(StandardLink::Logit)) => {}
6792            (ResponseFamily::Beta { .. }, _) => {
6793                return Err(HmcError::LinkMismatch {
6794                    reason: "Joint HMC Beta requires a logit inverse link".to_string(),
6795                }
6796                .into());
6797            }
6798            (ResponseFamily::RoystonParmar, _) => {
6799                return Err(HmcError::UnsupportedFamily {
6800                    reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
6801                }
6802                .into());
6803            }
6804        }
6805
6806        validate_firth_likelihood_support(&likelihood, firth_enabled).map_err(String::from)?;
6807        if matches!(likelihood.response, ResponseFamily::NegativeBinomial { .. }) {
6808            validate_count_responses("negative-binomial joint HMC", &y, &weights)
6809                .map_err(String::from)?;
6810        }
6811        if likelihood.is_binomial() {
6812            validate_binary_responses("binomial joint HMC", &y, &weights).map_err(String::from)?;
6813        }
6814
6815        let whitening = hessian_whitening_transform(
6816            hessian,
6817            n_beta,
6818            1.0,
6819            "Joint HMC: Hessian Cholesky failed",
6820        )?;
6821        let chol = whitening.chol;
6822        let chol_t = whitening.chol_t;
6823
6824        let data = SharedData {
6825            x: Arc::new(x.to_owned()),
6826            y: Arc::new(y.to_owned()),
6827            weights: Arc::new(weights.to_owned()),
6828            mode: Arc::new(mode.to_owned()),
6829            offset: None,
6830            gamma_shape: gamma_shape.unwrap_or(1.0),
6831            // Joint (β, ρ) HMC keeps the likelihood on its native scale;
6832            // dispersion enters via the per-family scale parameter, not
6833            // via the whitening transform here. `Known(1.0)` matches the
6834            // pre-refactor behaviour for this code path.
6835            dispersion: gam_solve::model_types::Dispersion::Known(1.0),
6836            n_samples,
6837            dim: n_beta,
6838        };
6839        let link_param_mode = Self::link_param_mode(&likelihood.link);
6840
6841        Ok(Self {
6842            data,
6843            chol,
6844            chol_t,
6845            likelihood,
6846            n_beta,
6847            n_rho,
6848            n_link_params: link_param_mode.len(),
6849            link_param_mode,
6850            penalty_canonical,
6851            rho_prior,
6852            rho_mode: rho_mode.to_owned(),
6853            firth_enabled,
6854            penalty_logdet_cache: Mutex::new(None),
6855        })
6856    }
6857
6858    /// FNV-1a hash over the raw f64 bit pattern of `rho`.
6859    ///
6860    /// NUTS leapfrog / tree-doubling / U-turn checks revisit identical
6861    /// position vectors byte-for-byte, so exact-equality on `to_bits()`
6862    /// captures the dominant repetition pattern without any tolerance.
6863    #[inline]
6864    fn hash_rho(rho: ndarray::ArrayView1<f64>) -> u64 {
6865        let mut h: u64 = 0xcbf2_9ce4_8422_2325;
6866        for &x in rho.iter() {
6867            h ^= x.to_bits();
6868            h = h.wrapping_mul(0x0000_0100_0000_01b3);
6869        }
6870        h
6871    }
6872
6873    fn link_param_mode(inverse_link: &InverseLink) -> Array1<f64> {
6874        match inverse_link {
6875            InverseLink::Sas(state) | InverseLink::BetaLogistic(state) => {
6876                Array1::from_vec(vec![state.epsilon, state.log_delta])
6877            }
6878            InverseLink::Mixture(state) => state.rho.clone(),
6879            InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => Array1::zeros(0),
6880        }
6881    }
6882
6883    fn inverse_link_with_params(
6884        &self,
6885        link_params: ndarray::ArrayView1<'_, f64>,
6886    ) -> Result<InverseLink, String> {
6887        match &self.likelihood.link {
6888            InverseLink::Sas(_) => {
6889                if link_params.len() != 2 {
6890                    return Err(format!(
6891                        "SAS link parameter length must be 2, got {}",
6892                        link_params.len()
6893                    ));
6894                }
6895                Ok(InverseLink::Sas(
6896                    gam_solve::mixture_link::sas_link_state_from_raw(
6897                        link_params[0],
6898                        link_params[1],
6899                    )?,
6900                ))
6901            }
6902            InverseLink::BetaLogistic(_) => {
6903                if link_params.len() != 2 || !link_params.iter().all(|v| v.is_finite()) {
6904                    return Err(
6905                        "Beta-Logistic link parameters must be finite with length 2".to_string()
6906                    );
6907                }
6908                Ok(InverseLink::BetaLogistic(
6909                    gam_problem::types::SasLinkState {
6910                        epsilon: link_params[0],
6911                        log_delta: link_params[1],
6912                        delta: link_params[1].exp(),
6913                    },
6914                ))
6915            }
6916            InverseLink::Mixture(state) => {
6917                let rho = link_params.to_owned();
6918                Ok(InverseLink::Mixture(gam_problem::types::MixtureLinkState {
6919                    components: state.components.clone(),
6920                    pi: softmax_last_fixedzero(&rho),
6921                    rho,
6922                }))
6923            }
6924            InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => {
6925                Ok(self.likelihood.link.clone())
6926            }
6927        }
6928    }
6929
6930    /// Compute the joint log-posterior and gradient.
6931    ///
6932    /// The joint log-posterior is:
6933    ///   log p(β, ρ | y) = ℓ(y|β) + ½ log|I(β)| [if Firth]
6934    ///                    − ½β'S(ρ)β + ½ log|S(ρ)|₊ + log p(ρ) + const
6935    ///
6936    /// This is NOT the REML/LAML objective (which integrates out β). Here β is
6937    /// an explicit parameter being sampled, evaluated at arbitrary values — not
6938    /// just at the mode β̂(ρ).
6939    ///
6940    /// Parameter vector layout: [z_β (whitened, length n_beta); ρ (length n_rho);
6941    /// adaptive inverse-link params (length n_link_params)]
6942    fn compute_joint_logp_and_grad_into(
6943        &self,
6944        params: &Array1<f64>,
6945        out_grad: &mut Array1<f64>,
6946    ) -> f64 {
6947        let n_beta = self.n_beta;
6948        let n_rho = self.n_rho;
6949        let n_link_params = self.n_link_params;
6950
6951        // Split parameter vector — keep as views to avoid two per-step
6952        // `to_owned()` allocations of size n_beta and n_rho.
6953        let z = params.slice(ndarray::s![..n_beta]);
6954        let rho = params.slice(ndarray::s![n_beta..n_beta + n_rho]);
6955        let link_params = params.slice(ndarray::s![n_beta + n_rho..]);
6956        let lambdas: Array1<f64> = rho.mapv(f64::exp);
6957
6958        let inverse_link = match self.inverse_link_with_params(link_params) {
6959            Ok(link) => link,
6960            Err(err) => {
6961                log::warn!(
6962                    "[Joint HMC] adaptive inverse-link parameters are invalid: {}",
6963                    err
6964                );
6965                out_grad.fill(0.0);
6966                return f64::NEG_INFINITY;
6967            }
6968        };
6969
6970        // Un-whiten: β = μ + L z
6971        let beta = self.data.mode.as_ref() + &self.chol.dot(&z);
6972
6973        // η = X β
6974        let eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
6975
6976        // ---- Log-likelihood ℓ(y|β) and ∇_β ℓ ----
6977        let step_likelihood = LikelihoodSpec {
6978            response: self.likelihood.response.clone(),
6979            link: inverse_link,
6980        };
6981        let (ll, mut grad_ll_beta, grad_link) = match joint_family_logp_grad_and_link_grad(
6982            &step_likelihood,
6983            &self.data,
6984            &eta,
6985            n_link_params,
6986        ) {
6987            Ok(value) => value,
6988            Err(err) => {
6989                log::warn!(
6990                    "[Joint HMC] likelihood target became invalid at the current state: {}",
6991                    err
6992                );
6993                out_grad.fill(0.0);
6994                return f64::NEG_INFINITY;
6995            }
6996        };
6997
6998        let mut firth_logdet = 0.0;
6999        if self.firth_enabled {
7000            match firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &self.data, &eta) {
7001                Ok((value, grad_beta_firth)) => {
7002                    firth_logdet = value;
7003                    grad_ll_beta += &grad_beta_firth;
7004                }
7005                Err(err) => {
7006                    log::warn!(
7007                        "[Joint HMC/Firth] Jeffreys target became invalid at the current state: {}",
7008                        err
7009                    );
7010                    out_grad.fill(0.0);
7011                    return f64::NEG_INFINITY;
7012                }
7013            }
7014        }
7015
7016        // ---- Penalty: -0.5 β'S(ρ)β ----
7017        // S(ρ) = Σ_k λ_k S_k where S_k = R_k'R_k (precomputed in penalty_matrices).
7018        // Uses penalty_roots for the efficient ||R_k β||² form.
7019        let mut penalty_val = 0.0;
7020        let mut s_beta = Array1::<f64>::zeros(n_beta);
7021        let mut grad_rho = Array1::<f64>::zeros(n_rho);
7022
7023        // Reuse one max-rank scratch buffer for r_beta = R_k · β_block across
7024        // all penalty blocks instead of allocating a fresh Array1 per block
7025        // per HMC step.
7026        let max_rank = self
7027            .penalty_canonical
7028            .iter()
7029            .map(|cp| cp.rank())
7030            .max()
7031            .unwrap_or(0);
7032        let mut r_beta_scratch = Array1::<f64>::zeros(max_rank);
7033
7034        for (k, cp) in self.penalty_canonical.iter().enumerate() {
7035            // Block-local quadratic: β'S_k β via root
7036            let r = &cp.col_range;
7037            let beta_block = beta.slice(ndarray::s![r.start..r.end]);
7038            let rank_k = cp.rank();
7039            gam_linalg::faer_ndarray::fast_av_view_into(
7040                &cp.root,
7041                &beta_block,
7042                r_beta_scratch.slice_mut(ndarray::s![..rank_k]),
7043            );
7044            let r_beta = r_beta_scratch.slice(ndarray::s![..rank_k]);
7045            let quad_k = r_beta.dot(&r_beta);
7046            penalty_val += 0.5 * lambdas[k] * quad_k;
7047
7048            // Accumulate S(ρ)β for β-gradient — block-local
7049            for a in 0..cp.block_dim() {
7050                let val: f64 = (0..rank_k).map(|row| cp.root[[row, a]] * r_beta[row]).sum();
7051                s_beta[r.start + a] += lambdas[k] * val;
7052            }
7053
7054            // ρ_k gradient from penalty
7055            grad_rho[k] = -0.5 * lambdas[k] * quad_k;
7056        }
7057
7058        // ---- Structural penalty log-determinant: +0.5 log|S(ρ)|₊ and ρ-derivatives ----
7059        //
7060        // One-deep cache keyed on the exact f64 bits of `rho`: NUTS tree
7061        // doubling revisits identical positions byte-for-byte, so an
7062        // exact-equality cache eliminates the dominant SVD/eigendecomp
7063        // cost in `PenaltyPseudologdet::from_penalties` across leapfrog
7064        // half-steps.
7065        let log_det_s = if self.penalty_canonical.is_empty() {
7066            0.0
7067        } else {
7068            let rho_hash = Self::hash_rho(rho);
7069            let cached = self.penalty_logdet_cache.lock().ok().and_then(|guard| {
7070                guard.as_ref().and_then(|(h, v, g)| {
7071                    if *h == rho_hash && g.len() == n_rho {
7072                        for k in 0..n_rho {
7073                            grad_rho[k] += 0.5 * g[k];
7074                        }
7075                        Some(*v)
7076                    } else {
7077                        None
7078                    }
7079                })
7080            });
7081            if let Some(hit) = cached {
7082                hit
7083            } else {
7084                match PenaltyPseudologdet::from_penalties(
7085                    &self.penalty_canonical,
7086                    lambdas.as_slice().unwrap_or(&[]),
7087                    0.0,
7088                    n_beta,
7089                ) {
7090                    Ok(pld) => {
7091                        let (det1, _) = pld.rho_derivatives_from_penalties(
7092                            &self.penalty_canonical,
7093                            lambdas.as_slice().unwrap_or(&[]),
7094                        );
7095                        let value = pld.value();
7096                        if let Ok(mut guard) = self.penalty_logdet_cache.lock() {
7097                            *guard = Some((rho_hash, value, det1.clone()));
7098                        }
7099                        for k in 0..n_rho {
7100                            grad_rho[k] += 0.5 * det1[k];
7101                        }
7102                        value
7103                    }
7104                    Err(err) => {
7105                        log::warn!(
7106                            "[Joint HMC] structural penalty logdet became invalid at the current state: {}",
7107                            err
7108                        );
7109                        out_grad.fill(0.0);
7110                        return f64::NEG_INFINITY;
7111                    }
7112                }
7113            }
7114        };
7115
7116        // ---- Prior on ρ ----
7117        let mut rho_prior = 0.0;
7118        match &self.rho_prior {
7119            RhoPrior::Flat => {}
7120            RhoPrior::Normal { mean, sd } => {
7121                let inv_var = 1.0 / (*sd * *sd);
7122                for k in 0..n_rho {
7123                    let d = rho[k] - *mean;
7124                    rho_prior -= 0.5 * inv_var * d * d;
7125                    grad_rho[k] -= inv_var * d;
7126                }
7127            }
7128            RhoPrior::GammaPrecision { shape, rate } => {
7129                for k in 0..n_rho {
7130                    let lambda = rho[k].exp();
7131                    // Density over sampled rho includes the e^rho Jacobian (Gamma is on lambda = e^rho).
7132                    rho_prior += *shape * rho[k] - *rate * lambda;
7133                    grad_rho[k] += *shape - *rate * lambda;
7134                }
7135            }
7136            RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7137                if !pc_prior_params_valid(*upper, *tail_prob) {
7138                    out_grad.fill(0.0);
7139                    return f64::NEG_INFINITY;
7140                }
7141                let theta = -tail_prob.ln() / *upper;
7142                for k in 0..n_rho {
7143                    // log p(ρ) = const − ρ/2 − θ exp(−ρ/2).
7144                    let e = (-0.5 * rho[k]).exp();
7145                    rho_prior += -0.5 * rho[k] - theta * e;
7146                    grad_rho[k] += -0.5 + 0.5 * theta * e;
7147                }
7148            }
7149            RhoPrior::Independent(priors) => {
7150                if priors.len() != n_rho {
7151                    out_grad.fill(0.0);
7152                    return f64::NEG_INFINITY;
7153                }
7154                for k in 0..n_rho {
7155                    match &priors[k] {
7156                        RhoPrior::Flat => {}
7157                        RhoPrior::Normal { mean, sd } => {
7158                            let inv_var = 1.0 / (*sd * *sd);
7159                            let d = rho[k] - *mean;
7160                            rho_prior -= 0.5 * inv_var * d * d;
7161                            grad_rho[k] -= inv_var * d;
7162                        }
7163                        RhoPrior::GammaPrecision { shape, rate } => {
7164                            let lambda = rho[k].exp();
7165                            // Density over sampled rho includes the e^rho Jacobian (Gamma is on lambda = e^rho).
7166                            rho_prior += *shape * rho[k] - *rate * lambda;
7167                            grad_rho[k] += *shape - *rate * lambda;
7168                        }
7169                        RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7170                            if !pc_prior_params_valid(*upper, *tail_prob) {
7171                                out_grad.fill(0.0);
7172                                return f64::NEG_INFINITY;
7173                            }
7174                            let theta = -tail_prob.ln() / *upper;
7175                            let e = (-0.5 * rho[k]).exp();
7176                            rho_prior += -0.5 * rho[k] - theta * e;
7177                            grad_rho[k] += -0.5 + 0.5 * theta * e;
7178                        }
7179                        RhoPrior::Independent(_) => {
7180                            out_grad.fill(0.0);
7181                            return f64::NEG_INFINITY;
7182                        }
7183                    }
7184                }
7185            }
7186        }
7187
7188        // ---- Assemble ----
7189        let logp = ll + firth_logdet - penalty_val + 0.5 * log_det_s + rho_prior;
7190
7191        // β-gradient in original space: ∇_β ℓ - S(ρ)β
7192        let grad_beta = &grad_ll_beta - &s_beta;
7193
7194        // Combined gradient: [∇_z; ∇_ρ; ∇_link]
7195        gam_linalg::faer_ndarray::fast_av_view_into(
7196            &self.chol_t,
7197            &grad_beta,
7198            out_grad.slice_mut(ndarray::s![..n_beta]),
7199        );
7200        out_grad
7201            .slice_mut(ndarray::s![n_beta..n_beta + n_rho])
7202            .assign(&grad_rho);
7203        out_grad
7204            .slice_mut(ndarray::s![n_beta + n_rho..])
7205            .assign(&grad_link);
7206
7207        logp
7208    }
7209}
7210
7211/// Penalized-complexity hyperparameters are usable iff `upper` is finite and
7212/// strictly positive and `tail_prob` is a probability in the open `(0, 1)`.
7213/// Mirrors the validation in the shared `rho_prior_eval` engine; an invalid
7214/// configuration repels the sampler (`-∞` potential) rather than producing a
7215/// non-finite gradient.
7216fn pc_prior_params_valid(upper: f64, tail_prob: f64) -> bool {
7217    upper.is_finite() && upper > 0.0 && tail_prob.is_finite() && tail_prob > 0.0 && tail_prob < 1.0
7218}
7219
7220impl HamiltonianTarget<Array1<f64>> for JointBetaRhoPosterior {
7221    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7222        self.compute_joint_logp_and_grad_into(position, grad)
7223    }
7224}
7225
7226/// Inputs for joint (β, ρ) sampling.
7227pub struct JointBetaRhoInputs<'a> {
7228    pub x: ArrayView2<'a, f64>,
7229    pub y: ArrayView1<'a, f64>,
7230    pub weights: ArrayView1<'a, f64>,
7231    pub likelihood: LikelihoodSpec,
7232    pub gamma_shape: Option<f64>,
7233    pub mode: ArrayView1<'a, f64>,
7234    pub hessian: ArrayView2<'a, f64>,
7235    pub penalty_roots: Vec<CanonicalPenalty>,
7236    pub rho_mode: ArrayView1<'a, f64>,
7237    pub rho_prior: RhoPrior,
7238    pub firth_bias_reduction: bool,
7239    /// Max posterior skewness that triggered this sampling
7240    pub trigger_skewness: f64,
7241}
7242
7243/// Run joint (β, ρ) NUTS sampling.
7244///
7245/// This is the automatic fallback when the Laplace approximation has high
7246/// skewness. It samples from the true joint posterior, completely bypassing
7247/// the Laplace approximation for smoothing parameter selection.
7248pub fn run_joint_beta_rho_sampling(
7249    inputs: &JointBetaRhoInputs<'_>,
7250    config: &NutsConfig,
7251) -> Result<JointBetaRhoResult, String> {
7252    validate_firth_likelihood_support(&inputs.likelihood, inputs.firth_bias_reduction)
7253        .map_err(String::from)?;
7254    validate_nuts_config(config).map_err(String::from)?;
7255    let n_beta = inputs.mode.len();
7256    let n_rho = inputs.penalty_roots.len();
7257    let n_link_params = JointBetaRhoPosterior::link_param_mode(&inputs.likelihood.link).len();
7258    let total_dim = n_beta + n_rho + n_link_params;
7259
7260    log::info!(
7261        "[Joint HMC] Sampling (β, ρ, link) jointly: {} β-params + {} ρ-params + {} link-params = {} total (triggered by skewness {:.3})",
7262        n_beta,
7263        n_rho,
7264        n_link_params,
7265        total_dim,
7266        inputs.trigger_skewness,
7267    );
7268
7269    let target = JointBetaRhoPosterior::new(
7270        inputs.x,
7271        inputs.y,
7272        inputs.weights,
7273        inputs.mode,
7274        inputs.hessian,
7275        inputs.penalty_roots.clone(),
7276        inputs.rho_mode,
7277        inputs.likelihood.clone(),
7278        inputs.gamma_shape,
7279        inputs.rho_prior.clone(),
7280        inputs.firth_bias_reduction,
7281    )?;
7282
7283    let chol = target.chol.clone();
7284    let mode_arr = target.data.mode.clone();
7285    let rho_mode = target.rho_mode.clone();
7286    let link_param_mode = target.link_param_mode.clone();
7287
7288    // Initialize chains: z_β at 0 (= mode), ρ at rho_mode, link params at fitted state.
7289    let initial_positions: Vec<Array1<f64>> = (0..config.n_chains)
7290        .map(|chain| {
7291            let mut rng =
7292                StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x9B51_6E37_F2D0_A48C));
7293            let mut pos = Array1::<f64>::zeros(total_dim);
7294            // Small jitter for β (whitened space)
7295            for j in 0..n_beta {
7296                pos[j] = sample_standard_normal(&mut rng) * 0.1;
7297            }
7298            // Small jitter for ρ around mode
7299            for k in 0..n_rho {
7300                pos[n_beta + k] = rho_mode[k] + sample_standard_normal(&mut rng) * 0.2;
7301            }
7302            // Small jitter for adaptive link parameters around fitted state
7303            for k in 0..n_link_params {
7304                pos[n_beta + n_rho + k] =
7305                    link_param_mode[k] + sample_standard_normal(&mut rng) * 0.05;
7306            }
7307            pos
7308        })
7309        .collect();
7310
7311    // Keep warmup covariance phase-local: diagonal windows are less likely to
7312    // encode cross-block covariance from a transient mode switch.
7313    let mass_cfg = robust_mass_matrix_config(total_dim, config.nwarmup);
7314
7315    let (samples_array, run_stats) = run_whitened_nuts_samples(
7316        target,
7317        initial_positions,
7318        config,
7319        total_dim,
7320        mass_cfg,
7321        0x63AF_175B_D820_C94E,
7322        "Joint (β,ρ) NUTS sampling failed",
7323    )?;
7324    log::info!("[Joint HMC] Sampling complete: {}", run_stats);
7325
7326    // Unpack samples
7327    let shape = samples_array.shape();
7328    let n_chains = shape[0];
7329    let n_samples_out = shape[1];
7330    let total_samples = n_chains * n_samples_out;
7331
7332    let beta_samples = unwhiten_samples(&samples_array, mode_arr.as_ref(), &chol, n_beta, 0);
7333    let mut rho_samples = Array2::<f64>::zeros((total_samples, n_rho));
7334    let mut link_param_samples = Array2::<f64>::zeros((total_samples, n_link_params));
7335
7336    for chain in 0..n_chains {
7337        for sample_i in 0..n_samples_out {
7338            let sample_idx = chain * n_samples_out + sample_i;
7339            let zview = samples_array.slice(ndarray::s![chain, sample_i, ..]);
7340
7341            // ρ and adaptive link parameters are stored directly
7342            let rho_slice = zview.slice(ndarray::s![n_beta..n_beta + n_rho]);
7343            rho_samples.row_mut(sample_idx).assign(&rho_slice);
7344            let link_slice = zview.slice(ndarray::s![n_beta + n_rho..]);
7345            link_param_samples.row_mut(sample_idx).assign(&link_slice);
7346        }
7347    }
7348
7349    let beta_mean = beta_samples
7350        .mean_axis(Axis(0))
7351        .unwrap_or_else(|| Array1::zeros(n_beta));
7352    let rho_mean = rho_samples
7353        .mean_axis(Axis(0))
7354        .unwrap_or_else(|| Array1::zeros(n_rho));
7355    let link_param_mean = link_param_samples
7356        .mean_axis(Axis(0))
7357        .unwrap_or_else(|| Array1::zeros(n_link_params));
7358
7359    let (rhat, ess) = compute_split_rhat_and_ess(&samples_array);
7360
7361    let converged = NutsConvergenceThresholds {
7362        max_rhat: 1.1,
7363        min_ess: Some(50.0),
7364    }
7365    .converged(rhat, ess);
7366    if !converged {
7367        log::warn!(
7368            "[Joint HMC] Convergence warning: R-hat={:.3}, ESS={:.1}",
7369            rhat,
7370            ess,
7371        );
7372    }
7373
7374    Ok(JointBetaRhoResult {
7375        beta_samples,
7376        rho_samples,
7377        beta_mean,
7378        link_param_samples,
7379        link_param_mean,
7380        rho_mean,
7381        rhat,
7382        ess,
7383        converged,
7384        trigger_skewness: inputs.trigger_skewness,
7385    })
7386}
7387
7388// ============================================================================
7389// Survival Model HMC Support
7390// ============================================================================
7391
7392mod survival_hmc {
7393    use super::*;
7394    use gam_models::survival::{
7395        PenaltyBlocks, SurvivalEngineInputs, SurvivalMonotonicityPenalty, SurvivalSpec,
7396        WorkingModelSurvival,
7397    };
7398
7399    /// Shared data for survival NUTS posterior (wrapped in Arc to prevent cloning).
7400    #[derive(Clone)]
7401    struct SharedSurvivalData {
7402        /// Exact survival model in original spline coordinates.
7403        base_model: Arc<WorkingModelSurvival>,
7404        /// MAP estimate in coefficient coordinates.
7405        mode: Arc<Array1<f64>>,
7406    }
7407
7408    /// Whitened log-posterior target for survival models with analytical gradients.
7409    #[derive(Clone)]
7410    pub struct SurvivalPosterior {
7411        /// Shared read-only data (Arc prevents duplication)
7412        data: SharedSurvivalData,
7413        /// Transform: L where L L^T = H^{-1}
7414        chol: Array2<f64>,
7415        /// L^T for gradient chain rule: ∇z = L^T @ ∇_β
7416        chol_t: Array2<f64>,
7417    }
7418
7419    impl SurvivalPosterior {
7420        /// Creates a new survival posterior target.
7421        pub fn new(
7422            age_entry: ArrayView1<'_, f64>,
7423            age_exit: ArrayView1<'_, f64>,
7424            event_target: ArrayView1<'_, u8>,
7425            event_competing: ArrayView1<'_, u8>,
7426            sampleweight: ArrayView1<'_, f64>,
7427            x_entry: ArrayView2<'_, f64>,
7428            x_exit: ArrayView2<'_, f64>,
7429            x_derivative: ArrayView2<'_, f64>,
7430            offset_eta_entry: Option<ArrayView1<'_, f64>>,
7431            offset_eta_exit: Option<ArrayView1<'_, f64>>,
7432            offset_derivative_exit: Option<ArrayView1<'_, f64>>,
7433            penalties: PenaltyBlocks,
7434            monotonicity: SurvivalMonotonicityPenalty,
7435            spec: SurvivalSpec,
7436            structurally_monotonic: bool,
7437            structural_time_columns: usize,
7438            mode: ArrayView1<f64>,
7439            hessian: ArrayView2<f64>,
7440        ) -> Result<Self, String> {
7441            let n = age_entry.len();
7442            let off_eta_entry = offset_eta_entry
7443                .map(|v| v.to_owned())
7444                .unwrap_or_else(|| Array1::zeros(n));
7445            let off_eta_exit = offset_eta_exit
7446                .map(|v| v.to_owned())
7447                .unwrap_or_else(|| Array1::zeros(n));
7448            let off_deriv_exit = offset_derivative_exit
7449                .map(|v| v.to_owned())
7450                .unwrap_or_else(|| Array1::zeros(n));
7451
7452            let mut base_model = WorkingModelSurvival::from_engine_inputswith_offsets(
7453                SurvivalEngineInputs {
7454                    age_entry,
7455                    age_exit,
7456                    event_target,
7457                    event_competing,
7458                    sampleweight,
7459                    x_entry,
7460                    x_exit,
7461                    x_derivative,
7462                    monotonicity_constraint_rows: None,
7463                    monotonicity_constraint_offsets: None,
7464                },
7465                Some(gam_models::survival::SurvivalBaselineOffsets {
7466                    eta_entry: off_eta_entry.view(),
7467                    eta_exit: off_eta_exit.view(),
7468                    derivative_exit: off_deriv_exit.view(),
7469                }),
7470                penalties,
7471                monotonicity,
7472                spec,
7473            )
7474            .map_err(|e| format!("Survival state construction failed: {:?}", e))?;
7475            if structurally_monotonic {
7476                base_model
7477                    .set_structural_monotonicity(true, structural_time_columns)
7478                    .map_err(|e| {
7479                        format!("Failed to enable structural monotonicity in survival HMC: {e}")
7480                    })?;
7481            }
7482
7483            let sampler_mode = mode.to_owned();
7484            let dim = sampler_mode.len();
7485
7486            let whitening = hessian_whitening_transform(
7487                hessian,
7488                dim,
7489                1.0,
7490                "Hessian Cholesky decomposition failed",
7491            )?;
7492            let chol = whitening.chol;
7493            let chol_t = whitening.chol_t;
7494
7495            let data = SharedSurvivalData {
7496                base_model: Arc::new(base_model),
7497                mode: Arc::new(sampler_mode),
7498            };
7499
7500            Ok(Self { data, chol, chol_t })
7501        }
7502
7503        fn compute_logp_and_grad_into(
7504            &self,
7505            z: &Array1<f64>,
7506            grad: &mut Array1<f64>,
7507        ) -> Result<f64, String> {
7508            let sampler_position = self.data.mode.as_ref() + &self.chol.dot(z);
7509            let state = self
7510                .data
7511                .base_model
7512                .update_state(&sampler_position)
7513                .map_err(|e| format!("Survival state update failed: {:?}", e))?;
7514            let logp = state.log_likelihood - state.penalty_term;
7515            let grad_beta = state.gradient.mapv(|g| -g);
7516            fast_av_into(&self.chol_t, &grad_beta, grad);
7517            Ok(logp)
7518        }
7519
7520        /// Get the Cholesky factor L for un-whitening samples
7521        pub fn chol(&self) -> &Array2<f64> {
7522            &self.chol
7523        }
7524
7525        /// Get the mode
7526        pub fn mode(&self) -> &Array1<f64> {
7527            &self.data.mode
7528        }
7529    }
7530
7531    impl HamiltonianTarget<Array1<f64>> for SurvivalPosterior {
7532        fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7533            match self.compute_logp_and_grad_into(position, grad) {
7534                Ok(logp) => logp,
7535                Err(e) => {
7536                    log::warn!("Survival posterior evaluation failed: {}", e);
7537                    grad.fill(0.0);
7538                    f64::NEG_INFINITY
7539                }
7540            }
7541        }
7542    }
7543
7544    /// Runs NUTS sampling for survival models with whitened parameter space.
7545    pub(crate) fn run_survival_nuts_sampling(
7546        age_entry: ArrayView1<'_, f64>,
7547        age_exit: ArrayView1<'_, f64>,
7548        event_target: ArrayView1<'_, u8>,
7549        event_competing: ArrayView1<'_, u8>,
7550        sampleweight: ArrayView1<'_, f64>,
7551        x_entry: ArrayView2<'_, f64>,
7552        x_exit: ArrayView2<'_, f64>,
7553        x_derivative: ArrayView2<'_, f64>,
7554        eta_offset_entry: Option<ArrayView1<'_, f64>>,
7555        eta_offset_exit: Option<ArrayView1<'_, f64>>,
7556        derivative_offset_exit: Option<ArrayView1<'_, f64>>,
7557        penalties: PenaltyBlocks,
7558        monotonicity: SurvivalMonotonicityPenalty,
7559        spec: SurvivalSpec,
7560        structurally_monotonic: bool,
7561        structural_time_columns: usize,
7562        mode: ArrayView1<f64>,
7563        hessian: ArrayView2<f64>,
7564        config: &NutsConfig,
7565    ) -> Result<NutsResult, String> {
7566        validate_nuts_config(config).map_err(String::from)?;
7567        // Create posterior target
7568        let target = SurvivalPosterior::new(
7569            age_entry,
7570            age_exit,
7571            event_target,
7572            event_competing,
7573            sampleweight,
7574            x_entry,
7575            x_exit,
7576            x_derivative,
7577            eta_offset_entry,
7578            eta_offset_exit,
7579            derivative_offset_exit,
7580            penalties,
7581            monotonicity,
7582            spec,
7583            structurally_monotonic,
7584            structural_time_columns,
7585            mode,
7586            hessian,
7587        )?;
7588
7589        // Get Cholesky factor for un-whitening samples later
7590        let chol = target.chol().clone();
7591        let mode_arr = target.mode().clone();
7592        let dim = mode_arr.len();
7593
7594        let initial_positions = jittered_initial_positions(config, dim, 0.1, 0xEC2D_7A9B_4051_F638);
7595
7596        let mass_cfg = robust_survival_mass_matrix_config(dim, config.nwarmup);
7597        let (result, run_stats) = run_whitened_nuts_result(
7598            target,
7599            &mode_arr,
7600            &chol,
7601            initial_positions,
7602            config,
7603            dim,
7604            mass_cfg,
7605            0x731B_60D4_AE52_9C8F,
7606            "NUTS sampling failed",
7607            Array1::zeros(dim),
7608            NutsConvergenceThresholds {
7609                max_rhat: 1.1,
7610                min_ess: None,
7611            },
7612        )?;
7613
7614        log::info!("Survival NUTS sampling complete: {}", run_stats);
7615
7616        Ok(result)
7617    }
7618}
7619
7620/// Engine-facing flattened survival NUTS entrypoint.
7621pub fn run_survival_nuts_sampling_flattened<'a>(
7622    flat: SurvivalFlatInputs<'a>,
7623    penalties: gam_models::survival::PenaltyBlocks,
7624    monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
7625    spec: gam_models::survival::SurvivalSpec,
7626    structurally_monotonic: bool,
7627    structural_time_columns: usize,
7628    mode: ArrayView1<'a, f64>,
7629    hessian: ArrayView2<'a, f64>,
7630    config: &NutsConfig,
7631) -> Result<NutsResult, String> {
7632    run_nuts_sampling_flattened_family(
7633        LikelihoodSpec {
7634            response: ResponseFamily::RoystonParmar,
7635            link: InverseLink::Standard(StandardLink::Identity),
7636        },
7637        FamilyNutsInputs::Survival(Box::new(SurvivalNutsInputs {
7638            flat,
7639            penalties,
7640            monotonicity,
7641            spec,
7642            structurally_monotonic,
7643            structural_time_columns,
7644            mode,
7645            hessian,
7646        })),
7647        config,
7648    )
7649}