1use std::collections::HashMap;
12
13use faer::Side;
14use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
15use rand::{RngExt, SeedableRng};
16
17use super::hmc_io::{
18 FamilyNutsInputs, GlmFlatInputs, LinkWiggleSplineArtifacts, NutsFamily, SurvivalFlatInputs,
19 explicit_fit_hessian_for_whitening, run_link_wiggle_nuts_sampling,
20 run_nuts_sampling_flattened_family, run_survival_nuts_sampling_flattened, validate_nuts_config,
21};
22pub use super::hmc_io::{NutsConfig, NutsResult};
23use gam_terms::basis::create_difference_penalty_matrix;
24use gam_solve::estimate::{BlockRole, UnifiedFitResult, validate_all_finite};
25use gam_linalg::faer_ndarray::FaerCholesky;
26use gam_models::survival::construction::{
27 SurvivalLikelihoodMode, add_survival_time_derivative_guard_offset, build_survival_time_basis,
28 build_survival_time_offsets_for_likelihood, center_survival_time_designs_at_anchor,
29 evaluate_survival_time_basis_row, normalize_survival_time_pair,
30 resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
31};
32use gam_models::survival::predict::{
33 fit_result_from_saved_model_for_prediction, require_saved_survival_likelihood_mode,
34 resolve_saved_survival_time_columns, resolve_termspec_for_prediction,
35 saved_baseline_timewiggle_components, saved_survival_runtime_baseline_config,
36};
37use gam_models::survival::royston_parmar::{self, RoystonParmarInputs};
38use gam_models::survival::{
39 PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
40};
41use gam_models::wiggle::{
42 append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_knots,
43 split_wiggle_penalty_orders,
44};
45use crate::formula_dsl::{LinkWiggleFormulaSpec, parse_formula};
46use crate::model::{
47 FittedModel as SavedModel, PredictModelClass, load_survival_time_basis_config_from_model,
48};
49use gam_linalg::triangular::back_substitution_lower_transpose_guarded_into;
50use gam_terms::smooth::build_term_collection_design;
51use gam_terms::smooth::{LinearCoefficientGeometry, weighted_blockwise_penalty_sum};
52use gam_terms::term_builder::resolve_role_col;
53use gam_problem::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
54
55pub fn saved_baseline_timewiggle_spec(
60 model: &SavedModel,
61) -> Result<Option<LinkWiggleFormulaSpec>, String> {
62 model
63 .saved_baseline_time_wiggle()
64 .map_err(|e| e.to_string())
65 .map(|runtime| {
66 runtime.map(|saved| LinkWiggleFormulaSpec {
67 degree: saved.degree,
68 num_internal_knots: saved.knots.len().saturating_sub(2 * (saved.degree + 1)),
69 penalty_orders: saved.penalty_orders,
70 double_penalty: saved.double_penalty,
71 })
72 })
73}
74
75fn weighted_penalty_matrix(
76 penalties: &[Array2<f64>],
77 lambdas: ArrayView1<'_, f64>,
78) -> Result<Array2<f64>, String> {
79 if penalties.len() != lambdas.len() {
80 return Err(format!(
81 "penalty/lambda mismatch: {} penalties vs {} lambdas",
82 penalties.len(),
83 lambdas.len()
84 ));
85 }
86 if penalties.is_empty() {
87 return Err("cannot sample without at least one penalty block".to_string());
88 }
89 let p = penalties[0].nrows();
90 let mut out = Array2::<f64>::zeros((p, p));
91 for (k, s) in penalties.iter().enumerate() {
92 if s.nrows() != p || s.ncols() != p {
93 return Err(format!(
94 "penalty block {k} shape mismatch: got {}x{}, expected {}x{}",
95 s.nrows(),
96 s.ncols(),
97 p,
98 p
99 ));
100 }
101 let lam = lambdas[k];
102 out += &(s * lam);
103 }
104 Ok(out)
105}
106
107fn validate_explicit_link_wiggle_joint_hessian(
108 hessian: &Array2<f64>,
109 expected_dim: usize,
110) -> Result<(), String> {
111 if hessian.nrows() != expected_dim || hessian.ncols() != expected_dim {
112 return Err(format!(
113 "link-wiggle sample: explicit joint Hessian is {}x{} but expected {}x{}",
114 hessian.nrows(),
115 hessian.ncols(),
116 expected_dim,
117 expected_dim,
118 ));
119 }
120 validate_all_finite(
121 "link-wiggle explicit joint Hessian",
122 hessian.iter().copied(),
123 )?;
124 let mut max_abs = 0.0_f64;
125 for r in 0..expected_dim {
126 for c in 0..expected_dim {
127 max_abs = max_abs.max(hessian[[r, c]].abs());
128 let scale = hessian[[r, c]].abs().max(hessian[[c, r]].abs()).max(1.0);
129 if (hessian[[r, c]] - hessian[[c, r]]).abs() > 1e-9 * scale {
130 return Err(format!(
131 "link-wiggle sample: explicit joint Hessian is not symmetric at ({r},{c})"
132 ));
133 }
134 }
135 }
136 if max_abs == 0.0 {
137 return Err("link-wiggle sample: explicit joint Hessian is all zeros; refit with exact Hessian export"
138 .to_string());
139 }
140 Ok(())
141}
142
143fn family_noise_parameter(fit: &UnifiedFitResult, likelihood: &LikelihoodSpec) -> Option<f64> {
152 crate::generative::family_noise_parameter(
153 fit.likelihood_scale,
154 fit.standard_deviation,
155 likelihood,
156 )
157}
158
159fn refresh_negbin_theta_for_sampling(
177 likelihood: &mut LikelihoodSpec,
178 scale: gam_problem::types::LikelihoodScaleMetadata,
179) {
180 if let ResponseFamily::NegativeBinomial { theta, .. } = &mut likelihood.response {
181 if let Some(theta_hat) = scale.negbin_theta() {
182 *theta = theta_hat;
183 }
184 }
185}
186
187fn likelihood_spec_for_saved_model(model: &SavedModel) -> Result<LikelihoodSpec, String> {
191 Ok(model.likelihood())
192}
193
194const DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA: f64 = 1e-2;
199
200#[inline]
201const fn splitmix64(x: u64) -> u64 {
202 gam_linalg::utils::splitmix64_hash(x)
203}
204
205#[inline]
206const fn chain_stream_seed(seed: u64, chain: usize, stream: u64) -> u64 {
207 splitmix64(seed ^ stream ^ ((chain as u64).wrapping_mul(0xD1B5_4A32_D192_ED03)))
208}
209
210pub fn sample_saved_model(
229 model: &SavedModel,
230 data: ArrayView2<'_, f64>,
231 col_map: &HashMap<String, usize>,
232 training_headers: Option<&Vec<String>>,
233 cfg: &NutsConfig,
234) -> Result<NutsResult, String> {
235 validate_nuts_config(cfg).map_err(String::from)?;
244 let likelihood = likelihood_spec_for_saved_model(model)?;
245 match model.predict_model_class() {
246 PredictModelClass::Survival => {
247 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
253 if matches!(
254 saved_likelihood_mode,
255 SurvivalLikelihoodMode::Latent
256 | SurvivalLikelihoodMode::LatentBinary
257 | SurvivalLikelihoodMode::LocationScale
258 ) {
259 laplace_gaussian_fallback(model, cfg, "survival posterior fallback")
260 } else {
261 sample_survival(model, data, col_map, training_headers, cfg)
262 }
263 }
264 PredictModelClass::Standard => {
265 if matches!(likelihood.response, ResponseFamily::Beta { .. }) {
274 laplace_gaussian_fallback(model, cfg, "beta-regression posterior fallback")
275 } else {
276 sample_standard(model, data, col_map, training_headers, likelihood, cfg)
277 }
278 }
279 PredictModelClass::GaussianLocationScale => {
289 laplace_gaussian_fallback(model, cfg, "gaussian location-scale posterior")
290 }
291 PredictModelClass::BinomialLocationScale => {
292 laplace_gaussian_fallback(model, cfg, "binomial location-scale posterior")
293 }
294 PredictModelClass::DispersionLocationScale => {
295 laplace_gaussian_fallback(model, cfg, "dispersion location-scale posterior")
296 }
297 PredictModelClass::BernoulliMarginalSlope => {
298 laplace_gaussian_fallback(model, cfg, "bernoulli marginal-slope posterior")
299 }
300 PredictModelClass::TransformationNormal => {
301 laplace_gaussian_fallback(model, cfg, "transformation-normal posterior")
302 }
303 }
304}
305
306pub fn laplace_gaussian_fallback(
320 model: &SavedModel,
321 cfg: &NutsConfig,
322 rationale: &'static str,
323) -> Result<NutsResult, String> {
324 use gam_problem::dispersion_cov::DispersionExt as _;
325 validate_nuts_config(cfg).map_err(String::from)?;
330 let fit = fit_result_from_saved_model_for_prediction(model)?;
331 let mode = fit.beta.clone();
332 let p = mode.len();
333 if p == 0 {
334 return Err(format!(
335 "{rationale}: cannot sample from an empty coefficient vector"
336 ));
337 }
338 let h = fit.penalized_hessian().ok_or_else(|| {
339 format!(
340 "{rationale}: posterior fallback requires the explicit penalised Hessian; \
341 refit with exact geometry export to enable posterior sampling for this class."
342 )
343 })?;
344 let dispersion = fit.dispersion().unwrap_or_default();
351 let sqrt_phi = dispersion.sqrt_phi();
352 if h.nrows() != p || h.ncols() != p {
353 return Err(format!(
354 "{rationale}: penalised Hessian is {}x{}, expected {}x{}",
355 h.nrows(),
356 h.ncols(),
357 p,
358 p
359 ));
360 }
361 let chol = h.cholesky(Side::Lower).map_err(|err| {
362 format!("{rationale}: Cholesky factorisation of the penalised Hessian failed: {err:?}")
363 })?;
364 let l = chol.lower_triangular();
365
366 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
370 let mut samples = Array2::<f64>::zeros((n_total, p));
371 let mut eps = Array1::<f64>::zeros(p);
372 let mut delta = Array1::<f64>::zeros(p);
373 for chain in 0..cfg.n_chains {
374 let mut rng = rand::rngs::StdRng::seed_from_u64(chain_stream_seed(
375 cfg.seed,
376 chain,
377 0xA0B7_6C5D_E431_298F,
378 ));
379 for draw in 0..cfg.n_samples {
380 let k = chain * cfg.n_samples + draw;
381 for i in 0..p {
382 eps[i] = sample_standard_normal(&mut rng);
383 }
384 back_substitution_lower_transpose_guarded_into(&l, &eps, &mut delta);
385 for i in 0..p {
386 samples[(k, i)] = mode[i] + sqrt_phi * delta[i];
390 }
391 }
392 }
393
394 let posterior_mean = samples
395 .mean_axis(ndarray::Axis(0))
396 .unwrap_or_else(|| Array1::<f64>::zeros(p));
397 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
398
399 Ok(NutsResult {
400 samples,
401 posterior_mean,
402 posterior_std,
403 rhat: 1.0,
404 ess: n_total as f64,
405 converged: true,
406 })
407}
408
409#[inline]
410fn sample_standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
411 let u1 = rng.random::<f64>().max(1e-16);
415 let u2 = rng.random::<f64>();
416 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
417}
418
419fn sample_standard(
420 model: &SavedModel,
421 data: ArrayView2<'_, f64>,
422 col_map: &HashMap<String, usize>,
423 training_headers: Option<&Vec<String>>,
424 mut likelihood: LikelihoodSpec,
425 cfg: &NutsConfig,
426) -> Result<NutsResult, String> {
427 let needs_constraint_aware_sampler = model.resolved_termspec.as_ref().is_some_and(|ts| {
443 ts.linear_terms.iter().any(|term| {
444 !matches!(
445 term.coefficient_geometry,
446 LinearCoefficientGeometry::Unconstrained
447 ) || term.coefficient_min.is_some()
448 || term.coefficient_max.is_some()
449 }) || ts
450 .smooth_terms
451 .iter()
452 .any(|term| !matches!(term.shape, gam_terms::smooth::ShapeConstraint::None))
453 });
454 if likelihood.is_gaussian_identity() && !needs_constraint_aware_sampler {
455 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
456 }
457 if model.has_link_wiggle() {
458 if likelihood.is_gaussian_identity() {
466 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
467 }
468 return sample_standard_link_wiggle(
469 model,
470 data,
471 col_map,
472 training_headers,
473 likelihood,
474 cfg,
475 );
476 }
477 let parsed = parse_formula(&model.formula)?;
478 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
479 let y = data.column(y_col).to_owned();
480 let spec = resolve_termspec_for_prediction(
481 &model.resolved_termspec,
482 training_headers,
483 col_map,
484 "resolved_termspec",
485 )?;
486 let design = build_term_collection_design(data, &spec)
487 .map_err(|e| format!("failed to build term collection design: {e}"))?;
488
489 let has_bounded = spec.linear_terms.iter().any(|term| {
518 matches!(
519 term.coefficient_geometry,
520 LinearCoefficientGeometry::Bounded { .. }
521 )
522 });
523 if has_bounded {
524 let bounded_columns: Vec<gam_models::fit_orchestration::drivers::BoundedSampleColumn> = spec
529 .linear_terms
530 .iter()
531 .enumerate()
532 .filter_map(|(j, term)| match term.coefficient_geometry {
533 LinearCoefficientGeometry::Bounded { min, max, .. } => {
534 Some(gam_models::fit_orchestration::drivers::BoundedSampleColumn {
535 col_idx: design.intercept_range.end + j,
536 min,
537 max,
538 })
539 }
540 LinearCoefficientGeometry::Unconstrained => None,
541 })
542 .collect();
543 return sample_standard_bounded(model, cfg, &bounded_columns);
544 }
545
546 if let Some(constraints) = design
556 .linear_constraints
557 .as_ref()
558 .filter(|c| c.a.nrows() > 0)
559 {
560 return sample_standard_truncated(model, cfg, constraints);
561 }
562
563 if likelihood.is_gaussian_identity() {
565 return laplace_gaussian_fallback(model, cfg, "standard gaussian posterior");
566 }
567
568 let weights = Array1::ones(data.nrows());
570 let dense_design_hmc = design.design.to_dense();
571 let p = dense_design_hmc.ncols();
572 let fit = fit_result_from_saved_model_for_prediction(model)?;
573 refresh_negbin_theta_for_sampling(&mut likelihood, fit.likelihood_scale);
584 if fit.beta.len() != p {
585 return Err(format!(
586 "standard sample: saved model has {} coefficients but rebuilt design has {} columns",
587 fit.beta.len(),
588 p,
589 ));
590 }
591 if fit.lambdas.len() != design.penalties.len() {
592 return Err(format!(
593 "standard sample: saved model has {} lambdas but rebuilt design has {} penalties",
594 fit.lambdas.len(),
595 design.penalties.len(),
596 ));
597 }
598 let penalty =
599 weighted_blockwise_penalty_sum(&design.penalties, fit.lambdas.as_slice().unwrap(), p);
600
601 let offset_vec: Option<Array1<f64>> = match model.offset_column.as_deref() {
606 Some(name) => {
607 let idx = resolve_role_col(col_map, name, "offset")?;
608 Some(data.column(idx).to_owned())
609 }
610 None => None,
611 };
612
613 run_nuts_sampling_flattened_family(
614 likelihood,
615 FamilyNutsInputs::Glm(GlmFlatInputs {
616 x: dense_design_hmc.view(),
617 y: y.view(),
618 weights: weights.view(),
619 penalty_matrix: penalty.view(),
620 mode: fit.beta.view(),
621 hessian: explicit_fit_hessian_for_whitening(&fit, p, "saved standard model")?.view(),
622 gamma_shape: fit.likelihood_scale.gamma_shape(),
623 dispersion: fit.dispersion().unwrap_or_default(),
627 firth_bias_reduction: false,
628 offset: offset_vec.as_ref().map(|o| o.view()),
629 }),
630 cfg,
631 )
632 .map_err(|e| format!("NUTS sampling failed: {e}"))
633}
634
635fn sample_standard_bounded(
645 model: &SavedModel,
646 cfg: &NutsConfig,
647 bounded_columns: &[gam_models::fit_orchestration::drivers::BoundedSampleColumn],
648) -> Result<NutsResult, String> {
649 validate_nuts_config(cfg).map_err(String::from)?;
650 let fit = fit_result_from_saved_model_for_prediction(model)?;
651 let mode = fit.beta.clone();
652 let p = mode.len();
653 if p == 0 {
654 return Err(
655 "standard bounded-coefficient posterior: cannot sample from an empty coefficient vector"
656 .to_string(),
657 );
658 }
659 let user_hessian =
664 explicit_fit_hessian_for_whitening(&fit, p, "saved standard bounded-coefficient model")?;
665 let sqrt_cov_scale = fit.coefficient_covariance_scale().max(0.0).sqrt();
672 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
673 let samples = gam_models::fit_orchestration::drivers::sample_bounded_latent_posterior_internal(
674 &mode,
675 user_hessian,
676 bounded_columns,
677 n_total,
678 sqrt_cov_scale,
679 chain_stream_seed(cfg.seed, 0, 0xB0DD_ED5E_ED90_1A7Cu64),
680 )
681 .map_err(|e| format!("standard bounded-coefficient posterior sampling failed: {e}"))?;
682
683 let posterior_mean = samples
684 .mean_axis(ndarray::Axis(0))
685 .unwrap_or_else(|| Array1::<f64>::zeros(p));
686 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
687
688 Ok(NutsResult {
689 samples,
690 posterior_mean,
691 posterior_std,
692 rhat: 1.0,
693 ess: n_total as f64,
694 converged: true,
695 })
696}
697
698fn sample_standard_truncated(
712 model: &SavedModel,
713 cfg: &NutsConfig,
714 constraints: &gam_solve::pirls::LinearInequalityConstraints,
715) -> Result<NutsResult, String> {
716 validate_nuts_config(cfg).map_err(String::from)?;
717 let fit = fit_result_from_saved_model_for_prediction(model)?;
718 let mode = fit.beta.clone();
719 let p = mode.len();
720 if p == 0 {
721 return Err(
722 "standard constrained-coefficient posterior: cannot sample from an empty coefficient \
723 vector"
724 .to_string(),
725 );
726 }
727 let penalized_hessian =
733 explicit_fit_hessian_for_whitening(&fit, p, "saved standard constrained model")?;
734 let sqrt_phi = {
735 use gam_problem::dispersion_cov::DispersionExt as _;
736 fit.dispersion().unwrap_or_default().sqrt_phi()
737 };
738 let samples = crate::truncated_gaussian::sample_truncated_gaussian_posterior(
739 &mode,
740 &penalized_hessian,
741 sqrt_phi,
742 constraints,
743 cfg.n_samples,
744 cfg.n_chains,
745 chain_stream_seed(cfg.seed, 0, 0x7290_C047_5D6E_B14Du64),
746 )?;
747 let n_total = cfg.n_samples.saturating_mul(cfg.n_chains);
748
749 let posterior_mean = samples
750 .mean_axis(ndarray::Axis(0))
751 .unwrap_or_else(|| Array1::<f64>::zeros(p));
752 let posterior_std = samples.std_axis(ndarray::Axis(0), 1.0);
753
754 Ok(NutsResult {
755 samples,
756 posterior_mean,
757 posterior_std,
758 rhat: 1.0,
759 ess: n_total as f64,
760 converged: true,
761 })
762}
763
764fn sample_standard_link_wiggle(
765 model: &SavedModel,
766 data: ArrayView2<'_, f64>,
767 col_map: &HashMap<String, usize>,
768 training_headers: Option<&Vec<String>>,
769 likelihood: LikelihoodSpec,
770 cfg: &NutsConfig,
771) -> Result<NutsResult, String> {
772 let parsed = parse_formula(&model.formula)?;
773 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
774 let y = data.column(y_col).to_owned();
775
776 let spec = resolve_termspec_for_prediction(
777 &model.resolved_termspec,
778 training_headers,
779 col_map,
780 "resolved_termspec",
781 )?;
782 let design = build_term_collection_design(data, &spec)
783 .map_err(|e| format!("failed to build term collection design: {e}"))?;
784 let p_main = design.design.ncols();
785
786 let fit = fit_result_from_saved_model_for_prediction(model)?;
787 let wiggle_runtime = model
788 .saved_prediction_runtime()?
789 .link_wiggle
790 .ok_or_else(|| "link-wiggle model is missing wiggle runtime metadata".to_string())?;
791 let mode_beta = fit
792 .block_by_role(BlockRole::Mean)
793 .ok_or_else(|| "standard link-wiggle model is missing Mean coefficient block".to_string())?
794 .beta
795 .clone();
796 let mode_theta = fit
797 .block_by_role(BlockRole::LinkWiggle)
798 .ok_or_else(|| {
799 "standard link-wiggle model is missing LinkWiggle coefficient block".to_string()
800 })?
801 .beta
802 .clone();
803 let p_wiggle = mode_theta.len();
804 let p_total = mode_beta.len() + p_wiggle;
805
806 if mode_beta.len() != p_main {
807 return Err(format!(
808 "link-wiggle sample: saved mean block has {} coefficients but rebuilt design has {} columns",
809 mode_beta.len(),
810 p_main,
811 ));
812 }
813 if fit.beta.len() != p_total {
814 return Err(format!(
815 "link-wiggle sample: saved beta has {} coefficients but design has {} main + {} wiggle = {} total",
816 fit.beta.len(),
817 p_main,
818 p_wiggle,
819 p_total,
820 ));
821 }
822
823 let hessian = &fit
824 .geometry
825 .as_ref()
826 .ok_or_else(|| {
827 "link-wiggle model is missing explicit joint Hessian geometry; refit with exact Hessian export"
828 .to_string()
829 })?
830 .penalized_hessian;
831 validate_explicit_link_wiggle_joint_hessian(hessian, p_total)?;
832
833 let n_base_penalties = design.penalties.len();
834 let base_lambdas = fit
835 .block_by_role(BlockRole::Mean)
836 .ok_or_else(|| "standard link-wiggle model is missing Mean block lambdas".to_string())?
837 .lambdas
838 .view();
839 if base_lambdas.len() != n_base_penalties {
840 return Err(format!(
841 "link-wiggle sample: mean block has {} lambdas but rebuilt design has {} base penalties",
842 base_lambdas.len(),
843 n_base_penalties,
844 ));
845 }
846
847 let penalty_base =
848 weighted_blockwise_penalty_sum(&design.penalties, base_lambdas.as_slice().unwrap(), p_main);
849
850 let wiggle_lambdas_owned = fit
851 .lambdas_linkwiggle()
852 .ok_or_else(|| "standard link-wiggle model is missing LinkWiggle lambdas".to_string())?;
853 let wiggle_lambdas = wiggle_lambdas_owned.view();
854 let degree = wiggle_runtime.degree;
855 let knot_arr = Array1::from_vec(wiggle_runtime.knots.clone());
856
857 let mut wiggle_penalties = Vec::new();
858 let default_orders = [2usize];
859 let n_wiggle_lambdas = wiggle_lambdas.len();
860 for k in 0..n_wiggle_lambdas {
861 let order = if k < default_orders.len() {
862 default_orders[k]
863 } else {
864 k + 1
865 };
866 if order >= p_wiggle {
867 continue;
868 }
869 let penalty = create_difference_penalty_matrix(p_wiggle, order, None)
870 .map_err(|e| format!("wiggle difference penalty failed: {e}"))?;
871 wiggle_penalties.push(penalty);
872 }
873 while wiggle_penalties.len() < n_wiggle_lambdas {
874 wiggle_penalties.push(Array2::zeros((p_wiggle, p_wiggle)));
875 }
876
877 let penalty_link = weighted_penalty_matrix(&wiggle_penalties, wiggle_lambdas)?;
878
879 let q0 = design.design.dot(&mode_beta);
880 let (q0_min, q0_max) = q0
881 .iter()
882 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
883 (lo.min(v), hi.max(v))
884 });
885
886 let spline = LinkWiggleSplineArtifacts {
887 knot_range: (q0_min, q0_max),
888 knot_vector: knot_arr,
889 degree,
890 };
891
892 let nuts_family = match (&likelihood.response, &likelihood.link) {
893 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
894 NutsFamily::BinomialLogit
895 }
896 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
897 NutsFamily::BinomialProbit
898 }
899 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
900 NutsFamily::BinomialCLogLog
901 }
902 (ResponseFamily::Gaussian, _) => NutsFamily::Gaussian,
903 (ResponseFamily::Poisson, _) => NutsFamily::PoissonLog,
904 (ResponseFamily::Tweedie { .. }, _) => NutsFamily::TweedieLog,
905 (ResponseFamily::NegativeBinomial { .. }, _) => NutsFamily::NegativeBinomialLog,
906 (ResponseFamily::Gamma, _) => NutsFamily::GammaLog,
907 _ => {
908 return Err(format!(
909 "NUTS sampling with link wiggle is not supported for family {}",
910 likelihood.pretty_name()
911 ));
912 }
913 };
914
915 let weights = Array1::ones(data.nrows());
916 let scale = family_noise_parameter(&fit, &likelihood).unwrap_or(fit.standard_deviation);
917
918 let wiggle_nuts_dense = design.design.as_dense_cow();
919 run_link_wiggle_nuts_sampling(
920 wiggle_nuts_dense.view(),
921 y.view(),
922 weights.view(),
923 penalty_base.view(),
924 penalty_link.view(),
925 mode_beta.view(),
926 mode_theta.view(),
927 hessian.view(),
928 spline,
929 nuts_family,
930 scale,
931 cfg,
932 )
933 .map_err(|e| format!("link-wiggle NUTS sampling failed: {e}"))
934}
935
936fn sample_survival(
937 model: &SavedModel,
938 data: ArrayView2<'_, f64>,
939 col_map: &HashMap<String, usize>,
940 training_headers: Option<&Vec<String>>,
941 cfg: &NutsConfig,
942) -> Result<NutsResult, String> {
943 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
944 if matches!(
945 saved_likelihood_mode,
946 SurvivalLikelihoodMode::Latent
947 | SurvivalLikelihoodMode::LatentBinary
948 | SurvivalLikelihoodMode::LocationScale
949 ) {
950 return laplace_gaussian_fallback(model, cfg, "survival posterior fallback");
951 }
952 let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
961 let exit_col = time_cols.exit_col;
962 let eventname = model
963 .survival_event
964 .as_ref()
965 .ok_or_else(|| "survival model missing event column metadata".to_string())?;
966 let event_col = resolve_role_col(col_map, eventname, "event")?;
967 let termspec = resolve_termspec_for_prediction(
968 &model.resolved_termspec,
969 training_headers,
970 col_map,
971 "resolved_termspec",
972 )?;
973 let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
974 let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
975 let cov_design = build_term_collection_design(cov_input, &termspec)
976 .map_err(|e| format!("failed to build survival design: {e}"))?;
977 let n = data.nrows();
978 let p_cov = cov_design.design.ncols();
979 let mut age_entry = Array1::<f64>::zeros(n);
980 let mut age_exit = Array1::<f64>::zeros(n);
981 let mut event_target = Array1::<u8>::zeros(n);
982 let event_competing = Array1::<u8>::zeros(n);
983 let weights = Array1::<f64>::ones(n);
984 for i in 0..n {
985 let (t0, t1) = normalize_survival_time_pair(
986 time_cols.row_entry_time(data, i),
987 data[[i, exit_col]],
988 i,
989 )?;
990 age_entry[i] = t0;
991 age_exit[i] = t1;
992 event_target[i] = if data[[i, event_col]] >= 0.5 { 1 } else { 0 };
993 }
994 let time_cfg = load_survival_time_basis_config_from_model(model)?;
995 let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
996 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
997 &time_build.basisname,
998 time_build.degree,
999 time_build.knots.as_ref(),
1000 time_build.keep_cols.as_ref(),
1001 time_build.smooth_lambda,
1002 )?;
1003 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1004 let time_anchor = model
1005 .survival_time_anchor
1006 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1007 let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
1008 center_survival_time_designs_at_anchor(
1009 &mut time_build.x_entry_time,
1010 &mut time_build.x_exit_time,
1011 &time_anchor_row,
1012 )?;
1013 }
1014 let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1015 let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
1016 build_survival_time_offsets_for_likelihood(
1017 &age_entry,
1018 &age_exit,
1019 &baseline_cfg,
1020 saved_likelihood_mode,
1021 None,
1022 )?;
1023 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
1024 let time_anchor = model
1025 .survival_time_anchor
1026 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1027 add_survival_time_derivative_guard_offset(
1028 &age_entry,
1029 &age_exit,
1030 time_anchor,
1031 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
1032 &mut eta_offset_entry,
1033 &mut eta_offset_exit,
1034 &mut derivative_offset_exit,
1035 )?;
1036 }
1037 let saved_timewiggle = saved_baseline_timewiggle_components(
1038 &eta_offset_entry,
1039 &eta_offset_exit,
1040 &derivative_offset_exit,
1041 model,
1042 )?;
1043 let p_time = time_build.x_exit_time.ncols();
1044 let p_timewiggle = saved_timewiggle
1045 .as_ref()
1046 .map(|(_, exit, _)| exit.ncols())
1047 .unwrap_or(0);
1048 let p = p_time + p_timewiggle + p_cov;
1049 let tb_entry_dense = time_build.x_entry_time.to_dense();
1050 let tb_exit_dense = time_build.x_exit_time.to_dense();
1051 let tb_deriv_dense = time_build.x_derivative_time.to_dense();
1052 let mut x_entry = Array2::<f64>::zeros((n, p));
1053 let mut x_exit = Array2::<f64>::zeros((n, p));
1054 let mut x_derivative = Array2::<f64>::zeros((n, p));
1055 if p_time > 0 {
1056 x_entry.slice_mut(s![.., ..p_time]).assign(&tb_entry_dense);
1057 x_exit.slice_mut(s![.., ..p_time]).assign(&tb_exit_dense);
1058 x_derivative
1059 .slice_mut(s![.., ..p_time])
1060 .assign(&tb_deriv_dense);
1061 }
1062 if let Some((entry_w, exit_w, deriv_w)) = saved_timewiggle.as_ref()
1063 && p_timewiggle > 0
1064 {
1065 x_entry
1066 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1067 .assign(entry_w);
1068 x_exit
1069 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1070 .assign(exit_w);
1071 x_derivative
1072 .slice_mut(s![.., p_time..(p_time + p_timewiggle)])
1073 .assign(deriv_w);
1074 }
1075 if p_cov > 0 {
1076 let cov_dense = cov_design.design.to_dense();
1077 let cov_range = (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov);
1078 x_entry
1079 .slice_mut(s![.., cov_range.clone()])
1080 .assign(&cov_dense);
1081 x_exit.slice_mut(s![.., cov_range]).assign(&cov_dense);
1082 }
1083 let mut penalty_blocks: Vec<PenaltyBlock> = Vec::new();
1084 for (idx, s) in time_build.penalties.iter().enumerate() {
1085 if s.nrows() == p_time && s.ncols() == p_time {
1086 penalty_blocks.push(PenaltyBlock {
1087 matrix: s.clone(),
1088 lambda: time_build
1089 .smooth_lambda
1090 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1091 range: 0..p_time,
1092 nullspace_dim: time_build.nullspace_dims.get(idx).copied().unwrap_or(0),
1093 });
1094 }
1095 }
1096 let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1097 if let Some((_, exit_w, _)) = saved_timewiggle.as_ref() {
1098 let start = p_time;
1099 let end = start + exit_w.ncols();
1100 let wiggle_lambda_offset = penalty_blocks.len();
1101 let wiggle_cfg = saved_baseline_timewiggle_spec(model)?.ok_or_else(|| {
1102 "saved baseline-timewiggle model missing baseline-timewiggle metadata".to_string()
1103 })?;
1104 let wiggle_degree = wiggle_cfg.degree;
1105 let wiggle_knots =
1106 Array1::from_vec(model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1107 "saved baseline-timewiggle model missing baseline_timewiggle_knots".to_string()
1108 })?);
1109 let mut seed = Array1::<f64>::zeros(2 * n);
1110 for i in 0..n {
1111 seed[i] = eta_offset_entry[i];
1112 seed[n + i] = eta_offset_exit[i];
1113 }
1114 let (primary_order, extra_orders) =
1115 split_wiggle_penalty_orders(2, &wiggle_cfg.penalty_orders);
1116 let mut block = buildwiggle_block_input_from_knots(
1117 seed.view(),
1118 &wiggle_knots,
1119 wiggle_degree,
1120 primary_order,
1121 wiggle_cfg.double_penalty,
1122 )?;
1123 append_selected_wiggle_penalty_orders(&mut block, &extra_orders)
1124 .map_err(|e| format!("baseline-timewiggle penalty reconstruction failed: {e}"))?;
1125 for (widx, s) in block.penalties.iter().enumerate() {
1126 let s = match s {
1127 gam_solve::estimate::PenaltySpec::Block { local, .. } => local,
1128 gam_solve::estimate::PenaltySpec::Dense(m)
1129 | gam_solve::estimate::PenaltySpec::DenseWithMean { matrix: m, .. } => m,
1130 };
1131 if s.nrows() == exit_w.ncols() && s.ncols() == exit_w.ncols() {
1132 penalty_blocks.push(PenaltyBlock {
1133 matrix: s.clone(),
1134 lambda: time_build
1135 .smooth_lambda
1136 .unwrap_or(DEFAULT_RECONSTRUCTED_SMOOTH_LAMBDA),
1137 range: start..end,
1138 nullspace_dim: block.nullspace_dims.get(widx).copied().unwrap_or(0),
1139 });
1140 }
1141 }
1142 for (local_idx, block_penalty) in penalty_blocks[wiggle_lambda_offset..]
1143 .iter_mut()
1144 .enumerate()
1145 {
1146 if let Some(&lam) = fit_saved.lambdas.get(wiggle_lambda_offset + local_idx) {
1147 block_penalty.lambda = lam;
1148 }
1149 }
1150 }
1151 let ridge_lambda = model.survivalridge_lambda.ok_or_else(|| {
1152 "saved survival model is missing survivalridge_lambda; refusing to \
1153 pick a load-time default (the historical 1e-4 fallback silently \
1154 disagreed with the 1e-6 fit-time default). Refit."
1155 .to_string()
1156 })?;
1157 let ridge_range_start = if time_build.basisname == "linear" && !model.has_baseline_time_wiggle()
1158 {
1159 1
1160 } else {
1161 0
1162 };
1163 if ridge_lambda > 0.0 && p > ridge_range_start {
1164 let dim = p - ridge_range_start;
1165 let mut ridge = Array2::<f64>::zeros((dim, dim));
1166 for d in 0..dim {
1167 ridge[[d, d]] = 1.0;
1168 }
1169 penalty_blocks.push(PenaltyBlock {
1170 matrix: ridge,
1171 lambda: ridge_lambda,
1172 range: ridge_range_start..p,
1173 nullspace_dim: 0,
1174 });
1175 }
1176 for (idx, block) in penalty_blocks.iter_mut().enumerate() {
1177 if let Some(&lam) = fit_saved.lambdas.get(idx) {
1178 block.lambda = lam;
1179 }
1180 }
1181 let penalties = PenaltyBlocks::new(penalty_blocks);
1182 let survivalspec = match model
1183 .survivalspec
1184 .as_deref()
1185 .unwrap_or("net")
1186 .to_ascii_lowercase()
1187 .as_str()
1188 {
1189 "net" => SurvivalSpec::Net,
1190 "crude" => {
1191 return Err("saved survival spec 'crude' is not supported by the one-hazard survival engine; refit or export a net survival model for this path"
1192 .to_string());
1193 }
1194 other => {
1195 return Err(format!("unsupported saved survival spec '{other}'"));
1196 }
1197 };
1198 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 0.0 };
1199 let mut model_surv = royston_parmar::working_model_from_flattened(
1200 penalties.clone(),
1201 monotonicity,
1202 survivalspec,
1203 RoystonParmarInputs {
1204 age_entry: age_entry.view(),
1205 age_exit: age_exit.view(),
1206 event_target: event_target.view(),
1207 event_competing: event_competing.view(),
1208 weights: weights.view(),
1209 x_entry: x_entry.view(),
1210 x_exit: x_exit.view(),
1211 x_derivative: x_derivative.view(),
1212 monotonicity_constraint_rows: None,
1213 monotonicity_constraint_offsets: None,
1214 eta_offset_entry: Some(eta_offset_entry.view()),
1215 eta_offset_exit: Some(eta_offset_exit.view()),
1216 derivative_offset_exit: Some(derivative_offset_exit.view()),
1217 },
1218 )
1219 .map_err(|e| format!("failed to construct survival model: {e}"))?;
1220 if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull {
1221 model_surv
1222 .set_structural_monotonicity(true, p_time + p_timewiggle)
1223 .map_err(|e| format!("failed to enable structural monotonicity: {e}"))?;
1224 }
1225 let beta0 = fit_saved.beta.clone();
1226 let state = model_surv
1227 .update_state(&beta0)
1228 .map_err(|e| format!("failed to evaluate survival state: {e}"))?;
1229 let hessian = state.hessian.to_dense();
1230 run_survival_nuts_sampling_flattened(
1231 SurvivalFlatInputs {
1232 age_entry: age_entry.view(),
1233 age_exit: age_exit.view(),
1234 event_target: event_target.view(),
1235 event_competing: event_competing.view(),
1236 weights: weights.view(),
1237 x_entry: x_entry.view(),
1238 x_exit: x_exit.view(),
1239 x_derivative: x_derivative.view(),
1240 eta_offset_entry: Some(eta_offset_entry.view()),
1241 eta_offset_exit: Some(eta_offset_exit.view()),
1242 derivative_offset_exit: Some(derivative_offset_exit.view()),
1243 },
1244 penalties,
1245 monotonicity,
1246 survivalspec,
1247 saved_likelihood_mode != SurvivalLikelihoodMode::Weibull,
1248 p_time + p_timewiggle,
1249 beta0.view(),
1250 hessian.view(),
1251 cfg,
1252 )
1253 .map_err(|e| format!("survival NUTS sampling failed: {e}"))
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258 use super::*;
1259 use gam_problem::types::LikelihoodScaleMetadata;
1260
1261 #[test]
1272 fn refresh_negbin_theta_reads_theta_hat_not_seed() {
1273 let mut likelihood = LikelihoodSpec::negative_binomial_log(1.0);
1276 let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 };
1277
1278 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1279
1280 match likelihood.response {
1281 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1282 theta, 2.97,
1283 "NB NUTS must sample at theta_hat (#1463), not the seed theta=1.0"
1284 ),
1285 other => panic!("expected NegativeBinomial response, got {other:?}"),
1286 }
1287 }
1288
1289 #[test]
1293 fn refresh_negbin_theta_fixed_theta_is_preserved() {
1294 let mut likelihood = LikelihoodSpec::negative_binomial_log_fixed(4.25);
1295 let scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta: 4.25 };
1296
1297 refresh_negbin_theta_for_sampling(&mut likelihood, scale);
1298
1299 match likelihood.response {
1300 ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
1301 assert_eq!(theta, 4.25, "fixed NB theta must survive the refresh");
1302 assert!(theta_fixed, "theta_fixed flag must be preserved");
1303 }
1304 other => panic!("expected NegativeBinomial response, got {other:?}"),
1305 }
1306 }
1307
1308 #[test]
1312 fn refresh_negbin_theta_falls_back_to_seed_when_unfitted() {
1313 let mut likelihood = LikelihoodSpec::negative_binomial_log(3.5);
1314 refresh_negbin_theta_for_sampling(
1316 &mut likelihood,
1317 LikelihoodScaleMetadata::ProfiledGaussian,
1318 );
1319
1320 match likelihood.response {
1321 ResponseFamily::NegativeBinomial { theta, .. } => assert_eq!(
1322 theta, 3.5,
1323 "with no fitted theta the NB seed must be kept verbatim"
1324 ),
1325 other => panic!("expected NegativeBinomial response, got {other:?}"),
1326 }
1327 }
1328
1329 #[test]
1333 fn refresh_negbin_theta_leaves_non_nb_families_untouched() {
1334 let mut poisson = LikelihoodSpec::poisson_log();
1335 let before = poisson.response.clone();
1336 refresh_negbin_theta_for_sampling(
1337 &mut poisson,
1338 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 9.0 },
1339 );
1340 assert_eq!(
1341 poisson.response, before,
1342 "Poisson response must be untouched by the NB theta refresh"
1343 );
1344 }
1345}