1use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18 FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19 explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20 run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27 SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28 build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29 evaluate_survival_time_basis_row, normalize_survival_time_pair,
30 resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33 fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34 resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35 saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39 PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42 append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43 split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47 FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55pub fn saved_baseline_timewiggle_spec(
60 model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62 model
63 .saved_baseline_time_wiggle()
64 .map_err(|e| e.to_string())
65 .map(|runtime| {
66 runtime.map(|saved| LinkWiggleFormulaSpec {
67 degree: saved.degree,
68 num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69 penalty_orders: saved.penalty_orders,
70 double_penalty: saved.double_penalty,
71 })
72 })
73}
74
75fn weighted_penalty_matrix(
76 penalties: &[Array2<f64>],
77 lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79 if penalties.len() != lambdas.len() {
80 return Err(format!(
81 "penalty/lambda mismatch: {} penalties vs {} lambdas",
82 penalties.len(),
83 lambdas.len()
84 ));
85 }
86 if penalties.is_empty() {
87 return Err("cannot sample without at least one penalty block".to_string());
88 }
89 let p = penalties[0].nrows();
90 let mut out = Array2::<f64>::zeros((p, p));
91 for (k, s) in penalties.iter().enumerate() {
92 if s.nrows() != p || s.ncols() != p {
93 return Err(format!(
94 "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95 s.nrows(),
96 s.ncols(),
97 p,
98 p
99 ));
100 }
101 let lam = lambdas[k];
102 out += &(s * lam);
103 }
104 Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108 hessian: &Array2<f64>,
109 expected_dim: usize,
110) -> Result<(), String> {
111 if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112 return Err(format!(
113 "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114 hessian.nrows(),
115 hessian.ncols(),
116 expected_dim,
117 expected_dim,
118 ));
119 }
120 validate_all_finite(
121 "link-wiggle explicit joint Hessian",
122 hessian.iter().copied(),
123 )?;
124 let mut max_abs = 0.0_f64;
125 for r in 0..expected_dim {
126 for c in 0..expected_dim {
127 max_abs = max_abs.max(hessian[[r, c]].abs());
128 let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129 if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130 return Err(format!(
131 "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132 ));
133 }
134 }
135 }
136 if max_abs == 0.0 {
137 return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138 .to_string());
139 }
140 Ok(())
141}
142
143fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152 crate::generative::family_noise_parameter(
153 fit.likelihood_scale,
154 fit.standard_deviation,
155 likelihood,
156 )
157}
158
159fn refresh_negbin_theta_for_sampling(
177 likelihood: &mut LikelihoodSpec,
178 scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180 if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181 if let Some(theta_hat) = scale.negbin_theta() {
182 *theta = theta_hat;
183 }
184 }
185}
186
187fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191 Ok(model.likelihood())
192}
193
194const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202 gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207 splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210pub fn sample_saved_model(
229 model: &SavedModel,
230 data: ArrayView2<'_, f64>,
231 col_map: &HashMap<String, usize>,
232 training_headers: Option<&Vec<String>>,
233 cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235 validate_nuts_config(cfg).map_err(String::from)?;
244 let likelihood = likelihood_spec_for_saved_model(model)?;
245 match model.predict_model_class() {
246 PredictModelClass::Survival => {
247 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253 if matches!(
254 saved_likelihood_mode,
255 SurvivalLikelihoodMode::Latent
256 | SurvivalLikelihoodMode::LatentBinary
257 | SurvivalLikelihoodMode::LocationScale
258 ) {
259 laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260 } else {
261 sample_survival(model, data, col_map, training_headers, cfg)
262 }
263 }
264 PredictModelClass::Standard => {
265 if matches!(likelihood.response, ResponseFamily::Beta { .. }) {
274 laplace_gaussian_fallback(model, cfg, "beta-regression posterior fallback")
275 } else {
276 sample_standard(model, data, col_map, training_headers, likelihood, cfg)
277 }
278 }
279 PredictModelClass::GaussianLocationScale => {
289 laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
290 }
291 PredictModelClass::BinomialLocationScale => {
292 laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
293 }
294 PredictModelClass::DispersionLocationScale => {
295 laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
296 }
297 PredictModelClass::BernoulliMarginalSlope => {
298 laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
299 }
300 PredictModelClass::TransformationNormal => {
301 laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
302 }
303 }
304}
305
306pub fn laplace_gaussian_fallback(
320 model: &SavedModel,
321 cfg: &NutsConfig,
322 rationale: &'static str,
323) -> Result<NutsResult, String> {
324 validate_nuts_config(cfg).map_err(String::from)?;
329 let fit = fit_result_from_saved_model_for_prediction(model)?;
330 let mode = fit.beta.clone();
331 let p = mode.len();
332 if p == 0 {
333 return Err(format!(
334 "{rationale}: cannot sample from an empty coefficient vector"
335 ));
336 }
337 let h = fit.penalized_hessian().ok_or_else(|| {
338 format!(
339 "{rationale}: posterior fallback requires the explicit penalised Hessian; \
340 refit with exact geometry export to enable posterior sampling for this class."
341 )
342 })?;
343 let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
360 if h.nrows() != p || h.ncols() != p {
361 return Err(format!(
362 "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
363 h.nrows(),
364 h.ncols(),
365 p,
366 p
367 ));
368 }
369 let chol = h.cholesky(Side::Lower).map_err(|err| {
370 format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
371 })?;
372 let l = chol.lower_triangular();
373
374 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
378 let mut samples = Array2::<f64>::zeros((n_total, p));
379 let mut eps = Array1::<f64>::zeros(p);
380 let mut delta = Array1::<f64>::zeros(p);
381 for chain in 0..cfg.n_chains {
382 let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
383 cfg.seed,
384 chain,
385 0xA0B7_6C5D_E431_298F,
386 ));
387 for draw in 0..cfg.n_samples {
388 let k = chain * cfg.n_samples + draw;
389 for i in 0..p {
390 eps[i] = sample_standard_normal(&mut rng);
391 }
392 back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
393 for i in 0..p {
394 samples[(k, i)] = mode[i] + sqrt_cov_scale * delta[i];
399 }
400 }
401 }
402
403 let posterior_mean = samples
404 .mean_axis(ndarray::Axis(0))
405 .unwrap_or_else(|| Array1::<f64>::zeros(p));
406 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
407
408 Ok(NutsResult {
409 samples,
410 posterior_mean,
411 posterior_std,
412 rhat: 1.0,
413 ess: n_total as f64,
414 converged: true,
415 })
416}
417
418#[inline]
419fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
420 let u1 = rng.random::<f64>().max(1e-16);
424 let u2 = rng.random::<f64>();
425 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
426}
427
428fn sample_standard(
429 model: &SavedModel,
430 data: ArrayView2<'_, f64>,
431 col_map: &HashMap<String, usize>,
432 training_headers: Option<&Vec<String>>,
433 mut likelihood: LikelihoodSpec,
434 cfg: &NutsConfig,
435) -> Result<NutsResult, String> {
436 let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
452 ts.linear_terms.iter().any(|term| {
453 !matches!(
454 term.coefficient_geometry,
455 LinearCoefficientGeometry::Unconstrained
456 ) || term.coefficient_min.is_some()
457 || term.coefficient_max.is_some()
458 }) || ts
459 .smooth_terms
460 .iter()
461 .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
462 });
463 if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
464 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
465 }
466 if model.has_link_wiggle() {
467 if likelihood.is_gaussian_identity() {
475 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
476 }
477 return sample_standard_link_wiggle(
478 model,
479 data,
480 col_map,
481 training_headers,
482 likelihood,
483 cfg,
484 );
485 }
486 let parsed = parse_formula(&model.formula)?;
487 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
488 let y = data.column(y_col).to_owned();
489 let spec = resolve_termspec_for_prediction(
490 &model.resolved_termspec,
491 training_headers,
492 col_map,
493 "resolved_termspec",
494 )?;
495 let design = build_term_collection_design(data, &spec)
496 .map_err(|e| format!("failed to build term collection design: {e}"))?;
497
498 let has_bounded = spec.linear_terms.iter().any(|term| {
527 matches!(
528 term.coefficient_geometry,
529 LinearCoefficientGeometry::Bounded { .. }
530 )
531 });
532 if has_bounded {
533 let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
538 .linear_terms
539 .iter()
540 .enumerate()
541 .filter_map(|(j, term)| match term.coefficient_geometry {
542 LinearCoefficientGeometry::Bounded { min, max, .. } => {
543 Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
544 col_idx: design.intercept_range.end + j,
545 min,
546 max,
547 })
548 }
549 LinearCoefficientGeometry::Unconstrained => None,
550 })
551 .collect();
552 return sample_standard_bounded(model, cfg, &bounded_columns);
553 }
554
555 if let Some(constraints) = design
565 .linear_constraints
566 .as_ref()
567 .filter(|c| c.a.nrows() > 0)
568 {
569 return sample_standard_truncated(model, cfg, constraints);
570 }
571
572 if likelihood.is_gaussian_identity() {
574 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
575 }
576
577 let weights = Array1::ones(data.nrows());
579 let dense_design_hmc = design.design.to_dense();
580 let p = dense_design_hmc.ncols();
581 let fit = fit_result_from_saved_model_for_prediction(model)?;
582 refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
593 if fit.beta.len() != p {
594 return Err(format!(
595 "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
596 fit.beta.len(),
597 p,
598 ));
599 }
600 if fit.lambdas.len() != design.penalties.len() {
601 return Err(format!(
602 "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
603 fit.lambdas.len(),
604 design.penalties.len(),
605 ));
606 }
607 let penalty =
608 weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
609
610 let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
615 Some(name) => {
616 let idx = resolve_role_col(col_map, name, "offset")?;
617 Some(data.column(idx).to_owned())
618 }
619 None => None,
620 };
621
622 run_nuts_sampling_flattened_family(
623 likelihood,
624 FamilyNutsInputs::Glm(GlmFlatInputs {
625 x: dense_design_hmc.view(),
626 y: y.view(),
627 weights: weights.view(),
628 penalty_matrix: penalty.view(),
629 mode: fit.beta.view(),
630 hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
631 gamma_shape: fit.likelihood_scale.gamma_shape(),
632 dispersion: fit.dispersion().unwrap_or_default(),
636 firth_bias_reduction: false,
637 offset: offset_vec.as_ref().map(|o| o.view()),
638 }),
639 cfg,
640 )
641 .map_err(|e| format!("NUTS sampling failed: {e}"))
642}
643
644fn sample_standard_bounded(
654 model: &SavedModel,
655 cfg: &NutsConfig,
656 bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
657) -> Result<NutsResult, String> {
658 validate_nuts_config(cfg).map_err(String::from)?;
659 let fit = fit_result_from_saved_model_for_prediction(model)?;
660 let mode = fit.beta.clone();
661 let p = mode.len();
662 if p == 0 {
663 return Err(
664 "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
665 .to_string(),
666 );
667 }
668 let user_hessian =
673 explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
674 let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
681 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
682 let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
683 &mode,
684 user_hessian,
685 bounded_columns,
686 n_total,
687 sqrt_cov_scale,
688 chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
689 )
690 .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
691
692 let posterior_mean = samples
693 .mean_axis(ndarray::Axis(0))
694 .unwrap_or_else(|| Array1::<f64>::zeros(p));
695 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
696
697 Ok(NutsResult {
698 samples,
699 posterior_mean,
700 posterior_std,
701 rhat: 1.0,
702 ess: n_total as f64,
703 converged: true,
704 })
705}
706
707fn sample_standard_truncated(
721 model: &SavedModel,
722 cfg: &NutsConfig,
723 constraints: &gam_solve::pirls::LinearInequalityConstraints,
724) -> Result<NutsResult, String> {
725 validate_nuts_config(cfg).map_err(String::from)?;
726 let fit = fit_result_from_saved_model_for_prediction(model)?;
727 let mode = fit.beta.clone();
728 let p = mode.len();
729 if p == 0 {
730 return Err(
731 "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
732 vector"
733 .to_string(),
734 );
735 }
736 let penalized_hessian =
742 explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
743 let sqrt_phi = {
744 use gam_problem::dispersion_cov::DispersionExt as _;
745 fit.dispersion().unwrap_or_default().sqrt_phi()
746 };
747 let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
748 &mode,
749 &penalized_hessian,
750 sqrt_phi,
751 constraints,
752 cfg.n_samples,
753 cfg.n_chains,
754 chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
755 )?;
756 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
757
758 let posterior_mean = samples
759 .mean_axis(ndarray::Axis(0))
760 .unwrap_or_else(|| Array1::<f64>::zeros(p));
761 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
762
763 Ok(NutsResult {
764 samples,
765 posterior_mean,
766 posterior_std,
767 rhat: 1.0,
768 ess: n_total as f64,
769 converged: true,
770 })
771}
772
773fn sample_standard_link_wiggle(
774 model: &SavedModel,
775 data: ArrayView2<'_, f64>,
776 col_map: &HashMap<String, usize>,
777 training_headers: Option<&Vec<String>>,
778 likelihood: LikelihoodSpec,
779 cfg: &NutsConfig,
780) -> Result<NutsResult, String> {
781 let parsed = parse_formula(&model.formula)?;
782 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
783 let y = data.column(y_col).to_owned();
784
785 let spec = resolve_termspec_for_prediction(
786 &model.resolved_termspec,
787 training_headers,
788 col_map,
789 "resolved_termspec",
790 )?;
791 let design = build_term_collection_design(data, &spec)
792 .map_err(|e| format!("failed to build term collection design: {e}"))?;
793 let p_main = design.design.ncols();
794
795 let fit = fit_result_from_saved_model_for_prediction(model)?;
796 let wiggle_runtime = model
797 .saved_prediction_runtime()?
798 .link_wiggle
799 .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
800 let mode_beta = fit
801 .block_by_role(BlockRole::Mean)
802 .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
803 .beta
804 .clone();
805 let mode_theta = fit
806 .block_by_role(BlockRole::LinkWiggle)
807 .ok_or_else(|| {
808 "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
809 })?
810 .beta
811 .clone();
812 let p_wiggle = mode_theta.len();
813 let p_total = mode_beta.len() + p_wiggle;
814
815 if mode_beta.len() != p_main {
816 return Err(format!(
817 "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
818 mode_beta.len(),
819 p_main,
820 ));
821 }
822 if fit.beta.len() != p_total {
823 return Err(format!(
824 "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
825 fit.beta.len(),
826 p_main,
827 p_wiggle,
828 p_total,
829 ));
830 }
831
832 let hessian = &fit
833 .geometry
834 .as_ref()
835 .ok_or_else(|| {
836 "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
837 .to_string()
838 })?
839 .penalized_hessian;
840 validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
841
842 let n_base_penalties = design.penalties.len();
843 let base_lambdas = fit
844 .block_by_role(BlockRole::Mean)
845 .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
846 .lambdas
847 .view();
848 if base_lambdas.len() != n_base_penalties {
849 return Err(format!(
850 "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
851 base_lambdas.len(),
852 n_base_penalties,
853 ));
854 }
855
856 let penalty_base =
857 weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
858
859 let wiggle_lambdas_owned = fit
860 .lambdas_linkwiggle()
861 .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
862 let wiggle_lambdas = wiggle_lambdas_owned.view();
863 let degree = wiggle_runtime.degree;
864 let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
865
866 let mut wiggle_penalties = Vec::new();
867 let default_orders = [2usize];
868 let n_wiggle_lambdas = wiggle_lambdas.len();
869 for k in 0..n_wiggle_lambdas {
870 let order = if k < default_orders.len() {
871 default_orders[k]
872 } else {
873 k + 1
874 };
875 if order >= p_wiggle {
876 continue;
877 }
878 let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
879 .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
880 wiggle_penalties.push(penalty);
881 }
882 while wiggle_penalties.len() < n_wiggle_lambdas {
883 wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
884 }
885
886 let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
887
888 let q0 = design.design.dot(&mode_beta);
889 let (q0_min, q0_max) = q0
890 .iter()
891 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
892 (lo.min(v), hi.max(v))
893 });
894
895 let spline = LinkWiggleSplineArtifacts {
896 knot_range: (q0_min, q0_max),
897 knot_vector: knot_arr,
898 degree,
899 };
900
901 let nuts_family = match (&likelihood.response, &likelihood.link) {
902 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
903 NutsFamily::BinomialLogit
904 }
905 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
906 NutsFamily::BinomialProbit
907 }
908 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
909 NutsFamily::BinomialCLogLog
910 }
911 (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
912 (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
913 (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
914 (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
915 (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
916 _ => {
917 return Err(format!(
918 "NUTS sampling with link wiggle is not supported for family {}",
919 likelihood.pretty_name()
920 ));
921 }
922 };
923
924 let weights = Array1::ones(data.nrows());
925 let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
926
927 let wiggle_nuts_dense = design.design.as_dense_cow();
928 run_link_wiggle_nuts_sampling(
929 wiggle_nuts_dense.view(),
930 y.view(),
931 weights.view(),
932 penalty_base.view(),
933 penalty_link.view(),
934 mode_beta.view(),
935 mode_theta.view(),
936 hessian.view(),
937 spline,
938 nuts_family,
939 scale,
940 cfg,
941 )
942 .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
943}
944
945fn sample_survival(
946 model: &SavedModel,
947 data: ArrayView2<'_, f64>,
948 col_map: &HashMap<String, usize>,
949 training_headers: Option<&Vec<String>>,
950 cfg: &NutsConfig,
951) -> Result<NutsResult, String> {
952 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
953 if matches!(
954 saved_likelihood_mode,
955 SurvivalLikelihoodMode::Latent
956 | SurvivalLikelihoodMode::LatentBinary
957 | SurvivalLikelihoodMode::LocationScale
958 ) {
959 return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
960 }
961 let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
970 let exit_col = time_cols.exit_col;
971 let eventname = model
972 .survival_event
973 .as_ref()
974 .ok_or_else(|| "survival model missing event column metadata".to_string())?;
975 let event_col = resolve_role_col(col_map, eventname, "event")?;
976 let termspec = resolve_termspec_for_prediction(
977 &model.resolved_termspec,
978 training_headers,
979 col_map,
980 "resolved_termspec",
981 )?;
982 let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
983 let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
984 let cov_design = build_term_collection_design(cov_input, &termspec)
985 .map_err(|e| format!("failed to build survival design: {e}"))?;
986 let n = data.nrows();
987 let p_cov = cov_design.design.ncols();
988 let mut age_entry = Array1::<f64>::zeros(n);
989 let mut age_exit = Array1::<f64>::zeros(n);
990 let mut event_target = Array1::<u8>::zeros(n);
991 let event_competing = Array1::<u8>::zeros(n);
992 let weights = Array1::<f64>::ones(n);
993 for i in 0..n {
994 let (t0, t1) = normalize_survival_time_pair(
995 time_cols.row_entry_time(data, i),
996 data[[i, exit_col]],
997 i,
998 )?;
999 age_entry[i] = t0;
1000 age_exit[i] = t1;
1001 event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
1002 }
1003 let time_cfg = load_survival_time_basis_config_from_model(model)?;
1004 let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
1005 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
1006 &time_build.basisname,
1007 time_build.degree,
1008 time_build.knots.as_ref(),
1009 time_build.keep_cols.as_ref(),
1010 time_build.smooth_lambda,
1011 )?;
1012 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1013 let time_anchor = model
1014 .survival_time_anchor
1015 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1016 let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
1017 center_survival_time_designs_at_anchor(
1018 &mut time_build.x_entry_time,
1019 &mut time_build.x_exit_time,
1020 &time_anchor_row,
1021 )?;
1022 }
1023 let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1024 let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1025 build_survival_time_offsets_for_likelihood(
1026 &age_entry,
1027 &age_exit,
1028 &baseline_cfg,
1029 saved_likelihood_mode,
1030 None,
1031 )?;
1032 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1033 let time_anchor = model
1034 .survival_time_anchor
1035 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1036 add_survival_time_derivative_guard_offset(
1037 &age_entry,
1038 &age_exit,
1039 time_anchor,
1040 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1041 &mut eta_offset_entry,
1042 &mut eta_offset_exit,
1043 &mut derivative_offset_exit,
1044 )?;
1045 }
1046 let saved_timewiggle = saved_baseline_timewiggle_components(
1047 &eta_offset_entry,
1048 &eta_offset_exit,
1049 &derivative_offset_exit,
1050 model,
1051 )?;
1052 let p_time = time_build.x_exit_time.ncols();
1053 let p_timewiggle = saved_timewiggle
1054 .as_ref()
1055 .map(|(_, exit, _)| exit.ncols())
1056 .unwrap_or(0);
1057 let p = p_time + p_timewiggle + p_cov;
1058 let tb_entry_dense = time_build.x_entry_time.to_dense();
1059 let tb_exit_dense = time_build.x_exit_time.to_dense();
1060 let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1061 let mut x_entry = Array2::<f64>::zeros((n, p));
1062 let mut x_exit = Array2::<f64>::zeros((n, p));
1063 let mut x_derivative = Array2::<f64>::zeros((n, p));
1064 if p_time > 0 {
1065 x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1066 x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1067 x_derivative
1068 .slice_mut(s![.., ..p_time])
1069 .assign(&tb_deriv_dense);
1070 }
1071 if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1072 && p_timewiggle > 0
1073 {
1074 x_entry
1075 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1076 .assign(entry_w);
1077 x_exit
1078 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1079 .assign(exit_w);
1080 x_derivative
1081 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1082 .assign(deriv_w);
1083 }
1084 if p_cov > 0 {
1085 let cov_dense = cov_design.design.to_dense();
1086 let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1087 x_entry
1088 .slice_mut(s![.., cov_range.clone()])
1089 .assign(&cov_dense);
1090 x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1091 }
1092 let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1093 for (idx, s) in time_build.penalties.iter().enumerate() {
1094 if s.nrows() == p_time && s.ncols() == p_time {
1095 penalty_blocks.push(PenaltyBlock {
1096 matrix: s.clone(),
1097 lambda: time_build
1098 .smooth_lambda
1099 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1100 range: 0..p_time,
1101 nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1102 });
1103 }
1104 }
1105 let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1106 if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1107 let start = p_time;
1108 let end = start + exit_w.ncols();
1109 let wiggle_lambda_offset = penalty_blocks.len();
1110 let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1111 "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1112 })?;
1113 let wiggle_degree = wiggle_cfg.degree;
1114 let wiggle_knots =
1115 Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1116 "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1117 })?);
1118 let mut seed = Array1::<f64>::zeros(2 * n);
1119 for i in 0..n {
1120 seed[i] = eta_offset_entry[i];
1121 seed[n + i] = eta_offset_exit[i];
1122 }
1123 let (primary_order, extra_orders) =
1124 split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1125 let mut block = buildwiggle_block_input_from_knots(
1126 seed.view(),
1127 &wiggle_knots,
1128 wiggle_degree,
1129 primary_order,
1130 wiggle_cfg.double_penalty,
1131 )?;
1132 append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1133 .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1134 for (widx, s) in block.penalties.iter().enumerate() {
1135 let s = match s {
1136 gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1137 gam_solve::estimate::PenaltySpec::Dense(m)
1138 | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1139 };
1140 if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1141 penalty_blocks.push(PenaltyBlock {
1142 matrix: s.clone(),
1143 lambda: time_build
1144 .smooth_lambda
1145 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1146 range: start..end,
1147 nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1148 });
1149 }
1150 }
1151 for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1152 .iter_mut()
1153 .enumerate()
1154 {
1155 if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1156 block_penalty.lambda = lam;
1157 }
1158 }
1159 }
1160 let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1161 "saved survival model is missing survivalridge_lambda; refusing to \
1162 pick a load-time default (the historical 1e-4 fallback silently \
1163 disagreed with the 1e-6 fit-time default). Refit."
1164 .to_string()
1165 })?;
1166 let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1167 {
1168 1
1169 } else {
1170 0
1171 };
1172 if ridge_lambda > 0.0 && p > ridge_range_start {
1173 let dim = p - ridge_range_start;
1174 let mut ridge = Array2::<f64>::zeros((dim, dim));
1175 for d in 0..dim {
1176 ridge[[d, d]] = 1.0;
1177 }
1178 penalty_blocks.push(PenaltyBlock {
1179 matrix: ridge,
1180 lambda: ridge_lambda,
1181 range: ridge_range_start..p,
1182 nullspace_dim: 0,
1183 });
1184 }
1185 for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1186 if let Some(&lam) = fit_saved.lambdas.get(idx) {
1187 block.lambda = lam;
1188 }
1189 }
1190 let penalties = PenaltyBlocks::new(penalty_blocks);
1191 let survivalspec = match model
1192 .survivalspec
1193 .as_deref()
1194 .unwrap_or("net")
1195 .to_ascii_lowercase()
1196 .as_str()
1197 {
1198 "net" => SurvivalSpec::Net,
1199 "crude" => {
1200 return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1201 .to_string());
1202 }
1203 other => {
1204 return Err(format!("unsupported saved survival spec '{other}'"));
1205 }
1206 };
1207 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1208 let mut model_surv = royston_parmar::working_model_from_flattened(
1209 penalties.clone(),
1210 monotonicity,
1211 survivalspec,
1212 RoystonParmarInputs {
1213 age_entry: age_entry.view(),
1214 age_exit: age_exit.view(),
1215 event_target: event_target.view(),
1216 event_competing: event_competing.view(),
1217 weights: weights.view(),
1218 x_entry: x_entry.view(),
1219 x_exit: x_exit.view(),
1220 x_derivative: x_derivative.view(),
1221 monotonicity_constraint_rows: None,
1222 monotonicity_constraint_offsets: None,
1223 eta_offset_entry: Some(eta_offset_entry.view()),
1224 eta_offset_exit: Some(eta_offset_exit.view()),
1225 derivative_offset_exit: Some(derivative_offset_exit.view()),
1226 },
1227 )
1228 .map_err(|e| format!("failed to construct survival model: {e}"))?;
1229 if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1230 model_surv
1231 .set_structural_monotonicity(true, p_time + p_timewiggle)
1232 .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1233 }
1234 let beta0 = fit_saved.beta.clone();
1235 let state = model_surv
1236 .update_state(&beta0)
1237 .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1238 let hessian = state.hessian.to_dense();
1239 run_survival_nuts_sampling_flattened(
1240 SurvivalFlatInputs {
1241 age_entry: age_entry.view(),
1242 age_exit: age_exit.view(),
1243 event_target: event_target.view(),
1244 event_competing: event_competing.view(),
1245 weights: weights.view(),
1246 x_entry: x_entry.view(),
1247 x_exit: x_exit.view(),
1248 x_derivative: x_derivative.view(),
1249 eta_offset_entry: Some(eta_offset_entry.view()),
1250 eta_offset_exit: Some(eta_offset_exit.view()),
1251 derivative_offset_exit: Some(derivative_offset_exit.view()),
1252 },
1253 penalties,
1254 monotonicity,
1255 survivalspec,
1256 saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1257 p_time + p_timewiggle,
1258 beta0.view(),
1259 hessian.view(),
1260 cfg,
1261 )
1262 .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267 use super::*;
1268 use gam_problem::types::LikelihoodScaleMetadata;
1269
1270 #[test]
1281 fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1282 let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1285 let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1286
1287 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1288
1289 match likelihood.response {
1290 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1291 theta, 2.97,
1292 "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1293 ),
1294 other => panic!("expected NegativeBinomial response, got {other:?}"),
1295 }
1296 }
1297
1298 #[test]
1302 fn refresh_negbin_theta_fixed_theta_is_preserved() {
1303 let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1304 let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1305
1306 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1307
1308 match likelihood.response {
1309 ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1310 assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1311 assert!(theta_fixed, "theta_fixed flag must be preserved");
1312 }
1313 other => panic!("expected NegativeBinomial response, got {other:?}"),
1314 }
1315 }
1316
1317 #[test]
1321 fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1322 let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1323 refresh_negbin_theta_for_sampling(
1325 &mut likelihood,
1326 LikelihoodScaleMetadata::ProfiledGaussian,
1327 );
1328
1329 match likelihood.response {
1330 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1331 theta, 3.5,
1332 "with no fitted theta the NB seed must be kept verbatim"
1333 ),
1334 other => panic!("expected NegativeBinomial response, got {other:?}"),
1335 }
1336 }
1337
1338 #[test]
1342 fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1343 let mut poisson = LikelihoodSpec::poisson_log();
1344 let before = poisson.response.clone();
1345 refresh_negbin_theta_for_sampling(
1346 &mut poisson,
1347 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1348 );
1349 assert_eq!(
1350 poisson.response, before,
1351 "Poisson response must be untouched by the NB theta refresh"
1352 );
1353 }
1354}