Skip to main content

gam_inference/
hmc_io.rs

1//! NUTS Sampler using general-mcmc
2//!
3//! This module provides NUTS (No-U-Turn Sampler) for honest uncertainty
4//! quantification after PIRLS convergence.
5//!
6//! # Design
7//!
8//! Since general-mcmc's NUTS uses an identity mass matrix, we whiten the
9//! parameter space using the Cholesky decomposition of the inverse Hessian:
10//!
11//! - Transform: β = μ + L @ z  (where L L^T = H^{-1})
12//! - The whitened space has unit covariance, so NUTS mixes efficiently
13//! - Samples are un-transformed back to the original space
14//!
15//! # Analytical Gradients
16//!
17//! We override `unnorm_logp_and_grad` to compute gradients analytically using
18//! ndarray, avoiding burn's autodiff overhead. The gradient computation mirrors
19//! the true log-posterior gradient (not the PIRLS working gradient).
20//!
21//! # Memory Efficiency
22//!
23//! Large data (design matrix, response, etc.) is wrapped in `Arc` to allow
24//! sharing across chains without duplication when general-mcmc clones the target.
25
26use faer::Side;
27use gam_terms::construction::CanonicalPenalty;
28use gam_solve::estimate::reml::FirthDenseOperator;
29use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
30use gam_solve::estimate::{
31    EstimationError, UnifiedFitResult, validate_explicit_dense_hessian_for_whitening,
32};
33use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh, fast_ata_into, fast_atv, fast_av_into};
34use gam_models::wiggle::monotone_wiggle_basis_with_derivative_order;
35use crate::gpu_polya_gamma::{PgSeed, PolyaGammaBatchInput};
36use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
37use gam_linalg::matrix::DesignMatrix;
38use gam_solve::mixture_link::{
39    InverseLinkKernel, LinkParamPartials, inverse_link_jet_for_inverse_link, softmax_last_fixedzero,
40};
41use gam_problem::types::{
42    InverseLink, LikelihoodSpec, ResponseFamily, RhoPrior, StandardLink, is_valid_tweedie_power,
43};
44use general_mcmc::generic_hmc::HamiltonianTarget;
45pub use general_mcmc::generic_nuts::NUTSMassMatrixConfig;
46use general_mcmc::generic_nuts::{GenericNUTS, MassMatrixAdaptation};
47use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
48use rand::{RngExt, SeedableRng, rngs::StdRng};
49use serde::{Deserialize, Serialize};
50use std::cell::RefCell;
51use std::fmt;
52use std::sync::{Arc, Mutex};
53
54/// Binomial families whose inverse link has a Fisher-weight jet
55/// (`fisher_weight_jet5`) support the Jeffreys/Firth term. This is the
56/// link-general set shared with the REML/PIRLS Firth operator; the canonical
57/// logit case is unchanged.
58#[inline]
59fn likelihood_spec_supports_firth(spec: &LikelihoodSpec) -> bool {
60    spec.supports_firth()
61}
62
63/// Inverse link to evaluate the Fisher working weight with for the Jeffreys
64/// term. Returns `None` for unsupported specs.
65#[inline]
66fn likelihood_spec_jeffreys_link(spec: &LikelihoodSpec) -> Option<InverseLink> {
67    if likelihood_spec_supports_firth(spec) {
68        Some(spec.link.clone())
69    } else {
70        None
71    }
72}
73
74/// Typed error variants for the HMC / NUTS sampling module.
75///
76/// External-facing helpers in this module continue to return
77/// `Result<_, String>`; this enum is materialized internally and converted
78/// at the public boundary via `.map_err(String::from)` so that the error
79/// text remains byte-identical to the previous `format!` output.
80#[derive(Debug, Clone)]
81pub enum HmcError {
82    /// Sampler state (penalty / Hessian / mode / posterior values) contains
83    /// NaN or Inf where finiteness is required.
84    NonFiniteState { reason: String },
85    /// Configuration value (e.g. `target_accept`, unit-weight requirement)
86    /// is out of range or otherwise invalid.
87    InvalidConfig { reason: String },
88    /// Dimensions of the supplied matrices / vectors are inconsistent.
89    DimensionMismatch { reason: String },
90    /// Firth/Jeffreys correction was requested for a family that does not
91    /// support it.
92    FirthUnsupported { reason: String },
93    /// Inverse-link state does not match the requested likelihood family in
94    /// the joint (β, ρ) sampler.
95    LinkMismatch { reason: String },
96    /// Likelihood family is not implemented in the current sampling path.
97    UnsupportedFamily { reason: String },
98    /// Sampling produced no usable output (empty kept set, non-finite
99    /// summary statistic, etc.).
100    SamplingFailed { reason: String },
101}
102
103impl fmt::Display for HmcError {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self {
106            HmcError::NonFiniteState { reason }
107            | HmcError::InvalidConfig { reason }
108            | HmcError::DimensionMismatch { reason }
109            | HmcError::FirthUnsupported { reason }
110            | HmcError::LinkMismatch { reason }
111            | HmcError::UnsupportedFamily { reason }
112            | HmcError::SamplingFailed { reason } => f.write_str(reason),
113        }
114    }
115}
116
117impl From<HmcError> for String {
118    fn from(err: HmcError) -> String {
119        err.to_string()
120    }
121}
122
123/// Upper bound on the autocorrelation lag summed in the effective-sample-size
124/// estimate. The Geyer initial-positive-sequence sum normally self-truncates
125/// long before this, but a hard cap bounds the `O(n·lag)` work for very long
126/// chains where the autocorrelation tail is numerical noise.
127const MAX_AUTOCORRELATION_LAG: usize = 1000;
128
129/// Floor on the lag-0 autocovariance (chain variance) used as the denominator in
130/// the autocorrelation ratios, guarding against division by zero for a chain
131/// that is numerically constant.
132const AUTOCOVARIANCE_FLOOR: f64 = 1e-16;
133
134/// Compute split-chain R-hat and ESS using the Gelman-Rubin diagnostic.
135///
136/// This is the standard split-chain formulation (no rank normalization).
137/// Returns (max_rhat, min_ess) across dimensions.
138fn compute_split_rhat_and_ess(samples: &Array3<f64>) -> (f64, f64) {
139    let n_chains = samples.shape()[0];
140    let n_samples = samples.shape()[1];
141    let dim = samples.shape()[2];
142
143    if n_chains < 2 || n_samples < 4 {
144        return (1.0, n_chains as f64 * n_samples as f64 * 0.5);
145    }
146
147    // Split each chain in half to detect non-stationarity
148    let half = n_samples / 2;
149    let n_split_chains = n_chains * 2;
150    let n_split_samples = half;
151
152    let mut max_rhat = 0.0f64;
153    let mut min_ess = f64::INFINITY;
154
155    #[inline]
156    fn splitvalue(
157        samples: &Array3<f64>,
158        n_chains: usize,
159        half: usize,
160        dim: usize,
161        sc: usize,
162        t: usize,
163    ) -> f64 {
164        let chain = sc % n_chains;
165        if sc < n_chains {
166            samples[[chain, t, dim]]
167        } else {
168            samples[[chain, half + t, dim]]
169        }
170    }
171
172    fn ess_from_split_dimension(
173        samples: &Array3<f64>,
174        n_chains: usize,
175        half: usize,
176        dim: usize,
177    ) -> f64 {
178        let m = n_chains * 2;
179        let n = half;
180        if m == 0 || n < 4 {
181            return (m * n).max(1) as f64;
182        }
183
184        let mut means = vec![0.0_f64; m];
185        let mut gamma0 = vec![0.0_f64; m];
186        for sc in 0..m {
187            let mut sum = 0.0;
188            for t in 0..n {
189                sum += splitvalue(samples, n_chains, half, dim, sc, t);
190            }
191            let mean = sum / n as f64;
192            means[sc] = mean;
193            let mut g0 = 0.0;
194            for t in 0..n {
195                let d = splitvalue(samples, n_chains, half, dim, sc, t) - mean;
196                g0 += d * d;
197            }
198            gamma0[sc] = (g0 / n as f64).max(AUTOCOVARIANCE_FLOOR);
199        }
200
201        let max_lag = (n - 1).min(MAX_AUTOCORRELATION_LAG);
202        let mut tau = 1.0_f64;
203        let mut lag = 1usize;
204        while lag < max_lag {
205            let mut pair = 0.0_f64;
206            for l in [lag, lag + 1] {
207                if l > max_lag {
208                    continue;
209                }
210                let mut rho_l = 0.0;
211                for sc in 0..m {
212                    let mu = means[sc];
213                    let mut cov = 0.0;
214                    let denom = (n - l) as f64;
215                    for t in 0..(n - l) {
216                        let x0 = splitvalue(samples, n_chains, half, dim, sc, t);
217                        let x1 = splitvalue(samples, n_chains, half, dim, sc, t + l);
218                        cov += (x0 - mu) * (x1 - mu);
219                    }
220                    cov /= denom;
221                    rho_l += cov / gamma0[sc];
222                }
223                rho_l /= m as f64;
224                pair += rho_l;
225            }
226            if !pair.is_finite() || pair <= 0.0 {
227                break;
228            }
229            tau += 2.0 * pair;
230            lag += 2;
231        }
232        if !tau.is_finite() || tau <= 0.0 {
233            return 1.0;
234        }
235        let total = (m * n) as f64;
236        (total / tau).clamp(1.0, total)
237    }
238
239    let mut chain_means = vec![0.0_f64; n_split_chains];
240    let mut chainvars = vec![0.0_f64; n_split_chains];
241    for d in 0..dim {
242        for chain in 0..n_chains {
243            // First half
244            let mut sum1 = 0.0;
245            for i in 0..half {
246                sum1 += samples[[chain, i, d]];
247            }
248            let mean1 = sum1 / half as f64;
249            let mut var1 = 0.0;
250            for i in 0..half {
251                let diff = samples[[chain, i, d]] - mean1;
252                var1 += diff * diff;
253            }
254            var1 /= (half - 1).max(1) as f64;
255            let first_idx = chain;
256            chain_means[first_idx] = mean1;
257            chainvars[first_idx] = var1;
258
259            // Second half
260            let mut sum2 = 0.0;
261            for i in half..(2 * half) {
262                sum2 += samples[[chain, i, d]];
263            }
264            let mean2 = sum2 / half as f64;
265            let mut var2 = 0.0;
266            for i in half..(2 * half) {
267                let diff = samples[[chain, i, d]] - mean2;
268                var2 += diff * diff;
269            }
270            var2 /= (half - 1).max(1) as f64;
271            let second_idx = n_chains + chain;
272            chain_means[second_idx] = mean2;
273            chainvars[second_idx] = var2;
274        }
275
276        // Within-chain variance W
277        let w: f64 = chainvars.iter().copied().sum::<f64>() / n_split_chains as f64;
278
279        // Between-chain variance B
280        let overall_mean: f64 = chain_means.iter().copied().sum::<f64>() / n_split_chains as f64;
281        let b: f64 = chain_means
282            .iter()
283            .map(|m| (m - overall_mean).powi(2))
284            .sum::<f64>()
285            * n_split_samples as f64
286            / (n_split_chains - 1) as f64;
287
288        // Estimated variance
289        let var_hat = (n_split_samples as f64 - 1.0) / n_split_samples as f64 * w
290            + b / n_split_samples as f64;
291
292        // R-hat
293        let rhat_d = if w > 1e-10 { (var_hat / w).sqrt() } else { 1.0 };
294        max_rhat = max_rhat.max(rhat_d);
295
296        // Real ESS via split-chain autocorrelation with Geyer IPS truncation.
297        let ess_d = ess_from_split_dimension(samples, n_chains, half, d);
298        min_ess = min_ess.min(ess_d);
299    }
300
301    (max_rhat, min_ess.max(1.0))
302}
303
304/// Solve L^T * X = I where L is lower triangular.
305///
306/// Returns X = L^{-T} (the inverse transpose of L).
307///
308/// This is the correct way to compute the whitening transform matrix:
309/// Given H = L L^T (Cholesky), we need W where W W^T = H^{-1}
310/// Since H^{-1} = L^{-T} L^{-1}, we have W = L^{-T}.
311///
312/// Implementation strategy (math-equivalent to back-substitution on L^T):
313/// We compute L^{-1} column-wise via forward substitution on L, then the
314/// result is `L^{-1}` transposed. Forward-substituting column `c` of L^{-1}
315/// uses `L`'s rows (which are contiguous in row-major `Array2`), giving
316/// stride-1 inner loops instead of the strided `l[[j, i]]` (column-major
317/// access pattern) and double-indexed writes of the original. We also
318/// exploit the triangular structure of `L^{-1}` (entries above the diagonal
319/// are zero), skipping ~half of the inner work compared to the previous
320/// version which traversed `i = (0..dim).rev()` for every column.
321///
322/// Total cost: ~dim^3 / 6 multiply-adds (down from dim^3 / 2), with all
323/// inner loops on contiguous slices.
324fn solve_upper_triangular_transpose(l: &Array2<f64>, dim: usize) -> Array2<f64> {
325    let mut result = Array2::<f64>::zeros((dim, dim));
326    if dim == 0 {
327        return result;
328    }
329
330    // Pull contiguous row slice access from L (row-major standard layout).
331    // Falls back to a one-time owned copy if `l` is not standard-layout
332    // (e.g. a transposed view); both branches feed the same inner loop.
333    let l_owned;
334    let l_rows: &[f64] = if let Some(s) = l.as_slice() {
335        s
336    } else {
337        l_owned = l.to_owned();
338        l_owned
339            .as_slice()
340            .expect("owned standard-layout Array2 has contiguous storage")
341    };
342
343    // Scratch column for L^{-1}[:, col]; reused across columns.
344    let mut y = vec![0.0_f64; dim];
345
346    for col in 0..dim {
347        // Forward-substitute L * y = e_col. y[i] = 0 for i < col.
348        // Diagonal term:
349        let d_col = l_rows[col * dim + col];
350        let inv_d_col = if d_col.abs() > 1e-15 {
351            1.0 / d_col
352        } else {
353            0.0
354        };
355        y[col] = inv_d_col;
356
357        // Below-diagonal entries: y[i] = -(sum_{j=col..i} L[i,j] * y[j]) / L[i,i].
358        // Each inner loop is a stride-1 dot product on row `i` of L (contiguous).
359        for i in (col + 1)..dim {
360            let row_off = i * dim;
361            let l_row = &l_rows[row_off + col..row_off + i];
362            let y_seg = &y[col..i];
363            // Both operands are contiguous slices of equal length; the loop
364            // is a straight-line stride-1 reduction the optimizer can
365            // auto-vectorize.
366            let mut sum = 0.0_f64;
367            for k in 0..l_row.len() {
368                sum += l_row[k] * y_seg[k];
369            }
370            let d = l_rows[row_off + i];
371            y[i] = if d.abs() > 1e-15 { -sum / d } else { 0.0 };
372        }
373
374        // Write the column into result transposed: result[col, i] = y[i] for i >= col.
375        // result[i, col] is left at zero for i < col (upper-triangular L^{-T}).
376        // That matches `result[col, i]` filling row `col` from column `col` rightward.
377        let res_row_start = col * dim + col;
378        let res_row = &mut result.as_slice_mut().expect("owned Array2 contiguous")
379            [res_row_start..res_row_start + (dim - col)];
380        for (k, slot) in res_row.iter_mut().enumerate() {
381            *slot = y[col + k];
382        }
383
384        // Clear scratch positions we wrote, so the next column starts clean above.
385        for slot in &mut y[col..dim] {
386            *slot = 0.0;
387        }
388    }
389
390    result
391}
392
393struct WhiteningTransform {
394    chol: Array2<f64>,
395    chol_t: Array2<f64>,
396}
397
398fn hessian_whitening_transform(
399    hessian: ArrayView2<f64>,
400    dim: usize,
401    cov_scale: f64,
402    cholesky_error_prefix: &str,
403) -> Result<WhiteningTransform, String> {
404    let hessian_owned = hessian.to_owned();
405    let chol_factor = hessian_owned
406        .cholesky(Side::Lower)
407        .map_err(|e| format!("{cholesky_error_prefix}: {:?}", e))?;
408    let l_h = chol_factor.lower_triangular();
409    let mut chol = solve_upper_triangular_transpose(&l_h, dim);
410    let sqrt_cov_scale = cov_scale.max(0.0).sqrt();
411    if (sqrt_cov_scale - 1.0).abs() > 0.0 {
412        chol.mapv_inplace(|v| v * sqrt_cov_scale);
413    }
414    let chol_t = chol.t().to_owned();
415    Ok(WhiteningTransform { chol, chol_t })
416}
417
418/// Shared data for NUTS posterior (wrapped in Arc to prevent cloning).
419///
420/// This struct holds read-only data that is shared across all chains.
421/// Using Arc prevents memory explosion when general-mcmc clones the target.
422#[derive(Clone)]
423struct SharedData {
424    /// Design matrix X [n_samples, dim]
425    x: Arc<Array2<f64>>,
426    /// Response vector y [n_samples]
427    y: Arc<Array1<f64>>,
428    /// Observation/case weights [n_samples]
429    weights: Arc<Array1<f64>>,
430    /// MAP estimate (mode) μ [dim]
431    mode: Arc<Array1<f64>>,
432    /// Fixed additive offset on the linear predictor: η = Xβ + offset
433    /// [n_samples]. `None` when the model was fit without an offset (the common
434    /// case), avoiding a per-step O(n) add of zeros. The offset shifts η only —
435    /// it is constant in β, so ∂η/∂β = X is unchanged and no gradient,
436    /// Hessian, or penalty term is affected. Dropping it (the historical
437    /// behaviour) silently sampled the wrong posterior for any `--offset-column`
438    /// fit (#882).
439    offset: Option<Arc<Array1<f64>>>,
440    /// Auxiliary log-link family parameter: Gamma shape, Tweedie power, or NB theta.
441    gamma_shape: f64,
442    /// Dispersion parameter φ (Gaussian: σ²; Gamma: 1/shape; `Known(1.0)` for
443    /// fixed-scale families). Consumed **only** by the likelihood adapters that
444    /// carry the dispersion in the data term itself: the profiled-Gaussian
445    /// log-likelihood and its gradient multiply through by `1/φ`, and the
446    /// Tweedie quasi-likelihood folds `1/φ` into its weight. It does NOT drive
447    /// the whitening or penalty scaling — those use the `cov_scale` invariant
448    /// (`NutsFamily::coefficient_covariance_scale`), which is `1.0` for Gamma
449    /// even though `φ ≠ 1`, because Gamma's dispersion already lives inside the
450    /// working weight (the `shape` factor in `gamma_log_logp_and_grad`). See
451    /// `inference::dispersion_cov` for the ownership invariants.
452    dispersion: gam_solve::model_types::Dispersion,
453    /// Number of samples
454    n_samples: usize,
455    /// Number of coefficients
456    dim: usize,
457}
458
459thread_local! {
460    static NUTS_RESIDUAL_SCRATCH: RefCell<Array1<f64>> = RefCell::new(Array1::zeros(0));
461}
462
463/// Whitened log-posterior target with analytical gradients.
464///
465/// Uses Arc for shared data to prevent memory explosion when cloned for chains.
466/// Uses faer for numerically stable Cholesky decomposition.
467/// Family mode for NUTS log-likelihood computation.
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub enum NutsFamily {
470    Gaussian,
471    BinomialLogit,
472    BinomialProbit,
473    BinomialCLogLog,
474    PoissonLog,
475    TweedieLog,
476    NegativeBinomialLog,
477    GammaLog,
478}
479
480impl NutsFamily {
481    #[inline]
482    fn likelihood_spec(self) -> LikelihoodSpec {
483        match self {
484            Self::Gaussian => LikelihoodSpec {
485                response: ResponseFamily::Gaussian,
486                link: InverseLink::Standard(StandardLink::Identity),
487            },
488            Self::BinomialLogit => LikelihoodSpec {
489                response: ResponseFamily::Binomial,
490                link: InverseLink::Standard(StandardLink::Logit),
491            },
492            Self::BinomialProbit => LikelihoodSpec {
493                response: ResponseFamily::Binomial,
494                link: InverseLink::Standard(StandardLink::Probit),
495            },
496            Self::BinomialCLogLog => LikelihoodSpec {
497                response: ResponseFamily::Binomial,
498                link: InverseLink::Standard(StandardLink::CLogLog),
499            },
500            Self::PoissonLog => LikelihoodSpec {
501                response: ResponseFamily::Poisson,
502                link: InverseLink::Standard(StandardLink::Log),
503            },
504            Self::TweedieLog => LikelihoodSpec {
505                response: ResponseFamily::Tweedie { p: 1.5 },
506                link: InverseLink::Standard(StandardLink::Log),
507            },
508            Self::NegativeBinomialLog => LikelihoodSpec {
509                response: ResponseFamily::NegativeBinomial {
510                    theta: 1.0,
511                    theta_fixed: false,
512                },
513                link: InverseLink::Standard(StandardLink::Log),
514            },
515            Self::GammaLog => LikelihoodSpec {
516                response: ResponseFamily::Gamma,
517                link: InverseLink::Standard(StandardLink::Log),
518            },
519        }
520    }
521
522    /// Coefficient-covariance scale for the whitened NUTS target — the
523    /// NUTS-family counterpart of
524    /// [`gam_problem::types::GlmLikelihoodSpec::coefficient_covariance_scale`] (#679).
525    ///
526    /// The sampler must reproduce the posterior `N(mode, Vb)` with
527    /// `Vb = scale · H⁻¹`, where `H = XᵀWX + S_λ` is the stored penalized
528    /// Hessian (penalty `S_λ` added **unscaled**). The returned `scale` is:
529    ///
530    /// * `profiled_gaussian_phi` (= σ̂²) for the **profiled Gaussian** identity
531    ///   model, whose working weight is scale-free (`W = priorweights`), so the
532    ///   stored `H` omits the dispersion and `Vb = σ̂²·H⁻¹`. The NUTS Gaussian
533    ///   log-likelihood is structurally this profiled form: scale-free
534    ///   residuals multiplied by `1/φ` (see `gaussian_logp_and_grad_into`).
535    /// * `1.0` for every **weight-carries-dispersion** family (Gamma, Tweedie,
536    ///   Negative-Binomial, and the fixed-scale Poisson/Binomial). Their
537    ///   working weight already folds in the reciprocal dispersion / full
538    ///   Fisher information — for Gamma-log the shape `ν = 1/φ` is baked into
539    ///   the likelihood score `∂ℓ/∂η = ν·(y/μ − 1)` — so the stored `H` is
540    ///   already the true penalized Hessian and `Vb = H⁻¹`. Multiplying by the
541    ///   dispersion again double-counts it and shrinks every posterior SD by
542    ///   `√dispersion` — exactly the Gamma-log defect addressed in #680.
543    ///
544    /// This single scalar governs BOTH the whitening preconditioner
545    /// (`L Lᵀ = scale·H⁻¹`, so `L` is scaled by `√scale`) and the target's
546    /// penalty weight (`penalty_scale = 1/scale`), keeping the sampled
547    /// posterior, its whitening metric, and the Wald `Vb` of #679 mutually
548    /// consistent. Crucially it does NOT key off the statistical dispersion
549    /// `φ`: Gamma carries `φ = 1/shape ≠ 1` yet still has `scale = 1`, because
550    /// that `φ` already lives inside `W`.
551    #[inline]
552    fn coefficient_covariance_scale(self, profiled_gaussian_phi: f64) -> f64 {
553        match self {
554            NutsFamily::Gaussian => profiled_gaussian_phi,
555            _ => 1.0,
556        }
557    }
558}
559
560/// Whitened-coordinate target for the No-U-Turn HMC sampler.
561///
562/// The posterior over β is reparameterized via `β = L z` where `L Lᵀ = H⁻¹`
563/// (Cholesky factor of the inverse posterior Hessian at the MAP), so that
564/// in `z`-coordinates the local curvature is approximately the identity.
565/// The struct holds the shared design, the whitening factor `L` and its
566/// transpose (for gradient chain-rule pull-back `∇_z = Lᵀ ∇_β`), the
567/// family-specific log-likelihood adapter, and a precomputed
568/// `M = Lᵀ S L` so the smoothing penalty `−½ βᵀSβ` becomes the cheap
569/// quadratic `−½ zᵀMz` inside the leapfrog hot loop.  Optionally adds
570/// the identifiable-subspace Firth/Jeffreys term to keep posterior modes
571/// away from infinity under separation.
572pub struct NutsPosterior {
573    /// Shared read-only data (Arc prevents duplication)
574    data: SharedData,
575    /// Transform: L where L L^T = H^{-1} (computed from Hessian)
576    /// This is the inverse-transpose of the Cholesky of H.
577    chol: Array2<f64>,
578    /// L^T for gradient chain rule: ∇z = L^T @ ∇_β
579    chol_t: Array2<f64>,
580    /// Family for log-likelihood computation
581    nuts_family: NutsFamily,
582    /// Whether to add the identifiable-subspace Jeffreys/Firth term to the
583    /// target
584    firth_enabled: bool,
585    /// Precomputed whitened-penalty operator `M = L^T S L` (dim×dim, symmetric
586    /// positive-semidefinite). The penalty term in z-coordinates is
587    ///   −0.5 βᵀSβ = −[c0 + (Lᵀ S μ)ᵀ z + 0.5 zᵀ M z],
588    /// so its z-gradient is just `−(L^T S μ + M z)` — no per-step `S·β` matvec
589    /// or `L^T·∇_β penalty` map is needed.
590    penalty_z_quad: Array2<f64>,
591    /// Precomputed `Lᵀ S μ` (length dim) — z-space gradient contribution from
592    /// the linear-in-z portion of the penalty.
593    penalty_z_lin: Array1<f64>,
594    /// Precomputed `0.5 μᵀ S μ` (scalar) — constant term of the penalty.
595    penalty_z_const: f64,
596    /// Coefficient-covariance scale `cov_scale` (#679/#680 invariant): the
597    /// `Vb = cov_scale·H⁻¹` multiplier. `σ̂²` for profiled Gaussian, `1.0` for
598    /// every weight-carries-dispersion family. Drives both the whitening
599    /// (`L Lᵀ = cov_scale·H⁻¹`) and the target penalty weight
600    /// (`penalty_scale = 1/cov_scale`).
601    cov_scale: f64,
602}
603
604impl NutsPosterior {
605    /// Creates a new posterior target from ndarray data.
606    ///
607    /// # Arguments
608    /// * `x` - Design matrix [n_samples, dim]
609    /// * `y` - Response vector [n_samples]
610    /// * `weights` - Observation/case weights [n_samples]
611    /// * `penalty_matrix` - Combined penalty S [dim, dim]
612    /// * `mode` - MAP estimate μ [dim]
613    /// * `hessian` - Hessian H [dim, dim] (NOT the inverse!)
614    /// * `nuts_family` - Family for log-likelihood computation
615    ///
616    /// # Numerical Stability
617    /// Accepts the Hessian directly and computes L = (chol(H))^{-T} via
618    /// triangular solves, which is more stable than explicitly inverting H.
619    pub fn new(
620        x: ArrayView2<f64>,
621        y: ArrayView1<f64>,
622        weights: ArrayView1<f64>,
623        penalty_matrix: ArrayView2<f64>,
624        mode: ArrayView1<f64>,
625        hessian: ArrayView2<f64>,
626        nuts_family: NutsFamily,
627        gamma_shape: f64,
628        dispersion: gam_solve::model_types::Dispersion,
629        firth_enabled: bool,
630    ) -> Result<Self, String> {
631        let n_samples = x.nrows();
632        let dim = x.ncols();
633
634        // Validate inputs are finite
635        if !penalty_matrix.iter().all(|x| x.is_finite()) {
636            return Err(HmcError::NonFiniteState {
637                reason: "Penalty matrix contains NaN or Inf values".to_string(),
638            }
639            .into());
640        }
641        if !hessian.iter().all(|x| x.is_finite()) {
642            return Err(HmcError::NonFiniteState {
643                reason: "Hessian matrix contains NaN or Inf values".to_string(),
644            }
645            .into());
646        }
647        if !mode.iter().all(|x| x.is_finite()) {
648            return Err(HmcError::NonFiniteState {
649                reason: "Mode vector contains NaN or Inf values".to_string(),
650            }
651            .into());
652        }
653
654        validate_firth_support(nuts_family, firth_enabled).map_err(String::from)?;
655        if nuts_family.likelihood_spec().is_binomial() {
656            validate_binary_responses("binomial NUTS", &y, &weights).map_err(String::from)?;
657        }
658        if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
659            validate_count_responses("negative-binomial NUTS", &y, &weights)
660                .map_err(String::from)?;
661        }
662
663        // Whitening metric: `L Lᵀ` must equal the posterior covariance the
664        // sampler reproduces, `Vb = cov_scale · H⁻¹` (#679/#680 invariant), so
665        // scale `L` by `√cov_scale`. Only the profiled-Gaussian model carries a
666        // non-unit scale (σ̂² = `dispersion.phi()`); every weight-carries-
667        // dispersion family (Gamma/Tweedie/NB) already folds its dispersion into
668        // the stored `H`, so `cov_scale == 1` and this is a no-op. This replaces
669        // a previous `sqrt_phi()` multiply that wrongly scaled Gamma (and any
670        // φ-bearing family) by `√φ`, mis-preconditioning against `φ·H⁻¹`.
671        let cov_scale = nuts_family.coefficient_covariance_scale(dispersion.phi());
672        let whitening = hessian_whitening_transform(
673            hessian,
674            dim,
675            cov_scale,
676            "Hessian Cholesky decomposition failed",
677        )?;
678        let chol = whitening.chol;
679        let chol_t = whitening.chol_t;
680
681        // Precompute the whitened penalty operator and constants so that the
682        // penalty contribution to logp/grad becomes a single symv against z.
683        // Math identity (β = μ + L z, L L^T = H^{-1}):
684        //   0.5 β^T S β = 0.5 μ^T S μ + (L^T S μ)^T z + 0.5 z^T (L^T S L) z
685        // and ∇_z [0.5 β^T S β] = L^T S μ + (L^T S L) z.
686        // This replaces three matvecs per leapfrog step (S·β, L·z used only
687        // for that purpose, and L^T·∇_β penalty) with one dim×dim symv.
688        let penalty_owned = penalty_matrix.to_owned();
689        let mode_owned = mode.to_owned();
690        let s_mu = penalty_owned.dot(&mode_owned);
691        let penalty_z_const = 0.5 * mode_owned.dot(&s_mu);
692        let penalty_z_lin = chol_t.dot(&s_mu);
693        // M = L^T S L = chol_t · (S · chol). Computed in two GEMMs at
694        // construction time only.
695        let s_chol = penalty_owned.dot(&chol);
696        let penalty_z_quad = chol_t.dot(&s_chol);
697
698        let data = SharedData {
699            x: Arc::new(x.to_owned()),
700            y: Arc::new(y.to_owned()),
701            weights: Arc::new(weights.to_owned()),
702            mode: Arc::new(mode_owned),
703            offset: None,
704            gamma_shape,
705            dispersion,
706            n_samples,
707            dim,
708        };
709
710        Ok(Self {
711            data,
712            chol,
713            chol_t,
714            nuts_family,
715            firth_enabled,
716            penalty_z_quad,
717            penalty_z_lin,
718            penalty_z_const,
719            cov_scale,
720        })
721    }
722
723    /// Attach a fixed additive offset to the linear predictor: η = Xβ + offset.
724    ///
725    /// The offset is constant in β, so the whitening geometry (`chol`), penalty
726    /// operators, and stored Hessian are all unchanged — only η (and hence the
727    /// per-observation working residual / mean) shifts. The fitted `mode` and
728    /// `hessian` handed to [`Self::new`] already correspond to the offset-trained
729    /// fit, so this only needs to restore the offset to the likelihood
730    /// evaluation. Returns an error if the offset length disagrees with the data
731    /// or carries non-finite entries.
732    fn with_offset(mut self, offset: ArrayView1<f64>) -> Result<Self, String> {
733        if offset.len() != self.data.n_samples {
734            return Err(HmcError::DimensionMismatch {
735                reason: format!(
736                    "NUTS offset length {} does not match {} observations",
737                    offset.len(),
738                    self.data.n_samples
739                ),
740            }
741            .into());
742        }
743        if !offset.iter().all(|v| v.is_finite()) {
744            return Err(HmcError::NonFiniteState {
745                reason: "NUTS offset contains NaN or Inf values".to_string(),
746            }
747            .into());
748        }
749        self.data.offset = Some(Arc::new(offset.to_owned()));
750        Ok(self)
751    }
752
753    fn compute_logp_and_grad_nd_into(
754        &self,
755        z: &Array1<f64>,
756        residual: &mut Array1<f64>,
757        grad: &mut Array1<f64>,
758    ) -> f64 {
759        // === Step 1: Transform z (whitened) -> β (original) ===
760        // β = μ + L @ z
761        let beta = self.data.mode.as_ref() + &self.chol.dot(z);
762
763        // === Step 2: Compute η = X @ β (+ offset) ===
764        let mut eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
765        if let Some(offset) = self.data.offset.as_ref() {
766            eta += offset.as_ref();
767        }
768
769        // === Step 3: Compute log-likelihood and gradient ===
770        let (ll, mut grad_ll_beta) = self.family_logp_and_grad_into(&eta, residual);
771
772        let mut firth_logdet = 0.0;
773        if self.firth_enabled {
774            match firth_jeffreys_logp_and_grad(self.nuts_family, &self.data, &eta) {
775                Ok((value, grad_beta_firth)) => {
776                    firth_logdet = value;
777                    grad_ll_beta += &grad_beta_firth;
778                }
779                Err(err) => {
780                    log::warn!(
781                        "[NUTS/Firth] Jeffreys target became invalid at the current state: {}",
782                        err
783                    );
784                    grad.fill(0.0);
785                    return f64::NEG_INFINITY;
786                }
787            }
788        }
789
790        // === Step 4: Penalty in z-coordinates (precomputed; see `new`) ===
791        //   −0.5 βᵀ S β  =  −[c0 + lᵀ z + 0.5 zᵀ M z]
792        //   ∇_z (−0.5 βᵀ S β) = −(l + M z)
793        // where l = L^T S μ, M = L^T S L, c0 = 0.5 μᵀ S μ.
794        // This single dim×dim symmetric matvec replaces both the per-step
795        // S·β multiply and the L^T·∇_β penalty chain-rule multiply, and lets
796        // the penalty value, β-gradient and chain rule fuse into one pass.
797        //
798        // Penalty weight in the un-whitened β-target
799        // `log p(β) = loglik(β) − penalty_scale · ½ βᵀSβ`. The invariant is
800        // `Vb = cov_scale · H⁻¹` with `H = XᵀWX + S` (penalty added unscaled),
801        // so the target curvature must equal `Vb⁻¹ = H/cov_scale`. The
802        // likelihood already supplies `−∇²ℓ = (data Fisher info)/cov_scale`
803        // (explicitly `/σ²` for profiled Gaussian, implicitly via the working
804        // weight / the `shape ≡ 1/φ` baked into `gamma_log_logp_and_grad` for
805        // the dispersion-carrying families), so the penalty must match it:
806        //   penalty_scale = 1/cov_scale.
807        // That is `1/σ²` for profiled Gaussian and exactly `1.0` for
808        // Gamma/Tweedie/NB/Poisson/Binomial. The previous code used
809        // `dispersion.inv_phi()` for GammaLog (= shape = 1/φ ≠ 1), which
810        // double-counted the dispersion in the sampled posterior (#680); the
811        // statistical dispersion `φ` is NOT `1/cov_scale` for Gamma because it
812        // already lives inside `W`. Mirrors `LinkWigglePosterior`.
813        let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
814        let mz = self.penalty_z_quad.dot(z);
815        let lin_term = self.penalty_z_lin.dot(z);
816        let quad_term = 0.5 * z.dot(&mz);
817        let penalty = penalty_scale * (self.penalty_z_const + lin_term + quad_term);
818
819        // === Step 5: z-space gradient ===
820        // ∇z log p = L^T ∇_β ℓ  −  penalty_scale · (l + M z)
821        fast_av_into(&self.chol_t, &grad_ll_beta, grad);
822        // gradz -= penalty_scale · (penalty_z_lin + M z); fused parallel update.
823        let lin_view = self.penalty_z_lin.view();
824        ndarray::Zip::from(grad)
825            .and(&lin_view)
826            .and(&mz)
827            .par_for_each(|g, &l, &m| {
828                *g -= penalty_scale * (l + m);
829            });
830
831        ll + firth_logdet - penalty
832    }
833
834    fn family_logp_and_grad_into(
835        &self,
836        eta: &Array1<f64>,
837        residual: &mut Array1<f64>,
838    ) -> (f64, Array1<f64>) {
839        nuts_family_logp_and_grad_into(self.nuts_family, &self.data, eta, residual)
840    }
841
842    /// Get the Cholesky factor L for un-whitening samples
843    pub fn chol(&self) -> &Array2<f64> {
844        &self.chol
845    }
846
847    /// Get the mode
848    pub fn mode(&self) -> &Array1<f64> {
849        &self.data.mode
850    }
851
852    /// Get dimension
853    pub fn dim(&self) -> usize {
854        self.data.dim
855    }
856}
857
858const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_7;
859
860#[inline]
861fn standard_normal_log_pdf(x: f64) -> f64 {
862    -0.5 * x * x - HALF_LOG_2PI
863}
864
865/// Stable log Φ(x) for the standard normal CDF.
866#[inline]
867fn log_ndtr(x: f64) -> f64 {
868    let arg = -x * std::f64::consts::FRAC_1_SQRT_2;
869    let erfc_val = statrs::function::erf::erfc(arg);
870    if erfc_val > 0.0 {
871        erfc_val.ln() - std::f64::consts::LN_2
872    } else {
873        -0.5 * x * x - (-x).ln() - HALF_LOG_2PI
874    }
875}
876
877#[inline]
878fn validate_firth_support(family: NutsFamily, firth_enabled: bool) -> Result<(), HmcError> {
879    let spec = family.likelihood_spec();
880    if firth_enabled && !likelihood_spec_supports_firth(&spec) {
881        return Err(HmcError::FirthUnsupported {
882            reason: format!(
883                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
884                spec.pretty_name()
885            ),
886        });
887    }
888    Ok::<(), _>(())
889}
890
891#[inline]
892fn validate_firth_likelihood_support(
893    likelihood: &LikelihoodSpec,
894    firth_enabled: bool,
895) -> Result<(), HmcError> {
896    if firth_enabled && !likelihood_spec_supports_firth(likelihood) {
897        return Err(HmcError::FirthUnsupported {
898            reason: format!(
899                "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
900                likelihood.pretty_name()
901            ),
902        });
903    }
904    Ok::<(), _>(())
905}
906
907#[inline]
908fn valid_count_response(y: f64) -> bool {
909    y.is_finite() && y >= 0.0 && (y - y.round()).abs() <= 1e-9
910}
911
912fn validate_count_responses(
913    family: &str,
914    y: &ArrayView1<'_, f64>,
915    weights: &ArrayView1<'_, f64>,
916) -> Result<(), HmcError> {
917    for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
918        if wi > 0.0 && !valid_count_response(yi) {
919            return Err(HmcError::InvalidConfig {
920                reason: format!(
921                    "{family} response must be a finite non-negative integer at positive-weight row {i}; got {yi}"
922                ),
923            });
924        }
925    }
926    Ok(())
927}
928
929fn validate_binary_responses(
930    family: &str,
931    y: &ArrayView1<'_, f64>,
932    weights: &ArrayView1<'_, f64>,
933) -> Result<(), HmcError> {
934    for (i, (&yi, &wi)) in y.iter().zip(weights.iter()).enumerate() {
935        if wi > 0.0 && !(yi == 0.0 || yi == 1.0) {
936            return Err(HmcError::InvalidConfig {
937                reason: format!(
938                    "{family} response must be exactly 0 or 1 at positive-weight row {i}; got {yi}"
939                ),
940            });
941        }
942    }
943    Ok(())
944}
945
946/// Compute the identifiable-subspace Jeffreys/Firth contribution and its
947/// β-gradient.
948///
949/// HMC uses the same `FirthDenseOperator` as the REML exact-gradient path.
950/// The operator owns the reduced identifiable Fisher factorization, the
951/// Jeffreys log-determinant, and the analytic β-gradient.
952fn firth_jeffreys_logp_and_grad(
953    family: NutsFamily,
954    data: &SharedData,
955    eta: &Array1<f64>,
956) -> Result<(f64, Array1<f64>), HmcError> {
957    if eta.len() != data.n_samples {
958        return Err(HmcError::DimensionMismatch {
959            reason: format!(
960                "Firth Jeffreys term eta length {} != number of samples {}",
961                eta.len(),
962                data.n_samples
963            ),
964        });
965    }
966    if data.dim == 0 || data.n_samples == 0 {
967        return Ok((0.0, Array1::zeros(data.dim)));
968    }
969    validate_firth_support(family, true)?;
970    if data.weights.iter().all(|w| *w == 0.0) {
971        return Ok((0.0, Array1::zeros(data.dim)));
972    }
973
974    let jeffreys_link =
975        likelihood_spec_jeffreys_link(&family.likelihood_spec()).ok_or_else(|| {
976            HmcError::FirthUnsupported {
977                reason: format!(
978                    "Firth Jeffreys term has no Fisher-weight jet for {}",
979                    family.likelihood_spec().pretty_name()
980                ),
981            }
982        })?;
983    let op = if data.weights.iter().all(|&w| w == 1.0) {
984        FirthDenseOperator::build_for_link(&jeffreys_link, data.x.as_ref(), eta)
985    } else {
986        FirthDenseOperator::build_with_observation_weights_for_link(
987            &jeffreys_link,
988            data.x.as_ref(),
989            eta,
990            data.weights.view(),
991        )
992    }
993    .map_err(|e| HmcError::SamplingFailed {
994        reason: format!("Firth Jeffreys operator failed: {e}"),
995    })?;
996    Ok(op.jeffreys_logdet_and_beta_gradient())
997}
998
999// ============================================================================
1000// Shared family log-likelihood helpers
1001// ============================================================================
1002//
1003// Freestanding functions for computing ℓ(y|β) and ∇_β ℓ for each supported
1004// family. Used by both `NutsPosterior` (fixed-ρ β-only sampling) and
1005// `JointBetaRhoPosterior` (joint β+ρ sampling).
1006
1007fn nuts_family_logp_and_grad_into(
1008    family: NutsFamily,
1009    data: &SharedData,
1010    eta: &Array1<f64>,
1011    residual: &mut Array1<f64>,
1012) -> (f64, Array1<f64>) {
1013    match family {
1014        NutsFamily::BinomialLogit => logit_logp_and_grad_into(data, eta, residual),
1015        NutsFamily::BinomialProbit => probit_logp_and_grad_into(data, eta, residual),
1016        NutsFamily::BinomialCLogLog => cloglog_logp_and_grad_into(data, eta, residual),
1017        NutsFamily::Gaussian => gaussian_logp_and_grad_into(data, eta, residual),
1018        NutsFamily::PoissonLog => poisson_log_logp_and_grad(data, eta),
1019        // Family mapping: TweedieLog stores variance power p in data.gamma_shape.
1020        // Its dispersion phi stays in data.dispersion, matching REML scale ownership.
1021        NutsFamily::TweedieLog => tweedie_log_quasilogp_and_grad(data, eta, data.gamma_shape),
1022        NutsFamily::NegativeBinomialLog => {
1023            // Family mapping: NegativeBinomialLog stores theta in data.gamma_shape.
1024            // NB has unit REML scale; theta is never sourced from fixed_phi.
1025            negative_binomial_log_logp_and_grad(data, eta, data.gamma_shape)
1026        }
1027        NutsFamily::GammaLog => gamma_log_logp_and_grad(data, eta),
1028    }
1029}
1030
1031#[derive(Clone, Debug)]
1032struct BinomialLinkTerms {
1033    log_mu: f64,
1034    log1m_mu: f64,
1035    dlog_mu_deta: f64,
1036    dlog1m_mu_deta: f64,
1037    dmu_dlink: Vec<f64>,
1038}
1039
1040#[inline]
1041fn log_terms_from_mu_and_dmu(
1042    mu: f64,
1043    dmu_deta: f64,
1044    dmu_dlink: Vec<f64>,
1045) -> Result<BinomialLinkTerms, String> {
1046    if !(mu.is_finite() && (0.0..=1.0).contains(&mu) && dmu_deta.is_finite()) {
1047        return Err(format!(
1048            "binomial inverse link returned invalid mu/deta derivative: mu={mu}, dmu_deta={dmu_deta}"
1049        ));
1050    }
1051    let log_mu = if mu == 0.0 {
1052        f64::NEG_INFINITY
1053    } else {
1054        mu.ln()
1055    };
1056    let one_minus_mu = 1.0 - mu;
1057    let log1m_mu = if one_minus_mu == 0.0 {
1058        f64::NEG_INFINITY
1059    } else {
1060        one_minus_mu.ln()
1061    };
1062    let dlog_mu_deta = if mu == 0.0 {
1063        f64::INFINITY.copysign(dmu_deta)
1064    } else {
1065        dmu_deta / mu
1066    };
1067    let dlog1m_mu_deta = if one_minus_mu == 0.0 {
1068        f64::NEG_INFINITY.copysign(dmu_deta)
1069    } else {
1070        -dmu_deta / one_minus_mu
1071    };
1072    Ok(BinomialLinkTerms {
1073        log_mu,
1074        log1m_mu,
1075        dlog_mu_deta,
1076        dlog1m_mu_deta,
1077        dmu_dlink,
1078    })
1079}
1080
1081#[inline]
1082fn binomial_link_terms(
1083    inverse_link: &InverseLink,
1084    eta: f64,
1085    n_link_params: usize,
1086) -> Result<BinomialLinkTerms, String> {
1087    let jet =
1088        inverse_link_jet_for_inverse_link(inverse_link, eta).map_err(|err| err.to_string())?;
1089    let mut dmu_dlink = vec![0.0; n_link_params];
1090    if n_link_params > 0 {
1091        match inverse_link
1092            .param_partials(eta)
1093            .map_err(|err| err.to_string())?
1094        {
1095            Some(LinkParamPartials::Sas(partials)) => {
1096                if n_link_params != 2 {
1097                    return Err(format!(
1098                        "SAS/Beta-Logistic link parameter dimension mismatch: expected 2, got {n_link_params}"
1099                    ));
1100                }
1101                dmu_dlink[0] = partials.djet_depsilon.mu;
1102                dmu_dlink[1] = partials.djet_dlog_delta.mu;
1103            }
1104            Some(LinkParamPartials::Mixture(partials)) => {
1105                if partials.djet_drho.len() != n_link_params {
1106                    return Err(format!(
1107                        "mixture link parameter dimension mismatch: expected {}, got {n_link_params}",
1108                        partials.djet_drho.len()
1109                    ));
1110                }
1111                for (slot, partial) in dmu_dlink.iter_mut().zip(partials.djet_drho.iter()) {
1112                    *slot = partial.mu;
1113                }
1114            }
1115            None => {
1116                return Err(format!(
1117                    "joint HMC expected {n_link_params} adaptive link parameters, but the inverse link exposes none"
1118                ));
1119            }
1120        }
1121    }
1122    log_terms_from_mu_and_dmu(jet.mu, jet.d1, dmu_dlink)
1123}
1124
1125fn joint_binomial_logp_grad_and_link_grad(
1126    inverse_link: &InverseLink,
1127    data: &SharedData,
1128    eta: &Array1<f64>,
1129    n_link_params: usize,
1130) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1131    let n = data.n_samples;
1132    // Per-row: compute stable log-tail terms and derivatives without endpoint
1133    // clamping. Positive-weight responses were validated as Bernoulli before
1134    // target construction, so each row selects exactly one log branch.
1135    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1136    let per_row: Result<Vec<(f64, f64, Vec<f64>)>, String> = (0..n)
1137        .into_par_iter()
1138        .map(|i| {
1139            let y_i = data.y[i];
1140            let w_i = data.weights[i];
1141            if w_i <= 0.0 {
1142                return Ok((0.0, 0.0, vec![0.0; n_link_params]));
1143            }
1144            let terms = binomial_link_terms(inverse_link, eta[i], n_link_params)?;
1145            if y_i == 1.0 {
1146                let inv_mu = terms.log_mu.exp().recip();
1147                let log_mu = terms.log_mu;
1148                let dlog_mu_deta = terms.dlog_mu_deta;
1149                let grad_link = terms
1150                    .dmu_dlink
1151                    .into_iter()
1152                    .map(|dmu| w_i * dmu * inv_mu)
1153                    .collect();
1154                Ok((w_i * log_mu, w_i * dlog_mu_deta, grad_link))
1155            } else if y_i == 0.0 {
1156                let inv_one_minus_mu = terms.log1m_mu.exp().recip();
1157                let log1m_mu = terms.log1m_mu;
1158                let dlog1m_mu_deta = terms.dlog1m_mu_deta;
1159                let grad_link = terms
1160                    .dmu_dlink
1161                    .into_iter()
1162                    .map(|dmu| -w_i * dmu * inv_one_minus_mu)
1163                    .collect();
1164                Ok((w_i * log1m_mu, w_i * dlog1m_mu_deta, grad_link))
1165            } else {
1166                Err(format!(
1167                    "binomial joint HMC response must be exactly 0 or 1 after validation; got {y_i}"
1168                ))
1169            }
1170        })
1171        .collect();
1172    let per_row = per_row?;
1173    let mut residual = Array1::<f64>::zeros(n);
1174    let mut grad_link = Array1::<f64>::zeros(n_link_params);
1175    let mut ll = 0.0;
1176    for (i, (ll_i, residual_i, grad_link_i)) in per_row.into_iter().enumerate() {
1177        ll += ll_i;
1178        residual[i] = residual_i;
1179        for (slot, value) in grad_link.iter_mut().zip(grad_link_i.iter()) {
1180            *slot += *value;
1181        }
1182    }
1183
1184    Ok((ll, fast_atv(&data.x, &residual), grad_link))
1185}
1186
1187fn joint_binomial_logp_and_grad(
1188    likelihood: &LikelihoodSpec,
1189    data: &SharedData,
1190    eta: &Array1<f64>,
1191) -> Result<(f64, Array1<f64>), String> {
1192    if !matches!(likelihood.response, ResponseFamily::Binomial) {
1193        return Err(HmcError::UnsupportedFamily {
1194            reason: format!(
1195                "{} is not a binomial joint-HMC family",
1196                likelihood.pretty_name()
1197            ),
1198        }
1199        .into());
1200    }
1201    match &likelihood.link {
1202        InverseLink::Standard(StandardLink::Logit) => Ok(logit_logp_and_grad(data, eta)),
1203        InverseLink::Standard(StandardLink::Probit) => Ok(probit_logp_and_grad(data, eta)),
1204        InverseLink::Standard(StandardLink::CLogLog) => Ok(cloglog_logp_and_grad(data, eta)),
1205        InverseLink::LatentCLogLog(_)
1206        | InverseLink::Sas(_)
1207        | InverseLink::BetaLogistic(_)
1208        | InverseLink::Mixture(_) => {
1209            let (ll, grad_beta, _) =
1210                joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, 0)?;
1211            Ok((ll, grad_beta))
1212        }
1213        InverseLink::Standard(_) => Err(HmcError::UnsupportedFamily {
1214            reason: format!(
1215                "{} is not a binomial joint-HMC family",
1216                likelihood.pretty_name()
1217            ),
1218        }
1219        .into()),
1220    }
1221}
1222
1223fn joint_family_logp_grad_and_link_grad(
1224    likelihood: &LikelihoodSpec,
1225    data: &SharedData,
1226    eta: &Array1<f64>,
1227    n_link_params: usize,
1228) -> Result<(f64, Array1<f64>, Array1<f64>), String> {
1229    match (&likelihood.response, &likelihood.link) {
1230        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
1231            let (ll, grad) = logit_logp_and_grad(data, eta);
1232            Ok((ll, grad, Array1::zeros(n_link_params)))
1233        }
1234        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
1235            let (ll, grad) = probit_logp_and_grad(data, eta);
1236            Ok((ll, grad, Array1::zeros(n_link_params)))
1237        }
1238        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
1239            let (ll, grad) = cloglog_logp_and_grad(data, eta);
1240            Ok((ll, grad, Array1::zeros(n_link_params)))
1241        }
1242        (
1243            ResponseFamily::Binomial,
1244            InverseLink::LatentCLogLog(_)
1245            | InverseLink::Sas(_)
1246            | InverseLink::BetaLogistic(_)
1247            | InverseLink::Mixture(_),
1248        ) => joint_binomial_logp_grad_and_link_grad(&likelihood.link, data, eta, n_link_params),
1249        _ => {
1250            let (ll, grad) = joint_family_logp_and_grad(likelihood, data, eta)?;
1251            Ok((ll, grad, Array1::zeros(n_link_params)))
1252        }
1253    }
1254}
1255
1256fn joint_family_logp_and_grad(
1257    likelihood: &LikelihoodSpec,
1258    data: &SharedData,
1259    eta: &Array1<f64>,
1260) -> Result<(f64, Array1<f64>), String> {
1261    match &likelihood.response {
1262        ResponseFamily::Binomial => joint_binomial_logp_and_grad(likelihood, data, eta),
1263        ResponseFamily::Gaussian => Ok(gaussian_logp_and_grad(data, eta)),
1264        ResponseFamily::Poisson => Ok(poisson_log_logp_and_grad(data, eta)),
1265        ResponseFamily::Tweedie { p } => {
1266            // Family mapping: Tweedie payload p is the variance power.
1267            // Its dispersion phi stays in data.dispersion, matching REML.
1268            let p = *p;
1269            if !is_valid_tweedie_power(p) {
1270                return Err(HmcError::InvalidConfig {
1271                    reason: format!(
1272                        "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
1273                    ),
1274                }
1275                .into());
1276            }
1277            Ok(tweedie_log_quasilogp_and_grad(data, eta, p))
1278        }
1279        ResponseFamily::NegativeBinomial { theta, .. } => {
1280            // Family mapping: NegativeBinomial payload theta is overdispersion.
1281            // NB keeps unit REML scale and never reads fixed_phi for theta.
1282            Ok(negative_binomial_log_logp_and_grad(data, eta, *theta))
1283        }
1284        ResponseFamily::Beta { .. } => Err(HmcError::UnsupportedFamily {
1285            reason: "Joint HMC fallback is not implemented for BetaLogit".to_string(),
1286        }
1287        .into()),
1288        ResponseFamily::Gamma => Ok(gamma_log_logp_and_grad(data, eta)),
1289        ResponseFamily::RoystonParmar => Err(HmcError::UnsupportedFamily {
1290            reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
1291        }
1292        .into()),
1293    }
1294}
1295
1296/// Logistic regression log-likelihood and gradient.
1297///
1298/// log p(y|η) = y·η − log(1 + exp(η)), gradient = X'(w ⊙ (y − μ))
1299fn logit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1300    let mut residual = Array1::<f64>::zeros(data.n_samples);
1301    logit_logp_and_grad_into(data, eta, &mut residual)
1302}
1303
1304fn logit_logp_and_grad_into(
1305    data: &SharedData,
1306    eta: &Array1<f64>,
1307    residual: &mut Array1<f64>,
1308) -> (f64, Array1<f64>) {
1309    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1310    let n = data.n_samples;
1311    assert_eq!(residual.len(), n);
1312    // Per-row independent: write residual entry into a pre-allocated buffer and
1313    // reduce the ll contribution in parallel — avoids materialising a
1314    // Vec<(f64, f64)> and the serial scatter that follows.
1315    let ll: f64 = residual
1316        .as_slice_mut()
1317        .unwrap()
1318        .par_iter_mut()
1319        .enumerate()
1320        .map(|(i, slot)| {
1321            let eta_i = eta[i];
1322            let y_i = data.y[i];
1323            let w_i = data.weights[i];
1324            let mu = gam_linalg::utils::stable_logistic(eta_i);
1325            *slot = w_i * (y_i - mu);
1326            w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i))
1327        })
1328        .sum();
1329
1330    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1331    (ll, grad_ll)
1332}
1333
1334/// Probit regression log-likelihood and gradient.
1335///
1336/// log p(y|η) = Σ [y·log Φ(η) + (1-y)·log(1-Φ(η))],
1337/// gradient_i = w_i · [y_i · φ(η_i)/Φ(η_i) − (1-y_i) · φ(η_i)/(1−Φ(η_i))]
1338///
1339/// Uses erfc-based log Φ for numerical stability.
1340fn probit_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1341    let mut residual = Array1::<f64>::zeros(data.n_samples);
1342    probit_logp_and_grad_into(data, eta, &mut residual)
1343}
1344
1345fn probit_logp_and_grad_into(
1346    data: &SharedData,
1347    eta: &Array1<f64>,
1348    residual: &mut Array1<f64>,
1349) -> (f64, Array1<f64>) {
1350    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1351    let n = data.n_samples;
1352    assert_eq!(residual.len(), n);
1353    let ll: f64 = residual
1354        .as_slice_mut()
1355        .unwrap()
1356        .par_iter_mut()
1357        .enumerate()
1358        .map(|(i, slot)| {
1359            let eta_i = eta[i];
1360            let y_i = data.y[i];
1361            let w_i = data.weights[i];
1362            let log_phi_pos = log_ndtr(eta_i);
1363            let log_phi_neg = log_ndtr(-eta_i);
1364            let log_phi_val = standard_normal_log_pdf(eta_i);
1365            let ratio_pos = (log_phi_val - log_phi_pos).exp();
1366            let ratio_neg = (log_phi_val - log_phi_neg).exp();
1367            let grad_i = y_i * ratio_pos - (1.0 - y_i) * ratio_neg;
1368            *slot = w_i * grad_i;
1369            w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg)
1370        })
1371        .sum();
1372
1373    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1374    (ll, grad_ll)
1375}
1376
1377/// Complementary log-log regression log-likelihood and gradient.
1378///
1379/// CLogLog link: μ = 1 − exp(−exp(η))
1380/// log p(y|η) = Σ [y·log(1−exp(−exp(η))) + (1−y)·(−exp(η))]
1381/// gradient_i = w_i · [y_i · exp(η_i)·exp(−exp(η_i)) / (1−exp(−exp(η_i))) − (1−y_i)·exp(η_i)]
1382#[inline]
1383fn cloglog_bernoulli_logp_and_residual(eta: f64, y: f64) -> Result<(f64, f64), EstimationError> {
1384    if !(eta.is_finite() && (-700.0..=700.0).contains(&eta)) {
1385        gam_problem::bail_invalid_estim!("cloglog eta must be finite and within [-700, 700]; got {eta}");
1386    }
1387    let exp_eta = eta.exp();
1388    // log_mu = log(1 - exp(-exp_eta)); exp_eta > 0 on the guarded domain, so this is
1389    // exactly the canonical cancellation-free log1mexp (single source of truth).
1390    let log_mu = crate::probability::log1mexp_positive(exp_eta);
1391    let log_one_minus_mu = -exp_eta;
1392    let grad_log_mu = (eta - exp_eta - log_mu).exp();
1393    let ll_i = y * log_mu + (1.0 - y) * log_one_minus_mu;
1394    let residual_i = y * grad_log_mu - (1.0 - y) * exp_eta;
1395    Ok((ll_i, residual_i))
1396}
1397
1398fn cloglog_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1399    let mut residual = Array1::<f64>::zeros(data.n_samples);
1400    cloglog_logp_and_grad_into(data, eta, &mut residual)
1401}
1402
1403fn cloglog_logp_and_grad_into(
1404    data: &SharedData,
1405    eta: &Array1<f64>,
1406    residual: &mut Array1<f64>,
1407) -> (f64, Array1<f64>) {
1408    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1409    let n = data.n_samples;
1410    assert_eq!(residual.len(), n);
1411    if eta
1412        .iter()
1413        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1414    {
1415        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1416    }
1417    let ll: f64 = residual
1418        .as_slice_mut()
1419        .unwrap()
1420        .par_iter_mut()
1421        .enumerate()
1422        .map(|(i, slot)| {
1423            let y_i = data.y[i];
1424            let w_i = data.weights[i];
1425            let (ll_i, residual_i) =
1426                cloglog_bernoulli_logp_and_residual(eta[i], y_i).expect("validated cloglog eta");
1427            *slot = w_i * residual_i;
1428            w_i * ll_i
1429        })
1430        .sum();
1431
1432    let grad_ll = fast_atv(data.x.as_ref(), &*residual);
1433    (ll, grad_ll)
1434}
1435
1436/// Gaussian log-likelihood and gradient.
1437///
1438/// log p(y|η) = −½ (w/φ)·(y − η)²,  gradient = (1/φ)·X'(w ⊙ (y − η))
1439///
1440/// Both the log-likelihood and its β-gradient are scaled by `1/φ` so that
1441/// the working likelihood matches the φ-scaled posterior covariance the
1442/// HMC whitening transform targets. With `φ == 1` (the only value
1443/// passed by the pre-refactor call sites) this collapses to the original
1444/// `−½ w·(y − η)²` expression; with an estimated dispersion (the
1445/// `Dispersion::Estimated(σ²)` branch) it removes the silent unit-σ
1446/// approximation the Gaussian NUTS log-density used previously.
1447fn gaussian_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1448    let mut weighted_residual = Array1::<f64>::zeros(data.n_samples);
1449    gaussian_logp_and_grad_into(data, eta, &mut weighted_residual)
1450}
1451
1452fn gaussian_logp_and_grad_into(
1453    data: &SharedData,
1454    eta: &Array1<f64>,
1455    weighted_residual: &mut Array1<f64>,
1456) -> (f64, Array1<f64>) {
1457    use gam_problem::dispersion_cov::DispersionExt as _;
1458    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1459    let n = data.n_samples;
1460    let inv_phi = data.dispersion.inv_phi();
1461    assert_eq!(weighted_residual.len(), n);
1462    // Per-row: residual = y - η, weighted_residual = (w/φ)·residual,
1463    // ll contribution = -0.5·(w/φ)·residual². All independent across rows.
1464    let ll: f64 = weighted_residual
1465        .as_slice_mut()
1466        .unwrap()
1467        .par_iter_mut()
1468        .enumerate()
1469        .map(|(i, slot)| {
1470            let residual = data.y[i] - eta[i];
1471            let w_i = data.weights[i];
1472            let scaled = w_i * inv_phi;
1473            *slot = scaled * residual;
1474            -0.5 * scaled * residual * residual
1475        })
1476        .sum();
1477
1478    let grad_ll = fast_atv(data.x.as_ref(), &*weighted_residual);
1479    (ll, grad_ll)
1480}
1481
1482/// Poisson(log) log-likelihood and gradient.
1483///
1484/// log p(y|η) = y·η − exp(η), gradient = X'(w ⊙ (y − μ))
1485fn poisson_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1486    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1487    let n = data.n_samples;
1488    if eta
1489        .iter()
1490        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1491    {
1492        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1493    }
1494    let mut residual = Array1::<f64>::zeros(n);
1495    let ll: f64 = residual
1496        .as_slice_mut()
1497        .unwrap()
1498        .par_iter_mut()
1499        .enumerate()
1500        .map(|(i, slot)| {
1501            let eta_i = eta[i];
1502            let mu_i = eta_i.exp();
1503            let y_i = data.y[i];
1504            let w_i = data.weights[i];
1505            *slot = w_i * (y_i - mu_i);
1506            w_i * (y_i * eta_i - mu_i)
1507        })
1508        .sum();
1509
1510    let grad_ll = fast_atv(&data.x, &residual);
1511    (ll, grad_ll)
1512}
1513
1514fn tweedie_log_quasilogp_and_grad(
1515    data: &SharedData,
1516    eta: &Array1<f64>,
1517    p: f64,
1518) -> (f64, Array1<f64>) {
1519    use gam_problem::dispersion_cov::DispersionExt as _;
1520    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1521    let n = data.n_samples;
1522    // Family mapping: Tweedie p is the variant payload; phi is data.dispersion.
1523    // Invalid payloads invalidate the target instead of falling back to p=1.5.
1524    if !is_valid_tweedie_power(p) {
1525        return (f64::NAN, Array1::from_elem(data.dim, f64::NAN));
1526    }
1527    if eta
1528        .iter()
1529        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1530    {
1531        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1532    }
1533    let inv_phi = data.dispersion.inv_phi();
1534    let mut residual = Array1::<f64>::zeros(n);
1535    let ll: f64 = residual
1536        .as_slice_mut()
1537        .unwrap()
1538        .par_iter_mut()
1539        .enumerate()
1540        .map(|(i, slot)| {
1541            let eta_i = eta[i];
1542            let mu_i = eta_i.exp().max(1e-300);
1543            let y_i = data.y[i];
1544            let w_i = data.weights[i] * inv_phi;
1545            *slot = w_i * (y_i - mu_i) * mu_i.powf(1.0 - p);
1546            let qll = y_i * mu_i.powf(1.0 - p) / (1.0 - p) - mu_i.powf(2.0 - p) / (2.0 - p);
1547            w_i * qll
1548        })
1549        .sum();
1550
1551    let grad_ll = fast_atv(&data.x, &residual);
1552    (ll, grad_ll)
1553}
1554
1555fn negative_binomial_log_logp_and_grad(
1556    data: &SharedData,
1557    eta: &Array1<f64>,
1558    theta: f64,
1559) -> (f64, Array1<f64>) {
1560    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1561    let n = data.n_samples;
1562    if !(theta.is_finite() && theta > 0.0)
1563        || eta
1564            .iter()
1565            .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1566        || data
1567            .y
1568            .iter()
1569            .zip(data.weights.iter())
1570            .any(|(&y_i, &w_i)| w_i > 0.0 && !valid_count_response(y_i))
1571    {
1572        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1573    }
1574    let mut residual = Array1::<f64>::zeros(n);
1575    let ll: f64 = residual
1576        .as_slice_mut()
1577        .unwrap()
1578        .par_iter_mut()
1579        .enumerate()
1580        .map(|(i, slot)| {
1581            let eta_i = eta[i];
1582            let mu_i = eta_i.exp().max(1e-12);
1583            let y_i = data.y[i];
1584            let w_i = data.weights[i];
1585            if w_i <= 0.0 {
1586                *slot = 0.0;
1587                return 0.0;
1588            }
1589            let log_mu_term = if y_i > 0.0 { y_i * mu_i.ln() } else { 0.0 };
1590            *slot = w_i * theta * (y_i - mu_i) / (theta + mu_i);
1591            w_i * (statrs::function::gamma::ln_gamma(y_i + theta)
1592                - statrs::function::gamma::ln_gamma(theta)
1593                - statrs::function::gamma::ln_gamma(y_i + 1.0)
1594                + theta * (theta.ln() - (theta + mu_i).ln())
1595                + log_mu_term
1596                - y_i * (theta + mu_i).ln())
1597        })
1598        .sum();
1599
1600    let grad_ll = fast_atv(&data.x, &residual);
1601    (ll, grad_ll)
1602}
1603
1604fn gamma_log_logp_and_grad(data: &SharedData, eta: &Array1<f64>) -> (f64, Array1<f64>) {
1605    use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1606    let n = data.n_samples;
1607    if eta
1608        .iter()
1609        .any(|&eta_i| !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)))
1610    {
1611        return (f64::NEG_INFINITY, Array1::zeros(data.dim));
1612    }
1613    let shape = data.gamma_shape.max(1e-10);
1614    // Hoist shape-only constants out of the per-sample loop: ln Γ(shape) and
1615    // shape · ln(shape) are independent of i, so previously each sample paid
1616    // an extra `ln_gamma` and `ln` plus a multiply. n is typically large-scale-
1617    // scale, so this collapses Θ(n) gamma-function evaluations to one.
1618    let shape_ln_shape = shape * shape.ln();
1619    let log_gamma_shape = statrs::function::gamma::ln_gamma(shape);
1620    let shape_minus_one = shape - 1.0;
1621    let mut residual = Array1::<f64>::zeros(n);
1622    let ll: f64 = residual
1623        .as_slice_mut()
1624        .unwrap()
1625        .par_iter_mut()
1626        .enumerate()
1627        .map(|(i, slot)| {
1628            let eta_i = eta[i];
1629            let mu_i = eta_i.exp();
1630            let y_i = data.y[i];
1631            let w_i = data.weights[i];
1632            let ll_i = w_i
1633                * (shape_ln_shape - log_gamma_shape - shape * eta_i
1634                    + shape_minus_one * y_i.max(1e-12).ln()
1635                    - shape * y_i / mu_i);
1636            *slot = w_i * shape * (y_i / mu_i - 1.0);
1637            ll_i
1638        })
1639        .sum();
1640
1641    let grad_ll = fast_atv(&data.x, &residual);
1642    (ll, grad_ll)
1643}
1644
1645#[cfg(test)]
1646mod tests {
1647    use super::{
1648        FamilyNutsInputs, GlmFlatInputs, JointBetaRhoInputs, JointBetaRhoPosterior,
1649        LinkWigglePosterior, LinkWiggleSplineArtifacts, NutsConfig, NutsFamily, NutsPosterior,
1650        SharedData, cloglog_bernoulli_logp_and_residual, firth_jeffreys_logp_and_grad,
1651        joint_family_logp_and_grad, laplace_directional_cubic_diagnostic,
1652        laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
1653        run_joint_beta_rho_sampling, run_logit_polya_gamma_gibbs,
1654        run_nuts_sampling_flattened_family,
1655    };
1656    use gam_terms::construction::CanonicalPenalty;
1657    use gam_solve::estimate::{
1658        BlockRole, FitGeometry, FitInference, FittedBlock, FittedLinkState, UnifiedFitResult,
1659        UnifiedFitResultParts,
1660    };
1661    use gam_models::survival::{PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec};
1662    use gam_linalg::matrix::DesignMatrix;
1663    use gam_problem::types::{
1664        InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LogLikelihoodNormalization,
1665        ResponseFamily, RhoPrior, StandardLink,
1666    };
1667    use general_mcmc::generic_hmc::HamiltonianTarget;
1668    use ndarray::{Array1, Array2, array};
1669    use std::sync::Arc;
1670
1671    impl NutsPosterior {
1672        /// Test-only allocation wrapper around `compute_logp_and_grad_nd_into`.
1673        pub(super) fn compute_logp_and_grad_nd(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1674            let mut residual = Array1::<f64>::zeros(self.data.n_samples);
1675            let mut grad = Array1::<f64>::zeros(z.len());
1676            let logp = self.compute_logp_and_grad_nd_into(z, &mut residual, &mut grad);
1677            (logp, grad)
1678        }
1679    }
1680
1681    impl LinkWigglePosterior {
1682        /// Test-only allocation wrapper around `compute_logp_and_grad_into`.
1683        pub(super) fn compute_logp_and_grad(&self, z: &Array1<f64>) -> (f64, Array1<f64>) {
1684            let dim = self.p_base + self.p_link;
1685            let mut grad = Array1::<f64>::zeros(dim);
1686            let logp = self.compute_logp_and_grad_into(z, &mut grad);
1687            (logp, grad)
1688        }
1689    }
1690
1691    impl JointBetaRhoPosterior {
1692        /// Test-only allocation wrapper around `compute_joint_logp_and_grad_into`.
1693        pub(super) fn compute_joint_logp_and_grad(
1694            &self,
1695            params: &Array1<f64>,
1696        ) -> (f64, Array1<f64>) {
1697            let total_dim = self.n_beta + self.n_rho + self.n_link_params;
1698            let mut grad = Array1::<f64>::zeros(total_dim);
1699            let logp = self.compute_joint_logp_and_grad_into(params, &mut grad);
1700            (logp, grad)
1701        }
1702    }
1703
1704    fn hmc_test_fit(
1705        blocks: Vec<FittedBlock>,
1706        inference: Option<FitInference>,
1707        geometry: Option<FitGeometry>,
1708    ) -> UnifiedFitResult {
1709        let lambdas = Array1::zeros(0);
1710        UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1711            blocks,
1712            log_lambdas: lambdas.clone(),
1713            lambdas,
1714            likelihood_family: Some(LikelihoodSpec::new(
1715                ResponseFamily::Gaussian,
1716                InverseLink::Standard(StandardLink::Identity),
1717            )),
1718            likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1719            log_likelihood_normalization: LogLikelihoodNormalization::Full,
1720            log_likelihood: -1.0,
1721            deviance: 2.0,
1722            reml_score: 0.0,
1723            stable_penalty_term: 0.0,
1724            penalized_objective: 0.0,
1725            used_device: false,
1726            outer_iterations: 1,
1727            outer_converged: true,
1728            outer_gradient_norm: None,
1729            standard_deviation: 1.0,
1730            covariance_conditional: None,
1731            covariance_corrected: None,
1732            inference,
1733            fitted_link: FittedLinkState::Standard(None),
1734            geometry,
1735            block_states: Vec::new(),
1736            pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1737            max_abs_eta: 0.0,
1738            constraint_kkt: None,
1739            artifacts: Default::default(),
1740            inner_cycles: 0,
1741        })
1742        .expect("valid HMC handoff test fit")
1743    }
1744
1745    #[test]
1746    fn hmc_whitening_consumes_standard_fit_inference_hessian() {
1747        let hessian = array![[2.0, 0.1], [0.1, 1.6]];
1748        let fit = hmc_test_fit(
1749            vec![FittedBlock {
1750                beta: array![0.05, -0.1],
1751                role: BlockRole::Mean,
1752                edf: 2.0,
1753                lambdas: Array1::zeros(0),
1754            }],
1755            Some(FitInference {
1756                edf_by_block: vec![],
1757                penalty_block_trace: vec![],
1758                edf_total: 2.0,
1759                smoothing_correction: None,
1760                penalized_hessian: hessian.clone().into(),
1761                working_weights: array![1.0, 1.0, 1.0],
1762                working_response: array![0.0, 0.1, -0.2],
1763                reparam_qs: None,
1764                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
1765                beta_covariance: None,
1766                beta_standard_errors: None,
1767                beta_covariance_corrected: None,
1768                beta_standard_errors_corrected: None,
1769                beta_covariance_frequentist: None,
1770                coefficient_influence: None,
1771                weighted_gram: None,
1772                bias_correction_beta: None,
1773            }),
1774            None,
1775        );
1776
1777        let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "standard fit")
1778            .expect("standard fit exports explicit Hessian");
1779        assert_eq!(explicit, &hessian);
1780
1781        let x = array![[1.0, 0.0], [1.0, 0.5], [1.0, -0.5]];
1782        let y = array![0.0, 0.2, -0.1];
1783        let weights = Array1::ones(3);
1784        let penalty = Array2::eye(2);
1785        NutsPosterior::new(
1786            x.view(),
1787            y.view(),
1788            weights.view(),
1789            penalty.view(),
1790            fit.beta.view(),
1791            explicit.view(),
1792            NutsFamily::Gaussian,
1793            1.0,
1794            gam_solve::estimate::Dispersion::Known(1.0),
1795            false,
1796        )
1797        .expect("HMC target whitens with upstream Hessian");
1798    }
1799
1800    #[test]
1801    fn hmc_whitening_consumes_blockwise_geometry_hessian() {
1802        let hessian = array![[3.0, 0.2], [0.2, 2.0]];
1803        let fit = hmc_test_fit(
1804            vec![
1805                FittedBlock {
1806                    beta: array![0.1],
1807                    role: BlockRole::Location,
1808                    edf: 1.0,
1809                    lambdas: Array1::zeros(0),
1810                },
1811                FittedBlock {
1812                    beta: array![-0.2],
1813                    role: BlockRole::Scale,
1814                    edf: 1.0,
1815                    lambdas: Array1::zeros(0),
1816                },
1817            ],
1818            None,
1819            Some(FitGeometry {
1820                penalized_hessian: hessian.clone().into(),
1821                working_weights: array![1.0, 0.8],
1822                working_response: array![0.0, 0.1],
1823            }),
1824        );
1825
1826        let explicit = super::explicit_fit_hessian_for_whitening(&fit, 2, "blockwise fit")
1827            .expect("blockwise fit exports materialized Hessian");
1828        assert_eq!(explicit, &hessian);
1829    }
1830
1831    #[test]
1832    fn hmc_whitening_rejects_covariance_only_fit_without_synthesizing_hessian() {
1833        let fit = UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
1834            blocks: vec![FittedBlock {
1835                beta: array![0.0],
1836                role: BlockRole::Mean,
1837                edf: 1.0,
1838                lambdas: Array1::zeros(0),
1839            }],
1840            log_lambdas: Array1::zeros(0),
1841            lambdas: Array1::zeros(0),
1842            likelihood_family: Some(LikelihoodSpec::new(
1843                ResponseFamily::Gaussian,
1844                InverseLink::Standard(StandardLink::Identity),
1845            )),
1846            likelihood_scale: LikelihoodScaleMetadata::ProfiledGaussian,
1847            log_likelihood_normalization: LogLikelihoodNormalization::Full,
1848            log_likelihood: -1.0,
1849            deviance: 2.0,
1850            reml_score: 0.0,
1851            stable_penalty_term: 0.0,
1852            penalized_objective: 0.0,
1853            used_device: false,
1854            outer_iterations: 1,
1855            outer_converged: true,
1856            outer_gradient_norm: None,
1857            standard_deviation: 1.0,
1858            covariance_conditional: Some(array![[0.5]]),
1859            covariance_corrected: None,
1860            inference: None,
1861            fitted_link: FittedLinkState::Standard(None),
1862            geometry: None,
1863            block_states: Vec::new(),
1864            pirls_status: gam_solve::pirls::PirlsStatus::Converged,
1865            max_abs_eta: 0.0,
1866            constraint_kkt: None,
1867            artifacts: Default::default(),
1868            inner_cycles: 0,
1869        })
1870        .expect("covariance-only fit can exist for prediction");
1871
1872        let err = super::explicit_fit_hessian_for_whitening(&fit, 1, "covariance-only fit")
1873            .expect_err("HMC must not invert covariance as a Hessian fallback");
1874        assert!(
1875            err.contains("missing an explicit penalized Hessian"),
1876            "unexpected error: {err}"
1877        );
1878    }
1879
1880    #[test]
1881    fn log1pexp_is_finite_for_extreme_eta() {
1882        assert!(gam_linalg::utils::stable_softplus(1000.0).is_finite());
1883        assert!(gam_linalg::utils::stable_softplus(-1000.0).is_finite());
1884        assert!((gam_linalg::utils::stable_softplus(-1000.0) - 0.0).abs() < 1e-12);
1885    }
1886
1887    #[test]
1888    fn sigmoid_stable_behaves_at_extremes() {
1889        let hi = gam_linalg::utils::stable_logistic(1000.0);
1890        let lo = gam_linalg::utils::stable_logistic(-1000.0);
1891        assert!((1.0 - 1e-12..=1.0).contains(&hi));
1892        assert!((0.0..=1e-12).contains(&lo));
1893    }
1894
1895    #[test]
1896    fn cloglog_log_mu_uses_complementary_loglog_inverse_link() {
1897        let eta = -1.0_f64;
1898        let (ll_y1, residual_y1) =
1899            cloglog_bernoulli_logp_and_residual(eta, 1.0).expect("valid eta");
1900        let expected = (1.0 - (-eta.exp()).exp()).ln();
1901        let wrong_log_one_minus_exp_eta = (1.0 - eta.exp()).ln();
1902
1903        assert!((ll_y1 - expected).abs() < 1e-14);
1904        assert!((ll_y1 - wrong_log_one_minus_exp_eta).abs() > 0.5);
1905
1906        let eps = 1e-6;
1907        let (lp, _) = cloglog_bernoulli_logp_and_residual(eta + eps, 1.0).expect("valid eta");
1908        let (lm, _) = cloglog_bernoulli_logp_and_residual(eta - eps, 1.0).expect("valid eta");
1909        let fd = (lp - lm) / (2.0 * eps);
1910        assert!(
1911            (residual_y1 - fd).abs() < 1e-9,
1912            "cloglog residual is not the derivative of log μ: analytic={residual_y1}, fd={fd}"
1913        );
1914    }
1915
1916    #[test]
1917    fn link_wiggle_posterior_whitening_uses_supplied_explicit_joint_hessian() {
1918        let x = array![[1.0], [1.0], [1.0]];
1919        let y = array![0.0, 1.0, 1.0];
1920        let weights = Array1::ones(3);
1921        let penalty_base = Array2::zeros((1, 1));
1922        let penalty_link = Array2::zeros((1, 1));
1923        let mode_beta = array![0.2];
1924        let mode_theta = array![0.05];
1925        let hessian = array![[4.0, 1.0], [1.0, 3.0]];
1926        let spline = LinkWiggleSplineArtifacts {
1927            knot_range: (-1.0, 1.0),
1928            knot_vector: Array1::from_vec(vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]),
1929            degree: 2,
1930        };
1931
1932        let posterior = LinkWigglePosterior::new(
1933            x.view(),
1934            y.view(),
1935            weights.view(),
1936            penalty_base.view(),
1937            penalty_link.view(),
1938            mode_beta.view(),
1939            mode_theta.view(),
1940            hessian.view(),
1941            spline,
1942            NutsFamily::BinomialLogit,
1943            1.0,
1944        )
1945        .expect("link-wiggle posterior should accept explicit SPD joint Hessian");
1946
1947        let reconstructed_cov = posterior.chol().dot(&posterior.chol().t());
1948        let eye_from_hessian = hessian.dot(&reconstructed_cov);
1949        for r in 0..2 {
1950            for c in 0..2 {
1951                let expected = if r == c { 1.0 } else { 0.0 };
1952                assert!(
1953                    (eye_from_hessian[[r, c]] - expected).abs() < 1e-10,
1954                    "whitening did not use the supplied explicit joint Hessian at ({r},{c}): got {} expected {}",
1955                    eye_from_hessian[[r, c]],
1956                    expected
1957                );
1958            }
1959        }
1960    }
1961
1962    #[test]
1963    fn link_wiggle_cloglog_gradient_matches_its_log_likelihood() {
1964        let x = array![[1.0], [1.0], [1.0], [1.0]];
1965        let y = array![1.0, 0.0, 1.0, 0.0];
1966        let weights = array![1.0, 1.2, 0.8, 1.4];
1967        let penalty_base = Array2::zeros((1, 1));
1968        let penalty_link = Array2::zeros((1, 1));
1969        let mode_beta = array![-0.8];
1970        let mode_theta = array![0.04];
1971        let hessian = Array2::eye(2);
1972        let spline = LinkWiggleSplineArtifacts {
1973            knot_range: (-1.5, 0.5),
1974            knot_vector: Array1::from_vec(vec![-1.5, -1.5, -1.5, 0.5, 0.5, 0.5]),
1975            degree: 2,
1976        };
1977
1978        let posterior = LinkWigglePosterior::new(
1979            x.view(),
1980            y.view(),
1981            weights.view(),
1982            penalty_base.view(),
1983            penalty_link.view(),
1984            mode_beta.view(),
1985            mode_theta.view(),
1986            hessian.view(),
1987            spline,
1988            NutsFamily::BinomialCLogLog,
1989            1.0,
1990        )
1991        .expect("cloglog link-wiggle posterior");
1992
1993        let z = array![0.2, -0.03];
1994        let (_, grad) = posterior.compute_logp_and_grad(&z);
1995        let eps = 1e-6;
1996        for j in 0..z.len() {
1997            let mut z_plus = z.clone();
1998            let mut z_minus = z.clone();
1999            z_plus[j] += eps;
2000            z_minus[j] -= eps;
2001            let (lp, _) = posterior.compute_logp_and_grad(&z_plus);
2002            let (lm, _) = posterior.compute_logp_and_grad(&z_minus);
2003            let fd = (lp - lm) / (2.0 * eps);
2004            assert!(
2005                (grad[j] - fd).abs() < 1e-6,
2006                "link-wiggle cloglog gradient mismatch at {j}: analytic={}, fd={}",
2007                grad[j],
2008                fd
2009            );
2010        }
2011    }
2012
2013    #[test]
2014    fn nuts_logitgradient_matches_finite_difference() {
2015        let x = array![[1.0, -0.5], [0.2, 0.7], [-1.0, 0.3], [0.5, -1.2]];
2016        let y = array![1.0, 0.0, 1.0, 0.0];
2017        let w = array![1.0, 1.5, 0.8, 1.2];
2018        let penalty = array![[0.4, 0.0], [0.0, 0.6]];
2019        let mode = array![0.1, -0.2];
2020        let hessian = array![[2.0, 0.2], [0.2, 1.7]]; // SPD
2021
2022        let posterior = NutsPosterior::new(
2023            x.view(),
2024            y.view(),
2025            w.view(),
2026            penalty.view(),
2027            mode.view(),
2028            hessian.view(),
2029            NutsFamily::BinomialLogit,
2030            1.0,
2031            gam_solve::estimate::Dispersion::Known(1.0),
2032            true,
2033        )
2034        .expect("posterior");
2035
2036        let z = array![0.15, -0.35];
2037        let (_, grad) = posterior.compute_logp_and_grad_nd(&z);
2038
2039        let eps = 1e-6;
2040        for j in 0..z.len() {
2041            let mut z_plus = z.clone();
2042            let mut z_minus = z.clone();
2043            z_plus[j] += eps;
2044            z_minus[j] -= eps;
2045            let (lp, _) = posterior.compute_logp_and_grad_nd(&z_plus);
2046            let (lm, _) = posterior.compute_logp_and_grad_nd(&z_minus);
2047            let fd = (lp - lm) / (2.0 * eps);
2048            assert_eq!(
2049                grad[j].signum(),
2050                fd.signum(),
2051                "gradient sign mismatch at {}: analytic={}, fd={}",
2052                j,
2053                grad[j],
2054                fd
2055            );
2056            assert!(
2057                (grad[j] - fd).abs() < 1e-5,
2058                "gradient mismatch at {}: analytic={}, fd={}",
2059                j,
2060                grad[j],
2061                fd
2062            );
2063        }
2064    }
2065
2066    #[test]
2067    fn gamma_log_logp_and_grad_uses_fitted_shape() {
2068        let x = array![[1.0_f64], [1.0_f64]];
2069        let y = array![1.5_f64, 2.5_f64];
2070        let weights = array![1.0_f64, 2.0_f64];
2071        let eta = array![0.2_f64, 0.4_f64];
2072        let shape = 3.5_f64;
2073        let data = SharedData {
2074            x: Arc::new(x.clone()),
2075            y: Arc::new(y.clone()),
2076            weights: Arc::new(weights.clone()),
2077            mode: Arc::new(Array1::zeros(1)),
2078            offset: None,
2079            gamma_shape: shape,
2080            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2081            n_samples: x.nrows(),
2082            dim: x.ncols(),
2083        };
2084
2085        let (ll, grad) = super::gamma_log_logp_and_grad(&data, &eta);
2086
2087        let mut expected_ll = 0.0;
2088        let mut expected_score = 0.0;
2089        for i in 0..eta.len() {
2090            let mu = eta[i].exp();
2091            expected_ll += weights[i]
2092                * (shape * shape.ln() - statrs::function::gamma::ln_gamma(shape) - shape * eta[i]
2093                    + (shape - 1.0) * y[i].ln()
2094                    - shape * y[i] / mu);
2095            expected_score += weights[i] * shape * (y[i] / mu - 1.0);
2096        }
2097
2098        assert!((ll - expected_ll).abs() < 1e-12);
2099        assert_eq!(grad.len(), 1);
2100        assert!((grad[0] - expected_score).abs() < 1e-12);
2101    }
2102
2103    /// Gamma observed information at the mode, `Xᵀ diag(w·ν·y/μ) X`, where the
2104    /// per-point curvature `w·ν·y/μ` is exactly `−∂/∂η` of the analytic score
2105    /// slot `w·ν·(y/μ − 1)` used by `gamma_log_logp_and_grad`.
2106    fn gamma_log_observed_information(
2107        x: &Array2<f64>,
2108        mode: &Array1<f64>,
2109        y: &Array1<f64>,
2110        weights: &Array1<f64>,
2111        shape: f64,
2112    ) -> Array2<f64> {
2113        let p = x.ncols();
2114        let eta = x.dot(mode);
2115        let mut h = Array2::<f64>::zeros((p, p));
2116        for i in 0..x.nrows() {
2117            let mu = eta[i].exp();
2118            let wt = weights[i] * shape * y[i] / mu;
2119            for a in 0..p {
2120                for b in 0..p {
2121                    h[[a, b]] += wt * x[[i, a]] * x[[i, b]];
2122                }
2123            }
2124        }
2125        h
2126    }
2127
2128    /// Regression for #680: the whitened GammaLog NUTS target must reproduce
2129    /// the #679 coefficient-covariance contract `Vb = H⁻¹` (scale `1.0`), NOT
2130    /// the dispersion-double-counted `(1/ν)(XᵀΛX + S)⁻¹`.
2131    ///
2132    /// We set the stored Hessian to the *true* penalized curvature of the
2133    /// target at the mode, `H = Xᵀ diag(w·ν·y/μ) X + S` (Gamma observed
2134    /// information + the penalty added **unscaled** — exactly the #679 `H`).
2135    /// The whitened target's curvature in z at the mode is `Lᵀ Hβ L`. The fix
2136    /// makes `L Lᵀ = H⁻¹` and `Hβ = H`, so this is the identity. The pre-fix
2137    /// code scaled the penalty by `ν` and the whitening by `√φ`, turning the
2138    /// z-curvature into `φ·(I + (ν−1)·L_H⁻¹ S L_H⁻ᵀ) ≠ I` (for ν=4 the
2139    /// diagonal collapses toward ~0.25, never 1).
2140    #[test]
2141    fn gamma_log_nuts_target_curvature_matches_unscaled_hessian_issue_680() {
2142        let x = array![[1.0, -0.7], [1.0, 0.3], [1.0, 1.1], [1.0, -0.2], [1.0, 0.8],];
2143        let mode = array![0.4_f64, -0.6_f64];
2144        let y = array![1.2_f64, 0.7, 2.3, 0.9, 1.6];
2145        let weights = array![1.0_f64, 1.5, 0.8, 1.2, 1.0];
2146        // ν = 1/φ = 4 ⇒ φ = 0.25: a large, easily-detectable double-count.
2147        let shape = 4.0_f64;
2148        let p = x.ncols();
2149
2150        let h_data = gamma_log_observed_information(&x, &mode, &y, &weights, shape);
2151        // A genuine PD smoothing penalty so the ×ν double-count is detectable.
2152        let s = array![[0.5_f64, 0.1], [0.1, 0.9]];
2153        let hessian = &h_data + &s;
2154
2155        let target = NutsPosterior::new(
2156            x.view(),
2157            y.view(),
2158            weights.view(),
2159            s.view(),
2160            mode.view(),
2161            hessian.view(),
2162            NutsFamily::GammaLog,
2163            shape,
2164            gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2165            false,
2166        )
2167        .expect("GammaLog NUTS target builds");
2168
2169        // z-space precision at the mode (z = 0) via central differences of the
2170        // analytic gradient: `−∂(∇_z logp)/∂z = Lᵀ Hβ L`. Correct value: I.
2171        let eps = 1e-6;
2172        let z0 = Array1::<f64>::zeros(p);
2173        let mut hz = Array2::<f64>::zeros((p, p));
2174        for j in 0..p {
2175            let mut zp = z0.clone();
2176            let mut zm = z0.clone();
2177            zp[j] += eps;
2178            zm[j] -= eps;
2179            let (_, gp) = target.compute_logp_and_grad_nd(&zp);
2180            let (_, gm) = target.compute_logp_and_grad_nd(&zm);
2181            for a in 0..p {
2182                hz[[a, j]] = -(gp[a] - gm[a]) / (2.0 * eps);
2183            }
2184        }
2185
2186        for a in 0..p {
2187            for b in 0..p {
2188                let expected = if a == b { 1.0 } else { 0.0 };
2189                assert!(
2190                    (hz[[a, b]] - expected).abs() < 1e-4,
2191                    "z-curvature[{a},{b}] = {} (expected {expected}); a non-identity \
2192                     value means the GammaLog target re-introduced the #680 dispersion \
2193                     double-count (penalty ×ν and/or whitening ×√φ)",
2194                    hz[[a, b]]
2195                );
2196            }
2197        }
2198        // Trace = p (identity) rejects the φ-scaled `φ·tr(...)` signature.
2199        let trace: f64 = (0..p).map(|i| hz[[i, i]]).sum();
2200        assert!(
2201            (trace - p as f64).abs() < 1e-3,
2202            "z-curvature trace {trace} ≠ {p}: dispersion double-count signature"
2203        );
2204    }
2205
2206    /// Regression for #680 (whitening half, isolated): for a weight-carries-
2207    /// dispersion family the whitening must satisfy `L Lᵀ = H⁻¹` — i.e.
2208    /// `cov_scale = 1` — so the sampler whitens against the same `H⁻¹` it
2209    /// targets. The pre-fix Gamma path scaled `L` by `√φ`, giving
2210    /// `L Lᵀ = φ·H⁻¹` and `chol·cholᵀ·H = φ·I ≠ I`.
2211    #[test]
2212    fn gamma_log_nuts_whitening_targets_unscaled_inverse_hessian_issue_680() {
2213        let x = array![[1.0, -0.4], [1.0, 0.6], [1.0, 0.1], [1.0, 1.3]];
2214        let mode = array![0.2_f64, 0.3_f64];
2215        let y = array![0.8_f64, 1.7, 1.1, 2.2];
2216        let weights = array![1.0_f64, 1.0, 1.5, 0.7];
2217        let shape = 6.25_f64; // φ = 0.16
2218        let p = x.ncols();
2219        let s = array![[0.3_f64, 0.0], [0.0, 0.7]];
2220        let hessian = &gamma_log_observed_information(&x, &mode, &y, &weights, shape) + &s;
2221
2222        let target = NutsPosterior::new(
2223            x.view(),
2224            y.view(),
2225            weights.view(),
2226            s.view(),
2227            mode.view(),
2228            hessian.view(),
2229            NutsFamily::GammaLog,
2230            shape,
2231            gam_solve::estimate::Dispersion::Estimated(1.0 / shape),
2232            false,
2233        )
2234        .expect("GammaLog NUTS target builds");
2235
2236        // chol = L with L Lᵀ = H⁻¹  ⇒  (L Lᵀ) H = I.
2237        let l = target.chol();
2238        let llt = l.dot(&l.t());
2239        let prod = llt.dot(&hessian);
2240        for a in 0..p {
2241            for b in 0..p {
2242                let expected = if a == b { 1.0 } else { 0.0 };
2243                assert!(
2244                    (prod[[a, b]] - expected).abs() < 1e-8,
2245                    "L Lᵀ H[{a},{b}] = {} (expected {expected}); a φ·I result means \
2246                     the Gamma whitening still scales by √φ (#680)",
2247                    prod[[a, b]]
2248                );
2249            }
2250        }
2251    }
2252
2253    #[test]
2254    fn firth_jeffreys_logit_is_finite_for_rank_deficient_design() {
2255        let x = array![
2256            [1.0, -0.5, 1.0],
2257            [1.0, 0.3, 1.0],
2258            [1.0, 0.8, 1.0],
2259            [1.0, -1.2, 1.0],
2260        ];
2261        let y = array![1.0, 0.0, 1.0, 0.0];
2262        let weights = array![1.0, 2.0, 0.5, 1.5];
2263        let eta = array![0.2, -0.1, 0.4, -0.3];
2264
2265        let data = SharedData {
2266            x: Arc::new(x.clone()),
2267            y: Arc::new(y),
2268            weights: Arc::new(weights.clone()),
2269            mode: Arc::new(Array1::zeros(x.ncols())),
2270            offset: None,
2271            gamma_shape: 1.0,
2272            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2273            n_samples: x.nrows(),
2274            dim: x.ncols(),
2275        };
2276
2277        let (value, grad) =
2278            firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &data, &eta).expect("firth");
2279
2280        assert!(value.is_finite());
2281        assert_eq!(grad.len(), x.ncols());
2282        assert!(grad.iter().all(|v| v.is_finite()));
2283    }
2284
2285    #[test]
2286    fn logit_pg_gibbs_returns_finite_samples() {
2287        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2288        let y = array![1.0, 0.0, 1.0, 0.0];
2289        let w = array![1.0, 1.0, 1.0, 1.0];
2290        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2291        let mode = array![0.0, 0.0];
2292        let cfg = NutsConfig {
2293            n_samples: 30,
2294            nwarmup: 30,
2295            n_chains: 2,
2296            target_accept: 0.8,
2297            seed: 123,
2298        };
2299        let out = run_logit_polya_gamma_gibbs(
2300            x.view(),
2301            y.view(),
2302            w.view(),
2303            penalty.view(),
2304            mode.view(),
2305            &cfg,
2306        )
2307        .expect("pg gibbs should run");
2308        assert_eq!(out.samples.ncols(), 2);
2309        assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2310        assert!(out.samples.iter().all(|v| v.is_finite()));
2311        assert!(out.posterior_mean.iter().all(|v| v.is_finite()));
2312        assert!(out.posterior_std.iter().all(|v| v.is_finite()));
2313    }
2314
2315    #[test]
2316    fn family_pg_dispatch_rejects_non_bernoulli_response() {
2317        let x = array![[1.0], [1.0]];
2318        let y = array![2.0, 0.0];
2319        let w = array![1.0, 1.0];
2320        let penalty = array![[0.1]];
2321        let mode = array![0.0];
2322        let non_spd_hessian = array![[0.0]];
2323        let cfg = NutsConfig {
2324            n_samples: 1,
2325            nwarmup: 1,
2326            n_chains: 1,
2327            target_accept: 0.8,
2328            seed: 321,
2329        };
2330
2331        let result = run_nuts_sampling_flattened_family(
2332            LikelihoodSpec::binomial_logit(),
2333            FamilyNutsInputs::Glm(GlmFlatInputs {
2334                x: x.view(),
2335                y: y.view(),
2336                weights: w.view(),
2337                penalty_matrix: penalty.view(),
2338                mode: mode.view(),
2339                hessian: non_spd_hessian.view(),
2340                gamma_shape: None,
2341                dispersion: gam_solve::model_types::Dispersion::Known(1.0),
2342                firth_bias_reduction: false,
2343                offset: None,
2344            }),
2345            &cfg,
2346        );
2347
2348        let err = result.err().expect("PG dispatch should reject count rows");
2349        assert!(
2350            err.contains("response must be exactly 0 or 1"),
2351            "unexpected error: {err}"
2352        );
2353    }
2354
2355    #[test]
2356    fn family_dispatch_uses_pg_gibbs_for_standard_logit() {
2357        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2358        let y = array![1.0, 0.0, 1.0, 0.0];
2359        let w = array![1.0, 1.0, 1.0, 1.0];
2360        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2361        let mode = array![0.0, 0.0];
2362        let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2363        let cfg = NutsConfig {
2364            n_samples: 20,
2365            nwarmup: 20,
2366            n_chains: 2,
2367            target_accept: 0.8,
2368            seed: 456,
2369        };
2370        let out = run_nuts_sampling_flattened_family(
2371            LikelihoodSpec {
2372                response: ResponseFamily::Binomial,
2373                link: InverseLink::Standard(StandardLink::Logit),
2374            },
2375            FamilyNutsInputs::Glm(GlmFlatInputs {
2376                x: x.view(),
2377                y: y.view(),
2378                weights: w.view(),
2379                penalty_matrix: penalty.view(),
2380                mode: mode.view(),
2381                hessian: non_spdhessian.view(),
2382                gamma_shape: None,
2383                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2384                firth_bias_reduction: false,
2385                offset: None,
2386            }),
2387            &cfg,
2388        )
2389        .expect("dispatch should use PG Gibbs and not require Hessian factorization");
2390        assert_eq!(out.samples.nrows(), cfg.n_samples * cfg.n_chains);
2391        assert!(out.samples.iter().all(|v| v.is_finite()));
2392    }
2393
2394    #[test]
2395    fn family_dispatch_routes_probit_to_nuts_path() {
2396        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2397        let y = array![1.0, 0.0, 1.0, 0.0];
2398        let w = array![1.0, 1.0, 1.0, 1.0];
2399        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2400        let mode = array![0.0, 0.0];
2401        let non_spdhessian = array![[0.0, 0.0], [0.0, 0.0]];
2402        let cfg = NutsConfig {
2403            n_samples: 20,
2404            nwarmup: 20,
2405            n_chains: 2,
2406            target_accept: 0.8,
2407            seed: 654,
2408        };
2409
2410        let err = match run_nuts_sampling_flattened_family(
2411            LikelihoodSpec {
2412                response: ResponseFamily::Binomial,
2413                link: InverseLink::Standard(StandardLink::Probit),
2414            },
2415            FamilyNutsInputs::Glm(GlmFlatInputs {
2416                x: x.view(),
2417                y: y.view(),
2418                weights: w.view(),
2419                penalty_matrix: penalty.view(),
2420                mode: mode.view(),
2421                hessian: non_spdhessian.view(),
2422                gamma_shape: None,
2423                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2424                firth_bias_reduction: false,
2425                offset: None,
2426            }),
2427            &cfg,
2428        ) {
2429            Ok(_) => panic!("non-SPD Hessian should fail after probit routes to the NUTS path"),
2430            Err(err) => err,
2431        };
2432
2433        assert!(
2434            err.contains("Hessian Cholesky decomposition failed"),
2435            "unexpected error: {err}"
2436        );
2437    }
2438
2439    #[test]
2440    fn family_dispatch_rejects_nonbinomial_firth_family() {
2441        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2442        let y = array![1.0, 2.0, 0.0, 3.0];
2443        let w = array![1.0, 1.0, 1.0, 1.0];
2444        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
2445        let mode = array![0.0, 0.0];
2446        let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2447        let cfg = NutsConfig {
2448            n_samples: 20,
2449            nwarmup: 20,
2450            n_chains: 2,
2451            target_accept: 0.8,
2452            seed: 111,
2453        };
2454
2455        let err = match run_nuts_sampling_flattened_family(
2456            LikelihoodSpec {
2457                response: ResponseFamily::Poisson,
2458                link: InverseLink::Standard(StandardLink::Log),
2459            },
2460            FamilyNutsInputs::Glm(GlmFlatInputs {
2461                x: x.view(),
2462                y: y.view(),
2463                weights: w.view(),
2464                penalty_matrix: penalty.view(),
2465                mode: mode.view(),
2466                hessian: hessian.view(),
2467                gamma_shape: None,
2468                dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2469                firth_bias_reduction: true,
2470                offset: None,
2471            }),
2472            &cfg,
2473        ) {
2474            Ok(_) => panic!("Poisson Firth should be rejected explicitly"),
2475            Err(err) => err,
2476        };
2477
2478        assert!(
2479            err.contains(
2480                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet"
2481            ),
2482            "unexpected error: {err}"
2483        );
2484    }
2485
2486    #[test]
2487    fn run_nuts_sampling_rejects_invalid_target_accept() {
2488        let x = array![[1.0], [1.0], [1.0]];
2489        let y = array![0.5, -0.5, 1.0];
2490        let weights = array![1.0, 1.0, 1.0];
2491        let penalty = array![[0.25]];
2492        let mode = array![0.0];
2493        let hessian = array![[1.25]];
2494        let cfg = NutsConfig {
2495            n_samples: 10,
2496            nwarmup: 10,
2497            n_chains: 1,
2498            target_accept: 1.0,
2499            seed: 222,
2500        };
2501
2502        let err = super::run_nuts_sampling(
2503            x.view(),
2504            y.view(),
2505            weights.view(),
2506            penalty.view(),
2507            mode.view(),
2508            hessian.view(),
2509            NutsFamily::Gaussian,
2510            1.0,
2511            gam_solve::estimate::Dispersion::Known(1.0),
2512            false,
2513            None,
2514            &cfg,
2515        )
2516        .expect_err("invalid target_accept should be rejected before sampling");
2517
2518        assert!(
2519            err.contains("target_accept must be finite and lie in (0, 1)"),
2520            "unexpected error: {err}"
2521        );
2522    }
2523
2524    #[test]
2525    fn run_nuts_sampling_rejects_zero_or_too_few_samples() {
2526        // Issue #399: `samples=0` (and `samples` in {1, 2, 3}) reached the
2527        // engine and panicked across the FFI boundary in `general-mcmc`'s
2528        // `.expect(...)` (empty stack / "split R-hat and ESS require at least 2
2529        // split chains and 2 draws per split chain"). The up-front guard must
2530        // reject anything below the split-R-hat-defined minimum of 4 draws with
2531        // a clean typed error *before* the sampler is constructed.
2532        let x = array![[1.0], [1.0], [1.0]];
2533        let y = array![0.5, -0.5, 1.0];
2534        let weights = array![1.0, 1.0, 1.0];
2535        let penalty = array![[0.25]];
2536        let mode = array![0.0];
2537        let hessian = array![[1.25]];
2538
2539        for bad_samples in [0usize, 1, 2, 3] {
2540            let cfg = NutsConfig {
2541                n_samples: bad_samples,
2542                nwarmup: 10,
2543                n_chains: 2,
2544                target_accept: 0.8,
2545                seed: 222,
2546            };
2547
2548            let err = super::run_nuts_sampling(
2549                x.view(),
2550                y.view(),
2551                weights.view(),
2552                penalty.view(),
2553                mode.view(),
2554                hessian.view(),
2555                NutsFamily::Gaussian,
2556                1.0,
2557                gam_solve::estimate::Dispersion::Known(1.0),
2558                false,
2559                None,
2560                &cfg,
2561            )
2562            .expect_err("too-few samples must be rejected before sampling");
2563
2564            assert!(
2565                err.contains("n_samples must be >= 4"),
2566                "n_samples={bad_samples} gave unexpected error: {err}"
2567            );
2568        }
2569    }
2570
2571    #[test]
2572    fn polya_gamma_gibbs_rejects_degenerate_counts_but_accepts_single_chain() {
2573        // Issue #399 (missed path): the canonical unit-weight Bernoulli-logit
2574        // GAM auto-selects the hand-rolled Pólya-Gamma Gibbs sampler, NOT the
2575        // general-mcmc NUTS engine. Pre-fix that path never validated
2576        // n_samples/n_chains, so `chains=0` / `samples=0` silently returned a
2577        // degenerate empty `(0, p)` posterior instead of the typed error the
2578        // NUTS path raised — a divergent contract on one public API. Assert PG
2579        // now rejects the degenerate counts up front, and (mirroring NUTS)
2580        // still accepts a single chain.
2581        let x = array![[1.0], [1.0], [1.0], [1.0]];
2582        let y = array![1.0, 0.0, 1.0, 0.0];
2583        let weights = array![1.0, 1.0, 1.0, 1.0];
2584        let penalty = array![[0.25]];
2585        let mode = array![0.0];
2586
2587        let zero_chain_cfg = NutsConfig {
2588            n_samples: 20,
2589            nwarmup: 10,
2590            n_chains: 0,
2591            target_accept: 0.8,
2592            seed: 7,
2593        };
2594        let err = super::run_logit_polya_gamma_gibbs(
2595            x.view(),
2596            y.view(),
2597            weights.view(),
2598            penalty.view(),
2599            mode.view(),
2600            &zero_chain_cfg,
2601        )
2602        .expect_err("PG Gibbs must reject zero chains up front, not return an empty posterior");
2603        assert!(
2604            err.contains("n_chains must be >= 1"),
2605            "PG n_chains=0 gave unexpected error: {err}"
2606        );
2607
2608        let zero_sample_cfg = NutsConfig {
2609            n_samples: 0,
2610            nwarmup: 10,
2611            n_chains: 2,
2612            target_accept: 0.8,
2613            seed: 7,
2614        };
2615        let err = super::run_logit_polya_gamma_gibbs(
2616            x.view(),
2617            y.view(),
2618            weights.view(),
2619            penalty.view(),
2620            mode.view(),
2621            &zero_sample_cfg,
2622        )
2623        .expect_err("PG Gibbs must reject zero samples up front, not return an empty posterior");
2624        assert!(
2625            err.contains("n_samples must be >= 4"),
2626            "PG n_samples=0 gave unexpected error: {err}"
2627        );
2628
2629        let single_chain_cfg = NutsConfig {
2630            n_samples: 20,
2631            nwarmup: 10,
2632            n_chains: 1,
2633            target_accept: 0.8,
2634            seed: 7,
2635        };
2636        let result = super::run_logit_polya_gamma_gibbs(
2637            x.view(),
2638            y.view(),
2639            weights.view(),
2640            penalty.view(),
2641            mode.view(),
2642            &single_chain_cfg,
2643        )
2644        .expect("PG Gibbs must accept a single chain and return draws");
2645        assert_eq!(
2646            result.samples.nrows(),
2647            20,
2648            "single-chain PG run should return all 20 requested draws"
2649        );
2650    }
2651
2652    #[test]
2653    fn run_nuts_sampling_rejects_zero_chains_but_accepts_single_chain() {
2654        // Issue #399: only `chains=0` is degenerate — it produces an empty
2655        // initial-position vector and panics in `ndarray::stack`, so it must be
2656        // rejected up front with a typed error.
2657        //
2658        // A *single* chain, by contrast, is a supported, tested configuration
2659        // (`tests/test_sample_seed_is_reproducible.py`,
2660        // `tests/test_posterior_save_no_extension_roundtrip.py`,
2661        // `tests/test_penalty_sampling_survival_diagnostics_regressions.py` all
2662        // sample with `chains=1`): the engine splits each chain in half, so one
2663        // chain still yields the two split-chains the R-hat path needs, and
2664        // `compute_split_rhat_and_ess` early-returns gracefully for
2665        // `n_chains < 2`. The original #399 fix wrongly raised the floor to 2
2666        // and regressed those tests; this asserts `chains=1` *returns draws*.
2667        let x = array![[1.0], [1.0], [1.0]];
2668        let y = array![0.5, -0.5, 1.0];
2669        let weights = array![1.0, 1.0, 1.0];
2670        let penalty = array![[0.25]];
2671        let mode = array![0.0];
2672        let hessian = array![[1.25]];
2673
2674        let zero_chain_cfg = NutsConfig {
2675            n_samples: 50,
2676            nwarmup: 10,
2677            n_chains: 0,
2678            target_accept: 0.8,
2679            seed: 222,
2680        };
2681        let err = super::run_nuts_sampling(
2682            x.view(),
2683            y.view(),
2684            weights.view(),
2685            penalty.view(),
2686            mode.view(),
2687            hessian.view(),
2688            NutsFamily::Gaussian,
2689            1.0,
2690            gam_solve::estimate::Dispersion::Known(1.0),
2691            false,
2692            None,
2693            &zero_chain_cfg,
2694        )
2695        .expect_err("zero chains must be rejected before sampling");
2696        assert!(
2697            err.contains("n_chains must be >= 1"),
2698            "n_chains=0 gave unexpected error: {err}"
2699        );
2700
2701        let single_chain_cfg = NutsConfig {
2702            n_samples: 50,
2703            nwarmup: 10,
2704            n_chains: 1,
2705            target_accept: 0.8,
2706            seed: 222,
2707        };
2708        let result = super::run_nuts_sampling(
2709            x.view(),
2710            y.view(),
2711            weights.view(),
2712            penalty.view(),
2713            mode.view(),
2714            hessian.view(),
2715            NutsFamily::Gaussian,
2716            1.0,
2717            gam_solve::estimate::Dispersion::Known(1.0),
2718            false,
2719            None,
2720            &single_chain_cfg,
2721        )
2722        .expect("a single chain is a supported configuration and must return draws");
2723        assert_eq!(
2724            result.samples.nrows(),
2725            50,
2726            "single-chain run should return all 50 requested draws"
2727        );
2728    }
2729
2730    #[test]
2731    fn joint_hmc_boundary_rejects_nonbinomial_firth_family() {
2732        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
2733        let y = array![1.0, 2.0, 0.0, 3.0];
2734        let w = array![1.0, 1.0, 1.0, 1.0];
2735        let hessian = array![[1.5, 0.1], [0.1, 1.2]];
2736        let penalty_root = array![[0.4, 0.0], [0.0, 0.6]];
2737        let mode = array![0.0, 0.0];
2738        let rho_mode = array![0.0];
2739        let cfg = NutsConfig {
2740            n_samples: 20,
2741            nwarmup: 20,
2742            n_chains: 2,
2743            target_accept: 0.8,
2744            seed: 111,
2745        };
2746
2747        let inputs = JointBetaRhoInputs {
2748            x: x.view(),
2749            y: y.view(),
2750            weights: w.view(),
2751            likelihood: LikelihoodSpec {
2752                response: ResponseFamily::Poisson,
2753                link: InverseLink::Standard(StandardLink::Log),
2754            },
2755            gamma_shape: None,
2756            mode: mode.view(),
2757            hessian: hessian.view(),
2758            penalty_roots: vec![CanonicalPenalty::from_dense_root(
2759                penalty_root.clone(),
2760                penalty_root.ncols(),
2761            )],
2762            rho_mode: rho_mode.view(),
2763            rho_prior: RhoPrior::default(),
2764            firth_bias_reduction: true,
2765            trigger_skewness: 0.75,
2766        };
2767
2768        let err = match run_joint_beta_rho_sampling(&inputs, &cfg) {
2769            Ok(_) => panic!("Poisson joint HMC Firth should be rejected explicitly"),
2770            Err(err) => err,
2771        };
2772
2773        assert!(
2774            err.contains(
2775                "Joint HMC with Firth requires a Binomial inverse link with a Fisher-weight jet"
2776            ),
2777            "unexpected error: {err}"
2778        );
2779    }
2780
2781    #[test]
2782    fn joint_hmc_uses_combined_penalty_logdet_for_overlapping_penalties() {
2783        let x = array![[0.0, 0.0]];
2784        let y = array![0.0];
2785        let w = array![0.0];
2786        let mode = array![0.0, 0.0];
2787        let hessian = array![[1.0, 0.0], [0.0, 1.0]];
2788        let rho_mode = array![0.0, 0.0];
2789        let penalty_1 = array![[1.0, 0.0], [0.0, 1.0]];
2790        let penalty_2 = array![[2.0_f64.sqrt(), 0.0], [0.0, 1.0]];
2791        let target = JointBetaRhoPosterior::new(
2792            x.view(),
2793            y.view(),
2794            w.view(),
2795            mode.view(),
2796            hessian.view(),
2797            vec![
2798                CanonicalPenalty::from_dense_root(penalty_1, 2),
2799                CanonicalPenalty::from_dense_root(penalty_2, 2),
2800            ],
2801            rho_mode.view(),
2802            LikelihoodSpec {
2803                response: ResponseFamily::Gaussian,
2804                link: InverseLink::Standard(StandardLink::Identity),
2805            },
2806            None,
2807            RhoPrior::Flat,
2808            false,
2809        )
2810        .expect("joint target");
2811
2812        let params = array![0.0, 0.0, 0.0, 0.0];
2813        let (_, grad) = target.compute_joint_logp_and_grad(&params);
2814        assert!(
2815            (grad[2] - 5.0 / 12.0).abs() < 1.0e-10,
2816            "expected overlapping-penalty gradient 5/12, got {}",
2817            grad[2]
2818        );
2819        assert!(
2820            (grad[3] - 7.0 / 12.0).abs() < 1.0e-10,
2821            "expected overlapping-penalty gradient 7/12, got {}",
2822            grad[3]
2823        );
2824    }
2825
2826    #[test]
2827    fn joint_hmc_target_does_not_depend_on_rho_mode_when_prior_is_fixed() {
2828        let x = array![[0.0]];
2829        let y = array![0.0];
2830        let w = array![0.0];
2831        let mode = array![0.0];
2832        let hessian = array![[1.0]];
2833        let penalty = CanonicalPenalty::from_dense_root(array![[1.0]], 1);
2834        let prior = RhoPrior::Normal {
2835            mean: 0.25,
2836            sd: 1.7,
2837        };
2838
2839        let target_a = JointBetaRhoPosterior::new(
2840            x.view(),
2841            y.view(),
2842            w.view(),
2843            mode.view(),
2844            hessian.view(),
2845            vec![penalty.clone()],
2846            array![0.0].view(),
2847            LikelihoodSpec {
2848                response: ResponseFamily::Gaussian,
2849                link: InverseLink::Standard(StandardLink::Identity),
2850            },
2851            None,
2852            prior.clone(),
2853            false,
2854        )
2855        .expect("target a");
2856        let target_b = JointBetaRhoPosterior::new(
2857            x.view(),
2858            y.view(),
2859            w.view(),
2860            mode.view(),
2861            hessian.view(),
2862            vec![penalty],
2863            array![2.5].view(),
2864            LikelihoodSpec {
2865                response: ResponseFamily::Gaussian,
2866                link: InverseLink::Standard(StandardLink::Identity),
2867            },
2868            None,
2869            prior,
2870            false,
2871        )
2872        .expect("target b");
2873
2874        let params = array![0.0, -0.4];
2875        let (lp_a, grad_a) = target_a.compute_joint_logp_and_grad(&params);
2876        let (lp_b, grad_b) = target_b.compute_joint_logp_and_grad(&params);
2877        assert!((lp_a - lp_b).abs() < 1.0e-12);
2878        for i in 0..grad_a.len() {
2879            assert!(
2880                (grad_a[i] - grad_b[i]).abs() < 1.0e-12,
2881                "rho_mode leaked into target gradient at {}: {} vs {}",
2882                i,
2883                grad_a[i],
2884                grad_b[i]
2885            );
2886        }
2887    }
2888
2889    #[test]
2890    fn joint_hmc_binomial_sas_uses_runtime_link_state() {
2891        let x = array![[1.0], [1.0]];
2892        let y = array![1.0, 0.0];
2893        let weights = array![1.0, 1.0];
2894        let eta = array![0.3, -0.2];
2895        let sas_state = gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
2896            initial_epsilon: 0.4,
2897            initial_log_delta: -0.2,
2898        })
2899        .expect("sas state");
2900        let data = SharedData {
2901            x: Arc::new(x),
2902            y: Arc::new(y),
2903            weights: Arc::new(weights),
2904            mode: Arc::new(Array1::zeros(1)),
2905            offset: None,
2906            gamma_shape: 1.0,
2907            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
2908            n_samples: 2,
2909            dim: 1,
2910        };
2911
2912        let (ll_sas, _) = joint_family_logp_and_grad(
2913            &LikelihoodSpec {
2914                response: ResponseFamily::Binomial,
2915                link: InverseLink::Sas(sas_state),
2916            },
2917            &data,
2918            &eta,
2919        )
2920        .expect("sas joint logp");
2921        let (ll_logit, _) = joint_family_logp_and_grad(
2922            &LikelihoodSpec {
2923                response: ResponseFamily::Binomial,
2924                link: InverseLink::Standard(StandardLink::Logit),
2925            },
2926            &data,
2927            &eta,
2928        )
2929        .expect("logit joint logp");
2930
2931        assert!(
2932            (ll_sas - ll_logit).abs() > 1.0e-6,
2933            "adaptive SAS link should not collapse to the logit likelihood"
2934        );
2935    }
2936
2937    #[test]
2938    fn directional_cubic_diagnostic_is_rotation_invariant_for_hessian_eigenvectors() {
2939        let x = array![[1.0, 0.5], [-0.3, 1.4], [0.8, -1.1]];
2940        let c = array![0.7, -0.5, 0.2];
2941        let h = array![[4.0, 0.0], [0.0, 1.0]];
2942        let theta = std::f64::consts::FRAC_PI_4;
2943        let q = array![[theta.cos(), -theta.sin()], [theta.sin(), theta.cos()],];
2944        let x_rot = x.dot(&q);
2945        let h_rot = q.t().dot(&h).dot(&q);
2946
2947        let (base_max, base_vals) = laplace_directional_cubic_diagnostic(
2948            &h,
2949            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
2950            &c,
2951            true,
2952        )
2953        .expect("base diagnostic");
2954        let (rot_max, rot_vals) = laplace_directional_cubic_diagnostic(
2955            &h_rot,
2956            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x_rot)),
2957            &c,
2958            true,
2959        )
2960        .expect("rotated diagnostic");
2961
2962        let mut base_abs: Vec<f64> = base_vals.iter().map(|v| v.abs()).collect();
2963        let mut rot_abs: Vec<f64> = rot_vals.iter().map(|v| v.abs()).collect();
2964        base_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2965        rot_abs.sort_by(|a, b| a.partial_cmp(b).expect("finite compare"));
2966
2967        assert!((base_max - rot_max).abs() < 1.0e-10);
2968        for i in 0..base_abs.len() {
2969            assert!(
2970                (base_abs[i] - rot_abs[i]).abs() < 1.0e-10,
2971                "directional diagnostic changed under rotation at {}: {} vs {}",
2972                i,
2973                base_abs[i],
2974                rot_abs[i]
2975            );
2976        }
2977    }
2978
2979    /// Verify that joint HMC and REML compute identical penalty logdet
2980    /// derivatives for the same penalty system. This catches any divergence
2981    /// between the two code paths.
2982    #[test]
2983    fn joint_hmc_penalty_logdet_agrees_with_reml_path() {
2984        use gam_solve::estimate::reml::penalty_logdet::PenaltyPseudologdet;
2985
2986        // Two overlapping 3x3 penalties with non-trivial lambdas.
2987        let root_1 = array![[1.0, 0.5, 0.0], [0.0, 0.8, 0.3]];
2988        let root_2 = array![[0.0, 0.7, 0.0], [0.0, 0.0, 1.2]];
2989        let cp1 = CanonicalPenalty::from_dense_root(root_1, 3);
2990        let cp2 = CanonicalPenalty::from_dense_root(root_2, 3);
2991        let lambdas = [2.5_f64, 0.8];
2992        let penalties = [cp1.clone(), cp2.clone()];
2993
2994        // REML path: PenaltyPseudologdet directly.
2995        let pld =
2996            PenaltyPseudologdet::from_penalties(&penalties, &lambdas, 0.0, 3).expect("reml pld");
2997        let reml_value = pld.value();
2998        let (reml_d1, reml_d2) = pld.rho_derivatives_from_penalties(&penalties, &lambdas);
2999
3000        // Joint HMC path: build a JointBetaRhoPosterior and extract the
3001        // penalty logdet contribution. We isolate it by using zero data
3002        // (so likelihood = 0, penalty quadratic = 0) and Flat rho prior.
3003        let x = Array2::<f64>::zeros((1, 3));
3004        let y = array![0.0];
3005        let w = array![0.0];
3006        let mode = Array1::<f64>::zeros(3);
3007        let hessian = Array2::<f64>::eye(3);
3008        let rho = Array1::from_vec(lambdas.iter().map(|l| l.ln()).collect());
3009        let target = JointBetaRhoPosterior::new(
3010            x.view(),
3011            y.view(),
3012            w.view(),
3013            mode.view(),
3014            hessian.view(),
3015            vec![cp1, cp2],
3016            rho.view(),
3017            LikelihoodSpec {
3018                response: ResponseFamily::Gaussian,
3019                link: InverseLink::Standard(StandardLink::Identity),
3020            },
3021            None,
3022            RhoPrior::Flat,
3023            false,
3024        )
3025        .expect("joint target");
3026
3027        // Evaluate at beta=0, rho=ln(lambdas).
3028        let mut params = Array1::<f64>::zeros(3 + 2);
3029        params[3] = rho[0];
3030        params[4] = rho[1];
3031        let (logp, grad) = target.compute_joint_logp_and_grad(&params);
3032
3033        // logp should be 0.5 * reml_value (likelihood=0, prior=0, quadratic=0).
3034        assert!(
3035            (logp - 0.5 * reml_value).abs() < 1.0e-8,
3036            "joint HMC logdet value {} vs REML 0.5*{} = {}",
3037            logp,
3038            reml_value,
3039            0.5 * reml_value,
3040        );
3041
3042        // grad[3..5] should be 0.5 * reml_d1.
3043        for k in 0..2 {
3044            assert!(
3045                (grad[3 + k] - 0.5 * reml_d1[k]).abs() < 1.0e-8,
3046                "joint HMC logdet gradient[{}] = {} vs REML 0.5*{} = {}",
3047                k,
3048                grad[3 + k],
3049                reml_d1[k],
3050                0.5 * reml_d1[k],
3051            );
3052        }
3053
3054        // Sanity: second derivatives are available from REML but not directly
3055        // from a single HMC gradient call; just verify they're symmetric.
3056        assert!(
3057            (reml_d2[[0, 1]] - reml_d2[[1, 0]]).abs() < 1.0e-12,
3058            "REML penalty logdet Hessian not symmetric"
3059        );
3060    }
3061
3062    /// Verify the family-gating invariant: every LikelihoodSpec that
3063    /// joint_family_logp_and_grad accepts produces a result (not an error
3064    /// about missing implementation). Every family it rejects returns an
3065    /// explicit error. No family is silently remapped to a different one.
3066    #[test]
3067    fn joint_hmc_family_gating_never_remaps() {
3068        let data = SharedData {
3069            x: Arc::new(array![[1.0], [1.0]]),
3070            y: Arc::new(array![1.0, 0.0]),
3071            weights: Arc::new(array![1.0, 1.0]),
3072            mode: Arc::new(Array1::zeros(1)),
3073            offset: None,
3074            gamma_shape: 1.0,
3075            dispersion: gam_solve::estimate::Dispersion::Known(1.0),
3076            n_samples: 2,
3077            dim: 1,
3078        };
3079        let eta = array![0.1, -0.1];
3080
3081        // These families must succeed with their own inverse link.
3082        let accepted = [
3083            LikelihoodSpec {
3084                response: ResponseFamily::Binomial,
3085                link: InverseLink::Standard(StandardLink::Logit),
3086            },
3087            LikelihoodSpec {
3088                response: ResponseFamily::Binomial,
3089                link: InverseLink::Standard(StandardLink::Probit),
3090            },
3091            LikelihoodSpec {
3092                response: ResponseFamily::Binomial,
3093                link: InverseLink::Standard(StandardLink::CLogLog),
3094            },
3095            LikelihoodSpec {
3096                response: ResponseFamily::Gaussian,
3097                link: InverseLink::Standard(StandardLink::Identity),
3098            },
3099            LikelihoodSpec {
3100                response: ResponseFamily::Poisson,
3101                link: InverseLink::Standard(StandardLink::Log),
3102            },
3103            LikelihoodSpec {
3104                response: ResponseFamily::Gamma,
3105                link: InverseLink::Standard(StandardLink::Log),
3106            },
3107        ];
3108        for spec in &accepted {
3109            let result = joint_family_logp_and_grad(spec, &data, &eta);
3110            assert!(
3111                result.is_ok(),
3112                "spec {:?} should be accepted but got error: {:?}",
3113                spec,
3114                result.err(),
3115            );
3116        }
3117
3118        // SAS/BetaLogistic/Mixture must succeed with their real link state,
3119        // NOT be remapped to logit.
3120        let sas_state = gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3121            initial_epsilon: 0.0,
3122            initial_log_delta: 0.0,
3123        })
3124        .expect("sas state");
3125        let adaptive = [
3126            LikelihoodSpec {
3127                response: ResponseFamily::Binomial,
3128                link: InverseLink::Sas(sas_state),
3129            },
3130            LikelihoodSpec {
3131                response: ResponseFamily::Binomial,
3132                link: InverseLink::BetaLogistic(
3133                    gam_solve::mixture_link::state_from_sasspec(gam_problem::types::SasLinkSpec {
3134                        initial_epsilon: 0.0,
3135                        initial_log_delta: 0.0,
3136                    })
3137                    .expect("bl state"),
3138                ),
3139            },
3140        ];
3141        for spec in &adaptive {
3142            let result = joint_family_logp_and_grad(spec, &data, &eta);
3143            assert!(
3144                result.is_ok(),
3145                "adaptive spec {:?} should be accepted with its real link",
3146                spec,
3147            );
3148        }
3149
3150        // RoystonParmar must be explicitly rejected (not silently remapped).
3151        let rp_result = joint_family_logp_and_grad(
3152            &LikelihoodSpec {
3153                response: ResponseFamily::RoystonParmar,
3154                link: InverseLink::Standard(StandardLink::Logit),
3155            },
3156            &data,
3157            &eta,
3158        );
3159        assert!(
3160            rp_result.is_err(),
3161            "RoystonParmar should be rejected, not silently accepted"
3162        );
3163    }
3164
3165    /// The power-iteration refinement should find non-Gaussianity at least
3166    /// as large as the eigenvector-only pass (it's a supremum search).
3167    #[test]
3168    fn directional_cubic_power_iteration_finds_larger_or_equal_skewness() {
3169        // Construct a design where the maximum |gamma| occurs off-axis.
3170        // A single row with asymmetric structure makes the cubic form
3171        // peak between eigenvectors.
3172        let x = array![
3173            [2.0, 1.0],
3174            [-1.0, 2.0],
3175            [0.5, -0.5],
3176            [1.5, 0.3],
3177            [-0.8, 1.7],
3178        ];
3179        let c = array![1.0, -0.5, 0.3, -0.7, 0.4];
3180        let h = array![[3.0, 1.0], [1.0, 2.0]];
3181
3182        let (max_val, eigenvector_vals) = laplace_directional_cubic_diagnostic(
3183            &h,
3184            &DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(x)),
3185            &c,
3186            true,
3187        )
3188        .expect("diagnostic");
3189
3190        // max_val should be >= max of eigenvector-only values.
3191        let eig_max = eigenvector_vals
3192            .iter()
3193            .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
3194        assert!(
3195            max_val >= eig_max - 1.0e-12,
3196            "power iteration result {} should be >= eigenvector max {}",
3197            max_val,
3198            eig_max,
3199        );
3200    }
3201
3202    #[test]
3203    fn laplace_trustworthiness_is_block_local_and_threshold_shrinks_with_n() {
3204        // Two directions: one nearly Gaussian (tiny skewness), one strongly
3205        // skewed. The adaptive verdict must flag ONLY the skewed direction —
3206        // this is the block-local behavior #784 requires (keep cheap Laplace
3207        // where the Gaussian summary holds, correct only the curvature-heavy
3208        // block).
3209        let skew = array![0.01, 0.9];
3210
3211        // At a modest effective sample size the skewed direction dominates the
3212        // Laplace floor and must be flagged; the near-Gaussian one must not.
3213        let verdict = laplace_trustworthiness_from_skewness(&skew, 100.0);
3214        assert_eq!(
3215            verdict.untrustworthy_directions,
3216            vec![1],
3217            "only the strongly-skewed direction should be flagged (block-local)",
3218        );
3219        assert!(verdict.fallback_required());
3220        assert!((verdict.max_abs_skewness - 0.9).abs() < 1e-12);
3221
3222        // The threshold must SHRINK as n grows (Laplace gets stricter): a
3223        // direction tolerated at small n becomes untrustworthy at large n,
3224        // because the Gaussian floor it must beat is O(1/n).
3225        let t_small = laplace_skewness_threshold(25.0);
3226        let t_large = laplace_skewness_threshold(10_000.0);
3227        assert!(
3228            t_large < t_small,
3229            "validity threshold must tighten with sample size: {t_large} !< {t_small}",
3230        );
3231
3232        // Degenerate / empty curvature support => everything trustworthy
3233        // (nothing for the Gaussian summary to be wrong about).
3234        let none = laplace_trustworthiness_from_skewness(&skew, 0.0);
3235        assert!(!none.fallback_required());
3236        assert!(none.threshold.is_infinite());
3237    }
3238
3239    /// Synthetic block-excess oracle: an anharmonicity `ΔF(t) = a·Σ_k t_k⁴`
3240    /// whose per-direction strength carries unit ρ-sensitivity, so
3241    /// `∂ΔF/∂ρ_k = a·t_k⁴`. `a = 0` is a pure Gaussian block (exactly zero
3242    /// excess and zero ρ-gradient — the consistency anchor); `a > 0` is the
3243    /// quartic correction oracle the importance sampler is checked against.
3244    struct AnharmonicBlock {
3245        lambdas: Array1<f64>,
3246        a: f64,
3247    }
3248    impl super::BlockExcessTarget for AnharmonicBlock {
3249        fn block_dim(&self) -> usize {
3250            self.lambdas.len()
3251        }
3252        fn rho_dim(&self) -> usize {
3253            self.lambdas.len()
3254        }
3255        fn block_curvatures(&self) -> &Array1<f64> {
3256            &self.lambdas
3257        }
3258        fn excess(&self, t: &Array1<f64>) -> f64 {
3259            self.a * t.iter().map(|&x| x.powi(4)).sum::<f64>()
3260        }
3261        fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3262            t.mapv(|x| self.a * x.powi(4))
3263        }
3264        fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3265            // The synthetic oracle has no observation rows: its ΔF carries no
3266            // deviance channel, so the per-row score moment is empty and the
3267            // (b)–(d) channel assembly contracts against nothing.
3268            assert_eq!(t.len(), self.block_dim(), "displacement dim mismatch");
3269            Array1::zeros(0)
3270        }
3271        fn base_neg_score(&self) -> Array1<f64> {
3272            Array1::zeros(0)
3273        }
3274    }
3275
3276    #[test]
3277    fn block_sampled_marginal_is_zero_for_gaussian_block() {
3278        // A purely Gaussian block has ΔF ≡ 0, so the sampled correction (the
3279        // log-ratio of true to Laplace block free energy) must be exactly 0,
3280        // with a zero ρ-gradient. This is the consistency anchor: where the
3281        // Gaussian summary holds, the fallback is a no-op.
3282        let target = AnharmonicBlock {
3283            lambdas: array![2.0, 0.5],
3284            a: 0.0,
3285        };
3286        let out = super::block_sampled_marginal_correction(&target).expect("correction");
3287        assert!(
3288            out.value.abs() < 1e-12,
3289            "Gaussian block value {}",
3290            out.value
3291        );
3292        assert!(out.rho_gradient.iter().all(|&g| g.abs() < 1e-12));
3293        assert!(out.n_draws > 0);
3294    }
3295
3296    #[test]
3297    fn block_sampled_marginal_recovers_analytic_quartic_correction() {
3298        // 1-D block with a quartic excess ΔF(t) = a t⁴ (a small positive
3299        // anharmonicity). Then exp(Δ_b) = E_{t~N(0,1/λ)}[exp(−a t⁴)], a known
3300        // 1-D integral the IS estimator must recover. We check the sampled Δ_b
3301        // matches a high-accuracy deterministic quadrature of the same
3302        // expectation, and that Δ_b < 0 (an added quartic penalty makes the
3303        // true block mass *smaller* than the Gaussian's).
3304        let lambda = 3.0_f64;
3305        let a = 0.05_f64;
3306        let target = AnharmonicBlock {
3307            lambdas: array![lambda],
3308            a,
3309        };
3310        let out = super::block_sampled_marginal_correction(&target).expect("correction");
3311
3312        // Deterministic reference: Δ_b = log E_{t~N(0,1/λ)}[exp(−a t⁴)] via a
3313        // fine trapezoid rule over the Gaussian density.
3314        let sigma = (1.0 / lambda).sqrt();
3315        let steps = 20_001;
3316        let lo = -8.0 * sigma;
3317        let hi = 8.0 * sigma;
3318        let h = (hi - lo) / (steps as f64 - 1.0);
3319        let mut integral = 0.0_f64;
3320        for i in 0..steps {
3321            let tt = lo + h * i as f64;
3322            let gauss = (-(tt * tt) / (2.0 * sigma * sigma)).exp()
3323                / (sigma * (2.0 * std::f64::consts::PI).sqrt());
3324            let w = if i == 0 || i == steps - 1 { 0.5 } else { 1.0 };
3325            integral += w * gauss * (-a * tt.powi(4)).exp() * h;
3326        }
3327        let reference = integral.ln();
3328        assert!(
3329            (out.value - reference).abs() < 5e-3,
3330            "sampled Δ_b {} vs reference {}",
3331            out.value,
3332            reference,
3333        );
3334        assert!(out.value < 0.0, "quartic penalty must shrink block mass");
3335    }
3336
3337    /// A block target whose excess and per-row score are driven by real design
3338    /// matvecs `s = X·(V_b·t)` — the SAME structure as the production
3339    /// `Gam784BlockTarget` — so it can compute those matvecs either serially
3340    /// (one `fast_av` per draw) or batched (one GEMM over all draws), toggled by
3341    /// `batched`. The two must yield a bit-for-bit (to FP-reassociation
3342    /// tolerance) identical correction: that is exactly the #1082 batching
3343    /// contract — GEMM changes HOW the matvec is computed, never WHAT.
3344    struct MatvecBlock {
3345        lambdas: Array1<f64>,
3346        x: Array2<f64>,
3347        v_b: Array2<f64>,
3348        y: Array1<f64>,
3349        batched: bool,
3350    }
3351    impl MatvecBlock {
3352        fn s_of(&self, t: &Array1<f64>) -> Array1<f64> {
3353            let delta = self.v_b.dot(t);
3354            gam_linalg::faer_ndarray::fast_av(&self.x, &delta)
3355        }
3356        // A smooth, finite, family-like excess + per-row score built from `s`.
3357        fn excess_and_ngs(&self, s: &Array1<f64>) -> (f64, Array1<f64>) {
3358            let mut excess = 0.0;
3359            let mut ngs = Array1::<f64>::zeros(s.len());
3360            for i in 0..s.len() {
3361                let mu = (self.y[i] + s[i]).tanh();
3362                excess += 0.5 * s[i] * s[i] - 0.1 * mu;
3363                ngs[i] = mu - self.y[i];
3364            }
3365            (excess, ngs)
3366        }
3367    }
3368    impl super::BlockExcessTarget for MatvecBlock {
3369        fn block_dim(&self) -> usize {
3370            self.lambdas.len()
3371        }
3372        fn rho_dim(&self) -> usize {
3373            self.lambdas.len()
3374        }
3375        fn block_curvatures(&self) -> &Array1<f64> {
3376            &self.lambdas
3377        }
3378        fn excess(&self, t: &Array1<f64>) -> f64 {
3379            self.excess_and_ngs(&self.s_of(t)).0
3380        }
3381        fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64> {
3382            t.mapv(|x| 0.01 * x)
3383        }
3384        fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64> {
3385            self.excess_and_ngs(&self.s_of(t)).1
3386        }
3387        fn base_neg_score(&self) -> Array1<f64> {
3388            self.excess_and_ngs(&self.s_of(&Array1::zeros(self.block_dim())))
3389                .1
3390        }
3391        fn excess_with_displaced_neg_score_batch(
3392            &self,
3393            draws: &Array2<f64>,
3394        ) -> Vec<(f64, Option<Array1<f64>>)> {
3395            if !self.batched {
3396                // Serial reference: per-column, exactly the default path.
3397                let mut out = Vec::with_capacity(draws.ncols());
3398                let mut t = Array1::<f64>::zeros(draws.nrows());
3399                for s in 0..draws.ncols() {
3400                    t.assign(&draws.column(s));
3401                    out.push(self.excess_with_displaced_neg_score(&t));
3402                }
3403                return out;
3404            }
3405            // Batched: Δ = V_b·T then S = X·Δ as two GEMMs, then per-column.
3406            let delta_all = gam_linalg::faer_ndarray::fast_ab(&self.v_b, draws);
3407            let s_all = gam_linalg::faer_ndarray::fast_ab(&self.x, &delta_all);
3408            (0..draws.ncols())
3409                .map(|c| {
3410                    let (e, ngs) = self.excess_and_ngs(&s_all.column(c).to_owned());
3411                    if e.is_finite() {
3412                        (e, Some(ngs))
3413                    } else {
3414                        (e, None)
3415                    }
3416                })
3417                .collect()
3418        }
3419    }
3420
3421    #[test]
3422    fn block_sampled_marginal_batched_matches_serial_matvec() {
3423        // Real design / block-frame matvecs, large enough that the GEMM path is
3424        // actually taken (n, p ≥ faer threshold). The batched override must give
3425        // the same correction value, ρ-gradient, and moments as the serial path.
3426        let n = 80usize;
3427        let p = 40usize;
3428        let m = 3usize;
3429        let mut x = Array2::<f64>::zeros((n, p));
3430        for i in 0..n {
3431            for j in 0..p {
3432                x[(i, j)] = ((i * 7 + j * 13) % 11) as f64 * 0.05 - 0.25;
3433            }
3434        }
3435        let mut v_b = Array2::<f64>::zeros((p, m));
3436        for i in 0..p {
3437            for r in 0..m {
3438                v_b[(i, r)] = ((i * 3 + r * 5) % 7) as f64 * 0.1 - 0.3;
3439            }
3440        }
3441        let y: Array1<f64> = (0..n).map(|i| ((i % 5) as f64) * 0.2).collect();
3442        let lambdas = array![2.0, 1.0, 0.5];
3443
3444        let serial = super::block_sampled_marginal_correction(&MatvecBlock {
3445            lambdas: lambdas.clone(),
3446            x: x.clone(),
3447            v_b: v_b.clone(),
3448            y: y.clone(),
3449            batched: false,
3450        })
3451        .expect("serial");
3452        let batched = super::block_sampled_marginal_correction(&MatvecBlock {
3453            lambdas,
3454            x,
3455            v_b,
3456            y,
3457            batched: true,
3458        })
3459        .expect("batched");
3460
3461        assert_eq!(serial.n_draws, batched.n_draws);
3462        assert!(
3463            (serial.value - batched.value).abs() <= 1e-10 * (1.0 + serial.value.abs()),
3464            "value serial {} vs batched {}",
3465            serial.value,
3466            batched.value
3467        );
3468        for k in 0..serial.rho_gradient.len() {
3469            assert!(
3470                (serial.rho_gradient[k] - batched.rho_gradient[k]).abs()
3471                    <= 1e-10 * (1.0 + serial.rho_gradient[k].abs()),
3472                "rho_gradient[{k}] serial {} vs batched {}",
3473                serial.rho_gradient[k],
3474                batched.rho_gradient[k]
3475            );
3476        }
3477        let ms = serial.moments.expect("serial moments");
3478        let mb = batched.moments.expect("batched moments");
3479        for (a, b) in ms.e_t.iter().zip(mb.e_t.iter()) {
3480            assert!((a - b).abs() <= 1e-10 * (1.0 + a.abs()), "e_t {a} vs {b}");
3481        }
3482        for (a, b) in ms.e_neg_score.iter().zip(mb.e_neg_score.iter()) {
3483            assert!(
3484                (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3485                "e_neg_score {a} vs {b}"
3486            );
3487        }
3488        for (a, b) in ms.e_t_neg_score.iter().zip(mb.e_t_neg_score.iter()) {
3489            assert!(
3490                (a - b).abs() <= 1e-10 * (1.0 + a.abs()),
3491                "e_t_neg_score {a} vs {b}"
3492            );
3493        }
3494    }
3495
3496    #[test]
3497    fn logit_pg_rao_blackwell_returns_finite_terms() {
3498        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3499        let y = array![1.0, 0.0, 1.0, 0.0];
3500        let w = array![1.0, 1.0, 1.0, 1.0];
3501        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3502        let mode = array![0.0, 0.0];
3503        let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3504        let cfg = NutsConfig {
3505            n_samples: 30,
3506            nwarmup: 30,
3507            n_chains: 2,
3508            target_accept: 0.8,
3509            seed: 789,
3510        };
3511
3512        let rb = super::estimate_logit_pg_rao_blackwell_terms(
3513            x.view(),
3514            y.view(),
3515            w.view(),
3516            penalty.view(),
3517            mode.view(),
3518            &roots,
3519            &cfg,
3520        )
3521        .expect("rao-blackwell PG should run");
3522
3523        assert_eq!(rb.len(), 1);
3524        assert!(rb[0].is_finite());
3525        assert!(rb[0] >= 0.0);
3526    }
3527
3528    #[test]
3529    fn logit_pg_rao_blackwell_rejects_non_bernoulli_response() {
3530        let x = array![[1.0], [1.0]];
3531        let y = array![0.25, 1.0];
3532        let w = array![1.0, 1.0];
3533        let penalty = array![[0.1]];
3534        let mode = array![0.0];
3535        let roots = vec![array![[0.1_f64.sqrt()]]];
3536        let cfg = NutsConfig {
3537            n_samples: 1,
3538            nwarmup: 1,
3539            n_chains: 1,
3540            target_accept: 0.8,
3541            seed: 654,
3542        };
3543
3544        let result = super::estimate_logit_pg_rao_blackwell_terms(
3545            x.view(),
3546            y.view(),
3547            w.view(),
3548            penalty.view(),
3549            mode.view(),
3550            &roots,
3551            &cfg,
3552        );
3553
3554        let err = result
3555            .err()
3556            .expect("PG Rao-Blackwell should reject proportion rows");
3557        assert!(
3558            err.contains("response must be exactly 0 or 1"),
3559            "unexpected error: {err}"
3560        );
3561    }
3562
3563    #[test]
3564    fn logit_pg_rao_blackwell_matches_beta_quadratic_moment_sanity() {
3565        let x = array![[1.0, 0.2], [1.0, -0.1], [1.0, 1.2], [1.0, -0.7]];
3566        let y = array![1.0, 0.0, 1.0, 0.0];
3567        let w = array![1.0, 1.0, 1.0, 1.0];
3568        let penalty = array![[0.2, 0.0], [0.0, 0.4]];
3569        let mode = array![0.0, 0.0];
3570        let roots = vec![array![[0.2_f64.sqrt(), 0.0], [0.0, 0.4_f64.sqrt()]]];
3571        let cfg = NutsConfig {
3572            n_samples: 120,
3573            nwarmup: 80,
3574            n_chains: 2,
3575            target_accept: 0.8,
3576            seed: 901,
3577        };
3578
3579        let gibbs = run_logit_polya_gamma_gibbs(
3580            x.view(),
3581            y.view(),
3582            w.view(),
3583            penalty.view(),
3584            mode.view(),
3585            &cfg,
3586        )
3587        .expect("pg gibbs should run");
3588        let mc_quad = gibbs
3589            .samples
3590            .rows()
3591            .into_iter()
3592            .map(|beta| {
3593                let sb = penalty.dot(&beta.to_owned());
3594                beta.dot(&sb)
3595            })
3596            .sum::<f64>()
3597            / (gibbs.samples.nrows() as f64);
3598
3599        let rb = super::estimate_logit_pg_rao_blackwell_terms(
3600            x.view(),
3601            y.view(),
3602            w.view(),
3603            penalty.view(),
3604            mode.view(),
3605            &roots,
3606            &cfg,
3607        )
3608        .expect("rao-blackwell PG should run");
3609
3610        let diff = (rb[0] - mc_quad).abs();
3611        assert!(
3612            diff < 0.35,
3613            "Rao-Blackwell vs beta-moment mismatch too large: rb={}, mc={}, diff={}",
3614            rb[0],
3615            mc_quad,
3616            diff
3617        );
3618    }
3619
3620    #[test]
3621    fn survival_hmc_structural_monotonic_returns_finitevalues() {
3622        let age_entry = array![1.0];
3623        let age_exit = array![2.0];
3624        let event_target = array![1u8];
3625        let event_competing = array![0u8];
3626        let sampleweight = array![1.0];
3627        let x_entry = array![[1.0, 0.2]];
3628        let x_exit = array![[1.0, 0.6]];
3629        let x_derivative = array![[0.0, 1.0]];
3630        let penalties = PenaltyBlocks::new(Vec::new());
3631        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3632        let mode = array![0.0, 0.0];
3633        let hessian = Array2::<f64>::eye(2);
3634
3635        let posterior = super::survival_hmc::SurvivalPosterior::new(
3636            age_entry.view(),
3637            age_exit.view(),
3638            event_target.view(),
3639            event_competing.view(),
3640            sampleweight.view(),
3641            x_entry.view(),
3642            x_exit.view(),
3643            x_derivative.view(),
3644            None,
3645            None,
3646            None,
3647            penalties,
3648            monotonicity,
3649            SurvivalSpec::Net,
3650            true,
3651            2,
3652            mode.view(),
3653            hessian.view(),
3654        )
3655        .expect("construct survival posterior");
3656
3657        let position = array![0.0, 0.0];
3658        let mut grad = Array1::<f64>::zeros(2);
3659        let logp = HamiltonianTarget::logp_and_grad(&posterior, &position, &mut grad);
3660        assert!(logp.is_finite());
3661        assert!(grad.iter().all(|v| v.is_finite()));
3662    }
3663
3664    #[test]
3665    fn survival_hmc_structural_monotonic_differs_from_linear_geometry() {
3666        let age_entry = array![1.0];
3667        let age_exit = array![2.0];
3668        let event_target = array![1u8];
3669        let event_competing = array![0u8];
3670        let sampleweight = array![1.0];
3671        let x_entry = array![[0.2, 0.1]];
3672        let x_exit = array![[0.6, 0.3]];
3673        let x_derivative = array![[1.0, 0.0]];
3674        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3675        let mode = array![0.0, 0.0];
3676        let hessian = Array2::<f64>::eye(2);
3677        let z = array![std::f64::consts::LN_2, 0.0];
3678
3679        let posterior_linear = super::survival_hmc::SurvivalPosterior::new(
3680            age_entry.view(),
3681            age_exit.view(),
3682            event_target.view(),
3683            event_competing.view(),
3684            sampleweight.view(),
3685            x_entry.view(),
3686            x_exit.view(),
3687            x_derivative.view(),
3688            None,
3689            None,
3690            None,
3691            PenaltyBlocks::new(Vec::new()),
3692            monotonicity,
3693            SurvivalSpec::Net,
3694            false,
3695            0,
3696            mode.view(),
3697            hessian.view(),
3698        )
3699        .expect("construct linear posterior");
3700        let mut grad_linear = Array1::<f64>::zeros(2);
3701        HamiltonianTarget::logp_and_grad(&posterior_linear, &z, &mut grad_linear);
3702
3703        let posterior_struct = super::survival_hmc::SurvivalPosterior::new(
3704            age_entry.view(),
3705            age_exit.view(),
3706            event_target.view(),
3707            event_competing.view(),
3708            sampleweight.view(),
3709            x_entry.view(),
3710            x_exit.view(),
3711            x_derivative.view(),
3712            None,
3713            None,
3714            None,
3715            PenaltyBlocks::new(Vec::new()),
3716            monotonicity,
3717            SurvivalSpec::Net,
3718            true,
3719            2,
3720            mode.view(),
3721            hessian.view(),
3722        )
3723        .expect("construct structural posterior");
3724        let mut grad_struct = Array1::<f64>::zeros(2);
3725        HamiltonianTarget::logp_and_grad(&posterior_struct, &z, &mut grad_struct);
3726
3727        assert!(
3728            (grad_struct[0] - grad_linear[0]).abs() > 1e-6,
3729            "expected structural and linear fallback gradients to differ"
3730        );
3731        assert!(grad_struct[0].is_finite());
3732        assert!(grad_linear[0].is_finite());
3733    }
3734
3735    #[test]
3736    fn survival_hmc_fallback_barrier_rejects_offsets_below_monotonicity_threshold() {
3737        let age_entry = array![1.0];
3738        let age_exit = array![2.0];
3739        let event_target = array![1u8];
3740        let event_competing = array![0u8];
3741        let sampleweight = array![1.0];
3742        let x_entry = array![[1.0, 0.0]];
3743        let x_exit = array![[1.0, 0.0]];
3744        // Zero derivative design so derivative_offset_exit drives d_eta/dt.
3745        let x_derivative = array![[0.0, 0.0]];
3746        let penalties = PenaltyBlocks::new(Vec::new());
3747        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3748        let mode = array![0.0, 0.0];
3749        let hessian = Array2::<f64>::eye(2);
3750        let z = array![0.0, 0.0];
3751
3752        let posterior_no_offset = super::survival_hmc::SurvivalPosterior::new(
3753            age_entry.view(),
3754            age_exit.view(),
3755            event_target.view(),
3756            event_competing.view(),
3757            sampleweight.view(),
3758            x_entry.view(),
3759            x_exit.view(),
3760            x_derivative.view(),
3761            None,
3762            None,
3763            Some(array![0.0].view()),
3764            penalties.clone(),
3765            monotonicity,
3766            SurvivalSpec::Net,
3767            false,
3768            0,
3769            mode.view(),
3770            hessian.view(),
3771        )
3772        .expect("construct posterior without derivative offset");
3773        let mut grad_no_offset = Array1::<f64>::zeros(2);
3774        let logp_no_offset =
3775            HamiltonianTarget::logp_and_grad(&posterior_no_offset, &z, &mut grad_no_offset);
3776
3777        let posteriorwith_offset = super::survival_hmc::SurvivalPosterior::new(
3778            age_entry.view(),
3779            age_exit.view(),
3780            event_target.view(),
3781            event_competing.view(),
3782            sampleweight.view(),
3783            x_entry.view(),
3784            x_exit.view(),
3785            x_derivative.view(),
3786            None,
3787            None,
3788            Some(array![2.0].view()),
3789            penalties,
3790            monotonicity,
3791            SurvivalSpec::Net,
3792            false,
3793            0,
3794            mode.view(),
3795            hessian.view(),
3796        )
3797        .expect("construct posterior with derivative offset");
3798        let mut gradwith_offset = Array1::<f64>::zeros(2);
3799        let logpwith_offset =
3800            HamiltonianTarget::logp_and_grad(&posteriorwith_offset, &z, &mut gradwith_offset);
3801
3802        assert!(!logp_no_offset.is_finite());
3803        assert!(!logpwith_offset.is_finite());
3804        assert!(grad_no_offset.iter().all(|v| *v == 0.0));
3805        assert!(gradwith_offset.iter().all(|v| *v == 0.0));
3806    }
3807
3808    #[test]
3809    fn survival_hmc_fallback_barrier_becomes_finite_once_offset_clears_guard() {
3810        let age_entry = array![1.0];
3811        let age_exit = array![2.0];
3812        let event_target = array![1u8];
3813        let event_competing = array![0u8];
3814        let sampleweight = array![1.0];
3815        let x_entry = array![[1.0, 0.0]];
3816        let x_exit = array![[1.0, 0.0]];
3817        let x_derivative = array![[0.0, 0.0]];
3818        let penalties = PenaltyBlocks::new(Vec::new());
3819        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3820        let mode = array![0.0, 0.0];
3821        let hessian = Array2::<f64>::eye(2);
3822        let z = array![0.0, 0.0];
3823
3824        let posterior_below_guard = super::survival_hmc::SurvivalPosterior::new(
3825            age_entry.view(),
3826            age_exit.view(),
3827            event_target.view(),
3828            event_competing.view(),
3829            sampleweight.view(),
3830            x_entry.view(),
3831            x_exit.view(),
3832            x_derivative.view(),
3833            None,
3834            None,
3835            Some(array![2.0].view()),
3836            penalties.clone(),
3837            monotonicity,
3838            SurvivalSpec::Net,
3839            false,
3840            0,
3841            mode.view(),
3842            hessian.view(),
3843        )
3844        .expect("construct posterior below derivative guard");
3845        let mut grad_below_guard = Array1::<f64>::zeros(2);
3846        let logp_below_guard =
3847            HamiltonianTarget::logp_and_grad(&posterior_below_guard, &z, &mut grad_below_guard);
3848
3849        let posterior_above_guard = super::survival_hmc::SurvivalPosterior::new(
3850            age_entry.view(),
3851            age_exit.view(),
3852            event_target.view(),
3853            event_competing.view(),
3854            sampleweight.view(),
3855            x_entry.view(),
3856            x_exit.view(),
3857            x_derivative.view(),
3858            None,
3859            None,
3860            Some(array![3.1].view()),
3861            penalties,
3862            monotonicity,
3863            SurvivalSpec::Net,
3864            false,
3865            0,
3866            mode.view(),
3867            hessian.view(),
3868        )
3869        .expect("construct posterior above derivative guard");
3870        let mut grad_above_guard = Array1::<f64>::zeros(2);
3871        let logp_above_guard =
3872            HamiltonianTarget::logp_and_grad(&posterior_above_guard, &z, &mut grad_above_guard);
3873
3874        assert!(!logp_below_guard.is_finite());
3875        assert!(logp_above_guard.is_finite());
3876        assert!(grad_below_guard.iter().all(|v| *v == 0.0));
3877        assert!(grad_above_guard.iter().all(|v| v.is_finite()));
3878    }
3879
3880    #[test]
3881    fn survival_hmc_structural_monotonic_handles_sparse_multirow_geometry() {
3882        let age_entry = array![1.0, 1.2];
3883        let age_exit = array![2.0, 2.4];
3884        let event_target = array![1u8, 1u8];
3885        let event_competing = array![0u8, 0u8];
3886        let sampleweight = array![1.0, 1.0];
3887        let x_entry = array![[0.1, 0.0, 0.2], [0.2, 0.1, 0.2]];
3888        let x_exit = array![[0.4, 0.2, 0.3], [0.6, 0.1, 0.3]];
3889        // First row constrains only column 0, second row constrains columns 0 and 1.
3890        let x_derivative = array![[1.0, 0.0, 0.0], [0.5, 1.0, 0.0]];
3891        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
3892        let mode = array![4.0, 2.0, 0.0];
3893        let hessian = Array2::<f64>::eye(3);
3894        let z = array![0.05, -0.1, 0.15];
3895
3896        let posterior = super::survival_hmc::SurvivalPosterior::new(
3897            age_entry.view(),
3898            age_exit.view(),
3899            event_target.view(),
3900            event_competing.view(),
3901            sampleweight.view(),
3902            x_entry.view(),
3903            x_exit.view(),
3904            x_derivative.view(),
3905            None,
3906            None,
3907            None,
3908            PenaltyBlocks::new(Vec::new()),
3909            monotonicity,
3910            SurvivalSpec::Net,
3911            true,
3912            2,
3913            mode.view(),
3914            hessian.view(),
3915        )
3916        .expect("construct structural posterior");
3917
3918        let mut grad = Array1::<f64>::zeros(3);
3919        let logp = HamiltonianTarget::logp_and_grad(&posterior, &z, &mut grad);
3920        assert!(logp.is_finite());
3921        assert!(grad.iter().all(|v| v.is_finite()));
3922    }
3923}
3924
3925/// Implement HamiltonianTarget for NUTS with analytical gradients.
3926impl HamiltonianTarget<Array1<f64>> for NutsPosterior {
3927    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
3928        NUTS_RESIDUAL_SCRATCH.with(|scratch| {
3929            let mut residual = scratch.borrow_mut();
3930            if residual.len() != self.data.n_samples {
3931                *residual = Array1::<f64>::zeros(self.data.n_samples);
3932            }
3933            self.compute_logp_and_grad_nd_into(position, &mut residual, grad)
3934        })
3935    }
3936}
3937
3938/// Configuration for NUTS sampling.
3939#[derive(Clone, Debug, Serialize, Deserialize)]
3940pub struct NutsConfig {
3941    /// Number of samples to collect (after warmup)
3942    pub n_samples: usize,
3943    /// Number of warmup samples to discard
3944    pub nwarmup: usize,
3945    /// Number of parallel chains
3946    pub n_chains: usize,
3947    /// Target acceptance probability (0.6-0.9 recommended)
3948    pub target_accept: f64,
3949    /// Seed for deterministic chain initialization
3950    #[serde(default = "default_nuts_seed")]
3951    pub seed: u64,
3952}
3953
3954fn default_nuts_seed() -> u64 {
3955    42
3956}
3957
3958fn validate_nuts_target_accept(target_accept: f64) -> Result<(), HmcError> {
3959    if target_accept.is_finite() && target_accept > 0.0 && target_accept < 1.0 {
3960        Ok(())
3961    } else {
3962        Err(HmcError::InvalidConfig {
3963            reason: format!(
3964                "NUTS target_accept must be finite and lie in (0, 1), got {target_accept}"
3965            ),
3966        })
3967    }
3968}
3969
3970/// Minimum number of post-warmup draws per chain that keeps the split-R-hat /
3971/// ESS machinery well-defined. Each chain is split in half for the
3972/// Gelman-Rubin diagnostic (`compute_split_rhat_and_ess` and the engine's own
3973/// run-stats path), so both halves need at least two draws, i.e. four draws
3974/// total. Below this the engine `.expect(...)` calls (empty-stack / "split
3975/// R-hat and ESS require at least 2 split chains and 2 draws per split chain")
3976/// panic across the FFI boundary instead of returning a typed error.
3977const MIN_NUTS_SAMPLES: usize = 4;
3978
3979/// Minimum number of parallel chains. With zero chains the engine receives an
3980/// empty initial-position vector and panics in `ndarray::stack` (and the
3981/// Laplace fallback would produce an empty `(0, p)` posterior). A *single*
3982/// chain is well-defined and is a supported, tested configuration: the engine
3983/// splits each chain in half for the diagnostic, so one chain still yields the
3984/// two split-chains the R-hat path needs, and `compute_split_rhat_and_ess`
3985/// gracefully early-returns for `n_chains < 2`. We therefore only reject the
3986/// genuinely-degenerate `n_chains == 0`.
3987const MIN_NUTS_CHAINS: usize = 1;
3988
3989/// Validate the draw / chain counts of a NUTS configuration up front, mirroring
3990/// `validate_nuts_target_accept`, so that out-of-range values surface as a typed
3991/// `HmcError::InvalidConfig` *before* the sampling engine is constructed rather
3992/// than as a panic caught at the FFI boundary.
3993fn validate_nuts_draws(config: &NutsConfig) -> Result<(), HmcError> {
3994    if config.n_chains < MIN_NUTS_CHAINS {
3995        return Err(HmcError::InvalidConfig {
3996            reason: format!(
3997                "NUTS n_chains must be >= {MIN_NUTS_CHAINS}; with zero chains the \
3998                 sampler has no initial positions to run, got {}",
3999                config.n_chains
4000            ),
4001        });
4002    }
4003    if config.n_samples < MIN_NUTS_SAMPLES {
4004        return Err(HmcError::InvalidConfig {
4005            reason: format!(
4006                "NUTS n_samples must be >= {MIN_NUTS_SAMPLES} so split-R-hat / ESS \
4007                 diagnostics are defined, got {}",
4008                config.n_samples
4009            ),
4010        });
4011    }
4012    Ok(())
4013}
4014
4015/// Full up-front validation of a NUTS configuration shared by every sampling
4016/// entry point (dense NUTS, link-wiggle, joint (β, ρ), survival, the
4017/// auto-selected Pólya-Gamma Gibbs path, and the Laplace-Gaussian fallback).
4018pub(crate) fn validate_nuts_config(config: &NutsConfig) -> Result<(), HmcError> {
4019    validate_nuts_target_accept(config.target_accept)?;
4020    validate_nuts_draws(config)?;
4021    Ok(())
4022}
4023
4024#[inline]
4025fn splitmix64(x: u64) -> u64 {
4026    gam_linalg::utils::splitmix64_hash(x)
4027}
4028
4029#[inline]
4030fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
4031    splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
4032}
4033
4034#[inline]
4035fn nuts_transition_seed(seed: u64, stream: u64) -> u64 {
4036    splitmix64(seed ^ stream ^ 0xA24B_AED4_963E_E407)
4037}
4038
4039#[inline]
4040fn gibbs_pg_seed(seed: u64, chain: usize, stream: u64, iter: usize) -> u64 {
4041    chain_stream_seed(
4042        seed,
4043        chain,
4044        stream ^ ((iter as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
4045    )
4046}
4047
4048fn draw_logit_pg1_omega(
4049    shapes: ArrayView1<'_, u32>,
4050    tilts: ArrayView1<'_, f64>,
4051    seed: u64,
4052    out: &mut Array1<f64>,
4053) -> Result<(), String> {
4054    if out.len() != tilts.len() {
4055        return Err(HmcError::DimensionMismatch {
4056            reason: "draw_logit_pg1_omega: output length mismatch".to_string(),
4057        }
4058        .into());
4059    }
4060    let draws = crate::gpu_polya_gamma::draw_batch(PolyaGammaBatchInput {
4061        shapes,
4062        tilts,
4063        seed: PgSeed(seed),
4064    })?;
4065    out.assign(&draws);
4066    out.mapv_inplace(|v| v.max(1.0e-12));
4067    Ok(())
4068}
4069
4070/// Parameter dimension above which the posterior is treated as "high-dimensional"
4071/// for the purpose of the more conservative sampler heuristics below: a higher
4072/// target-acceptance floor (smaller leapfrog steps) and stronger mass-matrix
4073/// regularization. The boundary matches the `dense_max_dim` cap at which the
4074/// engine stops attempting dense mass-matrix adaptation.
4075const HIGH_DIM_THRESHOLD: usize = 50;
4076
4077/// Target-acceptance floor enforced for high-dimensional posteriors
4078/// (`dim > HIGH_DIM_THRESHOLD`). NUTS efficiency degrades faster with too-large
4079/// steps in high dimensions, so we refuse to honor a requested accept below this.
4080const HIGH_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.92;
4081/// Target-acceptance floor for low-dimensional posteriors.
4082const LOW_DIM_TARGET_ACCEPT_FLOOR: f64 = 0.90;
4083/// Upper bound on the effective target acceptance. Pushing target accept toward
4084/// 1 collapses the step size and stalls mixing, so we cap the requested value.
4085const MAX_TARGET_ACCEPT: f64 = 0.95;
4086
4087/// Minimum warmup length below which mass-matrix adaptation is disabled: the
4088/// windowed (Stan-style) adaptation schedule needs enough warmup iterations to
4089/// populate its initial / terminal buffers, otherwise the estimated metric is
4090/// noise. With fewer warmup steps the sampler runs on the identity metric.
4091const MIN_WARMUP_FOR_MASS_ADAPT: usize = 80;
4092
4093/// Largest parameter dimension for which the engine attempts *dense* mass-matrix
4094/// adaptation; above this it falls back to a diagonal metric (an `O(p²)` dense
4095/// metric is neither affordable nor reliably estimable from limited warmup).
4096const DENSE_MASS_MATRIX_MAX_DIM: usize = 75;
4097
4098/// Mass-matrix ridge (added to the diagonal of the estimated metric) for the
4099/// general (mean-family) sampler. The high-dimensional value is larger because
4100/// the warmup metric estimate is noisier relative to its scale as `p` grows.
4101const MASS_REGULARIZE_HIGH_DIM: f64 = 0.14;
4102const MASS_REGULARIZE_LOW_DIM: f64 = 0.10;
4103/// Mass-matrix ridge for survival posteriors, which are frequently skewed by
4104/// censoring / rare events and so warrant a heavier ridge than the mean family.
4105const SURVIVAL_MASS_REGULARIZE_HIGH_DIM: f64 = 0.18;
4106const SURVIVAL_MASS_REGULARIZE_LOW_DIM: f64 = 0.12;
4107
4108/// Jitter added during mass-matrix inversion to keep the metric strictly
4109/// positive-definite against round-off in the warmup covariance estimate.
4110const MASS_MATRIX_JITTER: f64 = 1e-5;
4111
4112#[inline]
4113fn robust_target_accept(requested: f64, dim: usize) -> f64 {
4114    let floor = if dim > HIGH_DIM_THRESHOLD {
4115        HIGH_DIM_TARGET_ACCEPT_FLOOR
4116    } else {
4117        LOW_DIM_TARGET_ACCEPT_FLOOR
4118    };
4119    requested.max(floor).min(MAX_TARGET_ACCEPT)
4120}
4121
4122fn jittered_initial_positions(
4123    config: &NutsConfig,
4124    dim: usize,
4125    scale: f64,
4126    stream: u64,
4127) -> Vec<Array1<f64>> {
4128    (0..config.n_chains)
4129        .map(|chain| {
4130            let mut rng = StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, stream));
4131            Array1::from_shape_fn(dim, |_| sample_standard_normal(&mut rng) * scale)
4132        })
4133        .collect()
4134}
4135
4136fn robust_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4137    if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4138        return NUTSMassMatrixConfig::disabled();
4139    }
4140    let start_buffer = (nwarmup / 8).clamp(35, 180);
4141    let end_buffer = (nwarmup / 5).clamp(50, 250);
4142    let initial_window = (nwarmup / 20).clamp(10, 60);
4143    NUTSMassMatrixConfig {
4144        adaptation: MassMatrixAdaptation::Diagonal,
4145        start_buffer,
4146        end_buffer,
4147        initial_window,
4148        regularize: if dim > HIGH_DIM_THRESHOLD {
4149            MASS_REGULARIZE_HIGH_DIM
4150        } else {
4151            MASS_REGULARIZE_LOW_DIM
4152        },
4153        jitter: MASS_MATRIX_JITTER,
4154        dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4155    }
4156}
4157
4158fn robust_survival_mass_matrix_config(dim: usize, nwarmup: usize) -> NUTSMassMatrixConfig {
4159    if nwarmup < MIN_WARMUP_FOR_MASS_ADAPT {
4160        return NUTSMassMatrixConfig::disabled();
4161    }
4162    // Survival posteriors with censoring/rare events are often skewed; this
4163    // configuration uses diagonal adaptation.
4164    let start_buffer = (nwarmup / 7).clamp(40, 200);
4165    let end_buffer = (nwarmup / 4).clamp(60, 280);
4166    let initial_window = (nwarmup / 20).clamp(10, 60);
4167    NUTSMassMatrixConfig {
4168        adaptation: MassMatrixAdaptation::Diagonal,
4169        start_buffer,
4170        end_buffer,
4171        initial_window,
4172        regularize: if dim > HIGH_DIM_THRESHOLD {
4173            SURVIVAL_MASS_REGULARIZE_HIGH_DIM
4174        } else {
4175            SURVIVAL_MASS_REGULARIZE_LOW_DIM
4176        },
4177        jitter: MASS_MATRIX_JITTER,
4178        dense_max_dim: DENSE_MASS_MATRIX_MAX_DIM,
4179    }
4180}
4181
4182impl Default for NutsConfig {
4183    fn default() -> Self {
4184        Self {
4185            n_samples: 1000,
4186            nwarmup: 500,
4187            n_chains: 4,
4188            target_accept: 0.9,
4189            seed: 42,
4190        }
4191    }
4192}
4193
4194impl NutsConfig {
4195    /// Create a config with sample counts tuned for the model dimension.
4196    ///
4197    /// Higher dimensions need more samples because:
4198    /// - ESS decreases with dimension (autocorrelation grows)
4199    /// - Split R-hat needs enough samples per chain to be meaningful
4200    ///
4201    /// Rule of thumb: target 100 effective samples per parameter.
4202    pub fn for_dimension(n_params: usize) -> Self {
4203        // ESS ≈ n_samples / (1 + 2τ) where τ ≈ sqrt(dim) for well-tuned NUTS
4204        let effective_autocorr = (n_params as f64).sqrt().max(1.0);
4205
4206        // Target: at least 100 effective samples per parameter
4207        let target_ess = 100 * n_params;
4208
4209        // Samples needed = ESS * (1 + 2τ), with 1.5x safety factor
4210        let raw_samples = (target_ess as f64 * (1.0 + 2.0 * effective_autocorr) * 1.5) as usize;
4211
4212        // Clamp to reasonable range [500, 10000]
4213        let n_samples = raw_samples.clamp(500, 10_000);
4214
4215        // Warmup ≈ samples (standard practice for adaptation)
4216        let nwarmup = n_samples;
4217
4218        // More chains for higher dims (better R-hat estimation)
4219        let n_chains = if n_params > 50 { 4 } else { 2 };
4220
4221        Self {
4222            n_samples,
4223            nwarmup,
4224            n_chains,
4225            target_accept: 0.9,
4226            seed: 42,
4227        }
4228    }
4229}
4230
4231/// Result of NUTS sampling.
4232#[derive(Clone, Debug)]
4233pub struct NutsResult {
4234    /// Coefficient samples in ORIGINAL space: shape (n_total_samples, n_coeffs)
4235    pub samples: Array2<f64>,
4236    /// Posterior mean
4237    pub posterior_mean: Array1<f64>,
4238    /// Posterior standard deviation
4239    pub posterior_std: Array1<f64>,
4240    /// R-hat convergence diagnostic
4241    pub rhat: f64,
4242    /// Effective sample size
4243    pub ess: f64,
4244    /// Whether sampling converged (R-hat < 1.1)
4245    pub converged: bool,
4246}
4247
4248#[derive(Clone, Copy)]
4249struct NutsConvergenceThresholds {
4250    max_rhat: f64,
4251    min_ess: Option<f64>,
4252}
4253
4254impl NutsConvergenceThresholds {
4255    #[inline]
4256    fn converged(self, rhat: f64, ess: f64) -> bool {
4257        let rhat_ok = rhat < self.max_rhat;
4258        match self.min_ess {
4259            Some(min_ess) => rhat_ok && ess > min_ess,
4260            None => rhat_ok,
4261        }
4262    }
4263}
4264
4265fn run_whitened_nuts_samples<Target>(
4266    target: Target,
4267    initial_positions: Vec<Array1<f64>>,
4268    config: &NutsConfig,
4269    dim: usize,
4270    mass_cfg: NUTSMassMatrixConfig,
4271    transition_seed_stream: u64,
4272    sampling_error_label: &str,
4273) -> Result<(Array3<f64>, String), String>
4274where
4275    Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4276{
4277    let mut sampler = GenericNUTS::new_with_mass_matrix(
4278        target,
4279        initial_positions,
4280        robust_target_accept(config.target_accept, dim),
4281        mass_cfg,
4282    )
4283    .set_seed(nuts_transition_seed(config.seed, transition_seed_stream));
4284
4285    let (samples_array, run_stats) = sampler
4286        .run_progress(config.n_samples, config.nwarmup)
4287        .map_err(|e| format!("{sampling_error_label}: {e}"))?;
4288    Ok((samples_array, run_stats.to_string()))
4289}
4290
4291fn unwhiten_samples(
4292    samples_array: &Array3<f64>,
4293    mode: &Array1<f64>,
4294    chol: &Array2<f64>,
4295    dim: usize,
4296    z_start: usize,
4297) -> Array2<f64> {
4298    let shape = samples_array.shape();
4299    let n_chains = shape[0];
4300    let n_samples_out = shape[1];
4301    let total_samples = n_chains * n_samples_out;
4302
4303    let mut samples = Array2::<f64>::zeros((total_samples, dim));
4304    let mut z_buffer = Array1::<f64>::zeros(dim);
4305    for chain in 0..n_chains {
4306        for sample_i in 0..n_samples_out {
4307            let zview = samples_array.slice(ndarray::s![chain, sample_i, z_start..z_start + dim]);
4308            z_buffer.assign(&zview);
4309            let beta = mode + &chol.dot(&z_buffer);
4310            let sample_idx = chain * n_samples_out + sample_i;
4311            samples.row_mut(sample_idx).assign(&beta);
4312        }
4313    }
4314
4315    samples
4316}
4317
4318fn summarize_unwhitened_nuts_samples(
4319    samples: Array2<f64>,
4320    samples_array: &Array3<f64>,
4321    empty_mean: Array1<f64>,
4322    convergence: NutsConvergenceThresholds,
4323) -> NutsResult {
4324    let posterior_mean = samples.mean_axis(Axis(0)).unwrap_or(empty_mean);
4325    let posterior_std = samples.std_axis(Axis(0), 0.0);
4326    let (rhat, ess) = compute_split_rhat_and_ess(samples_array);
4327    let converged = convergence.converged(rhat, ess);
4328
4329    NutsResult {
4330        samples,
4331        posterior_mean,
4332        posterior_std,
4333        rhat,
4334        ess,
4335        converged,
4336    }
4337}
4338
4339fn run_whitened_nuts_result<Target>(
4340    target: Target,
4341    mode: &Array1<f64>,
4342    chol: &Array2<f64>,
4343    initial_positions: Vec<Array1<f64>>,
4344    config: &NutsConfig,
4345    dim: usize,
4346    mass_cfg: NUTSMassMatrixConfig,
4347    transition_seed_stream: u64,
4348    sampling_error_label: &str,
4349    empty_mean: Array1<f64>,
4350    convergence: NutsConvergenceThresholds,
4351) -> Result<(NutsResult, String), String>
4352where
4353    Target: HamiltonianTarget<Array1<f64>> + Sync + Send,
4354{
4355    let (samples_array, run_stats) = run_whitened_nuts_samples(
4356        target,
4357        initial_positions,
4358        config,
4359        dim,
4360        mass_cfg,
4361        transition_seed_stream,
4362        sampling_error_label,
4363    )?;
4364    let samples = unwhiten_samples(&samples_array, mode, chol, dim, 0);
4365    let result =
4366        summarize_unwhitened_nuts_samples(samples, &samples_array, empty_mean, convergence);
4367    Ok((result, run_stats))
4368}
4369
4370impl NutsResult {
4371    /// Computes the posterior mean of a function applied to coefficients.
4372    /// Returns 0.0 if samples is empty to avoid divide-by-zero.
4373    pub fn posterior_mean_of<F>(&self, f: F) -> f64
4374    where
4375        F: Fn(ArrayView1<f64>) -> f64 + Sync,
4376    {
4377        let n = self.samples.nrows();
4378        if n == 0 {
4379            return 0.0;
4380        }
4381        // Posterior mean of a sample-function: parallel reduction over rows.
4382        // `f: Fn(ArrayView1) -> f64` is shared-access so safe across threads.
4383        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4384        let sum: f64 = (0..n).into_par_iter().map(|i| f(self.samples.row(i))).sum();
4385        sum / n as f64
4386    }
4387
4388    /// Computes percentiles of a function applied to coefficients.
4389    pub fn posterior_interval_of<F>(&self, f: F, lower_pct: f64, upper_pct: f64) -> (f64, f64)
4390    where
4391        F: Fn(ArrayView1<f64>) -> f64,
4392    {
4393        let n = self.samples.nrows();
4394        if n == 0 {
4395            return (0.0, 0.0);
4396        }
4397        let mut values: Vec<f64> = (0..n).map(|i| f(self.samples.row(i))).collect();
4398        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4399
4400        let lower_idx = ((lower_pct / 100.0) * n as f64).floor() as usize;
4401        let upper_idx = ((upper_pct / 100.0) * n as f64).ceil() as usize;
4402
4403        (
4404            values[lower_idx.min(n.saturating_sub(1))],
4405            values[upper_idx.min(n.saturating_sub(1))],
4406        )
4407    }
4408}
4409
4410#[inline]
4411fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
4412    let u1 = rng.random::<f64>().max(1e-16);
4413    let u2 = rng.random::<f64>();
4414    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
4415}
4416
4417/// Runs a Pólya-Gamma Gibbs sampler for Bernoulli-logit models.
4418///
4419/// This sampler is gradient-free: each iteration alternates
4420/// 1) ω_i | β, y ~ PG(1, x_i^T β), and
4421/// 2) β | ω, y ~ N(Q^{-1} b, Q^{-1}), with Q = S + X^T diag(ω) X, b = X^T(y - 1/2).
4422///
4423/// For weighted data, this implementation is defined for weights ≈ 1.0 because it
4424/// samples PG(1,·) latent variables.
4425pub fn run_logit_polya_gamma_gibbs(
4426    x: ArrayView2<f64>,
4427    y: ArrayView1<f64>,
4428    weights: ArrayView1<f64>,
4429    penalty_matrix: ArrayView2<f64>,
4430    mode: ArrayView1<f64>,
4431    config: &NutsConfig,
4432) -> Result<NutsResult, String> {
4433    let n = x.nrows();
4434    let p = x.ncols();
4435    if y.len() != n || weights.len() != n {
4436        return Err(HmcError::DimensionMismatch {
4437            reason: "run_logit_polya_gamma_gibbs: input length mismatch".to_string(),
4438        }
4439        .into());
4440    }
4441    if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4442        return Err(HmcError::DimensionMismatch {
4443            reason: "run_logit_polya_gamma_gibbs: coefficient/penalty dimension mismatch"
4444                .to_string(),
4445        }
4446        .into());
4447    }
4448    if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4449        return Err(HmcError::InvalidConfig {
4450            reason: "run_logit_polya_gamma_gibbs requires unit weights (PG(1,·)); use NUTS for non-unit weights".to_string(),
4451        }
4452        .into());
4453    }
4454    validate_binary_responses("run_logit_polya_gamma_gibbs", &y, &weights).map_err(String::from)?;
4455    // Issue #399: the auto-selected PG-Gibbs path is reached for the canonical
4456    // unit-weight Bernoulli-logit GAM. Without this guard, `n_chains == 0` /
4457    // `n_samples == 0` would not panic but silently return a degenerate empty
4458    // `(0, p)` posterior, diverging from the typed error the NUTS path raises
4459    // for the same inputs. Route it through the shared validator so every
4460    // `Model.sample` surface rejects degenerate draw/chain counts identically.
4461    validate_nuts_config(config).map_err(String::from)?;
4462
4463    let n_iter = config.nwarmup + config.n_samples;
4464
4465    // b = X^T (y - 1/2), constant across iterations.
4466    let kappa = y.mapv(|v| v - 0.5);
4467    let rhs_b = fast_atv(&x, &kappa);
4468
4469    let mut samples_array = Array3::<f64>::zeros((config.n_chains, config.n_samples, p));
4470    let mut eta = Array1::<f64>::zeros(n);
4471    let mut omega = Array1::<f64>::ones(n);
4472    let pg_shapes = Array1::<u32>::from_elem(n, 1);
4473    let mut xw = x.to_owned();
4474    let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4475    let penalty = penalty_matrix.to_owned();
4476    let mut q = Array2::<f64>::zeros((p, p));
4477    let mut mean = Array1::<f64>::zeros(p);
4478    let mut z = Array1::<f64>::zeros(p);
4479    let mut noise = Array1::<f64>::zeros(p);
4480
4481    for chain in 0..config.n_chains {
4482        let mut init_rng =
4483            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xB3C4_5A1F_8E9D_7632));
4484        let mut draw_rng =
4485            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x17A9_26D5_4C1B_E083));
4486        let mut beta = mode.to_owned();
4487        // Small jitter so chains are not perfectly coupled.
4488        for j in 0..p {
4489            beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4490        }
4491
4492        for iter in 0..n_iter {
4493            eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4494            draw_logit_pg1_omega(
4495                pg_shapes.view(),
4496                eta.view(),
4497                gibbs_pg_seed(config.seed, chain, 0x4D94_DF4E_5D72_81AB, iter),
4498                &mut omega,
4499            )?;
4500
4501            // Build Xweighted = diag(sqrt(ω)) X and compute X^T Ω X via faer GEMM.
4502            // Per-row scaling is fully independent across rows.
4503            ndarray::Zip::indexed(xw.rows_mut())
4504                .and(x.rows())
4505                .and(&omega)
4506                .par_for_each(|_idx, mut xw_row, x_row, omega_i| {
4507                    let s = omega_i.sqrt();
4508                    for j in 0..p {
4509                        xw_row[j] = x_row[j] * s;
4510                    }
4511                });
4512            fast_ata_into(&xw, &mut xt_omega_x);
4513
4514            q.assign(&penalty);
4515            q += &xt_omega_x;
4516
4517            // β | ω,y ~ N(Q^{-1} b, Q^{-1})
4518            let factor = q
4519                .cholesky(Side::Lower)
4520                .map_err(|e| format!("PG Gibbs failed to factor Q: {:?}", e))?;
4521            mean.assign(&factor.solvevec(&rhs_b));
4522
4523            for j in 0..p {
4524                z[j] = sample_standard_normal(&mut draw_rng);
4525            }
4526            let l = factor.lower_triangular();
4527            back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4528            beta.assign(&(&mean + &noise));
4529
4530            if iter >= config.nwarmup {
4531                let keep_idx = iter - config.nwarmup;
4532                samples_array
4533                    .slice_mut(ndarray::s![chain, keep_idx, ..])
4534                    .assign(&beta);
4535            }
4536        }
4537    }
4538
4539    let total_samples = config.n_chains * config.n_samples;
4540    let mut samples = Array2::<f64>::zeros((total_samples, p));
4541    for chain in 0..config.n_chains {
4542        for s in 0..config.n_samples {
4543            let idx = chain * config.n_samples + s;
4544            samples
4545                .row_mut(idx)
4546                .assign(&samples_array.slice(ndarray::s![chain, s, ..]));
4547        }
4548    }
4549
4550    let posterior_mean = samples
4551        .mean_axis(Axis(0))
4552        .unwrap_or_else(|| Array1::zeros(p));
4553    let posterior_std = samples.std_axis(Axis(0), 0.0);
4554    let (rhat, ess) = if config.n_chains >= 2 && config.n_samples >= 4 {
4555        compute_split_rhat_and_ess(&samples_array)
4556    } else {
4557        (1.0, (total_samples as f64) * 0.5)
4558    };
4559    let converged = rhat < 1.1 && ess > 100.0;
4560
4561    Ok(NutsResult {
4562        samples,
4563        posterior_mean,
4564        posterior_std,
4565        rhat,
4566        ess,
4567        converged,
4568    })
4569}
4570
4571/// Estimate E_{ω|y,ρ}[ tr(S_k Q^{-1}) + μᵀ S_k μ ] with PG Gibbs + Rao-Blackwellization.
4572///
4573/// For each retained Gibbs state ω:
4574///   Q = S + Xᵀ diag(ω) X,  μ = Q^{-1} Xᵀ(y-1/2),
4575/// and with S_k = R_kᵀ R_k:
4576///   tr(S_k Q^{-1}) + μᵀ S_k μ
4577/// = tr(R_k Q^{-1} R_kᵀ) + ||R_k μ||².
4578///
4579/// Returns one expectation per penalty block k, averaged over retained draws.
4580pub fn estimate_logit_pg_rao_blackwell_terms(
4581    x: ArrayView2<f64>,
4582    y: ArrayView1<f64>,
4583    weights: ArrayView1<f64>,
4584    penalty_matrix: ArrayView2<f64>,
4585    mode: ArrayView1<f64>,
4586    penalty_roots: &[Array2<f64>],
4587    config: &NutsConfig,
4588) -> Result<Array1<f64>, String> {
4589    let n = x.nrows();
4590    let p = x.ncols();
4591    if y.len() != n || weights.len() != n {
4592        return Err(HmcError::DimensionMismatch {
4593            reason: "estimate_logit_pg_rao_blackwell_terms: input length mismatch".to_string(),
4594        }
4595        .into());
4596    }
4597    if mode.len() != p || penalty_matrix.nrows() != p || penalty_matrix.ncols() != p {
4598        return Err(HmcError::DimensionMismatch {
4599            reason: "estimate_logit_pg_rao_blackwell_terms: coefficient/penalty dimension mismatch"
4600                .to_string(),
4601        }
4602        .into());
4603    }
4604    if !weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10) {
4605        return Err(HmcError::InvalidConfig {
4606            reason: "estimate_logit_pg_rao_blackwell_terms requires unit weights (PG(1,·))"
4607                .to_string(),
4608        }
4609        .into());
4610    }
4611    validate_binary_responses("estimate_logit_pg_rao_blackwell_terms", &y, &weights)
4612        .map_err(String::from)?;
4613    if penalty_roots.iter().any(|r| r.ncols() != p) {
4614        return Err(HmcError::DimensionMismatch {
4615            reason: "estimate_logit_pg_rao_blackwell_terms: root width mismatch".to_string(),
4616        }
4617        .into());
4618    }
4619    // Precompute transposed root blocks once:
4620    //   R_k^T is the RHS used for batched solves Q X = R_k^T.
4621    let penalty_roots_t: Vec<Array2<f64>> =
4622        penalty_roots.iter().map(|r| r.t().to_owned()).collect();
4623
4624    let n_iter = config.nwarmup + config.n_samples;
4625
4626    // Logistic PG identity uses kappa_i = y_i - 1/2 so that
4627    // b = X^T kappa in the Gaussian conditional for beta|omega.
4628    let kappa = y.mapv(|v| v - 0.5);
4629    let rhs_b = fast_atv(&x, &kappa);
4630
4631    let penalty = penalty_matrix.to_owned();
4632    let mut eta = Array1::<f64>::zeros(n);
4633    let mut omega = Array1::<f64>::ones(n);
4634    let pg_shapes = Array1::<u32>::from_elem(n, 1);
4635    let mut xw = x.to_owned();
4636    let mut xt_omega_x = Array2::<f64>::zeros((p, p));
4637    let mut q = Array2::<f64>::zeros((p, p));
4638    let mut mean = Array1::<f64>::zeros(p);
4639    let mut rb_sum = Array1::<f64>::zeros(penalty_roots.len());
4640    let mut z = Array1::<f64>::zeros(p);
4641    let mut noise = Array1::<f64>::zeros(p);
4642
4643    let mut kept = 0usize;
4644    for chain in 0..config.n_chains {
4645        let mut init_rng =
4646            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x28F0_7B65_1A4D_C93E));
4647        let mut draw_rng =
4648            StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0xC642_6E35_B5A9_1D80));
4649        let mut beta = mode.to_owned();
4650        for j in 0..p {
4651            beta[j] += 0.05 * sample_standard_normal(&mut init_rng);
4652        }
4653
4654        for iter in 0..n_iter {
4655            eta.assign(&gam_linalg::faer_ndarray::fast_av(&x, &beta));
4656            draw_logit_pg1_omega(
4657                pg_shapes.view(),
4658                eta.view(),
4659                gibbs_pg_seed(config.seed, chain, 0x83F1_56C9_A7E0_2D4B, iter),
4660                &mut omega,
4661            )?;
4662
4663            ndarray::Zip::from(xw.rows_mut())
4664                .and(x.rows())
4665                .and(&omega)
4666                .par_for_each(|mut xw_row, x_row, &omega_i| {
4667                    let s = omega_i.sqrt();
4668                    for j in 0..p {
4669                        xw_row[j] = x_row[j] * s;
4670                    }
4671                });
4672            fast_ata_into(&xw, &mut xt_omega_x);
4673
4674            // Conditional precision:
4675            //   Q = S + X^T diag(omega) X.
4676            q.assign(&penalty);
4677            q += &xt_omega_x;
4678
4679            let factor = q
4680                .cholesky(Side::Lower)
4681                .map_err(|e| format!("PG Rao-Blackwell failed to factor Q: {:?}", e))?;
4682            // Conditional mean:
4683            //   mu = Q^{-1} b,  b = X^T(y - 1/2).
4684            mean.assign(&factor.solvevec(&rhs_b));
4685
4686            // Draw beta for the next Gibbs state.
4687            for j in 0..p {
4688                z[j] = sample_standard_normal(&mut draw_rng);
4689            }
4690            let l = factor.lower_triangular();
4691            back_substitution_lower_transpose_guarded_into(&l, &z, &mut noise);
4692            beta.assign(&(&mean + &noise));
4693
4694            if iter < config.nwarmup {
4695                continue;
4696            }
4697            kept += 1;
4698
4699            for (k, r_k) in penalty_roots.iter().enumerate() {
4700                if r_k.nrows() == 0 {
4701                    continue;
4702                }
4703
4704                // mu^T S_k mu via root form S_k = R_k^T R_k.
4705                let rmu = r_k.dot(&mean);
4706                let mu_quad = rmu.dot(&rmu);
4707
4708                // Batched trace solve:
4709                //   V_k = Q^{-1} R_k^T  (single multi-RHS solve)
4710                // then tr(R_k Q^{-1} R_k^T) = <R_k, V_k^T>_F.
4711                let solved_mat = factor.solve_mat(&penalty_roots_t[k]); // (p, r_k)
4712                let solved_t = solved_mat.t();
4713                let mut trace_term = 0.0_f64;
4714                for (&a, &b) in r_k.iter().zip(solved_t.iter()) {
4715                    trace_term += a * b;
4716                }
4717
4718                rb_sum[k] += trace_term + mu_quad;
4719            }
4720        }
4721    }
4722
4723    if kept == 0 {
4724        return Err(HmcError::SamplingFailed {
4725            reason: "estimate_logit_pg_rao_blackwell_terms: no retained samples".to_string(),
4726        }
4727        .into());
4728    }
4729    let out = rb_sum.mapv(|v| v / (kept as f64));
4730    if !out.iter().all(|v| v.is_finite()) {
4731        return Err(HmcError::NonFiniteState {
4732            reason: "estimate_logit_pg_rao_blackwell_terms: non-finite expectation".to_string(),
4733        }
4734        .into());
4735    }
4736    Ok(out)
4737}
4738
4739/// Runs NUTS sampling using general-mcmc with whitened parameter space.
4740///
4741/// # Arguments
4742/// * `x` - Design matrix [n_samples, dim]
4743/// * `y` - Response vector [n_samples]
4744/// * `weights` - Observation/case weights [n_samples]
4745/// * `penalty_matrix` - Combined penalty S [dim, dim]
4746/// * `mode` - MAP estimate μ [dim]
4747/// * `hessian` - Penalized Hessian H [dim, dim] (NOT the inverse!)
4748/// * `nuts_family` - Family for log-likelihood computation
4749/// * `firth_bias_reduction` - Whether Firth bias reduction was used in training
4750/// * `config` - NUTS configuration
4751pub(crate) fn run_nuts_sampling(
4752    x: ArrayView2<f64>,
4753    y: ArrayView1<f64>,
4754    weights: ArrayView1<f64>,
4755    penalty_matrix: ArrayView2<f64>,
4756    mode: ArrayView1<f64>,
4757    hessian: ArrayView2<f64>,
4758    nuts_family: NutsFamily,
4759    gamma_shape: f64,
4760    dispersion: gam_solve::model_types::Dispersion,
4761    firth_bias_reduction: bool,
4762    offset: Option<ArrayView1<f64>>,
4763    config: &NutsConfig,
4764) -> Result<NutsResult, String> {
4765    validate_firth_support(nuts_family, firth_bias_reduction).map_err(String::from)?;
4766    validate_nuts_config(config).map_err(String::from)?;
4767    if nuts_family == NutsFamily::TweedieLog && !is_valid_tweedie_power(gamma_shape) {
4768        return Err(format!(
4769            "Tweedie variance power must be finite and strictly between 1 and 2; got {gamma_shape}"
4770        ));
4771    }
4772    let dim = mode.len();
4773
4774    // Create posterior target with analytical gradients. When Firth is enabled,
4775    // this target includes the identifiable-subspace Jeffreys term.
4776    let target = NutsPosterior::new(
4777        x,
4778        y,
4779        weights,
4780        penalty_matrix,
4781        mode,
4782        hessian,
4783        nuts_family,
4784        gamma_shape,
4785        dispersion,
4786        firth_bias_reduction,
4787    )?;
4788    let target = match offset {
4789        Some(offset) => target.with_offset(offset)?,
4790        None => target,
4791    };
4792
4793    // Get Cholesky factor for un-whitening samples later
4794    let chol = target.chol().clone();
4795    let mode_arr = target.mode().clone();
4796
4797    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x0F65_83B2_BC71_4D9E);
4798    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4799    let (result, run_stats) = run_whitened_nuts_result(
4800        target,
4801        &mode_arr,
4802        &chol,
4803        initial_positions,
4804        config,
4805        dim,
4806        mass_cfg,
4807        0xF1D3_C2B5_A697_804E,
4808        "NUTS sampling failed",
4809        Array1::zeros(dim),
4810        NutsConvergenceThresholds {
4811            max_rhat: 1.1,
4812            min_ess: Some(100.0),
4813        },
4814    )?;
4815    log::info!("NUTS sampling complete: {}", run_stats);
4816
4817    Ok(result)
4818}
4819
4820/// Terminal never-fail Gaussian-posterior sampling target.
4821///
4822/// This is the bottom rung of the solver's geometry-driven escalation ladder.
4823/// When the outer smoothing optimizer cannot certify convergence on a custom
4824/// (BMS / general) family — typically because Strong-Wolfe stalls on an
4825/// indefinite or non-smooth LAML objective — the driver no longer dead-ends
4826/// with an `Err`. Instead it lands here: the *same* penalized objective's
4827/// curvature (its penalized joint Hessian `H = −∇²log L + Σ_k λ_k S_k`,
4828/// augmented with the proper (unconditional) Jeffreys/PC term)
4829/// is used as the precision of a proper Gaussian posterior `N(β̂, H⁻¹)` about
4830/// the best mode `β̂` the inner solve reached. Sampling a multivariate normal
4831/// cannot fail: in the worst case (a poorly conditioned `H`) the intervals come
4832/// out honestly wider, which is the intended "magic for all users" behavior —
4833/// a finite point with calibrated SEs instead of a hard error.
4834///
4835/// The target is expressed in the whitened space `z` (`β = β̂ + L z`,
4836/// `L Lᵀ = H⁻¹`), where the posterior is the standard normal `N(0, I)`. Its
4837/// log-density and gradient are then exactly `logp(z) = −½ zᵀz`,
4838/// `∇ = −z` — a smooth, globally coercive target with no failure mode. The
4839/// `chol` factor un-whitens draws back to coefficient space, identically to
4840/// the `NutsPosterior` whitening contract above.
4841struct GaussianModeTarget;
4842
4843impl HamiltonianTarget<Array1<f64>> for GaussianModeTarget {
4844    #[inline]
4845    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4846        // Standard-normal target in whitened coordinates: logp = -0.5 zᵀz,
4847        // ∇ = -z. The whitening `L` (built from the penalized Hessian) carries
4848        // all of the posterior geometry, so the sampler itself only ever sees a
4849        // unit-covariance Gaussian — which is why this rung cannot stall.
4850        let mut quad = 0.0;
4851        for (g, &zi) in grad.iter_mut().zip(position.iter()) {
4852            *g = -zi;
4853            quad += zi * zi;
4854        }
4855        -0.5 * quad
4856    }
4857}
4858
4859/// Sample the proper Gaussian posterior `N(mode, H⁻¹)` defined by a mode and a
4860/// (penalized, Jeffreys-augmented) SPD precision `hessian`.
4861///
4862/// This is the terminal, never-fail rung of the outer-optimizer escalation:
4863/// it consumes the same penalized-objective curvature the inner machinery
4864/// already computed and returns an honest posterior summary. It returns `Err`
4865/// only for a *structurally* impossible request (dimension mismatch, a Hessian
4866/// that is not even positive-definite after symmetrization, a degenerate
4867/// config) — never for "did not converge", which is precisely the dead-end this
4868/// path exists to remove.
4869///
4870/// `hessian` must be the SPD penalized joint Hessian at `mode` (e.g. from
4871/// `compute_joint_geometry`). It is symmetrized defensively and Cholesky-
4872/// factored to build the whitening `L` with `L Lᵀ = H⁻¹`.
4873pub fn sample_gaussian_mode_posterior(
4874    mode: ArrayView1<f64>,
4875    hessian: ArrayView2<f64>,
4876    config: &NutsConfig,
4877) -> Result<GaussianModePosterior, String> {
4878    validate_nuts_config(config).map_err(String::from)?;
4879    let dim = mode.len();
4880    if hessian.nrows() != dim || hessian.ncols() != dim {
4881        return Err(format!(
4882            "Gaussian-posterior fallback: hessian shape {:?} does not match mode dim {dim}",
4883            hessian.dim()
4884        ));
4885    }
4886    if dim == 0 {
4887        return Err("Gaussian-posterior fallback: zero-dimensional posterior".to_string());
4888    }
4889
4890    // Symmetrize defensively (the assembled joint Hessian may carry
4891    // floating-point asymmetry from directional-callback construction) and add
4892    // a tiny jitter on the diagonal so a Hessian that is SPD-up-to-roundoff at a
4893    // boundary optimum still factors. The jitter only ever *widens* the
4894    // posterior, consistent with the honest-interval guarantee.
4895    let mut h = hessian.to_owned();
4896    for i in 0..dim {
4897        for j in (i + 1)..dim {
4898            let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
4899            h[[i, j]] = avg;
4900            h[[j, i]] = avg;
4901        }
4902    }
4903    let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
4904    let jitter = (diag_scale * 1e-10).max(1e-12);
4905    for i in 0..dim {
4906        h[[i, i]] += jitter;
4907    }
4908
4909    let mode_owned = mode.to_owned();
4910    let whitening = hessian_whitening_transform(
4911        h.view(),
4912        dim,
4913        1.0,
4914        "Gaussian-posterior fallback Cholesky failed",
4915    )?;
4916    let chol = whitening.chol;
4917    let target = GaussianModeTarget;
4918    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x51A6_2C73_90E4_1DBF);
4919    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
4920    let (result, run_stats) = run_whitened_nuts_result(
4921        target,
4922        &mode_owned,
4923        &chol,
4924        initial_positions,
4925        config,
4926        dim,
4927        mass_cfg,
4928        0x7C19_5A3E_82D6_44B1,
4929        "Gaussian-posterior fallback NUTS sampling failed",
4930        mode_owned.clone(),
4931        NutsConvergenceThresholds {
4932            max_rhat: 1.1,
4933            min_ess: None,
4934        },
4935    )?;
4936    log::info!(
4937        "never-fail Gaussian-posterior fallback: sampling complete dim={dim} {}",
4938        run_stats
4939    );
4940
4941    Ok(GaussianModePosterior {
4942        samples: result.samples,
4943        posterior_mean: result.posterior_mean,
4944        posterior_std: result.posterior_std,
4945        rhat: result.rhat,
4946        ess: result.ess,
4947    })
4948}
4949
4950/// Penalty subtracted from the log-density when the `ρ`-criterion closure
4951/// reports an infeasible / non-finite point during Tier-2 `ρ`-posterior NUTS
4952/// (#938). The fallback density is the whitened standard normal shifted down by
4953/// this constant, so the sampler sees a smooth, coercive pull back toward the
4954/// feasible region around `ρ̂` instead of a `-inf` cliff.
4955const RHO_NUTS_INFEASIBLE_LOGP_PENALTY: f64 = 1.0e8;
4956
4957/// Tier-2 of the exact marginal-smoothing inference stack (#938): the whitened
4958/// `ρ`-criterion Hamiltonian target.
4959///
4960/// This reuses the module's β-level whitening design ONE LEVEL UP: the target
4961/// log-density is `logp(ρ) = −(criterion(ρ) − criterion(ρ̂))` — i.e.
4962/// `π(ρ|y) ∝ exp(−LAML(ρ))`, the exact profiled criterion the outer optimizer
4963/// minimizes — expressed in the whitened coordinates `ρ = ρ̂ + L z` with
4964/// `L Lᵀ = H_ρ⁻¹` built from the exact outer Hessian at `ρ̂`. The gradient is
4965/// the caller's exact profiled `ρ`-gradient pushed through the chain rule:
4966/// `∇_z logp = −Lᵀ ∇_ρ criterion`.
4967///
4968/// The criterion closure is `FnMut` (each evaluation is one warm inner profile
4969/// solve with interior caches), so it is serialized behind a `Mutex`; chains
4970/// take turns evaluating, which also keeps the inner warm-start trajectory
4971/// coherent.
4972struct WhitenedRhoCriterionTarget<F> {
4973    /// `ρ ↦ (criterion(ρ), ∇_ρ criterion(ρ))`; `None` marks an infeasible point.
4974    criterion_and_grad: Mutex<F>,
4975    /// `ρ̂`, the converged smoothing parameters (the whitening center).
4976    mode: Array1<f64>,
4977    /// `L` with `L Lᵀ = H_ρ⁻¹`: maps whitened `z` to `ρ = ρ̂ + L z`.
4978    chol: Array2<f64>,
4979    /// `Lᵀ`, for the gradient chain rule.
4980    chol_t: Array2<f64>,
4981    /// `criterion(ρ̂)`, subtracted for numerical stability (cancels in MCMC).
4982    cost_hat: f64,
4983}
4984
4985impl<F> HamiltonianTarget<Array1<f64>> for WhitenedRhoCriterionTarget<F>
4986where
4987    F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
4988{
4989    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
4990        let rho = &self.mode + &self.chol.dot(position);
4991        let eval = {
4992            let mut criterion = self
4993                .criterion_and_grad
4994                .lock()
4995                .expect("rho-criterion mutex poisoned");
4996            (*criterion)(&rho)
4997        };
4998        match eval {
4999            Some((cost, g))
5000                if cost.is_finite()
5001                    && g.len() == position.len()
5002                    && g.iter().all(|v| v.is_finite()) =>
5003            {
5004                let grad_z = self.chol_t.dot(&g);
5005                for (gi, &v) in grad.iter_mut().zip(grad_z.iter()) {
5006                    *gi = -v;
5007                }
5008                -(cost - self.cost_hat)
5009            }
5010            _ => {
5011                // Infeasible criterion: smooth coercive fallback toward ρ̂.
5012                let mut quad = 0.0;
5013                for (gi, &zi) in grad.iter_mut().zip(position.iter()) {
5014                    *gi = -zi;
5015                    quad += zi * zi;
5016                }
5017                -0.5 * quad - RHO_NUTS_INFEASIBLE_LOGP_PENALTY
5018            }
5019        }
5020    }
5021}
5022
5023/// Run NUTS over the smoothing parameters `ρ` with the exact profiled criterion
5024/// and gradient (#938 Tier 2).
5025///
5026/// * `rho_hat` — converged `ρ̂` (the whitening center and chain seed).
5027/// * `outer_hessian` — exact outer Hessian `H_ρ` at `ρ̂` (symmetrized and
5028///   jittered defensively, then Cholesky-factored for the whitening).
5029/// * `criterion_and_grad` — `ρ ↦ (LAML(ρ), ∇_ρ LAML(ρ))`, both exact; `None`
5030///   for infeasible `ρ`. Each call is one warm inner profile solve.
5031/// * `config` — sampler configuration; determinism comes from `config.seed`
5032///   through the same splitmix64 chain/transition streams as every other NUTS
5033///   entry point (no clock, no global RNG).
5034///
5035/// Returns draws in the ORIGINAL `ρ` space (un-whitened), with split-R̂/ESS
5036/// diagnostics.
5037pub fn run_rho_criterion_nuts<F>(
5038    rho_hat: ArrayView1<f64>,
5039    outer_hessian: ArrayView2<f64>,
5040    mut criterion_and_grad: F,
5041    config: &NutsConfig,
5042) -> Result<NutsResult, String>
5043where
5044    F: FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send,
5045{
5046    validate_nuts_config(config).map_err(String::from)?;
5047    let dim = rho_hat.len();
5048    if dim == 0 {
5049        return Err("rho-posterior NUTS: zero-dimensional rho".to_string());
5050    }
5051    if outer_hessian.nrows() != dim || outer_hessian.ncols() != dim {
5052        return Err(format!(
5053            "rho-posterior NUTS: outer Hessian shape {:?} does not match rho dim {dim}",
5054            outer_hessian.dim()
5055        ));
5056    }
5057
5058    // Symmetrize + jitter the exact outer Hessian so a boundary optimum that is
5059    // SPD-up-to-roundoff still factors; jitter only widens the proposal metric.
5060    let mut h = outer_hessian.to_owned();
5061    for i in 0..dim {
5062        for j in (i + 1)..dim {
5063            let avg = 0.5 * (h[[i, j]] + h[[j, i]]);
5064            h[[i, j]] = avg;
5065            h[[j, i]] = avg;
5066        }
5067    }
5068    let diag_scale = (0..dim).map(|i| h[[i, i]].abs()).fold(0.0_f64, f64::max);
5069    let jitter = (diag_scale * 1e-10).max(1e-12);
5070    for i in 0..dim {
5071        h[[i, i]] += jitter;
5072    }
5073
5074    let mode = rho_hat.to_owned();
5075    let whitening = hessian_whitening_transform(
5076        h.view(),
5077        dim,
5078        1.0,
5079        "rho-posterior NUTS: outer-Hessian Cholesky failed",
5080    )?;
5081
5082    let cost_hat = match criterion_and_grad(&mode) {
5083        Some((cost, _)) if cost.is_finite() => cost,
5084        _ => {
5085            return Err(
5086                "rho-posterior NUTS: criterion is infeasible at rho_hat itself".to_string(),
5087            );
5088        }
5089    };
5090
5091    let chol = whitening.chol;
5092    let target = WhitenedRhoCriterionTarget {
5093        criterion_and_grad: Mutex::new(criterion_and_grad),
5094        mode: mode.clone(),
5095        chol: chol.clone(),
5096        chol_t: whitening.chol_t,
5097        cost_hat,
5098    };
5099    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x3D8A_91C4_E27B_5F60);
5100    // The rho target is already whitened by the exact outer Hessian at rho_hat,
5101    // so the local mass matrix in z-space is identity. Re-adapting a diagonal or
5102    // dense metric during warmup would spend expensive profile solves estimating
5103    // curvature we have already supplied analytically.
5104    let mass_cfg = NUTSMassMatrixConfig::disabled();
5105    let (result, run_stats) = run_whitened_nuts_result(
5106        target,
5107        &mode,
5108        &chol,
5109        initial_positions,
5110        config,
5111        dim,
5112        mass_cfg,
5113        0x6B42_E9A1_05D7_C83F,
5114        "rho-posterior NUTS sampling failed",
5115        mode.clone(),
5116        NutsConvergenceThresholds {
5117            max_rhat: 1.1,
5118            min_ess: None,
5119        },
5120    )?;
5121    log::info!("rho-posterior NUTS (#938 tier 2): sampling complete dim={dim} {run_stats}");
5122    Ok(result)
5123}
5124
5125/// Flattened numeric inputs for GLM-family NUTS sampling.
5126pub struct GlmFlatInputs<'a> {
5127    pub x: ArrayView2<'a, f64>,
5128    pub y: ArrayView1<'a, f64>,
5129    pub weights: ArrayView1<'a, f64>,
5130    pub penalty_matrix: ArrayView2<'a, f64>,
5131    pub mode: ArrayView1<'a, f64>,
5132    pub hessian: ArrayView2<'a, f64>,
5133    pub gamma_shape: Option<f64>,
5134    /// Dispersion parameter φ used to scale the likelihood and the
5135    /// whitening Cholesky. For fixed-scale families (Binomial, Poisson)
5136    /// this is `Dispersion::Known(1.0)` and has no numerical effect;
5137    /// for Gaussian / Gamma it carries the estimated `phi` so that the
5138    /// sampler targets the φ-scaled posterior covariance `Vb = φ·H⁻¹`.
5139    /// See `inference::dispersion_cov` for the ownership invariants.
5140    pub dispersion: gam_solve::model_types::Dispersion,
5141    pub firth_bias_reduction: bool,
5142    /// Fixed additive offset on the linear predictor (η = Xβ + offset), or
5143    /// `None` for an offset-free fit. Carried so posterior sampling targets the
5144    /// same η the model was fit and predicts on; omitting it sampled the wrong
5145    /// posterior for any `--offset-column` model (#882).
5146    pub offset: Option<ArrayView1<'a, f64>>,
5147}
5148
5149/// Flat survival inputs for engine-facing HMC APIs.
5150pub struct SurvivalFlatInputs<'a> {
5151    pub age_entry: ArrayView1<'a, f64>,
5152    pub age_exit: ArrayView1<'a, f64>,
5153    pub event_target: ArrayView1<'a, u8>,
5154    pub event_competing: ArrayView1<'a, u8>,
5155    pub weights: ArrayView1<'a, f64>,
5156    pub x_entry: ArrayView2<'a, f64>,
5157    pub x_exit: ArrayView2<'a, f64>,
5158    pub x_derivative: ArrayView2<'a, f64>,
5159    pub eta_offset_entry: Option<ArrayView1<'a, f64>>,
5160    pub eta_offset_exit: Option<ArrayView1<'a, f64>>,
5161    pub derivative_offset_exit: Option<ArrayView1<'a, f64>>,
5162}
5163
5164/// Flattened numeric inputs for Royston-Parmar NUTS sampling.
5165pub struct SurvivalNutsInputs<'a> {
5166    pub flat: SurvivalFlatInputs<'a>,
5167    pub penalties: gam_models::survival::PenaltyBlocks,
5168    pub monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
5169    pub spec: gam_models::survival::SurvivalSpec,
5170    pub structurally_monotonic: bool,
5171    pub structural_time_columns: usize,
5172    pub mode: ArrayView1<'a, f64>,
5173    pub hessian: ArrayView2<'a, f64>,
5174}
5175
5176/// Family-dispatched flattened NUTS inputs.
5177pub enum FamilyNutsInputs<'a> {
5178    Glm(GlmFlatInputs<'a>),
5179    Survival(Box<SurvivalNutsInputs<'a>>),
5180}
5181
5182/// Return the explicit fitted penalized Hessian used for HMC/NUTS whitening.
5183///
5184/// This is the only supported upstream-to-HMC curvature handoff: callers must
5185/// pass a dense Hessian (or an already materialized exact operator stored as a
5186/// dense Hessian) exported by the fitter. We deliberately do not synthesize a
5187/// numerical Hessian and do not invert `beta_covariance` as a compatibility
5188/// fallback, because either path can silently whiten against curvature that the
5189/// upstream fit never certified.
5190pub fn explicit_fit_hessian_for_whitening<'a>(
5191    fit: &'a UnifiedFitResult,
5192    expected_dim: usize,
5193    label: &str,
5194) -> Result<&'a Array2<f64>, String> {
5195    let hessian = fit.penalized_hessian().ok_or_else(|| {
5196        format!(
5197            "{label}: fit result is missing an explicit penalized Hessian for HMC/NUTS whitening"
5198        )
5199    })?;
5200    validate_explicit_dense_hessian_for_whitening(
5201        &format!("{label} penalized Hessian"),
5202        hessian,
5203        expected_dim,
5204    )
5205    .map_err(|err| err.to_string())?;
5206    Ok(hessian)
5207}
5208
5209/// Family-agnostic flattened NUTS entrypoint across all supported likelihood families.
5210pub fn run_nuts_sampling_flattened_family(
5211    likelihood: LikelihoodSpec,
5212    inputs: FamilyNutsInputs<'_>,
5213    config: &NutsConfig,
5214) -> Result<NutsResult, String> {
5215    if let FamilyNutsInputs::Glm(glm) = &inputs
5216        && glm.firth_bias_reduction
5217        && !likelihood_spec_supports_firth(&likelihood)
5218    {
5219        return Err(HmcError::FirthUnsupported {
5220            reason: format!(
5221                "NUTS with Firth requires a Binomial inverse link with a Fisher-weight jet; {} does not support it",
5222                likelihood.pretty_name()
5223            ),
5224        }
5225        .into());
5226    }
5227
5228    match (likelihood.response.clone(), likelihood.link.clone(), inputs) {
5229        (
5230            ResponseFamily::Gaussian,
5231            InverseLink::Standard(StandardLink::Identity),
5232            FamilyNutsInputs::Glm(glm),
5233        ) => run_nuts_sampling(
5234            glm.x,
5235            glm.y,
5236            glm.weights,
5237            glm.penalty_matrix,
5238            glm.mode,
5239            glm.hessian,
5240            NutsFamily::Gaussian,
5241            1.0,
5242            glm.dispersion,
5243            glm.firth_bias_reduction,
5244            glm.offset,
5245            config,
5246        ),
5247        (
5248            ResponseFamily::Binomial,
5249            InverseLink::Standard(StandardLink::Logit),
5250            FamilyNutsInputs::Glm(glm),
5251        ) => {
5252            // Auto-select PG Gibbs when assumptions hold; otherwise fall back to NUTS.
5253            // This gives gradient-free posterior draws for standard Bernoulli logit GAMs.
5254            // The Pólya-Gamma augmentation here assumes η = Xβ (no offset); an
5255            // offset model routes to NUTS, which carries the offset through
5256            // `glm.offset` (#882). PG-with-offset is a valid but separate scheme
5257            // we deliberately do not duplicate.
5258            if !glm.firth_bias_reduction
5259                && glm.offset.is_none()
5260                && glm.weights.iter().all(|w| (*w - 1.0).abs() <= 1e-10)
5261            {
5262                run_logit_polya_gamma_gibbs(
5263                    glm.x,
5264                    glm.y,
5265                    glm.weights,
5266                    glm.penalty_matrix,
5267                    glm.mode,
5268                    config,
5269                )
5270            } else {
5271                run_nuts_sampling(
5272                    glm.x,
5273                    glm.y,
5274                    glm.weights,
5275                    glm.penalty_matrix,
5276                    glm.mode,
5277                    glm.hessian,
5278                    NutsFamily::BinomialLogit,
5279                    1.0,
5280                    glm.dispersion,
5281                    glm.firth_bias_reduction,
5282                    glm.offset,
5283                    config,
5284                )
5285            }
5286        }
5287        (
5288            ResponseFamily::Binomial,
5289            InverseLink::Standard(StandardLink::Probit),
5290            FamilyNutsInputs::Glm(glm),
5291        ) => run_nuts_sampling(
5292            glm.x,
5293            glm.y,
5294            glm.weights,
5295            glm.penalty_matrix,
5296            glm.mode,
5297            glm.hessian,
5298            NutsFamily::BinomialProbit,
5299            1.0,
5300            glm.dispersion,
5301            glm.firth_bias_reduction,
5302            glm.offset,
5303            config,
5304        ),
5305        (
5306            ResponseFamily::Binomial,
5307            InverseLink::Standard(StandardLink::CLogLog),
5308            FamilyNutsInputs::Glm(glm),
5309        ) => run_nuts_sampling(
5310            glm.x,
5311            glm.y,
5312            glm.weights,
5313            glm.penalty_matrix,
5314            glm.mode,
5315            glm.hessian,
5316            NutsFamily::BinomialCLogLog,
5317            1.0,
5318            glm.dispersion,
5319            glm.firth_bias_reduction,
5320            glm.offset,
5321            config,
5322        ),
5323        (
5324            ResponseFamily::Binomial,
5325            InverseLink::LatentCLogLog(_),
5326            FamilyNutsInputs::Glm(glm),
5327        ) => run_nuts_sampling(
5328            glm.x,
5329            glm.y,
5330            glm.weights,
5331            glm.penalty_matrix,
5332            glm.mode,
5333            glm.hessian,
5334            NutsFamily::BinomialCLogLog,
5335            1.0,
5336            glm.dispersion,
5337            glm.firth_bias_reduction,
5338            glm.offset,
5339            config,
5340        ),
5341        (ResponseFamily::Binomial, InverseLink::Mixture(_), FamilyNutsInputs::Glm(_)) => Err(
5342            "BinomialMixture NUTS is not implemented yet; use fit_gam/predict_gam for blended inverse-link models"
5343                .to_string(),
5344        ),
5345        (ResponseFamily::Binomial, InverseLink::Sas(_), FamilyNutsInputs::Glm(_)) => Err(
5346            "BinomialSas NUTS is not implemented yet; use fit_gam/predict_gam for SAS-link models"
5347                .to_string(),
5348        ),
5349        (ResponseFamily::Binomial, InverseLink::BetaLogistic(_), FamilyNutsInputs::Glm(_)) => Err(
5350            "BinomialBetaLogistic NUTS is not implemented yet; use fit_gam/predict_gam for beta-logistic-link models"
5351                .to_string(),
5352        ),
5353        (ResponseFamily::Binomial, InverseLink::Standard(_), FamilyNutsInputs::Glm(_)) => Err(
5354            "NUTS sampling is not implemented for this binomial inverse link".to_string(),
5355        ),
5356        (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Survival(survival)) => {
5357            survival_hmc::run_survival_nuts_sampling(
5358                survival.flat.age_entry,
5359                survival.flat.age_exit,
5360                survival.flat.event_target,
5361                survival.flat.event_competing,
5362                survival.flat.weights,
5363                survival.flat.x_entry,
5364                survival.flat.x_exit,
5365                survival.flat.x_derivative,
5366                survival.flat.eta_offset_entry,
5367                survival.flat.eta_offset_exit,
5368                survival.flat.derivative_offset_exit,
5369                survival.penalties,
5370                survival.monotonicity,
5371                survival.spec,
5372                survival.structurally_monotonic,
5373                survival.structural_time_columns,
5374                survival.mode,
5375                survival.hessian,
5376                config,
5377            )
5378        }
5379        (ResponseFamily::RoystonParmar, _, FamilyNutsInputs::Glm(_)) => Err(
5380            "RoystonParmar family requires FamilyNutsInputs::Survival flattened inputs".to_string(),
5381        ),
5382        (_, _, FamilyNutsInputs::Survival(_)) => Err(
5383            "Survival flattened inputs are only valid for the Royston-Parmar response family"
5384                .to_string(),
5385        ),
5386        (ResponseFamily::Poisson, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5387            glm.x,
5388            glm.y,
5389            glm.weights,
5390            glm.penalty_matrix,
5391            glm.mode,
5392            glm.hessian,
5393            NutsFamily::PoissonLog,
5394            1.0,
5395            glm.dispersion,
5396            glm.firth_bias_reduction,
5397            glm.offset,
5398            config,
5399        ),
5400        (ResponseFamily::Tweedie { p }, _, FamilyNutsInputs::Glm(glm)) => {
5401            // Family mapping: Tweedie payload p is passed through the family-parameter slot.
5402            // The Tweedie dispersion phi remains in glm.dispersion, matching REML.
5403            if !is_valid_tweedie_power(p) {
5404                return Err(format!(
5405                    "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
5406                ));
5407            }
5408            run_nuts_sampling(
5409                glm.x,
5410                glm.y,
5411                glm.weights,
5412                glm.penalty_matrix,
5413                glm.mode,
5414                glm.hessian,
5415                NutsFamily::TweedieLog,
5416                p,
5417                glm.dispersion,
5418                glm.firth_bias_reduction,
5419                glm.offset,
5420                config,
5421            )
5422        }
5423        (ResponseFamily::NegativeBinomial { theta, .. }, _, FamilyNutsInputs::Glm(glm)) => {
5424            // Family mapping: NegativeBinomial payload theta is passed through the family slot.
5425            // NB dispersion scale is unit; theta is not derived from fixed_phi.
5426            run_nuts_sampling(
5427                glm.x,
5428                glm.y,
5429                glm.weights,
5430                glm.penalty_matrix,
5431                glm.mode,
5432                glm.hessian,
5433                NutsFamily::NegativeBinomialLog,
5434                theta,
5435                glm.dispersion,
5436                glm.firth_bias_reduction,
5437                glm.offset,
5438                config,
5439            )
5440        }
5441        (ResponseFamily::Beta { .. }, _, FamilyNutsInputs::Glm(_)) => Err(
5442            "NUTS sampling is not implemented for beta-regression logit".to_string(),
5443        ),
5444        (ResponseFamily::Gamma, _, FamilyNutsInputs::Glm(glm)) => run_nuts_sampling(
5445            glm.x,
5446            glm.y,
5447            glm.weights,
5448            glm.penalty_matrix,
5449            glm.mode,
5450            glm.hessian,
5451            NutsFamily::GammaLog,
5452            glm.gamma_shape.unwrap_or(1.0),
5453            glm.dispersion,
5454            glm.firth_bias_reduction,
5455            glm.offset,
5456            config,
5457        ),
5458        (ResponseFamily::Gaussian, _, FamilyNutsInputs::Glm(_)) => Err(
5459            "NUTS sampling is only implemented for Gaussian with identity link".to_string(),
5460        ),
5461    }
5462}
5463
5464// ============================================================================
5465// Joint (β, θ) Link-Wiggle HMC
5466// ============================================================================
5467//
5468// NUTS sampling over the joint parameter space [β_eta; β_wiggle] for models
5469// with a structurally monotone I-spline link wiggle. The wiggle introduces a
5470// nonlinear coupling:
5471//
5472//   η(β_eta, β_wiggle) = q₀(β_eta) + B(q₀(β_eta)) · β_wiggle
5473//
5474// where B is the shared monotone wiggle basis evaluated at the base linear
5475// predictor q₀ = X · β_eta. The gradient of log p(y|β_eta, β_wiggle) w.r.t.
5476// β_eta picks up a chain-rule factor g'(q₀) = 1 + B'(q₀) · β_wiggle / range_width
5477// from the dependence of B on q₀.
5478//
5479// Whitening uses the Cholesky of the joint Hessian at the mode, exactly as for
5480// the standard NutsPosterior. C^1 linear extension outside the training knot
5481// range prevents basis evaluation discontinuities.
5482
5483/// Fixed spline artifacts for link-wiggle posterior sampling.
5484#[derive(Clone)]
5485pub struct LinkWiggleSplineArtifacts {
5486    /// Knot range (min, max) from training (in standardized [0,1] space of q₀)
5487    pub knot_range: (f64, f64),
5488    /// Full knot vector for the shared monotone I-spline basis
5489    pub knot_vector: Array1<f64>,
5490    /// I-spline degree
5491    pub degree: usize,
5492}
5493
5494/// Whitened log-posterior target for joint (β_eta, β_wiggle) with analytical gradients.
5495#[derive(Clone)]
5496pub struct LinkWigglePosterior {
5497    /// Main design matrix X (n × p_main)
5498    x: Arc<Array2<f64>>,
5499    y: Arc<Array1<f64>>,
5500    weights: Arc<Array1<f64>>,
5501    /// Penalty for main coefficients (p_main × p_main)
5502    penalty_base: Arc<Array2<f64>>,
5503    /// Penalty for wiggle coefficients (p_wiggle × p_wiggle)
5504    penalty_link: Arc<Array2<f64>>,
5505    mode_beta: Arc<Array1<f64>>,
5506    mode_theta: Arc<Array1<f64>>,
5507    spline: LinkWiggleSplineArtifacts,
5508    /// L where LL^T = H^{-1} (joint Hessian)
5509    chol: Array2<f64>,
5510    /// L^T for gradient chain rule
5511    chol_t: Array2<f64>,
5512    p_base: usize,
5513    p_link: usize,
5514    n_samples: usize,
5515    nuts_family: NutsFamily,
5516    /// Family-specific noise parameter: Gaussian sigma or Gamma shape.
5517    scale: f64,
5518    /// Coefficient-covariance scale `cov_scale` (#679/#680 invariant): the
5519    /// `Vb = cov_scale·H⁻¹` multiplier driving both the whitening
5520    /// (`L Lᵀ = cov_scale·H⁻¹`) and the target penalty weight
5521    /// (`penalty_scale = 1/cov_scale`). `σ²` for profiled Gaussian, `1.0` for
5522    /// every weight-carries-dispersion family (Gamma/Tweedie/NB).
5523    cov_scale: f64,
5524}
5525
5526impl LinkWigglePosterior {
5527    /// Standardize q₀ values to [0,1] range using training knot bounds.
5528    #[inline]
5529    fn standardized_z(&self, u: &Array1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
5530        let (min_u, max_u) = self.spline.knot_range;
5531        let rw = (max_u - min_u).max(1e-6);
5532        let z_raw: Array1<f64> = u.mapv(|v| (v - min_u) / rw);
5533        let z_c: Array1<f64> = z_raw.mapv(|z| z.clamp(0.0, 1.0));
5534        (z_raw, z_c, rw)
5535    }
5536
5537    /// Creates a new link-wiggle posterior target.
5538    pub fn new(
5539        x: ArrayView2<f64>,
5540        y: ArrayView1<f64>,
5541        weights: ArrayView1<f64>,
5542        penalty_base: ArrayView2<f64>,
5543        penalty_link: ArrayView2<f64>,
5544        mode_beta: ArrayView1<f64>,
5545        mode_theta: ArrayView1<f64>,
5546        hessian: ArrayView2<f64>,
5547        spline: LinkWiggleSplineArtifacts,
5548        nuts_family: NutsFamily,
5549        scale: f64,
5550    ) -> Result<Self, String> {
5551        let n_samples = x.nrows();
5552        let p_base = x.ncols();
5553        let p_link = mode_theta.len();
5554        let dim = p_base + p_link;
5555        if hessian.nrows() != dim || hessian.ncols() != dim {
5556            return Err(HmcError::DimensionMismatch {
5557                reason: format!(
5558                    "LinkWigglePosterior: Hessian dim mismatch: {}x{} vs expected {}x{}",
5559                    hessian.nrows(),
5560                    hessian.ncols(),
5561                    dim,
5562                    dim,
5563                ),
5564            }
5565            .into());
5566        }
5567        if nuts_family.likelihood_spec().is_binomial() {
5568            validate_binary_responses("binomial link-wiggle NUTS", &y, &weights)
5569                .map_err(String::from)?;
5570        }
5571        if matches!(nuts_family, NutsFamily::NegativeBinomialLog) {
5572            validate_count_responses("negative-binomial link-wiggle NUTS", &y, &weights)
5573                .map_err(String::from)?;
5574        }
5575        // Whitening metric `L Lᵀ = cov_scale · H⁻¹` (#679/#680 invariant), so
5576        // scale `L` by `√cov_scale`. For the link-wiggle joint target `scale`
5577        // is σ (Gaussian), so the profiled-Gaussian covariance scale is
5578        // `cov_scale = σ²`. Every other family folds its dispersion into the
5579        // working weight / the `shape`/`theta` already inside its
5580        // log-likelihood, so `cov_scale = 1` and this is a no-op. The previous
5581        // Gamma branch scaled `L` by `1/√shape = √φ`, mis-preconditioning the
5582        // sampler against `φ·H⁻¹` instead of the correct `H⁻¹` (#680).
5583        let cov_scale = match nuts_family {
5584            NutsFamily::Gaussian => scale * scale,
5585            _ => 1.0,
5586        };
5587        let whitening = hessian_whitening_transform(
5588            hessian,
5589            dim,
5590            cov_scale,
5591            "LinkWigglePosterior Cholesky failed",
5592        )?;
5593        let chol = whitening.chol;
5594        let chol_t = whitening.chol_t;
5595        Ok(Self {
5596            x: Arc::new(x.to_owned()),
5597            y: Arc::new(y.to_owned()),
5598            weights: Arc::new(weights.to_owned()),
5599            penalty_base: Arc::new(penalty_base.to_owned()),
5600            penalty_link: Arc::new(penalty_link.to_owned()),
5601            mode_beta: Arc::new(mode_beta.to_owned()),
5602            mode_theta: Arc::new(mode_theta.to_owned()),
5603            spline,
5604            chol,
5605            chol_t,
5606            p_base,
5607            p_link,
5608            n_samples,
5609            nuts_family,
5610            scale,
5611            cov_scale,
5612        })
5613    }
5614
5615    /// Evaluate the wiggle basis and compute η = q₀ + B(q₀)·θ with C^1 linear extension.
5616    fn evaluate_link(&self, u: &Array1<f64>, theta: &Array1<f64>) -> (Array2<f64>, Array1<f64>) {
5617        let n = u.len();
5618        if theta.is_empty() {
5619            return (Array2::zeros((n, 0)), u.clone());
5620        }
5621
5622        let (z_raw, z_c, _) = self.standardized_z(u);
5623        let Ok(mut basis) = monotone_wiggle_basis_with_derivative_order(
5624            z_c.view(),
5625            &self.spline.knot_vector,
5626            self.spline.degree,
5627            0,
5628        ) else {
5629            return (Array2::zeros((n, theta.len())), u.clone());
5630        };
5631        if basis.ncols() != theta.len() {
5632            return (Array2::zeros((n, theta.len())), u.clone());
5633        }
5634
5635        // C^1 linear extension outside [0, 1]:
5636        // B_ext(z_raw) = B(z_c) + (z_raw - z_c) * B'(z_c)
5637        let mut needs_ext = false;
5638        for i in 0..n {
5639            if (z_raw[i] - z_c[i]).abs() > 1e-12 {
5640                needs_ext = true;
5641                break;
5642            }
5643        }
5644        if needs_ext
5645            && let Ok(b_prime) = monotone_wiggle_basis_with_derivative_order(
5646                z_c.view(),
5647                &self.spline.knot_vector,
5648                self.spline.degree,
5649                1,
5650            )
5651        {
5652            for i in 0..n {
5653                let dz = z_raw[i] - z_c[i];
5654                if dz.abs() <= 1e-12 {
5655                    continue;
5656                }
5657                for j in 0..basis.ncols().min(b_prime.ncols()) {
5658                    basis[[i, j]] += dz * b_prime[[i, j]];
5659                }
5660            }
5661        }
5662        (
5663            basis.clone(),
5664            u + &gam_linalg::faer_ndarray::fast_av(&basis, theta),
5665        )
5666    }
5667
5668    /// Compute dη/dq₀ = 1 + B'(q₀)·θ / range_width (chain-rule factor for β_eta gradient).
5669    fn compute_g_prime(&self, u: &Array1<f64>, theta: &Array1<f64>) -> Array1<f64> {
5670        let n = u.len();
5671        let mut g = Array1::<f64>::ones(n);
5672        let (_, z_c, rw) = self.standardized_z(u);
5673        if theta.is_empty() {
5674            return g;
5675        }
5676
5677        let Ok(b_prime_constrained) = monotone_wiggle_basis_with_derivative_order(
5678            z_c.view(),
5679            &self.spline.knot_vector,
5680            self.spline.degree,
5681            1,
5682        ) else {
5683            return g;
5684        };
5685        if b_prime_constrained.ncols() != theta.len() {
5686            return g;
5687        }
5688        let dwiggle_dz = gam_linalg::faer_ndarray::fast_av(&b_prime_constrained, theta);
5689        ndarray::Zip::from(&mut g)
5690            .and(&dwiggle_dz)
5691            .par_for_each(|gi, &dw| *gi = 1.0 + dw / rw);
5692        g
5693    }
5694
5695    fn compute_logp_and_grad_into(&self, z: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5696        let dim = self.p_base + self.p_link;
5697
5698        // Un-whiten: q = mode + L·z
5699        let mut mode = Array1::<f64>::zeros(dim);
5700        mode.slice_mut(ndarray::s![0..self.p_base])
5701            .assign(&self.mode_beta);
5702        mode.slice_mut(ndarray::s![self.p_base..])
5703            .assign(&self.mode_theta);
5704        let q = &mode + &self.chol.dot(z);
5705        let beta = q.slice(ndarray::s![0..self.p_base]).to_owned();
5706        let theta = q.slice(ndarray::s![self.p_base..]).to_owned();
5707
5708        // Compute η = q₀ + B(q₀)·θ where q₀ = X·β
5709        let u = gam_linalg::faer_ndarray::fast_av(self.x.as_ref(), &beta);
5710        let (bwiggle, eta) = self.evaluate_link(&u, &theta);
5711
5712        // Log-likelihood and residuals via family dispatch
5713        let ll;
5714        let mut residual = Array1::<f64>::zeros(self.n_samples);
5715        match self.nuts_family {
5716            NutsFamily::Gaussian => {
5717                let inv_scale_sq = 1.0 / (self.scale * self.scale).max(1e-10);
5718                let mut ll_acc = 0.0;
5719                for i in 0..self.n_samples {
5720                    let r = self.y[i] - eta[i];
5721                    let w = self.weights[i];
5722                    ll_acc -= 0.5 * w * r * r * inv_scale_sq;
5723                    residual[i] = w * r * inv_scale_sq;
5724                }
5725                ll = ll_acc;
5726            }
5727            NutsFamily::BinomialLogit => {
5728                let mut ll_acc = 0.0;
5729                for i in 0..self.n_samples {
5730                    let eta_i = eta[i];
5731                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5732                    ll_acc += w_i * (y_i * eta_i - gam_linalg::utils::stable_softplus(eta_i));
5733                    let mu = gam_linalg::utils::stable_logistic(eta_i);
5734                    residual[i] = w_i * (y_i - mu);
5735                }
5736                ll = ll_acc;
5737            }
5738            NutsFamily::BinomialProbit => {
5739                let mut ll_acc = 0.0;
5740                for i in 0..self.n_samples {
5741                    let eta_i = eta[i];
5742                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5743                    let log_phi_pos = log_ndtr(eta_i);
5744                    let log_phi_neg = log_ndtr(-eta_i);
5745                    ll_acc += w_i * (y_i * log_phi_pos + (1.0 - y_i) * log_phi_neg);
5746                    let log_phi = standard_normal_log_pdf(eta_i);
5747                    let ratio_pos = (log_phi - log_phi_pos).exp();
5748                    let ratio_neg = (log_phi - log_phi_neg).exp();
5749                    residual[i] = w_i * (y_i * ratio_pos - (1.0 - y_i) * ratio_neg);
5750                }
5751                ll = ll_acc;
5752            }
5753            NutsFamily::BinomialCLogLog => {
5754                let mut ll_acc = 0.0;
5755                for i in 0..self.n_samples {
5756                    let eta_i = eta[i];
5757                    if !(eta_i.is_finite() && (-700.0..=700.0).contains(&eta_i)) {
5758                        grad.fill(0.0);
5759                        return f64::NEG_INFINITY;
5760                    }
5761                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5762                    let (ll_i, residual_i) = match cloglog_bernoulli_logp_and_residual(eta_i, y_i) {
5763                        Ok(values) => values,
5764                        Err(_) => {
5765                            grad.fill(0.0);
5766                            return f64::NEG_INFINITY;
5767                        }
5768                    };
5769                    ll_acc += w_i * ll_i;
5770                    residual[i] = w_i * residual_i;
5771                }
5772                ll = ll_acc;
5773            }
5774            NutsFamily::PoissonLog => {
5775                let mut ll_acc = 0.0;
5776                for i in 0..self.n_samples {
5777                    let eta_i = eta[i];
5778                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5779                        grad.fill(0.0);
5780                        return f64::NEG_INFINITY;
5781                    }
5782                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5783                    let mu = eta_i.exp();
5784                    ll_acc += w_i * (y_i * eta_i - mu);
5785                    residual[i] = w_i * (y_i - mu);
5786                }
5787                ll = ll_acc;
5788            }
5789            NutsFamily::TweedieLog => {
5790                let mut ll_acc = 0.0;
5791                // Family mapping: Tweedie scale carries payload p; phi is not stored here.
5792                // Invalid p makes the link-wiggle target invalid instead of defaulting.
5793                if !is_valid_tweedie_power(self.scale) {
5794                    grad.fill(0.0);
5795                    return f64::NEG_INFINITY;
5796                }
5797                let p = self.scale;
5798                for i in 0..self.n_samples {
5799                    let eta_i = eta[i];
5800                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5801                        grad.fill(0.0);
5802                        return f64::NEG_INFINITY;
5803                    }
5804                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5805                    let mu = eta_i.exp().max(1e-300);
5806                    ll_acc +=
5807                        w_i * (y_i * mu.powf(1.0 - p) / (1.0 - p) - mu.powf(2.0 - p) / (2.0 - p));
5808                    residual[i] = w_i * (y_i - mu) * mu.powf(1.0 - p);
5809                }
5810                ll = ll_acc;
5811            }
5812            NutsFamily::NegativeBinomialLog => {
5813                let mut ll_acc = 0.0;
5814                // Family mapping: NegativeBinomial scale carries payload theta.
5815                // Invalid theta makes the link-wiggle target invalid instead of clamping.
5816                if !(self.scale.is_finite() && self.scale > 0.0) {
5817                    grad.fill(0.0);
5818                    return f64::NEG_INFINITY;
5819                }
5820                let theta = self.scale;
5821                for i in 0..self.n_samples {
5822                    let eta_i = eta[i];
5823                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5824                        grad.fill(0.0);
5825                        return f64::NEG_INFINITY;
5826                    }
5827                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5828                    if w_i <= 0.0 {
5829                        residual[i] = 0.0;
5830                        continue;
5831                    }
5832                    let mu = eta_i.exp().max(1e-12);
5833                    let log_mu_term = if y_i > 0.0 { y_i * mu.ln() } else { 0.0 };
5834                    ll_acc += w_i
5835                        * (statrs::function::gamma::ln_gamma(y_i + theta)
5836                            - statrs::function::gamma::ln_gamma(theta)
5837                            - statrs::function::gamma::ln_gamma(y_i + 1.0)
5838                            + theta * (theta.ln() - (theta + mu).ln())
5839                            + log_mu_term
5840                            - y_i * (theta + mu).ln());
5841                    residual[i] = w_i * theta * (y_i - mu) / (theta + mu);
5842                }
5843                ll = ll_acc;
5844            }
5845            NutsFamily::GammaLog => {
5846                let mut ll_acc = 0.0;
5847                let shape = self.scale.max(1e-10);
5848                for i in 0..self.n_samples {
5849                    let eta_i = eta[i];
5850                    if !(eta_i.is_finite() && (-30.0..=30.0).contains(&eta_i)) {
5851                        grad.fill(0.0);
5852                        return f64::NEG_INFINITY;
5853                    }
5854                    let (y_i, w_i) = (self.y[i], self.weights[i]);
5855                    let mu = eta_i.exp();
5856                    ll_acc += w_i * shape * (-y_i / mu - eta_i);
5857                    residual[i] = w_i * shape * (y_i / mu - 1.0);
5858                }
5859                ll = ll_acc;
5860            }
5861        }
5862
5863        // Penalty weight = 1/cov_scale (#679/#680 invariant), matching the
5864        // factor the likelihood already carries so the prior and likelihood
5865        // live on the same scale and the MAP-anchored target curvature equals
5866        // `Vb⁻¹ = H/cov_scale`. The Gaussian block above multiplies through by
5867        // `1/σ²`, so `penalty_scale = 1/σ²`. The Gamma block carries an explicit
5868        // `shape = 1/φ` factor in its score (`w_i·shape·(y/μ − 1)`) — that is
5869        // the *data* Fisher information, already folded into the working
5870        // weight, so the penalty must stay UNSCALED (`cov_scale = 1`,
5871        // `penalty_scale = 1`). The previous code used `penalty_scale = shape`
5872        // for Gamma, double-counting the dispersion in the sampled posterior
5873        // and shrinking every posterior SD by `√φ` (#680). Tweedie/NB/Poisson/
5874        // Binomial are unit-scale and unchanged.
5875        let penalty_scale = 1.0 / self.cov_scale.max(1e-300);
5876
5877        // Gradient w.r.t. θ (wiggle): ∂ℓ/∂θ = B(q₀)^T · residual − S_link · θ
5878        let s_link_theta = self.penalty_link.dot(&theta);
5879        let grad_theta = &fast_atv(&bwiggle, &residual) - &(&s_link_theta * penalty_scale);
5880
5881        // Gradient w.r.t. β_eta: ∂ℓ/∂β = X^T · (residual ⊙ g'(q₀)) − S_base · β
5882        // where g'(q₀) = dη/dq₀ is the chain-rule factor
5883        let g_prime = self.compute_g_prime(&u, &theta);
5884        let r_scaled: Array1<f64> = residual
5885            .iter()
5886            .zip(g_prime.iter())
5887            .map(|(&r, &g)| r * g)
5888            .collect();
5889        let s_base_beta = self.penalty_base.dot(&beta);
5890        let grad_beta = &fast_atv(&self.x, &r_scaled) - &(&s_base_beta * penalty_scale);
5891
5892        // Penalty (also φ-scaled for Gaussian; see `penalty_scale` above).
5893        let penalty =
5894            penalty_scale * (0.5 * beta.dot(&s_base_beta) + 0.5 * theta.dot(&s_link_theta));
5895
5896        // Assemble joint gradient and transform to whitened space
5897        let mut grad_q = Array1::<f64>::zeros(dim);
5898        grad_q
5899            .slice_mut(ndarray::s![0..self.p_base])
5900            .assign(&grad_beta);
5901        grad_q
5902            .slice_mut(ndarray::s![self.p_base..])
5903            .assign(&grad_theta);
5904        fast_av_into(&self.chol_t, &grad_q, grad);
5905        ll - penalty
5906    }
5907
5908    /// Get the Cholesky factor L for un-whitening samples.
5909    pub fn chol(&self) -> &Array2<f64> {
5910        &self.chol
5911    }
5912
5913    /// Get the mode [β_eta; β_wiggle].
5914    pub fn mode_joint(&self) -> Array1<f64> {
5915        let dim = self.p_base + self.p_link;
5916        let mut mode = Array1::<f64>::zeros(dim);
5917        mode.slice_mut(ndarray::s![0..self.p_base])
5918            .assign(&self.mode_beta);
5919        mode.slice_mut(ndarray::s![self.p_base..])
5920            .assign(&self.mode_theta);
5921        mode
5922    }
5923}
5924
5925impl HamiltonianTarget<Array1<f64>> for LinkWigglePosterior {
5926    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
5927        self.compute_logp_and_grad_into(position, grad)
5928    }
5929}
5930
5931/// Runs NUTS sampling for joint (β_eta, β_wiggle) in a link-wiggle model.
5932pub fn run_link_wiggle_nuts_sampling(
5933    x: ArrayView2<f64>,
5934    y: ArrayView1<f64>,
5935    weights: ArrayView1<f64>,
5936    penalty_base: ArrayView2<f64>,
5937    penalty_link: ArrayView2<f64>,
5938    mode_beta: ArrayView1<f64>,
5939    mode_theta: ArrayView1<f64>,
5940    hessian: ArrayView2<f64>,
5941    spline: LinkWiggleSplineArtifacts,
5942    nuts_family: NutsFamily,
5943    scale: f64,
5944    config: &NutsConfig,
5945) -> Result<NutsResult, String> {
5946    validate_nuts_config(config).map_err(String::from)?;
5947    let dim = mode_beta.len() + mode_theta.len();
5948    let target = LinkWigglePosterior::new(
5949        x,
5950        y,
5951        weights,
5952        penalty_base,
5953        penalty_link,
5954        mode_beta,
5955        mode_theta,
5956        hessian,
5957        spline,
5958        nuts_family,
5959        scale,
5960    )?;
5961    let chol = target.chol().clone();
5962    let mode_arr = target.mode_joint();
5963
5964    let initial_positions = jittered_initial_positions(config, dim, 0.1, 0x8C48_0F65_3A2B_D917);
5965
5966    let mass_cfg = robust_mass_matrix_config(dim, config.nwarmup);
5967    let (result, run_stats) = run_whitened_nuts_result(
5968        target,
5969        &mode_arr,
5970        &chol,
5971        initial_positions,
5972        config,
5973        dim,
5974        mass_cfg,
5975        0x2E31_A4B6_C908_F57D,
5976        "Link-wiggle NUTS sampling failed",
5977        Array1::zeros(dim),
5978        NutsConvergenceThresholds {
5979            max_rhat: 1.1,
5980            min_ess: Some(100.0),
5981        },
5982    )?;
5983    log::info!("Link-wiggle NUTS sampling complete: {}", run_stats);
5984
5985    Ok(result)
5986}
5987
5988// ============================================================================
5989// Joint (β, ρ) HMC for Skewed Posteriors
5990// ============================================================================
5991//
5992// When the Laplace approximation to the marginal likelihood is unreliable
5993// (high posterior skewness), we bypass LAML entirely and sample from the
5994// joint posterior p(β, ρ | y) ∝ p(y|β) p(β|ρ) p(ρ).
5995//
5996// The joint log-posterior is:
5997//   log p(β, ρ | y) = ℓ(y|β) + Φ(β) [if Firth]
5998//                    - 0.5 β'S(ρ)β + 0.5 log|S(ρ)|_+ + log p(ρ) + const
5999//
6000// Gradients:
6001//   ∇_β: ∇_β ℓ + ∇_β Φ(β) [if Firth] - S(ρ) β
6002//   ∂/∂ρ_k: -0.5 λ_k β'S_k β + 0.5 tr(S_+⁻¹ A_k) + ∂log p(ρ)/∂ρ_k
6003//
6004// This completely avoids the Laplace approximation. When Firth bias reduction
6005// is active, the sampled target also includes the Jeffreys term Φ(β) in
6006// addition to the smoothing-parameter prior.
6007
6008/// Directional cubic non-Gaussianity diagnostic for the Laplace approximation.
6009///
6010/// For each positive-curvature Hessian eigenpair `(lambda_r, v_r)`, this computes
6011///
6012///   gamma_r = T[v_r, v_r, v_r] / lambda_r^(3/2)
6013///            = Σ_i c_i (x_i^T v_r)^3 / lambda_r^(3/2),
6014///
6015/// and reports `max_r |gamma_r|`. This is invariant to arbitrary coordinate
6016/// relabeling and uses the full directional cubic contraction rather than only
6017/// diagonal tensor entries.
6018/// `refine_supremum` controls Phase 2, the cubic power-iteration that sharpens
6019/// the returned scalar `max_abs` toward the true supremum of `|γ(u)|` over the
6020/// H-unit sphere (which can exceed the per-eigenvector maximum). That scalar is
6021/// the ONLY thing Phase 2 affects — the per-direction `directional` vector,
6022/// which drives [`laplace_trustworthiness_from_skewness`]'s direction selection
6023/// AND its own internally-recomputed `max_abs_skewness`, comes entirely from
6024/// Phase 1. The #784 block-local REML correction
6025/// (`block_local_sampled_correction`) consumes `directional` and uses `max_abs`
6026/// only for a `> 0` finiteness guard that Phase 1 already satisfies, so it
6027/// passes `false` and skips Phase 2's multi-probe O(probes·iters·np) refinement
6028/// on every inner evaluation. Diagnostic callers that report the true supremum
6029/// pass `true`.
6030pub fn laplace_directional_cubic_diagnostic(
6031    hessian: &Array2<f64>,
6032    design: &DesignMatrix,
6033    c_weights: &Array1<f64>,
6034    refine_supremum: bool,
6035) -> Result<(f64, Array1<f64>), String> {
6036    let p = hessian.nrows();
6037    if p == 0 || hessian.ncols() != p {
6038        return Ok((0.0, Array1::zeros(0)));
6039    }
6040
6041    let sym_h = (hessian + &hessian.t()) * 0.5;
6042    let (evals, evecs) = sym_h
6043        .eigh(Side::Lower)
6044        .map_err(|e| format!("directional cubic diagnostic eigendecomposition failed: {e}"))?;
6045    let max_eval = evals.iter().fold(0.0_f64, |acc, &ev| acc.max(ev.abs()));
6046    let tol = (max_eval * 1.0e-12).max(1.0e-14);
6047    let mut directional = Array1::<f64>::zeros(p);
6048    let mut max_abs = 0.0_f64;
6049
6050    // Build the whitening transform L^{-1} where H = L L^T, so that
6051    // the standardized cubic along whitened direction u is:
6052    //   gamma(u) = T[L^{-T}u, L^{-T}u, L^{-T}u]  for ||u||=1
6053    // Eigenvector directions v_r satisfy u_r = lambda_r^{1/2} v_r (after
6054    // appropriate normalization), so gamma_r = T[v_r,v_r,v_r] / lambda_r^{3/2}.
6055
6056    // Phase 1: evaluate gamma_r for all positive-curvature eigenvectors.
6057    for r in 0..p {
6058        let lambda = evals[r];
6059        if lambda <= tol {
6060            continue;
6061        }
6062        let v = evecs.column(r);
6063        let gamma = directional_cubic_contraction(design, c_weights, &v) / lambda.powf(1.5);
6064        directional[r] = if gamma.is_finite() { gamma } else { 0.0 };
6065        max_abs = max_abs.max(directional[r].abs());
6066    }
6067
6068    // Phase 2: power-iteration refinement in whitened space.
6069    //
6070    // The supremum of |gamma(u)| over ||u||_H=1 can exceed the max over
6071    // eigenvectors. We approximate it with a few rounds of cubic power
6072    // iteration: given current direction v, the gradient of T[v,v,v] w.r.t.
6073    // v on the H-unit sphere is 3 T[·,v,v] projected onto the tangent space.
6074    // Since T[·,v,v] = X^T diag(c_i (x_i^T v)^2) which is a matrix-vector
6075    // product, each iteration is O(np).
6076    //
6077    // We seed from the eigenvector with largest |gamma_r| and also from a
6078    // few random probe directions.
6079    if refine_supremum && p >= 2 {
6080        // Build H^{-1/2} columns for whitening: H^{-1/2} = V diag(1/sqrt(lam)) V^T
6081        // We need it to map whitened u -> original v = H^{-1/2} u, and
6082        // H^{1/2} to project back: H^{1/2} v = V diag(sqrt(lam)) V^T v.
6083        let positive_mask: Vec<bool> = evals.iter().map(|&ev| ev > tol).collect();
6084        let n_pos = positive_mask.iter().filter(|&&m| m).count();
6085        if n_pos >= 2 {
6086            let max_abs_from_probes = cubic_power_iteration_refinement(
6087                design,
6088                c_weights,
6089                &evals,
6090                &evecs,
6091                &positive_mask,
6092                n_pos,
6093            );
6094            if max_abs_from_probes > max_abs {
6095                max_abs = max_abs_from_probes;
6096            }
6097        }
6098    }
6099
6100    Ok((max_abs, directional))
6101}
6102
6103/// Compute T[v,v,v] = Σ_i c_i (x_i^T v)^3 for a given direction v.
6104fn directional_cubic_contraction(
6105    design: &DesignMatrix,
6106    c_weights: &Array1<f64>,
6107    v: &ArrayView1<f64>,
6108) -> f64 {
6109    match design.as_sparse() {
6110        Some(x_sparse) => {
6111            let (symbolic, values) = x_sparse.as_ref().parts();
6112            let col_ptr = symbolic.col_ptr();
6113            let row_idx = symbolic.row_idx();
6114            let mut row_scores = vec![0.0_f64; x_sparse.nrows()];
6115            for col in 0..x_sparse.ncols() {
6116                let coeff = v[col];
6117                for ptr in col_ptr[col]..col_ptr[col + 1] {
6118                    row_scores[row_idx[ptr]] += values[ptr] * coeff;
6119                }
6120            }
6121            let mut cubic = 0.0_f64;
6122            for i in 0..row_scores.len().min(c_weights.len()) {
6123                cubic += c_weights[i] * row_scores[i].powi(3);
6124            }
6125            cubic
6126        }
6127        None => {
6128            let x_dense = design.to_dense_cow();
6129            let x_dense = x_dense.as_ref();
6130            let mut cubic = 0.0_f64;
6131            for i in 0..x_dense.nrows().min(c_weights.len()) {
6132                let proj = x_dense.row(i).dot(v);
6133                cubic += c_weights[i] * proj.powi(3);
6134            }
6135            cubic
6136        }
6137    }
6138}
6139
6140/// Compute the gradient of T[v,v,v] w.r.t. v:  3 X^T diag(c_i (x_i^T v)^2) 1.
6141/// More precisely: ∂/∂v T[v,v,v] = 3 Σ_i c_i (x_i^T v)^2 x_i.
6142fn directional_cubic_gradient(
6143    design: &DesignMatrix,
6144    c_weights: &Array1<f64>,
6145    v: &Array1<f64>,
6146) -> Array1<f64> {
6147    let p = v.len();
6148    match design.as_sparse() {
6149        Some(x_sparse) => {
6150            let (symbolic, values) = x_sparse.as_ref().parts();
6151            let col_ptr = symbolic.col_ptr();
6152            let row_idx = symbolic.row_idx();
6153            let n = x_sparse.nrows();
6154            let mut row_scores = vec![0.0_f64; n];
6155            for col in 0..x_sparse.ncols() {
6156                let coeff = v[col];
6157                for ptr in col_ptr[col]..col_ptr[col + 1] {
6158                    row_scores[row_idx[ptr]] += values[ptr] * coeff;
6159                }
6160            }
6161            // quadratic weights: 3 c_i (x_i^T v)^2
6162            let mut quad_weights = vec![0.0_f64; n];
6163            for i in 0..n.min(c_weights.len()) {
6164                quad_weights[i] = 3.0 * c_weights[i] * row_scores[i] * row_scores[i];
6165            }
6166            // X^T quad_weights
6167            let mut grad = Array1::<f64>::zeros(p);
6168            for col in 0..x_sparse.ncols() {
6169                let mut acc = 0.0_f64;
6170                for ptr in col_ptr[col]..col_ptr[col + 1] {
6171                    acc += values[ptr] * quad_weights[row_idx[ptr]];
6172                }
6173                grad[col] = acc;
6174            }
6175            grad
6176        }
6177        None => {
6178            let x_dense = design.to_dense_cow();
6179            let x_dense = x_dense.as_ref();
6180            let n = x_dense.nrows();
6181            let mut grad = Array1::<f64>::zeros(p);
6182            for i in 0..n.min(c_weights.len()) {
6183                let proj = x_dense.row(i).dot(v);
6184                let w = 3.0 * c_weights[i] * proj * proj;
6185                // scaled_add works with any ArrayBase reference.
6186                let row = x_dense.row(i);
6187                for j in 0..p {
6188                    grad[j] += w * row[j];
6189                }
6190            }
6191            grad
6192        }
6193    }
6194}
6195
6196/// Power-iteration refinement for the supremum of |gamma(u)| over ||u||_H = 1.
6197///
6198/// Seeds from the best eigenvector direction plus deterministic probe
6199/// directions constructed from pairs of eigenvectors. Runs a few Riemannian
6200/// gradient ascent steps on the whitened unit sphere.
6201fn cubic_power_iteration_refinement(
6202    design: &DesignMatrix,
6203    c_weights: &Array1<f64>,
6204    evals: &Array1<f64>,
6205    evecs: &Array2<f64>,
6206    positive_mask: &[bool],
6207    n_pos: usize,
6208) -> f64 {
6209    let p = evals.len();
6210    let max_probes = 8;
6211    let max_iters = 5;
6212
6213    // Helper: convert whitened u -> original v = Σ_r (u_r / sqrt(lam_r)) * evec_r
6214    // (only over positive eigenspace).
6215    let to_original = |u: &Array1<f64>| -> Array1<f64> {
6216        let mut v = Array1::<f64>::zeros(p);
6217        let mut idx = 0;
6218        for r in 0..p {
6219            if positive_mask[r] {
6220                let scale = u[idx] / evals[r].sqrt();
6221                let col = evecs.column(r);
6222                for j in 0..p {
6223                    v[j] += scale * col[j];
6224                }
6225                idx += 1;
6226            }
6227        }
6228        v
6229    };
6230
6231    // Helper: project original-space vector to whitened: u_j = sqrt(lam_r) (evec_r^T g)
6232    let to_whitened = |g: &Array1<f64>| -> Array1<f64> {
6233        let mut u = Array1::<f64>::zeros(n_pos);
6234        let mut idx = 0;
6235        for r in 0..p {
6236            if positive_mask[r] {
6237                u[idx] = evals[r].sqrt() * evecs.column(r).dot(g);
6238                idx += 1;
6239            }
6240        }
6241        u
6242    };
6243
6244    // Evaluate |gamma(u)| for whitened direction u.
6245    let eval_gamma = |u: &Array1<f64>| -> f64 {
6246        let norm = u.dot(u).sqrt();
6247        if norm < 1e-30 {
6248            return 0.0;
6249        }
6250        let u_normed: Array1<f64> = u / norm;
6251        let v = to_original(&u_normed);
6252        // gamma = T[v,v,v] since v already has ||v||_H = 1
6253        let cubic = directional_cubic_contraction(design, c_weights, &v.view());
6254        if cubic.is_finite() { cubic.abs() } else { 0.0 }
6255    };
6256
6257    // One step of Riemannian gradient ascent on the whitened sphere for |T[v,v,v]|.
6258    let refine_step = |u: &Array1<f64>| -> Array1<f64> {
6259        let norm = u.dot(u).sqrt();
6260        if norm < 1e-30 {
6261            return u.clone();
6262        }
6263        let u_normed: Array1<f64> = u / norm;
6264        let v = to_original(&u_normed);
6265        // Gradient of T[v,v,v] w.r.t. v in original space
6266        let grad_v = directional_cubic_gradient(design, c_weights, &v);
6267        // Map to whitened space
6268        let mut grad_u = to_whitened(&grad_v);
6269        // Project onto tangent plane of sphere: grad - (grad . u) u
6270        let dot = grad_u.dot(&u_normed);
6271        grad_u.scaled_add(-dot, &u_normed);
6272        // Sign: we want to maximize |T|, so follow sign(T) * grad
6273        let cubic_val = directional_cubic_contraction(design, c_weights, &v.view());
6274        let sign = if cubic_val >= 0.0 { 1.0 } else { -1.0 };
6275        let step_size = 0.3;
6276        let mut u_new = &u_normed + &(&grad_u * (sign * step_size));
6277        let new_norm = u_new.dot(&u_new).sqrt();
6278        if new_norm > 1e-30 {
6279            u_new /= new_norm;
6280        }
6281        u_new
6282    };
6283
6284    let mut best = 0.0_f64;
6285
6286    // Build seed directions:
6287    // (a) The eigenvector with largest |gamma_r| (already computed by caller,
6288    //     but we re-derive the whitened form here).
6289    // (b) Deterministic probe directions from pairs of top eigenvectors:
6290    //     (e_i + e_j) / sqrt(2) and (e_i - e_j) / sqrt(2) in whitened space.
6291    let mut seeds: Vec<Array1<f64>> = Vec::with_capacity(max_probes);
6292
6293    // Seed (a): each eigenvector is a standard basis vector in whitened space.
6294    // Find the one with largest |gamma|.
6295    let mut best_eig_idx = 0;
6296    let mut best_eig_gamma = 0.0_f64;
6297    for j in 0..n_pos {
6298        let mut u = Array1::<f64>::zeros(n_pos);
6299        u[j] = 1.0;
6300        let g = eval_gamma(&u);
6301        if g > best_eig_gamma {
6302            best_eig_gamma = g;
6303            best_eig_idx = j;
6304        }
6305    }
6306    best = best.max(best_eig_gamma);
6307    let mut u_best = Array1::<f64>::zeros(n_pos);
6308    u_best[best_eig_idx] = 1.0;
6309    seeds.push(u_best);
6310
6311    // Seed (b): pairwise combinations of the top few eigenvectors.
6312    let n_top = n_pos.min(4);
6313    for i in 0..n_top {
6314        for j in (i + 1)..n_top {
6315            if seeds.len() >= max_probes {
6316                break;
6317            }
6318            let inv_sqrt2 = std::f64::consts::FRAC_1_SQRT_2;
6319            let mut u_plus = Array1::<f64>::zeros(n_pos);
6320            u_plus[i] = inv_sqrt2;
6321            u_plus[j] = inv_sqrt2;
6322            seeds.push(u_plus);
6323            if seeds.len() < max_probes {
6324                let mut u_minus = Array1::<f64>::zeros(n_pos);
6325                u_minus[i] = inv_sqrt2;
6326                u_minus[j] = -inv_sqrt2;
6327                seeds.push(u_minus);
6328            }
6329        }
6330    }
6331
6332    // Run power iteration from each seed.
6333    for seed in &seeds {
6334        let mut u = seed.clone();
6335        for _ in 0..max_iters {
6336            u = refine_step(&u);
6337        }
6338        let g = eval_gamma(&u);
6339        best = best.max(g);
6340    }
6341
6342    best
6343}
6344
6345// ───────────────── #1521 laplace-sampler contract re-exports ─────────────────
6346//
6347// The neutral DATA carriers + the caller-supplied [`BlockExcessTarget`]
6348// evaluator + the pure threshold math were contract-downed to the neutral
6349// `gam-problem` crate (#1521) so gam-solve (whose `Gam784BlockTarget`
6350// IMPLEMENTS `BlockExcessTarget`) and this gam-inference-tier sampler share one
6351// set of types without an SCC edge. The COMPUTATION (NUTS, importance sampling,
6352// the directional-cubic eigen diagnostic) stays UP in this module and
6353// constructs these types under their original names via this re-export.
6354pub use gam_problem::laplace_sampler_contract::{
6355    BlockExcessTarget, BlockSampledMarginal, BlockSampledMoments, GaussianModePosterior,
6356    LaplaceTrustworthiness, laplace_skewness_threshold, laplace_trustworthiness_from_skewness,
6357};
6358
6359/// Monolith (gam-inference-tier) implementor of the contract-downed
6360/// [`LaplaceMarginalSampler`](gam_problem::laplace_sampler_contract::LaplaceMarginalSampler):
6361/// wraps the `hmc_io` directional-cubic eigen diagnostic and the
6362/// importance-sampled #784 block correction. Registered at process init via
6363/// `gam_problem::laplace_sampler_contract::set_laplace_marginal_sampler`.
6364pub struct HmcIoLaplaceMarginalSampler;
6365
6366impl gam_problem::laplace_sampler_contract::LaplaceMarginalSampler for HmcIoLaplaceMarginalSampler {
6367    fn directional_cubic_diagnostic(
6368        &self,
6369        hessian: &Array2<f64>,
6370        design: &DesignMatrix,
6371        c_weights: &Array1<f64>,
6372        refine_supremum: bool,
6373    ) -> Result<(f64, Array1<f64>), String> {
6374        laplace_directional_cubic_diagnostic(hessian, design, c_weights, refine_supremum)
6375    }
6376
6377    fn block_sampled_marginal_correction(
6378        &self,
6379        target: &dyn BlockExcessTarget,
6380    ) -> Result<BlockSampledMarginal, String> {
6381        block_sampled_marginal_correction(target)
6382    }
6383}
6384
6385/// Monolith (gam-inference-tier) implementor of the contract-downed
6386/// [`GaussianModePosteriorSampler`](gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler):
6387/// the never-fail Gaussian mode-posterior rung. Builds the NUTS config from the
6388/// problem dimension internally (so `NutsConfig` never crosses the contract)
6389/// and wraps `hmc_io::sample_gaussian_mode_posterior`. Registered at process
6390/// init via `gam_problem::laplace_sampler_contract::set_gaussian_mode_posterior_sampler`.
6391pub struct HmcIoGaussianModePosteriorSampler;
6392
6393impl gam_problem::laplace_sampler_contract::GaussianModePosteriorSampler
6394    for HmcIoGaussianModePosteriorSampler
6395{
6396    fn sample_gaussian_mode_posterior(
6397        &self,
6398        mode: ArrayView1<f64>,
6399        precision: ArrayView2<f64>,
6400    ) -> Result<GaussianModePosterior, String> {
6401        let config = NutsConfig::for_dimension(mode.len());
6402        sample_gaussian_mode_posterior(mode, precision, &config)
6403    }
6404}
6405
6406/// Auto-derive the number of importance draws for the block-local sampled
6407/// marginalization from the block dimension.  MAGIC: more directions need more
6408/// draws to control the importance-weight variance, but the block is small by
6409/// construction (only the curvature-heavy directions), so this stays cheap.
6410/// No CLI flag.
6411fn block_sampling_draws(block_dim: usize) -> usize {
6412    // Base budget plus a per-direction allowance; capped so a pathological
6413    // block can never make a single inner evaluation explode.
6414    const BASE: usize = 256;
6415    const PER_DIM: usize = 256;
6416    const CAP: usize = 4096;
6417    (BASE + PER_DIM * block_dim).min(CAP)
6418}
6419
6420/// Estimate the block-local sampled marginal correction `Δ_b` and its
6421/// ρ-gradient by importance sampling against the local Laplace Gaussian
6422/// (issue #784).
6423///
6424/// # Math
6425///
6426/// Draw `t_s ~ q = N(0, diag(1/λ_r))` (the local Laplace Gaussian in the block
6427/// subspace; whitened draws `z_s ~ N(0, I)` give `t_{s,r} = z_{s,r}/√λ_r`).
6428/// With the non-Gaussian remainder `ΔF` defined on [`BlockExcessTarget`],
6429///
6430///   exp(Δ_b) = E_q[ exp(−ΔF(t)) ]  ⇒  Δ_b = log mean_s exp(−ΔF(t_s)),
6431///
6432/// computed via a numerically-stable log-mean-exp.  The ρ-gradient follows
6433/// from differentiating `Δ_b = log E_q[e^{−ΔF}]` (the `q`-Gaussian normalizer
6434/// `½Σ log(2π/λ_r)` cancels against `A_Lap`, leaving only the `ΔF` channel):
6435///
6436///   ∂Δ_b/∂ρ_k = E_p[ −∂ΔF/∂ρ_k ],   p ∝ q·e^{−ΔF},
6437///
6438/// i.e. the self-normalized importance-weighted average of `−∂ΔF/∂ρ_k` over the
6439/// same draws.  Because value and gradient come from one set of draws and one
6440/// target, they are mutually consistent — the contract the outer REML needs.
6441///
6442/// Determinism: draws come from a fixed-seed RNG so the inner evaluation is a
6443/// pure function of `(β̂, H, ρ)` and the outer optimizer sees a smooth,
6444/// reproducible objective rather than Monte-Carlo jitter across evaluations.
6445pub fn block_sampled_marginal_correction<T: BlockExcessTarget + ?Sized>(
6446    target: &T,
6447) -> Result<BlockSampledMarginal, String> {
6448    use rand::SeedableRng;
6449    use rand::rngs::StdRng;
6450
6451    let m = target.block_dim();
6452    let k = target.rho_dim();
6453    if m == 0 {
6454        return Ok(BlockSampledMarginal {
6455            value: 0.0,
6456            rho_gradient: Array1::zeros(k),
6457            importance_ess: 0.0,
6458            n_draws: 0,
6459            moments: None,
6460        });
6461    }
6462    let lambdas = target.block_curvatures();
6463    if lambdas.len() != m {
6464        return Err(format!(
6465            "block_sampled_marginal_correction: block_curvatures len {} != block_dim {m}",
6466            lambdas.len()
6467        ));
6468    }
6469    let inv_sqrt_lambda: Array1<f64> = lambdas.mapv(|l| {
6470        if l > 0.0 {
6471            1.0 / l.sqrt()
6472        } else {
6473            // A non-positive block curvature means the mode is not a strict
6474            // minimum in this direction; the Laplace Gaussian is undefined
6475            // there. Reject rather than fabricate a correction.
6476            f64::NAN
6477        }
6478    });
6479    if inv_sqrt_lambda.iter().any(|v| !v.is_finite()) {
6480        return Err(
6481            "block_sampled_marginal_correction: non-positive block curvature (mode is not a \
6482             strict local minimum in a sampled direction)"
6483                .to_string(),
6484        );
6485    }
6486
6487    let n_draws = block_sampling_draws(m);
6488    // ρ-invariant fixed seed → deterministic AND smooth-in-ρ objective.
6489    //
6490    // The doc comment above promises "the outer optimizer sees a smooth,
6491    // reproducible objective rather than Monte-Carlo jitter across
6492    // evaluations." That smoothness holds only if the importance draws
6493    // `z_s` themselves do NOT depend on ρ — ρ may enter the estimator
6494    // only through the per-sample importance weights `exp(−ΔF(t_s))` and
6495    // the rescaling `t_s = z_s / √λ_r`, both of which are continuous in
6496    // ρ for fixed `z_s`. A seed mixed from `λ_r = exp(ρ_k)` (or any
6497    // other ρ-dependent quantity such as the H-eigenvalues) permutes
6498    // `z_s` for every ρ probe, so the FD `(F(ρ+h) − F(ρ−h))/2h`
6499    // identity fails by O(MC_stdev/h) — exactly the order-10²–10³ FD
6500    // blow-up observed in the iso-κ Duchon binomial FD probes — and
6501    // every outer trust-region step lands on a different random face of
6502    // the objective. Mix only the (ρ-invariant) block / outer dimensions
6503    // so different problems still get independent streams.
6504    let mut seed_bits: u64 = 0x9E37_79B9_7F4A_7C15;
6505    seed_bits ^= (m as u64).rotate_left(17);
6506    seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6507    seed_bits ^= (k as u64).rotate_left(31);
6508    seed_bits = seed_bits.wrapping_mul(0x1000_0000_01B3);
6509    let mut rng = StdRng::seed_from_u64(seed_bits);
6510
6511    // Streaming, numerically-stable accumulation of the log-mean-exp value,
6512    // the explicit gradient channel `E_p[−∂ΔF/∂ρ]`, AND the gradient-channel
6513    // moments `E_p[t]`, `E_p[t tᵀ]`, `E_p[ngs]`, `E_p[t ⊗ ngs]` needed by the
6514    // exact (b)–(d) channel assembly (gradient exactness contract above).
6515    // Weights are kept relative to a running maximum log-weight: whenever a
6516    // new maximum arrives, every accumulator is rescaled by
6517    // `exp(max_old − max_new) ≤ 1`, so each per-draw relative weight is ≤ 1
6518    // and the sums never overflow. Infeasible / divergent draws contribute
6519    // zero weight rather than poisoning the estimate.
6520    let n_obs = target.base_neg_score().len();
6521    let mut max_lw = f64::NEG_INFINITY;
6522    let mut sum_w = 0.0_f64;
6523    let mut sum_w2 = 0.0_f64;
6524    let mut grad_acc = Array1::<f64>::zeros(k);
6525    let mut e_t_acc = Array1::<f64>::zeros(m);
6526    let mut e_tt_acc = Array2::<f64>::zeros((m, m));
6527    let mut e_ngs_acc = Array1::<f64>::zeros(n_obs);
6528    let mut e_t_ngs_acc = Array2::<f64>::zeros((n_obs, m));
6529
6530    // Pre-generate ALL whitened draws into the columns of `draws` (m × n_draws)
6531    // in the EXACT same RNG order as the serial loop (draw 0: r=0..m, draw 1:
6532    // r=0..m, …). The per-draw design matvec `s = X_t·(V_b·t_s)` is then batched
6533    // into two BLAS-3 products over all columns at once (the #1082 hot path),
6534    // instead of n_draws separate BLAS-2 matvecs — the draws, seed, budget, and
6535    // importance weights are byte-for-byte unchanged; only the matvecs are
6536    // reassociated into a GEMM.
6537    let mut draws = Array2::<f64>::zeros((m, n_draws));
6538    for s in 0..n_draws {
6539        let mut col = draws.column_mut(s);
6540        for r in 0..m {
6541            let z = sample_standard_normal(&mut rng);
6542            col[r] = z * inv_sqrt_lambda[r];
6543        }
6544    }
6545    let batched = target.excess_with_displaced_neg_score_batch(&draws);
6546
6547    let mut t = Array1::<f64>::zeros(m);
6548    for (sidx, (excess, displaced_ngs)) in batched.into_iter().enumerate() {
6549        t.assign(&draws.column(sidx));
6550        if !excess.is_finite() {
6551            continue;
6552        }
6553        let Some(ngs) = displaced_ngs else {
6554            // A finite excess always carries a score; absence means infeasible.
6555            continue;
6556        };
6557        let lw = -excess;
6558        if lw > max_lw {
6559            // exp(−∞ − lw) = 0 zeroes the (empty) accumulators on the first
6560            // feasible draw, so no special-casing is needed.
6561            let rescale = (max_lw - lw).exp();
6562            sum_w *= rescale;
6563            sum_w2 *= rescale * rescale;
6564            grad_acc *= rescale;
6565            e_t_acc *= rescale;
6566            e_tt_acc *= rescale;
6567            e_ngs_acc *= rescale;
6568            e_t_ngs_acc *= rescale;
6569            max_lw = lw;
6570        }
6571        let w = (lw - max_lw).exp();
6572        sum_w += w;
6573        sum_w2 += w * w;
6574        // Explicit channel: −∂ΔF/∂ρ.
6575        grad_acc.scaled_add(-w, &target.excess_rho_gradient(&t));
6576        // Moment channels (score already computed in the fused call above).
6577        if ngs.len() != n_obs {
6578            return Err(format!(
6579                "block_sampled_marginal_correction: displaced_neg_score len {} != {n_obs}",
6580                ngs.len()
6581            ));
6582        }
6583        e_t_acc.scaled_add(w, &t);
6584        e_ngs_acc.scaled_add(w, &ngs);
6585        for r in 0..m {
6586            let wt_r = w * t[r];
6587            for q in 0..m {
6588                e_tt_acc[(q, r)] += wt_r * t[q];
6589            }
6590            e_t_ngs_acc.column_mut(r).scaled_add(wt_r, &ngs);
6591        }
6592    }
6593    if !max_lw.is_finite() {
6594        return Err(
6595            "block_sampled_marginal_correction: all importance draws were infeasible".to_string(),
6596        );
6597    }
6598    let value = max_lw + (sum_w / n_draws as f64).ln();
6599    // Self-normalized importance-weighted gradient E_p[−∂ΔF/∂ρ] and moments.
6600    let (rho_gradient, moments) = if sum_w > 0.0 {
6601        (
6602            grad_acc / sum_w,
6603            Some(BlockSampledMoments {
6604                e_t: e_t_acc / sum_w,
6605                e_tt: e_tt_acc / sum_w,
6606                e_neg_score: e_ngs_acc / sum_w,
6607                e_t_neg_score: e_t_ngs_acc / sum_w,
6608            }),
6609        )
6610    } else {
6611        (Array1::zeros(k), None)
6612    };
6613    // Kish effective sample size of the importance weights.
6614    let importance_ess = if sum_w2 > 0.0 {
6615        (sum_w * sum_w) / sum_w2
6616    } else {
6617        0.0
6618    };
6619
6620    if !value.is_finite() || rho_gradient.iter().any(|v| !v.is_finite()) {
6621        return Err(
6622            "block_sampled_marginal_correction: produced a non-finite correction or gradient"
6623                .to_string(),
6624        );
6625    }
6626    if let Some(mo) = moments.as_ref()
6627        && (mo.e_t.iter().any(|v| !v.is_finite())
6628            || mo.e_tt.iter().any(|v| !v.is_finite())
6629            || mo.e_neg_score.iter().any(|v| !v.is_finite())
6630            || mo.e_t_neg_score.iter().any(|v| !v.is_finite()))
6631    {
6632        return Err(
6633            "block_sampled_marginal_correction: produced non-finite gradient-channel moments"
6634                .to_string(),
6635        );
6636    }
6637
6638    Ok(BlockSampledMarginal {
6639        value,
6640        rho_gradient,
6641        importance_ess,
6642        n_draws,
6643        moments,
6644    })
6645}
6646
6647/// Result of joint (β, ρ) sampling.
6648#[derive(Clone, Debug)]
6649pub struct JointBetaRhoResult {
6650    /// Coefficient samples: shape (n_total_samples, n_beta)
6651    pub beta_samples: Array2<f64>,
6652    /// Log-smoothing parameter samples: shape (n_total_samples, n_rho)
6653    pub rho_samples: Array2<f64>,
6654    /// Posterior mean of β
6655    pub beta_mean: Array1<f64>,
6656    /// Adaptive inverse-link parameter samples: shape (n_total_samples, n_link_params)
6657    pub link_param_samples: Array2<f64>,
6658    /// Posterior mean of adaptive inverse-link parameters
6659    pub link_param_mean: Array1<f64>,
6660    /// Posterior mean of ρ
6661    pub rho_mean: Array1<f64>,
6662    /// R-hat diagnostic
6663    pub rhat: f64,
6664    /// Effective sample size
6665    pub ess: f64,
6666    /// Whether sampling converged
6667    pub converged: bool,
6668    /// Max skewness that triggered this sampling
6669    pub trigger_skewness: f64,
6670}
6671
6672/// Joint (β, ρ) posterior target for NUTS.
6673///
6674/// Samples from p(β, ρ | y) ∝ p(y|β) p(β|ρ) p(ρ) directly,
6675/// completely bypassing the Laplace approximation.
6676///
6677/// The parameter vector is [z_β; ρ] where z_β = L⁻¹(β - μ) is the
6678/// whitened β, ρ is the raw log-smoothing parameters, and adaptive inverse-link
6679/// parameters follow when the binomial link has fitted shape/mixing parameters.
6680struct JointBetaRhoPosterior {
6681    data: SharedData,
6682    /// L where LL' = H⁻¹ (whitening for β block)
6683    chol: Array2<f64>,
6684    /// L' for chain rule
6685    chol_t: Array2<f64>,
6686    /// Joint likelihood specification (response + parameterized link).
6687    likelihood: LikelihoodSpec,
6688    /// Dimension of β
6689    n_beta: usize,
6690    /// Dimension of ρ
6691    n_rho: usize,
6692    /// Dimension of adaptive inverse-link parameters
6693    n_link_params: usize,
6694    /// LAML-converged adaptive inverse-link parameters (used only to initialize chains)
6695    link_param_mode: Array1<f64>,
6696    /// Canonical penalties in the transformed basis.
6697    penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6698    /// Fixed prior on rho used by the sampled target.
6699    rho_prior: RhoPrior,
6700    /// LAML-converged ρ (used only to initialize chains)
6701    rho_mode: Array1<f64>,
6702    /// Whether to add the identifiable-subspace Jeffreys/Firth term to the
6703    /// target
6704    firth_enabled: bool,
6705    /// One-deep cache for the structural penalty pseudo-logdet and its
6706    /// ρ-gradient. NUTS tree-doubling and U-turn checks repeatedly evaluate
6707    /// the joint log-posterior at the same `rho` bytes, so a single-slot
6708    /// cache keyed on the exact f64 bit pattern of `rho` avoids redundant
6709    /// SVD/eigendecompositions inside `PenaltyPseudologdet::from_penalties`.
6710    /// `Mutex` (not `RefCell`) because chains share the target via
6711    /// `Arc<Target>` and run in parallel via rayon.
6712    penalty_logdet_cache: Mutex<Option<(u64, f64, Array1<f64>)>>,
6713}
6714
6715impl JointBetaRhoPosterior {
6716    fn new(
6717        x: ArrayView2<f64>,
6718        y: ArrayView1<f64>,
6719        weights: ArrayView1<f64>,
6720        mode: ArrayView1<f64>,
6721        hessian: ArrayView2<f64>,
6722        penalty_canonical: Vec<gam_terms::construction::CanonicalPenalty>,
6723        rho_mode: ArrayView1<f64>,
6724        likelihood: LikelihoodSpec,
6725        gamma_shape: Option<f64>,
6726        rho_prior: RhoPrior,
6727        firth_enabled: bool,
6728    ) -> Result<Self, String> {
6729        let n_samples = x.nrows();
6730        let n_beta = x.ncols();
6731        let n_rho = penalty_canonical.len();
6732
6733        if rho_mode.len() != n_rho {
6734            return Err(HmcError::DimensionMismatch {
6735                reason: format!(
6736                    "rho_mode length {} != penalty count {}",
6737                    rho_mode.len(),
6738                    n_rho
6739                ),
6740            }
6741            .into());
6742        }
6743
6744        match (&likelihood.response, &likelihood.link) {
6745            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {}
6746            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {}
6747            (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {}
6748            (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {}
6749            (ResponseFamily::Binomial, InverseLink::Sas(_)) => {}
6750            (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {}
6751            (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {}
6752            (ResponseFamily::Binomial, InverseLink::Standard(other)) => {
6753                return Err(HmcError::LinkMismatch {
6754                    reason: format!(
6755                        "Joint HMC binomial response requires a binomial-compatible inverse link; got {:?}",
6756                        other
6757                    ),
6758                }
6759                .into());
6760            }
6761            (ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {}
6762            (ResponseFamily::Gaussian, _) => {
6763                return Err(HmcError::LinkMismatch {
6764                    reason: "Joint HMC Gaussian requires an identity inverse link".to_string(),
6765                }
6766                .into());
6767            }
6768            (
6769                ResponseFamily::Poisson
6770                | ResponseFamily::Tweedie { .. }
6771                | ResponseFamily::NegativeBinomial { .. }
6772                | ResponseFamily::Gamma,
6773                InverseLink::Standard(StandardLink::Log),
6774            ) => {}
6775            (
6776                ResponseFamily::Poisson
6777                | ResponseFamily::Tweedie { .. }
6778                | ResponseFamily::NegativeBinomial { .. }
6779                | ResponseFamily::Gamma,
6780                _,
6781            ) => {
6782                return Err(HmcError::LinkMismatch {
6783                    reason: "Joint HMC log-link family requires a log inverse link".to_string(),
6784                }
6785                .into());
6786            }
6787            (ResponseFamily::Beta { .. }, InverseLink::Standard(StandardLink::Logit)) => {}
6788            (ResponseFamily::Beta { .. }, _) => {
6789                return Err(HmcError::LinkMismatch {
6790                    reason: "Joint HMC Beta requires a logit inverse link".to_string(),
6791                }
6792                .into());
6793            }
6794            (ResponseFamily::RoystonParmar, _) => {
6795                return Err(HmcError::UnsupportedFamily {
6796                    reason: "Joint HMC fallback is not implemented for RoystonParmar".to_string(),
6797                }
6798                .into());
6799            }
6800        }
6801
6802        validate_firth_likelihood_support(&likelihood, firth_enabled).map_err(String::from)?;
6803        if matches!(likelihood.response, ResponseFamily::NegativeBinomial { .. }) {
6804            validate_count_responses("negative-binomial joint HMC", &y, &weights)
6805                .map_err(String::from)?;
6806        }
6807        if likelihood.is_binomial() {
6808            validate_binary_responses("binomial joint HMC", &y, &weights).map_err(String::from)?;
6809        }
6810
6811        let whitening = hessian_whitening_transform(
6812            hessian,
6813            n_beta,
6814            1.0,
6815            "Joint HMC: Hessian Cholesky failed",
6816        )?;
6817        let chol = whitening.chol;
6818        let chol_t = whitening.chol_t;
6819
6820        let data = SharedData {
6821            x: Arc::new(x.to_owned()),
6822            y: Arc::new(y.to_owned()),
6823            weights: Arc::new(weights.to_owned()),
6824            mode: Arc::new(mode.to_owned()),
6825            offset: None,
6826            gamma_shape: gamma_shape.unwrap_or(1.0),
6827            // Joint (β, ρ) HMC keeps the likelihood on its native scale;
6828            // dispersion enters via the per-family scale parameter, not
6829            // via the whitening transform here. `Known(1.0)` matches the
6830            // pre-refactor behaviour for this code path.
6831            dispersion: gam_solve::model_types::Dispersion::Known(1.0),
6832            n_samples,
6833            dim: n_beta,
6834        };
6835        let link_param_mode = Self::link_param_mode(&likelihood.link);
6836
6837        Ok(Self {
6838            data,
6839            chol,
6840            chol_t,
6841            likelihood,
6842            n_beta,
6843            n_rho,
6844            n_link_params: link_param_mode.len(),
6845            link_param_mode,
6846            penalty_canonical,
6847            rho_prior,
6848            rho_mode: rho_mode.to_owned(),
6849            firth_enabled,
6850            penalty_logdet_cache: Mutex::new(None),
6851        })
6852    }
6853
6854    /// FNV-1a hash over the raw f64 bit pattern of `rho`.
6855    ///
6856    /// NUTS leapfrog / tree-doubling / U-turn checks revisit identical
6857    /// position vectors byte-for-byte, so exact-equality on `to_bits()`
6858    /// captures the dominant repetition pattern without any tolerance.
6859    #[inline]
6860    fn hash_rho(rho: ndarray::ArrayView1<f64>) -> u64 {
6861        let mut h: u64 = 0xcbf2_9ce4_8422_2325;
6862        for &x in rho.iter() {
6863            h ^= x.to_bits();
6864            h = h.wrapping_mul(0x0000_0100_0000_01b3);
6865        }
6866        h
6867    }
6868
6869    fn link_param_mode(inverse_link: &InverseLink) -> Array1<f64> {
6870        match inverse_link {
6871            InverseLink::Sas(state) | InverseLink::BetaLogistic(state) => {
6872                Array1::from_vec(vec![state.epsilon, state.log_delta])
6873            }
6874            InverseLink::Mixture(state) => state.rho.clone(),
6875            InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => Array1::zeros(0),
6876        }
6877    }
6878
6879    fn inverse_link_with_params(
6880        &self,
6881        link_params: ndarray::ArrayView1<'_, f64>,
6882    ) -> Result<InverseLink, String> {
6883        match &self.likelihood.link {
6884            InverseLink::Sas(_) => {
6885                if link_params.len() != 2 {
6886                    return Err(format!(
6887                        "SAS link parameter length must be 2, got {}",
6888                        link_params.len()
6889                    ));
6890                }
6891                Ok(InverseLink::Sas(
6892                    gam_solve::mixture_link::sas_link_state_from_raw(
6893                        link_params[0],
6894                        link_params[1],
6895                    )?,
6896                ))
6897            }
6898            InverseLink::BetaLogistic(_) => {
6899                if link_params.len() != 2 || !link_params.iter().all(|v| v.is_finite()) {
6900                    return Err(
6901                        "Beta-Logistic link parameters must be finite with length 2".to_string()
6902                    );
6903                }
6904                Ok(InverseLink::BetaLogistic(gam_problem::types::SasLinkState {
6905                    epsilon: link_params[0],
6906                    log_delta: link_params[1],
6907                    delta: link_params[1].exp(),
6908                }))
6909            }
6910            InverseLink::Mixture(state) => {
6911                let rho = link_params.to_owned();
6912                Ok(InverseLink::Mixture(gam_problem::types::MixtureLinkState {
6913                    components: state.components.clone(),
6914                    pi: softmax_last_fixedzero(&rho),
6915                    rho,
6916                }))
6917            }
6918            InverseLink::Standard(_) | InverseLink::LatentCLogLog(_) => {
6919                Ok(self.likelihood.link.clone())
6920            }
6921        }
6922    }
6923
6924    /// Compute the joint log-posterior and gradient.
6925    ///
6926    /// The joint log-posterior is:
6927    ///   log p(β, ρ | y) = ℓ(y|β) + ½ log|I(β)| [if Firth]
6928    ///                    − ½β'S(ρ)β + ½ log|S(ρ)|₊ + log p(ρ) + const
6929    ///
6930    /// This is NOT the REML/LAML objective (which integrates out β). Here β is
6931    /// an explicit parameter being sampled, evaluated at arbitrary values — not
6932    /// just at the mode β̂(ρ).
6933    ///
6934    /// Parameter vector layout: [z_β (whitened, length n_beta); ρ (length n_rho);
6935    /// adaptive inverse-link params (length n_link_params)]
6936    fn compute_joint_logp_and_grad_into(
6937        &self,
6938        params: &Array1<f64>,
6939        out_grad: &mut Array1<f64>,
6940    ) -> f64 {
6941        let n_beta = self.n_beta;
6942        let n_rho = self.n_rho;
6943        let n_link_params = self.n_link_params;
6944
6945        // Split parameter vector — keep as views to avoid two per-step
6946        // `to_owned()` allocations of size n_beta and n_rho.
6947        let z = params.slice(ndarray::s![..n_beta]);
6948        let rho = params.slice(ndarray::s![n_beta..n_beta + n_rho]);
6949        let link_params = params.slice(ndarray::s![n_beta + n_rho..]);
6950        let lambdas: Array1<f64> = rho.mapv(f64::exp);
6951
6952        let inverse_link = match self.inverse_link_with_params(link_params) {
6953            Ok(link) => link,
6954            Err(err) => {
6955                log::warn!(
6956                    "[Joint HMC] adaptive inverse-link parameters are invalid: {}",
6957                    err
6958                );
6959                out_grad.fill(0.0);
6960                return f64::NEG_INFINITY;
6961            }
6962        };
6963
6964        // Un-whiten: β = μ + L z
6965        let beta = self.data.mode.as_ref() + &self.chol.dot(&z);
6966
6967        // η = X β
6968        let eta = gam_linalg::faer_ndarray::fast_av(self.data.x.as_ref(), &beta);
6969
6970        // ---- Log-likelihood ℓ(y|β) and ∇_β ℓ ----
6971        let step_likelihood = LikelihoodSpec {
6972            response: self.likelihood.response.clone(),
6973            link: inverse_link,
6974        };
6975        let (ll, mut grad_ll_beta, grad_link) = match joint_family_logp_grad_and_link_grad(
6976            &step_likelihood,
6977            &self.data,
6978            &eta,
6979            n_link_params,
6980        ) {
6981            Ok(value) => value,
6982            Err(err) => {
6983                log::warn!(
6984                    "[Joint HMC] likelihood target became invalid at the current state: {}",
6985                    err
6986                );
6987                out_grad.fill(0.0);
6988                return f64::NEG_INFINITY;
6989            }
6990        };
6991
6992        let mut firth_logdet = 0.0;
6993        if self.firth_enabled {
6994            match firth_jeffreys_logp_and_grad(NutsFamily::BinomialLogit, &self.data, &eta) {
6995                Ok((value, grad_beta_firth)) => {
6996                    firth_logdet = value;
6997                    grad_ll_beta += &grad_beta_firth;
6998                }
6999                Err(err) => {
7000                    log::warn!(
7001                        "[Joint HMC/Firth] Jeffreys target became invalid at the current state: {}",
7002                        err
7003                    );
7004                    out_grad.fill(0.0);
7005                    return f64::NEG_INFINITY;
7006                }
7007            }
7008        }
7009
7010        // ---- Penalty: -0.5 β'S(ρ)β ----
7011        // S(ρ) = Σ_k λ_k S_k where S_k = R_k'R_k (precomputed in penalty_matrices).
7012        // Uses penalty_roots for the efficient ||R_k β||² form.
7013        let mut penalty_val = 0.0;
7014        let mut s_beta = Array1::<f64>::zeros(n_beta);
7015        let mut grad_rho = Array1::<f64>::zeros(n_rho);
7016
7017        // Reuse one max-rank scratch buffer for r_beta = R_k · β_block across
7018        // all penalty blocks instead of allocating a fresh Array1 per block
7019        // per HMC step.
7020        let max_rank = self
7021            .penalty_canonical
7022            .iter()
7023            .map(|cp| cp.rank())
7024            .max()
7025            .unwrap_or(0);
7026        let mut r_beta_scratch = Array1::<f64>::zeros(max_rank);
7027
7028        for (k, cp) in self.penalty_canonical.iter().enumerate() {
7029            // Block-local quadratic: β'S_k β via root
7030            let r = &cp.col_range;
7031            let beta_block = beta.slice(ndarray::s![r.start..r.end]);
7032            let rank_k = cp.rank();
7033            gam_linalg::faer_ndarray::fast_av_view_into(
7034                &cp.root,
7035                &beta_block,
7036                r_beta_scratch.slice_mut(ndarray::s![..rank_k]),
7037            );
7038            let r_beta = r_beta_scratch.slice(ndarray::s![..rank_k]);
7039            let quad_k = r_beta.dot(&r_beta);
7040            penalty_val += 0.5 * lambdas[k] * quad_k;
7041
7042            // Accumulate S(ρ)β for β-gradient — block-local
7043            for a in 0..cp.block_dim() {
7044                let val: f64 = (0..rank_k).map(|row| cp.root[[row, a]] * r_beta[row]).sum();
7045                s_beta[r.start + a] += lambdas[k] * val;
7046            }
7047
7048            // ρ_k gradient from penalty
7049            grad_rho[k] = -0.5 * lambdas[k] * quad_k;
7050        }
7051
7052        // ---- Structural penalty log-determinant: +0.5 log|S(ρ)|₊ and ρ-derivatives ----
7053        //
7054        // One-deep cache keyed on the exact f64 bits of `rho`: NUTS tree
7055        // doubling revisits identical positions byte-for-byte, so an
7056        // exact-equality cache eliminates the dominant SVD/eigendecomp
7057        // cost in `PenaltyPseudologdet::from_penalties` across leapfrog
7058        // half-steps.
7059        let log_det_s = if self.penalty_canonical.is_empty() {
7060            0.0
7061        } else {
7062            let rho_hash = Self::hash_rho(rho);
7063            let cached = self.penalty_logdet_cache.lock().ok().and_then(|guard| {
7064                guard.as_ref().and_then(|(h, v, g)| {
7065                    if *h == rho_hash && g.len() == n_rho {
7066                        for k in 0..n_rho {
7067                            grad_rho[k] += 0.5 * g[k];
7068                        }
7069                        Some(*v)
7070                    } else {
7071                        None
7072                    }
7073                })
7074            });
7075            if let Some(hit) = cached {
7076                hit
7077            } else {
7078                match PenaltyPseudologdet::from_penalties(
7079                    &self.penalty_canonical,
7080                    lambdas.as_slice().unwrap_or(&[]),
7081                    0.0,
7082                    n_beta,
7083                ) {
7084                    Ok(pld) => {
7085                        let (det1, _) = pld.rho_derivatives_from_penalties(
7086                            &self.penalty_canonical,
7087                            lambdas.as_slice().unwrap_or(&[]),
7088                        );
7089                        let value = pld.value();
7090                        if let Ok(mut guard) = self.penalty_logdet_cache.lock() {
7091                            *guard = Some((rho_hash, value, det1.clone()));
7092                        }
7093                        for k in 0..n_rho {
7094                            grad_rho[k] += 0.5 * det1[k];
7095                        }
7096                        value
7097                    }
7098                    Err(err) => {
7099                        log::warn!(
7100                            "[Joint HMC] structural penalty logdet became invalid at the current state: {}",
7101                            err
7102                        );
7103                        out_grad.fill(0.0);
7104                        return f64::NEG_INFINITY;
7105                    }
7106                }
7107            }
7108        };
7109
7110        // ---- Prior on ρ ----
7111        let mut rho_prior = 0.0;
7112        match &self.rho_prior {
7113            RhoPrior::Flat => {}
7114            RhoPrior::Normal { mean, sd } => {
7115                let inv_var = 1.0 / (*sd * *sd);
7116                for k in 0..n_rho {
7117                    let d = rho[k] - *mean;
7118                    rho_prior -= 0.5 * inv_var * d * d;
7119                    grad_rho[k] -= inv_var * d;
7120                }
7121            }
7122            RhoPrior::GammaPrecision { shape, rate } => {
7123                for k in 0..n_rho {
7124                    let lambda = rho[k].exp();
7125                    // Density over sampled rho includes the e^rho Jacobian (Gamma is on lambda = e^rho).
7126                    rho_prior += *shape * rho[k] - *rate * lambda;
7127                    grad_rho[k] += *shape - *rate * lambda;
7128                }
7129            }
7130            RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7131                if !pc_prior_params_valid(*upper, *tail_prob) {
7132                    out_grad.fill(0.0);
7133                    return f64::NEG_INFINITY;
7134                }
7135                let theta = -tail_prob.ln() / *upper;
7136                for k in 0..n_rho {
7137                    // log p(ρ) = const − ρ/2 − θ exp(−ρ/2).
7138                    let e = (-0.5 * rho[k]).exp();
7139                    rho_prior += -0.5 * rho[k] - theta * e;
7140                    grad_rho[k] += -0.5 + 0.5 * theta * e;
7141                }
7142            }
7143            RhoPrior::Independent(priors) => {
7144                if priors.len() != n_rho {
7145                    out_grad.fill(0.0);
7146                    return f64::NEG_INFINITY;
7147                }
7148                for k in 0..n_rho {
7149                    match &priors[k] {
7150                        RhoPrior::Flat => {}
7151                        RhoPrior::Normal { mean, sd } => {
7152                            let inv_var = 1.0 / (*sd * *sd);
7153                            let d = rho[k] - *mean;
7154                            rho_prior -= 0.5 * inv_var * d * d;
7155                            grad_rho[k] -= inv_var * d;
7156                        }
7157                        RhoPrior::GammaPrecision { shape, rate } => {
7158                            let lambda = rho[k].exp();
7159                            // Density over sampled rho includes the e^rho Jacobian (Gamma is on lambda = e^rho).
7160                            rho_prior += *shape * rho[k] - *rate * lambda;
7161                            grad_rho[k] += *shape - *rate * lambda;
7162                        }
7163                        RhoPrior::PenalizedComplexity { upper, tail_prob } => {
7164                            if !pc_prior_params_valid(*upper, *tail_prob) {
7165                                out_grad.fill(0.0);
7166                                return f64::NEG_INFINITY;
7167                            }
7168                            let theta = -tail_prob.ln() / *upper;
7169                            let e = (-0.5 * rho[k]).exp();
7170                            rho_prior += -0.5 * rho[k] - theta * e;
7171                            grad_rho[k] += -0.5 + 0.5 * theta * e;
7172                        }
7173                        RhoPrior::Independent(_) => {
7174                            out_grad.fill(0.0);
7175                            return f64::NEG_INFINITY;
7176                        }
7177                    }
7178                }
7179            }
7180        }
7181
7182        // ---- Assemble ----
7183        let logp = ll + firth_logdet - penalty_val + 0.5 * log_det_s + rho_prior;
7184
7185        // β-gradient in original space: ∇_β ℓ - S(ρ)β
7186        let grad_beta = &grad_ll_beta - &s_beta;
7187
7188        // Combined gradient: [∇_z; ∇_ρ; ∇_link]
7189        gam_linalg::faer_ndarray::fast_av_view_into(
7190            &self.chol_t,
7191            &grad_beta,
7192            out_grad.slice_mut(ndarray::s![..n_beta]),
7193        );
7194        out_grad
7195            .slice_mut(ndarray::s![n_beta..n_beta + n_rho])
7196            .assign(&grad_rho);
7197        out_grad
7198            .slice_mut(ndarray::s![n_beta + n_rho..])
7199            .assign(&grad_link);
7200
7201        logp
7202    }
7203}
7204
7205/// Penalized-complexity hyperparameters are usable iff `upper` is finite and
7206/// strictly positive and `tail_prob` is a probability in the open `(0, 1)`.
7207/// Mirrors the validation in the shared `rho_prior_eval` engine; an invalid
7208/// configuration repels the sampler (`-∞` potential) rather than producing a
7209/// non-finite gradient.
7210fn pc_prior_params_valid(upper: f64, tail_prob: f64) -> bool {
7211    upper.is_finite() && upper > 0.0 && tail_prob.is_finite() && tail_prob > 0.0 && tail_prob < 1.0
7212}
7213
7214impl HamiltonianTarget<Array1<f64>> for JointBetaRhoPosterior {
7215    fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7216        self.compute_joint_logp_and_grad_into(position, grad)
7217    }
7218}
7219
7220/// Inputs for joint (β, ρ) sampling.
7221pub struct JointBetaRhoInputs<'a> {
7222    pub x: ArrayView2<'a, f64>,
7223    pub y: ArrayView1<'a, f64>,
7224    pub weights: ArrayView1<'a, f64>,
7225    pub likelihood: LikelihoodSpec,
7226    pub gamma_shape: Option<f64>,
7227    pub mode: ArrayView1<'a, f64>,
7228    pub hessian: ArrayView2<'a, f64>,
7229    pub penalty_roots: Vec<CanonicalPenalty>,
7230    pub rho_mode: ArrayView1<'a, f64>,
7231    pub rho_prior: RhoPrior,
7232    pub firth_bias_reduction: bool,
7233    /// Max posterior skewness that triggered this sampling
7234    pub trigger_skewness: f64,
7235}
7236
7237/// Run joint (β, ρ) NUTS sampling.
7238///
7239/// This is the automatic fallback when the Laplace approximation has high
7240/// skewness. It samples from the true joint posterior, completely bypassing
7241/// the Laplace approximation for smoothing parameter selection.
7242pub fn run_joint_beta_rho_sampling(
7243    inputs: &JointBetaRhoInputs<'_>,
7244    config: &NutsConfig,
7245) -> Result<JointBetaRhoResult, String> {
7246    validate_firth_likelihood_support(&inputs.likelihood, inputs.firth_bias_reduction)
7247        .map_err(String::from)?;
7248    validate_nuts_config(config).map_err(String::from)?;
7249    let n_beta = inputs.mode.len();
7250    let n_rho = inputs.penalty_roots.len();
7251    let n_link_params = JointBetaRhoPosterior::link_param_mode(&inputs.likelihood.link).len();
7252    let total_dim = n_beta + n_rho + n_link_params;
7253
7254    log::info!(
7255        "[Joint HMC] Sampling (β, ρ, link) jointly: {} β-params + {} ρ-params + {} link-params = {} total (triggered by skewness {:.3})",
7256        n_beta,
7257        n_rho,
7258        n_link_params,
7259        total_dim,
7260        inputs.trigger_skewness,
7261    );
7262
7263    let target = JointBetaRhoPosterior::new(
7264        inputs.x,
7265        inputs.y,
7266        inputs.weights,
7267        inputs.mode,
7268        inputs.hessian,
7269        inputs.penalty_roots.clone(),
7270        inputs.rho_mode,
7271        inputs.likelihood.clone(),
7272        inputs.gamma_shape,
7273        inputs.rho_prior.clone(),
7274        inputs.firth_bias_reduction,
7275    )?;
7276
7277    let chol = target.chol.clone();
7278    let mode_arr = target.data.mode.clone();
7279    let rho_mode = target.rho_mode.clone();
7280    let link_param_mode = target.link_param_mode.clone();
7281
7282    // Initialize chains: z_β at 0 (= mode), ρ at rho_mode, link params at fitted state.
7283    let initial_positions: Vec<Array1<f64>> = (0..config.n_chains)
7284        .map(|chain| {
7285            let mut rng =
7286                StdRng::seed_from_u64(chain_stream_seed(config.seed, chain, 0x9B51_6E37_F2D0_A48C));
7287            let mut pos = Array1::<f64>::zeros(total_dim);
7288            // Small jitter for β (whitened space)
7289            for j in 0..n_beta {
7290                pos[j] = sample_standard_normal(&mut rng) * 0.1;
7291            }
7292            // Small jitter for ρ around mode
7293            for k in 0..n_rho {
7294                pos[n_beta + k] = rho_mode[k] + sample_standard_normal(&mut rng) * 0.2;
7295            }
7296            // Small jitter for adaptive link parameters around fitted state
7297            for k in 0..n_link_params {
7298                pos[n_beta + n_rho + k] =
7299                    link_param_mode[k] + sample_standard_normal(&mut rng) * 0.05;
7300            }
7301            pos
7302        })
7303        .collect();
7304
7305    // Keep warmup covariance phase-local: diagonal windows are less likely to
7306    // encode cross-block covariance from a transient mode switch.
7307    let mass_cfg = robust_mass_matrix_config(total_dim, config.nwarmup);
7308
7309    let (samples_array, run_stats) = run_whitened_nuts_samples(
7310        target,
7311        initial_positions,
7312        config,
7313        total_dim,
7314        mass_cfg,
7315        0x63AF_175B_D820_C94E,
7316        "Joint (β,ρ) NUTS sampling failed",
7317    )?;
7318    log::info!("[Joint HMC] Sampling complete: {}", run_stats);
7319
7320    // Unpack samples
7321    let shape = samples_array.shape();
7322    let n_chains = shape[0];
7323    let n_samples_out = shape[1];
7324    let total_samples = n_chains * n_samples_out;
7325
7326    let beta_samples = unwhiten_samples(&samples_array, mode_arr.as_ref(), &chol, n_beta, 0);
7327    let mut rho_samples = Array2::<f64>::zeros((total_samples, n_rho));
7328    let mut link_param_samples = Array2::<f64>::zeros((total_samples, n_link_params));
7329
7330    for chain in 0..n_chains {
7331        for sample_i in 0..n_samples_out {
7332            let sample_idx = chain * n_samples_out + sample_i;
7333            let zview = samples_array.slice(ndarray::s![chain, sample_i, ..]);
7334
7335            // ρ and adaptive link parameters are stored directly
7336            let rho_slice = zview.slice(ndarray::s![n_beta..n_beta + n_rho]);
7337            rho_samples.row_mut(sample_idx).assign(&rho_slice);
7338            let link_slice = zview.slice(ndarray::s![n_beta + n_rho..]);
7339            link_param_samples.row_mut(sample_idx).assign(&link_slice);
7340        }
7341    }
7342
7343    let beta_mean = beta_samples
7344        .mean_axis(Axis(0))
7345        .unwrap_or_else(|| Array1::zeros(n_beta));
7346    let rho_mean = rho_samples
7347        .mean_axis(Axis(0))
7348        .unwrap_or_else(|| Array1::zeros(n_rho));
7349    let link_param_mean = link_param_samples
7350        .mean_axis(Axis(0))
7351        .unwrap_or_else(|| Array1::zeros(n_link_params));
7352
7353    let (rhat, ess) = compute_split_rhat_and_ess(&samples_array);
7354
7355    let converged = NutsConvergenceThresholds {
7356        max_rhat: 1.1,
7357        min_ess: Some(50.0),
7358    }
7359    .converged(rhat, ess);
7360    if !converged {
7361        log::warn!(
7362            "[Joint HMC] Convergence warning: R-hat={:.3}, ESS={:.1}",
7363            rhat,
7364            ess,
7365        );
7366    }
7367
7368    Ok(JointBetaRhoResult {
7369        beta_samples,
7370        rho_samples,
7371        beta_mean,
7372        link_param_samples,
7373        link_param_mean,
7374        rho_mean,
7375        rhat,
7376        ess,
7377        converged,
7378        trigger_skewness: inputs.trigger_skewness,
7379    })
7380}
7381
7382// ============================================================================
7383// Survival Model HMC Support
7384// ============================================================================
7385
7386mod survival_hmc {
7387    use super::*;
7388    use gam_models::survival::{
7389        PenaltyBlocks, SurvivalEngineInputs, SurvivalMonotonicityPenalty, SurvivalSpec,
7390        WorkingModelSurvival,
7391    };
7392
7393    /// Shared data for survival NUTS posterior (wrapped in Arc to prevent cloning).
7394    #[derive(Clone)]
7395    struct SharedSurvivalData {
7396        /// Exact survival model in original spline coordinates.
7397        base_model: Arc<WorkingModelSurvival>,
7398        /// MAP estimate in coefficient coordinates.
7399        mode: Arc<Array1<f64>>,
7400    }
7401
7402    /// Whitened log-posterior target for survival models with analytical gradients.
7403    #[derive(Clone)]
7404    pub struct SurvivalPosterior {
7405        /// Shared read-only data (Arc prevents duplication)
7406        data: SharedSurvivalData,
7407        /// Transform: L where L L^T = H^{-1}
7408        chol: Array2<f64>,
7409        /// L^T for gradient chain rule: ∇z = L^T @ ∇_β
7410        chol_t: Array2<f64>,
7411    }
7412
7413    impl SurvivalPosterior {
7414        /// Creates a new survival posterior target.
7415        pub fn new(
7416            age_entry: ArrayView1<'_, f64>,
7417            age_exit: ArrayView1<'_, f64>,
7418            event_target: ArrayView1<'_, u8>,
7419            event_competing: ArrayView1<'_, u8>,
7420            sampleweight: ArrayView1<'_, f64>,
7421            x_entry: ArrayView2<'_, f64>,
7422            x_exit: ArrayView2<'_, f64>,
7423            x_derivative: ArrayView2<'_, f64>,
7424            offset_eta_entry: Option<ArrayView1<'_, f64>>,
7425            offset_eta_exit: Option<ArrayView1<'_, f64>>,
7426            offset_derivative_exit: Option<ArrayView1<'_, f64>>,
7427            penalties: PenaltyBlocks,
7428            monotonicity: SurvivalMonotonicityPenalty,
7429            spec: SurvivalSpec,
7430            structurally_monotonic: bool,
7431            structural_time_columns: usize,
7432            mode: ArrayView1<f64>,
7433            hessian: ArrayView2<f64>,
7434        ) -> Result<Self, String> {
7435            let n = age_entry.len();
7436            let off_eta_entry = offset_eta_entry
7437                .map(|v| v.to_owned())
7438                .unwrap_or_else(|| Array1::zeros(n));
7439            let off_eta_exit = offset_eta_exit
7440                .map(|v| v.to_owned())
7441                .unwrap_or_else(|| Array1::zeros(n));
7442            let off_deriv_exit = offset_derivative_exit
7443                .map(|v| v.to_owned())
7444                .unwrap_or_else(|| Array1::zeros(n));
7445
7446            let mut base_model = WorkingModelSurvival::from_engine_inputswith_offsets(
7447                SurvivalEngineInputs {
7448                    age_entry,
7449                    age_exit,
7450                    event_target,
7451                    event_competing,
7452                    sampleweight,
7453                    x_entry,
7454                    x_exit,
7455                    x_derivative,
7456                    monotonicity_constraint_rows: None,
7457                    monotonicity_constraint_offsets: None,
7458                },
7459                Some(gam_models::survival::SurvivalBaselineOffsets {
7460                    eta_entry: off_eta_entry.view(),
7461                    eta_exit: off_eta_exit.view(),
7462                    derivative_exit: off_deriv_exit.view(),
7463                }),
7464                penalties,
7465                monotonicity,
7466                spec,
7467            )
7468            .map_err(|e| format!("Survival state construction failed: {:?}", e))?;
7469            if structurally_monotonic {
7470                base_model
7471                    .set_structural_monotonicity(true, structural_time_columns)
7472                    .map_err(|e| {
7473                        format!("Failed to enable structural monotonicity in survival HMC: {e}")
7474                    })?;
7475            }
7476
7477            let sampler_mode = mode.to_owned();
7478            let dim = sampler_mode.len();
7479
7480            let whitening = hessian_whitening_transform(
7481                hessian,
7482                dim,
7483                1.0,
7484                "Hessian Cholesky decomposition failed",
7485            )?;
7486            let chol = whitening.chol;
7487            let chol_t = whitening.chol_t;
7488
7489            let data = SharedSurvivalData {
7490                base_model: Arc::new(base_model),
7491                mode: Arc::new(sampler_mode),
7492            };
7493
7494            Ok(Self { data, chol, chol_t })
7495        }
7496
7497        fn compute_logp_and_grad_into(
7498            &self,
7499            z: &Array1<f64>,
7500            grad: &mut Array1<f64>,
7501        ) -> Result<f64, String> {
7502            let sampler_position = self.data.mode.as_ref() + &self.chol.dot(z);
7503            let state = self
7504                .data
7505                .base_model
7506                .update_state(&sampler_position)
7507                .map_err(|e| format!("Survival state update failed: {:?}", e))?;
7508            let logp = state.log_likelihood - state.penalty_term;
7509            let grad_beta = state.gradient.mapv(|g| -g);
7510            fast_av_into(&self.chol_t, &grad_beta, grad);
7511            Ok(logp)
7512        }
7513
7514        /// Get the Cholesky factor L for un-whitening samples
7515        pub fn chol(&self) -> &Array2<f64> {
7516            &self.chol
7517        }
7518
7519        /// Get the mode
7520        pub fn mode(&self) -> &Array1<f64> {
7521            &self.data.mode
7522        }
7523    }
7524
7525    impl HamiltonianTarget<Array1<f64>> for SurvivalPosterior {
7526        fn logp_and_grad(&self, position: &Array1<f64>, grad: &mut Array1<f64>) -> f64 {
7527            match self.compute_logp_and_grad_into(position, grad) {
7528                Ok(logp) => logp,
7529                Err(e) => {
7530                    log::warn!("Survival posterior evaluation failed: {}", e);
7531                    grad.fill(0.0);
7532                    f64::NEG_INFINITY
7533                }
7534            }
7535        }
7536    }
7537
7538    /// Runs NUTS sampling for survival models with whitened parameter space.
7539    pub(crate) fn run_survival_nuts_sampling(
7540        age_entry: ArrayView1<'_, f64>,
7541        age_exit: ArrayView1<'_, f64>,
7542        event_target: ArrayView1<'_, u8>,
7543        event_competing: ArrayView1<'_, u8>,
7544        sampleweight: ArrayView1<'_, f64>,
7545        x_entry: ArrayView2<'_, f64>,
7546        x_exit: ArrayView2<'_, f64>,
7547        x_derivative: ArrayView2<'_, f64>,
7548        eta_offset_entry: Option<ArrayView1<'_, f64>>,
7549        eta_offset_exit: Option<ArrayView1<'_, f64>>,
7550        derivative_offset_exit: Option<ArrayView1<'_, f64>>,
7551        penalties: PenaltyBlocks,
7552        monotonicity: SurvivalMonotonicityPenalty,
7553        spec: SurvivalSpec,
7554        structurally_monotonic: bool,
7555        structural_time_columns: usize,
7556        mode: ArrayView1<f64>,
7557        hessian: ArrayView2<f64>,
7558        config: &NutsConfig,
7559    ) -> Result<NutsResult, String> {
7560        validate_nuts_config(config).map_err(String::from)?;
7561        // Create posterior target
7562        let target = SurvivalPosterior::new(
7563            age_entry,
7564            age_exit,
7565            event_target,
7566            event_competing,
7567            sampleweight,
7568            x_entry,
7569            x_exit,
7570            x_derivative,
7571            eta_offset_entry,
7572            eta_offset_exit,
7573            derivative_offset_exit,
7574            penalties,
7575            monotonicity,
7576            spec,
7577            structurally_monotonic,
7578            structural_time_columns,
7579            mode,
7580            hessian,
7581        )?;
7582
7583        // Get Cholesky factor for un-whitening samples later
7584        let chol = target.chol().clone();
7585        let mode_arr = target.mode().clone();
7586        let dim = mode_arr.len();
7587
7588        let initial_positions = jittered_initial_positions(config, dim, 0.1, 0xEC2D_7A9B_4051_F638);
7589
7590        let mass_cfg = robust_survival_mass_matrix_config(dim, config.nwarmup);
7591        let (result, run_stats) = run_whitened_nuts_result(
7592            target,
7593            &mode_arr,
7594            &chol,
7595            initial_positions,
7596            config,
7597            dim,
7598            mass_cfg,
7599            0x731B_60D4_AE52_9C8F,
7600            "NUTS sampling failed",
7601            Array1::zeros(dim),
7602            NutsConvergenceThresholds {
7603                max_rhat: 1.1,
7604                min_ess: None,
7605            },
7606        )?;
7607
7608        log::info!("Survival NUTS sampling complete: {}", run_stats);
7609
7610        Ok(result)
7611    }
7612}
7613
7614/// Engine-facing flattened survival NUTS entrypoint.
7615pub fn run_survival_nuts_sampling_flattened<'a>(
7616    flat: SurvivalFlatInputs<'a>,
7617    penalties: gam_models::survival::PenaltyBlocks,
7618    monotonicity: gam_models::survival::SurvivalMonotonicityPenalty,
7619    spec: gam_models::survival::SurvivalSpec,
7620    structurally_monotonic: bool,
7621    structural_time_columns: usize,
7622    mode: ArrayView1<'a, f64>,
7623    hessian: ArrayView2<'a, f64>,
7624    config: &NutsConfig,
7625) -> Result<NutsResult, String> {
7626    run_nuts_sampling_flattened_family(
7627        LikelihoodSpec {
7628            response: ResponseFamily::RoystonParmar,
7629            link: InverseLink::Standard(StandardLink::Identity),
7630        },
7631        FamilyNutsInputs::Survival(Box::new(SurvivalNutsInputs {
7632            flat,
7633            penalties,
7634            monotonicity,
7635            spec,
7636            structurally_monotonic,
7637            structural_time_columns,
7638            mode,
7639            hessian,
7640        })),
7641        config,
7642    )
7643}