Skip to main content

gam_inference/
sample.rs

1//! Library-side orchestration for NUTS posterior sampling from a saved model.
2//!
3//! The CLI's `gam sample` subcommand and the Python `Model.sample(...)` API
4//! both call into [`sample_saved_model`], which dispatches on the saved
5//! model's class (standard GLM, standard with link-wiggle, or survival) and
6//! returns a fully-converged [`NutsResult`] over the original coefficient
7//! space. Gaussian identity standard models are sampled from the saved
8//! closed-form posterior, conditioning on the training fit rather than any
9//! prediction rows supplied by the caller.
10
11use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18    FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19    explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20    run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27    SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28    build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29    evaluate_survival_time_basis_row, normalize_survival_time_pair,
30    resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33    fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34    resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35    saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39    PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42    append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43    split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47    FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55/// Reconstruct the `LinkWiggleFormulaSpec` from a saved model's
56/// baseline-time-wiggle runtime, returning `None` when the model has no
57/// time-wiggle component. Re-exported because the survival fitter's tests
58/// exercise the spec independently of running NUTS.
59pub fn saved_baseline_timewiggle_spec(
60    model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62    model
63        .saved_baseline_time_wiggle()
64        .map_err(|e| e.to_string())
65        .map(|runtime| {
66            runtime.map(|saved| LinkWiggleFormulaSpec {
67                degree: saved.degree,
68                num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69                penalty_orders: saved.penalty_orders,
70                double_penalty: saved.double_penalty,
71            })
72        })
73}
74
75fn weighted_penalty_matrix(
76    penalties: &[Array2<f64>],
77    lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79    if penalties.len() != lambdas.len() {
80        return Err(format!(
81            "penalty/lambda mismatch: {} penalties vs {} lambdas",
82            penalties.len(),
83            lambdas.len()
84        ));
85    }
86    if penalties.is_empty() {
87        return Err("cannot sample without at least one penalty block".to_string());
88    }
89    let p = penalties[0].nrows();
90    let mut out = Array2::<f64>::zeros((p, p));
91    for (k, s) in penalties.iter().enumerate() {
92        if s.nrows() != p || s.ncols() != p {
93            return Err(format!(
94                "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95                s.nrows(),
96                s.ncols(),
97                p,
98                p
99            ));
100        }
101        let lam = lambdas[k];
102        out += &(s * lam);
103    }
104    Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108    hessian: &Array2<f64>,
109    expected_dim: usize,
110) -> Result<(), String> {
111    if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112        return Err(format!(
113            "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114            hessian.nrows(),
115            hessian.ncols(),
116            expected_dim,
117            expected_dim,
118        ));
119    }
120    validate_all_finite(
121        "link-wiggle explicit joint Hessian",
122        hessian.iter().copied(),
123    )?;
124    let mut max_abs = 0.0_f64;
125    for r in 0..expected_dim {
126        for c in 0..expected_dim {
127            max_abs = max_abs.max(hessian[[r, c]].abs());
128            let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129            if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130                return Err(format!(
131                    "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132                ));
133            }
134        }
135    }
136    if max_abs == 0.0 {
137        return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138                    .to_string());
139    }
140    Ok(())
141}
142
143/// Resolve the scalar generative dispersion for a fitted model.
144///
145/// Thin adapter over the single canonical
146/// [`crate::generative::family_noise_parameter`]: the replicate-sampling path
147/// here and the CLI `gam generate` path both route through that one helper, so
148/// the fitted dispersion (NB θ̂, Beta/Tweedie φ̂, Gamma k̂) can never be read
149/// inconsistently between them. A divergent second copy of this logic was the
150/// root cause of #1124.
151fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152    crate::generative::family_noise_parameter(
153        fit.likelihood_scale,
154        fit.standard_deviation,
155        likelihood,
156    )
157}
158
159/// Refresh the Negative-Binomial overdispersion `theta` on the sampling
160/// likelihood spec from the fit's jointly-estimated `theta_hat` before the NUTS
161/// dispatch reads it (#1463).
162///
163/// The construction seed stored on the family spec (`theta: 1.0`) only seeds the
164/// inner solve. NB carries unit REML scale and records its fitted overdispersion
165/// in `likelihood_scale` (`EstimatedNegBinTheta` / `FixedNegBinTheta`), *not* in
166/// the REML dispersion. The NUTS NB log-likelihood / score
167/// (`src/inference/hmc.rs`) reads `theta` straight off this spec, so leaving the
168/// seed in place over-states `Var(y) = μ + μ²/θ` and inflates every
169/// coefficient's posterior SD ~1.4–1.5× (the HMC sibling of the replicate-path
170/// bug #1124). This mirrors the canonical replicate picker
171/// [`crate::generative::family_noise_parameter`]'s `negbin_theta().or(seed)`:
172/// when the scale records a fitted `theta_hat`, use it; otherwise keep the
173/// existing seed. `theta_fixed` NB carries the user's exact value in both the
174/// spec and the scale metadata, so this refresh is a no-op there. Non-NB
175/// families are left untouched.
176fn refresh_negbin_theta_for_sampling(
177    likelihood: &mut LikelihoodSpec,
178    scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180    if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181        if let Some(theta_hat) = scale.negbin_theta() {
182            *theta = theta_hat;
183        }
184    }
185}
186
187/// Build a `LikelihoodSpec` for a saved model. Saved models already carry the
188/// response distribution and parameterized link state together, so sampling can
189/// dispatch directly on the cloned spec.
190fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191    Ok(model.likelihood())
192}
193
194/// Default smoothing strength `λ` applied to a reconstructed penalty block when
195/// the saved model carries no fitted `smooth_lambda`. A mild penalty: enough to
196/// regularize the reconstructed-for-prediction design without materially
197/// reshaping the saved fit. Fitted lambdas, when present, always override this.
198const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202    gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207    splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210/// Run NUTS posterior sampling over a saved model.
211///
212/// Dispatches on `model.predict_model_class()`:
213///
214/// * `Standard`: Gaussian identity models use the exact saved
215///   `N(mode, φ·H⁻¹)` posterior, where `mode`, `φ`, and `H` all come from the
216///   training fit. Other standard GLMs run NUTS from the saved mode,
217///   smoothing parameters, dispersion, and whitening curvature rather than
218///   refitting/reselecting them on the caller-supplied rows. Link-wiggle
219///   models take a specialised joint-space path that preserves the basis
220///   chain rule.
221/// * `Survival`: rebuilds the survival design (Royston-Parmar baseline +
222///   wiggle + covariate blocks) on the supplied data, evaluates the mode,
223///   and runs the survival-flat NUTS path. Latent and location-scale modes
224///   are explicitly rejected here.
225/// * Other model classes (location-scale GLM, bernoulli marginal-slope,
226///   transformation-normal) return a "not implemented" error matching the
227///   CLI surface.
228pub fn sample_saved_model(
229    model: &SavedModel,
230    data: ArrayView2<'_, f64>,
231    col_map: &HashMap<String, usize>,
232    training_headers: Option<&Vec<String>>,
233    cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235    // Issue #399: degenerate draw/chain counts (`samples=0` / `chains=0`, and
236    // the `samples < 4` counts the split-R-hat engine path cannot handle) must
237    // surface as one typed `InvalidConfig` error before any sampler runs —
238    // identically across *every* model class. Validating here, at the single
239    // public dispatch point, guarantees that the NUTS path, the auto-selected
240    // Pólya-Gamma Gibbs path, and the Laplace-Gaussian fallback all reject the
241    // same inputs the same way (previously the fallback silently accepted them
242    // via `.max(1)` while NUTS errored — a divergent contract on one API).
243    validate_nuts_config(cfg).map_err(String::from)?;
244    let likelihood = likelihood_spec_for_saved_model(model)?;
245    match model.predict_model_class() {
246        PredictModelClass::Survival => {
247            // Latent / latent-binary / location-scale survival likelihoods
248            // have no exact NUTS implementation in the engine yet; fall
249            // through to the Laplace-Gaussian fallback so callers still
250            // get a posterior they can predict with. Royston-Parmar /
251            // Weibull / marginal-slope survival use the exact path.
252            let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253            if matches!(
254                saved_likelihood_mode,
255                SurvivalLikelihoodMode::Latent
256                    | SurvivalLikelihoodMode::LatentBinary
257                    | SurvivalLikelihoodMode::LocationScale
258            ) {
259                laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260            } else {
261                sample_survival(model, data, col_map, training_headers, cfg)
262            }
263        }
264        PredictModelClass::Standard => {
265            // Most `Standard` GLM families (Gaussian, Poisson, Gamma, Tweedie,
266            // Negative-Binomial, binomial logit/probit/cloglog) have an exact
267            // NUTS implementation and run through `sample_standard`. Beta
268            // regression is the one `Standard` family the engine cannot sample
269            // with NUTS (`hmc_io.rs` returns a hard error for it). Rather than
270            // aborting the whole `sample` command, route it to the same
271            // Laplace-Gaussian fallback every other NUTS-unsupported model
272            // class already uses, so callers still get a usable posterior.
273            if matches!(likelihood.response, ResponseFamily::Beta { .. }) {
274                laplace_gaussian_fallback(model, cfg, "beta-regression posterior fallback")
275            } else {
276                sample_standard(model, data, col_map, training_headers, likelihood, cfg)
277            }
278        }
279        // For classes where the Rust core doesn't yet have an exact NUTS
280        // implementation we fall back to drawing from the Laplace
281        // (Gaussian) approximation of the posterior around the fitted
282        // joint mode, using the saved penalised Hessian. This is the
283        // standard "Bayesian credible interval" surface used by mgcv
284        // and similar packages: it drops higher-order posterior shape
285        // but lets every downstream consumer (credible intervals,
286        // posterior predictive, etc.) keep working uniformly across
287        // model classes.
288        PredictModelClass::GaussianLocationScale => {
289            laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
290        }
291        PredictModelClass::BinomialLocationScale => {
292            laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
293        }
294        PredictModelClass::DispersionLocationScale => {
295            laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
296        }
297        PredictModelClass::BernoulliMarginalSlope => {
298            laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
299        }
300        PredictModelClass::TransformationNormal => {
301            laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
302        }
303    }
304}
305
306/// Draw iid samples from `N(mode, H^{-1})` using the saved penalised
307/// Hessian `H = L L^T`.
308///
309/// We solve `L^T δ = ε` for each iid `ε ~ N(0, I)` and report
310/// `β = mode + δ`. The resulting draws are unbiased samples of the
311/// Laplace-Gaussian approximation: their finite-sample mean / std
312/// converge to `(mode, diag(H^{-1})^{1/2})` and the implied credible
313/// bands match the surface that closed-form posterior tooling in
314/// `mgcv` and `gam` itself uses for prediction intervals.
315///
316/// `rationale` is a short label appearing in error messages so callers
317/// can tell which class fell back to this path. We mark `rhat = 1.0`
318/// and `ess = n_total` because the draws are iid by construction.
319pub fn laplace_gaussian_fallback(
320    model: &SavedModel,
321    cfg: &NutsConfig,
322    rationale: &'static str,
323) -> Result<NutsResult, String> {
324    use gam_problem::dispersion_cov::DispersionExt as _;
325    // Defense in depth: this is `pub`, so guard the same degenerate
326    // draw/chain counts the NUTS / PG paths reject (issue #399) rather than
327    // papering over `n_chains == 0` / `n_samples == 0` with `.max(1)`, which
328    // would silently fabricate draws the caller never asked for.
329    validate_nuts_config(cfg).map_err(String::from)?;
330    let fit = fit_result_from_saved_model_for_prediction(model)?;
331    let mode = fit.beta.clone();
332    let p = mode.len();
333    if p == 0 {
334        return Err(format!(
335            "{rationale}: cannot sample from an empty coefficient vector"
336        ));
337    }
338    let h = fit.penalized_hessian().ok_or_else(|| {
339        format!(
340            "{rationale}: posterior fallback requires the explicit penalised Hessian; \
341             refit with exact geometry export to enable posterior sampling for this class."
342        )
343    })?;
344    // `penalized_hessian` is stored unscaled (no φ). To draw Laplace
345    // approximations of `N(mode, φ·H⁻¹)` we solve `Lᵀ δ = ε` (so
346    // `Var(δ) = H⁻¹`) and then rescale by √φ. For families with
347    // `Dispersion::Known(1.0)` (Binomial / Poisson) this is a no-op;
348    // for Gaussian / Gamma it restores the φ-scaled posterior
349    // covariance that the Wald-style intervals downstream assume.
350    let dispersion = fit.dispersion().unwrap_or_default();
351    let sqrt_phi = dispersion.sqrt_phi();
352    if h.nrows() != p || h.ncols() != p {
353        return Err(format!(
354            "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
355            h.nrows(),
356            h.ncols(),
357            p,
358            p
359        ));
360    }
361    let chol = h.cholesky(Side::Lower).map_err(|err| {
362        format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
363    })?;
364    let l = chol.lower_triangular();
365
366    // `validate_nuts_config` above guarantees `n_chains >= 1` and
367    // `n_samples >= 4`, so the draw grid is always non-empty and densely
368    // filled — no `.max(1)` clamping or bounds guard is needed.
369    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
370    let mut samples = Array2::<f64>::zeros((n_total, p));
371    let mut eps = Array1::<f64>::zeros(p);
372    let mut delta = Array1::<f64>::zeros(p);
373    for chain in 0..cfg.n_chains {
374        let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
375            cfg.seed,
376            chain,
377            0xA0B7_6C5D_E431_298F,
378        ));
379        for draw in 0..cfg.n_samples {
380            let k = chain * cfg.n_samples + draw;
381            for i in 0..p {
382                eps[i] = sample_standard_normal(&mut rng);
383            }
384            back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
385            for i in 0..p {
386                // `delta` has covariance H⁻¹; multiplying by √φ produces a
387                // draw with covariance φ·H⁻¹, matching the φ-scaled
388                // posterior covariance `Vb` the rest of inference assumes.
389                samples[(k, i)] = mode[i] + sqrt_phi * delta[i];
390            }
391        }
392    }
393
394    let posterior_mean = samples
395        .mean_axis(ndarray::Axis(0))
396        .unwrap_or_else(|| Array1::<f64>::zeros(p));
397    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
398
399    Ok(NutsResult {
400        samples,
401        posterior_mean,
402        posterior_std,
403        rhat: 1.0,
404        ess: n_total as f64,
405        converged: true,
406    })
407}
408
409#[inline]
410fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
411    // Box-Muller transform — sufficient for posterior-mean-style sampling.
412    // The same construction is used by the NUTS warmup; keeping it in
413    // sync avoids two divergent gaussian RNG paths inside the engine.
414    let u1 = rng.random::<f64>().max(1e-16);
415    let u2 = rng.random::<f64>();
416    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
417}
418
419fn sample_standard(
420    model: &SavedModel,
421    data: ArrayView2<'_, f64>,
422    col_map: &HashMap<String, usize>,
423    training_headers: Option<&Vec<String>>,
424    mut likelihood: LikelihoodSpec,
425    cfg: &NutsConfig,
426) -> Result<NutsResult, String> {
427    // A coefficient that needs a *constraint-aware* posterior sampler must not
428    // take the gaussian-identity closed-form Laplace shortcut: that shortcut
429    // draws an unconstrained `N(mode, φ·H⁻¹)`, which for an active bound puts
430    // ~half its mass on the forbidden side of the boundary. Three geometries
431    // qualify, and all reproduce on a default gaussian model:
432    //   * a `bounded(x, min, max)` interval transform (#1508) — sampled on its
433    //     latent logit scale by `sample_standard_bounded`;
434    //   * a `nonnegative()`/`nonpositive()`/`linear(min,max)`/`constrain()` box
435    //     bound on a parametric coefficient (#1507); and
436    //   * a monotone/convex/concave shape cone on a spline (#1509).
437    // The latter two are sampled from the truncated Gaussian below. Detect all
438    // three cheaply from the saved termspec so the common, fully-unconstrained
439    // gaussian path keeps its fast exact fallback without building the design;
440    // the precise dispatch (and the authoritative `design.linear_constraints`
441    // check) happens after the design is assembled.
442    let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
443        ts.linear_terms.iter().any(|term| {
444            !matches!(
445                term.coefficient_geometry,
446                LinearCoefficientGeometry::Unconstrained
447            ) || term.coefficient_min.is_some()
448                || term.coefficient_max.is_some()
449        }) || ts
450            .smooth_terms
451            .iter()
452            .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
453    });
454    if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
455        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
456    }
457    if model.has_link_wiggle() {
458        // A Gaussian-identity link-wiggle model is sampled from its saved
459        // closed-form joint Laplace posterior (the mean and wiggle coefficients
460        // are jointly Gaussian); only the non-Gaussian wiggle posterior needs
461        // the dedicated link-wiggle NUTS path. Preserved from the original
462        // dispatch, where the Gaussian-identity shortcut ran ahead of the
463        // wiggle branch and so claimed Gaussian wiggle models for the
464        // closed-form path.
465        if likelihood.is_gaussian_identity() {
466            return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
467        }
468        return sample_standard_link_wiggle(
469            model,
470            data,
471            col_map,
472            training_headers,
473            likelihood,
474            cfg,
475        );
476    }
477    let parsed = parse_formula(&model.formula)?;
478    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
479    let y = data.column(y_col).to_owned();
480    let spec = resolve_termspec_for_prediction(
481        &model.resolved_termspec,
482        training_headers,
483        col_map,
484        "resolved_termspec",
485    )?;
486    let design = build_term_collection_design(data, &spec)
487        .map_err(|e| format!("failed to build term collection design: {e}"))?;
488
489    // ---- Constraint-aware posterior dispatch -------------------------------
490    //
491    // A coefficient subject to an *active* constraint sits on the boundary of
492    // its feasible region, so a plain unconstrained draw `N(mode, φ·H⁻¹)`
493    // places ~half its mass on the forbidden side. The constrained geometry
494    // must therefore be reconstructed *here*, ahead of both the
495    // Gaussian-identity closed-form shortcut and the GLM-NUTS fallback —
496    // neither of which is aware of the feasible region (#1507/#1508/#1509).
497    // The fit pins the point estimate correctly; only the posterior was blind.
498
499    // (1) bounded() interval coefficients are not sampled by the GLM-NUTS path.
500    // That path runs the Hamiltonian over the *raw* linear design with the
501    // saved user-scale mode, treating every coefficient as an unconstrained,
502    // Gaussian-penalized parameter. Bounded terms are fit through a custom
503    // family that drives eta via an interval transform `beta = min + (max-min)·
504    // sigmoid(theta)` of an unconstrained latent `theta`. The posterior is
505    // Gaussian on that *latent* scale (which is exactly where the fit treats the
506    // coefficient as a locally-quadratic, unconstrained parameter), so the
507    // correct draws are `theta ~ N(theta_mode, H_latent^{-1})` pushed forward
508    // through the interval map — never a Gaussian on the user scale, which can
509    // place mass outside [min,max] and discards the boundary-induced skew. The
510    // saved fit exports the user-scale mode and user-scale penalized Hessian;
511    // `sample_bounded_latent_posterior_internal` reconstructs the latent
512    // geometry via the exact inverse delta-method (`H_latent = J H_user J`) and
513    // returns user-scale draws that always lie strictly inside the interval.
514    // This must precede the Gaussian-identity shortcut: a Gaussian `bounded()`
515    // model would otherwise take the closed-form path and emit a user-scale
516    // Gaussian that spills outside the interval (#1508).
517    let has_bounded = spec.linear_terms.iter().any(|term| {
518        matches!(
519            term.coefficient_geometry,
520            LinearCoefficientGeometry::Bounded { .. }
521        )
522    });
523    if has_bounded {
524        // Mirror the fit-time layout: linear coefficient `j` lives at column
525        // `intercept_range.end + j` of the model's coefficient vector. Bounds
526        // are on the original (user/data) scale, which is also the scale the
527        // saved beta and penalized Hessian live on.
528        let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
529            .linear_terms
530            .iter()
531            .enumerate()
532            .filter_map(|(j, term)| match term.coefficient_geometry {
533                LinearCoefficientGeometry::Bounded { min, max, .. } => {
534                    Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
535                        col_idx: design.intercept_range.end + j,
536                        min,
537                        max,
538                    })
539                }
540                LinearCoefficientGeometry::Unconstrained => None,
541            })
542            .collect();
543        return sample_standard_bounded(model, cfg, &bounded_columns);
544    }
545
546    // (2) box / shape *inequality* constraints — `nonnegative()` /
547    // `linear(min,max)` / `constrain()` box bounds on a parametric coefficient
548    // (#1507) and the monotone/convex/concave shape cone `γ_j ≥ 0` on a spline
549    // (#1509). Both are reconstructed by `build_term_collection_design` into a
550    // single `A β ≥ b` polytope in the saved coefficient coordinate system, so
551    // one truncated-Gaussian sampler covers them uniformly. Like `bounded()`,
552    // this must precede the Gaussian-identity shortcut so a constrained
553    // Gaussian model is sampled inside its feasible region rather than from the
554    // boundary-centred unconstrained Gaussian.
555    if let Some(constraints) = design
556        .linear_constraints
557        .as_ref()
558        .filter(|c| c.a.nrows() > 0)
559    {
560        return sample_standard_truncated(model, cfg, constraints);
561    }
562
563    // (3) unconstrained Gaussian identity — saved closed-form Laplace posterior.
564    if likelihood.is_gaussian_identity() {
565        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
566    }
567
568    // (4) unconstrained non-Gaussian GLM — exact NUTS over the raw design.
569    let weights = Array1::ones(data.nrows());
570    let dense_design_hmc = design.design.to_dense();
571    let p = dense_design_hmc.ncols();
572    let fit = fit_result_from_saved_model_for_prediction(model)?;
573    // Refresh the NB overdispersion `theta` from the fit's jointly-estimated
574    // `theta_hat` before sampling. The construction seed stored on the family
575    // spec (`theta: 1.0`) only seeds the inner solve; the NUTS NB log-likelihood
576    // / score (`src/inference/hmc.rs`) reads `theta` straight off this spec, so
577    // leaving the seed in place over-states `Var(y) = μ + μ²/θ` and inflates
578    // every coefficient's posterior SD (#1463 — the HMC sibling of the
579    // replicate-path bug #1124). `theta_fixed` NB carries the user's exact value
580    // in both the spec and the scale metadata, so this refresh is a no-op there.
581    // Mirrors how the replicate path reads `theta_hat` via the canonical
582    // `family_noise_parameter` helper (`negbin_theta().or(seed)`).
583    refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
584    if fit.beta.len() != p {
585        return Err(format!(
586            "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
587            fit.beta.len(),
588            p,
589        ));
590    }
591    if fit.lambdas.len() != design.penalties.len() {
592        return Err(format!(
593            "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
594            fit.lambdas.len(),
595            design.penalties.len(),
596        ));
597    }
598    let penalty =
599        weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
600
601    // Re-apply the offset the model was fit with so the posterior targets the
602    // same η = Xβ + offset as the fit and predict paths. The diagnostic loader
603    // keeps the saved offset column in the frame; dropping the offset silently
604    // sampled the wrong target for any `--offset-column` GLM (#882).
605    let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
606        Some(name) => {
607            let idx = resolve_role_col(col_map, name, "offset")?;
608            Some(data.column(idx).to_owned())
609        }
610        None => None,
611    };
612
613    run_nuts_sampling_flattened_family(
614        likelihood,
615        FamilyNutsInputs::Glm(GlmFlatInputs {
616            x: dense_design_hmc.view(),
617            y: y.view(),
618            weights: weights.view(),
619            penalty_matrix: penalty.view(),
620            mode: fit.beta.view(),
621            hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
622            gamma_shape: fit.likelihood_scale.gamma_shape(),
623            // Forward the saved training dispersion so NUTS whitening uses the
624            // posterior scale selected at fit time; fixed-scale families remain
625            // a no-op.
626            dispersion: fit.dispersion().unwrap_or_default(),
627            firth_bias_reduction: false,
628            offset: offset_vec.as_ref().map(|o| o.view()),
629        }),
630        cfg,
631    )
632    .map_err(|e| format!("NUTS sampling failed: {e}"))
633}
634
635/// Exact posterior draws for a standard GLM with `bounded()` coefficients.
636///
637/// The bounded coefficients are sampled on their natural latent (logit) scale —
638/// where the Laplace approximation is Gaussian — and every draw is pushed
639/// through the exact interval map so user-scale draws always lie strictly inside
640/// `[min, max]` and carry the boundary-induced skew. Non-bounded coefficients
641/// are drawn as the ordinary Gaussian Laplace component of the same joint
642/// posterior, so cross-coefficient correlations with the bounded columns are
643/// preserved (the latent precision is the full `H_latent = J H_user J`).
644fn sample_standard_bounded(
645    model: &SavedModel,
646    cfg: &NutsConfig,
647    bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
648) -> Result<NutsResult, String> {
649    validate_nuts_config(cfg).map_err(String::from)?;
650    let fit = fit_result_from_saved_model_for_prediction(model)?;
651    let mode = fit.beta.clone();
652    let p = mode.len();
653    if p == 0 {
654        return Err(
655            "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
656                .to_string(),
657        );
658    }
659    // The bounded fit exports the UNSCALED user-scale penalized Hessian; the
660    // latent sampler reconstructs the latent precision from it via the exact
661    // inverse delta-method. (`explicit_fit_hessian_for_whitening` returns this
662    // same user-scale penalized Hessian for a saved standard fit.)
663    let user_hessian =
664        explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
665    // The exported Hessian carries unit implicit dispersion, so the latent
666    // posterior covariance is `cov_scale·H_latent⁻¹` with `cov_scale` the
667    // coefficient-covariance scale the fit used for `Vb` (`σ̂²` for a profiled
668    // Gaussian, `1` for fixed-scale Binomial). Re-applying `√cov_scale` here
669    // keeps the draw spread identical to the reported `summary().std_error`
670    // (gam#1514); the truncated-constraint path does the analogous √φ lift.
671    let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
672    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
673    let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
674        &mode,
675        user_hessian,
676        bounded_columns,
677        n_total,
678        sqrt_cov_scale,
679        chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
680    )
681    .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
682
683    let posterior_mean = samples
684        .mean_axis(ndarray::Axis(0))
685        .unwrap_or_else(|| Array1::<f64>::zeros(p));
686    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
687
688    Ok(NutsResult {
689        samples,
690        posterior_mean,
691        posterior_std,
692        rhat: 1.0,
693        ess: n_total as f64,
694        converged: true,
695    })
696}
697
698/// Exact posterior draws for a standard GLM whose coefficients carry linear
699/// *inequality* constraints `A β ≥ b` — `nonnegative()` / `linear(min,max)` /
700/// `constrain()` box bounds on a parametric term (#1507) and the
701/// monotone/convex/concave shape cone `γ_j ≥ 0` on a spline (#1509).
702///
703/// The posterior is the Laplace Gaussian `N(mode, φ·H⁻¹)` *truncated* to the
704/// feasible polytope. For a Gaussian-identity model this is the exact
705/// posterior; for a non-Gaussian GLM it is the constraint-respecting Laplace
706/// approximation — the same modelling choice the `bounded()` term makes. The
707/// draws are produced by exact reflective Hamiltonian Monte Carlo
708/// ([`crate::truncated_gaussian`]), so every draw is feasible and
709/// successive draws are essentially independent (`rhat ≈ 1`, matching the other
710/// Laplace posterior paths).
711fn sample_standard_truncated(
712    model: &SavedModel,
713    cfg: &NutsConfig,
714    constraints: &gam_solve::pirls::LinearInequalityConstraints,
715) -> Result<NutsResult, String> {
716    validate_nuts_config(cfg).map_err(String::from)?;
717    let fit = fit_result_from_saved_model_for_prediction(model)?;
718    let mode = fit.beta.clone();
719    let p = mode.len();
720    if p == 0 {
721        return Err(
722            "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
723             vector"
724                .to_string(),
725        );
726    }
727    // The saved standard fit exports the unscaled user-scale penalised Hessian
728    // `H`; the truncated sampler whitens with its Cholesky and re-applies √φ so
729    // the posterior covariance is `φ·H⁻¹`, identical to the unconstrained
730    // Gaussian/bounded paths. Fixed-scale families (Binomial / Poisson) have
731    // φ = 1.
732    let penalized_hessian =
733        explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
734    let sqrt_phi = {
735        use gam_problem::dispersion_cov::DispersionExt as _;
736        fit.dispersion().unwrap_or_default().sqrt_phi()
737    };
738    let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
739        &mode,
740        &penalized_hessian,
741        sqrt_phi,
742        constraints,
743        cfg.n_samples,
744        cfg.n_chains,
745        chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
746    )?;
747    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
748
749    let posterior_mean = samples
750        .mean_axis(ndarray::Axis(0))
751        .unwrap_or_else(|| Array1::<f64>::zeros(p));
752    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
753
754    Ok(NutsResult {
755        samples,
756        posterior_mean,
757        posterior_std,
758        rhat: 1.0,
759        ess: n_total as f64,
760        converged: true,
761    })
762}
763
764fn sample_standard_link_wiggle(
765    model: &SavedModel,
766    data: ArrayView2<'_, f64>,
767    col_map: &HashMap<String, usize>,
768    training_headers: Option<&Vec<String>>,
769    likelihood: LikelihoodSpec,
770    cfg: &NutsConfig,
771) -> Result<NutsResult, String> {
772    let parsed = parse_formula(&model.formula)?;
773    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
774    let y = data.column(y_col).to_owned();
775
776    let spec = resolve_termspec_for_prediction(
777        &model.resolved_termspec,
778        training_headers,
779        col_map,
780        "resolved_termspec",
781    )?;
782    let design = build_term_collection_design(data, &spec)
783        .map_err(|e| format!("failed to build term collection design: {e}"))?;
784    let p_main = design.design.ncols();
785
786    let fit = fit_result_from_saved_model_for_prediction(model)?;
787    let wiggle_runtime = model
788        .saved_prediction_runtime()?
789        .link_wiggle
790        .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
791    let mode_beta = fit
792        .block_by_role(BlockRole::Mean)
793        .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
794        .beta
795        .clone();
796    let mode_theta = fit
797        .block_by_role(BlockRole::LinkWiggle)
798        .ok_or_else(|| {
799            "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
800        })?
801        .beta
802        .clone();
803    let p_wiggle = mode_theta.len();
804    let p_total = mode_beta.len() + p_wiggle;
805
806    if mode_beta.len() != p_main {
807        return Err(format!(
808            "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
809            mode_beta.len(),
810            p_main,
811        ));
812    }
813    if fit.beta.len() != p_total {
814        return Err(format!(
815            "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
816            fit.beta.len(),
817            p_main,
818            p_wiggle,
819            p_total,
820        ));
821    }
822
823    let hessian = &fit
824        .geometry
825        .as_ref()
826        .ok_or_else(|| {
827            "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
828                .to_string()
829        })?
830        .penalized_hessian;
831    validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
832
833    let n_base_penalties = design.penalties.len();
834    let base_lambdas = fit
835        .block_by_role(BlockRole::Mean)
836        .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
837        .lambdas
838        .view();
839    if base_lambdas.len() != n_base_penalties {
840        return Err(format!(
841            "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
842            base_lambdas.len(),
843            n_base_penalties,
844        ));
845    }
846
847    let penalty_base =
848        weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
849
850    let wiggle_lambdas_owned = fit
851        .lambdas_linkwiggle()
852        .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
853    let wiggle_lambdas = wiggle_lambdas_owned.view();
854    let degree = wiggle_runtime.degree;
855    let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
856
857    let mut wiggle_penalties = Vec::new();
858    let default_orders = [2usize];
859    let n_wiggle_lambdas = wiggle_lambdas.len();
860    for k in 0..n_wiggle_lambdas {
861        let order = if k < default_orders.len() {
862            default_orders[k]
863        } else {
864            k + 1
865        };
866        if order >= p_wiggle {
867            continue;
868        }
869        let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
870            .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
871        wiggle_penalties.push(penalty);
872    }
873    while wiggle_penalties.len() < n_wiggle_lambdas {
874        wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
875    }
876
877    let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
878
879    let q0 = design.design.dot(&mode_beta);
880    let (q0_min, q0_max) = q0
881        .iter()
882        .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
883            (lo.min(v), hi.max(v))
884        });
885
886    let spline = LinkWiggleSplineArtifacts {
887        knot_range: (q0_min, q0_max),
888        knot_vector: knot_arr,
889        degree,
890    };
891
892    let nuts_family = match (&likelihood.response, &likelihood.link) {
893        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
894            NutsFamily::BinomialLogit
895        }
896        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
897            NutsFamily::BinomialProbit
898        }
899        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
900            NutsFamily::BinomialCLogLog
901        }
902        (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
903        (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
904        (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
905        (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
906        (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
907        _ => {
908            return Err(format!(
909                "NUTS sampling with link wiggle is not supported for family {}",
910                likelihood.pretty_name()
911            ));
912        }
913    };
914
915    let weights = Array1::ones(data.nrows());
916    let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
917
918    let wiggle_nuts_dense = design.design.as_dense_cow();
919    run_link_wiggle_nuts_sampling(
920        wiggle_nuts_dense.view(),
921        y.view(),
922        weights.view(),
923        penalty_base.view(),
924        penalty_link.view(),
925        mode_beta.view(),
926        mode_theta.view(),
927        hessian.view(),
928        spline,
929        nuts_family,
930        scale,
931        cfg,
932    )
933    .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
934}
935
936fn sample_survival(
937    model: &SavedModel,
938    data: ArrayView2<'_, f64>,
939    col_map: &HashMap<String, usize>,
940    training_headers: Option<&Vec<String>>,
941    cfg: &NutsConfig,
942) -> Result<NutsResult, String> {
943    let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
944    if matches!(
945        saved_likelihood_mode,
946        SurvivalLikelihoodMode::Latent
947            | SurvivalLikelihoodMode::LatentBinary
948            | SurvivalLikelihoodMode::LocationScale
949    ) {
950        return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
951    }
952    // `survival_entry == None` is the right-censored shorthand
953    // `Surv(time, event)`: training synthesized a zero entry column,
954    // and posterior sampling must do the same so artifacts fit with
955    // the shorthand are first-class through `gam sample` /
956    // `model.sample` just like `gam predict` already handles them in
957    // `run_predict_survival`. The resolution flows through the shared
958    // `resolve_saved_survival_time_columns` helper so every consumer
959    // of saved survival metadata applies the same fallback contract.
960    let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
961    let exit_col = time_cols.exit_col;
962    let eventname = model
963        .survival_event
964        .as_ref()
965        .ok_or_else(|| "survival model missing event column metadata".to_string())?;
966    let event_col = resolve_role_col(col_map, eventname, "event")?;
967    let termspec = resolve_termspec_for_prediction(
968        &model.resolved_termspec,
969        training_headers,
970        col_map,
971        "resolved_termspec",
972    )?;
973    let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
974    let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
975    let cov_design = build_term_collection_design(cov_input, &termspec)
976        .map_err(|e| format!("failed to build survival design: {e}"))?;
977    let n = data.nrows();
978    let p_cov = cov_design.design.ncols();
979    let mut age_entry = Array1::<f64>::zeros(n);
980    let mut age_exit = Array1::<f64>::zeros(n);
981    let mut event_target = Array1::<u8>::zeros(n);
982    let event_competing = Array1::<u8>::zeros(n);
983    let weights = Array1::<f64>::ones(n);
984    for i in 0..n {
985        let (t0, t1) = normalize_survival_time_pair(
986            time_cols.row_entry_time(data, i),
987            data[[i, exit_col]],
988            i,
989        )?;
990        age_entry[i] = t0;
991        age_exit[i] = t1;
992        event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
993    }
994    let time_cfg = load_survival_time_basis_config_from_model(model)?;
995    let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
996    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
997        &time_build.basisname,
998        time_build.degree,
999        time_build.knots.as_ref(),
1000        time_build.keep_cols.as_ref(),
1001        time_build.smooth_lambda,
1002    )?;
1003    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1004        let time_anchor = model
1005            .survival_time_anchor
1006            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1007        let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
1008        center_survival_time_designs_at_anchor(
1009            &mut time_build.x_entry_time,
1010            &mut time_build.x_exit_time,
1011            &time_anchor_row,
1012        )?;
1013    }
1014    let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1015    let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1016        build_survival_time_offsets_for_likelihood(
1017            &age_entry,
1018            &age_exit,
1019            &baseline_cfg,
1020            saved_likelihood_mode,
1021            None,
1022        )?;
1023    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1024        let time_anchor = model
1025            .survival_time_anchor
1026            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1027        add_survival_time_derivative_guard_offset(
1028            &age_entry,
1029            &age_exit,
1030            time_anchor,
1031            survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1032            &mut eta_offset_entry,
1033            &mut eta_offset_exit,
1034            &mut derivative_offset_exit,
1035        )?;
1036    }
1037    let saved_timewiggle = saved_baseline_timewiggle_components(
1038        &eta_offset_entry,
1039        &eta_offset_exit,
1040        &derivative_offset_exit,
1041        model,
1042    )?;
1043    let p_time = time_build.x_exit_time.ncols();
1044    let p_timewiggle = saved_timewiggle
1045        .as_ref()
1046        .map(|(_, exit, _)| exit.ncols())
1047        .unwrap_or(0);
1048    let p = p_time + p_timewiggle + p_cov;
1049    let tb_entry_dense = time_build.x_entry_time.to_dense();
1050    let tb_exit_dense = time_build.x_exit_time.to_dense();
1051    let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1052    let mut x_entry = Array2::<f64>::zeros((n, p));
1053    let mut x_exit = Array2::<f64>::zeros((n, p));
1054    let mut x_derivative = Array2::<f64>::zeros((n, p));
1055    if p_time > 0 {
1056        x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1057        x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1058        x_derivative
1059            .slice_mut(s![.., ..p_time])
1060            .assign(&tb_deriv_dense);
1061    }
1062    if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1063        && p_timewiggle > 0
1064    {
1065        x_entry
1066            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1067            .assign(entry_w);
1068        x_exit
1069            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1070            .assign(exit_w);
1071        x_derivative
1072            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1073            .assign(deriv_w);
1074    }
1075    if p_cov > 0 {
1076        let cov_dense = cov_design.design.to_dense();
1077        let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1078        x_entry
1079            .slice_mut(s![.., cov_range.clone()])
1080            .assign(&cov_dense);
1081        x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1082    }
1083    let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1084    for (idx, s) in time_build.penalties.iter().enumerate() {
1085        if s.nrows() == p_time && s.ncols() == p_time {
1086            penalty_blocks.push(PenaltyBlock {
1087                matrix: s.clone(),
1088                lambda: time_build
1089                    .smooth_lambda
1090                    .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1091                range: 0..p_time,
1092                nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1093            });
1094        }
1095    }
1096    let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1097    if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1098        let start = p_time;
1099        let end = start + exit_w.ncols();
1100        let wiggle_lambda_offset = penalty_blocks.len();
1101        let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1102            "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1103        })?;
1104        let wiggle_degree = wiggle_cfg.degree;
1105        let wiggle_knots =
1106            Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1107                "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1108            })?);
1109        let mut seed = Array1::<f64>::zeros(2 * n);
1110        for i in 0..n {
1111            seed[i] = eta_offset_entry[i];
1112            seed[n + i] = eta_offset_exit[i];
1113        }
1114        let (primary_order, extra_orders) =
1115            split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1116        let mut block = buildwiggle_block_input_from_knots(
1117            seed.view(),
1118            &wiggle_knots,
1119            wiggle_degree,
1120            primary_order,
1121            wiggle_cfg.double_penalty,
1122        )?;
1123        append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1124            .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1125        for (widx, s) in block.penalties.iter().enumerate() {
1126            let s = match s {
1127                gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1128                gam_solve::estimate::PenaltySpec::Dense(m)
1129                | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1130            };
1131            if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1132                penalty_blocks.push(PenaltyBlock {
1133                    matrix: s.clone(),
1134                    lambda: time_build
1135                        .smooth_lambda
1136                        .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1137                    range: start..end,
1138                    nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1139                });
1140            }
1141        }
1142        for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1143            .iter_mut()
1144            .enumerate()
1145        {
1146            if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1147                block_penalty.lambda = lam;
1148            }
1149        }
1150    }
1151    let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1152        "saved survival model is missing survivalridge_lambda; refusing to \
1153         pick a load-time default (the historical 1e-4 fallback silently \
1154         disagreed with the 1e-6 fit-time default). Refit."
1155            .to_string()
1156    })?;
1157    let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1158    {
1159        1
1160    } else {
1161        0
1162    };
1163    if ridge_lambda > 0.0 && p > ridge_range_start {
1164        let dim = p - ridge_range_start;
1165        let mut ridge = Array2::<f64>::zeros((dim, dim));
1166        for d in 0..dim {
1167            ridge[[d, d]] = 1.0;
1168        }
1169        penalty_blocks.push(PenaltyBlock {
1170            matrix: ridge,
1171            lambda: ridge_lambda,
1172            range: ridge_range_start..p,
1173            nullspace_dim: 0,
1174        });
1175    }
1176    for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1177        if let Some(&lam) = fit_saved.lambdas.get(idx) {
1178            block.lambda = lam;
1179        }
1180    }
1181    let penalties = PenaltyBlocks::new(penalty_blocks);
1182    let survivalspec = match model
1183        .survivalspec
1184        .as_deref()
1185        .unwrap_or("net")
1186        .to_ascii_lowercase()
1187        .as_str()
1188    {
1189        "net" => SurvivalSpec::Net,
1190        "crude" => {
1191            return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1192                        .to_string());
1193        }
1194        other => {
1195            return Err(format!("unsupported saved survival spec '{other}'"));
1196        }
1197    };
1198    let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1199    let mut model_surv = royston_parmar::working_model_from_flattened(
1200        penalties.clone(),
1201        monotonicity,
1202        survivalspec,
1203        RoystonParmarInputs {
1204            age_entry: age_entry.view(),
1205            age_exit: age_exit.view(),
1206            event_target: event_target.view(),
1207            event_competing: event_competing.view(),
1208            weights: weights.view(),
1209            x_entry: x_entry.view(),
1210            x_exit: x_exit.view(),
1211            x_derivative: x_derivative.view(),
1212            monotonicity_constraint_rows: None,
1213            monotonicity_constraint_offsets: None,
1214            eta_offset_entry: Some(eta_offset_entry.view()),
1215            eta_offset_exit: Some(eta_offset_exit.view()),
1216            derivative_offset_exit: Some(derivative_offset_exit.view()),
1217        },
1218    )
1219    .map_err(|e| format!("failed to construct survival model: {e}"))?;
1220    if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1221        model_surv
1222            .set_structural_monotonicity(true, p_time + p_timewiggle)
1223            .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1224    }
1225    let beta0 = fit_saved.beta.clone();
1226    let state = model_surv
1227        .update_state(&beta0)
1228        .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1229    let hessian = state.hessian.to_dense();
1230    run_survival_nuts_sampling_flattened(
1231        SurvivalFlatInputs {
1232            age_entry: age_entry.view(),
1233            age_exit: age_exit.view(),
1234            event_target: event_target.view(),
1235            event_competing: event_competing.view(),
1236            weights: weights.view(),
1237            x_entry: x_entry.view(),
1238            x_exit: x_exit.view(),
1239            x_derivative: x_derivative.view(),
1240            eta_offset_entry: Some(eta_offset_entry.view()),
1241            eta_offset_exit: Some(eta_offset_exit.view()),
1242            derivative_offset_exit: Some(derivative_offset_exit.view()),
1243        },
1244        penalties,
1245        monotonicity,
1246        survivalspec,
1247        saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1248        p_time + p_timewiggle,
1249        beta0.view(),
1250        hessian.view(),
1251        cfg,
1252    )
1253    .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258    use super::*;
1259    use gam_problem::types::LikelihoodScaleMetadata;
1260
1261    /// #1463: the NB NUTS path must sample at the fit's jointly-estimated
1262    /// `theta_hat`, not the construction seed `theta = 1.0`. The seed only seeds
1263    /// the inner solve; the NUTS NB log-likelihood/score reads `theta` straight
1264    /// off the sampling `LikelihoodSpec`, so unless we refresh it from the scale
1265    /// metadata the posterior is drawn at the wrong overdispersion and every
1266    /// coefficient's posterior SD inflates ~1.4–1.5×.
1267    ///
1268    /// Pre-fix, `sample_standard` forwarded the seed unchanged: this assertion
1269    /// would read `theta == 1.0` and fail. With the refresh in place the seam
1270    /// rewrites the spec to `theta_hat`.
1271    #[test]
1272    fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1273        // Spec carries the construction seed theta = 1.0; the fit estimated a
1274        // very different theta_hat = 2.97 and recorded it in the scale metadata.
1275        let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1276        let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1277
1278        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1279
1280        match likelihood.response {
1281            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1282                theta, 2.97,
1283                "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1284            ),
1285            other => panic!("expected NegativeBinomial response, got {other:?}"),
1286        }
1287    }
1288
1289    /// A fixed-theta NB fit records the user's exact `theta` in both the spec and
1290    /// the scale metadata, so the refresh is a no-op that still lands on the
1291    /// fixed value (never the inner-solve seed of an estimated fit).
1292    #[test]
1293    fn refresh_negbin_theta_fixed_theta_is_preserved() {
1294        let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1295        let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1296
1297        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1298
1299        match likelihood.response {
1300            ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1301                assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1302                assert!(theta_fixed, "theta_fixed flag must be preserved");
1303            }
1304            other => panic!("expected NegativeBinomial response, got {other:?}"),
1305        }
1306    }
1307
1308    /// When the fit recorded no NB theta (non-NB scale metadata), the refresh
1309    /// must leave the spec's seed untouched — mirroring the canonical replicate
1310    /// picker's `negbin_theta().or(seed)`.
1311    #[test]
1312    fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1313        let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1314        // ProfiledGaussian carries no negbin_theta, so the accessor returns None.
1315        refresh_negbin_theta_for_sampling(
1316            &mut likelihood,
1317            LikelihoodScaleMetadata::ProfiledGaussian,
1318        );
1319
1320        match likelihood.response {
1321            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1322                theta, 3.5,
1323                "with no fitted theta the NB seed must be kept verbatim"
1324            ),
1325            other => panic!("expected NegativeBinomial response, got {other:?}"),
1326        }
1327    }
1328
1329    /// Non-NB families must be completely unaffected by the NB refresh, even when
1330    /// the scale metadata happens to carry an NB theta — the match guards on the
1331    /// response family, so Poisson/Gamma/etc. are left untouched.
1332    #[test]
1333    fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1334        let mut poisson = LikelihoodSpec::poisson_log();
1335        let before = poisson.response.clone();
1336        refresh_negbin_theta_for_sampling(
1337            &mut poisson,
1338            LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1339        );
1340        assert_eq!(
1341            poisson.response, before,
1342            "Poisson response must be untouched by the NB theta refresh"
1343        );
1344    }
1345}