1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4 LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7 SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12 monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15 inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16 parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21 InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22 SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72 pub feature_column: String,
74 pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83 pub feature_columns: Vec<String>,
85 pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99 SchemaMismatch { reason: String },
103 PayloadCorrupt { reason: String },
107 MissingField { reason: String },
110 IncompatibleConfig { reason: String },
114 InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120 FittedModelError {
121 SchemaMismatch,
122 PayloadCorrupt,
123 MissingField,
124 IncompatibleConfig,
125 InvalidInput,
126 }
127}
128
129impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134 fn from(err: FittedModelError) -> Self {
135 gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136 }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140 fn from(err: FittedModelError) -> Self {
141 crate::survival::predict::SurvivalPredictError::ModelPayload {
142 context: "saved-model survival prediction payload",
143 source: err,
144 }
145 }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150 pub mean: f64,
151 pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156 if !self.mean.is_finite() {
157 return Err(FittedModelError::PayloadCorrupt {
158 reason: format!("{context} latent z mean must be finite"),
159 });
160 }
161 if !(self.sd.is_finite() && self.sd > 1e-12) {
162 return Err(FittedModelError::PayloadCorrupt {
163 reason: format!(
164 "{context} latent z sd must be finite and > 1e-12; got {}",
165 self.sd
166 ),
167 });
168 }
169 Ok::<(), _>(())
170 }
171
172 pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173 self.validate(context)?;
174 if z.iter().any(|value| !value.is_finite()) {
175 return Err(FittedModelError::PayloadCorrupt {
176 reason: format!("{context} requires finite z values"),
177 });
178 }
179 Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180 }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189 #[default]
190 FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195 #[serde(default)]
196 pub score_kind: TransformationScoreKind,
197 #[serde(default = "default_transformation_score_pit_clip_eps")]
198 pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202 TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206 pub fn finite_support_pit() -> Self {
207 Self {
208 score_kind: TransformationScoreKind::FiniteSupportPit,
209 clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210 }
211 }
212
213 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214 if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215 return Err(FittedModelError::IncompatibleConfig {
216 reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217 });
218 }
219 if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220 return Err(FittedModelError::IncompatibleConfig {
221 reason: format!(
222 "{context} requires PIT clip_eps in (0, 0.5), got {}",
223 self.clip_eps
224 ),
225 });
226 }
227 Ok::<(), _>(())
228 }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233 pub version: u32,
234 pub formula: String,
235 pub model_kind: ModelKind,
236 pub family_state: FittedFamily,
237 pub family: String,
238 #[serde(default)]
247 pub inference_notes: Vec<String>,
248 #[serde(default)]
249 pub used_device: bool,
250 #[serde(default)]
251 pub fit_result: Option<UnifiedFitResult>,
252 #[serde(default)]
254 pub unified: Option<UnifiedFitResult>,
255 #[serde(default)]
263 pub spline_scan: Option<SavedSplineScan>,
264 #[serde(default)]
270 pub residual_cascade: Option<SavedResidualCascade>,
271 #[serde(default)]
272 pub data_schema: Option<DataSchema>,
273 pub link: Option<InverseLink>,
274 #[serde(default)]
275 pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276 #[serde(default)]
277 pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278 #[serde(default)]
279 pub formula_noise: Option<String>,
280 #[serde(default)]
281 pub formula_logslope: Option<String>,
282 #[serde(default)]
283 pub formula_logslopes: Option<Vec<String>>,
284 #[serde(default)]
285 pub offset_column: Option<String>,
286 #[serde(default)]
287 pub noise_offset_column: Option<String>,
288 #[serde(default)]
289 pub beta_noise: Option<Vec<f64>>,
290 #[serde(default)]
291 pub noise_projection: Option<Vec<Vec<f64>>>,
292 #[serde(default)]
293 pub noise_center: Option<Vec<f64>>,
294 #[serde(default)]
295 pub noise_scale: Option<Vec<f64>>,
296 #[serde(default)]
297 pub noise_non_intercept_start: Option<usize>,
298 #[serde(default)]
302 pub noise_projection_ridge_alpha: Option<f64>,
303 #[serde(default)]
304 pub gaussian_response_scale: Option<f64>,
305 #[serde(default)]
306 pub linkwiggle_knots: Option<Vec<f64>>,
307 #[serde(default)]
308 pub linkwiggle_degree: Option<usize>,
309 #[serde(default)]
310 pub beta_link_wiggle: Option<Vec<f64>>,
311 #[serde(default)]
312 pub baseline_timewiggle_knots: Option<Vec<f64>>,
313 #[serde(default)]
314 pub baseline_timewiggle_degree: Option<usize>,
315 #[serde(default)]
316 pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
317 #[serde(default)]
318 pub baseline_timewiggle_double_penalty: Option<bool>,
319 #[serde(default)]
320 pub beta_baseline_timewiggle: Option<Vec<f64>>,
321 #[serde(default)]
322 pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
323 #[serde(default)]
324 pub z_column: Option<String>,
325 #[serde(default)]
326 pub z_columns: Option<Vec<String>>,
327 #[serde(default)]
328 pub latent_z_normalization: Option<SavedLatentZNormalization>,
329 #[serde(default)]
330 pub latent_score_contract: Option<SavedLatentScoreContract>,
331 #[serde(default)]
332 pub latent_measure: Option<LatentMeasureKind>,
333 #[serde(default)]
340 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
341 #[serde(default)]
350 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
351 #[serde(default)]
352 pub marginal_baseline: Option<f64>,
353 #[serde(default)]
354 pub logslope_baseline: Option<f64>,
355 #[serde(default)]
356 pub logslope_baselines: Option<Vec<f64>>,
357 #[serde(default)]
358 pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
359 #[serde(default)]
360 pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
361 #[serde(default)]
366 pub influence_absorber_width: Option<usize>,
367 #[serde(default)]
368 pub survival_entry: Option<String>,
369 #[serde(default)]
370 pub survival_exit: Option<String>,
371 #[serde(default)]
372 pub survival_event: Option<String>,
373 #[serde(default)]
374 pub survivalspec: Option<String>,
375 #[serde(default)]
376 pub survival_cause_count: Option<usize>,
377 #[serde(default)]
378 pub survival_endpoint_names: Option<Vec<String>>,
379 #[serde(default)]
380 pub survival_baseline_target: Option<String>,
381 #[serde(default)]
382 pub survival_baseline_scale: Option<f64>,
383 #[serde(default)]
384 pub survival_baseline_shape: Option<f64>,
385 #[serde(default)]
386 pub survival_baseline_rate: Option<f64>,
387 #[serde(default)]
388 pub survival_baseline_makeham: Option<f64>,
389 #[serde(default)]
390 pub survival_time_basis: Option<String>,
391 #[serde(default)]
392 pub survival_time_degree: Option<usize>,
393 #[serde(default)]
394 pub survival_time_knots: Option<Vec<f64>>,
395 #[serde(default)]
396 pub survival_time_keep_cols: Option<Vec<usize>>,
397 #[serde(default)]
398 pub survival_time_smooth_lambda: Option<f64>,
399 #[serde(default)]
400 pub survival_time_anchor: Option<f64>,
401 #[serde(default)]
402 pub survivalridge_lambda: Option<f64>,
403 #[serde(default)]
404 pub survival_likelihood: Option<String>,
405 #[serde(default)]
406 pub survival_beta_time: Option<Vec<f64>>,
407 #[serde(default)]
408 pub survival_beta_threshold: Option<Vec<f64>>,
409 #[serde(default)]
410 pub survival_beta_log_sigma: Option<Vec<f64>>,
411 #[serde(default)]
412 pub survival_noise_projection: Option<Vec<Vec<f64>>>,
413 #[serde(default)]
414 pub survival_noise_center: Option<Vec<f64>>,
415 #[serde(default)]
416 pub survival_noise_scale: Option<Vec<f64>>,
417 #[serde(default)]
418 pub survival_noise_non_intercept_start: Option<usize>,
419 #[serde(default)]
422 pub survival_noise_projection_ridge_alpha: Option<f64>,
423 #[serde(default)]
424 pub survival_distribution: Option<ResidualDistribution>,
425 #[serde(default)]
426 pub training_headers: Option<Vec<String>>,
427 #[serde(default, skip_serializing_if = "Option::is_none")]
439 pub training_table_kind: Option<String>,
440 #[serde(default)]
448 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
449 #[serde(default, skip_serializing_if = "Option::is_none")]
455 pub group_metadata: Option<GroupMetadata>,
456 #[serde(default, skip_serializing_if = "Vec::is_empty")]
463 pub deployment_extensions: Vec<SavedDeploymentExtension>,
464 #[serde(default)]
466 pub transformation_response_knots: Option<Vec<f64>>,
467 #[serde(default)]
469 pub transformation_response_transform: Option<Vec<Vec<f64>>>,
470 #[serde(default)]
472 pub transformation_response_degree: Option<usize>,
473 #[serde(default)]
475 pub transformation_response_median: Option<f64>,
476 #[serde(default)]
480 pub transformation_score_calibration: Option<TransformationScoreCalibration>,
481 #[serde(default)]
482 pub resolved_termspec: Option<TermCollectionSpec>,
483 #[serde(default)]
484 pub resolved_termspec_noise: Option<TermCollectionSpec>,
485 #[serde(default)]
486 pub resolved_termspec_logslope: Option<TermCollectionSpec>,
487 #[serde(default)]
488 pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
489 #[serde(default)]
490 pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
491 #[serde(default)]
503 pub gaussian_jackknife_plus:
504 Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
505 #[serde(default)]
519 pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
523pub struct SavedDeploymentExtension {
524 pub name: String,
525 pub kind: String,
526 pub term: String,
527 pub level: JsonValue,
528 pub level_bits: u64,
529 pub coefficient_index: usize,
530 pub coefficient_mean: f64,
531 pub coefficient_variance: f64,
532 #[serde(default, skip_serializing_if = "Option::is_none")]
533 pub metadata: Option<JsonValue>,
534 #[serde(default, skip_serializing_if = "Option::is_none")]
535 pub prior: Option<JsonValue>,
536}
537
538pub fn append_deployment_extension_columns(
555 model: &FittedModelPayload,
556 data: ndarray::ArrayView2<'_, f64>,
557 col_map: &HashMap<String, usize>,
558 training_headers: Option<&Vec<String>>,
559 base_design: Array2<f64>,
560) -> Result<Array2<f64>, FittedModelError> {
561 if model.deployment_extensions.is_empty() {
562 return Ok(base_design);
563 }
564 if base_design.nrows() != data.nrows() {
565 return Err(FittedModelError::SchemaMismatch {
566 reason: format!(
567 "deployment extension design row mismatch: base design has {} rows but data has {}",
568 base_design.nrows(),
569 data.nrows()
570 ),
571 });
572 }
573 let spec = model
574 .resolved_termspec
575 .as_ref()
576 .ok_or_else(|| FittedModelError::MissingField {
577 reason: "deployment extension prediction requires saved resolved_termspec; refit"
578 .to_string(),
579 })?;
580 let n = base_design.nrows();
581 let p_old = base_design.ncols();
582 let mut extensions: Vec<&SavedDeploymentExtension> =
583 model.deployment_extensions.iter().collect();
584 extensions.sort_by_key(|extension| extension.coefficient_index);
585 for (tail_idx, extension) in extensions.iter().enumerate() {
586 let expected = p_old + tail_idx;
587 if extension.coefficient_index != expected {
588 return Err(FittedModelError::SchemaMismatch {
589 reason: format!(
590 "deployment extension '{}' has coefficient index {}, expected append-only index {}",
591 extension.name, extension.coefficient_index, expected
592 ),
593 });
594 }
595 }
596
597 let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
598 out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
599 for (tail_idx, extension) in extensions.into_iter().enumerate() {
600 if extension.kind != "random-effect-level" {
601 return Err(FittedModelError::IncompatibleConfig {
602 reason: format!(
603 "unsupported deployment extension kind '{}' for '{}'",
604 extension.kind, extension.name
605 ),
606 });
607 }
608 let term = spec
609 .random_effect_terms
610 .iter()
611 .find(|term| term.name == extension.term)
612 .ok_or_else(|| FittedModelError::MissingField {
613 reason: format!(
614 "deployment extension '{}' references unknown random-effect term '{}'",
615 extension.name, extension.term
616 ),
617 })?;
618 let prediction_col = training_headers
619 .and_then(|headers| headers.get(term.feature_col))
620 .and_then(|name| col_map.get(name))
621 .copied()
622 .unwrap_or(term.feature_col);
623 if prediction_col >= data.ncols() {
624 return Err(FittedModelError::SchemaMismatch {
625 reason: format!(
626 "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
627 extension.name,
628 prediction_col,
629 data.ncols()
630 ),
631 });
632 }
633 let col = p_old + tail_idx;
634 for row in 0..n {
635 if data[[row, prediction_col]].to_bits() == extension.level_bits {
636 out[[row, col]] = 1.0;
637 }
638 }
639 }
640 Ok(out)
641}
642
643#[derive(Clone, Debug, Serialize, Deserialize)]
644pub struct SavedLatentScoreContract {
645 pub semantics: String,
646 pub source_transform_id: Option<String>,
647 pub normalization_mean: f64,
648 pub normalization_sd: f64,
649 pub clip_eps: Option<f64>,
650 pub conditioning_columns: Vec<String>,
651}
652
653impl FittedModelPayload {
654 pub fn new(
655 version: u32,
656 formula: String,
657 model_kind: ModelKind,
658 family_state: FittedFamily,
659 family: String,
660 ) -> Self {
661 Self {
662 version,
663 formula,
664 model_kind,
665 family_state,
666 family,
667 inference_notes: Vec::new(),
668 used_device: false,
669 fit_result: None,
670 unified: None,
671 spline_scan: None,
672 residual_cascade: None,
673 data_schema: None,
674 link: None,
675 mixture_link_param_covariance: None,
676 sas_param_covariance: None,
677 formula_noise: None,
678 formula_logslope: None,
679 formula_logslopes: None,
680 offset_column: None,
681 noise_offset_column: None,
682 beta_noise: None,
683 noise_projection: None,
684 noise_center: None,
685 noise_scale: None,
686 noise_non_intercept_start: None,
687 noise_projection_ridge_alpha: None,
688 gaussian_response_scale: None,
689 linkwiggle_knots: None,
690 linkwiggle_degree: None,
691 beta_link_wiggle: None,
692 baseline_timewiggle_knots: None,
693 baseline_timewiggle_degree: None,
694 baseline_timewiggle_penalty_orders: None,
695 baseline_timewiggle_double_penalty: None,
696 beta_baseline_timewiggle: None,
697 beta_baseline_timewiggle_by_cause: None,
698 z_column: None,
699 z_columns: None,
700 latent_z_normalization: None,
701 latent_score_contract: None,
702 latent_measure: None,
703 latent_z_rank_int_calibration: None,
704 latent_z_conditional_calibration: None,
705 marginal_baseline: None,
706 logslope_baseline: None,
707 logslope_baselines: None,
708 score_warp_runtime: None,
709 link_deviation_runtime: None,
710 influence_absorber_width: None,
711 survival_entry: None,
712 survival_exit: None,
713 survival_event: None,
714 survivalspec: None,
715 survival_cause_count: None,
716 survival_endpoint_names: None,
717 survival_baseline_target: None,
718 survival_baseline_scale: None,
719 survival_baseline_shape: None,
720 survival_baseline_rate: None,
721 survival_baseline_makeham: None,
722 survival_time_basis: None,
723 survival_time_degree: None,
724 survival_time_knots: None,
725 survival_time_keep_cols: None,
726 survival_time_smooth_lambda: None,
727 survival_time_anchor: None,
728 survivalridge_lambda: None,
729 survival_likelihood: None,
730 survival_beta_time: None,
731 survival_beta_threshold: None,
732 survival_beta_log_sigma: None,
733 survival_noise_projection: None,
734 survival_noise_center: None,
735 survival_noise_scale: None,
736 survival_noise_non_intercept_start: None,
737 survival_noise_projection_ridge_alpha: None,
738 survival_distribution: None,
739 training_headers: None,
740 training_table_kind: None,
741 training_feature_ranges: None,
742 group_metadata: None,
743 deployment_extensions: Vec::new(),
744 transformation_response_knots: None,
745 transformation_response_transform: None,
746 transformation_response_degree: None,
747 transformation_response_median: None,
748 transformation_score_calibration: None,
749 resolved_termspec: None,
750 resolved_termspec_noise: None,
751 resolved_termspec_logslope: None,
752 resolved_termspec_logslopes: None,
753 adaptive_regularization_diagnostics: None,
754 gaussian_jackknife_plus: None,
755 full_conformal: None,
756 }
757 }
758
759 pub fn set_training_feature_metadata(
760 &mut self,
761 headers: Vec<String>,
762 feature_ranges: Vec<(f64, f64)>,
763 ) {
764 self.training_headers = Some(headers);
765 self.training_feature_ranges = Some(feature_ranges);
766 }
767
768 fn synchronize_empty_feature_contract(&mut self) {
769 if self.fit_result.is_none() {
770 return;
771 }
772 let Some(schema) = self.data_schema.as_ref() else {
773 return;
774 };
775 if !schema.columns.is_empty() {
776 return;
777 }
778 self.training_headers.get_or_insert_with(Vec::new);
779 self.resolved_termspec
780 .get_or_insert_with(|| TermCollectionSpec {
781 linear_terms: Vec::new(),
782 smooth_terms: Vec::new(),
783 random_effect_terms: Vec::new(),
784 });
785 }
786
787 pub fn apply_survival_time_basis(
795 &mut self,
796 snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
797 ) {
798 self.survival_time_basis = Some(snapshot.basisname.clone());
799 self.survival_time_degree = snapshot.degree;
800 self.survival_time_knots = snapshot.knots.clone();
801 self.survival_time_keep_cols = snapshot.keep_cols.clone();
802 self.survival_time_smooth_lambda = snapshot.smooth_lambda;
803 self.survival_time_anchor = Some(snapshot.anchor);
804 }
805
806 fn validate_payload_version(&self) -> Result<(), FittedModelError> {
807 if self.version != MODEL_PAYLOAD_VERSION {
808 return Err(FittedModelError::SchemaMismatch {
809 reason: format!(
810 "saved model payload schema mismatch: file has version={}, \
811 this binary expects MODEL_PAYLOAD_VERSION={}. \
812 Refit with the current CLI, or rebuild the reader at the same \
813 version the model was written with.",
814 self.version, MODEL_PAYLOAD_VERSION
815 ),
816 });
817 }
818 Ok(())
819 }
820}
821
822#[derive(Clone, Serialize, Deserialize)]
823#[serde(tag = "model_type", rename_all = "kebab-case")]
824pub enum FittedModel {
825 Standard { payload: FittedModelPayload },
826 LocationScale { payload: FittedModelPayload },
827 MarginalSlope { payload: FittedModelPayload },
828 Survival { payload: FittedModelPayload },
829 TransformationNormal { payload: FittedModelPayload },
830}
831
832#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
833#[serde(rename_all = "kebab-case")]
834pub enum ModelKind {
835 Standard,
836 LocationScale,
837 MarginalSlope,
838 Survival,
839 TransformationNormal,
840}
841
842#[derive(Clone, Debug, Serialize, Deserialize)]
843#[serde(tag = "family_kind", rename_all = "kebab-case")]
844pub enum FittedFamily {
845 Standard {
846 likelihood: LikelihoodSpec,
847 #[serde(default)]
848 link: Option<StandardLink>,
849 #[serde(default)]
850 latent_cloglog_state: Option<LatentCLogLogState>,
851 #[serde(default)]
852 mixture_state: Option<MixtureLinkState>,
853 #[serde(default)]
854 sas_state: Option<SasLinkState>,
855 },
856 LocationScale {
857 likelihood: LikelihoodSpec,
858 #[serde(default)]
859 base_link: Option<InverseLink>,
860 },
861 MarginalSlope {
862 likelihood: LikelihoodSpec,
863 base_link: InverseLink,
864 frailty: FrailtySpec,
865 },
866 Survival {
867 likelihood: LikelihoodSpec,
868 #[serde(default)]
869 survival_likelihood: Option<String>,
870 #[serde(default)]
871 survival_distribution: Option<ResidualDistribution>,
872 frailty: FrailtySpec,
873 },
874 LatentSurvival {
875 frailty: FrailtySpec,
876 },
877 LatentBinary {
878 frailty: FrailtySpec,
879 },
880 TransformationNormal {
881 likelihood: LikelihoodSpec,
882 },
883}
884
885#[derive(Clone, Copy, Debug, Eq, PartialEq)]
886pub enum PredictModelClass {
887 Standard,
888 GaussianLocationScale,
889 BinomialLocationScale,
890 DispersionLocationScale,
895 BernoulliMarginalSlope,
896 Survival,
897 TransformationNormal,
898}
899
900impl PredictModelClass {
901 #[inline]
902 pub const fn name(self) -> &'static str {
903 match self {
904 Self::Standard => "standard",
905 Self::GaussianLocationScale => "gaussian location-scale",
906 Self::BinomialLocationScale => "binomial location-scale",
907 Self::DispersionLocationScale => "dispersion location-scale",
908 Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
909 Self::Survival => "survival",
910 Self::TransformationNormal => "transformation-normal",
911 }
912 }
913}
914
915#[derive(Clone, Debug)]
916pub struct SavedLinkWiggleRuntime {
917 pub knots: Vec<f64>,
918 pub degree: usize,
919 pub beta: Vec<f64>,
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedBaselineTimeWiggleRuntime {
924 pub knots: Vec<f64>,
925 pub degree: usize,
926 pub penalty_orders: Vec<usize>,
927 pub double_penalty: bool,
928 pub beta: Vec<f64>,
929}
930
931pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
934
935#[derive(Clone, Debug, Serialize, Deserialize)]
936pub struct SavedCompiledFlexBlock {
937 pub kernel: String,
938 pub breakpoints: Vec<f64>,
939 pub basis_dim: usize,
940 pub span_c0: Vec<Vec<f64>>,
941 pub span_c1: Vec<Vec<f64>>,
942 pub span_c2: Vec<Vec<f64>>,
943 pub span_c3: Vec<Vec<f64>>,
944 #[serde(default)]
950 pub anchor_correction: Option<Vec<Vec<f64>>>,
951 #[serde(default)]
955 pub anchor_components: Vec<SavedAnchorComponent>,
956}
957
958#[derive(Clone, Debug, Serialize, Deserialize)]
959pub struct SavedAnchorComponent {
960 pub kind: SavedAnchorKind,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub enum SavedAnchorKind {
965 Parametric {
966 block: ParametricAnchorBlock,
967 ncols: usize,
968 },
969 FlexEvaluation { ncols: usize },
974}
975
976#[derive(Clone, Debug)]
977pub struct SavedPredictionRuntime {
978 pub model_class: PredictModelClass,
979 pub likelihood: LikelihoodSpec,
980 pub inverse_link: Option<InverseLink>,
981 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
982 pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
983 pub score_warp: Option<SavedCompiledFlexBlock>,
984 pub link_deviation: Option<SavedCompiledFlexBlock>,
985 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
989 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
995 pub influence_absorber_width: Option<usize>,
1005}
1006
1007pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1008 fit.block_by_role(BlockRole::Location)
1009 .or_else(|| fit.block_by_role(BlockRole::Mean))
1010 .map(|block| block.beta.clone())
1011}
1012
1013pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1014 fit.block_by_role(BlockRole::Threshold)
1015 .or_else(|| fit.block_by_role(BlockRole::Location))
1016 .or_else(|| fit.block_by_role(BlockRole::Mean))
1017 .map(|block| block.beta.clone())
1018}
1019
1020pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021 fit.block_by_role(BlockRole::Scale)
1022 .map(|block| block.beta.clone())
1023}
1024
1025fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1034 use gam_problem::types::ResponseFamily;
1035 matches!(
1036 response,
1037 ResponseFamily::NegativeBinomial { .. }
1038 | ResponseFamily::Gamma
1039 | ResponseFamily::Beta { .. }
1040 | ResponseFamily::Tweedie { .. }
1041 )
1042}
1043
1044fn validate_location_scale_saved_fit(
1045 fit: &UnifiedFitResult,
1046 model_class: PredictModelClass,
1047 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1048) -> Result<(), FittedModelError> {
1049 let primary = match model_class {
1050 PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1054 gaussian_location_scale_mean_beta(fit)
1055 }
1056 PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1057 _ => None,
1058 }
1059 .ok_or_else(|| FittedModelError::MissingField {
1060 reason: match model_class {
1061 PredictModelClass::GaussianLocationScale => {
1062 "gaussian-location-scale saved fit is missing mean/location block".to_string()
1063 }
1064 PredictModelClass::DispersionLocationScale => {
1065 "dispersion-location-scale saved fit is missing mean/location block".to_string()
1066 }
1067 PredictModelClass::BinomialLocationScale => {
1068 "binomial-location-scale saved fit is missing threshold/location block".to_string()
1069 }
1070 _ => "location-scale saved fit is missing primary block".to_string(),
1071 },
1072 })?;
1073
1074 let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1075 reason: "location-scale saved fit is missing scale block".to_string(),
1076 })?;
1077 let expected =
1078 primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1079
1080 if let Some(cov) = fit.beta_covariance()
1081 && (cov.nrows() != expected || cov.ncols() != expected)
1082 {
1083 return Err(FittedModelError::SchemaMismatch {
1084 reason: format!(
1085 "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1086 cov.nrows(),
1087 cov.ncols(),
1088 expected,
1089 expected
1090 ),
1091 });
1092 }
1093 if let Some(cov) = fit.beta_covariance_corrected()
1094 && (cov.nrows() != expected || cov.ncols() != expected)
1095 {
1096 return Err(FittedModelError::SchemaMismatch {
1097 reason: format!(
1098 "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1099 cov.nrows(),
1100 cov.ncols(),
1101 expected,
1102 expected
1103 ),
1104 });
1105 }
1106 Ok(())
1107}
1108
1109fn validate_survival_saved_block_matches_payload(
1110 fit: &UnifiedFitResult,
1111 role: BlockRole,
1112 payload_beta: Option<&Vec<f64>>,
1113 label: &str,
1114) -> Result<usize, FittedModelError> {
1115 let block = fit
1116 .block_by_role(role)
1117 .ok_or_else(|| FittedModelError::MissingField {
1118 reason: format!("location-scale survival saved fit is missing {label} block"),
1119 })?;
1120 if let Some(saved) = payload_beta
1121 && block.beta.to_vec() != *saved
1122 {
1123 return Err(FittedModelError::SchemaMismatch {
1124 reason: format!(
1125 "location-scale survival saved {label} coefficients disagree with fit_result"
1126 ),
1127 });
1128 }
1129 Ok(block.beta.len())
1130}
1131
1132fn validate_survival_location_scale_saved_fit(
1133 payload: &FittedModelPayload,
1134 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1135) -> Result<(), FittedModelError> {
1136 let fit = payload
1137 .fit_result
1138 .as_ref()
1139 .ok_or_else(|| FittedModelError::MissingField {
1140 reason: "location-scale survival model is missing canonical fit_result payload"
1141 .to_string(),
1142 })?;
1143 let p_time = validate_survival_saved_block_matches_payload(
1144 fit,
1145 BlockRole::Time,
1146 payload.survival_beta_time.as_ref(),
1147 "time",
1148 )?;
1149 let p_threshold = validate_survival_saved_block_matches_payload(
1150 fit,
1151 BlockRole::Threshold,
1152 payload.survival_beta_threshold.as_ref(),
1153 "threshold",
1154 )?;
1155 let p_log_sigma = validate_survival_saved_block_matches_payload(
1156 fit,
1157 BlockRole::Scale,
1158 payload.survival_beta_log_sigma.as_ref(),
1159 "log-sigma",
1160 )?;
1161 let p_wiggle = match link_wiggle {
1162 Some(runtime) => {
1163 let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1164 FittedModelError::MissingField {
1165 reason: "location-scale survival saved fit is missing link-wiggle block"
1166 .to_string(),
1167 }
1168 })?;
1169 if block.beta.to_vec() != runtime.beta {
1170 return Err(FittedModelError::SchemaMismatch {
1171 reason:
1172 "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1173 .to_string(),
1174 });
1175 }
1176 runtime.beta.len()
1177 }
1178 None => {
1179 if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1180 return Err(FittedModelError::SchemaMismatch {
1181 reason:
1182 "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1183 .to_string(),
1184 });
1185 }
1186 0
1187 }
1188 };
1189 let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1190
1191 if let Some(cov) = fit.beta_covariance()
1192 && (cov.nrows() != expected || cov.ncols() != expected)
1193 {
1194 return Err(FittedModelError::SchemaMismatch {
1195 reason: format!(
1196 "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1197 cov.nrows(),
1198 cov.ncols(),
1199 expected,
1200 expected
1201 ),
1202 });
1203 }
1204 if let Some(cov) = fit.beta_covariance_corrected()
1205 && (cov.nrows() != expected || cov.ncols() != expected)
1206 {
1207 return Err(FittedModelError::SchemaMismatch {
1208 reason: format!(
1209 "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1210 cov.nrows(),
1211 cov.ncols(),
1212 expected,
1213 expected
1214 ),
1215 });
1216 }
1217 Ok(())
1218}
1219
1220fn validate_marginal_slope_saved_fit(
1221 fit: &UnifiedFitResult,
1222 score_warp: Option<&SavedCompiledFlexBlock>,
1223 link_deviation: Option<&SavedCompiledFlexBlock>,
1224 fit_label: &str,
1225) -> Result<(), FittedModelError> {
1226 validate_marginal_slope_saved_fit_impl(
1227 fit,
1228 score_warp,
1229 link_deviation,
1230 fit_label,
1231 "bernoulli",
1232 2,
1233 "marginal, logslope",
1234 )
1235}
1236
1237fn validate_survival_marginal_slope_saved_fit(
1238 fit: &UnifiedFitResult,
1239 score_warp: Option<&SavedCompiledFlexBlock>,
1240 link_deviation: Option<&SavedCompiledFlexBlock>,
1241 fit_label: &str,
1242) -> Result<(), FittedModelError> {
1243 validate_marginal_slope_saved_fit_impl(
1244 fit,
1245 score_warp,
1246 link_deviation,
1247 fit_label,
1248 "survival",
1249 3,
1250 "time, marginal, slope",
1251 )
1252}
1253
1254fn validate_marginal_slope_saved_fit_impl(
1262 fit: &UnifiedFitResult,
1263 score_warp: Option<&SavedCompiledFlexBlock>,
1264 link_deviation: Option<&SavedCompiledFlexBlock>,
1265 fit_label: &str,
1266 family_kind: &str,
1267 base_block_count: usize,
1268 base_block_role_list: &str,
1269) -> Result<(), FittedModelError> {
1270 let expected_blocks = base_block_count
1271 + usize::from(score_warp.is_some())
1272 + usize::from(link_deviation.is_some());
1273 if fit.blocks.len() != expected_blocks {
1274 let score_warp_suffix = if score_warp.is_some() {
1275 ", score-warp"
1276 } else {
1277 ""
1278 };
1279 let link_deviation_suffix = if link_deviation.is_some() {
1280 ", link-deviation"
1281 } else {
1282 ""
1283 };
1284 return Err(FittedModelError::SchemaMismatch {
1285 reason: format!(
1286 "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1287 fit.blocks.len(),
1288 ),
1289 });
1290 }
1291 if let Some(runtime) = score_warp {
1292 let beta = &fit.blocks[base_block_count].beta;
1293 if beta.len() != runtime.basis_dim {
1294 return Err(FittedModelError::SchemaMismatch {
1295 reason: format!(
1296 "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1297 beta.len(),
1298 runtime.basis_dim
1299 ),
1300 });
1301 }
1302 }
1303 if let Some(runtime) = link_deviation {
1304 let idx = base_block_count + usize::from(score_warp.is_some());
1305 let beta = &fit.blocks[idx].beta;
1306 if beta.len() != runtime.basis_dim {
1307 return Err(FittedModelError::SchemaMismatch {
1308 reason: format!(
1309 "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1310 beta.len(),
1311 runtime.basis_dim
1312 ),
1313 });
1314 }
1315 }
1316 Ok(())
1317}
1318
1319impl SavedLinkWiggleRuntime {
1320 fn validate_monotone_derivative(
1321 &self,
1322 q0: &Array1<f64>,
1323 ) -> Result<Array1<f64>, FittedModelError> {
1324 let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1331 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1332 let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1333 if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1334 return Err(FittedModelError::PayloadCorrupt {
1335 reason: format!(
1336 "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1337 ),
1338 });
1339 }
1340 Ok(dq_dq0)
1341 }
1342
1343 pub fn constrained_basis(
1344 &self,
1345 q0: &Array1<f64>,
1346 basis_options: BasisOptions,
1347 ) -> Result<Array2<f64>, FittedModelError> {
1348 let knot_arr = Array1::from_vec(self.knots.clone());
1349 let constrained = monotone_wiggle_basis_with_derivative_order(
1350 q0.view(),
1351 &knot_arr,
1352 self.degree,
1353 basis_options.derivative_order,
1354 )
1355 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1356 if constrained.ncols() != self.beta.len() {
1357 return Err(FittedModelError::SchemaMismatch {
1358 reason: format!(
1359 "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1360 self.beta.len(),
1361 constrained.ncols()
1362 ),
1363 });
1364 }
1365 Ok(constrained)
1366 }
1367
1368 pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1369 self.validate_monotone_derivative(q0)?;
1370 self.constrained_basis(q0, BasisOptions::value())
1371 }
1372
1373 pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1374 let q = Array1::from_vec(vec![q0]);
1375 let x = self.design(&q)?;
1376 if x.nrows() != 1 {
1377 return Err(FittedModelError::SchemaMismatch {
1378 reason: format!(
1379 "saved link-wiggle scalar evaluation expected 1 row, got {}",
1380 x.nrows()
1381 ),
1382 });
1383 }
1384 Ok(x.row(0).to_owned())
1385 }
1386
1387 pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1388 self.validate_monotone_derivative(q0)?;
1389 let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1390 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1391 Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1392 }
1393
1394 pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395 self.validate_monotone_derivative(q0)
1396 }
1397}
1398
1399impl SavedBaselineTimeWiggleRuntime {
1400 pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1401 validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1402 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1403 }
1404}
1405
1406impl SavedCompiledFlexBlock {
1407 pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1408 if self.kernel.is_empty() {
1409 return Err(FittedModelError::SchemaMismatch {
1410 reason: "saved anchored deviation runtime is missing the exact kernel marker"
1411 .to_string(),
1412 });
1413 }
1414 if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1415 return Err(FittedModelError::IncompatibleConfig {
1416 reason: format!(
1417 "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1418 self.kernel,
1419 crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1420 ),
1421 });
1422 }
1423 if self.basis_dim == 0 {
1424 return Err(FittedModelError::SchemaMismatch {
1425 reason: format!(
1426 "saved anchored deviation runtime basis_dim must be positive, got {}",
1427 self.basis_dim
1428 ),
1429 });
1430 }
1431 if self.breakpoints.len() < 2 {
1432 return Err(FittedModelError::SchemaMismatch {
1433 reason: format!(
1434 "saved anchored deviation runtime requires at least two breakpoints, got {}",
1435 self.breakpoints.len()
1436 ),
1437 });
1438 }
1439 for window in self.breakpoints.windows(2) {
1440 let left = window[0];
1441 let right = window[1];
1442 if !left.is_finite() || !right.is_finite() || right <= left {
1443 return Err(FittedModelError::PayloadCorrupt {
1444 reason: format!(
1445 "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1446 ),
1447 });
1448 }
1449 }
1450 let span_count = self.breakpoints.len() - 1;
1451 self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1452 self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1453 self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1454 self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1455 self.validate_c2_span_continuity()?;
1456 self.validate_anchor_residual_shape()?;
1457 Ok(())
1458 }
1459
1460 fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1461 let coeffs = match self.anchor_correction.as_ref() {
1462 Some(c) => c,
1463 None => {
1464 if !self.anchor_components.is_empty() {
1465 return Err(FittedModelError::SchemaMismatch {
1466 reason:
1467 "saved anchored deviation runtime has anchor_components but no anchor_correction"
1468 .to_string(),
1469 });
1470 }
1471 return Ok(());
1472 }
1473 };
1474 let d: usize = self
1475 .anchor_components
1476 .iter()
1477 .map(|c| match &c.kind {
1478 SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1479 SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1480 })
1481 .sum();
1482 if coeffs.len() != d {
1483 return Err(FittedModelError::SchemaMismatch {
1484 reason: format!(
1485 "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1486 coeffs.len(),
1487 d,
1488 ),
1489 });
1490 }
1491 for (i, row) in coeffs.iter().enumerate() {
1492 if row.len() != self.basis_dim {
1493 return Err(FittedModelError::SchemaMismatch {
1494 reason: format!(
1495 "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1496 i,
1497 row.len(),
1498 self.basis_dim,
1499 ),
1500 });
1501 }
1502 for (j, &v) in row.iter().enumerate() {
1503 if !v.is_finite() {
1504 return Err(FittedModelError::PayloadCorrupt {
1505 reason: format!(
1506 "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1507 ),
1508 });
1509 }
1510 }
1511 }
1512 Ok(())
1513 }
1514
1515 fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1516 const TOL: f64 = 1e-8;
1517 for span_idx in 1..self.breakpoints.len() - 1 {
1518 let left_span = span_idx - 1;
1519 let right_span = span_idx;
1520 let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1521 for basis_idx in 0..self.basis_dim {
1522 let left_value = self.span_c0[left_span][basis_idx]
1523 + self.span_c1[left_span][basis_idx] * width
1524 + self.span_c2[left_span][basis_idx] * width * width
1525 + self.span_c3[left_span][basis_idx] * width * width * width;
1526 let left_d1 = self.span_c1[left_span][basis_idx]
1527 + 2.0 * self.span_c2[left_span][basis_idx] * width
1528 + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1529 let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1530 + 6.0 * self.span_c3[left_span][basis_idx] * width;
1531 let right_value = self.span_c0[right_span][basis_idx];
1532 let right_d1 = self.span_c1[right_span][basis_idx];
1533 let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1534 if (left_value - right_value).abs() > TOL
1535 || (left_d1 - right_d1).abs() > TOL
1536 || (left_d2 - right_d2).abs() > TOL
1537 {
1538 return Err(FittedModelError::SchemaMismatch {
1539 reason: format!(
1540 "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1541 left_value - right_value,
1542 left_d1 - right_d1,
1543 left_d2 - right_d2
1544 ),
1545 });
1546 }
1547 }
1548 }
1549 Ok(())
1550 }
1551
1552 fn validate_coefficient_matrix(
1553 &self,
1554 matrix: &[Vec<f64>],
1555 label: &str,
1556 expected_rows: usize,
1557 ) -> Result<(), FittedModelError> {
1558 if matrix.len() != expected_rows {
1559 return Err(FittedModelError::SchemaMismatch {
1560 reason: format!(
1561 "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1562 matrix.len(),
1563 expected_rows
1564 ),
1565 });
1566 }
1567 for (row_idx, row) in matrix.iter().enumerate() {
1568 if row.len() != self.basis_dim {
1569 return Err(FittedModelError::SchemaMismatch {
1570 reason: format!(
1571 "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1572 row_idx,
1573 row.len(),
1574 self.basis_dim
1575 ),
1576 });
1577 }
1578 for (j, &value) in row.iter().enumerate() {
1579 if !value.is_finite() {
1580 return Err(FittedModelError::PayloadCorrupt {
1581 reason: format!(
1582 "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1583 ),
1584 });
1585 }
1586 }
1587 }
1588 Ok(())
1589 }
1590
1591 fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1592 let last_span = self.breakpoints.len() - 2;
1593 let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1594 self.span_c0[last_span][basis_idx]
1595 + self.span_c1[last_span][basis_idx] * width
1596 + self.span_c2[last_span][basis_idx] * width * width
1597 + self.span_c3[last_span][basis_idx] * width * width * width
1598 }
1599
1600 fn evaluate_span_polynomial_design(
1601 &self,
1602 values: &Array1<f64>,
1603 derivative_order: usize,
1604 ) -> Result<Array2<f64>, FittedModelError> {
1605 self.validate_exact_replay_contract()?;
1606 let (left_ep, right_ep) = self.support_interval()?;
1607 let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1608 for (row_idx, &value) in values.iter().enumerate() {
1609 if !value.is_finite() {
1610 return Err(FittedModelError::PayloadCorrupt {
1611 reason: format!(
1612 "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1613 ),
1614 });
1615 }
1616 if value < left_ep {
1617 if derivative_order == 0 {
1618 for basis_idx in 0..self.basis_dim {
1619 out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1620 }
1621 }
1622 continue;
1623 }
1624 if value > right_ep {
1625 if derivative_order == 0 {
1626 for basis_idx in 0..self.basis_dim {
1627 out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1628 }
1629 }
1630 continue;
1631 }
1632 let span_idx = self.left_biased_span_index_for(value)?;
1633 let t = value - self.breakpoints[span_idx];
1634 for basis_idx in 0..self.basis_dim {
1635 let c0 = self.span_c0[span_idx][basis_idx];
1636 let c1 = self.span_c1[span_idx][basis_idx];
1637 let c2 = self.span_c2[span_idx][basis_idx];
1638 let c3 = self.span_c3[span_idx][basis_idx];
1639 out[[row_idx, basis_idx]] = match derivative_order {
1640 0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1641 1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1642 2 => 2.0 * c2 + 6.0 * c3 * t,
1643 3 => 6.0 * c3,
1644 4 => 0.0,
1645 other => {
1646 return Err(FittedModelError::IncompatibleConfig {
1647 reason: format!(
1648 "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1649 ),
1650 });
1651 }
1652 };
1653 }
1654 }
1655 Ok(out)
1656 }
1657
1658 pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1659 self.validate_exact_replay_contract()?;
1660 Ok(self.breakpoints.clone())
1661 }
1662
1663 pub fn span_count(&self) -> Result<usize, FittedModelError> {
1664 Ok(self.breakpoints()?.windows(2).count())
1665 }
1666
1667 pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1668 let points = self.breakpoints()?;
1669 span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1670 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1671 }
1672
1673 fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1674 let mut span_idx = span_index_for_breakpoints(
1675 &self.breakpoints,
1676 value,
1677 "saved anchored deviation span lookup",
1678 )
1679 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1680 if span_idx > 0 && value == self.breakpoints[span_idx] {
1683 span_idx -= 1;
1684 }
1685 Ok(span_idx)
1686 }
1687
1688 pub fn local_cubic_on_span(
1689 &self,
1690 beta: &Array1<f64>,
1691 span_idx: usize,
1692 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1693 self.validate_exact_replay_contract()?;
1694 if beta.len() != self.basis_dim {
1695 return Err(FittedModelError::SchemaMismatch {
1696 reason: format!(
1697 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1698 beta.len(),
1699 self.basis_dim
1700 ),
1701 });
1702 }
1703 self.local_cubic_on_span_validated(beta, span_idx)
1704 }
1705
1706 fn local_cubic_on_span_validated(
1707 &self,
1708 beta: &Array1<f64>,
1709 span_idx: usize,
1710 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1711 let points = &self.breakpoints;
1712 if span_idx + 1 >= points.len() {
1713 return Err(FittedModelError::SchemaMismatch {
1714 reason: format!(
1715 "saved anchored deviation span index {} out of range for {} spans",
1716 span_idx,
1717 points.len() - 1
1718 ),
1719 });
1720 }
1721 let left = points[span_idx];
1722 let right = points[span_idx + 1];
1723 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1724 left,
1725 right,
1726 c0: self.span_c0[span_idx]
1727 .iter()
1728 .zip(beta.iter())
1729 .map(|(coeff, weight)| coeff * weight)
1730 .sum(),
1731 c1: self.span_c1[span_idx]
1732 .iter()
1733 .zip(beta.iter())
1734 .map(|(coeff, weight)| coeff * weight)
1735 .sum(),
1736 c2: self.span_c2[span_idx]
1737 .iter()
1738 .zip(beta.iter())
1739 .map(|(coeff, weight)| coeff * weight)
1740 .sum(),
1741 c3: self.span_c3[span_idx]
1742 .iter()
1743 .zip(beta.iter())
1744 .map(|(coeff, weight)| coeff * weight)
1745 .sum(),
1746 })
1747 }
1748
1749 pub fn basis_span_cubic(
1750 &self,
1751 span_idx: usize,
1752 basis_idx: usize,
1753 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1754 self.validate_exact_replay_contract()?;
1755 if basis_idx >= self.basis_dim {
1756 return Err(FittedModelError::SchemaMismatch {
1757 reason: format!(
1758 "saved anchored deviation basis index {} out of range for {} coefficients",
1759 basis_idx, self.basis_dim
1760 ),
1761 });
1762 }
1763 self.basis_span_cubic_validated(span_idx, basis_idx)
1764 }
1765
1766 fn basis_span_cubic_validated(
1767 &self,
1768 span_idx: usize,
1769 basis_idx: usize,
1770 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1771 let points = &self.breakpoints;
1772 if span_idx + 1 >= points.len() {
1773 return Err(FittedModelError::SchemaMismatch {
1774 reason: format!(
1775 "saved anchored deviation span index {} out of range for {} spans",
1776 span_idx,
1777 points.len() - 1
1778 ),
1779 });
1780 }
1781 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1782 left: points[span_idx],
1783 right: points[span_idx + 1],
1784 c0: self.span_c0[span_idx][basis_idx],
1785 c1: self.span_c1[span_idx][basis_idx],
1786 c2: self.span_c2[span_idx][basis_idx],
1787 c3: self.span_c3[span_idx][basis_idx],
1788 })
1789 }
1790
1791 pub fn basis_cubic_at(
1792 &self,
1793 basis_idx: usize,
1794 value: f64,
1795 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1796 self.validate_exact_replay_contract()?;
1797 if basis_idx >= self.basis_dim {
1798 return Err(FittedModelError::SchemaMismatch {
1799 reason: format!(
1800 "saved anchored deviation basis index {} out of range for {} coefficients",
1801 basis_idx, self.basis_dim
1802 ),
1803 });
1804 }
1805 let (left_ep, right_ep) = self.support_interval()?;
1806 if value < left_ep {
1807 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1808 left: left_ep,
1809 right: left_ep + 1.0,
1810 c0: self.span_c0[0][basis_idx],
1811 c1: 0.0,
1812 c2: 0.0,
1813 c3: 0.0,
1814 });
1815 }
1816 if value > right_ep {
1817 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1818 left: right_ep,
1819 right: right_ep + 1.0,
1820 c0: self.right_boundary_basis_value(basis_idx),
1821 c1: 0.0,
1822 c2: 0.0,
1823 c3: 0.0,
1824 });
1825 }
1826 let span_idx = self.left_biased_span_index_for(value)?;
1827 self.basis_span_cubic_validated(span_idx, basis_idx)
1828 }
1829
1830 pub fn local_cubic_at(
1831 &self,
1832 beta: &Array1<f64>,
1833 value: f64,
1834 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1835 self.validate_exact_replay_contract()?;
1836 if beta.len() != self.basis_dim {
1837 return Err(FittedModelError::SchemaMismatch {
1838 reason: format!(
1839 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1840 beta.len(),
1841 self.basis_dim
1842 ),
1843 });
1844 }
1845 let (left_ep, right_ep) = self.support_interval()?;
1846 if value < left_ep {
1847 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1848 left: left_ep,
1849 right: left_ep + 1.0,
1850 c0: self.span_c0[0]
1851 .iter()
1852 .zip(beta.iter())
1853 .map(|(coeff, weight)| coeff * weight)
1854 .sum(),
1855 c1: 0.0,
1856 c2: 0.0,
1857 c3: 0.0,
1858 });
1859 }
1860 if value > right_ep {
1861 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1862 left: right_ep,
1863 right: right_ep + 1.0,
1864 c0: (0..self.basis_dim)
1865 .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1866 .sum(),
1867 c1: 0.0,
1868 c2: 0.0,
1869 c3: 0.0,
1870 });
1871 }
1872 let span_idx = self.left_biased_span_index_for(value)?;
1873 self.local_cubic_on_span_validated(beta, span_idx)
1874 }
1875
1876 fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1877 let points = self.breakpoints()?;
1878 match (points.first(), points.last()) {
1879 (Some(&left), Some(&right)) => Ok((left, right)),
1880 _ => Err(FittedModelError::MissingField {
1881 reason: "saved anchored deviation runtime is missing support breakpoints"
1882 .to_string(),
1883 }),
1884 }
1885 }
1886
1887 pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1888 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1897 }
1898
1899 pub fn design_uncorrected(
1907 &self,
1908 values: &Array1<f64>,
1909 ) -> Result<Array2<f64>, FittedModelError> {
1910 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1911 }
1912
1913 pub fn design_with_anchor_rows(
1922 &self,
1923 values: &Array1<f64>,
1924 anchor_rows: ndarray::ArrayView2<f64>,
1925 ) -> Result<Array2<f64>, FittedModelError> {
1926 let mut out =
1927 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1928 if let Some(m_rows) = self.anchor_correction.as_ref() {
1929 let d = m_rows.len();
1930 if anchor_rows.nrows() != values.len() {
1931 return Err(FittedModelError::SchemaMismatch {
1932 reason: format!(
1933 "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1934 anchor_rows.nrows(),
1935 values.len(),
1936 ),
1937 });
1938 }
1939 if anchor_rows.ncols() != d {
1940 return Err(FittedModelError::SchemaMismatch {
1941 reason: format!(
1942 "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1943 anchor_rows.ncols(),
1944 d,
1945 ),
1946 });
1947 }
1948 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1950 for (i, row) in m_rows.iter().enumerate() {
1951 if row.len() != self.basis_dim {
1952 return Err(FittedModelError::SchemaMismatch {
1953 reason: format!(
1954 "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1955 i,
1956 row.len(),
1957 self.basis_dim,
1958 ),
1959 });
1960 }
1961 for (j, &v) in row.iter().enumerate() {
1962 m_dense[[i, j]] = v;
1963 }
1964 }
1965 let subtract = anchor_rows.dot(&m_dense);
1968 out = out - subtract;
1969 } else if anchor_rows.ncols() != 0 {
1970 return Err(FittedModelError::SchemaMismatch {
1971 reason: format!(
1972 "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1973 anchor_rows.ncols(),
1974 ),
1975 });
1976 }
1977 Ok(out)
1978 }
1979
1980 pub fn anchor_correction_matrix(
1988 &self,
1989 n_anchor_rows: ndarray::ArrayView2<f64>,
1990 ) -> Result<Option<Array2<f64>>, FittedModelError> {
1991 let Some(m_rows) = self.anchor_correction.as_ref() else {
1992 return Ok(None);
1993 };
1994 let d = m_rows.len();
1995 if n_anchor_rows.ncols() != d {
1996 return Err(FittedModelError::SchemaMismatch {
1997 reason: format!(
1998 "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
1999 n_anchor_rows.ncols(),
2000 d,
2001 ),
2002 });
2003 }
2004 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2005 for (i, row) in m_rows.iter().enumerate() {
2006 if row.len() != self.basis_dim {
2007 return Err(FittedModelError::SchemaMismatch {
2008 reason: format!(
2009 "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2010 i,
2011 row.len(),
2012 self.basis_dim,
2013 ),
2014 });
2015 }
2016 for (j, &v) in row.iter().enumerate() {
2017 m_dense[[i, j]] = v;
2018 }
2019 }
2020 Ok(Some(n_anchor_rows.dot(&m_dense)))
2023 }
2024
2025 pub fn first_derivative_design(
2026 &self,
2027 values: &Array1<f64>,
2028 ) -> Result<Array2<f64>, FittedModelError> {
2029 self.evaluate_span_polynomial_design(
2030 values,
2031 BasisOptions::first_derivative().derivative_order,
2032 )
2033 }
2034
2035 pub fn second_derivative_design(
2036 &self,
2037 values: &Array1<f64>,
2038 ) -> Result<Array2<f64>, FittedModelError> {
2039 self.evaluate_span_polynomial_design(
2040 values,
2041 BasisOptions::second_derivative().derivative_order,
2042 )
2043 }
2044}
2045
2046impl FittedFamily {
2047 #[inline]
2048 pub fn likelihood(&self) -> LikelihoodSpec {
2049 let spec = match self {
2050 Self::Standard { likelihood, .. }
2051 | Self::LocationScale { likelihood, .. }
2052 | Self::MarginalSlope { likelihood, .. }
2053 | Self::Survival { likelihood, .. }
2054 | Self::TransformationNormal { likelihood, .. } => likelihood,
2055 Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2056 return LikelihoodSpec::royston_parmar();
2057 }
2058 };
2059 spec.clone()
2060 }
2061
2062 #[inline]
2063 pub fn frailty(&self) -> Option<&FrailtySpec> {
2064 match self {
2065 Self::MarginalSlope { frailty, .. }
2066 | Self::Survival { frailty, .. }
2067 | Self::LatentSurvival { frailty }
2068 | Self::LatentBinary { frailty } => Some(frailty),
2069 _ => None,
2070 }
2071 }
2072}
2073
2074fn collect_smooth_extrapolation_axes(
2080 basis: &gam_terms::smooth::SmoothBasisSpec,
2081 n_training_headers: usize,
2082 out: &mut std::collections::HashSet<usize>,
2083) {
2084 use gam_terms::smooth::SmoothBasisSpec;
2085 let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2086 if col < n_training_headers {
2087 out.insert(col);
2088 }
2089 };
2090 match basis {
2091 SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2093 SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2096 for &c in feature_cols {
2097 push(c, out);
2098 }
2099 }
2100 SmoothBasisSpec::ThinPlate { feature_cols, .. }
2107 | SmoothBasisSpec::Matern { feature_cols, .. }
2108 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2109 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2110 for &c in feature_cols {
2111 push(c, out);
2112 }
2113 }
2114 SmoothBasisSpec::FactorSmooth { spec } => {
2118 for &c in &spec.continuous_cols {
2119 push(c, out);
2120 }
2121 }
2122 SmoothBasisSpec::ByVariable { inner, .. }
2124 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2125 collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2126 }
2127 SmoothBasisSpec::BySmooth { smooth, .. } => {
2128 collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2129 }
2130 SmoothBasisSpec::Sphere { .. }
2136 | SmoothBasisSpec::ConstantCurvature { .. }
2137 | SmoothBasisSpec::Pca { .. } => {}
2138 }
2139}
2140
2141fn collect_by_variable_numeric_axes(
2160 basis: &gam_terms::smooth::SmoothBasisSpec,
2161 n_training_headers: usize,
2162 out: &mut std::collections::HashSet<usize>,
2163) {
2164 use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2165 match basis {
2166 SmoothBasisSpec::ByVariable {
2167 inner,
2168 by_col,
2169 kind,
2170 ..
2171 } => {
2172 if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2173 out.insert(*by_col);
2174 }
2175 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2176 }
2177 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2178 if let ByVarKind::Numeric { feature_col } = by_kind
2179 && *feature_col < n_training_headers
2180 {
2181 out.insert(*feature_col);
2182 }
2183 collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2184 }
2185 SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2186 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2187 }
2188 _ => {}
2189 }
2190}
2191
2192impl FittedModel {
2193 pub fn axis_clip_to_training_ranges(
2202 &self,
2203 data: ndarray::ArrayView2<'_, f64>,
2204 col_map: &std::collections::HashMap<String, usize>,
2205 ) -> Option<ndarray::Array2<f64>> {
2206 let training_headers = self.training_headers.as_ref()?;
2207 let ranges = self.training_feature_ranges.as_ref()?;
2208 if training_headers.len() != ranges.len() {
2209 return None;
2210 }
2211 let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2212 std::collections::HashMap::new();
2213 if let Some(schema) = self.data_schema.as_ref() {
2214 for col in &schema.columns {
2215 kind_by_header.insert(col.name.as_str(), col.kind);
2216 }
2217 }
2218 let periodic_axes = self.training_periodic_axes(training_headers);
2224 let linear_axes = self.training_linear_axes(training_headers.len());
2232 let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2237 let smooth_extrapolation_axes =
2245 self.training_smooth_extrapolation_axes(training_headers.len());
2246 let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2252 let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2257 let mut clipped = data.to_owned();
2258 let mut any_clipped = false;
2259 for (col_in_training, (header, &(lo, hi))) in
2260 training_headers.iter().zip(ranges.iter()).enumerate()
2261 {
2262 let (lo, hi) = sphere_lat_bounds
2263 .get(&col_in_training)
2264 .copied()
2265 .unwrap_or((lo, hi));
2266 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2267 continue;
2268 }
2269 if !matches!(
2270 kind_by_header.get(header.as_str()).copied(),
2271 Some(ColumnKindTag::Continuous)
2272 ) {
2273 continue;
2274 }
2275 if periodic_axes.contains(&col_in_training) {
2276 continue;
2277 }
2278 if linear_axes.contains(&col_in_training) {
2279 continue;
2280 }
2281 if random_effect_axes.contains(&col_in_training) {
2282 continue;
2283 }
2284 if smooth_extrapolation_axes.contains(&col_in_training) {
2285 continue;
2286 }
2287 if by_variable_axes.contains(&col_in_training) {
2288 continue;
2289 }
2290 let Some(&col_idx) = col_map.get(header) else {
2291 continue;
2292 };
2293 if col_idx >= clipped.ncols() {
2294 continue;
2295 }
2296 let mut col = clipped.column_mut(col_idx);
2297 for v in col.iter_mut() {
2298 if v.is_finite() {
2299 if *v < lo {
2300 *v = lo;
2301 any_clipped = true;
2302 } else if *v > hi {
2303 *v = hi;
2304 any_clipped = true;
2305 }
2306 }
2307 }
2308 }
2309 if any_clipped { Some(clipped) } else { None }
2310 }
2311
2312 fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2313 let mut specs: Vec<&TermCollectionSpec> = [
2314 self.resolved_termspec.as_ref(),
2315 self.resolved_termspec_noise.as_ref(),
2316 self.resolved_termspec_logslope.as_ref(),
2317 ]
2318 .into_iter()
2319 .flatten()
2320 .collect();
2321 if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2322 specs.extend(logslopes.iter());
2323 }
2324 specs
2325 }
2326
2327 fn training_periodic_axes(
2334 &self,
2335 training_headers: &[String],
2336 ) -> std::collections::HashSet<usize> {
2337 use gam_terms::basis::BSplineKnotSpec;
2338 use gam_terms::smooth::SmoothBasisSpec;
2339 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2340 let Some(spec) = self.resolved_termspec.as_ref() else {
2341 return out;
2342 };
2343 for term in &spec.smooth_terms {
2344 match &term.basis {
2345 SmoothBasisSpec::Sphere { feature_cols, .. } => {
2351 if let Some(&lon_col) = feature_cols.get(1)
2352 && lon_col < training_headers.len()
2353 {
2354 out.insert(lon_col);
2355 }
2356 }
2357 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2359 if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2360 && *feature_col < training_headers.len()
2361 {
2362 out.insert(*feature_col);
2363 }
2364 }
2365 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2368 for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2369 if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2370 && let Some(&col) = feature_cols.get(i)
2371 && col < training_headers.len()
2372 {
2373 out.insert(col);
2374 }
2375 }
2376 }
2377 _ => {}
2378 }
2379 }
2380 out
2381 }
2382
2383 fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2394 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2395 for spec in self.saved_term_specs() {
2396 for term in &spec.linear_terms {
2397 for col in term.effective_feature_cols() {
2398 if col < n_training_headers {
2399 out.insert(col);
2400 }
2401 }
2402 }
2403 }
2404 out
2405 }
2406
2407 fn training_random_effect_axes(
2413 &self,
2414 n_training_headers: usize,
2415 ) -> std::collections::HashSet<usize> {
2416 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2417 for spec in self.saved_term_specs() {
2418 for term in &spec.random_effect_terms {
2419 if term.feature_col < n_training_headers {
2420 out.insert(term.feature_col);
2421 }
2422 }
2423 }
2424 out
2425 }
2426
2427 fn training_smooth_extrapolation_axes(
2460 &self,
2461 n_training_headers: usize,
2462 ) -> std::collections::HashSet<usize> {
2463 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2464 for spec in self.saved_term_specs() {
2465 for term in &spec.smooth_terms {
2466 collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2467 }
2468 }
2469 out
2470 }
2471
2472 fn training_by_variable_numeric_axes(
2478 &self,
2479 n_training_headers: usize,
2480 ) -> std::collections::HashSet<usize> {
2481 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2482 for spec in self.saved_term_specs() {
2483 for term in &spec.smooth_terms {
2484 collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2485 }
2486 }
2487 out
2488 }
2489
2490 fn training_sphere_latitude_bounds(
2512 &self,
2513 training_headers: &[String],
2514 ) -> std::collections::HashMap<usize, (f64, f64)> {
2515 use gam_terms::smooth::SmoothBasisSpec;
2516 let mut out: std::collections::HashMap<usize, (f64, f64)> =
2517 std::collections::HashMap::new();
2518 let Some(spec) = self.resolved_termspec.as_ref() else {
2519 return out;
2520 };
2521 for term in &spec.smooth_terms {
2522 if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2523 && let Some(&lat_col) = feature_cols.first()
2524 && lat_col < training_headers.len()
2525 {
2526 let bound = if spec.radians {
2527 std::f64::consts::FRAC_PI_2
2528 } else {
2529 90.0
2530 };
2531 out.insert(lat_col, (-bound, bound));
2532 }
2533 }
2534 out
2535 }
2536
2537 pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2538 let likelihood = payload.family_state.likelihood();
2539 let class = match payload.model_kind {
2540 ModelKind::Survival => PredictModelClass::Survival,
2541 ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2542 ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2543 ModelKind::LocationScale => {
2544 if likelihood == LikelihoodSpec::gaussian_identity() {
2545 PredictModelClass::GaussianLocationScale
2546 } else if is_dispersion_location_scale_response(&likelihood.response) {
2547 PredictModelClass::DispersionLocationScale
2548 } else {
2549 PredictModelClass::BinomialLocationScale
2550 }
2551 }
2552 ModelKind::Standard => PredictModelClass::Standard,
2553 };
2554 match class {
2555 PredictModelClass::Survival => {
2556 payload.model_kind = ModelKind::Survival;
2557 Self::Survival { payload }
2558 }
2559 PredictModelClass::BernoulliMarginalSlope => {
2560 payload.model_kind = ModelKind::MarginalSlope;
2561 Self::MarginalSlope { payload }
2562 }
2563 PredictModelClass::TransformationNormal => {
2564 payload.model_kind = ModelKind::TransformationNormal;
2565 Self::TransformationNormal { payload }
2566 }
2567 PredictModelClass::GaussianLocationScale
2568 | PredictModelClass::BinomialLocationScale
2569 | PredictModelClass::DispersionLocationScale => {
2570 payload.model_kind = ModelKind::LocationScale;
2571 Self::LocationScale { payload }
2572 }
2573 PredictModelClass::Standard => {
2574 payload.model_kind = ModelKind::Standard;
2575 Self::Standard { payload }
2576 }
2577 }
2578 .with_synchronized_stateful_link_metadata()
2579 }
2580
2581 #[inline]
2582 pub fn payload(&self) -> &FittedModelPayload {
2583 match self {
2584 Self::Standard { payload }
2585 | Self::LocationScale { payload }
2586 | Self::MarginalSlope { payload }
2587 | Self::Survival { payload }
2588 | Self::TransformationNormal { payload } => payload,
2589 }
2590 }
2591
2592 #[inline]
2593 fn payload_mut(&mut self) -> &mut FittedModelPayload {
2594 match self {
2595 Self::Standard { payload }
2596 | Self::LocationScale { payload }
2597 | Self::MarginalSlope { payload }
2598 | Self::Survival { payload }
2599 | Self::TransformationNormal { payload } => payload,
2600 }
2601 }
2602
2603 fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2604 self.synchronize_stateful_link_metadata();
2605 self
2606 }
2607
2608 fn synchronize_stateful_link_metadata(&mut self) {
2609 let payload = self.payload_mut();
2610 payload.used_device = payload
2611 .fit_result
2612 .as_ref()
2613 .or(payload.unified.as_ref())
2614 .is_some_and(|fit| fit.used_device);
2615 payload.synchronize_empty_feature_contract();
2616 let Some(fit) = payload.fit_result.as_ref() else {
2617 return;
2618 };
2619 match (&mut payload.family_state, &fit.fitted_link) {
2620 (
2621 FittedFamily::Standard {
2622 likelihood,
2623 latent_cloglog_state,
2624 ..
2625 },
2626 FittedLinkState::LatentCLogLog { state },
2627 ) if likelihood.is_latent_cloglog() => {
2628 *latent_cloglog_state = Some(*state);
2629 }
2630 (
2631 FittedFamily::Standard {
2632 likelihood,
2633 sas_state,
2634 ..
2635 },
2636 FittedLinkState::Sas { state, covariance },
2637 ) if likelihood.is_binomial_sas() => {
2638 *sas_state = Some(*state);
2639 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2640 }
2641 (
2642 FittedFamily::Standard {
2643 likelihood,
2644 sas_state,
2645 ..
2646 },
2647 FittedLinkState::BetaLogistic { state, covariance },
2648 ) if likelihood.is_binomial_beta_logistic() => {
2649 *sas_state = Some(*state);
2650 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2651 }
2652 (
2653 FittedFamily::Standard {
2654 likelihood,
2655 mixture_state,
2656 ..
2657 },
2658 FittedLinkState::Mixture { state, covariance },
2659 ) if likelihood.is_binomial_mixture() => {
2660 *mixture_state = Some(state.clone());
2661 payload.mixture_link_param_covariance =
2662 covariance.as_ref().map(array2_to_nestedvec);
2663 }
2664 _ => {}
2665 }
2666 }
2667
2668 #[inline]
2669 pub fn likelihood(&self) -> LikelihoodSpec {
2670 self.payload().family_state.likelihood()
2671 }
2672
2673 pub fn prediction_required_columns(
2691 &self,
2692 ) -> Result<std::collections::BTreeSet<String>, String> {
2693 let payload = self.payload();
2694 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2695 let mut required = std::collections::BTreeSet::<String>::new();
2696 parsed_term_column_names(&parsed.terms, &mut required);
2697
2698 if let Some((entry, exit, _event)) =
2699 parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2700 {
2701 if let Some(entry) = entry {
2702 required.insert(entry);
2703 }
2704 required.insert(exit);
2705 } else if let Some((left, right, _event)) =
2706 parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707 {
2708 required.insert(left);
2709 required.insert(right);
2710 }
2711 if let Some(offset) = payload.offset_column.as_ref() {
2719 required.insert(offset.clone());
2720 }
2721 if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2722 required.insert(noise_offset.clone());
2723 }
2724 if matches!(
2725 self.predict_model_class(),
2726 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2727 ) {
2728 if let Some(z_column) = payload.z_column.as_ref() {
2729 required.remove("z");
2730 required.insert(z_column.clone());
2731 }
2732 }
2733 if let Some(noise_formula) = payload.formula_noise.as_ref() {
2734 self.add_auxiliary_formula_columns(
2735 &mut required,
2736 noise_formula,
2737 parsed.response.as_str(),
2738 )?;
2739 }
2740 if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2741 if logslope_formula != "same-as-main" {
2742 self.add_auxiliary_formula_columns(
2743 &mut required,
2744 logslope_formula,
2745 parsed.response.as_str(),
2746 )?;
2747 }
2748 }
2749 Ok(required)
2750 }
2751
2752 pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2769 let payload = self.payload();
2770 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2771 if parse_surv_response(parsed.response.as_str())
2774 .map_err(|e| e.to_string())?
2775 .is_some()
2776 || parse_surv_interval_response(parsed.response.as_str())
2777 .map_err(|e| e.to_string())?
2778 .is_some()
2779 {
2780 return Ok(Vec::new());
2781 }
2782 let response = parsed.response.trim();
2783 if response.is_empty() || response.contains('(') {
2786 return Ok(Vec::new());
2787 }
2788 if self.prediction_required_columns()?.contains(response) {
2791 return Ok(Vec::new());
2792 }
2793 Ok(vec![response.to_string()])
2794 }
2795
2796 fn add_auxiliary_formula_columns(
2799 &self,
2800 required: &mut std::collections::BTreeSet<String>,
2801 formula_or_rhs: &str,
2802 response: &str,
2803 ) -> Result<(), String> {
2804 let trimmed = formula_or_rhs.trim();
2805 if trimmed.is_empty() || trimmed == "1" {
2806 return Ok(());
2807 }
2808 let formula = if trimmed.contains('~') {
2809 trimmed.to_string()
2810 } else {
2811 format!("{response} ~ {trimmed}")
2812 };
2813 let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2814 parsed_term_column_names(&parsed.terms, required);
2815 Ok(())
2816 }
2817
2818 #[inline]
2819 pub fn predict_model_class(&self) -> PredictModelClass {
2820 match &self.payload().family_state {
2821 FittedFamily::Survival { .. }
2822 | FittedFamily::LatentSurvival { .. }
2823 | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2824 FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2825 FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2826 FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2827 PredictModelClass::GaussianLocationScale
2828 }
2829 FittedFamily::LocationScale { likelihood, .. }
2830 if is_dispersion_location_scale_response(&likelihood.response) =>
2831 {
2832 PredictModelClass::DispersionLocationScale
2833 }
2834 FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2835 FittedFamily::Standard { .. } => PredictModelClass::Standard,
2836 }
2837 }
2838
2839 pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2840 let payload = self.payload();
2841 let (knots, degree) = match (
2842 payload.linkwiggle_knots.as_ref(),
2843 payload.linkwiggle_degree,
2844 ) {
2845 (None, None) => return Ok(None),
2846 (Some(knots), Some(degree)) => (knots.clone(), degree),
2847 _ => {
2848 return Err(FittedModelError::SchemaMismatch {
2849 reason:
2850 "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2851 .to_string(),
2852 })
2853 }
2854 };
2855 let resolved_link = self.resolved_inverse_link()?;
2856 let saved_link_disallows_wiggle = resolved_link
2857 .as_ref()
2858 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2859 || payload
2860 .link
2861 .as_ref()
2862 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2863 if saved_link_disallows_wiggle {
2864 return Err(FittedModelError::IncompatibleConfig {
2865 reason: joint_wiggle_unsupported_link_message("link wiggle"),
2866 });
2867 }
2868 let beta = match self.predict_model_class() {
2869 PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2877 payload.beta_link_wiggle.clone().expect("checked is_some")
2878 }
2879 PredictModelClass::Standard => {
2880 let fit = payload.fit_result.as_ref().ok_or_else(|| {
2881 FittedModelError::MissingField {
2882 reason:
2883 "standard link-wiggle model is missing canonical fit_result payload"
2884 .to_string(),
2885 }
2886 })?;
2887 if fit.blocks.len() != 2
2888 || fit.blocks[0].role != BlockRole::Mean
2889 || fit.blocks[1].role != BlockRole::LinkWiggle
2890 {
2891 return Err(FittedModelError::SchemaMismatch {
2892 reason:
2893 "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2894 .to_string(),
2895 });
2896 }
2897 fit.block_by_role(BlockRole::LinkWiggle)
2898 .ok_or_else(|| FittedModelError::MissingField {
2899 reason:
2900 "standard link-wiggle model is missing LinkWiggle coefficient block"
2901 .to_string(),
2902 })?
2903 .beta
2904 .to_vec()
2905 }
2906 _ => payload
2907 .beta_link_wiggle
2908 .clone()
2909 .ok_or_else(|| FittedModelError::MissingField {
2910 reason:
2911 "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2912 .to_string(),
2913 })?,
2914 };
2915 Ok(Some(SavedLinkWiggleRuntime {
2916 knots,
2917 degree,
2918 beta,
2919 }))
2920 }
2921
2922 pub fn saved_baseline_time_wiggle(
2923 &self,
2924 ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2925 let payload = self.payload();
2926 if payload
2927 .survival_cause_count
2928 .is_some_and(|cause_count| cause_count > 1)
2929 && payload.beta_baseline_timewiggle.is_none()
2930 && payload.beta_baseline_timewiggle_by_cause.is_some()
2931 {
2932 return Err(FittedModelError::SchemaMismatch {
2933 reason:
2934 "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2935 .to_string(),
2936 });
2937 }
2938 match (
2939 payload.baseline_timewiggle_knots.as_ref(),
2940 payload.baseline_timewiggle_degree,
2941 payload.baseline_timewiggle_penalty_orders.as_ref(),
2942 payload.baseline_timewiggle_double_penalty,
2943 payload.beta_baseline_timewiggle.as_ref(),
2944 ) {
2945 (None, None, None, None, None) => Ok(None),
2946 (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2947 Ok(Some(SavedBaselineTimeWiggleRuntime {
2948 knots: knots.clone(),
2949 degree,
2950 penalty_orders: penalty_orders.clone(),
2951 double_penalty,
2952 beta: beta.clone(),
2953 }))
2954 }
2955 _ => Err(FittedModelError::SchemaMismatch {
2956 reason:
2957 "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2958 .to_string(),
2959 }),
2960 }
2961 }
2962
2963 #[inline]
2965 pub fn has_link_wiggle(&self) -> bool {
2966 self.saved_link_wiggle()
2967 .map(|runtime| runtime.is_some())
2968 .unwrap_or(false)
2969 }
2970
2971 #[inline]
2973 pub fn has_baseline_time_wiggle(&self) -> bool {
2974 let payload = self.payload();
2975 if payload
2976 .survival_cause_count
2977 .is_some_and(|cause_count| cause_count > 1)
2978 {
2979 return payload.baseline_timewiggle_knots.is_some()
2980 && payload.baseline_timewiggle_degree.is_some()
2981 && payload.baseline_timewiggle_penalty_orders.is_some()
2982 && payload.baseline_timewiggle_double_penalty.is_some()
2983 && payload.beta_baseline_timewiggle_by_cause.is_some();
2984 }
2985 self.saved_baseline_time_wiggle()
2986 .map(|runtime| runtime.is_some())
2987 .unwrap_or(false)
2988 }
2989
2990 #[inline]
3015 pub fn prediction_uses_posterior_mean(&self) -> bool {
3016 let family = self.likelihood();
3017 let curved_family = match &family.response {
3018 ResponseFamily::Gaussian => false,
3021 ResponseFamily::Poisson
3023 | ResponseFamily::Gamma
3024 | ResponseFamily::Tweedie { .. }
3025 | ResponseFamily::NegativeBinomial { .. } => true,
3026 ResponseFamily::Beta { .. } => true,
3028 ResponseFamily::RoystonParmar => true,
3030 ResponseFamily::Binomial => matches!(
3033 &family.link,
3034 InverseLink::Standard(_)
3035 | InverseLink::Sas(_)
3036 | InverseLink::BetaLogistic(_)
3037 | InverseLink::Mixture(_)
3038 | InverseLink::LatentCLogLog(_)
3039 ),
3040 };
3041 curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3042 }
3043
3044 pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3045 self.payload().validate_payload_version()?;
3046 if matches!(
3047 self.predict_model_class(),
3048 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3049 ) {
3050 if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3051 runtime.validate_exact_replay_contract().map_err(|err| {
3052 FittedModelError::PayloadCorrupt {
3053 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3054 }
3055 })?;
3056 }
3057 if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3058 runtime.validate_exact_replay_contract().map_err(|err| {
3059 FittedModelError::PayloadCorrupt {
3060 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3061 }
3062 })?;
3063 }
3064 }
3065 let runtime = SavedPredictionRuntime {
3066 model_class: self.predict_model_class(),
3067 likelihood: self.likelihood(),
3068 inverse_link: self.resolved_inverse_link()?,
3069 link_wiggle: self.saved_link_wiggle()?,
3070 baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3071 score_warp: self.payload().score_warp_runtime.clone(),
3072 link_deviation: self.payload().link_deviation_runtime.clone(),
3073 latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3074 latent_z_conditional_calibration: self
3075 .payload()
3076 .latent_z_conditional_calibration
3077 .clone(),
3078 influence_absorber_width: self.payload().influence_absorber_width,
3079 };
3080 if matches!(
3081 runtime.model_class,
3082 PredictModelClass::GaussianLocationScale
3083 | PredictModelClass::BinomialLocationScale
3084 | PredictModelClass::DispersionLocationScale
3085 ) {
3086 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3087 FittedModelError::MissingField {
3088 reason: "location-scale model is missing canonical fit_result payload"
3089 .to_string(),
3090 }
3091 })?;
3092 validate_location_scale_saved_fit(
3093 fit,
3094 runtime.model_class,
3095 runtime.link_wiggle.as_ref(),
3096 )?;
3097 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3098 && self
3099 .payload()
3100 .survival_likelihood
3101 .as_deref()
3102 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3103 {
3104 validate_survival_location_scale_saved_fit(
3105 self.payload(),
3106 runtime.link_wiggle.as_ref(),
3107 )?;
3108 } else if matches!(
3109 runtime.model_class,
3110 PredictModelClass::BernoulliMarginalSlope
3111 ) {
3112 let unified =
3113 self.payload()
3114 .unified
3115 .as_ref()
3116 .ok_or_else(|| FittedModelError::MissingField {
3117 reason: "marginal-slope model is missing unified fit payload; refit"
3118 .to_string(),
3119 })?;
3120 validate_marginal_slope_saved_fit(
3121 unified,
3122 runtime.score_warp.as_ref(),
3123 runtime.link_deviation.as_ref(),
3124 "unified",
3125 )?;
3126 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3127 && self
3128 .payload()
3129 .survival_likelihood
3130 .as_deref()
3131 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3132 {
3133 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3134 FittedModelError::MissingField {
3135 reason: "survival marginal-slope model is missing canonical fit_result payload"
3136 .to_string(),
3137 }
3138 })?;
3139 validate_survival_marginal_slope_saved_fit(
3140 fit,
3141 runtime.score_warp.as_ref(),
3142 runtime.link_deviation.as_ref(),
3143 "fit_result",
3144 )?;
3145 }
3146 Ok(runtime)
3147 }
3148
3149 pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3150 let payload = self.payload();
3151 let raw = match &payload.family_state {
3152 FittedFamily::Standard {
3153 likelihood,
3154 sas_state,
3155 ..
3156 } if likelihood.is_binomial_sas() => {
3157 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3158 reason: "binomial-sas model is missing state in family_state.sas_state"
3159 .to_string(),
3160 })?
3161 }
3162 FittedFamily::LocationScale {
3163 likelihood,
3164 base_link,
3165 } if likelihood.is_binomial_sas() => match base_link {
3166 Some(InverseLink::Sas(state)) => *state,
3167 _ => {
3168 return Err(FittedModelError::MissingField {
3169 reason: "binomial-sas location-scale model is missing SAS base_link state"
3170 .to_string(),
3171 });
3172 }
3173 },
3174 _ => return Ok(None),
3175 };
3176 state_from_sasspec(SasLinkSpec {
3177 initial_epsilon: raw.epsilon,
3178 initial_log_delta: raw.log_delta,
3179 })
3180 .map(Some)
3181 .map_err(|e| FittedModelError::PayloadCorrupt {
3182 reason: format!("invalid saved SAS link state: {e}"),
3183 })
3184 }
3185
3186 pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3187 let payload = self.payload();
3188 let raw = match &payload.family_state {
3189 FittedFamily::Standard {
3190 likelihood,
3191 sas_state,
3192 ..
3193 } if likelihood.is_binomial_beta_logistic() => {
3194 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3195 reason:
3196 "binomial-beta-logistic model is missing state in family_state.sas_state"
3197 .to_string(),
3198 })?
3199 }
3200 FittedFamily::LocationScale {
3201 likelihood,
3202 base_link,
3203 } if likelihood.is_binomial_beta_logistic() => match base_link {
3204 Some(InverseLink::BetaLogistic(state)) => *state,
3205 _ => {
3206 return Err(FittedModelError::MissingField {
3207 reason:
3208 "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3209 .to_string(),
3210 });
3211 }
3212 },
3213 _ => return Ok(None),
3214 };
3215 state_from_beta_logisticspec(SasLinkSpec {
3216 initial_epsilon: raw.epsilon,
3217 initial_log_delta: raw.log_delta,
3218 })
3219 .map(Some)
3220 .map_err(|e| FittedModelError::PayloadCorrupt {
3221 reason: format!("invalid saved Beta-Logistic link state: {e}"),
3222 })
3223 }
3224
3225 pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3226 let payload = self.payload();
3227 match &payload.family_state {
3228 FittedFamily::Standard {
3229 likelihood,
3230 mixture_state,
3231 ..
3232 } if likelihood.is_binomial_mixture() => mixture_state
3233 .clone()
3234 .ok_or_else(|| FittedModelError::MissingField {
3235 reason: "binomial-mixture model is missing state in family_state.mixture_state"
3236 .to_string(),
3237 })
3238 .map(Some),
3239 FittedFamily::LocationScale {
3240 likelihood,
3241 base_link,
3242 } if likelihood.is_binomial_mixture() => match base_link {
3243 Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3244 _ => Err(FittedModelError::MissingField {
3245 reason:
3246 "binomial-mixture location-scale model is missing mixture base_link state"
3247 .to_string(),
3248 }),
3249 },
3250 _ => Ok(None),
3251 }
3252 }
3253
3254 pub fn saved_latent_cloglog_state(
3255 &self,
3256 ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3257 let payload = self.payload();
3258 match &payload.family_state {
3259 FittedFamily::Standard {
3260 likelihood,
3261 latent_cloglog_state,
3262 ..
3263 } if likelihood.is_latent_cloglog() => latent_cloglog_state
3264 .ok_or_else(|| FittedModelError::MissingField {
3265 reason:
3266 "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3267 .to_string(),
3268 })
3269 .map(Some),
3270 _ => Ok(None),
3271 }
3272 }
3273
3274 pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3275 let stateful = if let Some(state) = self.saved_mixture_state()? {
3276 Some(InverseLink::Mixture(state))
3277 } else if let Some(state) = self.saved_latent_cloglog_state()? {
3278 Some(InverseLink::LatentCLogLog(state))
3279 } else if let Some(state) = self.saved_beta_logistic_state()? {
3280 Some(InverseLink::BetaLogistic(state))
3281 } else {
3282 self.saved_sas_state()?.map(InverseLink::Sas)
3283 };
3284 match &self.payload().family_state {
3285 FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3286 FittedFamily::Standard { link, .. } => {
3287 Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3288 }
3289 FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3290 FittedFamily::Survival { .. }
3291 | FittedFamily::LatentSurvival { .. }
3292 | FittedFamily::LatentBinary { .. } => Ok(None),
3293 FittedFamily::TransformationNormal { .. } => Ok(None),
3294 }
3295 }
3296
3297 const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3306
3307 pub fn measure_jet_extrapolation_variance(
3334 &self,
3335 data: ndarray::ArrayView2<'_, f64>,
3336 col_map: &HashMap<String, usize>,
3337 ) -> Result<Option<Array1<f64>>, FittedModelError> {
3338 use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3339 use gam_terms::smooth::build_term_collection_design;
3340 use gam_terms::smooth::SmoothBasisSpec;
3341 let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3342 return Ok(None);
3343 };
3344 if data.nrows() == 0
3345 || !saved_spec
3346 .smooth_terms
3347 .iter()
3348 .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3349 {
3350 return Ok(None);
3351 }
3352 let fit = self
3353 .fit_result
3354 .as_ref()
3355 .ok_or_else(|| FittedModelError::MissingField {
3356 reason: "measure-jet extrapolation variance requires the canonical \
3357 fit_result payload; refit"
3358 .to_string(),
3359 })?;
3360 let spec = crate::survival::predict::resolve_termspec_for_prediction(
3361 &self.resolved_termspec,
3362 self.training_headers.as_ref(),
3363 col_map,
3364 "resolved_termspec",
3365 )
3366 .map_err(|e| FittedModelError::SchemaMismatch {
3367 reason: format!("measure-jet extrapolation variance: {e}"),
3368 })?;
3369 let probe = data.slice(ndarray::s![0..1, ..]);
3376 let design = build_term_collection_design(probe, &spec).map_err(|e| {
3377 FittedModelError::SchemaMismatch {
3378 reason: format!(
3379 "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3380 ),
3381 }
3382 })?;
3383 let lambdas = &fit.lambdas;
3384 let phi_scale = fit.coefficient_covariance_scale();
3389 let mut total = Array1::<f64>::zeros(data.nrows());
3390 let mut contributed = false;
3391 for term in &spec.smooth_terms {
3392 let SmoothBasisSpec::MeasureJet {
3393 feature_cols,
3394 spec: mj,
3395 input_scales,
3396 } = &term.basis
3397 else {
3398 continue;
3399 };
3400 let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3401 (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3402 else {
3403 log::warn!(
3404 "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3405 quadrature); skipping its extrapolation variance",
3406 term.name
3407 );
3408 continue;
3409 };
3410 let n_levels = frozen.eps_band.len();
3411 let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3418 lambdas
3419 .get(global_index)
3420 .copied()
3421 .ok_or_else(|| FittedModelError::SchemaMismatch {
3422 reason: format!(
3423 "measure-jet term '{}': penalty global index {global_index} out \
3424 of bounds for {} fitted lambdas",
3425 term.name,
3426 lambdas.len()
3427 ),
3428 })
3429 };
3430 let mut per_scale: Vec<(usize, f64)> = Vec::new();
3431 let mut fused: Option<f64> = None;
3432 for info in &design.penaltyinfo {
3433 if info.termname.as_deref() != Some(term.name.as_str()) {
3434 continue;
3435 }
3436 match &info.penalty.source {
3437 PenaltySource::Other(label) => {
3438 if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3439 let level: usize = level_txt.parse().map_err(|_| {
3440 FittedModelError::SchemaMismatch {
3441 reason: format!(
3442 "measure-jet term '{}': unparseable penalty label \
3443 '{label}'",
3444 term.name
3445 ),
3446 }
3447 })?;
3448 per_scale.push((level, read_lambda(info.global_index)?));
3449 }
3450 }
3451 PenaltySource::Primary => {
3452 fused = Some(read_lambda(info.global_index)?);
3453 }
3454 _ => {}
3455 }
3456 }
3457 let mut lambda_phys = Vec::with_capacity(n_levels);
3458 let spectrum = if per_scale.is_empty() {
3459 let Some(lam) = fused else {
3460 log::warn!(
3461 "measure-jet term '{}' has no fitted amplitude in the penalty \
3462 layout; skipping its extrapolation variance",
3463 term.name
3464 );
3465 continue;
3466 };
3467 let Some(c) = frozen.fused_penalty_normalization_scale else {
3468 log::warn!(
3469 "measure-jet term '{}' is missing the fused penalty normalization scale; \
3470 skipping its extrapolation variance",
3471 term.name
3472 );
3473 continue;
3474 };
3475 MeasureJetExtrapolationSpectrum::Fused(lam / c)
3476 } else {
3477 per_scale.sort_by_key(|&(level, _)| level);
3478 let levels_complete = per_scale.len() == n_levels
3479 && per_scale
3480 .iter()
3481 .enumerate()
3482 .all(|(i, &(level, _))| level == i);
3483 if !levels_complete {
3484 log::warn!(
3485 "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3486 scales; skipping its extrapolation variance",
3487 term.name,
3488 per_scale.len(),
3489 n_levels
3490 );
3491 continue;
3492 }
3493 if frozen.penalty_normalization_scales.len() != n_levels {
3494 log::warn!(
3495 "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3496 band scales; skipping its extrapolation variance",
3497 term.name,
3498 frozen.penalty_normalization_scales.len(),
3499 n_levels
3500 );
3501 continue;
3502 }
3503 lambda_phys.extend(
3504 per_scale
3505 .iter()
3506 .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3507 );
3508 MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3509 };
3510 let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3515 for (j, &col) in feature_cols.iter().enumerate() {
3516 if col >= data.ncols() {
3517 return Err(FittedModelError::SchemaMismatch {
3518 reason: format!(
3519 "measure-jet term '{}': prediction column {col} out of bounds \
3520 for {} data columns",
3521 term.name,
3522 data.ncols()
3523 ),
3524 });
3525 }
3526 queries.column_mut(j).assign(&data.column(col));
3527 }
3528 if let Some(scales) = input_scales {
3529 if scales.len() != feature_cols.len() {
3530 return Err(FittedModelError::SchemaMismatch {
3531 reason: format!(
3532 "measure-jet term '{}': {} input scales for {} axes",
3533 term.name,
3534 scales.len(),
3535 feature_cols.len()
3536 ),
3537 });
3538 }
3539 for (j, &scale) in scales.iter().enumerate() {
3540 queries.column_mut(j).mapv_inplace(|v| v / scale);
3541 }
3542 }
3543 let support = gam_terms::basis::measure_jet_support_curve(
3544 queries.view(),
3545 centers.view(),
3546 frozen.masses.view(),
3547 &frozen.eps_band,
3548 )
3549 .map_err(|e| FittedModelError::SchemaMismatch {
3550 reason: format!(
3551 "measure-jet term '{}': support curve failed: {e}",
3552 term.name
3553 ),
3554 })?;
3555 for i in 0..data.nrows() {
3556 let v = gam_terms::basis::measure_jet_extrapolation_variance(
3557 support.row(i),
3558 &frozen.eps_band,
3559 &frozen.support_means,
3560 spectrum,
3561 Self::MEASURE_JET_COVERAGE_FLOOR,
3562 )
3563 .map_err(|e| FittedModelError::SchemaMismatch {
3564 reason: format!(
3565 "measure-jet term '{}': extrapolation variance failed: {e}",
3566 term.name
3567 ),
3568 })?;
3569 total[i] += phi_scale * v;
3570 }
3571 contributed = true;
3572 }
3573 Ok(contributed.then_some(total))
3574 }
3575
3576 pub fn unified(&self) -> Option<&UnifiedFitResult> {
3578 self.payload().unified.as_ref()
3579 }
3580
3581 pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3582 let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3583 reason: format!("failed to read model '{}': {e}", path.display()),
3584 })?;
3585 let model: Self =
3586 serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3587 reason: format!("failed to parse model json: {e}"),
3588 })?;
3589 let model = model.with_synchronized_stateful_link_metadata();
3590 model.validate_for_persistence()?;
3591 model.validate_numeric_finiteness()?;
3592 Ok(model)
3593 }
3594
3595 pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3596 let normalized = self.clone().with_synchronized_stateful_link_metadata();
3597 normalized.validate_for_persistence()?;
3598 normalized.validate_numeric_finiteness()?;
3599 let parent = path.parent().unwrap_or_else(|| Path::new("."));
3606 let file_name = path
3607 .file_name()
3608 .and_then(|s| s.to_str())
3609 .unwrap_or("model.json");
3610 let pid = std::process::id();
3611 let nanos = std::time::SystemTime::now()
3612 .duration_since(std::time::UNIX_EPOCH)
3613 .map(|d| d.as_nanos())
3614 .unwrap_or(0);
3615 let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3616 let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3617 reason: format!("failed to write model '{}': {e}", tmp.display()),
3618 })?;
3619 let mut writer = std::io::BufWriter::new(file);
3620 let ser_result = serde_json::to_writer(&mut writer, &normalized);
3621 if let Err(e) = ser_result {
3622 std::io::Write::flush(&mut writer).ok();
3625 drop(writer);
3626 fs::remove_file(&tmp).ok();
3627 return Err(FittedModelError::PayloadCorrupt {
3628 reason: format!("failed to serialize model: {e}"),
3629 });
3630 }
3631 std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3632 reason: format!("failed to write model '{}': {e}", tmp.display()),
3633 })?;
3634 let inner = writer
3636 .into_inner()
3637 .map_err(|e| FittedModelError::PayloadCorrupt {
3638 reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3639 })?;
3640 inner.sync_all().ok();
3641 drop(inner);
3642 if let Err(e) = fs::rename(&tmp, path) {
3643 fs::remove_file(&tmp).ok();
3644 return Err(FittedModelError::PayloadCorrupt {
3645 reason: format!("failed to publish model '{}': {e}", path.display()),
3646 });
3647 }
3648 if let Ok(d) = fs::File::open(parent) {
3653 d.sync_all().ok();
3654 }
3655 Ok(())
3656 }
3657
3658 pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3659 self.data_schema
3660 .as_ref()
3661 .ok_or_else(|| FittedModelError::MissingField {
3662 reason: "model is missing data_schema; refit".to_string(),
3663 })
3664 }
3665
3666 pub fn saved_spline_scan(
3670 &self,
3671 ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3672 let Some(saved) = self.spline_scan.as_ref() else {
3673 return Ok(None);
3674 };
3675 let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3676 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3677 Ok(Some((saved.feature_column.as_str(), fit)))
3678 }
3679
3680 pub fn saved_residual_cascade(
3685 &self,
3686 ) -> Result<
3687 Option<(
3688 &[String],
3689 gam_solve::residual_cascade::ResidualCascadeFit,
3690 )>,
3691 FittedModelError,
3692 > {
3693 let Some(saved) = self.residual_cascade.as_ref() else {
3694 return Ok(None);
3695 };
3696 let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3697 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3698 Ok(Some((saved.feature_columns.as_slice(), fit)))
3699 }
3700
3701 pub fn random_effect_group_columns(&self) -> HashSet<String> {
3702 let Some(training_headers) = self.training_headers.as_ref() else {
3703 return HashSet::new();
3704 };
3705 let mut out = HashSet::<String>::new();
3706 for spec in self.saved_term_specs() {
3707 for term in &spec.random_effect_terms {
3708 if let Some(name) = training_headers.get(term.feature_col) {
3709 out.insert(name.clone());
3710 }
3711 }
3712 }
3713 out
3714 }
3715
3716 pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3717 self.validate_payload_version()?;
3731 if let Some(scan) = self.spline_scan.as_ref() {
3732 if self.fit_result.is_some() || self.unified.is_some() {
3737 return Err(FittedModelError::SchemaMismatch {
3738 reason: "spline-scan model must not also carry a dense fit_result/unified \
3739 payload; the representations are mutually exclusive"
3740 .to_string(),
3741 });
3742 }
3743 if self.model_kind != ModelKind::Standard
3744 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3745 {
3746 return Err(FittedModelError::SchemaMismatch {
3747 reason: format!(
3748 "spline-scan representation requires a standard Gaussian-identity model; \
3749 got model_kind={:?}, likelihood={:?}",
3750 self.model_kind,
3751 self.family_state.likelihood()
3752 ),
3753 });
3754 }
3755 if scan.feature_column.is_empty() {
3756 return Err(FittedModelError::MissingField {
3757 reason: "spline-scan model is missing its feature column name; refit"
3758 .to_string(),
3759 });
3760 }
3761 gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3762 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3763 if self.data_schema.is_none() {
3769 return Err(FittedModelError::MissingField {
3770 reason: "spline-scan model is missing data_schema; refit".to_string(),
3771 });
3772 }
3773 if self.training_headers.is_none() {
3774 return Err(FittedModelError::MissingField {
3775 reason: "spline-scan model is missing training_headers; refit".to_string(),
3776 });
3777 }
3778 return Ok(());
3779 } else if let Some(cascade) = self.residual_cascade.as_ref() {
3780 if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3784 return Err(FittedModelError::SchemaMismatch {
3785 reason: "residual-cascade model must not also carry spline_scan / \
3786 fit_result / unified payloads; the representations are \
3787 mutually exclusive"
3788 .to_string(),
3789 });
3790 }
3791 if self.model_kind != ModelKind::Standard
3792 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3793 {
3794 return Err(FittedModelError::SchemaMismatch {
3795 reason: format!(
3796 "residual-cascade representation requires a standard Gaussian-identity \
3797 model; got model_kind={:?}, likelihood={:?}",
3798 self.model_kind,
3799 self.family_state.likelihood()
3800 ),
3801 });
3802 }
3803 if cascade.feature_columns.is_empty()
3804 || !(2..=3).contains(&cascade.feature_columns.len())
3805 {
3806 return Err(FittedModelError::MissingField {
3807 reason: format!(
3808 "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3809 cascade.feature_columns.len()
3810 ),
3811 });
3812 }
3813 gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3814 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3815 if self.data_schema.is_none() {
3816 return Err(FittedModelError::MissingField {
3817 reason: "residual-cascade model is missing data_schema; refit".to_string(),
3818 });
3819 }
3820 if self.training_headers.is_none() {
3821 return Err(FittedModelError::MissingField {
3822 reason: "residual-cascade model is missing training_headers; refit".to_string(),
3823 });
3824 }
3825 return Ok(());
3826 } else if self.fit_result.is_none() {
3827 return Err(FittedModelError::MissingField {
3828 reason: "model is missing canonical fit_result payload; refit".to_string(),
3829 });
3830 }
3831 if self.data_schema.is_none() {
3832 return Err(FittedModelError::MissingField {
3833 reason: "model is missing data_schema; refit".to_string(),
3834 });
3835 }
3836 if self.training_headers.is_none() {
3837 return Err(FittedModelError::MissingField {
3838 reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3839 .to_string(),
3840 });
3841 }
3842 let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3843 FittedModelError::MissingField {
3844 reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3845 .to_string(),
3846 }
3847 })?;
3848 validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3849
3850 if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3851 return Err(FittedModelError::MissingField {
3852 reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3853 .to_string(),
3854 });
3855 }
3856 if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3857 validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3858 }
3859 if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3860 let score = self.transformation_score_calibration.ok_or_else(|| {
3861 FittedModelError::MissingField {
3862 reason: "transformation-normal model is missing transformation_score_calibration; refit"
3863 .to_string(),
3864 }
3865 })?;
3866 score.validate("transformation-normal model")?;
3867 }
3868 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3869 if self.formula_logslope.is_none() {
3870 return Err(FittedModelError::MissingField {
3871 reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3872 });
3873 }
3874 if self.z_column.is_none() {
3875 return Err(FittedModelError::MissingField {
3876 reason: "marginal-slope model is missing z_column; refit".to_string(),
3877 });
3878 }
3879 let z_normalization =
3880 self.latent_z_normalization
3881 .ok_or_else(|| FittedModelError::MissingField {
3882 reason: "marginal-slope model is missing latent_z_normalization; refit"
3883 .to_string(),
3884 })?;
3885 z_normalization.validate("marginal-slope model")?;
3886 let latent_measure =
3887 self.latent_measure
3888 .as_ref()
3889 .ok_or_else(|| FittedModelError::MissingField {
3890 reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3891 })?;
3892 latent_measure
3893 .validate("marginal-slope model latent_measure")
3894 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3895 if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3896 return Err(FittedModelError::MissingField {
3897 reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3898 });
3899 }
3900 if self.resolved_termspec_logslope.as_ref().is_none() {
3901 return Err(FittedModelError::MissingField {
3902 reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3903 .to_string(),
3904 });
3905 }
3906 match self.family_state.frailty() {
3907 Some(FrailtySpec::None)
3908 | Some(FrailtySpec::GaussianShift {
3909 sigma_fixed: Some(_),
3910 }) => {}
3911 Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3912 return Err(FittedModelError::IncompatibleConfig {
3913 reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3914 .to_string(),
3915 });
3916 }
3917 Some(FrailtySpec::HazardMultiplier { .. }) => {
3918 return Err(FittedModelError::IncompatibleConfig {
3919 reason: "marginal-slope model does not support HazardMultiplier frailty"
3920 .to_string(),
3921 });
3922 }
3923 None => {
3924 return Err(FittedModelError::MissingField {
3925 reason: "marginal-slope model is missing family_state.frailty; refit"
3926 .to_string(),
3927 });
3928 }
3929 }
3930 }
3931
3932 if let FittedFamily::Survival {
3933 survival_likelihood,
3934 frailty,
3935 ..
3936 } = &self.family_state
3937 {
3938 if matches!(
3939 survival_likelihood.as_deref(),
3940 Some("latent") | Some("latent-binary")
3941 ) {
3942 return Err(FittedModelError::SchemaMismatch {
3943 reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3944 .to_string(),
3945 });
3946 }
3947 if survival_likelihood.as_deref() == Some("marginal-slope") {
3948 if self.formula_logslope.is_none() {
3949 return Err(FittedModelError::MissingField {
3950 reason: "survival marginal-slope model is missing formula_logslope; refit"
3951 .to_string(),
3952 });
3953 }
3954 if self.z_column.is_none() {
3955 return Err(FittedModelError::MissingField {
3956 reason: "survival marginal-slope model is missing z_column; refit"
3957 .to_string(),
3958 });
3959 }
3960 let z_normalization =
3961 self.latent_z_normalization
3962 .ok_or_else(|| {
3963 FittedModelError::MissingField {
3964 reason:
3965 "survival marginal-slope model is missing latent_z_normalization; refit"
3966 .to_string(),
3967 }
3968 })?;
3969 z_normalization.validate("survival marginal-slope model")?;
3970 let latent_measure =
3971 self.latent_measure
3972 .as_ref()
3973 .ok_or_else(|| FittedModelError::MissingField {
3974 reason:
3975 "survival marginal-slope model is missing latent_measure; refit"
3976 .to_string(),
3977 })?;
3978 latent_measure
3979 .validate("survival marginal-slope model latent_measure")
3980 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3981 if self.logslope_baseline.is_none() {
3982 return Err(FittedModelError::MissingField {
3983 reason: "survival marginal-slope model is missing logslope_baseline; refit"
3984 .to_string(),
3985 });
3986 }
3987 if self.resolved_termspec_logslope.as_ref().is_none() {
3988 return Err(FittedModelError::MissingField {
3989 reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3990 .to_string(),
3991 });
3992 }
3993 match frailty {
3994 FrailtySpec::None
3995 | FrailtySpec::GaussianShift {
3996 sigma_fixed: Some(_),
3997 } => {}
3998 FrailtySpec::GaussianShift { sigma_fixed: None } => {
3999 return Err(FittedModelError::IncompatibleConfig {
4000 reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4001 .to_string(),
4002 });
4003 }
4004 FrailtySpec::HazardMultiplier { .. } => {
4005 return Err(FittedModelError::IncompatibleConfig {
4006 reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4007 .to_string(),
4008 });
4009 }
4010 }
4011 } else if !matches!(frailty, FrailtySpec::None) {
4012 return Err(FittedModelError::IncompatibleConfig {
4013 reason:
4014 "non-marginal survival models do not currently persist a frailty modifier"
4015 .to_string(),
4016 });
4017 }
4018 if self.survival_time_basis.is_none() {
4026 return Err(FittedModelError::MissingField {
4027 reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4028 });
4029 }
4030 if self.survival_time_anchor.is_none() {
4031 return Err(FittedModelError::MissingField {
4032 reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4033 });
4034 }
4035 }
4036 if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4037 match frailty {
4038 FrailtySpec::HazardMultiplier {
4039 sigma_fixed: Some(_),
4040 ..
4041 } => {}
4042 FrailtySpec::HazardMultiplier {
4043 sigma_fixed: None, ..
4044 } => {
4045 return Err(FittedModelError::IncompatibleConfig {
4046 reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4047 .to_string(),
4048 });
4049 }
4050 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4051 return Err(FittedModelError::IncompatibleConfig {
4052 reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4053 .to_string(),
4054 });
4055 }
4056 }
4057 if self.survival_likelihood.as_deref() != Some("latent") {
4058 return Err(FittedModelError::SchemaMismatch {
4059 reason: "latent survival model must persist survival_likelihood=latent"
4060 .to_string(),
4061 });
4062 }
4063 }
4064 if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4065 match frailty {
4066 FrailtySpec::HazardMultiplier {
4067 sigma_fixed: Some(_),
4068 ..
4069 } => {}
4070 FrailtySpec::HazardMultiplier {
4071 sigma_fixed: None, ..
4072 } => {
4073 return Err(FittedModelError::IncompatibleConfig {
4074 reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4075 .to_string(),
4076 });
4077 }
4078 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4079 return Err(FittedModelError::IncompatibleConfig {
4080 reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4081 .to_string(),
4082 });
4083 }
4084 }
4085 if self.survival_likelihood.as_deref() != Some("latent-binary") {
4086 return Err(FittedModelError::SchemaMismatch {
4087 reason: "latent binary model must persist survival_likelihood=latent-binary"
4088 .to_string(),
4089 });
4090 }
4091 }
4092
4093 let family_likelihood = match &self.family_state {
4094 FittedFamily::Standard { likelihood, .. }
4095 | FittedFamily::LocationScale { likelihood, .. }
4096 | FittedFamily::MarginalSlope { likelihood, .. }
4097 | FittedFamily::Survival { likelihood, .. }
4098 | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4099 FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4100 };
4101 let is_standard_or_location_scale = matches!(
4102 self.family_state,
4103 FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4104 );
4105 if is_standard_or_location_scale
4106 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4107 {
4108 self.saved_sas_state()?;
4109 }
4110 if is_standard_or_location_scale
4111 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4112 {
4113 self.saved_beta_logistic_state()?;
4114 }
4115 if is_standard_or_location_scale
4116 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4117 {
4118 self.saved_mixture_state()?;
4119 }
4120 if matches!(self.family_state, FittedFamily::Standard { .. })
4121 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4122 {
4123 self.saved_latent_cloglog_state()?;
4124 }
4125 if matches!(self.family_state, FittedFamily::LocationScale { .. })
4126 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4127 {
4128 return Err(FittedModelError::IncompatibleConfig {
4129 reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4130 .to_string(),
4131 });
4132 }
4133 if matches!(self.family_state, FittedFamily::Survival { .. })
4134 && self.survival_likelihood.is_none()
4135 {
4136 return Err(FittedModelError::MissingField {
4137 reason: "saved survival model is missing survival_likelihood metadata; refit"
4138 .to_string(),
4139 });
4140 }
4141 let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4142 || self.linkwiggle_degree.is_some()
4143 || self.beta_link_wiggle.is_some()
4144 || self
4145 .fit_result
4146 .as_ref()
4147 .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4148 .is_some();
4149 let saved_link_wiggle = self.saved_link_wiggle()?;
4150 if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4151 return Err(FittedModelError::SchemaMismatch {
4152 reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4153 .to_string(),
4154 });
4155 }
4156 let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4157 || self.baseline_timewiggle_degree.is_some()
4158 || self.baseline_timewiggle_penalty_orders.is_some()
4159 || self.baseline_timewiggle_double_penalty.is_some()
4160 || self.beta_baseline_timewiggle.is_some()
4161 || self.beta_baseline_timewiggle_by_cause.is_some();
4162 let is_joint_cause_specific = self
4163 .survival_cause_count
4164 .is_some_and(|cause_count| cause_count > 1);
4165 if has_any_saved_baseline_time_wiggle {
4166 if is_joint_cause_specific {
4167 let complete = self.baseline_timewiggle_knots.is_some()
4168 && self.baseline_timewiggle_degree.is_some()
4169 && self.baseline_timewiggle_penalty_orders.is_some()
4170 && self.baseline_timewiggle_double_penalty.is_some()
4171 && self.beta_baseline_timewiggle_by_cause.is_some();
4172 if !complete {
4173 return Err(FittedModelError::SchemaMismatch {
4174 reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4175 .to_string(),
4176 });
4177 }
4178 } else if self.saved_baseline_time_wiggle()?.is_none() {
4179 return Err(FittedModelError::SchemaMismatch {
4180 reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4181 .to_string(),
4182 });
4183 }
4184 }
4185 if self
4186 .survival_likelihood
4187 .as_deref()
4188 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4189 {
4190 validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4191 }
4192
4193 if let Some(runtime) = self.score_warp_runtime.as_ref() {
4203 runtime.validate_exact_replay_contract().map_err(|err| {
4204 FittedModelError::PayloadCorrupt {
4205 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4206 }
4207 })?;
4208 }
4209 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4210 runtime.validate_exact_replay_contract().map_err(|err| {
4211 FittedModelError::PayloadCorrupt {
4212 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4213 }
4214 })?;
4215 }
4216 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4217 validate_marginal_slope_saved_fit(
4218 self.fit_result.as_ref().expect("checked above"),
4219 self.score_warp_runtime.as_ref(),
4220 self.link_deviation_runtime.as_ref(),
4221 "fit_result",
4222 )?;
4223 let unified = self
4224 .unified
4225 .as_ref()
4226 .ok_or_else(|| FittedModelError::MissingField {
4227 reason: "marginal-slope model is missing unified fit payload; refit"
4228 .to_string(),
4229 })?;
4230 validate_marginal_slope_saved_fit(
4231 unified,
4232 self.score_warp_runtime.as_ref(),
4233 self.link_deviation_runtime.as_ref(),
4234 "unified",
4235 )?;
4236 }
4237 if self
4238 .survival_likelihood
4239 .as_deref()
4240 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4241 {
4242 validate_survival_marginal_slope_saved_fit(
4243 self.fit_result.as_ref().expect("checked above"),
4244 self.score_warp_runtime.as_ref(),
4245 self.link_deviation_runtime.as_ref(),
4246 "fit_result",
4247 )?;
4248 if let Some(unified) = self.unified.as_ref() {
4249 validate_survival_marginal_slope_saved_fit(
4250 unified,
4251 self.score_warp_runtime.as_ref(),
4252 self.link_deviation_runtime.as_ref(),
4253 "unified",
4254 )?;
4255 }
4256 }
4257
4258 Ok(())
4269 }
4270
4271 pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4272 let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4273 if let Some(fit) = self.fit_result.as_ref() {
4274 fit.validate_numeric_finiteness()
4275 .map_err(|e| corrupt(e.to_string()))?;
4276 }
4277
4278 for (name, opt) in [
4279 ("survival_baseline_scale", self.survival_baseline_scale),
4280 ("survival_baseline_shape", self.survival_baseline_shape),
4281 ("survival_baseline_rate", self.survival_baseline_rate),
4282 ("survival_baseline_makeham", self.survival_baseline_makeham),
4283 (
4284 "survival_time_smooth_lambda",
4285 self.survival_time_smooth_lambda,
4286 ),
4287 ("survival_time_anchor", self.survival_time_anchor),
4288 ("survivalridge_lambda", self.survivalridge_lambda),
4289 ] {
4290 if let Some(v) = opt {
4291 ensure_finite_scalar(name, v).map_err(corrupt)?;
4292 }
4293 }
4294
4295 if let Some(v) = self.beta_noise.as_ref() {
4296 validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4297 }
4298 if let Some(v) = self.noise_projection.as_ref() {
4299 validate_all_finite("noise_projection", v.iter().flatten().copied())
4300 .map_err(corrupt)?;
4301 if self.noise_projection_ridge_alpha.is_none() {
4302 return Err(FittedModelError::MissingField {
4303 reason:
4304 "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4305 .to_string(),
4306 });
4307 }
4308 }
4309 if let Some(v) = self.noise_center.as_ref() {
4310 validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4311 }
4312 if let Some(v) = self.noise_scale.as_ref() {
4313 validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4314 }
4315 if let Some(v) = self.noise_projection_ridge_alpha {
4316 ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4317 if v < 0.0 {
4318 return Err(FittedModelError::InvalidInput {
4319 reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4320 });
4321 }
4322 }
4323 if let Some(v) = self.gaussian_response_scale {
4324 ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4325 }
4326 if let Some(v) = self.beta_link_wiggle.as_ref() {
4327 validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4328 }
4329 if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4330 validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4331 }
4332 if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4333 validate_all_finite(
4334 "beta_baseline_timewiggle_by_cause",
4335 v.iter().flatten().copied(),
4336 )
4337 .map_err(corrupt)?;
4338 }
4339 if let Some(v) = self.latent_z_normalization {
4340 v.validate("latent_z_normalization")?;
4341 }
4342 if let Some(v) = self.latent_measure.as_ref() {
4343 v.validate("latent_measure").map_err(corrupt)?;
4344 }
4345 if let Some(v) = self.survival_beta_time.as_ref() {
4346 validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4347 }
4348 if let Some(v) = self.survival_beta_threshold.as_ref() {
4349 validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4350 }
4351 if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4352 validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4353 }
4354 if let Some(v) = self.survival_noise_projection.as_ref() {
4355 validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4356 .map_err(corrupt)?;
4357 if self.survival_noise_projection_ridge_alpha.is_none() {
4358 return Err(FittedModelError::MissingField {
4359 reason:
4360 "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4361 .to_string(),
4362 });
4363 }
4364 }
4365 if let Some(v) = self.survival_noise_center.as_ref() {
4366 validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4367 }
4368 if let Some(v) = self.survival_noise_projection_ridge_alpha {
4369 ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4370 if v < 0.0 {
4371 return Err(FittedModelError::InvalidInput {
4372 reason: format!(
4373 "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4374 ),
4375 });
4376 }
4377 }
4378 if let Some(v) = self.survival_noise_scale.as_ref() {
4379 validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4380 }
4381 if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4382 validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4383 .map_err(corrupt)?;
4384 }
4385 if let Some(v) = self.sas_param_covariance.as_ref() {
4386 validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4387 .map_err(corrupt)?;
4388 }
4389 Ok(())
4390 }
4391}
4392
4393fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4394 a.rows().into_iter().map(|row| row.to_vec()).collect()
4395}
4396
4397use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4398
4399fn validate_frozen_term_collectionspec(
4400 spec: &TermCollectionSpec,
4401 label: &str,
4402) -> Result<(), FittedModelError> {
4403 spec.validate_frozen(label)
4404 .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4405}
4406
4407impl Deref for FittedModel {
4408 type Target = FittedModelPayload;
4409
4410 fn deref(&self) -> &Self::Target {
4411 self.payload()
4412 }
4413}
4414
4415impl DerefMut for FittedModel {
4416 fn deref_mut(&mut self) -> &mut Self::Target {
4417 self.payload_mut()
4418 }
4419}
4420
4421pub fn survival_baseline_config_from_model(
4426 model: &FittedModel,
4427) -> Result<SurvivalBaselineConfig, FittedModelError> {
4428 let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4429 FittedModelError::MissingField {
4430 reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4431 }
4432 })?;
4433 parse_survival_baseline_config(
4434 target,
4435 model.survival_baseline_scale,
4436 model.survival_baseline_shape,
4437 model.survival_baseline_rate,
4438 model.survival_baseline_makeham,
4439 )
4440 .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4441}
4442
4443pub fn load_survival_time_basis_config_from_model(
4444 model: &FittedModel,
4445) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4446 match model
4447 .survival_time_basis
4448 .as_deref()
4449 .ok_or_else(|| FittedModelError::MissingField {
4450 reason: "saved survival model missing survival_time_basis".to_string(),
4451 })?
4452 .to_ascii_lowercase()
4453 .as_str()
4454 {
4455 "none" => Ok(SurvivalTimeBasisConfig::None),
4456 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4457 "bspline" => {
4458 let degree =
4459 model
4460 .survival_time_degree
4461 .ok_or_else(|| FittedModelError::MissingField {
4462 reason: "saved survival bspline model missing survival_time_degree"
4463 .to_string(),
4464 })?;
4465 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4466 FittedModelError::MissingField {
4467 reason: "saved survival bspline model missing survival_time_knots".to_string(),
4468 }
4469 })?;
4470 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4471 if degree < 1 || knots.is_empty() {
4472 return Err(FittedModelError::SchemaMismatch {
4473 reason: "saved survival bspline time basis metadata is invalid".to_string(),
4474 });
4475 }
4476 Ok(SurvivalTimeBasisConfig::BSpline {
4477 degree,
4478 knots: Array1::from_vec(knots),
4479 smooth_lambda,
4480 })
4481 }
4482 "ispline" => {
4483 let degree =
4484 model
4485 .survival_time_degree
4486 .ok_or_else(|| FittedModelError::MissingField {
4487 reason: "saved survival ispline model missing survival_time_degree"
4488 .to_string(),
4489 })?;
4490 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4491 FittedModelError::MissingField {
4492 reason: "saved survival ispline model missing survival_time_knots".to_string(),
4493 }
4494 })?;
4495 let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4496 FittedModelError::MissingField {
4497 reason: "saved survival ispline model missing survival_time_keep_cols"
4498 .to_string(),
4499 }
4500 })?;
4501 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4502 if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4503 return Err(FittedModelError::SchemaMismatch {
4504 reason: "saved survival ispline time basis metadata is invalid".to_string(),
4505 });
4506 }
4507 Ok(SurvivalTimeBasisConfig::ISpline {
4508 degree,
4509 knots: Array1::from_vec(knots),
4510 keep_cols,
4511 smooth_lambda,
4512 })
4513 }
4514 other => Err(FittedModelError::IncompatibleConfig {
4515 reason: format!("unsupported saved survival_time_basis '{other}'"),
4516 }),
4517 }
4518}
4519
4520#[cfg(test)]
4521mod tests {
4522 use super::*;
4523 use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4524 use crate::survival::lognormal_kernel::FrailtySpec;
4525 use gam_solve::pirls::PirlsStatus;
4526 use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4527 use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4528 use gam_data::SchemaColumn;
4529 use ndarray::{Array1, Array2, array};
4530
4531 fn empty_termspec() -> TermCollectionSpec {
4532 TermCollectionSpec {
4533 linear_terms: vec![],
4534 random_effect_terms: vec![],
4535 smooth_terms: vec![],
4536 }
4537 }
4538
4539 #[test]
4543 fn spline_scan_payload_round_trips_and_validates() {
4544 let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4545 let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4546 let w = vec![1.0_f64; x.len()];
4547 let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4548 let make_payload = || {
4549 crate::inference::model_payload_builders::assemble_spline_scan_payload(
4550 "y ~ s(x)".to_string(),
4551 "x".to_string(),
4552 &fit,
4553 DataSchema {
4554 columns: vec![
4555 SchemaColumn {
4556 name: "y".to_string(),
4557 kind: ColumnKindTag::Continuous,
4558 levels: vec![],
4559 },
4560 SchemaColumn {
4561 name: "x".to_string(),
4562 kind: ColumnKindTag::Continuous,
4563 levels: vec![],
4564 },
4565 ],
4566 },
4567 vec!["x".to_string()],
4568 vec![(0.0, 1.0)],
4569 )
4570 };
4571 let model = FittedModel::from_payload(make_payload());
4574 model
4575 .validate_for_persistence()
4576 .expect("scan model validates");
4577 model
4578 .validate_numeric_finiteness()
4579 .expect("scan model is finite");
4580
4581 let json = serde_json::to_string(&model).expect("serialize model");
4582 let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4583 restored
4584 .validate_for_persistence()
4585 .expect("restored scan model validates");
4586 let (column, replay) = restored
4587 .saved_spline_scan()
4588 .expect("restore scan fit")
4589 .expect("payload carries the scan representation");
4590 assert_eq!(column, "x");
4591 for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4592 let (m0, v0) = fit.predict(xq).expect("predict original");
4593 let (m1, v1) = replay.predict(xq).expect("predict replayed");
4594 assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4595 assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4596 }
4597
4598 let mut dense = make_payload();
4600 dense.spline_scan = None;
4601 let err = FittedModel::from_payload(dense)
4602 .validate_for_persistence()
4603 .expect_err("dense payload without fit_result must be rejected");
4604 assert!(err.to_string().contains("fit_result"));
4605
4606 let mut corrupt = make_payload();
4608 corrupt
4609 .spline_scan
4610 .as_mut()
4611 .expect("scan channel present")
4612 .state
4613 .knots
4614 .truncate(2);
4615 FittedModel::from_payload(corrupt)
4616 .validate_for_persistence()
4617 .expect_err("corrupt scan state must be rejected");
4618 let mut unnamed = make_payload();
4619 unnamed
4620 .spline_scan
4621 .as_mut()
4622 .expect("scan channel present")
4623 .feature_column
4624 .clear();
4625 FittedModel::from_payload(unnamed)
4626 .validate_for_persistence()
4627 .expect_err("missing feature column must be rejected");
4628 }
4629
4630 fn standard_gaussian_payload() -> FittedModelPayload {
4631 FittedModelPayload::new(
4632 MODEL_PAYLOAD_VERSION,
4633 "y ~ 1".to_string(),
4634 ModelKind::Standard,
4635 FittedFamily::Standard {
4636 likelihood: LikelihoodSpec::gaussian_identity(),
4637 link: Some(StandardLink::Identity),
4638 latent_cloglog_state: None,
4639 mixture_state: None,
4640 sas_state: None,
4641 },
4642 "gaussian".to_string(),
4643 )
4644 }
4645
4646 fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4647 SavedCompiledFlexBlock {
4648 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4649 breakpoints: vec![-1.0, 1.0],
4650 basis_dim,
4651 span_c0: vec![vec![0.0; basis_dim]],
4652 span_c1: vec![vec![0.0; basis_dim]],
4653 span_c2: vec![vec![0.0; basis_dim]],
4654 span_c3: vec![vec![0.0; basis_dim]],
4655 anchor_correction: None,
4656 anchor_components: Vec::new(),
4657 }
4658 }
4659
4660 fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4661 let beta = Array1::from_vec(
4662 blocks
4663 .iter()
4664 .flat_map(|block| block.beta.iter().copied())
4665 .collect(),
4666 );
4667 let p = beta.len();
4668 UnifiedFitResult {
4669 blocks,
4670 log_lambdas: Array1::zeros(0),
4671 lambdas: Array1::zeros(0),
4672 likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4673 likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4674 log_likelihood_normalization: LogLikelihoodNormalization::Full,
4675 log_likelihood: 0.0,
4676 deviance: 0.0,
4677 reml_score: 0.0,
4678 stable_penalty_term: 0.0,
4679 penalized_objective: 0.0,
4680 used_device: false,
4681 outer_iterations: 0,
4682 outer_cost_evals: 0,
4683 inner_pirls_solves: 0,
4684 outer_converged: true,
4685 outer_gradient_norm: None,
4686 standard_deviation: 1.0,
4687 covariance_conditional: Some(Array2::zeros((p, p))),
4688 covariance_corrected: Some(Array2::zeros((p, p))),
4689 inference: None,
4690 fitted_link: FittedLinkState::Standard(None),
4691 geometry: None,
4692 block_states: vec![],
4693 beta,
4694 pirls_status: PirlsStatus::Converged,
4695 max_abs_eta: 0.0,
4696 constraint_kkt: None,
4697 artifacts: FitArtifacts {
4698 pirls: None,
4699 null_space_logdet: None,
4700 null_space_dim: None,
4701 survival_link_wiggle_knots: None,
4702 survival_link_wiggle_degree: None,
4703 criterion_certificate: None,
4704 rho_posterior_certificate: None,
4705 rho_posterior_escalation: None,
4706 rho_covariance: None,
4707 },
4708 inner_cycles: 0,
4709 }
4710 }
4711
4712 fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4713 let mut payload = FittedModelPayload::new(
4714 version,
4715 "y ~ 1".to_string(),
4716 ModelKind::MarginalSlope,
4717 FittedFamily::MarginalSlope {
4718 likelihood: LikelihoodSpec::binomial_probit(),
4719 base_link: InverseLink::Standard(StandardLink::Probit),
4720 frailty: FrailtySpec::None,
4721 },
4722 "bernoulli-marginal-slope".to_string(),
4723 );
4724 payload.fit_result = Some(fit.clone());
4725 payload.unified = Some(fit);
4726 payload.data_schema = Some(DataSchema {
4727 columns: vec![SchemaColumn {
4728 name: "z".to_string(),
4729 kind: ColumnKindTag::Continuous,
4730 levels: vec![],
4731 }],
4732 });
4733 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4734 payload.resolved_termspec = Some(empty_termspec());
4735 payload.resolved_termspec_logslope = Some(empty_termspec());
4736 payload.formula_logslope = Some("1".to_string());
4737 payload.z_column = Some("z".to_string());
4738 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4739 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4740 payload.marginal_baseline = Some(0.0);
4741 payload.logslope_baseline = Some(0.0);
4742 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4743 payload
4744 }
4745
4746 #[test]
4747 fn from_payload_synchronizes_used_device_from_saved_fit() {
4748 let mut fit = saved_fit(vec![
4749 FittedBlock {
4750 beta: Array1::from_vec(vec![0.25]),
4751 role: BlockRole::Mean,
4752 edf: 1.0,
4753 lambdas: Array1::zeros(0),
4754 },
4755 FittedBlock {
4756 beta: Array1::from_vec(vec![0.5]),
4757 role: BlockRole::Scale,
4758 edf: 1.0,
4759 lambdas: Array1::zeros(0),
4760 },
4761 ]);
4762 fit.used_device = true;
4763 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4764 payload.used_device = false;
4765
4766 let model = FittedModel::from_payload(payload);
4767
4768 assert!(model.payload().used_device);
4769 }
4770
4771 fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4772 let mut payload = FittedModelPayload::new(
4773 version,
4774 "Surv(entry, exit, event) ~ 1".to_string(),
4775 ModelKind::Survival,
4776 FittedFamily::Survival {
4777 likelihood: LikelihoodSpec::royston_parmar(),
4778 survival_likelihood: Some("marginal-slope".to_string()),
4779 survival_distribution: Some(ResidualDistribution::Gaussian),
4780 frailty: FrailtySpec::None,
4781 },
4782 "survival".to_string(),
4783 );
4784 payload.fit_result = Some(fit.clone());
4785 payload.unified = Some(fit);
4786 payload.survival_likelihood = Some("marginal-slope".to_string());
4787 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4788 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4789 payload.data_schema = Some(DataSchema {
4790 columns: vec![SchemaColumn {
4791 name: "z".to_string(),
4792 kind: ColumnKindTag::Continuous,
4793 levels: vec![],
4794 }],
4795 });
4796 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4797 payload.resolved_termspec = Some(empty_termspec());
4798 payload.resolved_termspec_logslope = Some(empty_termspec());
4799 payload.formula_logslope = Some("1".to_string());
4800 payload.z_column = Some("z".to_string());
4801 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4802 payload.logslope_baseline = Some(0.0);
4803 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4804 payload
4805 }
4806
4807 fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4808 let mut payload = FittedModelPayload::new(
4814 MODEL_PAYLOAD_VERSION,
4815 "y ~ x".to_string(),
4816 ModelKind::LocationScale,
4817 FittedFamily::LocationScale {
4818 likelihood: LikelihoodSpec::gamma_log(),
4819 base_link: Some(InverseLink::Standard(StandardLink::Log)),
4820 },
4821 "gamma-location-scale".to_string(),
4822 );
4823 payload.data_schema = Some(DataSchema {
4824 columns: vec![
4825 SchemaColumn {
4826 name: "y".to_string(),
4827 kind: ColumnKindTag::Continuous,
4828 levels: vec![],
4829 },
4830 SchemaColumn {
4831 name: "x".to_string(),
4832 kind: ColumnKindTag::Continuous,
4833 levels: vec![],
4834 },
4835 ],
4836 });
4837 payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4838 payload.resolved_termspec = Some(empty_termspec());
4839 payload.resolved_termspec_noise = Some(empty_termspec());
4840 payload.formula_noise = Some("x".to_string());
4841 payload.beta_noise = Some(vec![0.0]);
4842 payload.link = Some(InverseLink::Standard(StandardLink::Log));
4843 payload
4844 }
4845
4846 #[test]
4853 fn dispersion_location_scale_payload_is_not_classified_binomial() {
4854 let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4855 assert_eq!(
4856 model.predict_model_class(),
4857 PredictModelClass::DispersionLocationScale,
4858 "Gamma dispersion location-scale must route through the dispersion \
4859 predictor, not the binomial threshold-scale class",
4860 );
4861 assert!(
4862 !matches!(
4863 model.predict_model_class(),
4864 PredictModelClass::BinomialLocationScale
4865 ),
4866 "dispersion location-scale must never be classified as binomial",
4867 );
4868
4869 for likelihood in [
4871 LikelihoodSpec::gamma_log(),
4872 LikelihoodSpec::new(
4873 ResponseFamily::NegativeBinomial {
4874 theta: 1.0,
4875 theta_fixed: false,
4876 },
4877 InverseLink::Standard(StandardLink::Log),
4878 ),
4879 LikelihoodSpec::new(
4880 ResponseFamily::Beta { phi: 1.0 },
4881 InverseLink::Standard(StandardLink::Logit),
4882 ),
4883 LikelihoodSpec::new(
4884 ResponseFamily::Tweedie { p: 1.5 },
4885 InverseLink::Standard(StandardLink::Log),
4886 ),
4887 ] {
4888 let mut payload = gamma_dispersion_location_scale_payload();
4889 payload.family_state = FittedFamily::LocationScale {
4890 base_link: Some(likelihood.link.clone()),
4891 likelihood: likelihood.clone(),
4892 };
4893 let model = FittedModel::from_payload(payload);
4894 assert_eq!(
4895 model.predict_model_class(),
4896 PredictModelClass::DispersionLocationScale,
4897 "dispersion family {:?} mis-classified",
4898 likelihood.response,
4899 );
4900 }
4901 }
4902
4903 #[test]
4904 fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4905 let data = array![[100.0], [-100.0]];
4906 let col_map = HashMap::from([("g".to_string(), 0usize)]);
4907
4908 let mut plain_payload = standard_gaussian_payload();
4909 plain_payload.data_schema = Some(DataSchema {
4910 columns: vec![SchemaColumn {
4911 name: "g".to_string(),
4912 kind: ColumnKindTag::Continuous,
4913 levels: vec![],
4914 }],
4915 });
4916 plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4917 plain_payload.resolved_termspec = Some(empty_termspec());
4918 let plain = FittedModel::from_payload(plain_payload.clone());
4919 let clipped = plain
4920 .axis_clip_to_training_ranges(data.view(), &col_map)
4921 .expect("ordinary continuous axis should clip outside the training range");
4922 assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4923
4924 let mut group_payload = plain_payload;
4925 let mut group_spec = empty_termspec();
4926 group_spec
4927 .random_effect_terms
4928 .push(gam_terms::smooth::RandomEffectTermSpec {
4929 name: "g".to_string(),
4930 feature_col: 0,
4931 drop_first_level: false,
4932 penalized: true,
4933 frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4934 });
4935 group_payload.resolved_termspec = Some(group_spec);
4936 let group_model = FittedModel::from_payload(group_payload);
4937
4938 assert_eq!(
4939 group_model.random_effect_group_columns(),
4940 HashSet::from(["g".to_string()])
4941 );
4942
4943 assert_eq!(
4944 group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4945 None,
4946 "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4947 );
4948 }
4949
4950 #[test]
4951 fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4952 let fit = saved_fit(vec![
4953 FittedBlock {
4954 beta: array![0.1],
4955 role: BlockRole::Mean,
4956 edf: 1.0,
4957 lambdas: Array1::zeros(0),
4958 },
4959 FittedBlock {
4960 beta: array![0.2],
4961 role: BlockRole::Scale,
4962 edf: 1.0,
4963 lambdas: Array1::zeros(0),
4964 },
4965 FittedBlock {
4966 beta: array![0.3],
4967 role: BlockRole::Mean,
4968 edf: 1.0,
4969 lambdas: Array1::zeros(0),
4970 },
4971 ]);
4972 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4973 payload.score_warp_runtime = Some(anchored_runtime(2));
4974
4975 let err = FittedModel::from_payload(payload)
4976 .validate_for_persistence()
4977 .expect_err("marginal-slope score-warp basis mismatch should fail validation");
4978 assert!(err.to_string().contains("score-warp coefficient mismatch"));
4979 }
4980
4981 #[test]
4982 fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
4983 let fit = saved_fit(vec![
4984 FittedBlock {
4985 beta: array![0.1],
4986 role: BlockRole::Time,
4987 edf: 1.0,
4988 lambdas: Array1::zeros(0),
4989 },
4990 FittedBlock {
4991 beta: array![0.2],
4992 role: BlockRole::Mean,
4993 edf: 1.0,
4994 lambdas: Array1::zeros(0),
4995 },
4996 FittedBlock {
4997 beta: array![0.3],
4998 role: BlockRole::Scale,
4999 edf: 1.0,
5000 lambdas: Array1::zeros(0),
5001 },
5002 FittedBlock {
5003 beta: array![0.4],
5004 role: BlockRole::LinkWiggle,
5005 edf: 1.0,
5006 lambdas: Array1::zeros(0),
5007 },
5008 ]);
5009 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5010 payload.link_deviation_runtime = Some(anchored_runtime(2));
5011
5012 let err = FittedModel::from_payload(payload)
5013 .saved_prediction_runtime()
5014 .expect_err(
5015 "survival marginal-slope link basis mismatch should fail runtime validation",
5016 );
5017 assert!(
5018 err.to_string()
5019 .contains("link-deviation coefficient mismatch")
5020 );
5021 }
5022
5023 #[test]
5024 fn apply_survival_time_basis_writes_all_required_fields() {
5025 use crate::survival::construction::SavedSurvivalTimeBasis;
5026
5027 let fit = saved_fit(vec![
5028 FittedBlock {
5029 beta: array![0.1],
5030 role: BlockRole::Time,
5031 edf: 1.0,
5032 lambdas: Array1::zeros(0),
5033 },
5034 FittedBlock {
5035 beta: array![0.2],
5036 role: BlockRole::Mean,
5037 edf: 1.0,
5038 lambdas: Array1::zeros(0),
5039 },
5040 FittedBlock {
5041 beta: array![0.3],
5042 role: BlockRole::Scale,
5043 edf: 1.0,
5044 lambdas: Array1::zeros(0),
5045 },
5046 ]);
5047 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5048
5049 let snapshot = SavedSurvivalTimeBasis {
5054 basisname: "royston-parmar".to_string(),
5055 degree: Some(3),
5056 knots: Some(vec![0.0, 1.0, 2.0]),
5057 keep_cols: Some(vec![0, 2]),
5058 smooth_lambda: Some(0.5),
5059 anchor: 0.25,
5060 };
5061 payload.apply_survival_time_basis(&snapshot);
5062
5063 assert_eq!(
5064 payload.survival_time_basis.as_deref(),
5065 Some("royston-parmar")
5066 );
5067 assert_eq!(payload.survival_time_degree, Some(3));
5068 assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5069 assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5070 assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5071 assert_eq!(payload.survival_time_anchor, Some(0.25));
5072 }
5073
5074 #[test]
5075 fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5076 let fit = saved_fit(vec![
5077 FittedBlock {
5078 beta: array![0.1],
5079 role: BlockRole::Time,
5080 edf: 1.0,
5081 lambdas: Array1::zeros(0),
5082 },
5083 FittedBlock {
5084 beta: array![0.2],
5085 role: BlockRole::Mean,
5086 edf: 1.0,
5087 lambdas: Array1::zeros(0),
5088 },
5089 FittedBlock {
5090 beta: array![0.3],
5091 role: BlockRole::Scale,
5092 edf: 1.0,
5093 lambdas: Array1::zeros(0),
5094 },
5095 ]);
5096 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5097 payload.survival_time_basis = Some("ispline".to_string());
5103
5104 let err = FittedModel::from_payload(payload)
5105 .validate_for_persistence()
5106 .expect_err("survival model without time-anchor metadata should fail validation");
5107 assert!(err.to_string().contains("missing survival_time_anchor"));
5108 }
5109
5110 #[test]
5111 fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5112 let fit = saved_fit(vec![
5113 FittedBlock {
5114 beta: array![0.1],
5115 role: BlockRole::Time,
5116 edf: 1.0,
5117 lambdas: Array1::zeros(0),
5118 },
5119 FittedBlock {
5120 beta: array![0.2],
5121 role: BlockRole::Mean,
5122 edf: 1.0,
5123 lambdas: Array1::zeros(0),
5124 },
5125 FittedBlock {
5126 beta: array![0.3],
5127 role: BlockRole::Scale,
5128 edf: 1.0,
5129 lambdas: Array1::zeros(0),
5130 },
5131 ]);
5132 let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5133
5134 let err = FittedModel::from_payload(payload)
5135 .validate_for_persistence()
5136 .expect_err("survival model without time-basis metadata should fail validation");
5137 assert!(err.to_string().contains("missing survival_time_basis"));
5138 }
5139
5140 #[test]
5141 fn saved_prediction_runtime_rejects_stale_payload_version() {
5142 let fit = saved_fit(vec![
5143 FittedBlock {
5144 beta: array![0.1],
5145 role: BlockRole::Mean,
5146 edf: 1.0,
5147 lambdas: Array1::zeros(0),
5148 },
5149 FittedBlock {
5150 beta: array![0.2],
5151 role: BlockRole::Scale,
5152 edf: 1.0,
5153 lambdas: Array1::zeros(0),
5154 },
5155 ]);
5156 let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5157
5158 let err = FittedModel::from_payload(payload)
5159 .saved_prediction_runtime()
5160 .expect_err("stale payload version should fail before runtime assembly");
5161 assert!(err.to_string().contains("payload schema mismatch"));
5162 }
5163}