1use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18 FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19 explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20 run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27 SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28 build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29 evaluate_survival_time_basis_row, normalize_survival_time_pair,
30 resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33 fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34 resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35 saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39 PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42 append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43 split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47 FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55pub fn saved_baseline_timewiggle_spec(
60 model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62 model
63 .saved_baseline_time_wiggle()
64 .map_err(|e| e.to_string())
65 .map(|runtime| {
66 runtime.map(|saved| LinkWiggleFormulaSpec {
67 degree: saved.degree,
68 num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69 penalty_orders: saved.penalty_orders,
70 double_penalty: saved.double_penalty,
71 })
72 })
73}
74
75fn weighted_penalty_matrix(
76 penalties: &[Array2<f64>],
77 lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79 if penalties.len() != lambdas.len() {
80 return Err(format!(
81 "penalty/lambda mismatch: {} penalties vs {} lambdas",
82 penalties.len(),
83 lambdas.len()
84 ));
85 }
86 if penalties.is_empty() {
87 return Err("cannot sample without at least one penalty block".to_string());
88 }
89 let p = penalties[0].nrows();
90 let mut out = Array2::<f64>::zeros((p, p));
91 for (k, s) in penalties.iter().enumerate() {
92 if s.nrows() != p || s.ncols() != p {
93 return Err(format!(
94 "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95 s.nrows(),
96 s.ncols(),
97 p,
98 p
99 ));
100 }
101 let lam = lambdas[k];
102 out += &(s * lam);
103 }
104 Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108 hessian: &Array2<f64>,
109 expected_dim: usize,
110) -> Result<(), String> {
111 if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112 return Err(format!(
113 "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114 hessian.nrows(),
115 hessian.ncols(),
116 expected_dim,
117 expected_dim,
118 ));
119 }
120 validate_all_finite(
121 "link-wiggle explicit joint Hessian",
122 hessian.iter().copied(),
123 )?;
124 let mut max_abs = 0.0_f64;
125 for r in 0..expected_dim {
126 for c in 0..expected_dim {
127 max_abs = max_abs.max(hessian[[r, c]].abs());
128 let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129 if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130 return Err(format!(
131 "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132 ));
133 }
134 }
135 }
136 if max_abs == 0.0 {
137 return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138 .to_string());
139 }
140 Ok(())
141}
142
143fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152 crate::generative::family_noise_parameter(
153 fit.likelihood_scale,
154 fit.standard_deviation,
155 likelihood,
156 )
157}
158
159fn refresh_negbin_theta_for_sampling(
177 likelihood: &mut LikelihoodSpec,
178 scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180 if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181 if let Some(theta_hat) = scale.negbin_theta() {
182 *theta = theta_hat;
183 }
184 }
185}
186
187fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191 Ok(model.likelihood())
192}
193
194const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202 gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207 splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210pub fn sample_saved_model(
229 model: &SavedModel,
230 data: ArrayView2<'_, f64>,
231 col_map: &HashMap<String, usize>,
232 training_headers: Option<&Vec<String>>,
233 cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235 validate_nuts_config(cfg).map_err(String::from)?;
244 let likelihood = likelihood_spec_for_saved_model(model)?;
245 match model.predict_model_class() {
246 PredictModelClass::Survival => {
247 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253 if matches!(
254 saved_likelihood_mode,
255 SurvivalLikelihoodMode::Latent
256 | SurvivalLikelihoodMode::LatentBinary
257 | SurvivalLikelihoodMode::LocationScale
258 ) {
259 laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260 } else {
261 sample_survival(model, data, col_map, training_headers, cfg)
262 }
263 }
264 PredictModelClass::Standard => {
265 sample_standard(model, data, col_map, training_headers, likelihood, cfg)
266 }
267 PredictModelClass::GaussianLocationScale => {
277 laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
278 }
279 PredictModelClass::BinomialLocationScale => {
280 laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
281 }
282 PredictModelClass::DispersionLocationScale => {
283 laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
284 }
285 PredictModelClass::BernoulliMarginalSlope => {
286 laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
287 }
288 PredictModelClass::TransformationNormal => {
289 laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
290 }
291 }
292}
293
294pub fn laplace_gaussian_fallback(
308 model: &SavedModel,
309 cfg: &NutsConfig,
310 rationale: &'static str,
311) -> Result<NutsResult, String> {
312 use gam_problem::dispersion_cov::DispersionExt as _;
313 validate_nuts_config(cfg).map_err(String::from)?;
318 let fit = fit_result_from_saved_model_for_prediction(model)?;
319 let mode = fit.beta.clone();
320 let p = mode.len();
321 if p == 0 {
322 return Err(format!(
323 "{rationale}: cannot sample from an empty coefficient vector"
324 ));
325 }
326 let h = fit.penalized_hessian().ok_or_else(|| {
327 format!(
328 "{rationale}: posterior fallback requires the explicit penalised Hessian; \
329 refit with exact geometry export to enable posterior sampling for this class."
330 )
331 })?;
332 let dispersion = fit.dispersion().unwrap_or_default();
339 let sqrt_phi = dispersion.sqrt_phi();
340 if h.nrows() != p || h.ncols() != p {
341 return Err(format!(
342 "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
343 h.nrows(),
344 h.ncols(),
345 p,
346 p
347 ));
348 }
349 let chol = h.cholesky(Side::Lower).map_err(|err| {
350 format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
351 })?;
352 let l = chol.lower_triangular();
353
354 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
358 let mut samples = Array2::<f64>::zeros((n_total, p));
359 let mut eps = Array1::<f64>::zeros(p);
360 let mut delta = Array1::<f64>::zeros(p);
361 for chain in 0..cfg.n_chains {
362 let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
363 cfg.seed,
364 chain,
365 0xA0B7_6C5D_E431_298F,
366 ));
367 for draw in 0..cfg.n_samples {
368 let k = chain * cfg.n_samples + draw;
369 for i in 0..p {
370 eps[i] = sample_standard_normal(&mut rng);
371 }
372 back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
373 for i in 0..p {
374 samples[(k, i)] = mode[i] + sqrt_phi * delta[i];
378 }
379 }
380 }
381
382 let posterior_mean = samples
383 .mean_axis(ndarray::Axis(0))
384 .unwrap_or_else(|| Array1::<f64>::zeros(p));
385 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
386
387 Ok(NutsResult {
388 samples,
389 posterior_mean,
390 posterior_std,
391 rhat: 1.0,
392 ess: n_total as f64,
393 converged: true,
394 })
395}
396
397#[inline]
398fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
399 let u1 = rng.random::<f64>().max(1e-16);
403 let u2 = rng.random::<f64>();
404 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
405}
406
407fn sample_standard(
408 model: &SavedModel,
409 data: ArrayView2<'_, f64>,
410 col_map: &HashMap<String, usize>,
411 training_headers: Option<&Vec<String>>,
412 mut likelihood: LikelihoodSpec,
413 cfg: &NutsConfig,
414) -> Result<NutsResult, String> {
415 let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
431 ts.linear_terms.iter().any(|term| {
432 !matches!(
433 term.coefficient_geometry,
434 LinearCoefficientGeometry::Unconstrained
435 ) || term.coefficient_min.is_some()
436 || term.coefficient_max.is_some()
437 }) || ts
438 .smooth_terms
439 .iter()
440 .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
441 });
442 if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
443 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
444 }
445 if model.has_link_wiggle() {
446 if likelihood.is_gaussian_identity() {
454 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
455 }
456 return sample_standard_link_wiggle(
457 model,
458 data,
459 col_map,
460 training_headers,
461 likelihood,
462 cfg,
463 );
464 }
465 let parsed = parse_formula(&model.formula)?;
466 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
467 let y = data.column(y_col).to_owned();
468 let spec = resolve_termspec_for_prediction(
469 &model.resolved_termspec,
470 training_headers,
471 col_map,
472 "resolved_termspec",
473 )?;
474 let design = build_term_collection_design(data, &spec)
475 .map_err(|e| format!("failed to build term collection design: {e}"))?;
476
477 let has_bounded = spec.linear_terms.iter().any(|term| {
506 matches!(
507 term.coefficient_geometry,
508 LinearCoefficientGeometry::Bounded { .. }
509 )
510 });
511 if has_bounded {
512 let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
517 .linear_terms
518 .iter()
519 .enumerate()
520 .filter_map(|(j, term)| match term.coefficient_geometry {
521 LinearCoefficientGeometry::Bounded { min, max, .. } => {
522 Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
523 col_idx: design.intercept_range.end + j,
524 min,
525 max,
526 })
527 }
528 LinearCoefficientGeometry::Unconstrained => None,
529 })
530 .collect();
531 return sample_standard_bounded(model, cfg, &bounded_columns);
532 }
533
534 if let Some(constraints) = design
544 .linear_constraints
545 .as_ref()
546 .filter(|c| c.a.nrows() > 0)
547 {
548 return sample_standard_truncated(model, cfg, constraints);
549 }
550
551 if likelihood.is_gaussian_identity() {
553 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
554 }
555
556 let weights = Array1::ones(data.nrows());
558 let dense_design_hmc = design.design.to_dense();
559 let p = dense_design_hmc.ncols();
560 let fit = fit_result_from_saved_model_for_prediction(model)?;
561 refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
572 if fit.beta.len() != p {
573 return Err(format!(
574 "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
575 fit.beta.len(),
576 p,
577 ));
578 }
579 if fit.lambdas.len() != design.penalties.len() {
580 return Err(format!(
581 "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
582 fit.lambdas.len(),
583 design.penalties.len(),
584 ));
585 }
586 let penalty =
587 weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
588
589 let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
594 Some(name) => {
595 let idx = resolve_role_col(col_map, name, "offset")?;
596 Some(data.column(idx).to_owned())
597 }
598 None => None,
599 };
600
601 run_nuts_sampling_flattened_family(
602 likelihood,
603 FamilyNutsInputs::Glm(GlmFlatInputs {
604 x: dense_design_hmc.view(),
605 y: y.view(),
606 weights: weights.view(),
607 penalty_matrix: penalty.view(),
608 mode: fit.beta.view(),
609 hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
610 gamma_shape: fit.likelihood_scale.gamma_shape(),
611 dispersion: fit.dispersion().unwrap_or_default(),
615 firth_bias_reduction: false,
616 offset: offset_vec.as_ref().map(|o| o.view()),
617 }),
618 cfg,
619 )
620 .map_err(|e| format!("NUTS sampling failed: {e}"))
621}
622
623fn sample_standard_bounded(
633 model: &SavedModel,
634 cfg: &NutsConfig,
635 bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
636) -> Result<NutsResult, String> {
637 validate_nuts_config(cfg).map_err(String::from)?;
638 let fit = fit_result_from_saved_model_for_prediction(model)?;
639 let mode = fit.beta.clone();
640 let p = mode.len();
641 if p == 0 {
642 return Err(
643 "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
644 .to_string(),
645 );
646 }
647 let user_hessian =
652 explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
653 let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
660 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
661 let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
662 &mode,
663 user_hessian,
664 bounded_columns,
665 n_total,
666 sqrt_cov_scale,
667 chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
668 )
669 .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
670
671 let posterior_mean = samples
672 .mean_axis(ndarray::Axis(0))
673 .unwrap_or_else(|| Array1::<f64>::zeros(p));
674 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
675
676 Ok(NutsResult {
677 samples,
678 posterior_mean,
679 posterior_std,
680 rhat: 1.0,
681 ess: n_total as f64,
682 converged: true,
683 })
684}
685
686fn sample_standard_truncated(
700 model: &SavedModel,
701 cfg: &NutsConfig,
702 constraints: &gam_solve::pirls::LinearInequalityConstraints,
703) -> Result<NutsResult, String> {
704 validate_nuts_config(cfg).map_err(String::from)?;
705 let fit = fit_result_from_saved_model_for_prediction(model)?;
706 let mode = fit.beta.clone();
707 let p = mode.len();
708 if p == 0 {
709 return Err(
710 "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
711 vector"
712 .to_string(),
713 );
714 }
715 let penalized_hessian =
721 explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
722 let sqrt_phi = {
723 use gam_problem::dispersion_cov::DispersionExt as _;
724 fit.dispersion().unwrap_or_default().sqrt_phi()
725 };
726 let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
727 &mode,
728 &penalized_hessian,
729 sqrt_phi,
730 constraints,
731 cfg.n_samples,
732 cfg.n_chains,
733 chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
734 )?;
735 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
736
737 let posterior_mean = samples
738 .mean_axis(ndarray::Axis(0))
739 .unwrap_or_else(|| Array1::<f64>::zeros(p));
740 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
741
742 Ok(NutsResult {
743 samples,
744 posterior_mean,
745 posterior_std,
746 rhat: 1.0,
747 ess: n_total as f64,
748 converged: true,
749 })
750}
751
752fn sample_standard_link_wiggle(
753 model: &SavedModel,
754 data: ArrayView2<'_, f64>,
755 col_map: &HashMap<String, usize>,
756 training_headers: Option<&Vec<String>>,
757 likelihood: LikelihoodSpec,
758 cfg: &NutsConfig,
759) -> Result<NutsResult, String> {
760 let parsed = parse_formula(&model.formula)?;
761 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
762 let y = data.column(y_col).to_owned();
763
764 let spec = resolve_termspec_for_prediction(
765 &model.resolved_termspec,
766 training_headers,
767 col_map,
768 "resolved_termspec",
769 )?;
770 let design = build_term_collection_design(data, &spec)
771 .map_err(|e| format!("failed to build term collection design: {e}"))?;
772 let p_main = design.design.ncols();
773
774 let fit = fit_result_from_saved_model_for_prediction(model)?;
775 let wiggle_runtime = model
776 .saved_prediction_runtime()?
777 .link_wiggle
778 .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
779 let mode_beta = fit
780 .block_by_role(BlockRole::Mean)
781 .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
782 .beta
783 .clone();
784 let mode_theta = fit
785 .block_by_role(BlockRole::LinkWiggle)
786 .ok_or_else(|| {
787 "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
788 })?
789 .beta
790 .clone();
791 let p_wiggle = mode_theta.len();
792 let p_total = mode_beta.len() + p_wiggle;
793
794 if mode_beta.len() != p_main {
795 return Err(format!(
796 "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
797 mode_beta.len(),
798 p_main,
799 ));
800 }
801 if fit.beta.len() != p_total {
802 return Err(format!(
803 "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
804 fit.beta.len(),
805 p_main,
806 p_wiggle,
807 p_total,
808 ));
809 }
810
811 let hessian = &fit
812 .geometry
813 .as_ref()
814 .ok_or_else(|| {
815 "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
816 .to_string()
817 })?
818 .penalized_hessian;
819 validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
820
821 let n_base_penalties = design.penalties.len();
822 let base_lambdas = fit
823 .block_by_role(BlockRole::Mean)
824 .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
825 .lambdas
826 .view();
827 if base_lambdas.len() != n_base_penalties {
828 return Err(format!(
829 "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
830 base_lambdas.len(),
831 n_base_penalties,
832 ));
833 }
834
835 let penalty_base =
836 weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
837
838 let wiggle_lambdas_owned = fit
839 .lambdas_linkwiggle()
840 .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
841 let wiggle_lambdas = wiggle_lambdas_owned.view();
842 let degree = wiggle_runtime.degree;
843 let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
844
845 let mut wiggle_penalties = Vec::new();
846 let default_orders = [2usize];
847 let n_wiggle_lambdas = wiggle_lambdas.len();
848 for k in 0..n_wiggle_lambdas {
849 let order = if k < default_orders.len() {
850 default_orders[k]
851 } else {
852 k + 1
853 };
854 if order >= p_wiggle {
855 continue;
856 }
857 let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
858 .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
859 wiggle_penalties.push(penalty);
860 }
861 while wiggle_penalties.len() < n_wiggle_lambdas {
862 wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
863 }
864
865 let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
866
867 let q0 = design.design.dot(&mode_beta);
868 let (q0_min, q0_max) = q0
869 .iter()
870 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
871 (lo.min(v), hi.max(v))
872 });
873
874 let spline = LinkWiggleSplineArtifacts {
875 knot_range: (q0_min, q0_max),
876 knot_vector: knot_arr,
877 degree,
878 };
879
880 let nuts_family = match (&likelihood.response, &likelihood.link) {
881 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
882 NutsFamily::BinomialLogit
883 }
884 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
885 NutsFamily::BinomialProbit
886 }
887 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
888 NutsFamily::BinomialCLogLog
889 }
890 (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
891 (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
892 (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
893 (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
894 (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
895 _ => {
896 return Err(format!(
897 "NUTS sampling with link wiggle is not supported for family {}",
898 likelihood.pretty_name()
899 ));
900 }
901 };
902
903 let weights = Array1::ones(data.nrows());
904 let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
905
906 let wiggle_nuts_dense = design.design.as_dense_cow();
907 run_link_wiggle_nuts_sampling(
908 wiggle_nuts_dense.view(),
909 y.view(),
910 weights.view(),
911 penalty_base.view(),
912 penalty_link.view(),
913 mode_beta.view(),
914 mode_theta.view(),
915 hessian.view(),
916 spline,
917 nuts_family,
918 scale,
919 cfg,
920 )
921 .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
922}
923
924fn sample_survival(
925 model: &SavedModel,
926 data: ArrayView2<'_, f64>,
927 col_map: &HashMap<String, usize>,
928 training_headers: Option<&Vec<String>>,
929 cfg: &NutsConfig,
930) -> Result<NutsResult, String> {
931 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
932 if matches!(
933 saved_likelihood_mode,
934 SurvivalLikelihoodMode::Latent
935 | SurvivalLikelihoodMode::LatentBinary
936 | SurvivalLikelihoodMode::LocationScale
937 ) {
938 return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
939 }
940 let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
949 let exit_col = time_cols.exit_col;
950 let eventname = model
951 .survival_event
952 .as_ref()
953 .ok_or_else(|| "survival model missing event column metadata".to_string())?;
954 let event_col = resolve_role_col(col_map, eventname, "event")?;
955 let termspec = resolve_termspec_for_prediction(
956 &model.resolved_termspec,
957 training_headers,
958 col_map,
959 "resolved_termspec",
960 )?;
961 let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
962 let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
963 let cov_design = build_term_collection_design(cov_input, &termspec)
964 .map_err(|e| format!("failed to build survival design: {e}"))?;
965 let n = data.nrows();
966 let p_cov = cov_design.design.ncols();
967 let mut age_entry = Array1::<f64>::zeros(n);
968 let mut age_exit = Array1::<f64>::zeros(n);
969 let mut event_target = Array1::<u8>::zeros(n);
970 let event_competing = Array1::<u8>::zeros(n);
971 let weights = Array1::<f64>::ones(n);
972 for i in 0..n {
973 let (t0, t1) = normalize_survival_time_pair(
974 time_cols.row_entry_time(data, i),
975 data[[i, exit_col]],
976 i,
977 )?;
978 age_entry[i] = t0;
979 age_exit[i] = t1;
980 event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
981 }
982 let time_cfg = load_survival_time_basis_config_from_model(model)?;
983 let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
984 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
985 &time_build.basisname,
986 time_build.degree,
987 time_build.knots.as_ref(),
988 time_build.keep_cols.as_ref(),
989 time_build.smooth_lambda,
990 )?;
991 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
992 let time_anchor = model
993 .survival_time_anchor
994 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
995 let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
996 center_survival_time_designs_at_anchor(
997 &mut time_build.x_entry_time,
998 &mut time_build.x_exit_time,
999 &time_anchor_row,
1000 )?;
1001 }
1002 let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1003 let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1004 build_survival_time_offsets_for_likelihood(
1005 &age_entry,
1006 &age_exit,
1007 &baseline_cfg,
1008 saved_likelihood_mode,
1009 None,
1010 )?;
1011 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1012 let time_anchor = model
1013 .survival_time_anchor
1014 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1015 add_survival_time_derivative_guard_offset(
1016 &age_entry,
1017 &age_exit,
1018 time_anchor,
1019 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1020 &mut eta_offset_entry,
1021 &mut eta_offset_exit,
1022 &mut derivative_offset_exit,
1023 )?;
1024 }
1025 let saved_timewiggle = saved_baseline_timewiggle_components(
1026 &eta_offset_entry,
1027 &eta_offset_exit,
1028 &derivative_offset_exit,
1029 model,
1030 )?;
1031 let p_time = time_build.x_exit_time.ncols();
1032 let p_timewiggle = saved_timewiggle
1033 .as_ref()
1034 .map(|(_, exit, _)| exit.ncols())
1035 .unwrap_or(0);
1036 let p = p_time + p_timewiggle + p_cov;
1037 let tb_entry_dense = time_build.x_entry_time.to_dense();
1038 let tb_exit_dense = time_build.x_exit_time.to_dense();
1039 let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1040 let mut x_entry = Array2::<f64>::zeros((n, p));
1041 let mut x_exit = Array2::<f64>::zeros((n, p));
1042 let mut x_derivative = Array2::<f64>::zeros((n, p));
1043 if p_time > 0 {
1044 x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1045 x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1046 x_derivative
1047 .slice_mut(s![.., ..p_time])
1048 .assign(&tb_deriv_dense);
1049 }
1050 if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1051 && p_timewiggle > 0
1052 {
1053 x_entry
1054 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1055 .assign(entry_w);
1056 x_exit
1057 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1058 .assign(exit_w);
1059 x_derivative
1060 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1061 .assign(deriv_w);
1062 }
1063 if p_cov > 0 {
1064 let cov_dense = cov_design.design.to_dense();
1065 let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1066 x_entry
1067 .slice_mut(s![.., cov_range.clone()])
1068 .assign(&cov_dense);
1069 x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1070 }
1071 let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1072 for (idx, s) in time_build.penalties.iter().enumerate() {
1073 if s.nrows() == p_time && s.ncols() == p_time {
1074 penalty_blocks.push(PenaltyBlock {
1075 matrix: s.clone(),
1076 lambda: time_build
1077 .smooth_lambda
1078 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1079 range: 0..p_time,
1080 nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1081 });
1082 }
1083 }
1084 let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1085 if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1086 let start = p_time;
1087 let end = start + exit_w.ncols();
1088 let wiggle_lambda_offset = penalty_blocks.len();
1089 let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1090 "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1091 })?;
1092 let wiggle_degree = wiggle_cfg.degree;
1093 let wiggle_knots =
1094 Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1095 "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1096 })?);
1097 let mut seed = Array1::<f64>::zeros(2 * n);
1098 for i in 0..n {
1099 seed[i] = eta_offset_entry[i];
1100 seed[n + i] = eta_offset_exit[i];
1101 }
1102 let (primary_order, extra_orders) =
1103 split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1104 let mut block = buildwiggle_block_input_from_knots(
1105 seed.view(),
1106 &wiggle_knots,
1107 wiggle_degree,
1108 primary_order,
1109 wiggle_cfg.double_penalty,
1110 )?;
1111 append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1112 .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1113 for (widx, s) in block.penalties.iter().enumerate() {
1114 let s = match s {
1115 gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1116 gam_solve::estimate::PenaltySpec::Dense(m)
1117 | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1118 };
1119 if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1120 penalty_blocks.push(PenaltyBlock {
1121 matrix: s.clone(),
1122 lambda: time_build
1123 .smooth_lambda
1124 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1125 range: start..end,
1126 nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1127 });
1128 }
1129 }
1130 for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1131 .iter_mut()
1132 .enumerate()
1133 {
1134 if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1135 block_penalty.lambda = lam;
1136 }
1137 }
1138 }
1139 let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1140 "saved survival model is missing survivalridge_lambda; refusing to \
1141 pick a load-time default (the historical 1e-4 fallback silently \
1142 disagreed with the 1e-6 fit-time default). Refit."
1143 .to_string()
1144 })?;
1145 let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1146 {
1147 1
1148 } else {
1149 0
1150 };
1151 if ridge_lambda > 0.0 && p > ridge_range_start {
1152 let dim = p - ridge_range_start;
1153 let mut ridge = Array2::<f64>::zeros((dim, dim));
1154 for d in 0..dim {
1155 ridge[[d, d]] = 1.0;
1156 }
1157 penalty_blocks.push(PenaltyBlock {
1158 matrix: ridge,
1159 lambda: ridge_lambda,
1160 range: ridge_range_start..p,
1161 nullspace_dim: 0,
1162 });
1163 }
1164 for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1165 if let Some(&lam) = fit_saved.lambdas.get(idx) {
1166 block.lambda = lam;
1167 }
1168 }
1169 let penalties = PenaltyBlocks::new(penalty_blocks);
1170 let survivalspec = match model
1171 .survivalspec
1172 .as_deref()
1173 .unwrap_or("net")
1174 .to_ascii_lowercase()
1175 .as_str()
1176 {
1177 "net" => SurvivalSpec::Net,
1178 "crude" => {
1179 return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1180 .to_string());
1181 }
1182 other => {
1183 return Err(format!("unsupported saved survival spec '{other}'"));
1184 }
1185 };
1186 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1187 let mut model_surv = royston_parmar::working_model_from_flattened(
1188 penalties.clone(),
1189 monotonicity,
1190 survivalspec,
1191 RoystonParmarInputs {
1192 age_entry: age_entry.view(),
1193 age_exit: age_exit.view(),
1194 event_target: event_target.view(),
1195 event_competing: event_competing.view(),
1196 weights: weights.view(),
1197 x_entry: x_entry.view(),
1198 x_exit: x_exit.view(),
1199 x_derivative: x_derivative.view(),
1200 monotonicity_constraint_rows: None,
1201 monotonicity_constraint_offsets: None,
1202 eta_offset_entry: Some(eta_offset_entry.view()),
1203 eta_offset_exit: Some(eta_offset_exit.view()),
1204 derivative_offset_exit: Some(derivative_offset_exit.view()),
1205 },
1206 )
1207 .map_err(|e| format!("failed to construct survival model: {e}"))?;
1208 if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1209 model_surv
1210 .set_structural_monotonicity(true, p_time + p_timewiggle)
1211 .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1212 }
1213 let beta0 = fit_saved.beta.clone();
1214 let state = model_surv
1215 .update_state(&beta0)
1216 .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1217 let hessian = state.hessian.to_dense();
1218 run_survival_nuts_sampling_flattened(
1219 SurvivalFlatInputs {
1220 age_entry: age_entry.view(),
1221 age_exit: age_exit.view(),
1222 event_target: event_target.view(),
1223 event_competing: event_competing.view(),
1224 weights: weights.view(),
1225 x_entry: x_entry.view(),
1226 x_exit: x_exit.view(),
1227 x_derivative: x_derivative.view(),
1228 eta_offset_entry: Some(eta_offset_entry.view()),
1229 eta_offset_exit: Some(eta_offset_exit.view()),
1230 derivative_offset_exit: Some(derivative_offset_exit.view()),
1231 },
1232 penalties,
1233 monotonicity,
1234 survivalspec,
1235 saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1236 p_time + p_timewiggle,
1237 beta0.view(),
1238 hessian.view(),
1239 cfg,
1240 )
1241 .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246 use super::*;
1247 use gam_problem::types::LikelihoodScaleMetadata;
1248
1249 #[test]
1260 fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1261 let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1264 let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1265
1266 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1267
1268 match likelihood.response {
1269 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1270 theta, 2.97,
1271 "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1272 ),
1273 other => panic!("expected NegativeBinomial response, got {other:?}"),
1274 }
1275 }
1276
1277 #[test]
1281 fn refresh_negbin_theta_fixed_theta_is_preserved() {
1282 let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1283 let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1284
1285 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1286
1287 match likelihood.response {
1288 ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1289 assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1290 assert!(theta_fixed, "theta_fixed flag must be preserved");
1291 }
1292 other => panic!("expected NegativeBinomial response, got {other:?}"),
1293 }
1294 }
1295
1296 #[test]
1300 fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1301 let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1302 refresh_negbin_theta_for_sampling(
1304 &mut likelihood,
1305 LikelihoodScaleMetadata::ProfiledGaussian,
1306 );
1307
1308 match likelihood.response {
1309 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1310 theta, 3.5,
1311 "with no fitted theta the NB seed must be kept verbatim"
1312 ),
1313 other => panic!("expected NegativeBinomial response, got {other:?}"),
1314 }
1315 }
1316
1317 #[test]
1321 fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1322 let mut poisson = LikelihoodSpec::poisson_log();
1323 let before = poisson.response.clone();
1324 refresh_negbin_theta_for_sampling(
1325 &mut poisson,
1326 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1327 );
1328 assert_eq!(
1329 poisson.response, before,
1330 "Poisson response must be untouched by the NB theta refresh"
1331 );
1332 }
1333}