1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4 LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7 SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12 monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15 inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16 parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21 InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22 SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72 pub feature_column: String,
74 pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83 pub feature_columns: Vec<String>,
85 pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99 SchemaMismatch { reason: String },
103 PayloadCorrupt { reason: String },
107 MissingField { reason: String },
110 IncompatibleConfig { reason: String },
114 InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120 FittedModelError {
121 SchemaMismatch,
122 PayloadCorrupt,
123 MissingField,
124 IncompatibleConfig,
125 InvalidInput,
126 }
127}
128
129impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134 fn from(err: FittedModelError) -> Self {
135 gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136 }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140 fn from(err: FittedModelError) -> Self {
141 crate::survival::predict::SurvivalPredictError::ModelPayload {
142 context: "saved-model survival prediction payload",
143 source: err,
144 }
145 }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150 pub mean: f64,
151 pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156 if !self.mean.is_finite() {
157 return Err(FittedModelError::PayloadCorrupt {
158 reason: format!("{context} latent z mean must be finite"),
159 });
160 }
161 if !(self.sd.is_finite() && self.sd > 1e-12) {
162 return Err(FittedModelError::PayloadCorrupt {
163 reason: format!(
164 "{context} latent z sd must be finite and > 1e-12; got {}",
165 self.sd
166 ),
167 });
168 }
169 Ok::<(), _>(())
170 }
171
172 pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173 self.validate(context)?;
174 if z.iter().any(|value| !value.is_finite()) {
175 return Err(FittedModelError::PayloadCorrupt {
176 reason: format!("{context} requires finite z values"),
177 });
178 }
179 Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180 }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189 #[default]
190 FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195 #[serde(default)]
196 pub score_kind: TransformationScoreKind,
197 #[serde(default = "default_transformation_score_pit_clip_eps")]
198 pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202 TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206 pub fn finite_support_pit() -> Self {
207 Self {
208 score_kind: TransformationScoreKind::FiniteSupportPit,
209 clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210 }
211 }
212
213 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214 if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215 return Err(FittedModelError::IncompatibleConfig {
216 reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217 });
218 }
219 if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220 return Err(FittedModelError::IncompatibleConfig {
221 reason: format!(
222 "{context} requires PIT clip_eps in (0, 0.5), got {}",
223 self.clip_eps
224 ),
225 });
226 }
227 Ok::<(), _>(())
228 }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233 pub version: u32,
234 pub formula: String,
235 pub model_kind: ModelKind,
236 pub family_state: FittedFamily,
237 pub family: String,
238 #[serde(default)]
247 pub inference_notes: Vec<String>,
248 #[serde(default)]
249 pub used_device: bool,
250 #[serde(default)]
251 pub fit_result: Option<UnifiedFitResult>,
252 #[serde(default)]
254 pub unified: Option<UnifiedFitResult>,
255 #[serde(default)]
263 pub spline_scan: Option<SavedSplineScan>,
264 #[serde(default)]
270 pub residual_cascade: Option<SavedResidualCascade>,
271 #[serde(default)]
272 pub data_schema: Option<DataSchema>,
273 pub link: Option<InverseLink>,
274 #[serde(default)]
275 pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276 #[serde(default)]
277 pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278 #[serde(default)]
279 pub formula_noise: Option<String>,
280 #[serde(default)]
281 pub formula_logslope: Option<String>,
282 #[serde(default)]
283 pub formula_logslopes: Option<Vec<String>>,
284 #[serde(default)]
285 pub offset_column: Option<String>,
286 #[serde(default)]
287 pub noise_offset_column: Option<String>,
288 #[serde(default)]
293 pub weight_column: Option<String>,
294 #[serde(default)]
295 pub beta_noise: Option<Vec<f64>>,
296 #[serde(default)]
297 pub noise_projection: Option<Vec<Vec<f64>>>,
298 #[serde(default)]
299 pub noise_center: Option<Vec<f64>>,
300 #[serde(default)]
301 pub noise_scale: Option<Vec<f64>>,
302 #[serde(default)]
303 pub noise_non_intercept_start: Option<usize>,
304 #[serde(default)]
308 pub noise_projection_ridge_alpha: Option<f64>,
309 #[serde(default)]
310 pub gaussian_response_scale: Option<f64>,
311 #[serde(default)]
312 pub linkwiggle_knots: Option<Vec<f64>>,
313 #[serde(default)]
314 pub linkwiggle_degree: Option<usize>,
315 #[serde(default)]
316 pub beta_link_wiggle: Option<Vec<f64>>,
317 #[serde(default)]
318 pub baseline_timewiggle_knots: Option<Vec<f64>>,
319 #[serde(default)]
320 pub baseline_timewiggle_degree: Option<usize>,
321 #[serde(default)]
322 pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
323 #[serde(default)]
324 pub baseline_timewiggle_double_penalty: Option<bool>,
325 #[serde(default)]
326 pub beta_baseline_timewiggle: Option<Vec<f64>>,
327 #[serde(default)]
328 pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
329 #[serde(default)]
330 pub z_column: Option<String>,
331 #[serde(default)]
332 pub z_columns: Option<Vec<String>>,
333 #[serde(default)]
334 pub latent_z_normalization: Option<SavedLatentZNormalization>,
335 #[serde(default)]
336 pub latent_score_contract: Option<SavedLatentScoreContract>,
337 #[serde(default)]
338 pub latent_measure: Option<LatentMeasureKind>,
339 #[serde(default)]
346 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
347 #[serde(default)]
356 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
357 #[serde(default)]
358 pub marginal_baseline: Option<f64>,
359 #[serde(default)]
360 pub logslope_baseline: Option<f64>,
361 #[serde(default)]
362 pub logslope_baselines: Option<Vec<f64>>,
363 #[serde(default)]
364 pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
365 #[serde(default)]
366 pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
367 #[serde(default)]
372 pub influence_absorber_width: Option<usize>,
373 #[serde(default)]
374 pub survival_entry: Option<String>,
375 #[serde(default)]
376 pub survival_exit: Option<String>,
377 #[serde(default)]
378 pub survival_event: Option<String>,
379 #[serde(default)]
380 pub survivalspec: Option<String>,
381 #[serde(default)]
382 pub survival_cause_count: Option<usize>,
383 #[serde(default)]
384 pub survival_endpoint_names: Option<Vec<String>>,
385 #[serde(default)]
386 pub survival_baseline_target: Option<String>,
387 #[serde(default)]
388 pub survival_baseline_scale: Option<f64>,
389 #[serde(default)]
390 pub survival_baseline_shape: Option<f64>,
391 #[serde(default)]
392 pub survival_baseline_rate: Option<f64>,
393 #[serde(default)]
394 pub survival_baseline_makeham: Option<f64>,
395 #[serde(default)]
396 pub survival_time_basis: Option<String>,
397 #[serde(default)]
398 pub survival_time_degree: Option<usize>,
399 #[serde(default)]
400 pub survival_time_knots: Option<Vec<f64>>,
401 #[serde(default)]
402 pub survival_time_keep_cols: Option<Vec<usize>>,
403 #[serde(default)]
404 pub survival_time_smooth_lambda: Option<f64>,
405 #[serde(default)]
406 pub survival_time_anchor: Option<f64>,
407 #[serde(default)]
408 pub survivalridge_lambda: Option<f64>,
409 #[serde(default)]
410 pub survival_likelihood: Option<String>,
411 #[serde(default)]
412 pub survival_beta_time: Option<Vec<f64>>,
413 #[serde(default)]
414 pub survival_beta_threshold: Option<Vec<f64>>,
415 #[serde(default)]
416 pub survival_beta_log_sigma: Option<Vec<f64>>,
417 #[serde(default)]
418 pub survival_noise_projection: Option<Vec<Vec<f64>>>,
419 #[serde(default)]
420 pub survival_noise_center: Option<Vec<f64>>,
421 #[serde(default)]
422 pub survival_noise_scale: Option<Vec<f64>>,
423 #[serde(default)]
424 pub survival_noise_non_intercept_start: Option<usize>,
425 #[serde(default)]
428 pub survival_noise_projection_ridge_alpha: Option<f64>,
429 #[serde(default)]
430 pub survival_distribution: Option<ResidualDistribution>,
431 #[serde(default)]
432 pub training_headers: Option<Vec<String>>,
433 #[serde(default, skip_serializing_if = "Option::is_none")]
445 pub training_table_kind: Option<String>,
446 #[serde(default)]
454 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
455 #[serde(default, skip_serializing_if = "Option::is_none")]
461 pub group_metadata: Option<GroupMetadata>,
462 #[serde(default, skip_serializing_if = "Vec::is_empty")]
469 pub deployment_extensions: Vec<SavedDeploymentExtension>,
470 #[serde(default)]
472 pub transformation_response_knots: Option<Vec<f64>>,
473 #[serde(default)]
475 pub transformation_response_transform: Option<Vec<Vec<f64>>>,
476 #[serde(default)]
478 pub transformation_response_degree: Option<usize>,
479 #[serde(default)]
481 pub transformation_response_median: Option<f64>,
482 #[serde(default)]
486 pub transformation_score_calibration: Option<TransformationScoreCalibration>,
487 #[serde(default)]
488 pub resolved_termspec: Option<TermCollectionSpec>,
489 #[serde(default)]
490 pub resolved_termspec_noise: Option<TermCollectionSpec>,
491 #[serde(default)]
492 pub resolved_termspec_logslope: Option<TermCollectionSpec>,
493 #[serde(default)]
494 pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
495 #[serde(default)]
496 pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
497 #[serde(default)]
509 pub gaussian_jackknife_plus:
510 Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
511 #[serde(default)]
525 pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
526}
527
528#[derive(Clone, Debug, Serialize, Deserialize)]
529pub struct SavedDeploymentExtension {
530 pub name: String,
531 pub kind: String,
532 pub term: String,
533 pub level: JsonValue,
534 pub level_bits: u64,
535 pub coefficient_index: usize,
536 pub coefficient_mean: f64,
537 pub coefficient_variance: f64,
538 #[serde(default, skip_serializing_if = "Option::is_none")]
539 pub metadata: Option<JsonValue>,
540 #[serde(default, skip_serializing_if = "Option::is_none")]
541 pub prior: Option<JsonValue>,
542}
543
544pub fn append_deployment_extension_columns(
561 model: &FittedModelPayload,
562 data: ndarray::ArrayView2<'_, f64>,
563 col_map: &HashMap<String, usize>,
564 training_headers: Option<&Vec<String>>,
565 base_design: Array2<f64>,
566) -> Result<Array2<f64>, FittedModelError> {
567 if model.deployment_extensions.is_empty() {
568 return Ok(base_design);
569 }
570 if base_design.nrows() != data.nrows() {
571 return Err(FittedModelError::SchemaMismatch {
572 reason: format!(
573 "deployment extension design row mismatch: base design has {} rows but data has {}",
574 base_design.nrows(),
575 data.nrows()
576 ),
577 });
578 }
579 let spec = model
580 .resolved_termspec
581 .as_ref()
582 .ok_or_else(|| FittedModelError::MissingField {
583 reason: "deployment extension prediction requires saved resolved_termspec; refit"
584 .to_string(),
585 })?;
586 let n = base_design.nrows();
587 let p_old = base_design.ncols();
588 let mut extensions: Vec<&SavedDeploymentExtension> =
589 model.deployment_extensions.iter().collect();
590 extensions.sort_by_key(|extension| extension.coefficient_index);
591 for (tail_idx, extension) in extensions.iter().enumerate() {
592 let expected = p_old + tail_idx;
593 if extension.coefficient_index != expected {
594 return Err(FittedModelError::SchemaMismatch {
595 reason: format!(
596 "deployment extension '{}' has coefficient index {}, expected append-only index {}",
597 extension.name, extension.coefficient_index, expected
598 ),
599 });
600 }
601 }
602
603 let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
604 out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
605 for (tail_idx, extension) in extensions.into_iter().enumerate() {
606 if extension.kind != "random-effect-level" {
607 return Err(FittedModelError::IncompatibleConfig {
608 reason: format!(
609 "unsupported deployment extension kind '{}' for '{}'",
610 extension.kind, extension.name
611 ),
612 });
613 }
614 let term = spec
615 .random_effect_terms
616 .iter()
617 .find(|term| term.name == extension.term)
618 .ok_or_else(|| FittedModelError::MissingField {
619 reason: format!(
620 "deployment extension '{}' references unknown random-effect term '{}'",
621 extension.name, extension.term
622 ),
623 })?;
624 let prediction_col = training_headers
625 .and_then(|headers| headers.get(term.feature_col))
626 .and_then(|name| col_map.get(name))
627 .copied()
628 .unwrap_or(term.feature_col);
629 if prediction_col >= data.ncols() {
630 return Err(FittedModelError::SchemaMismatch {
631 reason: format!(
632 "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
633 extension.name,
634 prediction_col,
635 data.ncols()
636 ),
637 });
638 }
639 let col = p_old + tail_idx;
640 for row in 0..n {
641 if data[[row, prediction_col]].to_bits() == extension.level_bits {
642 out[[row, col]] = 1.0;
643 }
644 }
645 }
646 Ok(out)
647}
648
649#[derive(Clone, Debug, Serialize, Deserialize)]
650pub struct SavedLatentScoreContract {
651 pub semantics: String,
652 pub source_transform_id: Option<String>,
653 pub normalization_mean: f64,
654 pub normalization_sd: f64,
655 pub clip_eps: Option<f64>,
656 pub conditioning_columns: Vec<String>,
657}
658
659impl FittedModelPayload {
660 pub fn new(
661 version: u32,
662 formula: String,
663 model_kind: ModelKind,
664 family_state: FittedFamily,
665 family: String,
666 ) -> Self {
667 Self {
668 version,
669 formula,
670 model_kind,
671 family_state,
672 family,
673 inference_notes: Vec::new(),
674 used_device: false,
675 fit_result: None,
676 unified: None,
677 spline_scan: None,
678 residual_cascade: None,
679 data_schema: None,
680 link: None,
681 mixture_link_param_covariance: None,
682 sas_param_covariance: None,
683 formula_noise: None,
684 formula_logslope: None,
685 formula_logslopes: None,
686 offset_column: None,
687 noise_offset_column: None,
688 weight_column: None,
689 beta_noise: None,
690 noise_projection: None,
691 noise_center: None,
692 noise_scale: None,
693 noise_non_intercept_start: None,
694 noise_projection_ridge_alpha: None,
695 gaussian_response_scale: None,
696 linkwiggle_knots: None,
697 linkwiggle_degree: None,
698 beta_link_wiggle: None,
699 baseline_timewiggle_knots: None,
700 baseline_timewiggle_degree: None,
701 baseline_timewiggle_penalty_orders: None,
702 baseline_timewiggle_double_penalty: None,
703 beta_baseline_timewiggle: None,
704 beta_baseline_timewiggle_by_cause: None,
705 z_column: None,
706 z_columns: None,
707 latent_z_normalization: None,
708 latent_score_contract: None,
709 latent_measure: None,
710 latent_z_rank_int_calibration: None,
711 latent_z_conditional_calibration: None,
712 marginal_baseline: None,
713 logslope_baseline: None,
714 logslope_baselines: None,
715 score_warp_runtime: None,
716 link_deviation_runtime: None,
717 influence_absorber_width: None,
718 survival_entry: None,
719 survival_exit: None,
720 survival_event: None,
721 survivalspec: None,
722 survival_cause_count: None,
723 survival_endpoint_names: None,
724 survival_baseline_target: None,
725 survival_baseline_scale: None,
726 survival_baseline_shape: None,
727 survival_baseline_rate: None,
728 survival_baseline_makeham: None,
729 survival_time_basis: None,
730 survival_time_degree: None,
731 survival_time_knots: None,
732 survival_time_keep_cols: None,
733 survival_time_smooth_lambda: None,
734 survival_time_anchor: None,
735 survivalridge_lambda: None,
736 survival_likelihood: None,
737 survival_beta_time: None,
738 survival_beta_threshold: None,
739 survival_beta_log_sigma: None,
740 survival_noise_projection: None,
741 survival_noise_center: None,
742 survival_noise_scale: None,
743 survival_noise_non_intercept_start: None,
744 survival_noise_projection_ridge_alpha: None,
745 survival_distribution: None,
746 training_headers: None,
747 training_table_kind: None,
748 training_feature_ranges: None,
749 group_metadata: None,
750 deployment_extensions: Vec::new(),
751 transformation_response_knots: None,
752 transformation_response_transform: None,
753 transformation_response_degree: None,
754 transformation_response_median: None,
755 transformation_score_calibration: None,
756 resolved_termspec: None,
757 resolved_termspec_noise: None,
758 resolved_termspec_logslope: None,
759 resolved_termspec_logslopes: None,
760 adaptive_regularization_diagnostics: None,
761 gaussian_jackknife_plus: None,
762 full_conformal: None,
763 }
764 }
765
766 pub fn set_training_feature_metadata(
767 &mut self,
768 headers: Vec<String>,
769 feature_ranges: Vec<(f64, f64)>,
770 ) {
771 self.training_headers = Some(headers);
772 self.training_feature_ranges = Some(feature_ranges);
773 }
774
775 fn synchronize_empty_feature_contract(&mut self) {
776 if self.fit_result.is_none() {
777 return;
778 }
779 let Some(schema) = self.data_schema.as_ref() else {
780 return;
781 };
782 if !schema.columns.is_empty() {
783 return;
784 }
785 self.training_headers.get_or_insert_with(Vec::new);
786 self.resolved_termspec
787 .get_or_insert_with(|| TermCollectionSpec {
788 linear_terms: Vec::new(),
789 smooth_terms: Vec::new(),
790 random_effect_terms: Vec::new(),
791 });
792 }
793
794 pub fn apply_survival_time_basis(
802 &mut self,
803 snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
804 ) {
805 self.survival_time_basis = Some(snapshot.basisname.clone());
806 self.survival_time_degree = snapshot.degree;
807 self.survival_time_knots = snapshot.knots.clone();
808 self.survival_time_keep_cols = snapshot.keep_cols.clone();
809 self.survival_time_smooth_lambda = snapshot.smooth_lambda;
810 self.survival_time_anchor = Some(snapshot.anchor);
811 }
812
813 fn validate_payload_version(&self) -> Result<(), FittedModelError> {
814 if self.version != MODEL_PAYLOAD_VERSION {
815 return Err(FittedModelError::SchemaMismatch {
816 reason: format!(
817 "saved model payload schema mismatch: file has version={}, \
818 this binary expects MODEL_PAYLOAD_VERSION={}. \
819 Refit with the current CLI, or rebuild the reader at the same \
820 version the model was written with.",
821 self.version, MODEL_PAYLOAD_VERSION
822 ),
823 });
824 }
825 Ok(())
826 }
827}
828
829#[derive(Clone, Serialize, Deserialize)]
830#[serde(tag = "model_type", rename_all = "kebab-case")]
831pub enum FittedModel {
832 Standard { payload: FittedModelPayload },
833 LocationScale { payload: FittedModelPayload },
834 MarginalSlope { payload: FittedModelPayload },
835 Survival { payload: FittedModelPayload },
836 TransformationNormal { payload: FittedModelPayload },
837}
838
839#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
840#[serde(rename_all = "kebab-case")]
841pub enum ModelKind {
842 Standard,
843 LocationScale,
844 MarginalSlope,
845 Survival,
846 TransformationNormal,
847}
848
849#[derive(Clone, Debug, Serialize, Deserialize)]
850#[serde(tag = "family_kind", rename_all = "kebab-case")]
851pub enum FittedFamily {
852 Standard {
853 likelihood: LikelihoodSpec,
854 #[serde(default)]
855 link: Option<StandardLink>,
856 #[serde(default)]
857 latent_cloglog_state: Option<LatentCLogLogState>,
858 #[serde(default)]
859 mixture_state: Option<MixtureLinkState>,
860 #[serde(default)]
861 sas_state: Option<SasLinkState>,
862 },
863 LocationScale {
864 likelihood: LikelihoodSpec,
865 #[serde(default)]
866 base_link: Option<InverseLink>,
867 },
868 MarginalSlope {
869 likelihood: LikelihoodSpec,
870 base_link: InverseLink,
871 frailty: FrailtySpec,
872 },
873 Survival {
874 likelihood: LikelihoodSpec,
875 #[serde(default)]
876 survival_likelihood: Option<String>,
877 #[serde(default)]
878 survival_distribution: Option<ResidualDistribution>,
879 frailty: FrailtySpec,
880 },
881 LatentSurvival {
882 frailty: FrailtySpec,
883 },
884 LatentBinary {
885 frailty: FrailtySpec,
886 },
887 TransformationNormal {
888 likelihood: LikelihoodSpec,
889 },
890}
891
892#[derive(Clone, Copy, Debug, Eq, PartialEq)]
893pub enum PredictModelClass {
894 Standard,
895 GaussianLocationScale,
896 BinomialLocationScale,
897 DispersionLocationScale,
902 BernoulliMarginalSlope,
903 Survival,
904 TransformationNormal,
905}
906
907impl PredictModelClass {
908 #[inline]
909 pub const fn name(self) -> &'static str {
910 match self {
911 Self::Standard => "standard",
912 Self::GaussianLocationScale => "gaussian location-scale",
913 Self::BinomialLocationScale => "binomial location-scale",
914 Self::DispersionLocationScale => "dispersion location-scale",
915 Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
916 Self::Survival => "survival",
917 Self::TransformationNormal => "transformation-normal",
918 }
919 }
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedLinkWiggleRuntime {
924 pub knots: Vec<f64>,
925 pub degree: usize,
926 pub beta: Vec<f64>,
927}
928
929#[derive(Clone, Debug)]
930pub struct SavedBaselineTimeWiggleRuntime {
931 pub knots: Vec<f64>,
932 pub degree: usize,
933 pub penalty_orders: Vec<usize>,
934 pub double_penalty: bool,
935 pub beta: Vec<f64>,
936}
937
938pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
941
942#[derive(Clone, Debug, Serialize, Deserialize)]
943pub struct SavedCompiledFlexBlock {
944 pub kernel: String,
945 pub breakpoints: Vec<f64>,
946 pub basis_dim: usize,
947 pub span_c0: Vec<Vec<f64>>,
948 pub span_c1: Vec<Vec<f64>>,
949 pub span_c2: Vec<Vec<f64>>,
950 pub span_c3: Vec<Vec<f64>>,
951 #[serde(default)]
957 pub anchor_correction: Option<Vec<Vec<f64>>>,
958 #[serde(default)]
962 pub anchor_components: Vec<SavedAnchorComponent>,
963}
964
965#[derive(Clone, Debug, Serialize, Deserialize)]
966pub struct SavedAnchorComponent {
967 pub kind: SavedAnchorKind,
968}
969
970#[derive(Clone, Debug, Serialize, Deserialize)]
971pub enum SavedAnchorKind {
972 Parametric {
973 block: ParametricAnchorBlock,
974 ncols: usize,
975 },
976 FlexEvaluation { ncols: usize },
981}
982
983#[derive(Clone, Debug)]
984pub struct SavedPredictionRuntime {
985 pub model_class: PredictModelClass,
986 pub likelihood: LikelihoodSpec,
987 pub inverse_link: Option<InverseLink>,
988 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
989 pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
990 pub score_warp: Option<SavedCompiledFlexBlock>,
991 pub link_deviation: Option<SavedCompiledFlexBlock>,
992 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
996 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
1002 pub influence_absorber_width: Option<usize>,
1012}
1013
1014pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1015 fit.block_by_role(BlockRole::Location)
1016 .or_else(|| fit.block_by_role(BlockRole::Mean))
1017 .map(|block| block.beta.clone())
1018}
1019
1020pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021 fit.block_by_role(BlockRole::Threshold)
1022 .or_else(|| fit.block_by_role(BlockRole::Location))
1023 .or_else(|| fit.block_by_role(BlockRole::Mean))
1024 .map(|block| block.beta.clone())
1025}
1026
1027pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1028 fit.block_by_role(BlockRole::Scale)
1029 .map(|block| block.beta.clone())
1030}
1031
1032fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1041 use gam_problem::types::ResponseFamily;
1042 matches!(
1043 response,
1044 ResponseFamily::NegativeBinomial { .. }
1045 | ResponseFamily::Gamma
1046 | ResponseFamily::Beta { .. }
1047 | ResponseFamily::Tweedie { .. }
1048 )
1049}
1050
1051fn validate_location_scale_saved_fit(
1052 fit: &UnifiedFitResult,
1053 model_class: PredictModelClass,
1054 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1055) -> Result<(), FittedModelError> {
1056 let primary = match model_class {
1057 PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1061 gaussian_location_scale_mean_beta(fit)
1062 }
1063 PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1064 _ => None,
1065 }
1066 .ok_or_else(|| FittedModelError::MissingField {
1067 reason: match model_class {
1068 PredictModelClass::GaussianLocationScale => {
1069 "gaussian-location-scale saved fit is missing mean/location block".to_string()
1070 }
1071 PredictModelClass::DispersionLocationScale => {
1072 "dispersion-location-scale saved fit is missing mean/location block".to_string()
1073 }
1074 PredictModelClass::BinomialLocationScale => {
1075 "binomial-location-scale saved fit is missing threshold/location block".to_string()
1076 }
1077 _ => "location-scale saved fit is missing primary block".to_string(),
1078 },
1079 })?;
1080
1081 let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1082 reason: "location-scale saved fit is missing scale block".to_string(),
1083 })?;
1084 let expected =
1085 primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1086
1087 if let Some(cov) = fit.beta_covariance()
1088 && (cov.nrows() != expected || cov.ncols() != expected)
1089 {
1090 return Err(FittedModelError::SchemaMismatch {
1091 reason: format!(
1092 "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1093 cov.nrows(),
1094 cov.ncols(),
1095 expected,
1096 expected
1097 ),
1098 });
1099 }
1100 if let Some(cov) = fit.beta_covariance_corrected()
1101 && (cov.nrows() != expected || cov.ncols() != expected)
1102 {
1103 return Err(FittedModelError::SchemaMismatch {
1104 reason: format!(
1105 "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1106 cov.nrows(),
1107 cov.ncols(),
1108 expected,
1109 expected
1110 ),
1111 });
1112 }
1113 Ok(())
1114}
1115
1116fn validate_survival_saved_block_matches_payload(
1117 fit: &UnifiedFitResult,
1118 role: BlockRole,
1119 payload_beta: Option<&Vec<f64>>,
1120 label: &str,
1121) -> Result<usize, FittedModelError> {
1122 let block = fit
1123 .block_by_role(role)
1124 .ok_or_else(|| FittedModelError::MissingField {
1125 reason: format!("location-scale survival saved fit is missing {label} block"),
1126 })?;
1127 if let Some(saved) = payload_beta
1128 && block.beta.to_vec() != *saved
1129 {
1130 return Err(FittedModelError::SchemaMismatch {
1131 reason: format!(
1132 "location-scale survival saved {label} coefficients disagree with fit_result"
1133 ),
1134 });
1135 }
1136 Ok(block.beta.len())
1137}
1138
1139fn validate_survival_location_scale_saved_fit(
1140 payload: &FittedModelPayload,
1141 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1142) -> Result<(), FittedModelError> {
1143 let fit = payload
1144 .fit_result
1145 .as_ref()
1146 .ok_or_else(|| FittedModelError::MissingField {
1147 reason: "location-scale survival model is missing canonical fit_result payload"
1148 .to_string(),
1149 })?;
1150 let p_time = validate_survival_saved_block_matches_payload(
1151 fit,
1152 BlockRole::Time,
1153 payload.survival_beta_time.as_ref(),
1154 "time",
1155 )?;
1156 let p_threshold = validate_survival_saved_block_matches_payload(
1157 fit,
1158 BlockRole::Threshold,
1159 payload.survival_beta_threshold.as_ref(),
1160 "threshold",
1161 )?;
1162 let p_log_sigma = validate_survival_saved_block_matches_payload(
1163 fit,
1164 BlockRole::Scale,
1165 payload.survival_beta_log_sigma.as_ref(),
1166 "log-sigma",
1167 )?;
1168 let p_wiggle = match link_wiggle {
1169 Some(runtime) => {
1170 let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1171 FittedModelError::MissingField {
1172 reason: "location-scale survival saved fit is missing link-wiggle block"
1173 .to_string(),
1174 }
1175 })?;
1176 if block.beta.to_vec() != runtime.beta {
1177 return Err(FittedModelError::SchemaMismatch {
1178 reason:
1179 "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1180 .to_string(),
1181 });
1182 }
1183 runtime.beta.len()
1184 }
1185 None => {
1186 if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1187 return Err(FittedModelError::SchemaMismatch {
1188 reason:
1189 "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1190 .to_string(),
1191 });
1192 }
1193 0
1194 }
1195 };
1196 let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1197
1198 if let Some(cov) = fit.beta_covariance()
1199 && (cov.nrows() != expected || cov.ncols() != expected)
1200 {
1201 return Err(FittedModelError::SchemaMismatch {
1202 reason: format!(
1203 "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1204 cov.nrows(),
1205 cov.ncols(),
1206 expected,
1207 expected
1208 ),
1209 });
1210 }
1211 if let Some(cov) = fit.beta_covariance_corrected()
1212 && (cov.nrows() != expected || cov.ncols() != expected)
1213 {
1214 return Err(FittedModelError::SchemaMismatch {
1215 reason: format!(
1216 "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1217 cov.nrows(),
1218 cov.ncols(),
1219 expected,
1220 expected
1221 ),
1222 });
1223 }
1224 Ok(())
1225}
1226
1227fn validate_marginal_slope_saved_fit(
1228 fit: &UnifiedFitResult,
1229 score_warp: Option<&SavedCompiledFlexBlock>,
1230 link_deviation: Option<&SavedCompiledFlexBlock>,
1231 fit_label: &str,
1232) -> Result<(), FittedModelError> {
1233 validate_marginal_slope_saved_fit_impl(
1234 fit,
1235 score_warp,
1236 link_deviation,
1237 fit_label,
1238 "bernoulli",
1239 2,
1240 "marginal, logslope",
1241 )
1242}
1243
1244fn validate_survival_marginal_slope_saved_fit(
1245 fit: &UnifiedFitResult,
1246 score_warp: Option<&SavedCompiledFlexBlock>,
1247 link_deviation: Option<&SavedCompiledFlexBlock>,
1248 fit_label: &str,
1249) -> Result<(), FittedModelError> {
1250 validate_marginal_slope_saved_fit_impl(
1251 fit,
1252 score_warp,
1253 link_deviation,
1254 fit_label,
1255 "survival",
1256 3,
1257 "time, marginal, slope",
1258 )
1259}
1260
1261fn validate_marginal_slope_saved_fit_impl(
1269 fit: &UnifiedFitResult,
1270 score_warp: Option<&SavedCompiledFlexBlock>,
1271 link_deviation: Option<&SavedCompiledFlexBlock>,
1272 fit_label: &str,
1273 family_kind: &str,
1274 base_block_count: usize,
1275 base_block_role_list: &str,
1276) -> Result<(), FittedModelError> {
1277 let expected_blocks = base_block_count
1278 + usize::from(score_warp.is_some())
1279 + usize::from(link_deviation.is_some());
1280 if fit.blocks.len() != expected_blocks {
1281 let score_warp_suffix = if score_warp.is_some() {
1282 ", score-warp"
1283 } else {
1284 ""
1285 };
1286 let link_deviation_suffix = if link_deviation.is_some() {
1287 ", link-deviation"
1288 } else {
1289 ""
1290 };
1291 return Err(FittedModelError::SchemaMismatch {
1292 reason: format!(
1293 "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1294 fit.blocks.len(),
1295 ),
1296 });
1297 }
1298 if let Some(runtime) = score_warp {
1299 let beta = &fit.blocks[base_block_count].beta;
1300 if beta.len() != runtime.basis_dim {
1301 return Err(FittedModelError::SchemaMismatch {
1302 reason: format!(
1303 "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1304 beta.len(),
1305 runtime.basis_dim
1306 ),
1307 });
1308 }
1309 }
1310 if let Some(runtime) = link_deviation {
1311 let idx = base_block_count + usize::from(score_warp.is_some());
1312 let beta = &fit.blocks[idx].beta;
1313 if beta.len() != runtime.basis_dim {
1314 return Err(FittedModelError::SchemaMismatch {
1315 reason: format!(
1316 "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1317 beta.len(),
1318 runtime.basis_dim
1319 ),
1320 });
1321 }
1322 }
1323 Ok(())
1324}
1325
1326impl SavedLinkWiggleRuntime {
1327 fn validate_monotone_derivative(
1328 &self,
1329 q0: &Array1<f64>,
1330 ) -> Result<Array1<f64>, FittedModelError> {
1331 let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1338 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1339 let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1340 if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1341 return Err(FittedModelError::PayloadCorrupt {
1342 reason: format!(
1343 "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1344 ),
1345 });
1346 }
1347 Ok(dq_dq0)
1348 }
1349
1350 pub fn constrained_basis(
1351 &self,
1352 q0: &Array1<f64>,
1353 basis_options: BasisOptions,
1354 ) -> Result<Array2<f64>, FittedModelError> {
1355 let knot_arr = Array1::from_vec(self.knots.clone());
1356 let constrained = monotone_wiggle_basis_with_derivative_order(
1357 q0.view(),
1358 &knot_arr,
1359 self.degree,
1360 basis_options.derivative_order,
1361 )
1362 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1363 if constrained.ncols() != self.beta.len() {
1364 return Err(FittedModelError::SchemaMismatch {
1365 reason: format!(
1366 "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1367 self.beta.len(),
1368 constrained.ncols()
1369 ),
1370 });
1371 }
1372 Ok(constrained)
1373 }
1374
1375 pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1376 self.validate_monotone_derivative(q0)?;
1377 self.constrained_basis(q0, BasisOptions::value())
1378 }
1379
1380 pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1381 let q = Array1::from_vec(vec![q0]);
1382 let x = self.design(&q)?;
1383 if x.nrows() != 1 {
1384 return Err(FittedModelError::SchemaMismatch {
1385 reason: format!(
1386 "saved link-wiggle scalar evaluation expected 1 row, got {}",
1387 x.nrows()
1388 ),
1389 });
1390 }
1391 Ok(x.row(0).to_owned())
1392 }
1393
1394 pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395 self.validate_monotone_derivative(q0)?;
1396 let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1397 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1398 Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1399 }
1400
1401 pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1402 self.validate_monotone_derivative(q0)
1403 }
1404}
1405
1406impl SavedBaselineTimeWiggleRuntime {
1407 pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1408 validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1409 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1410 }
1411}
1412
1413impl SavedCompiledFlexBlock {
1414 pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1415 if self.kernel.is_empty() {
1416 return Err(FittedModelError::SchemaMismatch {
1417 reason: "saved anchored deviation runtime is missing the exact kernel marker"
1418 .to_string(),
1419 });
1420 }
1421 if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1422 return Err(FittedModelError::IncompatibleConfig {
1423 reason: format!(
1424 "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1425 self.kernel,
1426 crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1427 ),
1428 });
1429 }
1430 if self.basis_dim == 0 {
1431 return Err(FittedModelError::SchemaMismatch {
1432 reason: format!(
1433 "saved anchored deviation runtime basis_dim must be positive, got {}",
1434 self.basis_dim
1435 ),
1436 });
1437 }
1438 if self.breakpoints.len() < 2 {
1439 return Err(FittedModelError::SchemaMismatch {
1440 reason: format!(
1441 "saved anchored deviation runtime requires at least two breakpoints, got {}",
1442 self.breakpoints.len()
1443 ),
1444 });
1445 }
1446 for window in self.breakpoints.windows(2) {
1447 let left = window[0];
1448 let right = window[1];
1449 if !left.is_finite() || !right.is_finite() || right <= left {
1450 return Err(FittedModelError::PayloadCorrupt {
1451 reason: format!(
1452 "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1453 ),
1454 });
1455 }
1456 }
1457 let span_count = self.breakpoints.len() - 1;
1458 self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1459 self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1460 self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1461 self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1462 self.validate_c2_span_continuity()?;
1463 self.validate_anchor_residual_shape()?;
1464 Ok(())
1465 }
1466
1467 fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1468 let coeffs = match self.anchor_correction.as_ref() {
1469 Some(c) => c,
1470 None => {
1471 if !self.anchor_components.is_empty() {
1472 return Err(FittedModelError::SchemaMismatch {
1473 reason:
1474 "saved anchored deviation runtime has anchor_components but no anchor_correction"
1475 .to_string(),
1476 });
1477 }
1478 return Ok(());
1479 }
1480 };
1481 let d: usize = self
1482 .anchor_components
1483 .iter()
1484 .map(|c| match &c.kind {
1485 SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1486 SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1487 })
1488 .sum();
1489 if coeffs.len() != d {
1490 return Err(FittedModelError::SchemaMismatch {
1491 reason: format!(
1492 "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1493 coeffs.len(),
1494 d,
1495 ),
1496 });
1497 }
1498 for (i, row) in coeffs.iter().enumerate() {
1499 if row.len() != self.basis_dim {
1500 return Err(FittedModelError::SchemaMismatch {
1501 reason: format!(
1502 "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1503 i,
1504 row.len(),
1505 self.basis_dim,
1506 ),
1507 });
1508 }
1509 for (j, &v) in row.iter().enumerate() {
1510 if !v.is_finite() {
1511 return Err(FittedModelError::PayloadCorrupt {
1512 reason: format!(
1513 "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1514 ),
1515 });
1516 }
1517 }
1518 }
1519 Ok(())
1520 }
1521
1522 fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1523 const TOL: f64 = 1e-8;
1524 for span_idx in 1..self.breakpoints.len() - 1 {
1525 let left_span = span_idx - 1;
1526 let right_span = span_idx;
1527 let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1528 for basis_idx in 0..self.basis_dim {
1529 let left_value = self.span_c0[left_span][basis_idx]
1530 + self.span_c1[left_span][basis_idx] * width
1531 + self.span_c2[left_span][basis_idx] * width * width
1532 + self.span_c3[left_span][basis_idx] * width * width * width;
1533 let left_d1 = self.span_c1[left_span][basis_idx]
1534 + 2.0 * self.span_c2[left_span][basis_idx] * width
1535 + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1536 let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1537 + 6.0 * self.span_c3[left_span][basis_idx] * width;
1538 let right_value = self.span_c0[right_span][basis_idx];
1539 let right_d1 = self.span_c1[right_span][basis_idx];
1540 let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1541 if (left_value - right_value).abs() > TOL
1542 || (left_d1 - right_d1).abs() > TOL
1543 || (left_d2 - right_d2).abs() > TOL
1544 {
1545 return Err(FittedModelError::SchemaMismatch {
1546 reason: format!(
1547 "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1548 left_value - right_value,
1549 left_d1 - right_d1,
1550 left_d2 - right_d2
1551 ),
1552 });
1553 }
1554 }
1555 }
1556 Ok(())
1557 }
1558
1559 fn validate_coefficient_matrix(
1560 &self,
1561 matrix: &[Vec<f64>],
1562 label: &str,
1563 expected_rows: usize,
1564 ) -> Result<(), FittedModelError> {
1565 if matrix.len() != expected_rows {
1566 return Err(FittedModelError::SchemaMismatch {
1567 reason: format!(
1568 "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1569 matrix.len(),
1570 expected_rows
1571 ),
1572 });
1573 }
1574 for (row_idx, row) in matrix.iter().enumerate() {
1575 if row.len() != self.basis_dim {
1576 return Err(FittedModelError::SchemaMismatch {
1577 reason: format!(
1578 "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1579 row_idx,
1580 row.len(),
1581 self.basis_dim
1582 ),
1583 });
1584 }
1585 for (j, &value) in row.iter().enumerate() {
1586 if !value.is_finite() {
1587 return Err(FittedModelError::PayloadCorrupt {
1588 reason: format!(
1589 "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1590 ),
1591 });
1592 }
1593 }
1594 }
1595 Ok(())
1596 }
1597
1598 fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1599 let last_span = self.breakpoints.len() - 2;
1600 let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1601 self.span_c0[last_span][basis_idx]
1602 + self.span_c1[last_span][basis_idx] * width
1603 + self.span_c2[last_span][basis_idx] * width * width
1604 + self.span_c3[last_span][basis_idx] * width * width * width
1605 }
1606
1607 fn evaluate_span_polynomial_design(
1608 &self,
1609 values: &Array1<f64>,
1610 derivative_order: usize,
1611 ) -> Result<Array2<f64>, FittedModelError> {
1612 self.validate_exact_replay_contract()?;
1613 let (left_ep, right_ep) = self.support_interval()?;
1614 let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1615 for (row_idx, &value) in values.iter().enumerate() {
1616 if !value.is_finite() {
1617 return Err(FittedModelError::PayloadCorrupt {
1618 reason: format!(
1619 "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1620 ),
1621 });
1622 }
1623 if value < left_ep {
1624 if derivative_order == 0 {
1625 for basis_idx in 0..self.basis_dim {
1626 out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1627 }
1628 }
1629 continue;
1630 }
1631 if value > right_ep {
1632 if derivative_order == 0 {
1633 for basis_idx in 0..self.basis_dim {
1634 out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1635 }
1636 }
1637 continue;
1638 }
1639 let span_idx = self.left_biased_span_index_for(value)?;
1640 let t = value - self.breakpoints[span_idx];
1641 for basis_idx in 0..self.basis_dim {
1642 let c0 = self.span_c0[span_idx][basis_idx];
1643 let c1 = self.span_c1[span_idx][basis_idx];
1644 let c2 = self.span_c2[span_idx][basis_idx];
1645 let c3 = self.span_c3[span_idx][basis_idx];
1646 out[[row_idx, basis_idx]] = match derivative_order {
1647 0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1648 1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1649 2 => 2.0 * c2 + 6.0 * c3 * t,
1650 3 => 6.0 * c3,
1651 4 => 0.0,
1652 other => {
1653 return Err(FittedModelError::IncompatibleConfig {
1654 reason: format!(
1655 "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1656 ),
1657 });
1658 }
1659 };
1660 }
1661 }
1662 Ok(out)
1663 }
1664
1665 pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1666 self.validate_exact_replay_contract()?;
1667 Ok(self.breakpoints.clone())
1668 }
1669
1670 pub fn span_count(&self) -> Result<usize, FittedModelError> {
1671 Ok(self.breakpoints()?.windows(2).count())
1672 }
1673
1674 pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1675 let points = self.breakpoints()?;
1676 span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1677 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1678 }
1679
1680 fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1681 let mut span_idx = span_index_for_breakpoints(
1682 &self.breakpoints,
1683 value,
1684 "saved anchored deviation span lookup",
1685 )
1686 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1687 if span_idx > 0 && value == self.breakpoints[span_idx] {
1690 span_idx -= 1;
1691 }
1692 Ok(span_idx)
1693 }
1694
1695 pub fn local_cubic_on_span(
1696 &self,
1697 beta: &Array1<f64>,
1698 span_idx: usize,
1699 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1700 self.validate_exact_replay_contract()?;
1701 if beta.len() != self.basis_dim {
1702 return Err(FittedModelError::SchemaMismatch {
1703 reason: format!(
1704 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1705 beta.len(),
1706 self.basis_dim
1707 ),
1708 });
1709 }
1710 self.local_cubic_on_span_validated(beta, span_idx)
1711 }
1712
1713 fn local_cubic_on_span_validated(
1714 &self,
1715 beta: &Array1<f64>,
1716 span_idx: usize,
1717 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1718 let points = &self.breakpoints;
1719 if span_idx + 1 >= points.len() {
1720 return Err(FittedModelError::SchemaMismatch {
1721 reason: format!(
1722 "saved anchored deviation span index {} out of range for {} spans",
1723 span_idx,
1724 points.len() - 1
1725 ),
1726 });
1727 }
1728 let left = points[span_idx];
1729 let right = points[span_idx + 1];
1730 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1731 left,
1732 right,
1733 c0: self.span_c0[span_idx]
1734 .iter()
1735 .zip(beta.iter())
1736 .map(|(coeff, weight)| coeff * weight)
1737 .sum(),
1738 c1: self.span_c1[span_idx]
1739 .iter()
1740 .zip(beta.iter())
1741 .map(|(coeff, weight)| coeff * weight)
1742 .sum(),
1743 c2: self.span_c2[span_idx]
1744 .iter()
1745 .zip(beta.iter())
1746 .map(|(coeff, weight)| coeff * weight)
1747 .sum(),
1748 c3: self.span_c3[span_idx]
1749 .iter()
1750 .zip(beta.iter())
1751 .map(|(coeff, weight)| coeff * weight)
1752 .sum(),
1753 })
1754 }
1755
1756 pub fn basis_span_cubic(
1757 &self,
1758 span_idx: usize,
1759 basis_idx: usize,
1760 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1761 self.validate_exact_replay_contract()?;
1762 if basis_idx >= self.basis_dim {
1763 return Err(FittedModelError::SchemaMismatch {
1764 reason: format!(
1765 "saved anchored deviation basis index {} out of range for {} coefficients",
1766 basis_idx, self.basis_dim
1767 ),
1768 });
1769 }
1770 self.basis_span_cubic_validated(span_idx, basis_idx)
1771 }
1772
1773 fn basis_span_cubic_validated(
1774 &self,
1775 span_idx: usize,
1776 basis_idx: usize,
1777 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1778 let points = &self.breakpoints;
1779 if span_idx + 1 >= points.len() {
1780 return Err(FittedModelError::SchemaMismatch {
1781 reason: format!(
1782 "saved anchored deviation span index {} out of range for {} spans",
1783 span_idx,
1784 points.len() - 1
1785 ),
1786 });
1787 }
1788 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1789 left: points[span_idx],
1790 right: points[span_idx + 1],
1791 c0: self.span_c0[span_idx][basis_idx],
1792 c1: self.span_c1[span_idx][basis_idx],
1793 c2: self.span_c2[span_idx][basis_idx],
1794 c3: self.span_c3[span_idx][basis_idx],
1795 })
1796 }
1797
1798 pub fn basis_cubic_at(
1799 &self,
1800 basis_idx: usize,
1801 value: f64,
1802 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1803 self.validate_exact_replay_contract()?;
1804 if basis_idx >= self.basis_dim {
1805 return Err(FittedModelError::SchemaMismatch {
1806 reason: format!(
1807 "saved anchored deviation basis index {} out of range for {} coefficients",
1808 basis_idx, self.basis_dim
1809 ),
1810 });
1811 }
1812 let (left_ep, right_ep) = self.support_interval()?;
1813 if value < left_ep {
1814 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1815 left: left_ep,
1816 right: left_ep + 1.0,
1817 c0: self.span_c0[0][basis_idx],
1818 c1: 0.0,
1819 c2: 0.0,
1820 c3: 0.0,
1821 });
1822 }
1823 if value > right_ep {
1824 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1825 left: right_ep,
1826 right: right_ep + 1.0,
1827 c0: self.right_boundary_basis_value(basis_idx),
1828 c1: 0.0,
1829 c2: 0.0,
1830 c3: 0.0,
1831 });
1832 }
1833 let span_idx = self.left_biased_span_index_for(value)?;
1834 self.basis_span_cubic_validated(span_idx, basis_idx)
1835 }
1836
1837 pub fn local_cubic_at(
1838 &self,
1839 beta: &Array1<f64>,
1840 value: f64,
1841 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1842 self.validate_exact_replay_contract()?;
1843 if beta.len() != self.basis_dim {
1844 return Err(FittedModelError::SchemaMismatch {
1845 reason: format!(
1846 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1847 beta.len(),
1848 self.basis_dim
1849 ),
1850 });
1851 }
1852 let (left_ep, right_ep) = self.support_interval()?;
1853 if value < left_ep {
1854 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1855 left: left_ep,
1856 right: left_ep + 1.0,
1857 c0: self.span_c0[0]
1858 .iter()
1859 .zip(beta.iter())
1860 .map(|(coeff, weight)| coeff * weight)
1861 .sum(),
1862 c1: 0.0,
1863 c2: 0.0,
1864 c3: 0.0,
1865 });
1866 }
1867 if value > right_ep {
1868 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1869 left: right_ep,
1870 right: right_ep + 1.0,
1871 c0: (0..self.basis_dim)
1872 .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1873 .sum(),
1874 c1: 0.0,
1875 c2: 0.0,
1876 c3: 0.0,
1877 });
1878 }
1879 let span_idx = self.left_biased_span_index_for(value)?;
1880 self.local_cubic_on_span_validated(beta, span_idx)
1881 }
1882
1883 fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1884 let points = self.breakpoints()?;
1885 match (points.first(), points.last()) {
1886 (Some(&left), Some(&right)) => Ok((left, right)),
1887 _ => Err(FittedModelError::MissingField {
1888 reason: "saved anchored deviation runtime is missing support breakpoints"
1889 .to_string(),
1890 }),
1891 }
1892 }
1893
1894 pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1895 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1904 }
1905
1906 pub fn design_uncorrected(
1914 &self,
1915 values: &Array1<f64>,
1916 ) -> Result<Array2<f64>, FittedModelError> {
1917 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1918 }
1919
1920 pub fn design_with_anchor_rows(
1929 &self,
1930 values: &Array1<f64>,
1931 anchor_rows: ndarray::ArrayView2<f64>,
1932 ) -> Result<Array2<f64>, FittedModelError> {
1933 let mut out =
1934 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1935 if let Some(m_rows) = self.anchor_correction.as_ref() {
1936 let d = m_rows.len();
1937 if anchor_rows.nrows() != values.len() {
1938 return Err(FittedModelError::SchemaMismatch {
1939 reason: format!(
1940 "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1941 anchor_rows.nrows(),
1942 values.len(),
1943 ),
1944 });
1945 }
1946 if anchor_rows.ncols() != d {
1947 return Err(FittedModelError::SchemaMismatch {
1948 reason: format!(
1949 "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1950 anchor_rows.ncols(),
1951 d,
1952 ),
1953 });
1954 }
1955 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1957 for (i, row) in m_rows.iter().enumerate() {
1958 if row.len() != self.basis_dim {
1959 return Err(FittedModelError::SchemaMismatch {
1960 reason: format!(
1961 "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1962 i,
1963 row.len(),
1964 self.basis_dim,
1965 ),
1966 });
1967 }
1968 for (j, &v) in row.iter().enumerate() {
1969 m_dense[[i, j]] = v;
1970 }
1971 }
1972 let subtract = anchor_rows.dot(&m_dense);
1975 out = out - subtract;
1976 } else if anchor_rows.ncols() != 0 {
1977 return Err(FittedModelError::SchemaMismatch {
1978 reason: format!(
1979 "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1980 anchor_rows.ncols(),
1981 ),
1982 });
1983 }
1984 Ok(out)
1985 }
1986
1987 pub fn anchor_correction_matrix(
1995 &self,
1996 n_anchor_rows: ndarray::ArrayView2<f64>,
1997 ) -> Result<Option<Array2<f64>>, FittedModelError> {
1998 let Some(m_rows) = self.anchor_correction.as_ref() else {
1999 return Ok(None);
2000 };
2001 let d = m_rows.len();
2002 if n_anchor_rows.ncols() != d {
2003 return Err(FittedModelError::SchemaMismatch {
2004 reason: format!(
2005 "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
2006 n_anchor_rows.ncols(),
2007 d,
2008 ),
2009 });
2010 }
2011 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2012 for (i, row) in m_rows.iter().enumerate() {
2013 if row.len() != self.basis_dim {
2014 return Err(FittedModelError::SchemaMismatch {
2015 reason: format!(
2016 "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2017 i,
2018 row.len(),
2019 self.basis_dim,
2020 ),
2021 });
2022 }
2023 for (j, &v) in row.iter().enumerate() {
2024 m_dense[[i, j]] = v;
2025 }
2026 }
2027 Ok(Some(n_anchor_rows.dot(&m_dense)))
2030 }
2031
2032 pub fn first_derivative_design(
2033 &self,
2034 values: &Array1<f64>,
2035 ) -> Result<Array2<f64>, FittedModelError> {
2036 self.evaluate_span_polynomial_design(
2037 values,
2038 BasisOptions::first_derivative().derivative_order,
2039 )
2040 }
2041
2042 pub fn second_derivative_design(
2043 &self,
2044 values: &Array1<f64>,
2045 ) -> Result<Array2<f64>, FittedModelError> {
2046 self.evaluate_span_polynomial_design(
2047 values,
2048 BasisOptions::second_derivative().derivative_order,
2049 )
2050 }
2051}
2052
2053impl FittedFamily {
2054 #[inline]
2055 pub fn likelihood(&self) -> LikelihoodSpec {
2056 let spec = match self {
2057 Self::Standard { likelihood, .. }
2058 | Self::LocationScale { likelihood, .. }
2059 | Self::MarginalSlope { likelihood, .. }
2060 | Self::Survival { likelihood, .. }
2061 | Self::TransformationNormal { likelihood, .. } => likelihood,
2062 Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2063 return LikelihoodSpec::royston_parmar();
2064 }
2065 };
2066 spec.clone()
2067 }
2068
2069 #[inline]
2070 pub fn frailty(&self) -> Option<&FrailtySpec> {
2071 match self {
2072 Self::MarginalSlope { frailty, .. }
2073 | Self::Survival { frailty, .. }
2074 | Self::LatentSurvival { frailty }
2075 | Self::LatentBinary { frailty } => Some(frailty),
2076 _ => None,
2077 }
2078 }
2079}
2080
2081fn collect_smooth_extrapolation_axes(
2087 basis: &gam_terms::smooth::SmoothBasisSpec,
2088 n_training_headers: usize,
2089 out: &mut std::collections::HashSet<usize>,
2090) {
2091 use gam_terms::smooth::SmoothBasisSpec;
2092 let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2093 if col < n_training_headers {
2094 out.insert(col);
2095 }
2096 };
2097 match basis {
2098 SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2100 SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2103 for &c in feature_cols {
2104 push(c, out);
2105 }
2106 }
2107 SmoothBasisSpec::ThinPlate { feature_cols, .. }
2114 | SmoothBasisSpec::Matern { feature_cols, .. }
2115 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2116 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2117 for &c in feature_cols {
2118 push(c, out);
2119 }
2120 }
2121 SmoothBasisSpec::FactorSmooth { spec } => {
2125 for &c in &spec.continuous_cols {
2126 push(c, out);
2127 }
2128 }
2129 SmoothBasisSpec::ByVariable { inner, .. }
2131 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2132 collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2133 }
2134 SmoothBasisSpec::BySmooth { smooth, .. } => {
2135 collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2136 }
2137 SmoothBasisSpec::Sphere { .. }
2143 | SmoothBasisSpec::ConstantCurvature { .. }
2144 | SmoothBasisSpec::Pca { .. } => {}
2145 }
2146}
2147
2148fn collect_by_variable_numeric_axes(
2167 basis: &gam_terms::smooth::SmoothBasisSpec,
2168 n_training_headers: usize,
2169 out: &mut std::collections::HashSet<usize>,
2170) {
2171 use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2172 match basis {
2173 SmoothBasisSpec::ByVariable {
2174 inner,
2175 by_col,
2176 kind,
2177 ..
2178 } => {
2179 if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2180 out.insert(*by_col);
2181 }
2182 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2183 }
2184 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2185 if let ByVarKind::Numeric { feature_col } = by_kind
2186 && *feature_col < n_training_headers
2187 {
2188 out.insert(*feature_col);
2189 }
2190 collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2191 }
2192 SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2193 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2194 }
2195 _ => {}
2196 }
2197}
2198
2199impl FittedModel {
2200 pub fn axis_clip_to_training_ranges(
2209 &self,
2210 data: ndarray::ArrayView2<'_, f64>,
2211 col_map: &std::collections::HashMap<String, usize>,
2212 ) -> Option<ndarray::Array2<f64>> {
2213 let training_headers = self.training_headers.as_ref()?;
2214 let ranges = self.training_feature_ranges.as_ref()?;
2215 if training_headers.len() != ranges.len() {
2216 return None;
2217 }
2218 let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2219 std::collections::HashMap::new();
2220 if let Some(schema) = self.data_schema.as_ref() {
2221 for col in &schema.columns {
2222 kind_by_header.insert(col.name.as_str(), col.kind);
2223 }
2224 }
2225 let periodic_axes = self.training_periodic_axes(training_headers);
2231 let linear_axes = self.training_linear_axes(training_headers.len());
2239 let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2244 let smooth_extrapolation_axes =
2252 self.training_smooth_extrapolation_axes(training_headers.len());
2253 let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2259 let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2264 let mut clipped = data.to_owned();
2265 let mut any_clipped = false;
2266 for (col_in_training, (header, &(lo, hi))) in
2267 training_headers.iter().zip(ranges.iter()).enumerate()
2268 {
2269 let (lo, hi) = sphere_lat_bounds
2270 .get(&col_in_training)
2271 .copied()
2272 .unwrap_or((lo, hi));
2273 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2274 continue;
2275 }
2276 if !matches!(
2277 kind_by_header.get(header.as_str()).copied(),
2278 Some(ColumnKindTag::Continuous)
2279 ) {
2280 continue;
2281 }
2282 if periodic_axes.contains(&col_in_training) {
2283 continue;
2284 }
2285 if linear_axes.contains(&col_in_training) {
2286 continue;
2287 }
2288 if random_effect_axes.contains(&col_in_training) {
2289 continue;
2290 }
2291 if smooth_extrapolation_axes.contains(&col_in_training) {
2292 continue;
2293 }
2294 if by_variable_axes.contains(&col_in_training) {
2295 continue;
2296 }
2297 let Some(&col_idx) = col_map.get(header) else {
2298 continue;
2299 };
2300 if col_idx >= clipped.ncols() {
2301 continue;
2302 }
2303 let mut col = clipped.column_mut(col_idx);
2304 for v in col.iter_mut() {
2305 if v.is_finite() {
2306 if *v < lo {
2307 *v = lo;
2308 any_clipped = true;
2309 } else if *v > hi {
2310 *v = hi;
2311 any_clipped = true;
2312 }
2313 }
2314 }
2315 }
2316 if any_clipped { Some(clipped) } else { None }
2317 }
2318
2319 fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2320 let mut specs: Vec<&TermCollectionSpec> = [
2321 self.resolved_termspec.as_ref(),
2322 self.resolved_termspec_noise.as_ref(),
2323 self.resolved_termspec_logslope.as_ref(),
2324 ]
2325 .into_iter()
2326 .flatten()
2327 .collect();
2328 if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2329 specs.extend(logslopes.iter());
2330 }
2331 specs
2332 }
2333
2334 fn training_periodic_axes(
2341 &self,
2342 training_headers: &[String],
2343 ) -> std::collections::HashSet<usize> {
2344 use gam_terms::basis::BSplineKnotSpec;
2345 use gam_terms::smooth::SmoothBasisSpec;
2346 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2347 let Some(spec) = self.resolved_termspec.as_ref() else {
2348 return out;
2349 };
2350 for term in &spec.smooth_terms {
2351 match &term.basis {
2352 SmoothBasisSpec::Sphere { feature_cols, .. } => {
2358 if let Some(&lon_col) = feature_cols.get(1)
2359 && lon_col < training_headers.len()
2360 {
2361 out.insert(lon_col);
2362 }
2363 }
2364 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2366 if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2367 && *feature_col < training_headers.len()
2368 {
2369 out.insert(*feature_col);
2370 }
2371 }
2372 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2375 for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2376 if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2377 && let Some(&col) = feature_cols.get(i)
2378 && col < training_headers.len()
2379 {
2380 out.insert(col);
2381 }
2382 }
2383 }
2384 _ => {}
2385 }
2386 }
2387 out
2388 }
2389
2390 fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2401 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2402 for spec in self.saved_term_specs() {
2403 for term in &spec.linear_terms {
2404 for col in term.effective_feature_cols() {
2405 if col < n_training_headers {
2406 out.insert(col);
2407 }
2408 }
2409 }
2410 }
2411 out
2412 }
2413
2414 fn training_random_effect_axes(
2420 &self,
2421 n_training_headers: usize,
2422 ) -> std::collections::HashSet<usize> {
2423 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2424 for spec in self.saved_term_specs() {
2425 for term in &spec.random_effect_terms {
2426 if term.feature_col < n_training_headers {
2427 out.insert(term.feature_col);
2428 }
2429 }
2430 }
2431 out
2432 }
2433
2434 fn training_smooth_extrapolation_axes(
2467 &self,
2468 n_training_headers: usize,
2469 ) -> std::collections::HashSet<usize> {
2470 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2471 for spec in self.saved_term_specs() {
2472 for term in &spec.smooth_terms {
2473 collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2474 }
2475 }
2476 out
2477 }
2478
2479 fn training_by_variable_numeric_axes(
2485 &self,
2486 n_training_headers: usize,
2487 ) -> std::collections::HashSet<usize> {
2488 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2489 for spec in self.saved_term_specs() {
2490 for term in &spec.smooth_terms {
2491 collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2492 }
2493 }
2494 out
2495 }
2496
2497 fn training_sphere_latitude_bounds(
2519 &self,
2520 training_headers: &[String],
2521 ) -> std::collections::HashMap<usize, (f64, f64)> {
2522 use gam_terms::smooth::SmoothBasisSpec;
2523 let mut out: std::collections::HashMap<usize, (f64, f64)> =
2524 std::collections::HashMap::new();
2525 let Some(spec) = self.resolved_termspec.as_ref() else {
2526 return out;
2527 };
2528 for term in &spec.smooth_terms {
2529 if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2530 && let Some(&lat_col) = feature_cols.first()
2531 && lat_col < training_headers.len()
2532 {
2533 let bound = if spec.radians {
2534 std::f64::consts::FRAC_PI_2
2535 } else {
2536 90.0
2537 };
2538 out.insert(lat_col, (-bound, bound));
2539 }
2540 }
2541 out
2542 }
2543
2544 pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2545 let likelihood = payload.family_state.likelihood();
2546 let class = match payload.model_kind {
2547 ModelKind::Survival => PredictModelClass::Survival,
2548 ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2549 ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2550 ModelKind::LocationScale => {
2551 if likelihood == LikelihoodSpec::gaussian_identity() {
2552 PredictModelClass::GaussianLocationScale
2553 } else if is_dispersion_location_scale_response(&likelihood.response) {
2554 PredictModelClass::DispersionLocationScale
2555 } else {
2556 PredictModelClass::BinomialLocationScale
2557 }
2558 }
2559 ModelKind::Standard => PredictModelClass::Standard,
2560 };
2561 match class {
2562 PredictModelClass::Survival => {
2563 payload.model_kind = ModelKind::Survival;
2564 Self::Survival { payload }
2565 }
2566 PredictModelClass::BernoulliMarginalSlope => {
2567 payload.model_kind = ModelKind::MarginalSlope;
2568 Self::MarginalSlope { payload }
2569 }
2570 PredictModelClass::TransformationNormal => {
2571 payload.model_kind = ModelKind::TransformationNormal;
2572 Self::TransformationNormal { payload }
2573 }
2574 PredictModelClass::GaussianLocationScale
2575 | PredictModelClass::BinomialLocationScale
2576 | PredictModelClass::DispersionLocationScale => {
2577 payload.model_kind = ModelKind::LocationScale;
2578 Self::LocationScale { payload }
2579 }
2580 PredictModelClass::Standard => {
2581 payload.model_kind = ModelKind::Standard;
2582 Self::Standard { payload }
2583 }
2584 }
2585 .with_synchronized_stateful_link_metadata()
2586 }
2587
2588 #[inline]
2589 pub fn payload(&self) -> &FittedModelPayload {
2590 match self {
2591 Self::Standard { payload }
2592 | Self::LocationScale { payload }
2593 | Self::MarginalSlope { payload }
2594 | Self::Survival { payload }
2595 | Self::TransformationNormal { payload } => payload,
2596 }
2597 }
2598
2599 #[inline]
2600 fn payload_mut(&mut self) -> &mut FittedModelPayload {
2601 match self {
2602 Self::Standard { payload }
2603 | Self::LocationScale { payload }
2604 | Self::MarginalSlope { payload }
2605 | Self::Survival { payload }
2606 | Self::TransformationNormal { payload } => payload,
2607 }
2608 }
2609
2610 fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2611 self.synchronize_stateful_link_metadata();
2612 self
2613 }
2614
2615 fn synchronize_stateful_link_metadata(&mut self) {
2616 let payload = self.payload_mut();
2617 payload.used_device = payload
2618 .fit_result
2619 .as_ref()
2620 .or(payload.unified.as_ref())
2621 .is_some_and(|fit| fit.used_device);
2622 payload.synchronize_empty_feature_contract();
2623 let Some(fit) = payload.fit_result.as_ref().or(payload.unified.as_ref()) else {
2624 return;
2625 };
2626 match (&mut payload.family_state, &fit.fitted_link) {
2627 (
2628 FittedFamily::Standard {
2629 likelihood,
2630 latent_cloglog_state,
2631 ..
2632 },
2633 FittedLinkState::LatentCLogLog { state },
2634 ) if likelihood.is_latent_cloglog() => {
2635 *latent_cloglog_state = Some(*state);
2636 }
2637 (
2638 FittedFamily::Standard {
2639 likelihood,
2640 sas_state,
2641 ..
2642 },
2643 FittedLinkState::Sas { state, covariance },
2644 ) if likelihood.is_binomial_sas() => {
2645 *sas_state = Some(*state);
2646 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2647 }
2648 (
2649 FittedFamily::Standard {
2650 likelihood,
2651 sas_state,
2652 ..
2653 },
2654 FittedLinkState::BetaLogistic { state, covariance },
2655 ) if likelihood.is_binomial_beta_logistic() => {
2656 *sas_state = Some(*state);
2657 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2658 }
2659 (
2660 FittedFamily::Standard {
2661 likelihood,
2662 mixture_state,
2663 ..
2664 },
2665 FittedLinkState::Mixture { state, covariance },
2666 ) if likelihood.is_binomial_mixture() => {
2667 *mixture_state = Some(state.clone());
2668 payload.mixture_link_param_covariance =
2669 covariance.as_ref().map(array2_to_nestedvec);
2670 }
2671 _ => {}
2672 }
2673 }
2674
2675 #[inline]
2676 pub fn likelihood(&self) -> LikelihoodSpec {
2677 self.payload().family_state.likelihood()
2678 }
2679
2680 pub fn prediction_required_columns(
2698 &self,
2699 ) -> Result<std::collections::BTreeSet<String>, String> {
2700 let payload = self.payload();
2701 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2702 let mut required = std::collections::BTreeSet::<String>::new();
2703 parsed_term_column_names(&parsed.terms, &mut required);
2704
2705 if let Some((entry, exit, _event)) =
2706 parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707 {
2708 if let Some(entry) = entry {
2709 required.insert(entry);
2710 }
2711 required.insert(exit);
2712 } else if let Some((left, right, _event)) =
2713 parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2714 {
2715 required.insert(left);
2716 required.insert(right);
2717 }
2718 if let Some(offset) = payload.offset_column.as_ref() {
2726 required.insert(offset.clone());
2727 }
2728 if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2729 required.insert(noise_offset.clone());
2730 }
2731 if matches!(
2732 self.predict_model_class(),
2733 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2734 ) {
2735 if let Some(z_column) = payload.z_column.as_ref() {
2736 required.remove("z");
2737 required.insert(z_column.clone());
2738 }
2739 }
2740 if let Some(noise_formula) = payload.formula_noise.as_ref() {
2741 self.add_auxiliary_formula_columns(
2742 &mut required,
2743 noise_formula,
2744 parsed.response.as_str(),
2745 )?;
2746 }
2747 if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2748 if logslope_formula != "same-as-main" {
2749 self.add_auxiliary_formula_columns(
2750 &mut required,
2751 logslope_formula,
2752 parsed.response.as_str(),
2753 )?;
2754 }
2755 }
2756 Ok(required)
2757 }
2758
2759 pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2776 let payload = self.payload();
2777 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2778 let mut extras: Vec<String> = Vec::new();
2788 if let Some(weight_column) = payload.weight_column.as_ref() {
2789 extras.push(weight_column.clone());
2790 }
2791 if parse_surv_response(parsed.response.as_str())
2794 .map_err(|e| e.to_string())?
2795 .is_some()
2796 || parse_surv_interval_response(parsed.response.as_str())
2797 .map_err(|e| e.to_string())?
2798 .is_some()
2799 {
2800 return Ok(extras);
2801 }
2802 let response = parsed.response.trim();
2803 if response.is_empty() || response.contains('(') {
2806 return Ok(extras);
2807 }
2808 if self.prediction_required_columns()?.contains(response) {
2811 return Ok(extras);
2812 }
2813 extras.push(response.to_string());
2814 Ok(extras)
2815 }
2816
2817 fn add_auxiliary_formula_columns(
2820 &self,
2821 required: &mut std::collections::BTreeSet<String>,
2822 formula_or_rhs: &str,
2823 response: &str,
2824 ) -> Result<(), String> {
2825 let trimmed = formula_or_rhs.trim();
2826 if trimmed.is_empty() || trimmed == "1" {
2827 return Ok(());
2828 }
2829 let formula = if trimmed.contains('~') {
2830 trimmed.to_string()
2831 } else {
2832 format!("{response} ~ {trimmed}")
2833 };
2834 let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2835 parsed_term_column_names(&parsed.terms, required);
2836 Ok(())
2837 }
2838
2839 #[inline]
2840 pub fn predict_model_class(&self) -> PredictModelClass {
2841 match &self.payload().family_state {
2842 FittedFamily::Survival { .. }
2843 | FittedFamily::LatentSurvival { .. }
2844 | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2845 FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2846 FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2847 FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2848 PredictModelClass::GaussianLocationScale
2849 }
2850 FittedFamily::LocationScale { likelihood, .. }
2851 if is_dispersion_location_scale_response(&likelihood.response) =>
2852 {
2853 PredictModelClass::DispersionLocationScale
2854 }
2855 FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2856 FittedFamily::Standard { .. } => PredictModelClass::Standard,
2857 }
2858 }
2859
2860 pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2861 let payload = self.payload();
2862 let (knots, degree) = match (
2863 payload.linkwiggle_knots.as_ref(),
2864 payload.linkwiggle_degree,
2865 ) {
2866 (None, None) => return Ok(None),
2867 (Some(knots), Some(degree)) => (knots.clone(), degree),
2868 _ => {
2869 return Err(FittedModelError::SchemaMismatch {
2870 reason:
2871 "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2872 .to_string(),
2873 })
2874 }
2875 };
2876 let resolved_link = self.resolved_inverse_link()?;
2877 let saved_link_disallows_wiggle = resolved_link
2878 .as_ref()
2879 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2880 || payload
2881 .link
2882 .as_ref()
2883 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2884 if saved_link_disallows_wiggle {
2885 return Err(FittedModelError::IncompatibleConfig {
2886 reason: joint_wiggle_unsupported_link_message("link wiggle"),
2887 });
2888 }
2889 let beta = match self.predict_model_class() {
2890 PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2898 payload.beta_link_wiggle.clone().expect("checked is_some")
2899 }
2900 PredictModelClass::Standard => {
2901 let fit = payload.fit_result.as_ref().ok_or_else(|| {
2902 FittedModelError::MissingField {
2903 reason:
2904 "standard link-wiggle model is missing canonical fit_result payload"
2905 .to_string(),
2906 }
2907 })?;
2908 if fit.blocks.len() != 2
2909 || fit.blocks[0].role != BlockRole::Mean
2910 || fit.blocks[1].role != BlockRole::LinkWiggle
2911 {
2912 return Err(FittedModelError::SchemaMismatch {
2913 reason:
2914 "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2915 .to_string(),
2916 });
2917 }
2918 fit.block_by_role(BlockRole::LinkWiggle)
2919 .ok_or_else(|| FittedModelError::MissingField {
2920 reason:
2921 "standard link-wiggle model is missing LinkWiggle coefficient block"
2922 .to_string(),
2923 })?
2924 .beta
2925 .to_vec()
2926 }
2927 _ => payload
2928 .beta_link_wiggle
2929 .clone()
2930 .ok_or_else(|| FittedModelError::MissingField {
2931 reason:
2932 "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2933 .to_string(),
2934 })?,
2935 };
2936 Ok(Some(SavedLinkWiggleRuntime {
2937 knots,
2938 degree,
2939 beta,
2940 }))
2941 }
2942
2943 pub fn saved_baseline_time_wiggle(
2944 &self,
2945 ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2946 let payload = self.payload();
2947 if payload
2948 .survival_cause_count
2949 .is_some_and(|cause_count| cause_count > 1)
2950 && payload.beta_baseline_timewiggle.is_none()
2951 && payload.beta_baseline_timewiggle_by_cause.is_some()
2952 {
2953 return Err(FittedModelError::SchemaMismatch {
2954 reason:
2955 "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2956 .to_string(),
2957 });
2958 }
2959 match (
2960 payload.baseline_timewiggle_knots.as_ref(),
2961 payload.baseline_timewiggle_degree,
2962 payload.baseline_timewiggle_penalty_orders.as_ref(),
2963 payload.baseline_timewiggle_double_penalty,
2964 payload.beta_baseline_timewiggle.as_ref(),
2965 ) {
2966 (None, None, None, None, None) => Ok(None),
2967 (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2968 Ok(Some(SavedBaselineTimeWiggleRuntime {
2969 knots: knots.clone(),
2970 degree,
2971 penalty_orders: penalty_orders.clone(),
2972 double_penalty,
2973 beta: beta.clone(),
2974 }))
2975 }
2976 _ => Err(FittedModelError::SchemaMismatch {
2977 reason:
2978 "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2979 .to_string(),
2980 }),
2981 }
2982 }
2983
2984 #[inline]
2986 pub fn has_link_wiggle(&self) -> bool {
2987 self.saved_link_wiggle()
2988 .map(|runtime| runtime.is_some())
2989 .unwrap_or(false)
2990 }
2991
2992 #[inline]
2994 pub fn has_baseline_time_wiggle(&self) -> bool {
2995 let payload = self.payload();
2996 if payload
2997 .survival_cause_count
2998 .is_some_and(|cause_count| cause_count > 1)
2999 {
3000 return payload.baseline_timewiggle_knots.is_some()
3001 && payload.baseline_timewiggle_degree.is_some()
3002 && payload.baseline_timewiggle_penalty_orders.is_some()
3003 && payload.baseline_timewiggle_double_penalty.is_some()
3004 && payload.beta_baseline_timewiggle_by_cause.is_some();
3005 }
3006 self.saved_baseline_time_wiggle()
3007 .map(|runtime| runtime.is_some())
3008 .unwrap_or(false)
3009 }
3010
3011 #[inline]
3036 pub fn prediction_uses_posterior_mean(&self) -> bool {
3037 let family = self.likelihood();
3038 let curved_family = match &family.response {
3039 ResponseFamily::Gaussian => false,
3042 ResponseFamily::Poisson
3044 | ResponseFamily::Gamma
3045 | ResponseFamily::Tweedie { .. }
3046 | ResponseFamily::NegativeBinomial { .. } => true,
3047 ResponseFamily::Beta { .. } => true,
3049 ResponseFamily::RoystonParmar => true,
3051 ResponseFamily::Binomial => matches!(
3054 &family.link,
3055 InverseLink::Standard(_)
3056 | InverseLink::Sas(_)
3057 | InverseLink::BetaLogistic(_)
3058 | InverseLink::Mixture(_)
3059 | InverseLink::LatentCLogLog(_)
3060 ),
3061 };
3062 curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3063 }
3064
3065 pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3066 self.payload().validate_payload_version()?;
3067 if matches!(
3068 self.predict_model_class(),
3069 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3070 ) {
3071 if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3072 runtime.validate_exact_replay_contract().map_err(|err| {
3073 FittedModelError::PayloadCorrupt {
3074 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3075 }
3076 })?;
3077 }
3078 if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3079 runtime.validate_exact_replay_contract().map_err(|err| {
3080 FittedModelError::PayloadCorrupt {
3081 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3082 }
3083 })?;
3084 }
3085 }
3086 let runtime = SavedPredictionRuntime {
3087 model_class: self.predict_model_class(),
3088 likelihood: self.likelihood(),
3089 inverse_link: self.resolved_inverse_link()?,
3090 link_wiggle: self.saved_link_wiggle()?,
3091 baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3092 score_warp: self.payload().score_warp_runtime.clone(),
3093 link_deviation: self.payload().link_deviation_runtime.clone(),
3094 latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3095 latent_z_conditional_calibration: self
3096 .payload()
3097 .latent_z_conditional_calibration
3098 .clone(),
3099 influence_absorber_width: self.payload().influence_absorber_width,
3100 };
3101 if matches!(
3102 runtime.model_class,
3103 PredictModelClass::GaussianLocationScale
3104 | PredictModelClass::BinomialLocationScale
3105 | PredictModelClass::DispersionLocationScale
3106 ) {
3107 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3108 FittedModelError::MissingField {
3109 reason: "location-scale model is missing canonical fit_result payload"
3110 .to_string(),
3111 }
3112 })?;
3113 validate_location_scale_saved_fit(
3114 fit,
3115 runtime.model_class,
3116 runtime.link_wiggle.as_ref(),
3117 )?;
3118 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3119 && self
3120 .payload()
3121 .survival_likelihood
3122 .as_deref()
3123 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3124 {
3125 validate_survival_location_scale_saved_fit(
3126 self.payload(),
3127 runtime.link_wiggle.as_ref(),
3128 )?;
3129 } else if matches!(
3130 runtime.model_class,
3131 PredictModelClass::BernoulliMarginalSlope
3132 ) {
3133 let unified =
3134 self.payload()
3135 .unified
3136 .as_ref()
3137 .ok_or_else(|| FittedModelError::MissingField {
3138 reason: "marginal-slope model is missing unified fit payload; refit"
3139 .to_string(),
3140 })?;
3141 validate_marginal_slope_saved_fit(
3142 unified,
3143 runtime.score_warp.as_ref(),
3144 runtime.link_deviation.as_ref(),
3145 "unified",
3146 )?;
3147 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3148 && self
3149 .payload()
3150 .survival_likelihood
3151 .as_deref()
3152 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3153 {
3154 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3155 FittedModelError::MissingField {
3156 reason: "survival marginal-slope model is missing canonical fit_result payload"
3157 .to_string(),
3158 }
3159 })?;
3160 validate_survival_marginal_slope_saved_fit(
3161 fit,
3162 runtime.score_warp.as_ref(),
3163 runtime.link_deviation.as_ref(),
3164 "fit_result",
3165 )?;
3166 }
3167 Ok(runtime)
3168 }
3169
3170 pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3171 let payload = self.payload();
3172 let raw = match &payload.family_state {
3173 FittedFamily::Standard {
3174 likelihood,
3175 sas_state,
3176 ..
3177 } if likelihood.is_binomial_sas() => {
3178 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3179 reason: "binomial-sas model is missing state in family_state.sas_state"
3180 .to_string(),
3181 })?
3182 }
3183 FittedFamily::LocationScale {
3184 likelihood,
3185 base_link,
3186 } if likelihood.is_binomial_sas() => match base_link {
3187 Some(InverseLink::Sas(state)) => *state,
3188 _ => {
3189 return Err(FittedModelError::MissingField {
3190 reason: "binomial-sas location-scale model is missing SAS base_link state"
3191 .to_string(),
3192 });
3193 }
3194 },
3195 _ => return Ok(None),
3196 };
3197 state_from_sasspec(SasLinkSpec {
3198 initial_epsilon: raw.epsilon,
3199 initial_log_delta: raw.log_delta,
3200 })
3201 .map(Some)
3202 .map_err(|e| FittedModelError::PayloadCorrupt {
3203 reason: format!("invalid saved SAS link state: {e}"),
3204 })
3205 }
3206
3207 pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3208 let payload = self.payload();
3209 let raw = match &payload.family_state {
3210 FittedFamily::Standard {
3211 likelihood,
3212 sas_state,
3213 ..
3214 } if likelihood.is_binomial_beta_logistic() => {
3215 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3216 reason:
3217 "binomial-beta-logistic model is missing state in family_state.sas_state"
3218 .to_string(),
3219 })?
3220 }
3221 FittedFamily::LocationScale {
3222 likelihood,
3223 base_link,
3224 } if likelihood.is_binomial_beta_logistic() => match base_link {
3225 Some(InverseLink::BetaLogistic(state)) => *state,
3226 _ => {
3227 return Err(FittedModelError::MissingField {
3228 reason:
3229 "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3230 .to_string(),
3231 });
3232 }
3233 },
3234 _ => return Ok(None),
3235 };
3236 state_from_beta_logisticspec(SasLinkSpec {
3237 initial_epsilon: raw.epsilon,
3238 initial_log_delta: raw.log_delta,
3239 })
3240 .map(Some)
3241 .map_err(|e| FittedModelError::PayloadCorrupt {
3242 reason: format!("invalid saved Beta-Logistic link state: {e}"),
3243 })
3244 }
3245
3246 pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3247 let payload = self.payload();
3248 match &payload.family_state {
3249 FittedFamily::Standard {
3250 likelihood,
3251 mixture_state,
3252 ..
3253 } if likelihood.is_binomial_mixture() => mixture_state
3254 .clone()
3255 .ok_or_else(|| FittedModelError::MissingField {
3256 reason: "binomial-mixture model is missing state in family_state.mixture_state"
3257 .to_string(),
3258 })
3259 .map(Some),
3260 FittedFamily::LocationScale {
3261 likelihood,
3262 base_link,
3263 } if likelihood.is_binomial_mixture() => match base_link {
3264 Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3265 _ => Err(FittedModelError::MissingField {
3266 reason:
3267 "binomial-mixture location-scale model is missing mixture base_link state"
3268 .to_string(),
3269 }),
3270 },
3271 _ => Ok(None),
3272 }
3273 }
3274
3275 pub fn saved_latent_cloglog_state(
3276 &self,
3277 ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3278 let payload = self.payload();
3279 match &payload.family_state {
3280 FittedFamily::Standard {
3281 likelihood,
3282 latent_cloglog_state,
3283 ..
3284 } if likelihood.is_latent_cloglog() => latent_cloglog_state
3285 .ok_or_else(|| FittedModelError::MissingField {
3286 reason:
3287 "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3288 .to_string(),
3289 })
3290 .map(Some),
3291 _ => Ok(None),
3292 }
3293 }
3294
3295 pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3296 let stateful = if let Some(state) = self.saved_mixture_state()? {
3297 Some(InverseLink::Mixture(state))
3298 } else if let Some(state) = self.saved_latent_cloglog_state()? {
3299 Some(InverseLink::LatentCLogLog(state))
3300 } else if let Some(state) = self.saved_beta_logistic_state()? {
3301 Some(InverseLink::BetaLogistic(state))
3302 } else {
3303 self.saved_sas_state()?.map(InverseLink::Sas)
3304 };
3305 match &self.payload().family_state {
3306 FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3307 FittedFamily::Standard { link, .. } => {
3308 Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3309 }
3310 FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3311 FittedFamily::Survival { .. }
3312 | FittedFamily::LatentSurvival { .. }
3313 | FittedFamily::LatentBinary { .. } => Ok(None),
3314 FittedFamily::TransformationNormal { .. } => Ok(None),
3315 }
3316 }
3317
3318 const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3327
3328 pub fn measure_jet_extrapolation_variance(
3355 &self,
3356 data: ndarray::ArrayView2<'_, f64>,
3357 col_map: &HashMap<String, usize>,
3358 ) -> Result<Option<Array1<f64>>, FittedModelError> {
3359 use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3360 use gam_terms::smooth::build_term_collection_design;
3361 use gam_terms::smooth::SmoothBasisSpec;
3362 let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3363 return Ok(None);
3364 };
3365 if data.nrows() == 0
3366 || !saved_spec
3367 .smooth_terms
3368 .iter()
3369 .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3370 {
3371 return Ok(None);
3372 }
3373 let fit = self
3374 .fit_result
3375 .as_ref()
3376 .ok_or_else(|| FittedModelError::MissingField {
3377 reason: "measure-jet extrapolation variance requires the canonical \
3378 fit_result payload; refit"
3379 .to_string(),
3380 })?;
3381 let spec = crate::survival::predict::resolve_termspec_for_prediction(
3382 &self.resolved_termspec,
3383 self.training_headers.as_ref(),
3384 col_map,
3385 "resolved_termspec",
3386 )
3387 .map_err(|e| FittedModelError::SchemaMismatch {
3388 reason: format!("measure-jet extrapolation variance: {e}"),
3389 })?;
3390 let probe = data.slice(ndarray::s![0..1, ..]);
3397 let design = build_term_collection_design(probe, &spec).map_err(|e| {
3398 FittedModelError::SchemaMismatch {
3399 reason: format!(
3400 "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3401 ),
3402 }
3403 })?;
3404 let lambdas = &fit.lambdas;
3405 let phi_scale = fit.coefficient_covariance_scale();
3410 let mut total = Array1::<f64>::zeros(data.nrows());
3411 let mut contributed = false;
3412 for term in &spec.smooth_terms {
3413 let SmoothBasisSpec::MeasureJet {
3414 feature_cols,
3415 spec: mj,
3416 input_scales,
3417 } = &term.basis
3418 else {
3419 continue;
3420 };
3421 let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3422 (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3423 else {
3424 log::warn!(
3425 "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3426 quadrature); skipping its extrapolation variance",
3427 term.name
3428 );
3429 continue;
3430 };
3431 let n_levels = frozen.eps_band.len();
3432 let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3439 lambdas
3440 .get(global_index)
3441 .copied()
3442 .ok_or_else(|| FittedModelError::SchemaMismatch {
3443 reason: format!(
3444 "measure-jet term '{}': penalty global index {global_index} out \
3445 of bounds for {} fitted lambdas",
3446 term.name,
3447 lambdas.len()
3448 ),
3449 })
3450 };
3451 let mut per_scale: Vec<(usize, f64)> = Vec::new();
3452 let mut fused: Option<f64> = None;
3453 for info in &design.penaltyinfo {
3454 if info.termname.as_deref() != Some(term.name.as_str()) {
3455 continue;
3456 }
3457 match &info.penalty.source {
3458 PenaltySource::Other(label) => {
3459 if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3460 let level: usize = level_txt.parse().map_err(|_| {
3461 FittedModelError::SchemaMismatch {
3462 reason: format!(
3463 "measure-jet term '{}': unparseable penalty label \
3464 '{label}'",
3465 term.name
3466 ),
3467 }
3468 })?;
3469 per_scale.push((level, read_lambda(info.global_index)?));
3470 }
3471 }
3472 PenaltySource::Primary => {
3473 fused = Some(read_lambda(info.global_index)?);
3474 }
3475 _ => {}
3476 }
3477 }
3478 let mut lambda_phys = Vec::with_capacity(n_levels);
3479 let spectrum = if per_scale.is_empty() {
3480 let Some(lam) = fused else {
3481 log::warn!(
3482 "measure-jet term '{}' has no fitted amplitude in the penalty \
3483 layout; skipping its extrapolation variance",
3484 term.name
3485 );
3486 continue;
3487 };
3488 let Some(c) = frozen.fused_penalty_normalization_scale else {
3489 log::warn!(
3490 "measure-jet term '{}' is missing the fused penalty normalization scale; \
3491 skipping its extrapolation variance",
3492 term.name
3493 );
3494 continue;
3495 };
3496 MeasureJetExtrapolationSpectrum::Fused(lam / c)
3497 } else {
3498 per_scale.sort_by_key(|&(level, _)| level);
3499 let levels_complete = per_scale.len() == n_levels
3500 && per_scale
3501 .iter()
3502 .enumerate()
3503 .all(|(i, &(level, _))| level == i);
3504 if !levels_complete {
3505 log::warn!(
3506 "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3507 scales; skipping its extrapolation variance",
3508 term.name,
3509 per_scale.len(),
3510 n_levels
3511 );
3512 continue;
3513 }
3514 if frozen.penalty_normalization_scales.len() != n_levels {
3515 log::warn!(
3516 "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3517 band scales; skipping its extrapolation variance",
3518 term.name,
3519 frozen.penalty_normalization_scales.len(),
3520 n_levels
3521 );
3522 continue;
3523 }
3524 lambda_phys.extend(
3525 per_scale
3526 .iter()
3527 .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3528 );
3529 MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3530 };
3531 let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3536 for (j, &col) in feature_cols.iter().enumerate() {
3537 if col >= data.ncols() {
3538 return Err(FittedModelError::SchemaMismatch {
3539 reason: format!(
3540 "measure-jet term '{}': prediction column {col} out of bounds \
3541 for {} data columns",
3542 term.name,
3543 data.ncols()
3544 ),
3545 });
3546 }
3547 queries.column_mut(j).assign(&data.column(col));
3548 }
3549 if let Some(scales) = input_scales {
3550 if scales.len() != feature_cols.len() {
3551 return Err(FittedModelError::SchemaMismatch {
3552 reason: format!(
3553 "measure-jet term '{}': {} input scales for {} axes",
3554 term.name,
3555 scales.len(),
3556 feature_cols.len()
3557 ),
3558 });
3559 }
3560 for (j, &scale) in scales.iter().enumerate() {
3561 queries.column_mut(j).mapv_inplace(|v| v / scale);
3562 }
3563 }
3564 let support = gam_terms::basis::measure_jet_support_curve(
3565 queries.view(),
3566 centers.view(),
3567 frozen.masses.view(),
3568 &frozen.eps_band,
3569 )
3570 .map_err(|e| FittedModelError::SchemaMismatch {
3571 reason: format!(
3572 "measure-jet term '{}': support curve failed: {e}",
3573 term.name
3574 ),
3575 })?;
3576 for i in 0..data.nrows() {
3577 let v = gam_terms::basis::measure_jet_extrapolation_variance(
3578 support.row(i),
3579 &frozen.eps_band,
3580 &frozen.support_means,
3581 spectrum,
3582 Self::MEASURE_JET_COVERAGE_FLOOR,
3583 )
3584 .map_err(|e| FittedModelError::SchemaMismatch {
3585 reason: format!(
3586 "measure-jet term '{}': extrapolation variance failed: {e}",
3587 term.name
3588 ),
3589 })?;
3590 total[i] += phi_scale * v;
3591 }
3592 contributed = true;
3593 }
3594 Ok(contributed.then_some(total))
3595 }
3596
3597 pub fn unified(&self) -> Option<&UnifiedFitResult> {
3599 self.payload().unified.as_ref()
3600 }
3601
3602 pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3603 let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3604 reason: format!("failed to read model '{}': {e}", path.display()),
3605 })?;
3606 let model: Self =
3607 serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3608 reason: format!("failed to parse model json: {e}"),
3609 })?;
3610 let model = model.with_synchronized_stateful_link_metadata();
3611 model.validate_for_persistence()?;
3612 model.validate_numeric_finiteness()?;
3613 Ok(model)
3614 }
3615
3616 pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3617 let normalized = self.clone().with_synchronized_stateful_link_metadata();
3618 normalized.validate_for_persistence()?;
3619 normalized.validate_numeric_finiteness()?;
3620 let parent = path.parent().unwrap_or_else(|| Path::new("."));
3627 let file_name = path
3628 .file_name()
3629 .and_then(|s| s.to_str())
3630 .unwrap_or("model.json");
3631 let pid = std::process::id();
3632 let nanos = std::time::SystemTime::now()
3633 .duration_since(std::time::UNIX_EPOCH)
3634 .map(|d| d.as_nanos())
3635 .unwrap_or(0);
3636 let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3637 let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3638 reason: format!("failed to write model '{}': {e}", tmp.display()),
3639 })?;
3640 let mut writer = std::io::BufWriter::new(file);
3641 let ser_result = serde_json::to_writer(&mut writer, &normalized);
3642 if let Err(e) = ser_result {
3643 std::io::Write::flush(&mut writer).ok();
3646 drop(writer);
3647 fs::remove_file(&tmp).ok();
3648 return Err(FittedModelError::PayloadCorrupt {
3649 reason: format!("failed to serialize model: {e}"),
3650 });
3651 }
3652 std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3653 reason: format!("failed to write model '{}': {e}", tmp.display()),
3654 })?;
3655 let inner = writer
3657 .into_inner()
3658 .map_err(|e| FittedModelError::PayloadCorrupt {
3659 reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3660 })?;
3661 inner.sync_all().ok();
3662 drop(inner);
3663 if let Err(e) = fs::rename(&tmp, path) {
3664 fs::remove_file(&tmp).ok();
3665 return Err(FittedModelError::PayloadCorrupt {
3666 reason: format!("failed to publish model '{}': {e}", path.display()),
3667 });
3668 }
3669 if let Ok(d) = fs::File::open(parent) {
3674 d.sync_all().ok();
3675 }
3676 Ok(())
3677 }
3678
3679 pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3680 self.data_schema
3681 .as_ref()
3682 .ok_or_else(|| FittedModelError::MissingField {
3683 reason: "model is missing data_schema; refit".to_string(),
3684 })
3685 }
3686
3687 pub fn saved_spline_scan(
3691 &self,
3692 ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3693 let Some(saved) = self.spline_scan.as_ref() else {
3694 return Ok(None);
3695 };
3696 let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3697 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3698 Ok(Some((saved.feature_column.as_str(), fit)))
3699 }
3700
3701 pub fn saved_residual_cascade(
3706 &self,
3707 ) -> Result<
3708 Option<(
3709 &[String],
3710 gam_solve::residual_cascade::ResidualCascadeFit,
3711 )>,
3712 FittedModelError,
3713 > {
3714 let Some(saved) = self.residual_cascade.as_ref() else {
3715 return Ok(None);
3716 };
3717 let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3718 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3719 Ok(Some((saved.feature_columns.as_slice(), fit)))
3720 }
3721
3722 pub fn random_effect_group_columns(&self) -> HashSet<String> {
3723 let Some(training_headers) = self.training_headers.as_ref() else {
3724 return HashSet::new();
3725 };
3726 let mut out = HashSet::<String>::new();
3727 for spec in self.saved_term_specs() {
3728 for term in &spec.random_effect_terms {
3729 if let Some(name) = training_headers.get(term.feature_col) {
3730 out.insert(name.clone());
3731 }
3732 }
3733 }
3734 out
3735 }
3736
3737 pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3738 self.validate_payload_version()?;
3752 if let Some(scan) = self.spline_scan.as_ref() {
3753 if self.fit_result.is_some() || self.unified.is_some() {
3758 return Err(FittedModelError::SchemaMismatch {
3759 reason: "spline-scan model must not also carry a dense fit_result/unified \
3760 payload; the representations are mutually exclusive"
3761 .to_string(),
3762 });
3763 }
3764 if self.model_kind != ModelKind::Standard
3765 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3766 {
3767 return Err(FittedModelError::SchemaMismatch {
3768 reason: format!(
3769 "spline-scan representation requires a standard Gaussian-identity model; \
3770 got model_kind={:?}, likelihood={:?}",
3771 self.model_kind,
3772 self.family_state.likelihood()
3773 ),
3774 });
3775 }
3776 if scan.feature_column.is_empty() {
3777 return Err(FittedModelError::MissingField {
3778 reason: "spline-scan model is missing its feature column name; refit"
3779 .to_string(),
3780 });
3781 }
3782 gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3783 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3784 if self.data_schema.is_none() {
3790 return Err(FittedModelError::MissingField {
3791 reason: "spline-scan model is missing data_schema; refit".to_string(),
3792 });
3793 }
3794 if self.training_headers.is_none() {
3795 return Err(FittedModelError::MissingField {
3796 reason: "spline-scan model is missing training_headers; refit".to_string(),
3797 });
3798 }
3799 return Ok(());
3800 } else if let Some(cascade) = self.residual_cascade.as_ref() {
3801 if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3805 return Err(FittedModelError::SchemaMismatch {
3806 reason: "residual-cascade model must not also carry spline_scan / \
3807 fit_result / unified payloads; the representations are \
3808 mutually exclusive"
3809 .to_string(),
3810 });
3811 }
3812 if self.model_kind != ModelKind::Standard
3813 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3814 {
3815 return Err(FittedModelError::SchemaMismatch {
3816 reason: format!(
3817 "residual-cascade representation requires a standard Gaussian-identity \
3818 model; got model_kind={:?}, likelihood={:?}",
3819 self.model_kind,
3820 self.family_state.likelihood()
3821 ),
3822 });
3823 }
3824 if cascade.feature_columns.is_empty()
3825 || !(2..=3).contains(&cascade.feature_columns.len())
3826 {
3827 return Err(FittedModelError::MissingField {
3828 reason: format!(
3829 "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3830 cascade.feature_columns.len()
3831 ),
3832 });
3833 }
3834 gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3835 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3836 if self.data_schema.is_none() {
3837 return Err(FittedModelError::MissingField {
3838 reason: "residual-cascade model is missing data_schema; refit".to_string(),
3839 });
3840 }
3841 if self.training_headers.is_none() {
3842 return Err(FittedModelError::MissingField {
3843 reason: "residual-cascade model is missing training_headers; refit".to_string(),
3844 });
3845 }
3846 return Ok(());
3847 } else if self.fit_result.is_none() {
3848 return Err(FittedModelError::MissingField {
3849 reason: "model is missing canonical fit_result payload; refit".to_string(),
3850 });
3851 }
3852 if self.data_schema.is_none() {
3853 return Err(FittedModelError::MissingField {
3854 reason: "model is missing data_schema; refit".to_string(),
3855 });
3856 }
3857 if self.training_headers.is_none() {
3858 return Err(FittedModelError::MissingField {
3859 reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3860 .to_string(),
3861 });
3862 }
3863 let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3864 FittedModelError::MissingField {
3865 reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3866 .to_string(),
3867 }
3868 })?;
3869 validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3870
3871 if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3872 return Err(FittedModelError::MissingField {
3873 reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3874 .to_string(),
3875 });
3876 }
3877 if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3878 validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3879 }
3880 if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3881 let score = self.transformation_score_calibration.ok_or_else(|| {
3882 FittedModelError::MissingField {
3883 reason: "transformation-normal model is missing transformation_score_calibration; refit"
3884 .to_string(),
3885 }
3886 })?;
3887 score.validate("transformation-normal model")?;
3888 }
3889 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3890 if self.formula_logslope.is_none() {
3891 return Err(FittedModelError::MissingField {
3892 reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3893 });
3894 }
3895 if self.z_column.is_none() {
3896 return Err(FittedModelError::MissingField {
3897 reason: "marginal-slope model is missing z_column; refit".to_string(),
3898 });
3899 }
3900 let z_normalization =
3901 self.latent_z_normalization
3902 .ok_or_else(|| FittedModelError::MissingField {
3903 reason: "marginal-slope model is missing latent_z_normalization; refit"
3904 .to_string(),
3905 })?;
3906 z_normalization.validate("marginal-slope model")?;
3907 let latent_measure =
3908 self.latent_measure
3909 .as_ref()
3910 .ok_or_else(|| FittedModelError::MissingField {
3911 reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3912 })?;
3913 latent_measure
3914 .validate("marginal-slope model latent_measure")
3915 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3916 if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3917 return Err(FittedModelError::MissingField {
3918 reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3919 });
3920 }
3921 if self.resolved_termspec_logslope.as_ref().is_none() {
3922 return Err(FittedModelError::MissingField {
3923 reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3924 .to_string(),
3925 });
3926 }
3927 match self.family_state.frailty() {
3928 Some(FrailtySpec::None)
3929 | Some(FrailtySpec::GaussianShift {
3930 sigma_fixed: Some(_),
3931 }) => {}
3932 Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3933 return Err(FittedModelError::IncompatibleConfig {
3934 reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3935 .to_string(),
3936 });
3937 }
3938 Some(FrailtySpec::HazardMultiplier { .. }) => {
3939 return Err(FittedModelError::IncompatibleConfig {
3940 reason: "marginal-slope model does not support HazardMultiplier frailty"
3941 .to_string(),
3942 });
3943 }
3944 None => {
3945 return Err(FittedModelError::MissingField {
3946 reason: "marginal-slope model is missing family_state.frailty; refit"
3947 .to_string(),
3948 });
3949 }
3950 }
3951 }
3952
3953 if let FittedFamily::Survival {
3954 survival_likelihood,
3955 frailty,
3956 ..
3957 } = &self.family_state
3958 {
3959 if matches!(
3960 survival_likelihood.as_deref(),
3961 Some("latent") | Some("latent-binary")
3962 ) {
3963 return Err(FittedModelError::SchemaMismatch {
3964 reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3965 .to_string(),
3966 });
3967 }
3968 if survival_likelihood.as_deref() == Some("marginal-slope") {
3969 if self.formula_logslope.is_none() {
3970 return Err(FittedModelError::MissingField {
3971 reason: "survival marginal-slope model is missing formula_logslope; refit"
3972 .to_string(),
3973 });
3974 }
3975 if self.z_column.is_none() {
3976 return Err(FittedModelError::MissingField {
3977 reason: "survival marginal-slope model is missing z_column; refit"
3978 .to_string(),
3979 });
3980 }
3981 let z_normalization =
3982 self.latent_z_normalization
3983 .ok_or_else(|| {
3984 FittedModelError::MissingField {
3985 reason:
3986 "survival marginal-slope model is missing latent_z_normalization; refit"
3987 .to_string(),
3988 }
3989 })?;
3990 z_normalization.validate("survival marginal-slope model")?;
3991 let latent_measure =
3992 self.latent_measure
3993 .as_ref()
3994 .ok_or_else(|| FittedModelError::MissingField {
3995 reason:
3996 "survival marginal-slope model is missing latent_measure; refit"
3997 .to_string(),
3998 })?;
3999 latent_measure
4000 .validate("survival marginal-slope model latent_measure")
4001 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
4002 if self.logslope_baseline.is_none() {
4003 return Err(FittedModelError::MissingField {
4004 reason: "survival marginal-slope model is missing logslope_baseline; refit"
4005 .to_string(),
4006 });
4007 }
4008 if self.resolved_termspec_logslope.as_ref().is_none() {
4009 return Err(FittedModelError::MissingField {
4010 reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
4011 .to_string(),
4012 });
4013 }
4014 match frailty {
4015 FrailtySpec::None
4016 | FrailtySpec::GaussianShift {
4017 sigma_fixed: Some(_),
4018 } => {}
4019 FrailtySpec::GaussianShift { sigma_fixed: None } => {
4020 return Err(FittedModelError::IncompatibleConfig {
4021 reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4022 .to_string(),
4023 });
4024 }
4025 FrailtySpec::HazardMultiplier { .. } => {
4026 return Err(FittedModelError::IncompatibleConfig {
4027 reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4028 .to_string(),
4029 });
4030 }
4031 }
4032 } else if !matches!(frailty, FrailtySpec::None) {
4033 return Err(FittedModelError::IncompatibleConfig {
4034 reason:
4035 "non-marginal survival models do not currently persist a frailty modifier"
4036 .to_string(),
4037 });
4038 }
4039 if self.survival_time_basis.is_none() {
4047 return Err(FittedModelError::MissingField {
4048 reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4049 });
4050 }
4051 if self.survival_time_anchor.is_none() {
4052 return Err(FittedModelError::MissingField {
4053 reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4054 });
4055 }
4056 }
4057 if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4058 match frailty {
4059 FrailtySpec::HazardMultiplier {
4060 sigma_fixed: Some(_),
4061 ..
4062 } => {}
4063 FrailtySpec::HazardMultiplier {
4064 sigma_fixed: None, ..
4065 } => {
4066 return Err(FittedModelError::IncompatibleConfig {
4067 reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4068 .to_string(),
4069 });
4070 }
4071 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4072 return Err(FittedModelError::IncompatibleConfig {
4073 reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4074 .to_string(),
4075 });
4076 }
4077 }
4078 if self.survival_likelihood.as_deref() != Some("latent") {
4079 return Err(FittedModelError::SchemaMismatch {
4080 reason: "latent survival model must persist survival_likelihood=latent"
4081 .to_string(),
4082 });
4083 }
4084 }
4085 if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4086 match frailty {
4087 FrailtySpec::HazardMultiplier {
4088 sigma_fixed: Some(_),
4089 ..
4090 } => {}
4091 FrailtySpec::HazardMultiplier {
4092 sigma_fixed: None, ..
4093 } => {
4094 return Err(FittedModelError::IncompatibleConfig {
4095 reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4096 .to_string(),
4097 });
4098 }
4099 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4100 return Err(FittedModelError::IncompatibleConfig {
4101 reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4102 .to_string(),
4103 });
4104 }
4105 }
4106 if self.survival_likelihood.as_deref() != Some("latent-binary") {
4107 return Err(FittedModelError::SchemaMismatch {
4108 reason: "latent binary model must persist survival_likelihood=latent-binary"
4109 .to_string(),
4110 });
4111 }
4112 }
4113
4114 let family_likelihood = match &self.family_state {
4115 FittedFamily::Standard { likelihood, .. }
4116 | FittedFamily::LocationScale { likelihood, .. }
4117 | FittedFamily::MarginalSlope { likelihood, .. }
4118 | FittedFamily::Survival { likelihood, .. }
4119 | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4120 FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4121 };
4122 let is_standard_or_location_scale = matches!(
4123 self.family_state,
4124 FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4125 );
4126 if is_standard_or_location_scale
4127 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4128 {
4129 self.saved_sas_state()?;
4130 }
4131 if is_standard_or_location_scale
4132 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4133 {
4134 self.saved_beta_logistic_state()?;
4135 }
4136 if is_standard_or_location_scale
4137 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4138 {
4139 self.saved_mixture_state()?;
4140 }
4141 if matches!(self.family_state, FittedFamily::Standard { .. })
4142 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4143 {
4144 self.saved_latent_cloglog_state()?;
4145 }
4146 if matches!(self.family_state, FittedFamily::LocationScale { .. })
4147 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4148 {
4149 return Err(FittedModelError::IncompatibleConfig {
4150 reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4151 .to_string(),
4152 });
4153 }
4154 if matches!(self.family_state, FittedFamily::Survival { .. })
4155 && self.survival_likelihood.is_none()
4156 {
4157 return Err(FittedModelError::MissingField {
4158 reason: "saved survival model is missing survival_likelihood metadata; refit"
4159 .to_string(),
4160 });
4161 }
4162 let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4163 || self.linkwiggle_degree.is_some()
4164 || self.beta_link_wiggle.is_some()
4165 || self
4166 .fit_result
4167 .as_ref()
4168 .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4169 .is_some();
4170 let saved_link_wiggle = self.saved_link_wiggle()?;
4171 if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4172 return Err(FittedModelError::SchemaMismatch {
4173 reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4174 .to_string(),
4175 });
4176 }
4177 let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4178 || self.baseline_timewiggle_degree.is_some()
4179 || self.baseline_timewiggle_penalty_orders.is_some()
4180 || self.baseline_timewiggle_double_penalty.is_some()
4181 || self.beta_baseline_timewiggle.is_some()
4182 || self.beta_baseline_timewiggle_by_cause.is_some();
4183 let is_joint_cause_specific = self
4184 .survival_cause_count
4185 .is_some_and(|cause_count| cause_count > 1);
4186 if has_any_saved_baseline_time_wiggle {
4187 if is_joint_cause_specific {
4188 let complete = self.baseline_timewiggle_knots.is_some()
4189 && self.baseline_timewiggle_degree.is_some()
4190 && self.baseline_timewiggle_penalty_orders.is_some()
4191 && self.baseline_timewiggle_double_penalty.is_some()
4192 && self.beta_baseline_timewiggle_by_cause.is_some();
4193 if !complete {
4194 return Err(FittedModelError::SchemaMismatch {
4195 reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4196 .to_string(),
4197 });
4198 }
4199 } else if self.saved_baseline_time_wiggle()?.is_none() {
4200 return Err(FittedModelError::SchemaMismatch {
4201 reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4202 .to_string(),
4203 });
4204 }
4205 }
4206 if self
4207 .survival_likelihood
4208 .as_deref()
4209 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4210 {
4211 validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4212 }
4213
4214 if let Some(runtime) = self.score_warp_runtime.as_ref() {
4224 runtime.validate_exact_replay_contract().map_err(|err| {
4225 FittedModelError::PayloadCorrupt {
4226 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4227 }
4228 })?;
4229 }
4230 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4231 runtime.validate_exact_replay_contract().map_err(|err| {
4232 FittedModelError::PayloadCorrupt {
4233 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4234 }
4235 })?;
4236 }
4237 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4238 validate_marginal_slope_saved_fit(
4239 self.fit_result.as_ref().expect("checked above"),
4240 self.score_warp_runtime.as_ref(),
4241 self.link_deviation_runtime.as_ref(),
4242 "fit_result",
4243 )?;
4244 let unified = self
4245 .unified
4246 .as_ref()
4247 .ok_or_else(|| FittedModelError::MissingField {
4248 reason: "marginal-slope model is missing unified fit payload; refit"
4249 .to_string(),
4250 })?;
4251 validate_marginal_slope_saved_fit(
4252 unified,
4253 self.score_warp_runtime.as_ref(),
4254 self.link_deviation_runtime.as_ref(),
4255 "unified",
4256 )?;
4257 }
4258 if self
4259 .survival_likelihood
4260 .as_deref()
4261 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4262 {
4263 validate_survival_marginal_slope_saved_fit(
4264 self.fit_result.as_ref().expect("checked above"),
4265 self.score_warp_runtime.as_ref(),
4266 self.link_deviation_runtime.as_ref(),
4267 "fit_result",
4268 )?;
4269 if let Some(unified) = self.unified.as_ref() {
4270 validate_survival_marginal_slope_saved_fit(
4271 unified,
4272 self.score_warp_runtime.as_ref(),
4273 self.link_deviation_runtime.as_ref(),
4274 "unified",
4275 )?;
4276 }
4277 }
4278
4279 Ok(())
4290 }
4291
4292 pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4293 let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4294 if let Some(fit) = self.fit_result.as_ref() {
4295 fit.validate_numeric_finiteness()
4296 .map_err(|e| corrupt(e.to_string()))?;
4297 }
4298
4299 for (name, opt) in [
4300 ("survival_baseline_scale", self.survival_baseline_scale),
4301 ("survival_baseline_shape", self.survival_baseline_shape),
4302 ("survival_baseline_rate", self.survival_baseline_rate),
4303 ("survival_baseline_makeham", self.survival_baseline_makeham),
4304 (
4305 "survival_time_smooth_lambda",
4306 self.survival_time_smooth_lambda,
4307 ),
4308 ("survival_time_anchor", self.survival_time_anchor),
4309 ("survivalridge_lambda", self.survivalridge_lambda),
4310 ] {
4311 if let Some(v) = opt {
4312 ensure_finite_scalar(name, v).map_err(corrupt)?;
4313 }
4314 }
4315
4316 if let Some(v) = self.beta_noise.as_ref() {
4317 validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4318 }
4319 if let Some(v) = self.noise_projection.as_ref() {
4320 validate_all_finite("noise_projection", v.iter().flatten().copied())
4321 .map_err(corrupt)?;
4322 if self.noise_projection_ridge_alpha.is_none() {
4323 return Err(FittedModelError::MissingField {
4324 reason:
4325 "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4326 .to_string(),
4327 });
4328 }
4329 }
4330 if let Some(v) = self.noise_center.as_ref() {
4331 validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4332 }
4333 if let Some(v) = self.noise_scale.as_ref() {
4334 validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4335 }
4336 if let Some(v) = self.noise_projection_ridge_alpha {
4337 ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4338 if v < 0.0 {
4339 return Err(FittedModelError::InvalidInput {
4340 reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4341 });
4342 }
4343 }
4344 if let Some(v) = self.gaussian_response_scale {
4345 ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4346 }
4347 if let Some(v) = self.beta_link_wiggle.as_ref() {
4348 validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4349 }
4350 if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4351 validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4352 }
4353 if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4354 validate_all_finite(
4355 "beta_baseline_timewiggle_by_cause",
4356 v.iter().flatten().copied(),
4357 )
4358 .map_err(corrupt)?;
4359 }
4360 if let Some(v) = self.latent_z_normalization {
4361 v.validate("latent_z_normalization")?;
4362 }
4363 if let Some(v) = self.latent_measure.as_ref() {
4364 v.validate("latent_measure").map_err(corrupt)?;
4365 }
4366 if let Some(v) = self.survival_beta_time.as_ref() {
4367 validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4368 }
4369 if let Some(v) = self.survival_beta_threshold.as_ref() {
4370 validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4371 }
4372 if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4373 validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4374 }
4375 if let Some(v) = self.survival_noise_projection.as_ref() {
4376 validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4377 .map_err(corrupt)?;
4378 if self.survival_noise_projection_ridge_alpha.is_none() {
4379 return Err(FittedModelError::MissingField {
4380 reason:
4381 "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4382 .to_string(),
4383 });
4384 }
4385 }
4386 if let Some(v) = self.survival_noise_center.as_ref() {
4387 validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4388 }
4389 if let Some(v) = self.survival_noise_projection_ridge_alpha {
4390 ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4391 if v < 0.0 {
4392 return Err(FittedModelError::InvalidInput {
4393 reason: format!(
4394 "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4395 ),
4396 });
4397 }
4398 }
4399 if let Some(v) = self.survival_noise_scale.as_ref() {
4400 validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4401 }
4402 if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4403 validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4404 .map_err(corrupt)?;
4405 }
4406 if let Some(v) = self.sas_param_covariance.as_ref() {
4407 validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4408 .map_err(corrupt)?;
4409 }
4410 Ok(())
4411 }
4412}
4413
4414fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4415 a.rows().into_iter().map(|row| row.to_vec()).collect()
4416}
4417
4418use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4419
4420fn validate_frozen_term_collectionspec(
4421 spec: &TermCollectionSpec,
4422 label: &str,
4423) -> Result<(), FittedModelError> {
4424 spec.validate_frozen(label)
4425 .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4426}
4427
4428impl Deref for FittedModel {
4429 type Target = FittedModelPayload;
4430
4431 fn deref(&self) -> &Self::Target {
4432 self.payload()
4433 }
4434}
4435
4436impl DerefMut for FittedModel {
4437 fn deref_mut(&mut self) -> &mut Self::Target {
4438 self.payload_mut()
4439 }
4440}
4441
4442pub fn survival_baseline_config_from_model(
4447 model: &FittedModel,
4448) -> Result<SurvivalBaselineConfig, FittedModelError> {
4449 let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4450 FittedModelError::MissingField {
4451 reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4452 }
4453 })?;
4454 parse_survival_baseline_config(
4455 target,
4456 model.survival_baseline_scale,
4457 model.survival_baseline_shape,
4458 model.survival_baseline_rate,
4459 model.survival_baseline_makeham,
4460 )
4461 .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4462}
4463
4464pub fn load_survival_time_basis_config_from_model(
4465 model: &FittedModel,
4466) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4467 match model
4468 .survival_time_basis
4469 .as_deref()
4470 .ok_or_else(|| FittedModelError::MissingField {
4471 reason: "saved survival model missing survival_time_basis".to_string(),
4472 })?
4473 .to_ascii_lowercase()
4474 .as_str()
4475 {
4476 "none" => Ok(SurvivalTimeBasisConfig::None),
4477 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4478 "bspline" => {
4479 let degree =
4480 model
4481 .survival_time_degree
4482 .ok_or_else(|| FittedModelError::MissingField {
4483 reason: "saved survival bspline model missing survival_time_degree"
4484 .to_string(),
4485 })?;
4486 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4487 FittedModelError::MissingField {
4488 reason: "saved survival bspline model missing survival_time_knots".to_string(),
4489 }
4490 })?;
4491 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4492 if degree < 1 || knots.is_empty() {
4493 return Err(FittedModelError::SchemaMismatch {
4494 reason: "saved survival bspline time basis metadata is invalid".to_string(),
4495 });
4496 }
4497 Ok(SurvivalTimeBasisConfig::BSpline {
4498 degree,
4499 knots: Array1::from_vec(knots),
4500 smooth_lambda,
4501 })
4502 }
4503 "ispline" => {
4504 let degree =
4505 model
4506 .survival_time_degree
4507 .ok_or_else(|| FittedModelError::MissingField {
4508 reason: "saved survival ispline model missing survival_time_degree"
4509 .to_string(),
4510 })?;
4511 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4512 FittedModelError::MissingField {
4513 reason: "saved survival ispline model missing survival_time_knots".to_string(),
4514 }
4515 })?;
4516 let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4517 FittedModelError::MissingField {
4518 reason: "saved survival ispline model missing survival_time_keep_cols"
4519 .to_string(),
4520 }
4521 })?;
4522 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4523 if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4524 return Err(FittedModelError::SchemaMismatch {
4525 reason: "saved survival ispline time basis metadata is invalid".to_string(),
4526 });
4527 }
4528 Ok(SurvivalTimeBasisConfig::ISpline {
4529 degree,
4530 knots: Array1::from_vec(knots),
4531 keep_cols,
4532 smooth_lambda,
4533 })
4534 }
4535 other => Err(FittedModelError::IncompatibleConfig {
4536 reason: format!("unsupported saved survival_time_basis '{other}'"),
4537 }),
4538 }
4539}
4540
4541#[cfg(test)]
4542mod tests {
4543 use super::*;
4544 use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4545 use crate::survival::lognormal_kernel::FrailtySpec;
4546 use gam_solve::pirls::PirlsStatus;
4547 use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4548 use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4549 use gam_data::SchemaColumn;
4550 use ndarray::{Array1, Array2, array};
4551
4552 fn empty_termspec() -> TermCollectionSpec {
4553 TermCollectionSpec {
4554 linear_terms: vec![],
4555 random_effect_terms: vec![],
4556 smooth_terms: vec![],
4557 }
4558 }
4559
4560 #[test]
4564 fn spline_scan_payload_round_trips_and_validates() {
4565 let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4566 let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4567 let w = vec![1.0_f64; x.len()];
4568 let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4569 let make_payload = || {
4570 crate::inference::model_payload_builders::assemble_spline_scan_payload(
4571 "y ~ s(x)".to_string(),
4572 "x".to_string(),
4573 &fit,
4574 DataSchema {
4575 columns: vec![
4576 SchemaColumn {
4577 name: "y".to_string(),
4578 kind: ColumnKindTag::Continuous,
4579 levels: vec![],
4580 },
4581 SchemaColumn {
4582 name: "x".to_string(),
4583 kind: ColumnKindTag::Continuous,
4584 levels: vec![],
4585 },
4586 ],
4587 },
4588 vec!["x".to_string()],
4589 vec![(0.0, 1.0)],
4590 )
4591 };
4592 let model = FittedModel::from_payload(make_payload());
4595 model
4596 .validate_for_persistence()
4597 .expect("scan model validates");
4598 model
4599 .validate_numeric_finiteness()
4600 .expect("scan model is finite");
4601
4602 let json = serde_json::to_string(&model).expect("serialize model");
4603 let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4604 restored
4605 .validate_for_persistence()
4606 .expect("restored scan model validates");
4607 let (column, replay) = restored
4608 .saved_spline_scan()
4609 .expect("restore scan fit")
4610 .expect("payload carries the scan representation");
4611 assert_eq!(column, "x");
4612 for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4613 let (m0, v0) = fit.predict(xq).expect("predict original");
4614 let (m1, v1) = replay.predict(xq).expect("predict replayed");
4615 assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4616 assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4617 }
4618
4619 let mut dense = make_payload();
4621 dense.spline_scan = None;
4622 let err = FittedModel::from_payload(dense)
4623 .validate_for_persistence()
4624 .expect_err("dense payload without fit_result must be rejected");
4625 assert!(err.to_string().contains("fit_result"));
4626
4627 let mut corrupt = make_payload();
4629 corrupt
4630 .spline_scan
4631 .as_mut()
4632 .expect("scan channel present")
4633 .state
4634 .knots
4635 .truncate(2);
4636 FittedModel::from_payload(corrupt)
4637 .validate_for_persistence()
4638 .expect_err("corrupt scan state must be rejected");
4639 let mut unnamed = make_payload();
4640 unnamed
4641 .spline_scan
4642 .as_mut()
4643 .expect("scan channel present")
4644 .feature_column
4645 .clear();
4646 FittedModel::from_payload(unnamed)
4647 .validate_for_persistence()
4648 .expect_err("missing feature column must be rejected");
4649 }
4650
4651 fn standard_gaussian_payload() -> FittedModelPayload {
4652 FittedModelPayload::new(
4653 MODEL_PAYLOAD_VERSION,
4654 "y ~ 1".to_string(),
4655 ModelKind::Standard,
4656 FittedFamily::Standard {
4657 likelihood: LikelihoodSpec::gaussian_identity(),
4658 link: Some(StandardLink::Identity),
4659 latent_cloglog_state: None,
4660 mixture_state: None,
4661 sas_state: None,
4662 },
4663 "gaussian".to_string(),
4664 )
4665 }
4666
4667 fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4668 SavedCompiledFlexBlock {
4669 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4670 breakpoints: vec![-1.0, 1.0],
4671 basis_dim,
4672 span_c0: vec![vec![0.0; basis_dim]],
4673 span_c1: vec![vec![0.0; basis_dim]],
4674 span_c2: vec![vec![0.0; basis_dim]],
4675 span_c3: vec![vec![0.0; basis_dim]],
4676 anchor_correction: None,
4677 anchor_components: Vec::new(),
4678 }
4679 }
4680
4681 fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4682 let beta = Array1::from_vec(
4683 blocks
4684 .iter()
4685 .flat_map(|block| block.beta.iter().copied())
4686 .collect(),
4687 );
4688 let p = beta.len();
4689 UnifiedFitResult {
4690 blocks,
4691 log_lambdas: Array1::zeros(0),
4692 lambdas: Array1::zeros(0),
4693 likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4694 likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4695 log_likelihood_normalization: LogLikelihoodNormalization::Full,
4696 log_likelihood: 0.0,
4697 deviance: 0.0,
4698 reml_score: 0.0,
4699 stable_penalty_term: 0.0,
4700 penalized_objective: 0.0,
4701 used_device: false,
4702 outer_iterations: 0,
4703 outer_cost_evals: 0,
4704 inner_pirls_solves: 0,
4705 outer_converged: true,
4706 outer_gradient_norm: None,
4707 standard_deviation: 1.0,
4708 covariance_conditional: Some(Array2::zeros((p, p))),
4709 covariance_corrected: Some(Array2::zeros((p, p))),
4710 inference: None,
4711 fitted_link: FittedLinkState::Standard(None),
4712 geometry: None,
4713 block_states: vec![],
4714 beta,
4715 pirls_status: PirlsStatus::Converged,
4716 max_abs_eta: 0.0,
4717 constraint_kkt: None,
4718 artifacts: FitArtifacts {
4719 pirls: None,
4720 null_space_logdet: None,
4721 null_space_dim: None,
4722 survival_link_wiggle_knots: None,
4723 survival_link_wiggle_degree: None,
4724 criterion_certificate: None,
4725 rho_posterior_certificate: None,
4726 rho_posterior_escalation: None,
4727 rho_covariance: None,
4728 joint_log_lambdas: None,
4729 },
4730 inner_cycles: 0,
4731 }
4732 }
4733
4734 fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4735 let mut payload = FittedModelPayload::new(
4736 version,
4737 "y ~ 1".to_string(),
4738 ModelKind::MarginalSlope,
4739 FittedFamily::MarginalSlope {
4740 likelihood: LikelihoodSpec::binomial_probit(),
4741 base_link: InverseLink::Standard(StandardLink::Probit),
4742 frailty: FrailtySpec::None,
4743 },
4744 "bernoulli-marginal-slope".to_string(),
4745 );
4746 payload.fit_result = Some(fit.clone());
4747 payload.unified = Some(fit);
4748 payload.data_schema = Some(DataSchema {
4749 columns: vec![SchemaColumn {
4750 name: "z".to_string(),
4751 kind: ColumnKindTag::Continuous,
4752 levels: vec![],
4753 }],
4754 });
4755 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4756 payload.resolved_termspec = Some(empty_termspec());
4757 payload.resolved_termspec_logslope = Some(empty_termspec());
4758 payload.formula_logslope = Some("1".to_string());
4759 payload.z_column = Some("z".to_string());
4760 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4761 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4762 payload.marginal_baseline = Some(0.0);
4763 payload.logslope_baseline = Some(0.0);
4764 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4765 payload
4766 }
4767
4768 #[test]
4769 fn from_payload_synchronizes_used_device_from_saved_fit() {
4770 let mut fit = saved_fit(vec![
4771 FittedBlock {
4772 beta: Array1::from_vec(vec![0.25]),
4773 role: BlockRole::Mean,
4774 edf: 1.0,
4775 lambdas: Array1::zeros(0),
4776 },
4777 FittedBlock {
4778 beta: Array1::from_vec(vec![0.5]),
4779 role: BlockRole::Scale,
4780 edf: 1.0,
4781 lambdas: Array1::zeros(0),
4782 },
4783 ]);
4784 fit.used_device = true;
4785 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4786 payload.used_device = false;
4787
4788 let model = FittedModel::from_payload(payload);
4789
4790 assert!(model.payload().used_device);
4791 }
4792
4793 fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4794 let mut payload = FittedModelPayload::new(
4795 version,
4796 "Surv(entry, exit, event) ~ 1".to_string(),
4797 ModelKind::Survival,
4798 FittedFamily::Survival {
4799 likelihood: LikelihoodSpec::royston_parmar(),
4800 survival_likelihood: Some("marginal-slope".to_string()),
4801 survival_distribution: Some(ResidualDistribution::Gaussian),
4802 frailty: FrailtySpec::None,
4803 },
4804 "survival".to_string(),
4805 );
4806 payload.fit_result = Some(fit.clone());
4807 payload.unified = Some(fit);
4808 payload.survival_likelihood = Some("marginal-slope".to_string());
4809 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4810 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4811 payload.data_schema = Some(DataSchema {
4812 columns: vec![SchemaColumn {
4813 name: "z".to_string(),
4814 kind: ColumnKindTag::Continuous,
4815 levels: vec![],
4816 }],
4817 });
4818 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4819 payload.resolved_termspec = Some(empty_termspec());
4820 payload.resolved_termspec_logslope = Some(empty_termspec());
4821 payload.formula_logslope = Some("1".to_string());
4822 payload.z_column = Some("z".to_string());
4823 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4824 payload.logslope_baseline = Some(0.0);
4825 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4826 payload
4827 }
4828
4829 fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4830 let mut payload = FittedModelPayload::new(
4836 MODEL_PAYLOAD_VERSION,
4837 "y ~ x".to_string(),
4838 ModelKind::LocationScale,
4839 FittedFamily::LocationScale {
4840 likelihood: LikelihoodSpec::gamma_log(),
4841 base_link: Some(InverseLink::Standard(StandardLink::Log)),
4842 },
4843 "gamma-location-scale".to_string(),
4844 );
4845 payload.data_schema = Some(DataSchema {
4846 columns: vec![
4847 SchemaColumn {
4848 name: "y".to_string(),
4849 kind: ColumnKindTag::Continuous,
4850 levels: vec![],
4851 },
4852 SchemaColumn {
4853 name: "x".to_string(),
4854 kind: ColumnKindTag::Continuous,
4855 levels: vec![],
4856 },
4857 ],
4858 });
4859 payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4860 payload.resolved_termspec = Some(empty_termspec());
4861 payload.resolved_termspec_noise = Some(empty_termspec());
4862 payload.formula_noise = Some("x".to_string());
4863 payload.beta_noise = Some(vec![0.0]);
4864 payload.link = Some(InverseLink::Standard(StandardLink::Log));
4865 payload
4866 }
4867
4868 #[test]
4875 fn dispersion_location_scale_payload_is_not_classified_binomial() {
4876 let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4877 assert_eq!(
4878 model.predict_model_class(),
4879 PredictModelClass::DispersionLocationScale,
4880 "Gamma dispersion location-scale must route through the dispersion \
4881 predictor, not the binomial threshold-scale class",
4882 );
4883 assert!(
4884 !matches!(
4885 model.predict_model_class(),
4886 PredictModelClass::BinomialLocationScale
4887 ),
4888 "dispersion location-scale must never be classified as binomial",
4889 );
4890
4891 for likelihood in [
4893 LikelihoodSpec::gamma_log(),
4894 LikelihoodSpec::new(
4895 ResponseFamily::NegativeBinomial {
4896 theta: 1.0,
4897 theta_fixed: false,
4898 },
4899 InverseLink::Standard(StandardLink::Log),
4900 ),
4901 LikelihoodSpec::new(
4902 ResponseFamily::Beta { phi: 1.0 },
4903 InverseLink::Standard(StandardLink::Logit),
4904 ),
4905 LikelihoodSpec::new(
4906 ResponseFamily::Tweedie { p: 1.5 },
4907 InverseLink::Standard(StandardLink::Log),
4908 ),
4909 ] {
4910 let mut payload = gamma_dispersion_location_scale_payload();
4911 payload.family_state = FittedFamily::LocationScale {
4912 base_link: Some(likelihood.link.clone()),
4913 likelihood: likelihood.clone(),
4914 };
4915 let model = FittedModel::from_payload(payload);
4916 assert_eq!(
4917 model.predict_model_class(),
4918 PredictModelClass::DispersionLocationScale,
4919 "dispersion family {:?} mis-classified",
4920 likelihood.response,
4921 );
4922 }
4923 }
4924
4925 #[test]
4926 fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4927 let data = array![[100.0], [-100.0]];
4928 let col_map = HashMap::from([("g".to_string(), 0usize)]);
4929
4930 let mut plain_payload = standard_gaussian_payload();
4931 plain_payload.data_schema = Some(DataSchema {
4932 columns: vec![SchemaColumn {
4933 name: "g".to_string(),
4934 kind: ColumnKindTag::Continuous,
4935 levels: vec![],
4936 }],
4937 });
4938 plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4939 plain_payload.resolved_termspec = Some(empty_termspec());
4940 let plain = FittedModel::from_payload(plain_payload.clone());
4941 let clipped = plain
4942 .axis_clip_to_training_ranges(data.view(), &col_map)
4943 .expect("ordinary continuous axis should clip outside the training range");
4944 assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4945
4946 let mut group_payload = plain_payload;
4947 let mut group_spec = empty_termspec();
4948 group_spec
4949 .random_effect_terms
4950 .push(gam_terms::smooth::RandomEffectTermSpec {
4951 name: "g".to_string(),
4952 feature_col: 0,
4953 drop_first_level: false,
4954 penalized: true,
4955 frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4956 });
4957 group_payload.resolved_termspec = Some(group_spec);
4958 let group_model = FittedModel::from_payload(group_payload);
4959
4960 assert_eq!(
4961 group_model.random_effect_group_columns(),
4962 HashSet::from(["g".to_string()])
4963 );
4964
4965 assert_eq!(
4966 group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4967 None,
4968 "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4969 );
4970 }
4971
4972 #[test]
4973 fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4974 let fit = saved_fit(vec![
4975 FittedBlock {
4976 beta: array![0.1],
4977 role: BlockRole::Mean,
4978 edf: 1.0,
4979 lambdas: Array1::zeros(0),
4980 },
4981 FittedBlock {
4982 beta: array![0.2],
4983 role: BlockRole::Scale,
4984 edf: 1.0,
4985 lambdas: Array1::zeros(0),
4986 },
4987 FittedBlock {
4988 beta: array![0.3],
4989 role: BlockRole::Mean,
4990 edf: 1.0,
4991 lambdas: Array1::zeros(0),
4992 },
4993 ]);
4994 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4995 payload.score_warp_runtime = Some(anchored_runtime(2));
4996
4997 let err = FittedModel::from_payload(payload)
4998 .validate_for_persistence()
4999 .expect_err("marginal-slope score-warp basis mismatch should fail validation");
5000 assert!(err.to_string().contains("score-warp coefficient mismatch"));
5001 }
5002
5003 #[test]
5004 fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
5005 let fit = saved_fit(vec![
5006 FittedBlock {
5007 beta: array![0.1],
5008 role: BlockRole::Time,
5009 edf: 1.0,
5010 lambdas: Array1::zeros(0),
5011 },
5012 FittedBlock {
5013 beta: array![0.2],
5014 role: BlockRole::Mean,
5015 edf: 1.0,
5016 lambdas: Array1::zeros(0),
5017 },
5018 FittedBlock {
5019 beta: array![0.3],
5020 role: BlockRole::Scale,
5021 edf: 1.0,
5022 lambdas: Array1::zeros(0),
5023 },
5024 FittedBlock {
5025 beta: array![0.4],
5026 role: BlockRole::LinkWiggle,
5027 edf: 1.0,
5028 lambdas: Array1::zeros(0),
5029 },
5030 ]);
5031 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5032 payload.link_deviation_runtime = Some(anchored_runtime(2));
5033
5034 let err = FittedModel::from_payload(payload)
5035 .saved_prediction_runtime()
5036 .expect_err(
5037 "survival marginal-slope link basis mismatch should fail runtime validation",
5038 );
5039 assert!(
5040 err.to_string()
5041 .contains("link-deviation coefficient mismatch")
5042 );
5043 }
5044
5045 #[test]
5046 fn apply_survival_time_basis_writes_all_required_fields() {
5047 use crate::survival::construction::SavedSurvivalTimeBasis;
5048
5049 let fit = saved_fit(vec![
5050 FittedBlock {
5051 beta: array![0.1],
5052 role: BlockRole::Time,
5053 edf: 1.0,
5054 lambdas: Array1::zeros(0),
5055 },
5056 FittedBlock {
5057 beta: array![0.2],
5058 role: BlockRole::Mean,
5059 edf: 1.0,
5060 lambdas: Array1::zeros(0),
5061 },
5062 FittedBlock {
5063 beta: array![0.3],
5064 role: BlockRole::Scale,
5065 edf: 1.0,
5066 lambdas: Array1::zeros(0),
5067 },
5068 ]);
5069 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5070
5071 let snapshot = SavedSurvivalTimeBasis {
5076 basisname: "royston-parmar".to_string(),
5077 degree: Some(3),
5078 knots: Some(vec![0.0, 1.0, 2.0]),
5079 keep_cols: Some(vec![0, 2]),
5080 smooth_lambda: Some(0.5),
5081 anchor: 0.25,
5082 };
5083 payload.apply_survival_time_basis(&snapshot);
5084
5085 assert_eq!(
5086 payload.survival_time_basis.as_deref(),
5087 Some("royston-parmar")
5088 );
5089 assert_eq!(payload.survival_time_degree, Some(3));
5090 assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5091 assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5092 assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5093 assert_eq!(payload.survival_time_anchor, Some(0.25));
5094 }
5095
5096 #[test]
5097 fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5098 let fit = saved_fit(vec![
5099 FittedBlock {
5100 beta: array![0.1],
5101 role: BlockRole::Time,
5102 edf: 1.0,
5103 lambdas: Array1::zeros(0),
5104 },
5105 FittedBlock {
5106 beta: array![0.2],
5107 role: BlockRole::Mean,
5108 edf: 1.0,
5109 lambdas: Array1::zeros(0),
5110 },
5111 FittedBlock {
5112 beta: array![0.3],
5113 role: BlockRole::Scale,
5114 edf: 1.0,
5115 lambdas: Array1::zeros(0),
5116 },
5117 ]);
5118 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5119 payload.survival_time_basis = Some("ispline".to_string());
5125
5126 let err = FittedModel::from_payload(payload)
5127 .validate_for_persistence()
5128 .expect_err("survival model without time-anchor metadata should fail validation");
5129 assert!(err.to_string().contains("missing survival_time_anchor"));
5130 }
5131
5132 #[test]
5133 fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5134 let fit = saved_fit(vec![
5135 FittedBlock {
5136 beta: array![0.1],
5137 role: BlockRole::Time,
5138 edf: 1.0,
5139 lambdas: Array1::zeros(0),
5140 },
5141 FittedBlock {
5142 beta: array![0.2],
5143 role: BlockRole::Mean,
5144 edf: 1.0,
5145 lambdas: Array1::zeros(0),
5146 },
5147 FittedBlock {
5148 beta: array![0.3],
5149 role: BlockRole::Scale,
5150 edf: 1.0,
5151 lambdas: Array1::zeros(0),
5152 },
5153 ]);
5154 let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5155
5156 let err = FittedModel::from_payload(payload)
5157 .validate_for_persistence()
5158 .expect_err("survival model without time-basis metadata should fail validation");
5159 assert!(err.to_string().contains("missing survival_time_basis"));
5160 }
5161
5162 #[test]
5163 fn saved_prediction_runtime_rejects_stale_payload_version() {
5164 let fit = saved_fit(vec![
5165 FittedBlock {
5166 beta: array![0.1],
5167 role: BlockRole::Mean,
5168 edf: 1.0,
5169 lambdas: Array1::zeros(0),
5170 },
5171 FittedBlock {
5172 beta: array![0.2],
5173 role: BlockRole::Scale,
5174 edf: 1.0,
5175 lambdas: Array1::zeros(0),
5176 },
5177 ]);
5178 let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5179
5180 let err = FittedModel::from_payload(payload)
5181 .saved_prediction_runtime()
5182 .expect_err("stale payload version should fail before runtime assembly");
5183 assert!(err.to_string().contains("payload schema mismatch"));
5184 }
5185}