Skip to main content

gam_inference/
sample.rs

1//! Library-side orchestration for NUTS posterior sampling from a saved model.
2//!
3//! The CLI's `gam sample` subcommand and the Python `Model.sample(...)` API
4//! both call into [`sample_saved_model`], which dispatches on the saved
5//! model's class (standard GLM, standard with link-wiggle, or survival) and
6//! returns a fully-converged [`NutsResult`] over the original coefficient
7//! space. Gaussian identity standard models are sampled from the saved
8//! closed-form posterior, conditioning on the training fit rather than any
9//! prediction rows supplied by the caller.
10
11use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18    FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19    explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20    run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27    SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28    build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29    evaluate_survival_time_basis_row, normalize_survival_time_pair,
30    resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33    fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34    resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35    saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39    PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42    append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43    split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47    FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55/// Reconstruct the `LinkWiggleFormulaSpec` from a saved model's
56/// baseline-time-wiggle runtime, returning `None` when the model has no
57/// time-wiggle component. Re-exported because the survival fitter's tests
58/// exercise the spec independently of running NUTS.
59pub fn saved_baseline_timewiggle_spec(
60    model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62    model
63        .saved_baseline_time_wiggle()
64        .map_err(|e| e.to_string())
65        .map(|runtime| {
66            runtime.map(|saved| LinkWiggleFormulaSpec {
67                degree: saved.degree,
68                num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69                penalty_orders: saved.penalty_orders,
70                double_penalty: saved.double_penalty,
71            })
72        })
73}
74
75fn weighted_penalty_matrix(
76    penalties: &[Array2<f64>],
77    lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79    if penalties.len() != lambdas.len() {
80        return Err(format!(
81            "penalty/lambda mismatch: {} penalties vs {} lambdas",
82            penalties.len(),
83            lambdas.len()
84        ));
85    }
86    if penalties.is_empty() {
87        return Err("cannot sample without at least one penalty block".to_string());
88    }
89    let p = penalties[0].nrows();
90    let mut out = Array2::<f64>::zeros((p, p));
91    for (k, s) in penalties.iter().enumerate() {
92        if s.nrows() != p || s.ncols() != p {
93            return Err(format!(
94                "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95                s.nrows(),
96                s.ncols(),
97                p,
98                p
99            ));
100        }
101        let lam = lambdas[k];
102        out += &(s * lam);
103    }
104    Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108    hessian: &Array2<f64>,
109    expected_dim: usize,
110) -> Result<(), String> {
111    if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112        return Err(format!(
113            "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114            hessian.nrows(),
115            hessian.ncols(),
116            expected_dim,
117            expected_dim,
118        ));
119    }
120    validate_all_finite(
121        "link-wiggle explicit joint Hessian",
122        hessian.iter().copied(),
123    )?;
124    let mut max_abs = 0.0_f64;
125    for r in 0..expected_dim {
126        for c in 0..expected_dim {
127            max_abs = max_abs.max(hessian[[r, c]].abs());
128            let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129            if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130                return Err(format!(
131                    "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132                ));
133            }
134        }
135    }
136    if max_abs == 0.0 {
137        return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138                    .to_string());
139    }
140    Ok(())
141}
142
143/// Resolve the scalar generative dispersion for a fitted model.
144///
145/// Thin adapter over the single canonical
146/// [`crate::generative::family_noise_parameter`]: the replicate-sampling path
147/// here and the CLI `gam generate` path both route through that one helper, so
148/// the fitted dispersion (NB θ̂, Beta/Tweedie φ̂, Gamma k̂) can never be read
149/// inconsistently between them. A divergent second copy of this logic was the
150/// root cause of #1124.
151fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152    crate::generative::family_noise_parameter(
153        fit.likelihood_scale,
154        fit.standard_deviation,
155        likelihood,
156    )
157}
158
159/// Refresh the Negative-Binomial overdispersion `theta` on the sampling
160/// likelihood spec from the fit's jointly-estimated `theta_hat` before the NUTS
161/// dispatch reads it (#1463).
162///
163/// The construction seed stored on the family spec (`theta: 1.0`) only seeds the
164/// inner solve. NB carries unit REML scale and records its fitted overdispersion
165/// in `likelihood_scale` (`EstimatedNegBinTheta` / `FixedNegBinTheta`), *not* in
166/// the REML dispersion. The NUTS NB log-likelihood / score
167/// (`src/inference/hmc.rs`) reads `theta` straight off this spec, so leaving the
168/// seed in place over-states `Var(y) = μ + μ²/θ` and inflates every
169/// coefficient's posterior SD ~1.4–1.5× (the HMC sibling of the replicate-path
170/// bug #1124). This mirrors the canonical replicate picker
171/// [`crate::generative::family_noise_parameter`]'s `negbin_theta().or(seed)`:
172/// when the scale records a fitted `theta_hat`, use it; otherwise keep the
173/// existing seed. `theta_fixed` NB carries the user's exact value in both the
174/// spec and the scale metadata, so this refresh is a no-op there. Non-NB
175/// families are left untouched.
176fn refresh_negbin_theta_for_sampling(
177    likelihood: &mut LikelihoodSpec,
178    scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180    if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181        if let Some(theta_hat) = scale.negbin_theta() {
182            *theta = theta_hat;
183        }
184    }
185}
186
187/// Build a `LikelihoodSpec` for a saved model. Saved models already carry the
188/// response distribution and parameterized link state together, so sampling can
189/// dispatch directly on the cloned spec.
190fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191    Ok(model.likelihood())
192}
193
194/// Default smoothing strength `λ` applied to a reconstructed penalty block when
195/// the saved model carries no fitted `smooth_lambda`. A mild penalty: enough to
196/// regularize the reconstructed-for-prediction design without materially
197/// reshaping the saved fit. Fitted lambdas, when present, always override this.
198const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202    gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207    splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210/// Run NUTS posterior sampling over a saved model.
211///
212/// Dispatches on `model.predict_model_class()`:
213///
214/// * `Standard`: Gaussian identity models use the exact saved
215///   `N(mode, φ·H⁻¹)` posterior, where `mode`, `φ`, and `H` all come from the
216///   training fit. Other standard GLMs run NUTS from the saved mode,
217///   smoothing parameters, dispersion, and whitening curvature rather than
218///   refitting/reselecting them on the caller-supplied rows. Link-wiggle
219///   models take a specialised joint-space path that preserves the basis
220///   chain rule.
221/// * `Survival`: rebuilds the survival design (Royston-Parmar baseline +
222///   wiggle + covariate blocks) on the supplied data, evaluates the mode,
223///   and runs the survival-flat NUTS path. Latent and location-scale modes
224///   are explicitly rejected here.
225/// * Other model classes (location-scale GLM, bernoulli marginal-slope,
226///   transformation-normal) return a "not implemented" error matching the
227///   CLI surface.
228pub fn sample_saved_model(
229    model: &SavedModel,
230    data: ArrayView2<'_, f64>,
231    col_map: &HashMap<String, usize>,
232    training_headers: Option<&Vec<String>>,
233    cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235    // Issue #399: degenerate draw/chain counts (`samples=0` / `chains=0`, and
236    // the `samples < 4` counts the split-R-hat engine path cannot handle) must
237    // surface as one typed `InvalidConfig` error before any sampler runs —
238    // identically across *every* model class. Validating here, at the single
239    // public dispatch point, guarantees that the NUTS path, the auto-selected
240    // Pólya-Gamma Gibbs path, and the Laplace-Gaussian fallback all reject the
241    // same inputs the same way (previously the fallback silently accepted them
242    // via `.max(1)` while NUTS errored — a divergent contract on one API).
243    validate_nuts_config(cfg).map_err(String::from)?;
244    let likelihood = likelihood_spec_for_saved_model(model)?;
245    match model.predict_model_class() {
246        PredictModelClass::Survival => {
247            // Latent / latent-binary / location-scale survival likelihoods
248            // have no exact NUTS implementation in the engine yet; fall
249            // through to the Laplace-Gaussian fallback so callers still
250            // get a posterior they can predict with. Royston-Parmar /
251            // Weibull / marginal-slope survival use the exact path.
252            let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253            if matches!(
254                saved_likelihood_mode,
255                SurvivalLikelihoodMode::Latent
256                    | SurvivalLikelihoodMode::LatentBinary
257                    | SurvivalLikelihoodMode::LocationScale
258            ) {
259                laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260            } else {
261                sample_survival(model, data, col_map, training_headers, cfg)
262            }
263        }
264        PredictModelClass::Standard => {
265            // Most `Standard` GLM families (Gaussian, Poisson, Gamma, Tweedie,
266            // Negative-Binomial, binomial logit/probit/cloglog) have an exact
267            // NUTS implementation and run through `sample_standard`. Beta
268            // regression is the one `Standard` family the engine cannot sample
269            // with NUTS (`hmc_io.rs` returns a hard error for it). Rather than
270            // aborting the whole `sample` command, route it to the same
271            // Laplace-Gaussian fallback every other NUTS-unsupported model
272            // class already uses, so callers still get a usable posterior.
273            if matches!(likelihood.response, ResponseFamily::Beta { .. }) {
274                laplace_gaussian_fallback(model, cfg, "beta-regression posterior fallback")
275            } else {
276                sample_standard(model, data, col_map, training_headers, likelihood, cfg)
277            }
278        }
279        // For classes where the Rust core doesn't yet have an exact NUTS
280        // implementation we fall back to drawing from the Laplace
281        // (Gaussian) approximation of the posterior around the fitted
282        // joint mode, using the saved penalised Hessian. This is the
283        // standard "Bayesian credible interval" surface used by mgcv
284        // and similar packages: it drops higher-order posterior shape
285        // but lets every downstream consumer (credible intervals,
286        // posterior predictive, etc.) keep working uniformly across
287        // model classes.
288        PredictModelClass::GaussianLocationScale => {
289            laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
290        }
291        PredictModelClass::BinomialLocationScale => {
292            laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
293        }
294        PredictModelClass::DispersionLocationScale => {
295            laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
296        }
297        PredictModelClass::BernoulliMarginalSlope => {
298            laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
299        }
300        PredictModelClass::TransformationNormal => {
301            laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
302        }
303    }
304}
305
306/// Draw iid samples from `N(mode, H^{-1})` using the saved penalised
307/// Hessian `H = L L^T`.
308///
309/// We solve `L^T δ = ε` for each iid `ε ~ N(0, I)` and report
310/// `β = mode + δ`. The resulting draws are unbiased samples of the
311/// Laplace-Gaussian approximation: their finite-sample mean / std
312/// converge to `(mode, diag(H^{-1})^{1/2})` and the implied credible
313/// bands match the surface that closed-form posterior tooling in
314/// `mgcv` and `gam` itself uses for prediction intervals.
315///
316/// `rationale` is a short label appearing in error messages so callers
317/// can tell which class fell back to this path. We mark `rhat = 1.0`
318/// and `ess = n_total` because the draws are iid by construction.
319pub fn laplace_gaussian_fallback(
320    model: &SavedModel,
321    cfg: &NutsConfig,
322    rationale: &'static str,
323) -> Result<NutsResult, String> {
324    // Defense in depth: this is `pub`, so guard the same degenerate
325    // draw/chain counts the NUTS / PG paths reject (issue #399) rather than
326    // papering over `n_chains == 0` / `n_samples == 0` with `.max(1)`, which
327    // would silently fabricate draws the caller never asked for.
328    validate_nuts_config(cfg).map_err(String::from)?;
329    let fit = fit_result_from_saved_model_for_prediction(model)?;
330    let mode = fit.beta.clone();
331    let p = mode.len();
332    if p == 0 {
333        return Err(format!(
334            "{rationale}: cannot sample from an empty coefficient vector"
335        ));
336    }
337    let h = fit.penalized_hessian().ok_or_else(|| {
338        format!(
339            "{rationale}: posterior fallback requires the explicit penalised Hessian; \
340             refit with exact geometry export to enable posterior sampling for this class."
341        )
342    })?;
343    // `penalized_hessian` is stored unscaled. To draw Laplace
344    // approximations of `N(mode, cov_scale·H⁻¹)` we solve `Lᵀ δ = ε` (so
345    // `Var(δ) = H⁻¹`) and then rescale by `√cov_scale`, where `cov_scale`
346    // is the *coefficient-covariance* scale the fit uses for `Vb` — exactly
347    // the quantity `summary()`'s Wald SE is built from. This is `σ̂²` for a
348    // profiled Gaussian and `1.0` for every family whose IRLS working weight
349    // already folds the dispersion / full Fisher information into the stored
350    // `H` (Binomial / Poisson / Gamma / Beta / Negative-Binomial / Tweedie),
351    // so `Vb = H⁻¹` needs no extra dispersion factor. Using the dispersion's
352    // `√φ` here instead would double-count the dispersion for Beta, whose
353    // `dispersion()` is `Known(1/(1+φ))` even though its `cov_scale` is `1.0`,
354    // shrinking every posterior SD by `√(1/(1+φ))` (gam#1722). For the
355    // profiled Gaussian `cov_scale == σ̂² == φ`, so this matches the previous
356    // `√φ` behaviour exactly; it only changes (fixes) Beta. This keeps the
357    // draw spread identical to the reported `summary().std_error`, like the
358    // sibling bounded-coefficient path (gam#1514).
359    let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
360    if h.nrows() != p || h.ncols() != p {
361        return Err(format!(
362            "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
363            h.nrows(),
364            h.ncols(),
365            p,
366            p
367        ));
368    }
369    let chol = h.cholesky(Side::Lower).map_err(|err| {
370        format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
371    })?;
372    let l = chol.lower_triangular();
373
374    // `validate_nuts_config` above guarantees `n_chains >= 1` and
375    // `n_samples >= 4`, so the draw grid is always non-empty and densely
376    // filled — no `.max(1)` clamping or bounds guard is needed.
377    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
378    let mut samples = Array2::<f64>::zeros((n_total, p));
379    let mut eps = Array1::<f64>::zeros(p);
380    let mut delta = Array1::<f64>::zeros(p);
381    for chain in 0..cfg.n_chains {
382        let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
383            cfg.seed,
384            chain,
385            0xA0B7_6C5D_E431_298F,
386        ));
387        for draw in 0..cfg.n_samples {
388            let k = chain * cfg.n_samples + draw;
389            for i in 0..p {
390                eps[i] = sample_standard_normal(&mut rng);
391            }
392            back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
393            for i in 0..p {
394                // `delta` has covariance H⁻¹; multiplying by `√cov_scale`
395                // produces a draw with covariance `cov_scale·H⁻¹`, matching
396                // the coefficient covariance `Vb` the rest of inference (and
397                // `summary()`'s Wald SE) assumes.
398                samples[(k, i)] = mode[i] + sqrt_cov_scale * delta[i];
399            }
400        }
401    }
402
403    let posterior_mean = samples
404        .mean_axis(ndarray::Axis(0))
405        .unwrap_or_else(|| Array1::<f64>::zeros(p));
406    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
407
408    Ok(NutsResult {
409        samples,
410        posterior_mean,
411        posterior_std,
412        rhat: 1.0,
413        ess: n_total as f64,
414        converged: true,
415    })
416}
417
418#[inline]
419fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
420    // Box-Muller transform — sufficient for posterior-mean-style sampling.
421    // The same construction is used by the NUTS warmup; keeping it in
422    // sync avoids two divergent gaussian RNG paths inside the engine.
423    let u1 = rng.random::<f64>().max(1e-16);
424    let u2 = rng.random::<f64>();
425    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
426}
427
428fn sample_standard(
429    model: &SavedModel,
430    data: ArrayView2<'_, f64>,
431    col_map: &HashMap<String, usize>,
432    training_headers: Option<&Vec<String>>,
433    mut likelihood: LikelihoodSpec,
434    cfg: &NutsConfig,
435) -> Result<NutsResult, String> {
436    // A coefficient that needs a *constraint-aware* posterior sampler must not
437    // take the gaussian-identity closed-form Laplace shortcut: that shortcut
438    // draws an unconstrained `N(mode, φ·H⁻¹)`, which for an active bound puts
439    // ~half its mass on the forbidden side of the boundary. Three geometries
440    // qualify, and all reproduce on a default gaussian model:
441    //   * a `bounded(x, min, max)` interval transform (#1508) — sampled on its
442    //     latent logit scale by `sample_standard_bounded`;
443    //   * a `nonnegative()`/`nonpositive()`/`linear(min,max)`/`constrain()` box
444    //     bound on a parametric coefficient (#1507); and
445    //   * a monotone/convex/concave shape cone on a spline (#1509).
446    // The latter two are sampled from the truncated Gaussian below. Detect all
447    // three cheaply from the saved termspec so the common, fully-unconstrained
448    // gaussian path keeps its fast exact fallback without building the design;
449    // the precise dispatch (and the authoritative `design.linear_constraints`
450    // check) happens after the design is assembled.
451    let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
452        ts.linear_terms.iter().any(|term| {
453            !matches!(
454                term.coefficient_geometry,
455                LinearCoefficientGeometry::Unconstrained
456            ) || term.coefficient_min.is_some()
457                || term.coefficient_max.is_some()
458        }) || ts
459            .smooth_terms
460            .iter()
461            .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
462    });
463    if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
464        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
465    }
466    if model.has_link_wiggle() {
467        // A Gaussian-identity link-wiggle model is sampled from its saved
468        // closed-form joint Laplace posterior (the mean and wiggle coefficients
469        // are jointly Gaussian); only the non-Gaussian wiggle posterior needs
470        // the dedicated link-wiggle NUTS path. Preserved from the original
471        // dispatch, where the Gaussian-identity shortcut ran ahead of the
472        // wiggle branch and so claimed Gaussian wiggle models for the
473        // closed-form path.
474        if likelihood.is_gaussian_identity() {
475            return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
476        }
477        return sample_standard_link_wiggle(
478            model,
479            data,
480            col_map,
481            training_headers,
482            likelihood,
483            cfg,
484        );
485    }
486    let parsed = parse_formula(&model.formula)?;
487    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
488    let y = data.column(y_col).to_owned();
489    let spec = resolve_termspec_for_prediction(
490        &model.resolved_termspec,
491        training_headers,
492        col_map,
493        "resolved_termspec",
494    )?;
495    let design = build_term_collection_design(data, &spec)
496        .map_err(|e| format!("failed to build term collection design: {e}"))?;
497
498    // ---- Constraint-aware posterior dispatch -------------------------------
499    //
500    // A coefficient subject to an *active* constraint sits on the boundary of
501    // its feasible region, so a plain unconstrained draw `N(mode, φ·H⁻¹)`
502    // places ~half its mass on the forbidden side. The constrained geometry
503    // must therefore be reconstructed *here*, ahead of both the
504    // Gaussian-identity closed-form shortcut and the GLM-NUTS fallback —
505    // neither of which is aware of the feasible region (#1507/#1508/#1509).
506    // The fit pins the point estimate correctly; only the posterior was blind.
507
508    // (1) bounded() interval coefficients are not sampled by the GLM-NUTS path.
509    // That path runs the Hamiltonian over the *raw* linear design with the
510    // saved user-scale mode, treating every coefficient as an unconstrained,
511    // Gaussian-penalized parameter. Bounded terms are fit through a custom
512    // family that drives eta via an interval transform `beta = min + (max-min)·
513    // sigmoid(theta)` of an unconstrained latent `theta`. The posterior is
514    // Gaussian on that *latent* scale (which is exactly where the fit treats the
515    // coefficient as a locally-quadratic, unconstrained parameter), so the
516    // correct draws are `theta ~ N(theta_mode, H_latent^{-1})` pushed forward
517    // through the interval map — never a Gaussian on the user scale, which can
518    // place mass outside [min,max] and discards the boundary-induced skew. The
519    // saved fit exports the user-scale mode and user-scale penalized Hessian;
520    // `sample_bounded_latent_posterior_internal` reconstructs the latent
521    // geometry via the exact inverse delta-method (`H_latent = J H_user J`) and
522    // returns user-scale draws that always lie strictly inside the interval.
523    // This must precede the Gaussian-identity shortcut: a Gaussian `bounded()`
524    // model would otherwise take the closed-form path and emit a user-scale
525    // Gaussian that spills outside the interval (#1508).
526    let has_bounded = spec.linear_terms.iter().any(|term| {
527        matches!(
528            term.coefficient_geometry,
529            LinearCoefficientGeometry::Bounded { .. }
530        )
531    });
532    if has_bounded {
533        // Mirror the fit-time layout: linear coefficient `j` lives at column
534        // `intercept_range.end + j` of the model's coefficient vector. Bounds
535        // are on the original (user/data) scale, which is also the scale the
536        // saved beta and penalized Hessian live on.
537        let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
538            .linear_terms
539            .iter()
540            .enumerate()
541            .filter_map(|(j, term)| match term.coefficient_geometry {
542                LinearCoefficientGeometry::Bounded { min, max, .. } => {
543                    Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
544                        col_idx: design.intercept_range.end + j,
545                        min,
546                        max,
547                    })
548                }
549                LinearCoefficientGeometry::Unconstrained => None,
550            })
551            .collect();
552        return sample_standard_bounded(model, cfg, &bounded_columns);
553    }
554
555    // (2) box / shape *inequality* constraints — `nonnegative()` /
556    // `linear(min,max)` / `constrain()` box bounds on a parametric coefficient
557    // (#1507) and the monotone/convex/concave shape cone `γ_j ≥ 0` on a spline
558    // (#1509). Both are reconstructed by `build_term_collection_design` into a
559    // single `A β ≥ b` polytope in the saved coefficient coordinate system, so
560    // one truncated-Gaussian sampler covers them uniformly. Like `bounded()`,
561    // this must precede the Gaussian-identity shortcut so a constrained
562    // Gaussian model is sampled inside its feasible region rather than from the
563    // boundary-centred unconstrained Gaussian.
564    if let Some(constraints) = design
565        .linear_constraints
566        .as_ref()
567        .filter(|c| c.a.nrows() > 0)
568    {
569        return sample_standard_truncated(model, cfg, constraints);
570    }
571
572    // (3) unconstrained Gaussian identity — saved closed-form Laplace posterior.
573    if likelihood.is_gaussian_identity() {
574        return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
575    }
576
577    // (4) unconstrained non-Gaussian GLM — exact NUTS over the raw design.
578    let weights = Array1::ones(data.nrows());
579    let dense_design_hmc = design.design.to_dense();
580    let p = dense_design_hmc.ncols();
581    let fit = fit_result_from_saved_model_for_prediction(model)?;
582    // Refresh the NB overdispersion `theta` from the fit's jointly-estimated
583    // `theta_hat` before sampling. The construction seed stored on the family
584    // spec (`theta: 1.0`) only seeds the inner solve; the NUTS NB log-likelihood
585    // / score (`src/inference/hmc.rs`) reads `theta` straight off this spec, so
586    // leaving the seed in place over-states `Var(y) = μ + μ²/θ` and inflates
587    // every coefficient's posterior SD (#1463 — the HMC sibling of the
588    // replicate-path bug #1124). `theta_fixed` NB carries the user's exact value
589    // in both the spec and the scale metadata, so this refresh is a no-op there.
590    // Mirrors how the replicate path reads `theta_hat` via the canonical
591    // `family_noise_parameter` helper (`negbin_theta().or(seed)`).
592    refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
593    if fit.beta.len() != p {
594        return Err(format!(
595            "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
596            fit.beta.len(),
597            p,
598        ));
599    }
600    if fit.lambdas.len() != design.penalties.len() {
601        return Err(format!(
602            "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
603            fit.lambdas.len(),
604            design.penalties.len(),
605        ));
606    }
607    let penalty =
608        weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
609
610    // Re-apply the offset the model was fit with so the posterior targets the
611    // same η = Xβ + offset as the fit and predict paths. The diagnostic loader
612    // keeps the saved offset column in the frame; dropping the offset silently
613    // sampled the wrong target for any `--offset-column` GLM (#882).
614    let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
615        Some(name) => {
616            let idx = resolve_role_col(col_map, name, "offset")?;
617            Some(data.column(idx).to_owned())
618        }
619        None => None,
620    };
621
622    run_nuts_sampling_flattened_family(
623        likelihood,
624        FamilyNutsInputs::Glm(GlmFlatInputs {
625            x: dense_design_hmc.view(),
626            y: y.view(),
627            weights: weights.view(),
628            penalty_matrix: penalty.view(),
629            mode: fit.beta.view(),
630            hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
631            gamma_shape: fit.likelihood_scale.gamma_shape(),
632            // Forward the saved training dispersion so NUTS whitening uses the
633            // posterior scale selected at fit time; fixed-scale families remain
634            // a no-op.
635            dispersion: fit.dispersion().unwrap_or_default(),
636            firth_bias_reduction: false,
637            offset: offset_vec.as_ref().map(|o| o.view()),
638        }),
639        cfg,
640    )
641    .map_err(|e| format!("NUTS sampling failed: {e}"))
642}
643
644/// Exact posterior draws for a standard GLM with `bounded()` coefficients.
645///
646/// The bounded coefficients are sampled on their natural latent (logit) scale —
647/// where the Laplace approximation is Gaussian — and every draw is pushed
648/// through the exact interval map so user-scale draws always lie strictly inside
649/// `[min, max]` and carry the boundary-induced skew. Non-bounded coefficients
650/// are drawn as the ordinary Gaussian Laplace component of the same joint
651/// posterior, so cross-coefficient correlations with the bounded columns are
652/// preserved (the latent precision is the full `H_latent = J H_user J`).
653fn sample_standard_bounded(
654    model: &SavedModel,
655    cfg: &NutsConfig,
656    bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
657) -> Result<NutsResult, String> {
658    validate_nuts_config(cfg).map_err(String::from)?;
659    let fit = fit_result_from_saved_model_for_prediction(model)?;
660    let mode = fit.beta.clone();
661    let p = mode.len();
662    if p == 0 {
663        return Err(
664            "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
665                .to_string(),
666        );
667    }
668    // The bounded fit exports the UNSCALED user-scale penalized Hessian; the
669    // latent sampler reconstructs the latent precision from it via the exact
670    // inverse delta-method. (`explicit_fit_hessian_for_whitening` returns this
671    // same user-scale penalized Hessian for a saved standard fit.)
672    let user_hessian =
673        explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
674    // The exported Hessian carries unit implicit dispersion, so the latent
675    // posterior covariance is `cov_scale·H_latent⁻¹` with `cov_scale` the
676    // coefficient-covariance scale the fit used for `Vb` (`σ̂²` for a profiled
677    // Gaussian, `1` for fixed-scale Binomial). Re-applying `√cov_scale` here
678    // keeps the draw spread identical to the reported `summary().std_error`
679    // (gam#1514); the truncated-constraint path does the analogous √φ lift.
680    let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
681    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
682    let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
683        &mode,
684        user_hessian,
685        bounded_columns,
686        n_total,
687        sqrt_cov_scale,
688        chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
689    )
690    .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
691
692    let posterior_mean = samples
693        .mean_axis(ndarray::Axis(0))
694        .unwrap_or_else(|| Array1::<f64>::zeros(p));
695    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
696
697    Ok(NutsResult {
698        samples,
699        posterior_mean,
700        posterior_std,
701        rhat: 1.0,
702        ess: n_total as f64,
703        converged: true,
704    })
705}
706
707/// Exact posterior draws for a standard GLM whose coefficients carry linear
708/// *inequality* constraints `A β ≥ b` — `nonnegative()` / `linear(min,max)` /
709/// `constrain()` box bounds on a parametric term (#1507) and the
710/// monotone/convex/concave shape cone `γ_j ≥ 0` on a spline (#1509).
711///
712/// The posterior is the Laplace Gaussian `N(mode, φ·H⁻¹)` *truncated* to the
713/// feasible polytope. For a Gaussian-identity model this is the exact
714/// posterior; for a non-Gaussian GLM it is the constraint-respecting Laplace
715/// approximation — the same modelling choice the `bounded()` term makes. The
716/// draws are produced by exact reflective Hamiltonian Monte Carlo
717/// ([`crate::truncated_gaussian`]), so every draw is feasible and
718/// successive draws are essentially independent (`rhat ≈ 1`, matching the other
719/// Laplace posterior paths).
720fn sample_standard_truncated(
721    model: &SavedModel,
722    cfg: &NutsConfig,
723    constraints: &gam_solve::pirls::LinearInequalityConstraints,
724) -> Result<NutsResult, String> {
725    validate_nuts_config(cfg).map_err(String::from)?;
726    let fit = fit_result_from_saved_model_for_prediction(model)?;
727    let mode = fit.beta.clone();
728    let p = mode.len();
729    if p == 0 {
730        return Err(
731            "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
732             vector"
733                .to_string(),
734        );
735    }
736    // The saved standard fit exports the unscaled user-scale penalised Hessian
737    // `H`; the truncated sampler whitens with its Cholesky and re-applies √φ so
738    // the posterior covariance is `φ·H⁻¹`, identical to the unconstrained
739    // Gaussian/bounded paths. Fixed-scale families (Binomial / Poisson) have
740    // φ = 1.
741    let penalized_hessian =
742        explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
743    let sqrt_phi = {
744        use gam_problem::dispersion_cov::DispersionExt as _;
745        fit.dispersion().unwrap_or_default().sqrt_phi()
746    };
747    let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
748        &mode,
749        &penalized_hessian,
750        sqrt_phi,
751        constraints,
752        cfg.n_samples,
753        cfg.n_chains,
754        chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
755    )?;
756    let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
757
758    let posterior_mean = samples
759        .mean_axis(ndarray::Axis(0))
760        .unwrap_or_else(|| Array1::<f64>::zeros(p));
761    let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
762
763    Ok(NutsResult {
764        samples,
765        posterior_mean,
766        posterior_std,
767        rhat: 1.0,
768        ess: n_total as f64,
769        converged: true,
770    })
771}
772
773fn sample_standard_link_wiggle(
774    model: &SavedModel,
775    data: ArrayView2<'_, f64>,
776    col_map: &HashMap<String, usize>,
777    training_headers: Option<&Vec<String>>,
778    likelihood: LikelihoodSpec,
779    cfg: &NutsConfig,
780) -> Result<NutsResult, String> {
781    let parsed = parse_formula(&model.formula)?;
782    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
783    let y = data.column(y_col).to_owned();
784
785    let spec = resolve_termspec_for_prediction(
786        &model.resolved_termspec,
787        training_headers,
788        col_map,
789        "resolved_termspec",
790    )?;
791    let design = build_term_collection_design(data, &spec)
792        .map_err(|e| format!("failed to build term collection design: {e}"))?;
793    let p_main = design.design.ncols();
794
795    let fit = fit_result_from_saved_model_for_prediction(model)?;
796    let wiggle_runtime = model
797        .saved_prediction_runtime()?
798        .link_wiggle
799        .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
800    let mode_beta = fit
801        .block_by_role(BlockRole::Mean)
802        .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
803        .beta
804        .clone();
805    let mode_theta = fit
806        .block_by_role(BlockRole::LinkWiggle)
807        .ok_or_else(|| {
808            "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
809        })?
810        .beta
811        .clone();
812    let p_wiggle = mode_theta.len();
813    let p_total = mode_beta.len() + p_wiggle;
814
815    if mode_beta.len() != p_main {
816        return Err(format!(
817            "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
818            mode_beta.len(),
819            p_main,
820        ));
821    }
822    if fit.beta.len() != p_total {
823        return Err(format!(
824            "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
825            fit.beta.len(),
826            p_main,
827            p_wiggle,
828            p_total,
829        ));
830    }
831
832    let hessian = &fit
833        .geometry
834        .as_ref()
835        .ok_or_else(|| {
836            "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
837                .to_string()
838        })?
839        .penalized_hessian;
840    validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
841
842    let n_base_penalties = design.penalties.len();
843    let base_lambdas = fit
844        .block_by_role(BlockRole::Mean)
845        .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
846        .lambdas
847        .view();
848    if base_lambdas.len() != n_base_penalties {
849        return Err(format!(
850            "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
851            base_lambdas.len(),
852            n_base_penalties,
853        ));
854    }
855
856    let penalty_base =
857        weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
858
859    let wiggle_lambdas_owned = fit
860        .lambdas_linkwiggle()
861        .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
862    let wiggle_lambdas = wiggle_lambdas_owned.view();
863    let degree = wiggle_runtime.degree;
864    let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
865
866    let mut wiggle_penalties = Vec::new();
867    let default_orders = [2usize];
868    let n_wiggle_lambdas = wiggle_lambdas.len();
869    for k in 0..n_wiggle_lambdas {
870        let order = if k < default_orders.len() {
871            default_orders[k]
872        } else {
873            k + 1
874        };
875        if order >= p_wiggle {
876            continue;
877        }
878        let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
879            .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
880        wiggle_penalties.push(penalty);
881    }
882    while wiggle_penalties.len() < n_wiggle_lambdas {
883        wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
884    }
885
886    let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
887
888    let q0 = design.design.dot(&mode_beta);
889    let (q0_min, q0_max) = q0
890        .iter()
891        .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
892            (lo.min(v), hi.max(v))
893        });
894
895    let spline = LinkWiggleSplineArtifacts {
896        knot_range: (q0_min, q0_max),
897        knot_vector: knot_arr,
898        degree,
899    };
900
901    let nuts_family = match (&likelihood.response, &likelihood.link) {
902        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
903            NutsFamily::BinomialLogit
904        }
905        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
906            NutsFamily::BinomialProbit
907        }
908        (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
909            NutsFamily::BinomialCLogLog
910        }
911        (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
912        (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
913        (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
914        (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
915        (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
916        _ => {
917            return Err(format!(
918                "NUTS sampling with link wiggle is not supported for family {}",
919                likelihood.pretty_name()
920            ));
921        }
922    };
923
924    let weights = Array1::ones(data.nrows());
925    let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
926
927    let wiggle_nuts_dense = design.design.as_dense_cow();
928    run_link_wiggle_nuts_sampling(
929        wiggle_nuts_dense.view(),
930        y.view(),
931        weights.view(),
932        penalty_base.view(),
933        penalty_link.view(),
934        mode_beta.view(),
935        mode_theta.view(),
936        hessian.view(),
937        spline,
938        nuts_family,
939        scale,
940        cfg,
941    )
942    .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
943}
944
945fn sample_survival(
946    model: &SavedModel,
947    data: ArrayView2<'_, f64>,
948    col_map: &HashMap<String, usize>,
949    training_headers: Option<&Vec<String>>,
950    cfg: &NutsConfig,
951) -> Result<NutsResult, String> {
952    let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
953    if matches!(
954        saved_likelihood_mode,
955        SurvivalLikelihoodMode::Latent
956            | SurvivalLikelihoodMode::LatentBinary
957            | SurvivalLikelihoodMode::LocationScale
958    ) {
959        return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
960    }
961    // `survival_entry == None` is the right-censored shorthand
962    // `Surv(time, event)`: training synthesized a zero entry column,
963    // and posterior sampling must do the same so artifacts fit with
964    // the shorthand are first-class through `gam sample` /
965    // `model.sample` just like `gam predict` already handles them in
966    // `run_predict_survival`. The resolution flows through the shared
967    // `resolve_saved_survival_time_columns` helper so every consumer
968    // of saved survival metadata applies the same fallback contract.
969    let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
970    let exit_col = time_cols.exit_col;
971    let eventname = model
972        .survival_event
973        .as_ref()
974        .ok_or_else(|| "survival model missing event column metadata".to_string())?;
975    let event_col = resolve_role_col(col_map, eventname, "event")?;
976    let termspec = resolve_termspec_for_prediction(
977        &model.resolved_termspec,
978        training_headers,
979        col_map,
980        "resolved_termspec",
981    )?;
982    let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
983    let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
984    let cov_design = build_term_collection_design(cov_input, &termspec)
985        .map_err(|e| format!("failed to build survival design: {e}"))?;
986    let n = data.nrows();
987    let p_cov = cov_design.design.ncols();
988    let mut age_entry = Array1::<f64>::zeros(n);
989    let mut age_exit = Array1::<f64>::zeros(n);
990    let mut event_target = Array1::<u8>::zeros(n);
991    let event_competing = Array1::<u8>::zeros(n);
992    let weights = Array1::<f64>::ones(n);
993    for i in 0..n {
994        let (t0, t1) = normalize_survival_time_pair(
995            time_cols.row_entry_time(data, i),
996            data[[i, exit_col]],
997            i,
998        )?;
999        age_entry[i] = t0;
1000        age_exit[i] = t1;
1001        event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
1002    }
1003    let time_cfg = load_survival_time_basis_config_from_model(model)?;
1004    let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
1005    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
1006        &time_build.basisname,
1007        time_build.degree,
1008        time_build.knots.as_ref(),
1009        time_build.keep_cols.as_ref(),
1010        time_build.smooth_lambda,
1011    )?;
1012    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1013        let time_anchor = model
1014            .survival_time_anchor
1015            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1016        let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
1017        center_survival_time_designs_at_anchor(
1018            &mut time_build.x_entry_time,
1019            &mut time_build.x_exit_time,
1020            &time_anchor_row,
1021        )?;
1022    }
1023    let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1024    let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1025        build_survival_time_offsets_for_likelihood(
1026            &age_entry,
1027            &age_exit,
1028            &baseline_cfg,
1029            saved_likelihood_mode,
1030            None,
1031        )?;
1032    if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1033        let time_anchor = model
1034            .survival_time_anchor
1035            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1036        add_survival_time_derivative_guard_offset(
1037            &age_entry,
1038            &age_exit,
1039            time_anchor,
1040            survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1041            &mut eta_offset_entry,
1042            &mut eta_offset_exit,
1043            &mut derivative_offset_exit,
1044        )?;
1045    }
1046    let saved_timewiggle = saved_baseline_timewiggle_components(
1047        &eta_offset_entry,
1048        &eta_offset_exit,
1049        &derivative_offset_exit,
1050        model,
1051    )?;
1052    let p_time = time_build.x_exit_time.ncols();
1053    let p_timewiggle = saved_timewiggle
1054        .as_ref()
1055        .map(|(_, exit, _)| exit.ncols())
1056        .unwrap_or(0);
1057    let p = p_time + p_timewiggle + p_cov;
1058    let tb_entry_dense = time_build.x_entry_time.to_dense();
1059    let tb_exit_dense = time_build.x_exit_time.to_dense();
1060    let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1061    let mut x_entry = Array2::<f64>::zeros((n, p));
1062    let mut x_exit = Array2::<f64>::zeros((n, p));
1063    let mut x_derivative = Array2::<f64>::zeros((n, p));
1064    if p_time > 0 {
1065        x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1066        x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1067        x_derivative
1068            .slice_mut(s![.., ..p_time])
1069            .assign(&tb_deriv_dense);
1070    }
1071    if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1072        && p_timewiggle > 0
1073    {
1074        x_entry
1075            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1076            .assign(entry_w);
1077        x_exit
1078            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1079            .assign(exit_w);
1080        x_derivative
1081            .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1082            .assign(deriv_w);
1083    }
1084    if p_cov > 0 {
1085        let cov_dense = cov_design.design.to_dense();
1086        let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1087        x_entry
1088            .slice_mut(s![.., cov_range.clone()])
1089            .assign(&cov_dense);
1090        x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1091    }
1092    let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1093    for (idx, s) in time_build.penalties.iter().enumerate() {
1094        if s.nrows() == p_time && s.ncols() == p_time {
1095            penalty_blocks.push(PenaltyBlock {
1096                matrix: s.clone(),
1097                lambda: time_build
1098                    .smooth_lambda
1099                    .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1100                range: 0..p_time,
1101                nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1102            });
1103        }
1104    }
1105    let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1106    if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1107        let start = p_time;
1108        let end = start + exit_w.ncols();
1109        let wiggle_lambda_offset = penalty_blocks.len();
1110        let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1111            "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1112        })?;
1113        let wiggle_degree = wiggle_cfg.degree;
1114        let wiggle_knots =
1115            Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1116                "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1117            })?);
1118        let mut seed = Array1::<f64>::zeros(2 * n);
1119        for i in 0..n {
1120            seed[i] = eta_offset_entry[i];
1121            seed[n + i] = eta_offset_exit[i];
1122        }
1123        let (primary_order, extra_orders) =
1124            split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1125        let mut block = buildwiggle_block_input_from_knots(
1126            seed.view(),
1127            &wiggle_knots,
1128            wiggle_degree,
1129            primary_order,
1130            wiggle_cfg.double_penalty,
1131        )?;
1132        append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1133            .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1134        for (widx, s) in block.penalties.iter().enumerate() {
1135            let s = match s {
1136                gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1137                gam_solve::estimate::PenaltySpec::Dense(m)
1138                | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1139            };
1140            if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1141                penalty_blocks.push(PenaltyBlock {
1142                    matrix: s.clone(),
1143                    lambda: time_build
1144                        .smooth_lambda
1145                        .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1146                    range: start..end,
1147                    nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1148                });
1149            }
1150        }
1151        for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1152            .iter_mut()
1153            .enumerate()
1154        {
1155            if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1156                block_penalty.lambda = lam;
1157            }
1158        }
1159    }
1160    let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1161        "saved survival model is missing survivalridge_lambda; refusing to \
1162         pick a load-time default (the historical 1e-4 fallback silently \
1163         disagreed with the 1e-6 fit-time default). Refit."
1164            .to_string()
1165    })?;
1166    let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1167    {
1168        1
1169    } else {
1170        0
1171    };
1172    if ridge_lambda > 0.0 && p > ridge_range_start {
1173        let dim = p - ridge_range_start;
1174        let mut ridge = Array2::<f64>::zeros((dim, dim));
1175        for d in 0..dim {
1176            ridge[[d, d]] = 1.0;
1177        }
1178        penalty_blocks.push(PenaltyBlock {
1179            matrix: ridge,
1180            lambda: ridge_lambda,
1181            range: ridge_range_start..p,
1182            nullspace_dim: 0,
1183        });
1184    }
1185    for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1186        if let Some(&lam) = fit_saved.lambdas.get(idx) {
1187            block.lambda = lam;
1188        }
1189    }
1190    let penalties = PenaltyBlocks::new(penalty_blocks);
1191    let survivalspec = match model
1192        .survivalspec
1193        .as_deref()
1194        .unwrap_or("net")
1195        .to_ascii_lowercase()
1196        .as_str()
1197    {
1198        "net" => SurvivalSpec::Net,
1199        "crude" => {
1200            return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1201                        .to_string());
1202        }
1203        other => {
1204            return Err(format!("unsupported saved survival spec '{other}'"));
1205        }
1206    };
1207    let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1208    let mut model_surv = royston_parmar::working_model_from_flattened(
1209        penalties.clone(),
1210        monotonicity,
1211        survivalspec,
1212        RoystonParmarInputs {
1213            age_entry: age_entry.view(),
1214            age_exit: age_exit.view(),
1215            event_target: event_target.view(),
1216            event_competing: event_competing.view(),
1217            weights: weights.view(),
1218            x_entry: x_entry.view(),
1219            x_exit: x_exit.view(),
1220            x_derivative: x_derivative.view(),
1221            monotonicity_constraint_rows: None,
1222            monotonicity_constraint_offsets: None,
1223            eta_offset_entry: Some(eta_offset_entry.view()),
1224            eta_offset_exit: Some(eta_offset_exit.view()),
1225            derivative_offset_exit: Some(derivative_offset_exit.view()),
1226        },
1227    )
1228    .map_err(|e| format!("failed to construct survival model: {e}"))?;
1229    if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1230        model_surv
1231            .set_structural_monotonicity(true, p_time + p_timewiggle)
1232            .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1233    }
1234    let beta0 = fit_saved.beta.clone();
1235    let state = model_surv
1236        .update_state(&beta0)
1237        .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1238    let hessian = state.hessian.to_dense();
1239    run_survival_nuts_sampling_flattened(
1240        SurvivalFlatInputs {
1241            age_entry: age_entry.view(),
1242            age_exit: age_exit.view(),
1243            event_target: event_target.view(),
1244            event_competing: event_competing.view(),
1245            weights: weights.view(),
1246            x_entry: x_entry.view(),
1247            x_exit: x_exit.view(),
1248            x_derivative: x_derivative.view(),
1249            eta_offset_entry: Some(eta_offset_entry.view()),
1250            eta_offset_exit: Some(eta_offset_exit.view()),
1251            derivative_offset_exit: Some(derivative_offset_exit.view()),
1252        },
1253        penalties,
1254        monotonicity,
1255        survivalspec,
1256        saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1257        p_time + p_timewiggle,
1258        beta0.view(),
1259        hessian.view(),
1260        cfg,
1261    )
1262    .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267    use super::*;
1268    use gam_problem::types::LikelihoodScaleMetadata;
1269
1270    /// #1463: the NB NUTS path must sample at the fit's jointly-estimated
1271    /// `theta_hat`, not the construction seed `theta = 1.0`. The seed only seeds
1272    /// the inner solve; the NUTS NB log-likelihood/score reads `theta` straight
1273    /// off the sampling `LikelihoodSpec`, so unless we refresh it from the scale
1274    /// metadata the posterior is drawn at the wrong overdispersion and every
1275    /// coefficient's posterior SD inflates ~1.4–1.5×.
1276    ///
1277    /// Pre-fix, `sample_standard` forwarded the seed unchanged: this assertion
1278    /// would read `theta == 1.0` and fail. With the refresh in place the seam
1279    /// rewrites the spec to `theta_hat`.
1280    #[test]
1281    fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1282        // Spec carries the construction seed theta = 1.0; the fit estimated a
1283        // very different theta_hat = 2.97 and recorded it in the scale metadata.
1284        let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1285        let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1286
1287        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1288
1289        match likelihood.response {
1290            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1291                theta, 2.97,
1292                "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1293            ),
1294            other => panic!("expected NegativeBinomial response, got {other:?}"),
1295        }
1296    }
1297
1298    /// A fixed-theta NB fit records the user's exact `theta` in both the spec and
1299    /// the scale metadata, so the refresh is a no-op that still lands on the
1300    /// fixed value (never the inner-solve seed of an estimated fit).
1301    #[test]
1302    fn refresh_negbin_theta_fixed_theta_is_preserved() {
1303        let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1304        let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1305
1306        refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1307
1308        match likelihood.response {
1309            ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1310                assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1311                assert!(theta_fixed, "theta_fixed flag must be preserved");
1312            }
1313            other => panic!("expected NegativeBinomial response, got {other:?}"),
1314        }
1315    }
1316
1317    /// When the fit recorded no NB theta (non-NB scale metadata), the refresh
1318    /// must leave the spec's seed untouched — mirroring the canonical replicate
1319    /// picker's `negbin_theta().or(seed)`.
1320    #[test]
1321    fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1322        let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1323        // ProfiledGaussian carries no negbin_theta, so the accessor returns None.
1324        refresh_negbin_theta_for_sampling(
1325            &mut likelihood,
1326            LikelihoodScaleMetadata::ProfiledGaussian,
1327        );
1328
1329        match likelihood.response {
1330            ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1331                theta, 3.5,
1332                "with no fitted theta the NB seed must be kept verbatim"
1333            ),
1334            other => panic!("expected NegativeBinomial response, got {other:?}"),
1335        }
1336    }
1337
1338    /// Non-NB families must be completely unaffected by the NB refresh, even when
1339    /// the scale metadata happens to carry an NB theta — the match guards on the
1340    /// response family, so Poisson/Gamma/etc. are left untouched.
1341    #[test]
1342    fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1343        let mut poisson = LikelihoodSpec::poisson_log();
1344        let before = poisson.response.clone();
1345        refresh_negbin_theta_for_sampling(
1346            &mut poisson,
1347            LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1348        );
1349        assert_eq!(
1350            poisson.response, before,
1351            "Poisson response must be untouched by the NB theta refresh"
1352        );
1353    }
1354}