Skip to main content

gam_inference/
sample.rs

1//! Library-side orchestration for NUTS posterior sampling from a saved model.
2//!
3//! The CLI's `gam sample` subcommand and the Python `Model.sample(...)` API
4//! both call into [`sample_saved_model`], which dispatches on the saved
5//! model's class (standard GLM, standard with link-wiggle, or survival) and
6//! returns a fully-converged [`NutsResult`] over the original coefficient
7//! space. Gaussian identity standard models are sampled from the saved
8//! closed-form posterior, conditioning on the training fit rather than any
9//! prediction rows supplied by the caller.
10
11use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18    FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19    explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20    run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27    SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28    build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29    evaluate_survival_time_basis_row, normalize_survival_time_pair,
30    resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33    fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34    resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35    saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39    PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42    append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43    split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47    FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55/// Reconstruct the `LinkWiggleFormulaSpec` from a saved model's
56/// baseline-time-wiggle runtime, returning `None` when the model has no
57/// time-wiggle component. Re-exported because the survival fitter's tests
58/// exercise the spec independently of running NUTS.
59pub fn saved_baseline_timewiggle_spec(
60    model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62    model
63        .saved_baseline_time_wiggle()
64        .map_err(|e| e.to_string())
65        .map(|runtime| {
66            runtime.map(|saved| LinkWiggleFormulaSpec {
67                degree: saved.degree,
68                num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69                penalty_orders: saved.penalty_orders,
70                double_penalty: saved.double_penalty,
71            })
72        })
73}
74
75fn weighted_penalty_matrix(
76    penalties: &[Array2<f64>],
77    lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79    if penalties.len() != lambdas.len() {
80        return Err(format!(
81            "penalty/lambda mismatch: {} penalties vs {} lambdas",
82            penalties.len(),
83            lambdas.len()
84        ));
85    }
86    if penalties.is_empty() {
87        return Err("cannot sample without at least one penalty block".to_string());
88    }
89    let p = penalties[0].nrows();
90    let mut out = Array2::<f64>::zeros((p, p));
91    for (k, s) in penalties.iter().enumerate() {
92        if s.nrows() != p || s.ncols() != p {
93            return Err(format!(
94                "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95                s.nrows(),
96                s.ncols(),
97                p,
98                p
99            ));
100        }
101        let lam = lambdas[k];
102        out += &(s * lam);
103    }
104    Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108    hessian: &Array2<f64>,
109    expected_dim: usize,
110) -> Result<(), String> {
111    if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112        return Err(format!(
113            "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114            hessian.nrows(),
115            hessian.ncols(),
116            expected_dim,
117            expected_dim,
118        ));
119    }
120    validate_all_finite(
121        "link-wiggle explicit joint Hessian",
122        hessian.iter().copied(),
123    )?;
124    let mut max_abs = 0.0_f64;
125    for r in 0..expected_dim {
126        for c in 0..expected_dim {
127            max_abs = max_abs.max(hessian[[r, c]].abs());
128            let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129            if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130                return Err(format!(
131                    "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132                ));
133            }
134        }
135    }
136    if max_abs == 0.0 {
137        return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138                    .to_string());
139    }
140    Ok(())
141}
142
143/// Resolve the scalar generative dispersion for a fitted model.
144///
145/// Thin adapter over the single canonical
146/// [`crate::generative::family_noise_parameter`]: the replicate-sampling path
147/// here and the CLI `gam generate` path both route through that one helper, so
148/// the fitted dispersion (NB θ̂, Beta/Tweedie φ̂, Gamma k̂) can never be read
149/// inconsistently between them. A divergent second copy of this logic was the
150/// root cause of #1124.
151fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152    crate::generative::family_noise_parameter(
153        fit.likelihood_scale,
154        fit.standard_deviation,
155        likelihood,
156    )
157}
158
159/// Refresh the Negative-Binomial overdispersion `theta` on the sampling
160/// likelihood spec from the fit's jointly-estimated `theta_hat` before the NUTS
161/// dispatch reads it (#1463).
162///
163/// The construction seed stored on the family spec (`theta: 1.0`) only seeds the
164/// inner solve. NB carries unit REML scale and records its fitted overdispersion
165/// in `likelihood_scale` (`EstimatedNegBinTheta` / `FixedNegBinTheta`), *not* in
166/// the REML dispersion. The NUTS NB log-likelihood / score
167/// (`src/inference/hmc.rs`) reads `theta` straight off this spec, so leaving the
168/// seed in place over-states `Var(y) = μ + μ²/θ` and inflates every
169/// coefficient's posterior SD ~1.4–1.5× (the HMC sibling of the replicate-path
170/// bug #1124). This mirrors the canonical replicate picker
171/// [`crate::generative::family_noise_parameter`]'s `negbin_theta().or(seed)`:
172/// when the scale records a fitted `theta_hat`, use it; otherwise keep the
173/// existing seed. `theta_fixed` NB carries the user's exact value in both the
174/// spec and the scale metadata, so this refresh is a no-op there. Non-NB
175/// families are left untouched.
176fn refresh_negbin_theta_for_sampling(
177    likelihood: &mut LikelihoodSpec,
178    scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180    if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181        if let Some(theta_hat) = scale.negbin_theta() {
182            *theta = theta_hat;
183        }
184    }
185}
186
187/// Build a `LikelihoodSpec` for a saved model. Saved models already carry the
188/// response distribution and parameterized link state together, so sampling can
189/// dispatch directly on the cloned spec.
190fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191    Ok(model.likelihood())
192}
193
194/// Default smoothing strength `λ` applied to a reconstructed penalty block when
195/// the saved model carries no fitted `smooth_lambda`. A mild penalty: enough to
196/// regularize the reconstructed-for-prediction design without materially
197/// reshaping the saved fit. Fitted lambdas, when present, always override this.
198const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202    gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207    splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210/// Run NUTS posterior sampling over a saved model.
211///
212/// Dispatches on `model.predict_model_class()`:
213///
214/// * `Standard`: Gaussian identity models use the exact saved
215///   `N(mode, φ·H⁻¹)` posterior, where `mode`, `φ`, and `H` all come from the
216///   training fit. Other standard GLMs run NUTS from the saved mode,
217///   smoothing parameters, dispersion, and whitening curvature rather than
218///   refitting/reselecting them on the caller-supplied rows. Link-wiggle
219///   models take a specialised joint-space path that preserves the basis
220///   chain rule.
221/// * `Survival`: rebuilds the survival design (Royston-Parmar baseline +
222///   wiggle + covariate blocks) on the supplied data, evaluates the mode,
223///   and runs the survival-flat NUTS path. Latent and location-scale modes
224///   are explicitly rejected here.
225/// * Other model classes (location-scale GLM, bernoulli marginal-slope,
226///   transformation-normal) return a "not implemented" error matching the
227///   CLI surface.
228pub fn sample_saved_model(
229    model: &SavedModel,
230    data: ArrayView2<'_, f64>,
231    col_map: &HashMap<String, usize>,
232    training_headers: Option<&Vec<String>>,
233    cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235    // Issue #399: degenerate draw/chain counts (`samples=0` / `chains=0`, and
236    // the `samples < 4` counts the split-R-hat engine path cannot handle) must
237    // surface as one typed `InvalidConfig` error before any sampler runs —
238    // identically across *every* model class. Validating here, at the single
239    // public dispatch point, guarantees that the NUTS path, the auto-selected
240    // Pólya-Gamma Gibbs path, and the Laplace-Gaussian fallback all reject the
241    // same inputs the same way (previously the fallback silently accepted them
242    // via `.max(1)` while NUTS errored — a divergent contract on one API).
243    validate_nuts_config(cfg).map_err(String::from)?;
244    let likelihood = likelihood_spec_for_saved_model(model)?;
245    match model.predict_model_class() {
246        PredictModelClass::Survival => {
247            // Latent / latent-binary / location-scale survival likelihoods
248            // have no exact NUTS implementation in the engine yet; fall
249            // through to the Laplace-Gaussian fallback so callers still
250            // get a posterior they can predict with. Royston-Parmar /
251            // Weibull / marginal-slope survival use the exact path.
252            let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253            if matches!(
254                saved_likelihood_mode,
255                SurvivalLikelihoodMode::Latent
256                    | SurvivalLikelihoodMode::LatentBinary
257                    | SurvivalLikelihoodMode::LocationScale
258            ) {
259                laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260            } else {
261                sample_survival(model, data, col_map, training_headers, cfg)
262            }
263        }
264        PredictModelClass::Standard => {
265            sample_standard(model, data, col_map, training_headers, likelihood, cfg)
266        }
267        // For classes where the Rust core doesn't yet have an exact NUTS
268        // implementation we fall back to drawing from the Laplace
269        // (Gaussian) approximation of the posterior around the fitted
270        // joint mode, using the saved penalised Hessian. This is the
271        // standard "Bayesian credible interval" surface used by mgcv
272        // and similar packages: it drops higher-order posterior shape
273        // but lets every downstream consumer (credible intervals,
274        // posterior predictive, etc.) keep working uniformly across
275        // model classes.
276        PredictModelClass::GaussianLocationScale => {
277            laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
278        }
279        PredictModelClass::BinomialLocationScale => {
280            laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
281        }
282        PredictModelClass::DispersionLocationScale => {
283            laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
284        }
285        PredictModelClass::BernoulliMarginalSlope => {
286            laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
287        }
288        PredictModelClass::TransformationNormal => {
289            laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
290        }
291    }
292}
293
294/// Draw iid samples from `N(mode, H^{-1})` using the saved penalised
295/// Hessian `H = L L^T`.
296///
297/// We solve `L^T δ = ε` for each iid `ε ~ N(0, I)` and report
298/// `β = mode + δ`. The resulting draws are unbiased samples of the
299/// Laplace-Gaussian approximation: their finite-sample mean / std
300/// converge to `(mode, diag(H^{-1})^{1/2})` and the implied credible
301/// bands match the surface that closed-form posterior tooling in
302/// `mgcv` and `gam` itself uses for prediction intervals.
303///
304/// `rationale` is a short label appearing in error messages so callers
305/// can tell which class fell back to this path. We mark `rhat = 1.0`
306/// and `ess = n_total` because the draws are iid by construction.
307pub fn laplace_gaussian_fallback(
308    model: &SavedModel,
309    cfg: &NutsConfig,
310    rationale: &'static str,
311) -> Result<NutsResult, String> {
312    use gam_problem::dispersion_cov::DispersionExt as _;
313    // Defense in depth: this is `pub`, so guard the same degenerate
314    // draw/chain counts the NUTS / PG paths reject (issue #399) rather than
315    // papering over `n_chains == 0` / `n_samples == 0` with `.max(1)`, which
316    // would silently fabricate draws the caller never asked for.
317    validate_nuts_config(cfg).map_err(String::from)?;
318    let fit = fit_result_from_saved_model_for_prediction(model)?;
319    let mode = fit.beta.clone();
320    let p = mode.len();
321    if p == 0 {
322        return Err(format!(
323            "{rationale}: cannot sample from an empty coefficient vector"
324        ));
325    }
326    let h = fit.penalized_hessian().ok_or_else(|| {
327        format!(
328            "{rationale}: posterior fallback requires the explicit penalised Hessian; \
329             refit with exact geometry export to enable posterior sampling for this class."
330        )
331    })?;
332    // `penalized_hessian` is stored unscaled (no φ). To draw Laplace
333    // approximations of `N(mode, φ·H⁻¹)` we solve `Lᵀ δ = ε` (so
334    // `Var(δ) = H⁻¹`) and then rescale by √φ. For families with
335    // `Dispersion::Known(1.0)` (Binomial / Poisson) this is a no-op;
336    // for Gaussian / Gamma it restores the φ-scaled posterior
337    // covariance that the Wald-style intervals downstream assume.
338    let dispersion = fit.dispersion().unwrap_or_default();
339    let sqrt_phi = dispersion.sqrt_phi();
340    if h.nrows() != p || h.ncols() != p {
341        return Err(format!(
342            "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
343            h.nrows(),
344            h.ncols(),
345            p,
346            p
347        ));
348    }
349    let chol = h.cholesky(Side::Lower).map_err(|err| {
350        format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
351    })?;
352    let l = chol.lower_triangular();
353
354    // `validate_nuts_config` above guarantees `n_chains >= 1` and
355    // `n_samples >= 4`, so the draw grid is always non-empty and densely
356    // filled — no `.max(1)` clamping or bounds guard is needed.
357    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
358    let mut samples = Array2::<f64>::zeros((n_total, p));
359    let mut eps = Array1::<f64>::zeros(p);
360    let mut delta = Array1::<f64>::zeros(p);
361    for chain in 0..cfg.n_chains {
362        let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
363            cfg.seed,
364            chain,
365            0xA0B7_6C5D_E431_298F,
366        ));
367        for draw in 0..cfg.n_samples {
368            let k = chain * cfg.n_samples + draw;
369            for i in 0..p {
370                eps[i] = sample_standard_normal(&mut rng);
371            }
372            back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
373            for i in 0..p {
374                // `delta` has covariance H⁻¹; multiplying by √φ produces a
375                // draw with covariance φ·H⁻¹, matching the φ-scaled
376                // posterior covariance `Vb` the rest of inference assumes.
377                samples[(k, i)] = mode[i] + sqrt_phi * delta[i];
378            }
379        }
380    }
381
382    let posterior_mean = samples
383        .mean_axis(ndarray::Axis(0))
384        .unwrap_or_else(|| Array1::<f64>::zeros(p));
385    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
386
387    Ok(NutsResult {
388        samples,
389        posterior_mean,
390        posterior_std,
391        rhat: 1.0,
392        ess: n_total as f64,
393        converged: true,
394    })
395}
396
397#[inline]
398fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
399    // Box-Muller transform — sufficient for posterior-mean-style sampling.
400    // The same construction is used by the NUTS warmup; keeping it in
401    // sync avoids two divergent gaussian RNG paths inside the engine.
402    let u1 = rng.random::<f64>().max(1e-16);
403    let u2 = rng.random::<f64>();
404    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
405}
406
407fn sample_standard(
408    model: &SavedModel,
409    data: ArrayView2<'_, f64>,
410    col_map: &HashMap<String, usize>,
411    training_headers: Option<&Vec<String>>,
412    mut likelihood: LikelihoodSpec,
413    cfg: &NutsConfig,
414) -> Result<NutsResult, String> {
415    // A coefficient that needs a *constraint-aware* posterior sampler must not
416    // take the gaussian-identity closed-form Laplace shortcut: that shortcut
417    // draws an unconstrained `N(mode, φ·H⁻¹)`, which for an active bound puts
418    // ~half its mass on the forbidden side of the boundary. Three geometries
419    // qualify, and all reproduce on a default gaussian model:
420    //   * a `bounded(x, min, max)` interval transform (#1508) — sampled on its
421    //     latent logit scale by `sample_standard_bounded`;
422    //   * a `nonnegative()`/`nonpositive()`/`linear(min,max)`/`constrain()` box
423    //     bound on a parametric coefficient (#1507); and
424    //   * a monotone/convex/concave shape cone on a spline (#1509).
425    // The latter two are sampled from the truncated Gaussian below. Detect all
426    // three cheaply from the saved termspec so the common, fully-unconstrained
427    // gaussian path keeps its fast exact fallback without building the design;
428    // the precise dispatch (and the authoritative `design.linear_constraints`
429    // check) happens after the design is assembled.
430    let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
431        ts.linear_terms.iter().any(|term| {
432            !matches!(
433                term.coefficient_geometry,
434                LinearCoefficientGeometry::Unconstrained
435            ) || term.coefficient_min.is_some()
436                || term.coefficient_max.is_some()
437        }) || ts
438            .smooth_terms
439            .iter()
440            .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
441    });
442    if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
443        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
444    }
445    if model.has_link_wiggle() {
446        // A Gaussian-identity link-wiggle model is sampled from its saved
447        // closed-form joint Laplace posterior (the mean and wiggle coefficients
448        // are jointly Gaussian); only the non-Gaussian wiggle posterior needs
449        // the dedicated link-wiggle NUTS path. Preserved from the original
450        // dispatch, where the Gaussian-identity shortcut ran ahead of the
451        // wiggle branch and so claimed Gaussian wiggle models for the
452        // closed-form path.
453        if likelihood.is_gaussian_identity() {
454            return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
455        }
456        return sample_standard_link_wiggle(
457            model,
458            data,
459            col_map,
460            training_headers,
461            likelihood,
462            cfg,
463        );
464    }
465    let parsed = parse_formula(&model.formula)?;
466    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
467    let y = data.column(y_col).to_owned();
468    let spec = resolve_termspec_for_prediction(
469        &model.resolved_termspec,
470        training_headers,
471        col_map,
472        "resolved_termspec",
473    )?;
474    let design = build_term_collection_design(data, &spec)
475        .map_err(|e| format!("failed to build term collection design: {e}"))?;
476
477    // ---- Constraint-aware posterior dispatch -------------------------------
478    //
479    // A coefficient subject to an *active* constraint sits on the boundary of
480    // its feasible region, so a plain unconstrained draw `N(mode, φ·H⁻¹)`
481    // places ~half its mass on the forbidden side. The constrained geometry
482    // must therefore be reconstructed *here*, ahead of both the
483    // Gaussian-identity closed-form shortcut and the GLM-NUTS fallback —
484    // neither of which is aware of the feasible region (#1507/#1508/#1509).
485    // The fit pins the point estimate correctly; only the posterior was blind.
486
487    // (1) bounded() interval coefficients are not sampled by the GLM-NUTS path.
488    // That path runs the Hamiltonian over the *raw* linear design with the
489    // saved user-scale mode, treating every coefficient as an unconstrained,
490    // Gaussian-penalized parameter. Bounded terms are fit through a custom
491    // family that drives eta via an interval transform `beta = min + (max-min)·
492    // sigmoid(theta)` of an unconstrained latent `theta`. The posterior is
493    // Gaussian on that *latent* scale (which is exactly where the fit treats the
494    // coefficient as a locally-quadratic, unconstrained parameter), so the
495    // correct draws are `theta ~ N(theta_mode, H_latent^{-1})` pushed forward
496    // through the interval map — never a Gaussian on the user scale, which can
497    // place mass outside [min,max] and discards the boundary-induced skew. The
498    // saved fit exports the user-scale mode and user-scale penalized Hessian;
499    // `sample_bounded_latent_posterior_internal` reconstructs the latent
500    // geometry via the exact inverse delta-method (`H_latent = J H_user J`) and
501    // returns user-scale draws that always lie strictly inside the interval.
502    // This must precede the Gaussian-identity shortcut: a Gaussian `bounded()`
503    // model would otherwise take the closed-form path and emit a user-scale
504    // Gaussian that spills outside the interval (#1508).
505    let has_bounded = spec.linear_terms.iter().any(|term| {
506        matches!(
507            term.coefficient_geometry,
508            LinearCoefficientGeometry::Bounded { .. }
509        )
510    });
511    if has_bounded {
512        // Mirror the fit-time layout: linear coefficient `j` lives at column
513        // `intercept_range.end + j` of the model's coefficient vector. Bounds
514        // are on the original (user/data) scale, which is also the scale the
515        // saved beta and penalized Hessian live on.
516        let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
517            .linear_terms
518            .iter()
519            .enumerate()
520            .filter_map(|(j, term)| match term.coefficient_geometry {
521                LinearCoefficientGeometry::Bounded { min, max, .. } => {
522                    Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
523                        col_idx: design.intercept_range.end + j,
524                        min,
525                        max,
526                    })
527                }
528                LinearCoefficientGeometry::Unconstrained => None,
529            })
530            .collect();
531        return sample_standard_bounded(model, cfg, &bounded_columns);
532    }
533
534    // (2) box / shape *inequality* constraints — `nonnegative()` /
535    // `linear(min,max)` / `constrain()` box bounds on a parametric coefficient
536    // (#1507) and the monotone/convex/concave shape cone `γ_j ≥ 0` on a spline
537    // (#1509). Both are reconstructed by `build_term_collection_design` into a
538    // single `A β ≥ b` polytope in the saved coefficient coordinate system, so
539    // one truncated-Gaussian sampler covers them uniformly. Like `bounded()`,
540    // this must precede the Gaussian-identity shortcut so a constrained
541    // Gaussian model is sampled inside its feasible region rather than from the
542    // boundary-centred unconstrained Gaussian.
543    if let Some(constraints) = design
544        .linear_constraints
545        .as_ref()
546        .filter(|c| c.a.nrows() > 0)
547    {
548        return sample_standard_truncated(model, cfg, constraints);
549    }
550
551    // (3) unconstrained Gaussian identity — saved closed-form Laplace posterior.
552    if likelihood.is_gaussian_identity() {
553        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
554    }
555
556    // (4) unconstrained non-Gaussian GLM — exact NUTS over the raw design.
557    let weights = Array1::ones(data.nrows());
558    let dense_design_hmc = design.design.to_dense();
559    let p = dense_design_hmc.ncols();
560    let fit = fit_result_from_saved_model_for_prediction(model)?;
561    // Refresh the NB overdispersion `theta` from the fit's jointly-estimated
562    // `theta_hat` before sampling. The construction seed stored on the family
563    // spec (`theta: 1.0`) only seeds the inner solve; the NUTS NB log-likelihood
564    // / score (`src/inference/hmc.rs`) reads `theta` straight off this spec, so
565    // leaving the seed in place over-states `Var(y) = μ + μ²/θ` and inflates
566    // every coefficient's posterior SD (#1463 — the HMC sibling of the
567    // replicate-path bug #1124). `theta_fixed` NB carries the user's exact value
568    // in both the spec and the scale metadata, so this refresh is a no-op there.
569    // Mirrors how the replicate path reads `theta_hat` via the canonical
570    // `family_noise_parameter` helper (`negbin_theta().or(seed)`).
571    refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
572    if fit.beta.len() != p {
573        return Err(format!(
574            "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
575            fit.beta.len(),
576            p,
577        ));
578    }
579    if fit.lambdas.len() != design.penalties.len() {
580        return Err(format!(
581            "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
582            fit.lambdas.len(),
583            design.penalties.len(),
584        ));
585    }
586    let penalty =
587        weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
588
589    // Re-apply the offset the model was fit with so the posterior targets the
590    // same η = Xβ + offset as the fit and predict paths. The diagnostic loader
591    // keeps the saved offset column in the frame; dropping the offset silently
592    // sampled the wrong target for any `--offset-column` GLM (#882).
593    let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
594        Some(name) => {
595            let idx = resolve_role_col(col_map, name, "offset")?;
596            Some(data.column(idx).to_owned())
597        }
598        None => None,
599    };
600
601    run_nuts_sampling_flattened_family(
602        likelihood,
603        FamilyNutsInputs::Glm(GlmFlatInputs {
604            x: dense_design_hmc.view(),
605            y: y.view(),
606            weights: weights.view(),
607            penalty_matrix: penalty.view(),
608            mode: fit.beta.view(),
609            hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
610            gamma_shape: fit.likelihood_scale.gamma_shape(),
611            // Forward the saved training dispersion so NUTS whitening uses the
612            // posterior scale selected at fit time; fixed-scale families remain
613            // a no-op.
614            dispersion: fit.dispersion().unwrap_or_default(),
615            firth_bias_reduction: false,
616            offset: offset_vec.as_ref().map(|o| o.view()),
617        }),
618        cfg,
619    )
620    .map_err(|e| format!("NUTS sampling failed: {e}"))
621}
622
623/// Exact posterior draws for a standard GLM with `bounded()` coefficients.
624///
625/// The bounded coefficients are sampled on their natural latent (logit) scale —
626/// where the Laplace approximation is Gaussian — and every draw is pushed
627/// through the exact interval map so user-scale draws always lie strictly inside
628/// `[min, max]` and carry the boundary-induced skew. Non-bounded coefficients
629/// are drawn as the ordinary Gaussian Laplace component of the same joint
630/// posterior, so cross-coefficient correlations with the bounded columns are
631/// preserved (the latent precision is the full `H_latent = J H_user J`).
632fn sample_standard_bounded(
633    model: &SavedModel,
634    cfg: &NutsConfig,
635    bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
636) -> Result<NutsResult, String> {
637    validate_nuts_config(cfg).map_err(String::from)?;
638    let fit = fit_result_from_saved_model_for_prediction(model)?;
639    let mode = fit.beta.clone();
640    let p = mode.len();
641    if p == 0 {
642        return Err(
643            "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
644                .to_string(),
645        );
646    }
647    // The bounded fit exports the UNSCALED user-scale penalized Hessian; the
648    // latent sampler reconstructs the latent precision from it via the exact
649    // inverse delta-method. (`explicit_fit_hessian_for_whitening` returns this
650    // same user-scale penalized Hessian for a saved standard fit.)
651    let user_hessian =
652        explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
653    // The exported Hessian carries unit implicit dispersion, so the latent
654    // posterior covariance is `cov_scale·H_latent⁻¹` with `cov_scale` the
655    // coefficient-covariance scale the fit used for `Vb` (`σ̂²` for a profiled
656    // Gaussian, `1` for fixed-scale Binomial). Re-applying `√cov_scale` here
657    // keeps the draw spread identical to the reported `summary().std_error`
658    // (gam#1514); the truncated-constraint path does the analogous √φ lift.
659    let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
660    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
661    let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
662        &mode,
663        user_hessian,
664        bounded_columns,
665        n_total,
666        sqrt_cov_scale,
667        chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
668    )
669    .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
670
671    let posterior_mean = samples
672        .mean_axis(ndarray::Axis(0))
673        .unwrap_or_else(|| Array1::<f64>::zeros(p));
674    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
675
676    Ok(NutsResult {
677        samples,
678        posterior_mean,
679        posterior_std,
680        rhat: 1.0,
681        ess: n_total as f64,
682        converged: true,
683    })
684}
685
686/// Exact posterior draws for a standard GLM whose coefficients carry linear
687/// *inequality* constraints `A β ≥ b` — `nonnegative()` / `linear(min,max)` /
688/// `constrain()` box bounds on a parametric term (#1507) and the
689/// monotone/convex/concave shape cone `γ_j ≥ 0` on a spline (#1509).
690///
691/// The posterior is the Laplace Gaussian `N(mode, φ·H⁻¹)` *truncated* to the
692/// feasible polytope. For a Gaussian-identity model this is the exact
693/// posterior; for a non-Gaussian GLM it is the constraint-respecting Laplace
694/// approximation — the same modelling choice the `bounded()` term makes. The
695/// draws are produced by exact reflective Hamiltonian Monte Carlo
696/// ([`crate::truncated_gaussian`]), so every draw is feasible and
697/// successive draws are essentially independent (`rhat ≈ 1`, matching the other
698/// Laplace posterior paths).
699fn sample_standard_truncated(
700    model: &SavedModel,
701    cfg: &NutsConfig,
702    constraints: &gam_solve::pirls::LinearInequalityConstraints,
703) -> Result<NutsResult, String> {
704    validate_nuts_config(cfg).map_err(String::from)?;
705    let fit = fit_result_from_saved_model_for_prediction(model)?;
706    let mode = fit.beta.clone();
707    let p = mode.len();
708    if p == 0 {
709        return Err(
710            "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
711             vector"
712                .to_string(),
713        );
714    }
715    // The saved standard fit exports the unscaled user-scale penalised Hessian
716    // `H`; the truncated sampler whitens with its Cholesky and re-applies √φ so
717    // the posterior covariance is `φ·H⁻¹`, identical to the unconstrained
718    // Gaussian/bounded paths. Fixed-scale families (Binomial / Poisson) have
719    // φ = 1.
720    let penalized_hessian =
721        explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
722    let sqrt_phi = {
723        use gam_problem::dispersion_cov::DispersionExt as _;
724        fit.dispersion().unwrap_or_default().sqrt_phi()
725    };
726    let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
727        &mode,
728        &penalized_hessian,
729        sqrt_phi,
730        constraints,
731        cfg.n_samples,
732        cfg.n_chains,
733        chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
734    )?;
735    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
736
737    let posterior_mean = samples
738        .mean_axis(ndarray::Axis(0))
739        .unwrap_or_else(|| Array1::<f64>::zeros(p));
740    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
741
742    Ok(NutsResult {
743        samples,
744        posterior_mean,
745        posterior_std,
746        rhat: 1.0,
747        ess: n_total as f64,
748        converged: true,
749    })
750}
751
752fn sample_standard_link_wiggle(
753    model: &SavedModel,
754    data: ArrayView2<'_, f64>,
755    col_map: &HashMap<String, usize>,
756    training_headers: Option<&Vec<String>>,
757    likelihood: LikelihoodSpec,
758    cfg: &NutsConfig,
759) -> Result<NutsResult, String> {
760    let parsed = parse_formula(&model.formula)?;
761    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
762    let y = data.column(y_col).to_owned();
763
764    let spec = resolve_termspec_for_prediction(
765        &model.resolved_termspec,
766        training_headers,
767        col_map,
768        "resolved_termspec",
769    )?;
770    let design = build_term_collection_design(data, &spec)
771        .map_err(|e| format!("failed to build term collection design: {e}"))?;
772    let p_main = design.design.ncols();
773
774    let fit = fit_result_from_saved_model_for_prediction(model)?;
775    let wiggle_runtime = model
776        .saved_prediction_runtime()?
777        .link_wiggle
778        .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
779    let mode_beta = fit
780        .block_by_role(BlockRole::Mean)
781        .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
782        .beta
783        .clone();
784    let mode_theta = fit
785        .block_by_role(BlockRole::LinkWiggle)
786        .ok_or_else(|| {
787            "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
788        })?
789        .beta
790        .clone();
791    let p_wiggle = mode_theta.len();
792    let p_total = mode_beta.len() + p_wiggle;
793
794    if mode_beta.len() != p_main {
795        return Err(format!(
796            "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
797            mode_beta.len(),
798            p_main,
799        ));
800    }
801    if fit.beta.len() != p_total {
802        return Err(format!(
803            "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
804            fit.beta.len(),
805            p_main,
806            p_wiggle,
807            p_total,
808        ));
809    }
810
811    let hessian = &fit
812        .geometry
813        .as_ref()
814        .ok_or_else(|| {
815            "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
816                .to_string()
817        })?
818        .penalized_hessian;
819    validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
820
821    let n_base_penalties = design.penalties.len();
822    let base_lambdas = fit
823        .block_by_role(BlockRole::Mean)
824        .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
825        .lambdas
826        .view();
827    if base_lambdas.len() != n_base_penalties {
828        return Err(format!(
829            "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
830            base_lambdas.len(),
831            n_base_penalties,
832        ));
833    }
834
835    let penalty_base =
836        weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
837
838    let wiggle_lambdas_owned = fit
839        .lambdas_linkwiggle()
840        .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
841    let wiggle_lambdas = wiggle_lambdas_owned.view();
842    let degree = wiggle_runtime.degree;
843    let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
844
845    let mut wiggle_penalties = Vec::new();
846    let default_orders = [2usize];
847    let n_wiggle_lambdas = wiggle_lambdas.len();
848    for k in 0..n_wiggle_lambdas {
849        let order = if k < default_orders.len() {
850            default_orders[k]
851        } else {
852            k + 1
853        };
854        if order >= p_wiggle {
855            continue;
856        }
857        let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
858            .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
859        wiggle_penalties.push(penalty);
860    }
861    while wiggle_penalties.len() < n_wiggle_lambdas {
862        wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
863    }
864
865    let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
866
867    let q0 = design.design.dot(&mode_beta);
868    let (q0_min, q0_max) = q0
869        .iter()
870        .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
871            (lo.min(v), hi.max(v))
872        });
873
874    let spline = LinkWiggleSplineArtifacts {
875        knot_range: (q0_min, q0_max),
876        knot_vector: knot_arr,
877        degree,
878    };
879
880    let nuts_family = match (&likelihood.response, &likelihood.link) {
881        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
882            NutsFamily::BinomialLogit
883        }
884        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
885            NutsFamily::BinomialProbit
886        }
887        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
888            NutsFamily::BinomialCLogLog
889        }
890        (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
891        (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
892        (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
893        (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
894        (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
895        _ => {
896            return Err(format!(
897                "NUTS sampling with link wiggle is not supported for family {}",
898                likelihood.pretty_name()
899            ));
900        }
901    };
902
903    let weights = Array1::ones(data.nrows());
904    let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
905
906    let wiggle_nuts_dense = design.design.as_dense_cow();
907    run_link_wiggle_nuts_sampling(
908        wiggle_nuts_dense.view(),
909        y.view(),
910        weights.view(),
911        penalty_base.view(),
912        penalty_link.view(),
913        mode_beta.view(),
914        mode_theta.view(),
915        hessian.view(),
916        spline,
917        nuts_family,
918        scale,
919        cfg,
920    )
921    .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
922}
923
924fn sample_survival(
925    model: &SavedModel,
926    data: ArrayView2<'_, f64>,
927    col_map: &HashMap<String, usize>,
928    training_headers: Option<&Vec<String>>,
929    cfg: &NutsConfig,
930) -> Result<NutsResult, String> {
931    let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
932    if matches!(
933        saved_likelihood_mode,
934        SurvivalLikelihoodMode::Latent
935            | SurvivalLikelihoodMode::LatentBinary
936            | SurvivalLikelihoodMode::LocationScale
937    ) {
938        return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
939    }
940    // `survival_entry == None` is the right-censored shorthand
941    // `Surv(time, event)`: training synthesized a zero entry column,
942    // and posterior sampling must do the same so artifacts fit with
943    // the shorthand are first-class through `gam sample` /
944    // `model.sample` just like `gam predict` already handles them in
945    // `run_predict_survival`. The resolution flows through the shared
946    // `resolve_saved_survival_time_columns` helper so every consumer
947    // of saved survival metadata applies the same fallback contract.
948    let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
949    let exit_col = time_cols.exit_col;
950    let eventname = model
951        .survival_event
952        .as_ref()
953        .ok_or_else(|| "survival model missing event column metadata".to_string())?;
954    let event_col = resolve_role_col(col_map, eventname, "event")?;
955    let termspec = resolve_termspec_for_prediction(
956        &model.resolved_termspec,
957        training_headers,
958        col_map,
959        "resolved_termspec",
960    )?;
961    let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
962    let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
963    let cov_design = build_term_collection_design(cov_input, &termspec)
964        .map_err(|e| format!("failed to build survival design: {e}"))?;
965    let n = data.nrows();
966    let p_cov = cov_design.design.ncols();
967    let mut age_entry = Array1::<f64>::zeros(n);
968    let mut age_exit = Array1::<f64>::zeros(n);
969    let mut event_target = Array1::<u8>::zeros(n);
970    let event_competing = Array1::<u8>::zeros(n);
971    let weights = Array1::<f64>::ones(n);
972    for i in 0..n {
973        let (t0, t1) = normalize_survival_time_pair(
974            time_cols.row_entry_time(data, i),
975            data[[i, exit_col]],
976            i,
977        )?;
978        age_entry[i] = t0;
979        age_exit[i] = t1;
980        event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
981    }
982    let time_cfg = load_survival_time_basis_config_from_model(model)?;
983    let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
984    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
985        &time_build.basisname,
986        time_build.degree,
987        time_build.knots.as_ref(),
988        time_build.keep_cols.as_ref(),
989        time_build.smooth_lambda,
990    )?;
991    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
992        let time_anchor = model
993            .survival_time_anchor
994            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
995        let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
996        center_survival_time_designs_at_anchor(
997            &mut time_build.x_entry_time,
998            &mut time_build.x_exit_time,
999            &time_anchor_row,
1000        )?;
1001    }
1002    let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1003    let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1004        build_survival_time_offsets_for_likelihood(
1005            &age_entry,
1006            &age_exit,
1007            &baseline_cfg,
1008            saved_likelihood_mode,
1009            None,
1010        )?;
1011    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1012        let time_anchor = model
1013            .survival_time_anchor
1014            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1015        add_survival_time_derivative_guard_offset(
1016            &age_entry,
1017            &age_exit,
1018            time_anchor,
1019            survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1020            &mut eta_offset_entry,
1021            &mut eta_offset_exit,
1022            &mut derivative_offset_exit,
1023        )?;
1024    }
1025    let saved_timewiggle = saved_baseline_timewiggle_components(
1026        &eta_offset_entry,
1027        &eta_offset_exit,
1028        &derivative_offset_exit,
1029        model,
1030    )?;
1031    let p_time = time_build.x_exit_time.ncols();
1032    let p_timewiggle = saved_timewiggle
1033        .as_ref()
1034        .map(|(_, exit, _)| exit.ncols())
1035        .unwrap_or(0);
1036    let p = p_time + p_timewiggle + p_cov;
1037    let tb_entry_dense = time_build.x_entry_time.to_dense();
1038    let tb_exit_dense = time_build.x_exit_time.to_dense();
1039    let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1040    let mut x_entry = Array2::<f64>::zeros((n, p));
1041    let mut x_exit = Array2::<f64>::zeros((n, p));
1042    let mut x_derivative = Array2::<f64>::zeros((n, p));
1043    if p_time > 0 {
1044        x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1045        x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1046        x_derivative
1047            .slice_mut(s![.., ..p_time])
1048            .assign(&tb_deriv_dense);
1049    }
1050    if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1051        && p_timewiggle > 0
1052    {
1053        x_entry
1054            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1055            .assign(entry_w);
1056        x_exit
1057            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1058            .assign(exit_w);
1059        x_derivative
1060            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1061            .assign(deriv_w);
1062    }
1063    if p_cov > 0 {
1064        let cov_dense = cov_design.design.to_dense();
1065        let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1066        x_entry
1067            .slice_mut(s![.., cov_range.clone()])
1068            .assign(&cov_dense);
1069        x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1070    }
1071    let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1072    for (idx, s) in time_build.penalties.iter().enumerate() {
1073        if s.nrows() == p_time && s.ncols() == p_time {
1074            penalty_blocks.push(PenaltyBlock {
1075                matrix: s.clone(),
1076                lambda: time_build
1077                    .smooth_lambda
1078                    .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1079                range: 0..p_time,
1080                nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1081            });
1082        }
1083    }
1084    let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1085    if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1086        let start = p_time;
1087        let end = start + exit_w.ncols();
1088        let wiggle_lambda_offset = penalty_blocks.len();
1089        let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1090            "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1091        })?;
1092        let wiggle_degree = wiggle_cfg.degree;
1093        let wiggle_knots =
1094            Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1095                "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1096            })?);
1097        let mut seed = Array1::<f64>::zeros(2 * n);
1098        for i in 0..n {
1099            seed[i] = eta_offset_entry[i];
1100            seed[n + i] = eta_offset_exit[i];
1101        }
1102        let (primary_order, extra_orders) =
1103            split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1104        let mut block = buildwiggle_block_input_from_knots(
1105            seed.view(),
1106            &wiggle_knots,
1107            wiggle_degree,
1108            primary_order,
1109            wiggle_cfg.double_penalty,
1110        )?;
1111        append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1112            .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1113        for (widx, s) in block.penalties.iter().enumerate() {
1114            let s = match s {
1115                gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1116                gam_solve::estimate::PenaltySpec::Dense(m)
1117                | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1118            };
1119            if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1120                penalty_blocks.push(PenaltyBlock {
1121                    matrix: s.clone(),
1122                    lambda: time_build
1123                        .smooth_lambda
1124                        .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1125                    range: start..end,
1126                    nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1127                });
1128            }
1129        }
1130        for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1131            .iter_mut()
1132            .enumerate()
1133        {
1134            if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1135                block_penalty.lambda = lam;
1136            }
1137        }
1138    }
1139    let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1140        "saved survival model is missing survivalridge_lambda; refusing to \
1141         pick a load-time default (the historical 1e-4 fallback silently \
1142         disagreed with the 1e-6 fit-time default). Refit."
1143            .to_string()
1144    })?;
1145    let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1146    {
1147        1
1148    } else {
1149        0
1150    };
1151    if ridge_lambda > 0.0 && p > ridge_range_start {
1152        let dim = p - ridge_range_start;
1153        let mut ridge = Array2::<f64>::zeros((dim, dim));
1154        for d in 0..dim {
1155            ridge[[d, d]] = 1.0;
1156        }
1157        penalty_blocks.push(PenaltyBlock {
1158            matrix: ridge,
1159            lambda: ridge_lambda,
1160            range: ridge_range_start..p,
1161            nullspace_dim: 0,
1162        });
1163    }
1164    for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1165        if let Some(&lam) = fit_saved.lambdas.get(idx) {
1166            block.lambda = lam;
1167        }
1168    }
1169    let penalties = PenaltyBlocks::new(penalty_blocks);
1170    let survivalspec = match model
1171        .survivalspec
1172        .as_deref()
1173        .unwrap_or("net")
1174        .to_ascii_lowercase()
1175        .as_str()
1176    {
1177        "net" => SurvivalSpec::Net,
1178        "crude" => {
1179            return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1180                        .to_string());
1181        }
1182        other => {
1183            return Err(format!("unsupported saved survival spec '{other}'"));
1184        }
1185    };
1186    let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1187    let mut model_surv = royston_parmar::working_model_from_flattened(
1188        penalties.clone(),
1189        monotonicity,
1190        survivalspec,
1191        RoystonParmarInputs {
1192            age_entry: age_entry.view(),
1193            age_exit: age_exit.view(),
1194            event_target: event_target.view(),
1195            event_competing: event_competing.view(),
1196            weights: weights.view(),
1197            x_entry: x_entry.view(),
1198            x_exit: x_exit.view(),
1199            x_derivative: x_derivative.view(),
1200            monotonicity_constraint_rows: None,
1201            monotonicity_constraint_offsets: None,
1202            eta_offset_entry: Some(eta_offset_entry.view()),
1203            eta_offset_exit: Some(eta_offset_exit.view()),
1204            derivative_offset_exit: Some(derivative_offset_exit.view()),
1205        },
1206    )
1207    .map_err(|e| format!("failed to construct survival model: {e}"))?;
1208    if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1209        model_surv
1210            .set_structural_monotonicity(true, p_time + p_timewiggle)
1211            .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1212    }
1213    let beta0 = fit_saved.beta.clone();
1214    let state = model_surv
1215        .update_state(&beta0)
1216        .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1217    let hessian = state.hessian.to_dense();
1218    run_survival_nuts_sampling_flattened(
1219        SurvivalFlatInputs {
1220            age_entry: age_entry.view(),
1221            age_exit: age_exit.view(),
1222            event_target: event_target.view(),
1223            event_competing: event_competing.view(),
1224            weights: weights.view(),
1225            x_entry: x_entry.view(),
1226            x_exit: x_exit.view(),
1227            x_derivative: x_derivative.view(),
1228            eta_offset_entry: Some(eta_offset_entry.view()),
1229            eta_offset_exit: Some(eta_offset_exit.view()),
1230            derivative_offset_exit: Some(derivative_offset_exit.view()),
1231        },
1232        penalties,
1233        monotonicity,
1234        survivalspec,
1235        saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1236        p_time + p_timewiggle,
1237        beta0.view(),
1238        hessian.view(),
1239        cfg,
1240    )
1241    .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246    use super::*;
1247    use gam_problem::types::LikelihoodScaleMetadata;
1248
1249    /// #1463: the NB NUTS path must sample at the fit's jointly-estimated
1250    /// `theta_hat`, not the construction seed `theta = 1.0`. The seed only seeds
1251    /// the inner solve; the NUTS NB log-likelihood/score reads `theta` straight
1252    /// off the sampling `LikelihoodSpec`, so unless we refresh it from the scale
1253    /// metadata the posterior is drawn at the wrong overdispersion and every
1254    /// coefficient's posterior SD inflates ~1.4–1.5×.
1255    ///
1256    /// Pre-fix, `sample_standard` forwarded the seed unchanged: this assertion
1257    /// would read `theta == 1.0` and fail. With the refresh in place the seam
1258    /// rewrites the spec to `theta_hat`.
1259    #[test]
1260    fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1261        // Spec carries the construction seed theta = 1.0; the fit estimated a
1262        // very different theta_hat = 2.97 and recorded it in the scale metadata.
1263        let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1264        let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1265
1266        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1267
1268        match likelihood.response {
1269            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1270                theta, 2.97,
1271                "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1272            ),
1273            other => panic!("expected NegativeBinomial response, got {other:?}"),
1274        }
1275    }
1276
1277    /// A fixed-theta NB fit records the user's exact `theta` in both the spec and
1278    /// the scale metadata, so the refresh is a no-op that still lands on the
1279    /// fixed value (never the inner-solve seed of an estimated fit).
1280    #[test]
1281    fn refresh_negbin_theta_fixed_theta_is_preserved() {
1282        let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1283        let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1284
1285        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1286
1287        match likelihood.response {
1288            ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1289                assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1290                assert!(theta_fixed, "theta_fixed flag must be preserved");
1291            }
1292            other => panic!("expected NegativeBinomial response, got {other:?}"),
1293        }
1294    }
1295
1296    /// When the fit recorded no NB theta (non-NB scale metadata), the refresh
1297    /// must leave the spec's seed untouched — mirroring the canonical replicate
1298    /// picker's `negbin_theta().or(seed)`.
1299    #[test]
1300    fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1301        let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1302        // ProfiledGaussian carries no negbin_theta, so the accessor returns None.
1303        refresh_negbin_theta_for_sampling(
1304            &mut likelihood,
1305            LikelihoodScaleMetadata::ProfiledGaussian,
1306        );
1307
1308        match likelihood.response {
1309            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1310                theta, 3.5,
1311                "with no fitted theta the NB seed must be kept verbatim"
1312            ),
1313            other => panic!("expected NegativeBinomial response, got {other:?}"),
1314        }
1315    }
1316
1317    /// Non-NB families must be completely unaffected by the NB refresh, even when
1318    /// the scale metadata happens to carry an NB theta — the match guards on the
1319    /// response family, so Poisson/Gamma/etc. are left untouched.
1320    #[test]
1321    fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1322        let mut poisson = LikelihoodSpec::poisson_log();
1323        let before = poisson.response.clone();
1324        refresh_negbin_theta_for_sampling(
1325            &mut poisson,
1326            LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1327        );
1328        assert_eq!(
1329            poisson.response, before,
1330            "Poisson response must be untouched by the NB theta refresh"
1331        );
1332    }
1333}