1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4 LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7 SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12 monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15 inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16 parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21 InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22 SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72 pub feature_column: String,
74 pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83 pub feature_columns: Vec<String>,
85 pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99 SchemaMismatch { reason: String },
103 PayloadCorrupt { reason: String },
107 MissingField { reason: String },
110 IncompatibleConfig { reason: String },
114 InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120 FittedModelError {
121 SchemaMismatch,
122 PayloadCorrupt,
123 MissingField,
124 IncompatibleConfig,
125 InvalidInput,
126 }
127}
128
129impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134 fn from(err: FittedModelError) -> Self {
135 gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136 }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140 fn from(err: FittedModelError) -> Self {
141 crate::survival::predict::SurvivalPredictError::ModelPayload {
142 context: "saved-model survival prediction payload",
143 source: err,
144 }
145 }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150 pub mean: f64,
151 pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156 if !self.mean.is_finite() {
157 return Err(FittedModelError::PayloadCorrupt {
158 reason: format!("{context} latent z mean must be finite"),
159 });
160 }
161 if !(self.sd.is_finite() && self.sd > 1e-12) {
162 return Err(FittedModelError::PayloadCorrupt {
163 reason: format!(
164 "{context} latent z sd must be finite and > 1e-12; got {}",
165 self.sd
166 ),
167 });
168 }
169 Ok::<(), _>(())
170 }
171
172 pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173 self.validate(context)?;
174 if z.iter().any(|value| !value.is_finite()) {
175 return Err(FittedModelError::PayloadCorrupt {
176 reason: format!("{context} requires finite z values"),
177 });
178 }
179 Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180 }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189 #[default]
190 FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195 #[serde(default)]
196 pub score_kind: TransformationScoreKind,
197 #[serde(default = "default_transformation_score_pit_clip_eps")]
198 pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202 TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206 pub fn finite_support_pit() -> Self {
207 Self {
208 score_kind: TransformationScoreKind::FiniteSupportPit,
209 clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210 }
211 }
212
213 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214 if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215 return Err(FittedModelError::IncompatibleConfig {
216 reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217 });
218 }
219 if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220 return Err(FittedModelError::IncompatibleConfig {
221 reason: format!(
222 "{context} requires PIT clip_eps in (0, 0.5), got {}",
223 self.clip_eps
224 ),
225 });
226 }
227 Ok::<(), _>(())
228 }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233 pub version: u32,
234 pub formula: String,
235 pub model_kind: ModelKind,
236 pub family_state: FittedFamily,
237 pub family: String,
238 #[serde(default)]
247 pub inference_notes: Vec<String>,
248 #[serde(default)]
249 pub used_device: bool,
250 #[serde(default)]
251 pub fit_result: Option<UnifiedFitResult>,
252 #[serde(default)]
254 pub unified: Option<UnifiedFitResult>,
255 #[serde(default)]
263 pub spline_scan: Option<SavedSplineScan>,
264 #[serde(default)]
270 pub residual_cascade: Option<SavedResidualCascade>,
271 #[serde(default)]
272 pub data_schema: Option<DataSchema>,
273 pub link: Option<InverseLink>,
274 #[serde(default)]
275 pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276 #[serde(default)]
277 pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278 #[serde(default)]
279 pub formula_noise: Option<String>,
280 #[serde(default)]
281 pub formula_logslope: Option<String>,
282 #[serde(default)]
283 pub formula_logslopes: Option<Vec<String>>,
284 #[serde(default)]
285 pub offset_column: Option<String>,
286 #[serde(default)]
287 pub noise_offset_column: Option<String>,
288 #[serde(default)]
289 pub beta_noise: Option<Vec<f64>>,
290 #[serde(default)]
291 pub noise_projection: Option<Vec<Vec<f64>>>,
292 #[serde(default)]
293 pub noise_center: Option<Vec<f64>>,
294 #[serde(default)]
295 pub noise_scale: Option<Vec<f64>>,
296 #[serde(default)]
297 pub noise_non_intercept_start: Option<usize>,
298 #[serde(default)]
302 pub noise_projection_ridge_alpha: Option<f64>,
303 #[serde(default)]
304 pub gaussian_response_scale: Option<f64>,
305 #[serde(default)]
306 pub linkwiggle_knots: Option<Vec<f64>>,
307 #[serde(default)]
308 pub linkwiggle_degree: Option<usize>,
309 #[serde(default)]
310 pub beta_link_wiggle: Option<Vec<f64>>,
311 #[serde(default)]
312 pub baseline_timewiggle_knots: Option<Vec<f64>>,
313 #[serde(default)]
314 pub baseline_timewiggle_degree: Option<usize>,
315 #[serde(default)]
316 pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
317 #[serde(default)]
318 pub baseline_timewiggle_double_penalty: Option<bool>,
319 #[serde(default)]
320 pub beta_baseline_timewiggle: Option<Vec<f64>>,
321 #[serde(default)]
322 pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
323 #[serde(default)]
324 pub z_column: Option<String>,
325 #[serde(default)]
326 pub z_columns: Option<Vec<String>>,
327 #[serde(default)]
328 pub latent_z_normalization: Option<SavedLatentZNormalization>,
329 #[serde(default)]
330 pub latent_score_contract: Option<SavedLatentScoreContract>,
331 #[serde(default)]
332 pub latent_measure: Option<LatentMeasureKind>,
333 #[serde(default)]
340 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
341 #[serde(default)]
350 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
351 #[serde(default)]
352 pub marginal_baseline: Option<f64>,
353 #[serde(default)]
354 pub logslope_baseline: Option<f64>,
355 #[serde(default)]
356 pub logslope_baselines: Option<Vec<f64>>,
357 #[serde(default)]
358 pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
359 #[serde(default)]
360 pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
361 #[serde(default)]
366 pub influence_absorber_width: Option<usize>,
367 #[serde(default)]
368 pub survival_entry: Option<String>,
369 #[serde(default)]
370 pub survival_exit: Option<String>,
371 #[serde(default)]
372 pub survival_event: Option<String>,
373 #[serde(default)]
374 pub survivalspec: Option<String>,
375 #[serde(default)]
376 pub survival_cause_count: Option<usize>,
377 #[serde(default)]
378 pub survival_endpoint_names: Option<Vec<String>>,
379 #[serde(default)]
380 pub survival_baseline_target: Option<String>,
381 #[serde(default)]
382 pub survival_baseline_scale: Option<f64>,
383 #[serde(default)]
384 pub survival_baseline_shape: Option<f64>,
385 #[serde(default)]
386 pub survival_baseline_rate: Option<f64>,
387 #[serde(default)]
388 pub survival_baseline_makeham: Option<f64>,
389 #[serde(default)]
390 pub survival_time_basis: Option<String>,
391 #[serde(default)]
392 pub survival_time_degree: Option<usize>,
393 #[serde(default)]
394 pub survival_time_knots: Option<Vec<f64>>,
395 #[serde(default)]
396 pub survival_time_keep_cols: Option<Vec<usize>>,
397 #[serde(default)]
398 pub survival_time_smooth_lambda: Option<f64>,
399 #[serde(default)]
400 pub survival_time_anchor: Option<f64>,
401 #[serde(default)]
402 pub survivalridge_lambda: Option<f64>,
403 #[serde(default)]
404 pub survival_likelihood: Option<String>,
405 #[serde(default)]
406 pub survival_beta_time: Option<Vec<f64>>,
407 #[serde(default)]
408 pub survival_beta_threshold: Option<Vec<f64>>,
409 #[serde(default)]
410 pub survival_beta_log_sigma: Option<Vec<f64>>,
411 #[serde(default)]
412 pub survival_noise_projection: Option<Vec<Vec<f64>>>,
413 #[serde(default)]
414 pub survival_noise_center: Option<Vec<f64>>,
415 #[serde(default)]
416 pub survival_noise_scale: Option<Vec<f64>>,
417 #[serde(default)]
418 pub survival_noise_non_intercept_start: Option<usize>,
419 #[serde(default)]
422 pub survival_noise_projection_ridge_alpha: Option<f64>,
423 #[serde(default)]
424 pub survival_distribution: Option<ResidualDistribution>,
425 #[serde(default)]
426 pub training_headers: Option<Vec<String>>,
427 #[serde(default, skip_serializing_if = "Option::is_none")]
439 pub training_table_kind: Option<String>,
440 #[serde(default)]
448 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
449 #[serde(default, skip_serializing_if = "Option::is_none")]
455 pub group_metadata: Option<GroupMetadata>,
456 #[serde(default, skip_serializing_if = "Vec::is_empty")]
463 pub deployment_extensions: Vec<SavedDeploymentExtension>,
464 #[serde(default)]
466 pub transformation_response_knots: Option<Vec<f64>>,
467 #[serde(default)]
469 pub transformation_response_transform: Option<Vec<Vec<f64>>>,
470 #[serde(default)]
472 pub transformation_response_degree: Option<usize>,
473 #[serde(default)]
475 pub transformation_response_median: Option<f64>,
476 #[serde(default)]
480 pub transformation_score_calibration: Option<TransformationScoreCalibration>,
481 #[serde(default)]
482 pub resolved_termspec: Option<TermCollectionSpec>,
483 #[serde(default)]
484 pub resolved_termspec_noise: Option<TermCollectionSpec>,
485 #[serde(default)]
486 pub resolved_termspec_logslope: Option<TermCollectionSpec>,
487 #[serde(default)]
488 pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
489 #[serde(default)]
490 pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
491 #[serde(default)]
503 pub gaussian_jackknife_plus:
504 Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
505 #[serde(default)]
519 pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
523pub struct SavedDeploymentExtension {
524 pub name: String,
525 pub kind: String,
526 pub term: String,
527 pub level: JsonValue,
528 pub level_bits: u64,
529 pub coefficient_index: usize,
530 pub coefficient_mean: f64,
531 pub coefficient_variance: f64,
532 #[serde(default, skip_serializing_if = "Option::is_none")]
533 pub metadata: Option<JsonValue>,
534 #[serde(default, skip_serializing_if = "Option::is_none")]
535 pub prior: Option<JsonValue>,
536}
537
538pub fn append_deployment_extension_columns(
555 model: &FittedModelPayload,
556 data: ndarray::ArrayView2<'_, f64>,
557 col_map: &HashMap<String, usize>,
558 training_headers: Option<&Vec<String>>,
559 base_design: Array2<f64>,
560) -> Result<Array2<f64>, FittedModelError> {
561 if model.deployment_extensions.is_empty() {
562 return Ok(base_design);
563 }
564 if base_design.nrows() != data.nrows() {
565 return Err(FittedModelError::SchemaMismatch {
566 reason: format!(
567 "deployment extension design row mismatch: base design has {} rows but data has {}",
568 base_design.nrows(),
569 data.nrows()
570 ),
571 });
572 }
573 let spec = model
574 .resolved_termspec
575 .as_ref()
576 .ok_or_else(|| FittedModelError::MissingField {
577 reason: "deployment extension prediction requires saved resolved_termspec; refit"
578 .to_string(),
579 })?;
580 let n = base_design.nrows();
581 let p_old = base_design.ncols();
582 let mut extensions: Vec<&SavedDeploymentExtension> =
583 model.deployment_extensions.iter().collect();
584 extensions.sort_by_key(|extension| extension.coefficient_index);
585 for (tail_idx, extension) in extensions.iter().enumerate() {
586 let expected = p_old + tail_idx;
587 if extension.coefficient_index != expected {
588 return Err(FittedModelError::SchemaMismatch {
589 reason: format!(
590 "deployment extension '{}' has coefficient index {}, expected append-only index {}",
591 extension.name, extension.coefficient_index, expected
592 ),
593 });
594 }
595 }
596
597 let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
598 out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
599 for (tail_idx, extension) in extensions.into_iter().enumerate() {
600 if extension.kind != "random-effect-level" {
601 return Err(FittedModelError::IncompatibleConfig {
602 reason: format!(
603 "unsupported deployment extension kind '{}' for '{}'",
604 extension.kind, extension.name
605 ),
606 });
607 }
608 let term = spec
609 .random_effect_terms
610 .iter()
611 .find(|term| term.name == extension.term)
612 .ok_or_else(|| FittedModelError::MissingField {
613 reason: format!(
614 "deployment extension '{}' references unknown random-effect term '{}'",
615 extension.name, extension.term
616 ),
617 })?;
618 let prediction_col = training_headers
619 .and_then(|headers| headers.get(term.feature_col))
620 .and_then(|name| col_map.get(name))
621 .copied()
622 .unwrap_or(term.feature_col);
623 if prediction_col >= data.ncols() {
624 return Err(FittedModelError::SchemaMismatch {
625 reason: format!(
626 "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
627 extension.name,
628 prediction_col,
629 data.ncols()
630 ),
631 });
632 }
633 let col = p_old + tail_idx;
634 for row in 0..n {
635 if data[[row, prediction_col]].to_bits() == extension.level_bits {
636 out[[row, col]] = 1.0;
637 }
638 }
639 }
640 Ok(out)
641}
642
643#[derive(Clone, Debug, Serialize, Deserialize)]
644pub struct SavedLatentScoreContract {
645 pub semantics: String,
646 pub source_transform_id: Option<String>,
647 pub normalization_mean: f64,
648 pub normalization_sd: f64,
649 pub clip_eps: Option<f64>,
650 pub conditioning_columns: Vec<String>,
651}
652
653impl FittedModelPayload {
654 pub fn new(
655 version: u32,
656 formula: String,
657 model_kind: ModelKind,
658 family_state: FittedFamily,
659 family: String,
660 ) -> Self {
661 Self {
662 version,
663 formula,
664 model_kind,
665 family_state,
666 family,
667 inference_notes: Vec::new(),
668 used_device: false,
669 fit_result: None,
670 unified: None,
671 spline_scan: None,
672 residual_cascade: None,
673 data_schema: None,
674 link: None,
675 mixture_link_param_covariance: None,
676 sas_param_covariance: None,
677 formula_noise: None,
678 formula_logslope: None,
679 formula_logslopes: None,
680 offset_column: None,
681 noise_offset_column: None,
682 beta_noise: None,
683 noise_projection: None,
684 noise_center: None,
685 noise_scale: None,
686 noise_non_intercept_start: None,
687 noise_projection_ridge_alpha: None,
688 gaussian_response_scale: None,
689 linkwiggle_knots: None,
690 linkwiggle_degree: None,
691 beta_link_wiggle: None,
692 baseline_timewiggle_knots: None,
693 baseline_timewiggle_degree: None,
694 baseline_timewiggle_penalty_orders: None,
695 baseline_timewiggle_double_penalty: None,
696 beta_baseline_timewiggle: None,
697 beta_baseline_timewiggle_by_cause: None,
698 z_column: None,
699 z_columns: None,
700 latent_z_normalization: None,
701 latent_score_contract: None,
702 latent_measure: None,
703 latent_z_rank_int_calibration: None,
704 latent_z_conditional_calibration: None,
705 marginal_baseline: None,
706 logslope_baseline: None,
707 logslope_baselines: None,
708 score_warp_runtime: None,
709 link_deviation_runtime: None,
710 influence_absorber_width: None,
711 survival_entry: None,
712 survival_exit: None,
713 survival_event: None,
714 survivalspec: None,
715 survival_cause_count: None,
716 survival_endpoint_names: None,
717 survival_baseline_target: None,
718 survival_baseline_scale: None,
719 survival_baseline_shape: None,
720 survival_baseline_rate: None,
721 survival_baseline_makeham: None,
722 survival_time_basis: None,
723 survival_time_degree: None,
724 survival_time_knots: None,
725 survival_time_keep_cols: None,
726 survival_time_smooth_lambda: None,
727 survival_time_anchor: None,
728 survivalridge_lambda: None,
729 survival_likelihood: None,
730 survival_beta_time: None,
731 survival_beta_threshold: None,
732 survival_beta_log_sigma: None,
733 survival_noise_projection: None,
734 survival_noise_center: None,
735 survival_noise_scale: None,
736 survival_noise_non_intercept_start: None,
737 survival_noise_projection_ridge_alpha: None,
738 survival_distribution: None,
739 training_headers: None,
740 training_table_kind: None,
741 training_feature_ranges: None,
742 group_metadata: None,
743 deployment_extensions: Vec::new(),
744 transformation_response_knots: None,
745 transformation_response_transform: None,
746 transformation_response_degree: None,
747 transformation_response_median: None,
748 transformation_score_calibration: None,
749 resolved_termspec: None,
750 resolved_termspec_noise: None,
751 resolved_termspec_logslope: None,
752 resolved_termspec_logslopes: None,
753 adaptive_regularization_diagnostics: None,
754 gaussian_jackknife_plus: None,
755 full_conformal: None,
756 }
757 }
758
759 pub fn set_training_feature_metadata(
760 &mut self,
761 headers: Vec<String>,
762 feature_ranges: Vec<(f64, f64)>,
763 ) {
764 self.training_headers = Some(headers);
765 self.training_feature_ranges = Some(feature_ranges);
766 }
767
768 fn synchronize_empty_feature_contract(&mut self) {
769 if self.fit_result.is_none() {
770 return;
771 }
772 let Some(schema) = self.data_schema.as_ref() else {
773 return;
774 };
775 if !schema.columns.is_empty() {
776 return;
777 }
778 self.training_headers.get_or_insert_with(Vec::new);
779 self.resolved_termspec
780 .get_or_insert_with(|| TermCollectionSpec {
781 linear_terms: Vec::new(),
782 smooth_terms: Vec::new(),
783 random_effect_terms: Vec::new(),
784 });
785 }
786
787 pub fn apply_survival_time_basis(
795 &mut self,
796 snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
797 ) {
798 self.survival_time_basis = Some(snapshot.basisname.clone());
799 self.survival_time_degree = snapshot.degree;
800 self.survival_time_knots = snapshot.knots.clone();
801 self.survival_time_keep_cols = snapshot.keep_cols.clone();
802 self.survival_time_smooth_lambda = snapshot.smooth_lambda;
803 self.survival_time_anchor = Some(snapshot.anchor);
804 }
805
806 fn validate_payload_version(&self) -> Result<(), FittedModelError> {
807 if self.version != MODEL_PAYLOAD_VERSION {
808 return Err(FittedModelError::SchemaMismatch {
809 reason: format!(
810 "saved model payload schema mismatch: file has version={}, \
811 this binary expects MODEL_PAYLOAD_VERSION={}. \
812 Refit with the current CLI, or rebuild the reader at the same \
813 version the model was written with.",
814 self.version, MODEL_PAYLOAD_VERSION
815 ),
816 });
817 }
818 Ok(())
819 }
820}
821
822#[derive(Clone, Serialize, Deserialize)]
823#[serde(tag = "model_type", rename_all = "kebab-case")]
824pub enum FittedModel {
825 Standard { payload: FittedModelPayload },
826 LocationScale { payload: FittedModelPayload },
827 MarginalSlope { payload: FittedModelPayload },
828 Survival { payload: FittedModelPayload },
829 TransformationNormal { payload: FittedModelPayload },
830}
831
832#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
833#[serde(rename_all = "kebab-case")]
834pub enum ModelKind {
835 Standard,
836 LocationScale,
837 MarginalSlope,
838 Survival,
839 TransformationNormal,
840}
841
842#[derive(Clone, Debug, Serialize, Deserialize)]
843#[serde(tag = "family_kind", rename_all = "kebab-case")]
844pub enum FittedFamily {
845 Standard {
846 likelihood: LikelihoodSpec,
847 #[serde(default)]
848 link: Option<StandardLink>,
849 #[serde(default)]
850 latent_cloglog_state: Option<LatentCLogLogState>,
851 #[serde(default)]
852 mixture_state: Option<MixtureLinkState>,
853 #[serde(default)]
854 sas_state: Option<SasLinkState>,
855 },
856 LocationScale {
857 likelihood: LikelihoodSpec,
858 #[serde(default)]
859 base_link: Option<InverseLink>,
860 },
861 MarginalSlope {
862 likelihood: LikelihoodSpec,
863 base_link: InverseLink,
864 frailty: FrailtySpec,
865 },
866 Survival {
867 likelihood: LikelihoodSpec,
868 #[serde(default)]
869 survival_likelihood: Option<String>,
870 #[serde(default)]
871 survival_distribution: Option<ResidualDistribution>,
872 frailty: FrailtySpec,
873 },
874 LatentSurvival {
875 frailty: FrailtySpec,
876 },
877 LatentBinary {
878 frailty: FrailtySpec,
879 },
880 TransformationNormal {
881 likelihood: LikelihoodSpec,
882 },
883}
884
885#[derive(Clone, Copy, Debug, Eq, PartialEq)]
886pub enum PredictModelClass {
887 Standard,
888 GaussianLocationScale,
889 BinomialLocationScale,
890 DispersionLocationScale,
895 BernoulliMarginalSlope,
896 Survival,
897 TransformationNormal,
898}
899
900impl PredictModelClass {
901 #[inline]
902 pub const fn name(self) -> &'static str {
903 match self {
904 Self::Standard => "standard",
905 Self::GaussianLocationScale => "gaussian location-scale",
906 Self::BinomialLocationScale => "binomial location-scale",
907 Self::DispersionLocationScale => "dispersion location-scale",
908 Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
909 Self::Survival => "survival",
910 Self::TransformationNormal => "transformation-normal",
911 }
912 }
913}
914
915#[derive(Clone, Debug)]
916pub struct SavedLinkWiggleRuntime {
917 pub knots: Vec<f64>,
918 pub degree: usize,
919 pub beta: Vec<f64>,
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedBaselineTimeWiggleRuntime {
924 pub knots: Vec<f64>,
925 pub degree: usize,
926 pub penalty_orders: Vec<usize>,
927 pub double_penalty: bool,
928 pub beta: Vec<f64>,
929}
930
931pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
934
935#[derive(Clone, Debug, Serialize, Deserialize)]
936pub struct SavedCompiledFlexBlock {
937 pub kernel: String,
938 pub breakpoints: Vec<f64>,
939 pub basis_dim: usize,
940 pub span_c0: Vec<Vec<f64>>,
941 pub span_c1: Vec<Vec<f64>>,
942 pub span_c2: Vec<Vec<f64>>,
943 pub span_c3: Vec<Vec<f64>>,
944 #[serde(default)]
950 pub anchor_correction: Option<Vec<Vec<f64>>>,
951 #[serde(default)]
955 pub anchor_components: Vec<SavedAnchorComponent>,
956}
957
958#[derive(Clone, Debug, Serialize, Deserialize)]
959pub struct SavedAnchorComponent {
960 pub kind: SavedAnchorKind,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub enum SavedAnchorKind {
965 Parametric {
966 block: ParametricAnchorBlock,
967 ncols: usize,
968 },
969 FlexEvaluation { ncols: usize },
974}
975
976#[derive(Clone, Debug)]
977pub struct SavedPredictionRuntime {
978 pub model_class: PredictModelClass,
979 pub likelihood: LikelihoodSpec,
980 pub inverse_link: Option<InverseLink>,
981 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
982 pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
983 pub score_warp: Option<SavedCompiledFlexBlock>,
984 pub link_deviation: Option<SavedCompiledFlexBlock>,
985 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
989 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
995 pub influence_absorber_width: Option<usize>,
1005}
1006
1007pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1008 fit.block_by_role(BlockRole::Location)
1009 .or_else(|| fit.block_by_role(BlockRole::Mean))
1010 .map(|block| block.beta.clone())
1011}
1012
1013pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1014 fit.block_by_role(BlockRole::Threshold)
1015 .or_else(|| fit.block_by_role(BlockRole::Location))
1016 .or_else(|| fit.block_by_role(BlockRole::Mean))
1017 .map(|block| block.beta.clone())
1018}
1019
1020pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021 fit.block_by_role(BlockRole::Scale)
1022 .map(|block| block.beta.clone())
1023}
1024
1025fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1034 use gam_problem::types::ResponseFamily;
1035 matches!(
1036 response,
1037 ResponseFamily::NegativeBinomial { .. }
1038 | ResponseFamily::Gamma
1039 | ResponseFamily::Beta { .. }
1040 | ResponseFamily::Tweedie { .. }
1041 )
1042}
1043
1044fn validate_location_scale_saved_fit(
1045 fit: &UnifiedFitResult,
1046 model_class: PredictModelClass,
1047 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1048) -> Result<(), FittedModelError> {
1049 let primary = match model_class {
1050 PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1054 gaussian_location_scale_mean_beta(fit)
1055 }
1056 PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1057 _ => None,
1058 }
1059 .ok_or_else(|| FittedModelError::MissingField {
1060 reason: match model_class {
1061 PredictModelClass::GaussianLocationScale => {
1062 "gaussian-location-scale saved fit is missing mean/location block".to_string()
1063 }
1064 PredictModelClass::DispersionLocationScale => {
1065 "dispersion-location-scale saved fit is missing mean/location block".to_string()
1066 }
1067 PredictModelClass::BinomialLocationScale => {
1068 "binomial-location-scale saved fit is missing threshold/location block".to_string()
1069 }
1070 _ => "location-scale saved fit is missing primary block".to_string(),
1071 },
1072 })?;
1073
1074 let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1075 reason: "location-scale saved fit is missing scale block".to_string(),
1076 })?;
1077 let expected =
1078 primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1079
1080 if let Some(cov) = fit.beta_covariance()
1081 && (cov.nrows() != expected || cov.ncols() != expected)
1082 {
1083 return Err(FittedModelError::SchemaMismatch {
1084 reason: format!(
1085 "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1086 cov.nrows(),
1087 cov.ncols(),
1088 expected,
1089 expected
1090 ),
1091 });
1092 }
1093 if let Some(cov) = fit.beta_covariance_corrected()
1094 && (cov.nrows() != expected || cov.ncols() != expected)
1095 {
1096 return Err(FittedModelError::SchemaMismatch {
1097 reason: format!(
1098 "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1099 cov.nrows(),
1100 cov.ncols(),
1101 expected,
1102 expected
1103 ),
1104 });
1105 }
1106 Ok(())
1107}
1108
1109fn validate_survival_saved_block_matches_payload(
1110 fit: &UnifiedFitResult,
1111 role: BlockRole,
1112 payload_beta: Option<&Vec<f64>>,
1113 label: &str,
1114) -> Result<usize, FittedModelError> {
1115 let block = fit
1116 .block_by_role(role)
1117 .ok_or_else(|| FittedModelError::MissingField {
1118 reason: format!("location-scale survival saved fit is missing {label} block"),
1119 })?;
1120 if let Some(saved) = payload_beta
1121 && block.beta.to_vec() != *saved
1122 {
1123 return Err(FittedModelError::SchemaMismatch {
1124 reason: format!(
1125 "location-scale survival saved {label} coefficients disagree with fit_result"
1126 ),
1127 });
1128 }
1129 Ok(block.beta.len())
1130}
1131
1132fn validate_survival_location_scale_saved_fit(
1133 payload: &FittedModelPayload,
1134 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1135) -> Result<(), FittedModelError> {
1136 let fit = payload
1137 .fit_result
1138 .as_ref()
1139 .ok_or_else(|| FittedModelError::MissingField {
1140 reason: "location-scale survival model is missing canonical fit_result payload"
1141 .to_string(),
1142 })?;
1143 let p_time = validate_survival_saved_block_matches_payload(
1144 fit,
1145 BlockRole::Time,
1146 payload.survival_beta_time.as_ref(),
1147 "time",
1148 )?;
1149 let p_threshold = validate_survival_saved_block_matches_payload(
1150 fit,
1151 BlockRole::Threshold,
1152 payload.survival_beta_threshold.as_ref(),
1153 "threshold",
1154 )?;
1155 let p_log_sigma = validate_survival_saved_block_matches_payload(
1156 fit,
1157 BlockRole::Scale,
1158 payload.survival_beta_log_sigma.as_ref(),
1159 "log-sigma",
1160 )?;
1161 let p_wiggle = match link_wiggle {
1162 Some(runtime) => {
1163 let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1164 FittedModelError::MissingField {
1165 reason: "location-scale survival saved fit is missing link-wiggle block"
1166 .to_string(),
1167 }
1168 })?;
1169 if block.beta.to_vec() != runtime.beta {
1170 return Err(FittedModelError::SchemaMismatch {
1171 reason:
1172 "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1173 .to_string(),
1174 });
1175 }
1176 runtime.beta.len()
1177 }
1178 None => {
1179 if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1180 return Err(FittedModelError::SchemaMismatch {
1181 reason:
1182 "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1183 .to_string(),
1184 });
1185 }
1186 0
1187 }
1188 };
1189 let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1190
1191 if let Some(cov) = fit.beta_covariance()
1192 && (cov.nrows() != expected || cov.ncols() != expected)
1193 {
1194 return Err(FittedModelError::SchemaMismatch {
1195 reason: format!(
1196 "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1197 cov.nrows(),
1198 cov.ncols(),
1199 expected,
1200 expected
1201 ),
1202 });
1203 }
1204 if let Some(cov) = fit.beta_covariance_corrected()
1205 && (cov.nrows() != expected || cov.ncols() != expected)
1206 {
1207 return Err(FittedModelError::SchemaMismatch {
1208 reason: format!(
1209 "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1210 cov.nrows(),
1211 cov.ncols(),
1212 expected,
1213 expected
1214 ),
1215 });
1216 }
1217 Ok(())
1218}
1219
1220fn validate_marginal_slope_saved_fit(
1221 fit: &UnifiedFitResult,
1222 score_warp: Option<&SavedCompiledFlexBlock>,
1223 link_deviation: Option<&SavedCompiledFlexBlock>,
1224 fit_label: &str,
1225) -> Result<(), FittedModelError> {
1226 validate_marginal_slope_saved_fit_impl(
1227 fit,
1228 score_warp,
1229 link_deviation,
1230 fit_label,
1231 "bernoulli",
1232 2,
1233 "marginal, logslope",
1234 )
1235}
1236
1237fn validate_survival_marginal_slope_saved_fit(
1238 fit: &UnifiedFitResult,
1239 score_warp: Option<&SavedCompiledFlexBlock>,
1240 link_deviation: Option<&SavedCompiledFlexBlock>,
1241 fit_label: &str,
1242) -> Result<(), FittedModelError> {
1243 validate_marginal_slope_saved_fit_impl(
1244 fit,
1245 score_warp,
1246 link_deviation,
1247 fit_label,
1248 "survival",
1249 3,
1250 "time, marginal, slope",
1251 )
1252}
1253
1254fn validate_marginal_slope_saved_fit_impl(
1262 fit: &UnifiedFitResult,
1263 score_warp: Option<&SavedCompiledFlexBlock>,
1264 link_deviation: Option<&SavedCompiledFlexBlock>,
1265 fit_label: &str,
1266 family_kind: &str,
1267 base_block_count: usize,
1268 base_block_role_list: &str,
1269) -> Result<(), FittedModelError> {
1270 let expected_blocks = base_block_count
1271 + usize::from(score_warp.is_some())
1272 + usize::from(link_deviation.is_some());
1273 if fit.blocks.len() != expected_blocks {
1274 let score_warp_suffix = if score_warp.is_some() {
1275 ", score-warp"
1276 } else {
1277 ""
1278 };
1279 let link_deviation_suffix = if link_deviation.is_some() {
1280 ", link-deviation"
1281 } else {
1282 ""
1283 };
1284 return Err(FittedModelError::SchemaMismatch {
1285 reason: format!(
1286 "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1287 fit.blocks.len(),
1288 ),
1289 });
1290 }
1291 if let Some(runtime) = score_warp {
1292 let beta = &fit.blocks[base_block_count].beta;
1293 if beta.len() != runtime.basis_dim {
1294 return Err(FittedModelError::SchemaMismatch {
1295 reason: format!(
1296 "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1297 beta.len(),
1298 runtime.basis_dim
1299 ),
1300 });
1301 }
1302 }
1303 if let Some(runtime) = link_deviation {
1304 let idx = base_block_count + usize::from(score_warp.is_some());
1305 let beta = &fit.blocks[idx].beta;
1306 if beta.len() != runtime.basis_dim {
1307 return Err(FittedModelError::SchemaMismatch {
1308 reason: format!(
1309 "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1310 beta.len(),
1311 runtime.basis_dim
1312 ),
1313 });
1314 }
1315 }
1316 Ok(())
1317}
1318
1319impl SavedLinkWiggleRuntime {
1320 fn validate_monotone_derivative(
1321 &self,
1322 q0: &Array1<f64>,
1323 ) -> Result<Array1<f64>, FittedModelError> {
1324 let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1331 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1332 let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1333 if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1334 return Err(FittedModelError::PayloadCorrupt {
1335 reason: format!(
1336 "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1337 ),
1338 });
1339 }
1340 Ok(dq_dq0)
1341 }
1342
1343 pub fn constrained_basis(
1344 &self,
1345 q0: &Array1<f64>,
1346 basis_options: BasisOptions,
1347 ) -> Result<Array2<f64>, FittedModelError> {
1348 let knot_arr = Array1::from_vec(self.knots.clone());
1349 let constrained = monotone_wiggle_basis_with_derivative_order(
1350 q0.view(),
1351 &knot_arr,
1352 self.degree,
1353 basis_options.derivative_order,
1354 )
1355 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1356 if constrained.ncols() != self.beta.len() {
1357 return Err(FittedModelError::SchemaMismatch {
1358 reason: format!(
1359 "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1360 self.beta.len(),
1361 constrained.ncols()
1362 ),
1363 });
1364 }
1365 Ok(constrained)
1366 }
1367
1368 pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1369 self.validate_monotone_derivative(q0)?;
1370 self.constrained_basis(q0, BasisOptions::value())
1371 }
1372
1373 pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1374 let q = Array1::from_vec(vec![q0]);
1375 let x = self.design(&q)?;
1376 if x.nrows() != 1 {
1377 return Err(FittedModelError::SchemaMismatch {
1378 reason: format!(
1379 "saved link-wiggle scalar evaluation expected 1 row, got {}",
1380 x.nrows()
1381 ),
1382 });
1383 }
1384 Ok(x.row(0).to_owned())
1385 }
1386
1387 pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1388 self.validate_monotone_derivative(q0)?;
1389 let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1390 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1391 Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1392 }
1393
1394 pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395 self.validate_monotone_derivative(q0)
1396 }
1397}
1398
1399impl SavedBaselineTimeWiggleRuntime {
1400 pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1401 validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1402 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1403 }
1404}
1405
1406impl SavedCompiledFlexBlock {
1407 pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1408 if self.kernel.is_empty() {
1409 return Err(FittedModelError::SchemaMismatch {
1410 reason: "saved anchored deviation runtime is missing the exact kernel marker"
1411 .to_string(),
1412 });
1413 }
1414 if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1415 return Err(FittedModelError::IncompatibleConfig {
1416 reason: format!(
1417 "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1418 self.kernel,
1419 crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1420 ),
1421 });
1422 }
1423 if self.basis_dim == 0 {
1424 return Err(FittedModelError::SchemaMismatch {
1425 reason: format!(
1426 "saved anchored deviation runtime basis_dim must be positive, got {}",
1427 self.basis_dim
1428 ),
1429 });
1430 }
1431 if self.breakpoints.len() < 2 {
1432 return Err(FittedModelError::SchemaMismatch {
1433 reason: format!(
1434 "saved anchored deviation runtime requires at least two breakpoints, got {}",
1435 self.breakpoints.len()
1436 ),
1437 });
1438 }
1439 for window in self.breakpoints.windows(2) {
1440 let left = window[0];
1441 let right = window[1];
1442 if !left.is_finite() || !right.is_finite() || right <= left {
1443 return Err(FittedModelError::PayloadCorrupt {
1444 reason: format!(
1445 "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1446 ),
1447 });
1448 }
1449 }
1450 let span_count = self.breakpoints.len() - 1;
1451 self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1452 self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1453 self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1454 self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1455 self.validate_c2_span_continuity()?;
1456 self.validate_anchor_residual_shape()?;
1457 Ok(())
1458 }
1459
1460 fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1461 let coeffs = match self.anchor_correction.as_ref() {
1462 Some(c) => c,
1463 None => {
1464 if !self.anchor_components.is_empty() {
1465 return Err(FittedModelError::SchemaMismatch {
1466 reason:
1467 "saved anchored deviation runtime has anchor_components but no anchor_correction"
1468 .to_string(),
1469 });
1470 }
1471 return Ok(());
1472 }
1473 };
1474 let d: usize = self
1475 .anchor_components
1476 .iter()
1477 .map(|c| match &c.kind {
1478 SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1479 SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1480 })
1481 .sum();
1482 if coeffs.len() != d {
1483 return Err(FittedModelError::SchemaMismatch {
1484 reason: format!(
1485 "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1486 coeffs.len(),
1487 d,
1488 ),
1489 });
1490 }
1491 for (i, row) in coeffs.iter().enumerate() {
1492 if row.len() != self.basis_dim {
1493 return Err(FittedModelError::SchemaMismatch {
1494 reason: format!(
1495 "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1496 i,
1497 row.len(),
1498 self.basis_dim,
1499 ),
1500 });
1501 }
1502 for (j, &v) in row.iter().enumerate() {
1503 if !v.is_finite() {
1504 return Err(FittedModelError::PayloadCorrupt {
1505 reason: format!(
1506 "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1507 ),
1508 });
1509 }
1510 }
1511 }
1512 Ok(())
1513 }
1514
1515 fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1516 const TOL: f64 = 1e-8;
1517 for span_idx in 1..self.breakpoints.len() - 1 {
1518 let left_span = span_idx - 1;
1519 let right_span = span_idx;
1520 let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1521 for basis_idx in 0..self.basis_dim {
1522 let left_value = self.span_c0[left_span][basis_idx]
1523 + self.span_c1[left_span][basis_idx] * width
1524 + self.span_c2[left_span][basis_idx] * width * width
1525 + self.span_c3[left_span][basis_idx] * width * width * width;
1526 let left_d1 = self.span_c1[left_span][basis_idx]
1527 + 2.0 * self.span_c2[left_span][basis_idx] * width
1528 + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1529 let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1530 + 6.0 * self.span_c3[left_span][basis_idx] * width;
1531 let right_value = self.span_c0[right_span][basis_idx];
1532 let right_d1 = self.span_c1[right_span][basis_idx];
1533 let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1534 if (left_value - right_value).abs() > TOL
1535 || (left_d1 - right_d1).abs() > TOL
1536 || (left_d2 - right_d2).abs() > TOL
1537 {
1538 return Err(FittedModelError::SchemaMismatch {
1539 reason: format!(
1540 "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1541 left_value - right_value,
1542 left_d1 - right_d1,
1543 left_d2 - right_d2
1544 ),
1545 });
1546 }
1547 }
1548 }
1549 Ok(())
1550 }
1551
1552 fn validate_coefficient_matrix(
1553 &self,
1554 matrix: &[Vec<f64>],
1555 label: &str,
1556 expected_rows: usize,
1557 ) -> Result<(), FittedModelError> {
1558 if matrix.len() != expected_rows {
1559 return Err(FittedModelError::SchemaMismatch {
1560 reason: format!(
1561 "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1562 matrix.len(),
1563 expected_rows
1564 ),
1565 });
1566 }
1567 for (row_idx, row) in matrix.iter().enumerate() {
1568 if row.len() != self.basis_dim {
1569 return Err(FittedModelError::SchemaMismatch {
1570 reason: format!(
1571 "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1572 row_idx,
1573 row.len(),
1574 self.basis_dim
1575 ),
1576 });
1577 }
1578 for (j, &value) in row.iter().enumerate() {
1579 if !value.is_finite() {
1580 return Err(FittedModelError::PayloadCorrupt {
1581 reason: format!(
1582 "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1583 ),
1584 });
1585 }
1586 }
1587 }
1588 Ok(())
1589 }
1590
1591 fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1592 let last_span = self.breakpoints.len() - 2;
1593 let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1594 self.span_c0[last_span][basis_idx]
1595 + self.span_c1[last_span][basis_idx] * width
1596 + self.span_c2[last_span][basis_idx] * width * width
1597 + self.span_c3[last_span][basis_idx] * width * width * width
1598 }
1599
1600 fn evaluate_span_polynomial_design(
1601 &self,
1602 values: &Array1<f64>,
1603 derivative_order: usize,
1604 ) -> Result<Array2<f64>, FittedModelError> {
1605 self.validate_exact_replay_contract()?;
1606 let (left_ep, right_ep) = self.support_interval()?;
1607 let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1608 for (row_idx, &value) in values.iter().enumerate() {
1609 if !value.is_finite() {
1610 return Err(FittedModelError::PayloadCorrupt {
1611 reason: format!(
1612 "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1613 ),
1614 });
1615 }
1616 if value < left_ep {
1617 if derivative_order == 0 {
1618 for basis_idx in 0..self.basis_dim {
1619 out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1620 }
1621 }
1622 continue;
1623 }
1624 if value > right_ep {
1625 if derivative_order == 0 {
1626 for basis_idx in 0..self.basis_dim {
1627 out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1628 }
1629 }
1630 continue;
1631 }
1632 let span_idx = self.left_biased_span_index_for(value)?;
1633 let t = value - self.breakpoints[span_idx];
1634 for basis_idx in 0..self.basis_dim {
1635 let c0 = self.span_c0[span_idx][basis_idx];
1636 let c1 = self.span_c1[span_idx][basis_idx];
1637 let c2 = self.span_c2[span_idx][basis_idx];
1638 let c3 = self.span_c3[span_idx][basis_idx];
1639 out[[row_idx, basis_idx]] = match derivative_order {
1640 0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1641 1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1642 2 => 2.0 * c2 + 6.0 * c3 * t,
1643 3 => 6.0 * c3,
1644 4 => 0.0,
1645 other => {
1646 return Err(FittedModelError::IncompatibleConfig {
1647 reason: format!(
1648 "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1649 ),
1650 });
1651 }
1652 };
1653 }
1654 }
1655 Ok(out)
1656 }
1657
1658 pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1659 self.validate_exact_replay_contract()?;
1660 Ok(self.breakpoints.clone())
1661 }
1662
1663 pub fn span_count(&self) -> Result<usize, FittedModelError> {
1664 Ok(self.breakpoints()?.windows(2).count())
1665 }
1666
1667 pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1668 let points = self.breakpoints()?;
1669 span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1670 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1671 }
1672
1673 fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1674 let mut span_idx = span_index_for_breakpoints(
1675 &self.breakpoints,
1676 value,
1677 "saved anchored deviation span lookup",
1678 )
1679 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1680 if span_idx > 0 && value == self.breakpoints[span_idx] {
1683 span_idx -= 1;
1684 }
1685 Ok(span_idx)
1686 }
1687
1688 pub fn local_cubic_on_span(
1689 &self,
1690 beta: &Array1<f64>,
1691 span_idx: usize,
1692 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1693 self.validate_exact_replay_contract()?;
1694 if beta.len() != self.basis_dim {
1695 return Err(FittedModelError::SchemaMismatch {
1696 reason: format!(
1697 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1698 beta.len(),
1699 self.basis_dim
1700 ),
1701 });
1702 }
1703 self.local_cubic_on_span_validated(beta, span_idx)
1704 }
1705
1706 fn local_cubic_on_span_validated(
1707 &self,
1708 beta: &Array1<f64>,
1709 span_idx: usize,
1710 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1711 let points = &self.breakpoints;
1712 if span_idx + 1 >= points.len() {
1713 return Err(FittedModelError::SchemaMismatch {
1714 reason: format!(
1715 "saved anchored deviation span index {} out of range for {} spans",
1716 span_idx,
1717 points.len() - 1
1718 ),
1719 });
1720 }
1721 let left = points[span_idx];
1722 let right = points[span_idx + 1];
1723 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1724 left,
1725 right,
1726 c0: self.span_c0[span_idx]
1727 .iter()
1728 .zip(beta.iter())
1729 .map(|(coeff, weight)| coeff * weight)
1730 .sum(),
1731 c1: self.span_c1[span_idx]
1732 .iter()
1733 .zip(beta.iter())
1734 .map(|(coeff, weight)| coeff * weight)
1735 .sum(),
1736 c2: self.span_c2[span_idx]
1737 .iter()
1738 .zip(beta.iter())
1739 .map(|(coeff, weight)| coeff * weight)
1740 .sum(),
1741 c3: self.span_c3[span_idx]
1742 .iter()
1743 .zip(beta.iter())
1744 .map(|(coeff, weight)| coeff * weight)
1745 .sum(),
1746 })
1747 }
1748
1749 pub fn basis_span_cubic(
1750 &self,
1751 span_idx: usize,
1752 basis_idx: usize,
1753 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1754 self.validate_exact_replay_contract()?;
1755 if basis_idx >= self.basis_dim {
1756 return Err(FittedModelError::SchemaMismatch {
1757 reason: format!(
1758 "saved anchored deviation basis index {} out of range for {} coefficients",
1759 basis_idx, self.basis_dim
1760 ),
1761 });
1762 }
1763 self.basis_span_cubic_validated(span_idx, basis_idx)
1764 }
1765
1766 fn basis_span_cubic_validated(
1767 &self,
1768 span_idx: usize,
1769 basis_idx: usize,
1770 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1771 let points = &self.breakpoints;
1772 if span_idx + 1 >= points.len() {
1773 return Err(FittedModelError::SchemaMismatch {
1774 reason: format!(
1775 "saved anchored deviation span index {} out of range for {} spans",
1776 span_idx,
1777 points.len() - 1
1778 ),
1779 });
1780 }
1781 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1782 left: points[span_idx],
1783 right: points[span_idx + 1],
1784 c0: self.span_c0[span_idx][basis_idx],
1785 c1: self.span_c1[span_idx][basis_idx],
1786 c2: self.span_c2[span_idx][basis_idx],
1787 c3: self.span_c3[span_idx][basis_idx],
1788 })
1789 }
1790
1791 pub fn basis_cubic_at(
1792 &self,
1793 basis_idx: usize,
1794 value: f64,
1795 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1796 self.validate_exact_replay_contract()?;
1797 if basis_idx >= self.basis_dim {
1798 return Err(FittedModelError::SchemaMismatch {
1799 reason: format!(
1800 "saved anchored deviation basis index {} out of range for {} coefficients",
1801 basis_idx, self.basis_dim
1802 ),
1803 });
1804 }
1805 let (left_ep, right_ep) = self.support_interval()?;
1806 if value < left_ep {
1807 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1808 left: left_ep,
1809 right: left_ep + 1.0,
1810 c0: self.span_c0[0][basis_idx],
1811 c1: 0.0,
1812 c2: 0.0,
1813 c3: 0.0,
1814 });
1815 }
1816 if value > right_ep {
1817 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1818 left: right_ep,
1819 right: right_ep + 1.0,
1820 c0: self.right_boundary_basis_value(basis_idx),
1821 c1: 0.0,
1822 c2: 0.0,
1823 c3: 0.0,
1824 });
1825 }
1826 let span_idx = self.left_biased_span_index_for(value)?;
1827 self.basis_span_cubic_validated(span_idx, basis_idx)
1828 }
1829
1830 pub fn local_cubic_at(
1831 &self,
1832 beta: &Array1<f64>,
1833 value: f64,
1834 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1835 self.validate_exact_replay_contract()?;
1836 if beta.len() != self.basis_dim {
1837 return Err(FittedModelError::SchemaMismatch {
1838 reason: format!(
1839 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1840 beta.len(),
1841 self.basis_dim
1842 ),
1843 });
1844 }
1845 let (left_ep, right_ep) = self.support_interval()?;
1846 if value < left_ep {
1847 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1848 left: left_ep,
1849 right: left_ep + 1.0,
1850 c0: self.span_c0[0]
1851 .iter()
1852 .zip(beta.iter())
1853 .map(|(coeff, weight)| coeff * weight)
1854 .sum(),
1855 c1: 0.0,
1856 c2: 0.0,
1857 c3: 0.0,
1858 });
1859 }
1860 if value > right_ep {
1861 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1862 left: right_ep,
1863 right: right_ep + 1.0,
1864 c0: (0..self.basis_dim)
1865 .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1866 .sum(),
1867 c1: 0.0,
1868 c2: 0.0,
1869 c3: 0.0,
1870 });
1871 }
1872 let span_idx = self.left_biased_span_index_for(value)?;
1873 self.local_cubic_on_span_validated(beta, span_idx)
1874 }
1875
1876 fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1877 let points = self.breakpoints()?;
1878 match (points.first(), points.last()) {
1879 (Some(&left), Some(&right)) => Ok((left, right)),
1880 _ => Err(FittedModelError::MissingField {
1881 reason: "saved anchored deviation runtime is missing support breakpoints"
1882 .to_string(),
1883 }),
1884 }
1885 }
1886
1887 pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1888 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1897 }
1898
1899 pub fn design_uncorrected(
1907 &self,
1908 values: &Array1<f64>,
1909 ) -> Result<Array2<f64>, FittedModelError> {
1910 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1911 }
1912
1913 pub fn design_with_anchor_rows(
1922 &self,
1923 values: &Array1<f64>,
1924 anchor_rows: ndarray::ArrayView2<f64>,
1925 ) -> Result<Array2<f64>, FittedModelError> {
1926 let mut out =
1927 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1928 if let Some(m_rows) = self.anchor_correction.as_ref() {
1929 let d = m_rows.len();
1930 if anchor_rows.nrows() != values.len() {
1931 return Err(FittedModelError::SchemaMismatch {
1932 reason: format!(
1933 "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1934 anchor_rows.nrows(),
1935 values.len(),
1936 ),
1937 });
1938 }
1939 if anchor_rows.ncols() != d {
1940 return Err(FittedModelError::SchemaMismatch {
1941 reason: format!(
1942 "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1943 anchor_rows.ncols(),
1944 d,
1945 ),
1946 });
1947 }
1948 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1950 for (i, row) in m_rows.iter().enumerate() {
1951 if row.len() != self.basis_dim {
1952 return Err(FittedModelError::SchemaMismatch {
1953 reason: format!(
1954 "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1955 i,
1956 row.len(),
1957 self.basis_dim,
1958 ),
1959 });
1960 }
1961 for (j, &v) in row.iter().enumerate() {
1962 m_dense[[i, j]] = v;
1963 }
1964 }
1965 let subtract = anchor_rows.dot(&m_dense);
1968 out = out - subtract;
1969 } else if anchor_rows.ncols() != 0 {
1970 return Err(FittedModelError::SchemaMismatch {
1971 reason: format!(
1972 "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1973 anchor_rows.ncols(),
1974 ),
1975 });
1976 }
1977 Ok(out)
1978 }
1979
1980 pub fn anchor_correction_matrix(
1988 &self,
1989 n_anchor_rows: ndarray::ArrayView2<f64>,
1990 ) -> Result<Option<Array2<f64>>, FittedModelError> {
1991 let Some(m_rows) = self.anchor_correction.as_ref() else {
1992 return Ok(None);
1993 };
1994 let d = m_rows.len();
1995 if n_anchor_rows.ncols() != d {
1996 return Err(FittedModelError::SchemaMismatch {
1997 reason: format!(
1998 "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
1999 n_anchor_rows.ncols(),
2000 d,
2001 ),
2002 });
2003 }
2004 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2005 for (i, row) in m_rows.iter().enumerate() {
2006 if row.len() != self.basis_dim {
2007 return Err(FittedModelError::SchemaMismatch {
2008 reason: format!(
2009 "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2010 i,
2011 row.len(),
2012 self.basis_dim,
2013 ),
2014 });
2015 }
2016 for (j, &v) in row.iter().enumerate() {
2017 m_dense[[i, j]] = v;
2018 }
2019 }
2020 Ok(Some(n_anchor_rows.dot(&m_dense)))
2023 }
2024
2025 pub fn first_derivative_design(
2026 &self,
2027 values: &Array1<f64>,
2028 ) -> Result<Array2<f64>, FittedModelError> {
2029 self.evaluate_span_polynomial_design(
2030 values,
2031 BasisOptions::first_derivative().derivative_order,
2032 )
2033 }
2034
2035 pub fn second_derivative_design(
2036 &self,
2037 values: &Array1<f64>,
2038 ) -> Result<Array2<f64>, FittedModelError> {
2039 self.evaluate_span_polynomial_design(
2040 values,
2041 BasisOptions::second_derivative().derivative_order,
2042 )
2043 }
2044}
2045
2046impl FittedFamily {
2047 #[inline]
2048 pub fn likelihood(&self) -> LikelihoodSpec {
2049 let spec = match self {
2050 Self::Standard { likelihood, .. }
2051 | Self::LocationScale { likelihood, .. }
2052 | Self::MarginalSlope { likelihood, .. }
2053 | Self::Survival { likelihood, .. }
2054 | Self::TransformationNormal { likelihood, .. } => likelihood,
2055 Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2056 return LikelihoodSpec::royston_parmar();
2057 }
2058 };
2059 spec.clone()
2060 }
2061
2062 #[inline]
2063 pub fn frailty(&self) -> Option<&FrailtySpec> {
2064 match self {
2065 Self::MarginalSlope { frailty, .. }
2066 | Self::Survival { frailty, .. }
2067 | Self::LatentSurvival { frailty }
2068 | Self::LatentBinary { frailty } => Some(frailty),
2069 _ => None,
2070 }
2071 }
2072}
2073
2074fn collect_smooth_extrapolation_axes(
2080 basis: &gam_terms::smooth::SmoothBasisSpec,
2081 n_training_headers: usize,
2082 out: &mut std::collections::HashSet<usize>,
2083) {
2084 use gam_terms::smooth::SmoothBasisSpec;
2085 let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2086 if col < n_training_headers {
2087 out.insert(col);
2088 }
2089 };
2090 match basis {
2091 SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2093 SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2096 for &c in feature_cols {
2097 push(c, out);
2098 }
2099 }
2100 SmoothBasisSpec::ThinPlate { feature_cols, .. }
2107 | SmoothBasisSpec::Matern { feature_cols, .. }
2108 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2109 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2110 for &c in feature_cols {
2111 push(c, out);
2112 }
2113 }
2114 SmoothBasisSpec::FactorSmooth { spec } => {
2118 for &c in &spec.continuous_cols {
2119 push(c, out);
2120 }
2121 }
2122 SmoothBasisSpec::ByVariable { inner, .. }
2124 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2125 collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2126 }
2127 SmoothBasisSpec::BySmooth { smooth, .. } => {
2128 collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2129 }
2130 SmoothBasisSpec::Sphere { .. }
2136 | SmoothBasisSpec::ConstantCurvature { .. }
2137 | SmoothBasisSpec::Pca { .. } => {}
2138 }
2139}
2140
2141fn collect_by_variable_numeric_axes(
2160 basis: &gam_terms::smooth::SmoothBasisSpec,
2161 n_training_headers: usize,
2162 out: &mut std::collections::HashSet<usize>,
2163) {
2164 use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2165 match basis {
2166 SmoothBasisSpec::ByVariable {
2167 inner,
2168 by_col,
2169 kind,
2170 ..
2171 } => {
2172 if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2173 out.insert(*by_col);
2174 }
2175 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2176 }
2177 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2178 if let ByVarKind::Numeric { feature_col } = by_kind
2179 && *feature_col < n_training_headers
2180 {
2181 out.insert(*feature_col);
2182 }
2183 collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2184 }
2185 SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2186 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2187 }
2188 _ => {}
2189 }
2190}
2191
2192impl FittedModel {
2193 pub fn axis_clip_to_training_ranges(
2202 &self,
2203 data: ndarray::ArrayView2<'_, f64>,
2204 col_map: &std::collections::HashMap<String, usize>,
2205 ) -> Option<ndarray::Array2<f64>> {
2206 let training_headers = self.training_headers.as_ref()?;
2207 let ranges = self.training_feature_ranges.as_ref()?;
2208 if training_headers.len() != ranges.len() {
2209 return None;
2210 }
2211 let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2212 std::collections::HashMap::new();
2213 if let Some(schema) = self.data_schema.as_ref() {
2214 for col in &schema.columns {
2215 kind_by_header.insert(col.name.as_str(), col.kind);
2216 }
2217 }
2218 let periodic_axes = self.training_periodic_axes(training_headers);
2224 let linear_axes = self.training_linear_axes(training_headers.len());
2232 let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2237 let smooth_extrapolation_axes =
2245 self.training_smooth_extrapolation_axes(training_headers.len());
2246 let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2252 let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2257 let mut clipped = data.to_owned();
2258 let mut any_clipped = false;
2259 for (col_in_training, (header, &(lo, hi))) in
2260 training_headers.iter().zip(ranges.iter()).enumerate()
2261 {
2262 let (lo, hi) = sphere_lat_bounds
2263 .get(&col_in_training)
2264 .copied()
2265 .unwrap_or((lo, hi));
2266 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2267 continue;
2268 }
2269 if !matches!(
2270 kind_by_header.get(header.as_str()).copied(),
2271 Some(ColumnKindTag::Continuous)
2272 ) {
2273 continue;
2274 }
2275 if periodic_axes.contains(&col_in_training) {
2276 continue;
2277 }
2278 if linear_axes.contains(&col_in_training) {
2279 continue;
2280 }
2281 if random_effect_axes.contains(&col_in_training) {
2282 continue;
2283 }
2284 if smooth_extrapolation_axes.contains(&col_in_training) {
2285 continue;
2286 }
2287 if by_variable_axes.contains(&col_in_training) {
2288 continue;
2289 }
2290 let Some(&col_idx) = col_map.get(header) else {
2291 continue;
2292 };
2293 if col_idx >= clipped.ncols() {
2294 continue;
2295 }
2296 let mut col = clipped.column_mut(col_idx);
2297 for v in col.iter_mut() {
2298 if v.is_finite() {
2299 if *v < lo {
2300 *v = lo;
2301 any_clipped = true;
2302 } else if *v > hi {
2303 *v = hi;
2304 any_clipped = true;
2305 }
2306 }
2307 }
2308 }
2309 if any_clipped { Some(clipped) } else { None }
2310 }
2311
2312 fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2313 let mut specs: Vec<&TermCollectionSpec> = [
2314 self.resolved_termspec.as_ref(),
2315 self.resolved_termspec_noise.as_ref(),
2316 self.resolved_termspec_logslope.as_ref(),
2317 ]
2318 .into_iter()
2319 .flatten()
2320 .collect();
2321 if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2322 specs.extend(logslopes.iter());
2323 }
2324 specs
2325 }
2326
2327 fn training_periodic_axes(
2334 &self,
2335 training_headers: &[String],
2336 ) -> std::collections::HashSet<usize> {
2337 use gam_terms::basis::BSplineKnotSpec;
2338 use gam_terms::smooth::SmoothBasisSpec;
2339 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2340 let Some(spec) = self.resolved_termspec.as_ref() else {
2341 return out;
2342 };
2343 for term in &spec.smooth_terms {
2344 match &term.basis {
2345 SmoothBasisSpec::Sphere { feature_cols, .. } => {
2351 if let Some(&lon_col) = feature_cols.get(1)
2352 && lon_col < training_headers.len()
2353 {
2354 out.insert(lon_col);
2355 }
2356 }
2357 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2359 if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2360 && *feature_col < training_headers.len()
2361 {
2362 out.insert(*feature_col);
2363 }
2364 }
2365 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2368 for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2369 if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2370 && let Some(&col) = feature_cols.get(i)
2371 && col < training_headers.len()
2372 {
2373 out.insert(col);
2374 }
2375 }
2376 }
2377 _ => {}
2378 }
2379 }
2380 out
2381 }
2382
2383 fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2394 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2395 for spec in self.saved_term_specs() {
2396 for term in &spec.linear_terms {
2397 for col in term.effective_feature_cols() {
2398 if col < n_training_headers {
2399 out.insert(col);
2400 }
2401 }
2402 }
2403 }
2404 out
2405 }
2406
2407 fn training_random_effect_axes(
2413 &self,
2414 n_training_headers: usize,
2415 ) -> std::collections::HashSet<usize> {
2416 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2417 for spec in self.saved_term_specs() {
2418 for term in &spec.random_effect_terms {
2419 if term.feature_col < n_training_headers {
2420 out.insert(term.feature_col);
2421 }
2422 }
2423 }
2424 out
2425 }
2426
2427 fn training_smooth_extrapolation_axes(
2460 &self,
2461 n_training_headers: usize,
2462 ) -> std::collections::HashSet<usize> {
2463 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2464 for spec in self.saved_term_specs() {
2465 for term in &spec.smooth_terms {
2466 collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2467 }
2468 }
2469 out
2470 }
2471
2472 fn training_by_variable_numeric_axes(
2478 &self,
2479 n_training_headers: usize,
2480 ) -> std::collections::HashSet<usize> {
2481 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2482 for spec in self.saved_term_specs() {
2483 for term in &spec.smooth_terms {
2484 collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2485 }
2486 }
2487 out
2488 }
2489
2490 fn training_sphere_latitude_bounds(
2512 &self,
2513 training_headers: &[String],
2514 ) -> std::collections::HashMap<usize, (f64, f64)> {
2515 use gam_terms::smooth::SmoothBasisSpec;
2516 let mut out: std::collections::HashMap<usize, (f64, f64)> =
2517 std::collections::HashMap::new();
2518 let Some(spec) = self.resolved_termspec.as_ref() else {
2519 return out;
2520 };
2521 for term in &spec.smooth_terms {
2522 if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2523 && let Some(&lat_col) = feature_cols.first()
2524 && lat_col < training_headers.len()
2525 {
2526 let bound = if spec.radians {
2527 std::f64::consts::FRAC_PI_2
2528 } else {
2529 90.0
2530 };
2531 out.insert(lat_col, (-bound, bound));
2532 }
2533 }
2534 out
2535 }
2536
2537 pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2538 let likelihood = payload.family_state.likelihood();
2539 let class = match payload.model_kind {
2540 ModelKind::Survival => PredictModelClass::Survival,
2541 ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2542 ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2543 ModelKind::LocationScale => {
2544 if likelihood == LikelihoodSpec::gaussian_identity() {
2545 PredictModelClass::GaussianLocationScale
2546 } else if is_dispersion_location_scale_response(&likelihood.response) {
2547 PredictModelClass::DispersionLocationScale
2548 } else {
2549 PredictModelClass::BinomialLocationScale
2550 }
2551 }
2552 ModelKind::Standard => PredictModelClass::Standard,
2553 };
2554 match class {
2555 PredictModelClass::Survival => {
2556 payload.model_kind = ModelKind::Survival;
2557 Self::Survival { payload }
2558 }
2559 PredictModelClass::BernoulliMarginalSlope => {
2560 payload.model_kind = ModelKind::MarginalSlope;
2561 Self::MarginalSlope { payload }
2562 }
2563 PredictModelClass::TransformationNormal => {
2564 payload.model_kind = ModelKind::TransformationNormal;
2565 Self::TransformationNormal { payload }
2566 }
2567 PredictModelClass::GaussianLocationScale
2568 | PredictModelClass::BinomialLocationScale
2569 | PredictModelClass::DispersionLocationScale => {
2570 payload.model_kind = ModelKind::LocationScale;
2571 Self::LocationScale { payload }
2572 }
2573 PredictModelClass::Standard => {
2574 payload.model_kind = ModelKind::Standard;
2575 Self::Standard { payload }
2576 }
2577 }
2578 .with_synchronized_stateful_link_metadata()
2579 }
2580
2581 #[inline]
2582 pub fn payload(&self) -> &FittedModelPayload {
2583 match self {
2584 Self::Standard { payload }
2585 | Self::LocationScale { payload }
2586 | Self::MarginalSlope { payload }
2587 | Self::Survival { payload }
2588 | Self::TransformationNormal { payload } => payload,
2589 }
2590 }
2591
2592 #[inline]
2593 fn payload_mut(&mut self) -> &mut FittedModelPayload {
2594 match self {
2595 Self::Standard { payload }
2596 | Self::LocationScale { payload }
2597 | Self::MarginalSlope { payload }
2598 | Self::Survival { payload }
2599 | Self::TransformationNormal { payload } => payload,
2600 }
2601 }
2602
2603 fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2604 self.synchronize_stateful_link_metadata();
2605 self
2606 }
2607
2608 fn synchronize_stateful_link_metadata(&mut self) {
2609 let payload = self.payload_mut();
2610 payload.used_device = payload
2611 .fit_result
2612 .as_ref()
2613 .or(payload.unified.as_ref())
2614 .is_some_and(|fit| fit.used_device);
2615 payload.synchronize_empty_feature_contract();
2616 let Some(fit) = payload.fit_result.as_ref() else {
2617 return;
2618 };
2619 match (&mut payload.family_state, &fit.fitted_link) {
2620 (
2621 FittedFamily::Standard {
2622 likelihood,
2623 latent_cloglog_state,
2624 ..
2625 },
2626 FittedLinkState::LatentCLogLog { state },
2627 ) if likelihood.is_latent_cloglog() => {
2628 *latent_cloglog_state = Some(*state);
2629 }
2630 (
2631 FittedFamily::Standard {
2632 likelihood,
2633 sas_state,
2634 ..
2635 },
2636 FittedLinkState::Sas { state, covariance },
2637 ) if likelihood.is_binomial_sas() => {
2638 *sas_state = Some(*state);
2639 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2640 }
2641 (
2642 FittedFamily::Standard {
2643 likelihood,
2644 sas_state,
2645 ..
2646 },
2647 FittedLinkState::BetaLogistic { state, covariance },
2648 ) if likelihood.is_binomial_beta_logistic() => {
2649 *sas_state = Some(*state);
2650 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2651 }
2652 (
2653 FittedFamily::Standard {
2654 likelihood,
2655 mixture_state,
2656 ..
2657 },
2658 FittedLinkState::Mixture { state, covariance },
2659 ) if likelihood.is_binomial_mixture() => {
2660 *mixture_state = Some(state.clone());
2661 payload.mixture_link_param_covariance =
2662 covariance.as_ref().map(array2_to_nestedvec);
2663 }
2664 _ => {}
2665 }
2666 }
2667
2668 #[inline]
2669 pub fn likelihood(&self) -> LikelihoodSpec {
2670 self.payload().family_state.likelihood()
2671 }
2672
2673 pub fn prediction_required_columns(
2691 &self,
2692 ) -> Result<std::collections::BTreeSet<String>, String> {
2693 let payload = self.payload();
2694 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2695 let mut required = std::collections::BTreeSet::<String>::new();
2696 parsed_term_column_names(&parsed.terms, &mut required);
2697
2698 if let Some((entry, exit, _event)) =
2699 parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2700 {
2701 if let Some(entry) = entry {
2702 required.insert(entry);
2703 }
2704 required.insert(exit);
2705 } else if let Some((left, right, _event)) =
2706 parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707 {
2708 required.insert(left);
2709 required.insert(right);
2710 } else if matches!(
2711 self.predict_model_class(),
2712 PredictModelClass::TransformationNormal
2713 ) {
2714 let response = parsed.response.trim();
2715 if !response.is_empty() && !response.starts_with("Surv(") {
2716 required.insert(response.to_string());
2717 }
2718 }
2719
2720 if let Some(offset) = payload.offset_column.as_ref() {
2721 required.insert(offset.clone());
2722 }
2723 if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2724 required.insert(noise_offset.clone());
2725 }
2726 if matches!(
2727 self.predict_model_class(),
2728 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2729 ) {
2730 if let Some(z_column) = payload.z_column.as_ref() {
2731 required.remove("z");
2732 required.insert(z_column.clone());
2733 }
2734 }
2735 if let Some(noise_formula) = payload.formula_noise.as_ref() {
2736 self.add_auxiliary_formula_columns(
2737 &mut required,
2738 noise_formula,
2739 parsed.response.as_str(),
2740 )?;
2741 }
2742 if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2743 if logslope_formula != "same-as-main" {
2744 self.add_auxiliary_formula_columns(
2745 &mut required,
2746 logslope_formula,
2747 parsed.response.as_str(),
2748 )?;
2749 }
2750 }
2751 Ok(required)
2752 }
2753
2754 pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2771 let payload = self.payload();
2772 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2773 if parse_surv_response(parsed.response.as_str())
2776 .map_err(|e| e.to_string())?
2777 .is_some()
2778 || parse_surv_interval_response(parsed.response.as_str())
2779 .map_err(|e| e.to_string())?
2780 .is_some()
2781 {
2782 return Ok(Vec::new());
2783 }
2784 let response = parsed.response.trim();
2785 if response.is_empty() || response.contains('(') {
2788 return Ok(Vec::new());
2789 }
2790 if self.prediction_required_columns()?.contains(response) {
2793 return Ok(Vec::new());
2794 }
2795 Ok(vec![response.to_string()])
2796 }
2797
2798 fn add_auxiliary_formula_columns(
2801 &self,
2802 required: &mut std::collections::BTreeSet<String>,
2803 formula_or_rhs: &str,
2804 response: &str,
2805 ) -> Result<(), String> {
2806 let trimmed = formula_or_rhs.trim();
2807 if trimmed.is_empty() || trimmed == "1" {
2808 return Ok(());
2809 }
2810 let formula = if trimmed.contains('~') {
2811 trimmed.to_string()
2812 } else {
2813 format!("{response} ~ {trimmed}")
2814 };
2815 let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2816 parsed_term_column_names(&parsed.terms, required);
2817 Ok(())
2818 }
2819
2820 #[inline]
2821 pub fn predict_model_class(&self) -> PredictModelClass {
2822 match &self.payload().family_state {
2823 FittedFamily::Survival { .. }
2824 | FittedFamily::LatentSurvival { .. }
2825 | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2826 FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2827 FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2828 FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2829 PredictModelClass::GaussianLocationScale
2830 }
2831 FittedFamily::LocationScale { likelihood, .. }
2832 if is_dispersion_location_scale_response(&likelihood.response) =>
2833 {
2834 PredictModelClass::DispersionLocationScale
2835 }
2836 FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2837 FittedFamily::Standard { .. } => PredictModelClass::Standard,
2838 }
2839 }
2840
2841 pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2842 let payload = self.payload();
2843 let (knots, degree) = match (
2844 payload.linkwiggle_knots.as_ref(),
2845 payload.linkwiggle_degree,
2846 ) {
2847 (None, None) => return Ok(None),
2848 (Some(knots), Some(degree)) => (knots.clone(), degree),
2849 _ => {
2850 return Err(FittedModelError::SchemaMismatch {
2851 reason:
2852 "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2853 .to_string(),
2854 })
2855 }
2856 };
2857 let resolved_link = self.resolved_inverse_link()?;
2858 let saved_link_disallows_wiggle = resolved_link
2859 .as_ref()
2860 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2861 || payload
2862 .link
2863 .as_ref()
2864 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2865 if saved_link_disallows_wiggle {
2866 return Err(FittedModelError::IncompatibleConfig {
2867 reason: joint_wiggle_unsupported_link_message("link wiggle"),
2868 });
2869 }
2870 let beta = match self.predict_model_class() {
2871 PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2879 payload.beta_link_wiggle.clone().expect("checked is_some")
2880 }
2881 PredictModelClass::Standard => {
2882 let fit = payload.fit_result.as_ref().ok_or_else(|| {
2883 FittedModelError::MissingField {
2884 reason:
2885 "standard link-wiggle model is missing canonical fit_result payload"
2886 .to_string(),
2887 }
2888 })?;
2889 if fit.blocks.len() != 2
2890 || fit.blocks[0].role != BlockRole::Mean
2891 || fit.blocks[1].role != BlockRole::LinkWiggle
2892 {
2893 return Err(FittedModelError::SchemaMismatch {
2894 reason:
2895 "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2896 .to_string(),
2897 });
2898 }
2899 fit.block_by_role(BlockRole::LinkWiggle)
2900 .ok_or_else(|| FittedModelError::MissingField {
2901 reason:
2902 "standard link-wiggle model is missing LinkWiggle coefficient block"
2903 .to_string(),
2904 })?
2905 .beta
2906 .to_vec()
2907 }
2908 _ => payload
2909 .beta_link_wiggle
2910 .clone()
2911 .ok_or_else(|| FittedModelError::MissingField {
2912 reason:
2913 "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2914 .to_string(),
2915 })?,
2916 };
2917 Ok(Some(SavedLinkWiggleRuntime {
2918 knots,
2919 degree,
2920 beta,
2921 }))
2922 }
2923
2924 pub fn saved_baseline_time_wiggle(
2925 &self,
2926 ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2927 let payload = self.payload();
2928 if payload
2929 .survival_cause_count
2930 .is_some_and(|cause_count| cause_count > 1)
2931 && payload.beta_baseline_timewiggle.is_none()
2932 && payload.beta_baseline_timewiggle_by_cause.is_some()
2933 {
2934 return Err(FittedModelError::SchemaMismatch {
2935 reason:
2936 "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2937 .to_string(),
2938 });
2939 }
2940 match (
2941 payload.baseline_timewiggle_knots.as_ref(),
2942 payload.baseline_timewiggle_degree,
2943 payload.baseline_timewiggle_penalty_orders.as_ref(),
2944 payload.baseline_timewiggle_double_penalty,
2945 payload.beta_baseline_timewiggle.as_ref(),
2946 ) {
2947 (None, None, None, None, None) => Ok(None),
2948 (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2949 Ok(Some(SavedBaselineTimeWiggleRuntime {
2950 knots: knots.clone(),
2951 degree,
2952 penalty_orders: penalty_orders.clone(),
2953 double_penalty,
2954 beta: beta.clone(),
2955 }))
2956 }
2957 _ => Err(FittedModelError::SchemaMismatch {
2958 reason:
2959 "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2960 .to_string(),
2961 }),
2962 }
2963 }
2964
2965 #[inline]
2967 pub fn has_link_wiggle(&self) -> bool {
2968 self.saved_link_wiggle()
2969 .map(|runtime| runtime.is_some())
2970 .unwrap_or(false)
2971 }
2972
2973 #[inline]
2975 pub fn has_baseline_time_wiggle(&self) -> bool {
2976 let payload = self.payload();
2977 if payload
2978 .survival_cause_count
2979 .is_some_and(|cause_count| cause_count > 1)
2980 {
2981 return payload.baseline_timewiggle_knots.is_some()
2982 && payload.baseline_timewiggle_degree.is_some()
2983 && payload.baseline_timewiggle_penalty_orders.is_some()
2984 && payload.baseline_timewiggle_double_penalty.is_some()
2985 && payload.beta_baseline_timewiggle_by_cause.is_some();
2986 }
2987 self.saved_baseline_time_wiggle()
2988 .map(|runtime| runtime.is_some())
2989 .unwrap_or(false)
2990 }
2991
2992 #[inline]
3017 pub fn prediction_uses_posterior_mean(&self) -> bool {
3018 let family = self.likelihood();
3019 let curved_family = match &family.response {
3020 ResponseFamily::Gaussian => false,
3023 ResponseFamily::Poisson
3025 | ResponseFamily::Gamma
3026 | ResponseFamily::Tweedie { .. }
3027 | ResponseFamily::NegativeBinomial { .. } => true,
3028 ResponseFamily::Beta { .. } => true,
3030 ResponseFamily::RoystonParmar => true,
3032 ResponseFamily::Binomial => matches!(
3035 &family.link,
3036 InverseLink::Standard(_)
3037 | InverseLink::Sas(_)
3038 | InverseLink::BetaLogistic(_)
3039 | InverseLink::Mixture(_)
3040 | InverseLink::LatentCLogLog(_)
3041 ),
3042 };
3043 curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3044 }
3045
3046 pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3047 self.payload().validate_payload_version()?;
3048 if matches!(
3049 self.predict_model_class(),
3050 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3051 ) {
3052 if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3053 runtime.validate_exact_replay_contract().map_err(|err| {
3054 FittedModelError::PayloadCorrupt {
3055 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3056 }
3057 })?;
3058 }
3059 if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3060 runtime.validate_exact_replay_contract().map_err(|err| {
3061 FittedModelError::PayloadCorrupt {
3062 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3063 }
3064 })?;
3065 }
3066 }
3067 let runtime = SavedPredictionRuntime {
3068 model_class: self.predict_model_class(),
3069 likelihood: self.likelihood(),
3070 inverse_link: self.resolved_inverse_link()?,
3071 link_wiggle: self.saved_link_wiggle()?,
3072 baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3073 score_warp: self.payload().score_warp_runtime.clone(),
3074 link_deviation: self.payload().link_deviation_runtime.clone(),
3075 latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3076 latent_z_conditional_calibration: self
3077 .payload()
3078 .latent_z_conditional_calibration
3079 .clone(),
3080 influence_absorber_width: self.payload().influence_absorber_width,
3081 };
3082 if matches!(
3083 runtime.model_class,
3084 PredictModelClass::GaussianLocationScale
3085 | PredictModelClass::BinomialLocationScale
3086 | PredictModelClass::DispersionLocationScale
3087 ) {
3088 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3089 FittedModelError::MissingField {
3090 reason: "location-scale model is missing canonical fit_result payload"
3091 .to_string(),
3092 }
3093 })?;
3094 validate_location_scale_saved_fit(
3095 fit,
3096 runtime.model_class,
3097 runtime.link_wiggle.as_ref(),
3098 )?;
3099 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3100 && self
3101 .payload()
3102 .survival_likelihood
3103 .as_deref()
3104 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3105 {
3106 validate_survival_location_scale_saved_fit(
3107 self.payload(),
3108 runtime.link_wiggle.as_ref(),
3109 )?;
3110 } else if matches!(
3111 runtime.model_class,
3112 PredictModelClass::BernoulliMarginalSlope
3113 ) {
3114 let unified =
3115 self.payload()
3116 .unified
3117 .as_ref()
3118 .ok_or_else(|| FittedModelError::MissingField {
3119 reason: "marginal-slope model is missing unified fit payload; refit"
3120 .to_string(),
3121 })?;
3122 validate_marginal_slope_saved_fit(
3123 unified,
3124 runtime.score_warp.as_ref(),
3125 runtime.link_deviation.as_ref(),
3126 "unified",
3127 )?;
3128 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3129 && self
3130 .payload()
3131 .survival_likelihood
3132 .as_deref()
3133 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3134 {
3135 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3136 FittedModelError::MissingField {
3137 reason: "survival marginal-slope model is missing canonical fit_result payload"
3138 .to_string(),
3139 }
3140 })?;
3141 validate_survival_marginal_slope_saved_fit(
3142 fit,
3143 runtime.score_warp.as_ref(),
3144 runtime.link_deviation.as_ref(),
3145 "fit_result",
3146 )?;
3147 }
3148 Ok(runtime)
3149 }
3150
3151 pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3152 let payload = self.payload();
3153 let raw = match &payload.family_state {
3154 FittedFamily::Standard {
3155 likelihood,
3156 sas_state,
3157 ..
3158 } if likelihood.is_binomial_sas() => {
3159 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3160 reason: "binomial-sas model is missing state in family_state.sas_state"
3161 .to_string(),
3162 })?
3163 }
3164 FittedFamily::LocationScale {
3165 likelihood,
3166 base_link,
3167 } if likelihood.is_binomial_sas() => match base_link {
3168 Some(InverseLink::Sas(state)) => *state,
3169 _ => {
3170 return Err(FittedModelError::MissingField {
3171 reason: "binomial-sas location-scale model is missing SAS base_link state"
3172 .to_string(),
3173 });
3174 }
3175 },
3176 _ => return Ok(None),
3177 };
3178 state_from_sasspec(SasLinkSpec {
3179 initial_epsilon: raw.epsilon,
3180 initial_log_delta: raw.log_delta,
3181 })
3182 .map(Some)
3183 .map_err(|e| FittedModelError::PayloadCorrupt {
3184 reason: format!("invalid saved SAS link state: {e}"),
3185 })
3186 }
3187
3188 pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3189 let payload = self.payload();
3190 let raw = match &payload.family_state {
3191 FittedFamily::Standard {
3192 likelihood,
3193 sas_state,
3194 ..
3195 } if likelihood.is_binomial_beta_logistic() => {
3196 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3197 reason:
3198 "binomial-beta-logistic model is missing state in family_state.sas_state"
3199 .to_string(),
3200 })?
3201 }
3202 FittedFamily::LocationScale {
3203 likelihood,
3204 base_link,
3205 } if likelihood.is_binomial_beta_logistic() => match base_link {
3206 Some(InverseLink::BetaLogistic(state)) => *state,
3207 _ => {
3208 return Err(FittedModelError::MissingField {
3209 reason:
3210 "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3211 .to_string(),
3212 });
3213 }
3214 },
3215 _ => return Ok(None),
3216 };
3217 state_from_beta_logisticspec(SasLinkSpec {
3218 initial_epsilon: raw.epsilon,
3219 initial_log_delta: raw.log_delta,
3220 })
3221 .map(Some)
3222 .map_err(|e| FittedModelError::PayloadCorrupt {
3223 reason: format!("invalid saved Beta-Logistic link state: {e}"),
3224 })
3225 }
3226
3227 pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3228 let payload = self.payload();
3229 match &payload.family_state {
3230 FittedFamily::Standard {
3231 likelihood,
3232 mixture_state,
3233 ..
3234 } if likelihood.is_binomial_mixture() => mixture_state
3235 .clone()
3236 .ok_or_else(|| FittedModelError::MissingField {
3237 reason: "binomial-mixture model is missing state in family_state.mixture_state"
3238 .to_string(),
3239 })
3240 .map(Some),
3241 FittedFamily::LocationScale {
3242 likelihood,
3243 base_link,
3244 } if likelihood.is_binomial_mixture() => match base_link {
3245 Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3246 _ => Err(FittedModelError::MissingField {
3247 reason:
3248 "binomial-mixture location-scale model is missing mixture base_link state"
3249 .to_string(),
3250 }),
3251 },
3252 _ => Ok(None),
3253 }
3254 }
3255
3256 pub fn saved_latent_cloglog_state(
3257 &self,
3258 ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3259 let payload = self.payload();
3260 match &payload.family_state {
3261 FittedFamily::Standard {
3262 likelihood,
3263 latent_cloglog_state,
3264 ..
3265 } if likelihood.is_latent_cloglog() => latent_cloglog_state
3266 .ok_or_else(|| FittedModelError::MissingField {
3267 reason:
3268 "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3269 .to_string(),
3270 })
3271 .map(Some),
3272 _ => Ok(None),
3273 }
3274 }
3275
3276 pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3277 let stateful = if let Some(state) = self.saved_mixture_state()? {
3278 Some(InverseLink::Mixture(state))
3279 } else if let Some(state) = self.saved_latent_cloglog_state()? {
3280 Some(InverseLink::LatentCLogLog(state))
3281 } else if let Some(state) = self.saved_beta_logistic_state()? {
3282 Some(InverseLink::BetaLogistic(state))
3283 } else {
3284 self.saved_sas_state()?.map(InverseLink::Sas)
3285 };
3286 match &self.payload().family_state {
3287 FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3288 FittedFamily::Standard { link, .. } => {
3289 Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3290 }
3291 FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3292 FittedFamily::Survival { .. }
3293 | FittedFamily::LatentSurvival { .. }
3294 | FittedFamily::LatentBinary { .. } => Ok(None),
3295 FittedFamily::TransformationNormal { .. } => Ok(None),
3296 }
3297 }
3298
3299 const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3308
3309 pub fn measure_jet_extrapolation_variance(
3336 &self,
3337 data: ndarray::ArrayView2<'_, f64>,
3338 col_map: &HashMap<String, usize>,
3339 ) -> Result<Option<Array1<f64>>, FittedModelError> {
3340 use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3341 use gam_terms::smooth::build_term_collection_design;
3342 use gam_terms::smooth::SmoothBasisSpec;
3343 let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3344 return Ok(None);
3345 };
3346 if data.nrows() == 0
3347 || !saved_spec
3348 .smooth_terms
3349 .iter()
3350 .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3351 {
3352 return Ok(None);
3353 }
3354 let fit = self
3355 .fit_result
3356 .as_ref()
3357 .ok_or_else(|| FittedModelError::MissingField {
3358 reason: "measure-jet extrapolation variance requires the canonical \
3359 fit_result payload; refit"
3360 .to_string(),
3361 })?;
3362 let spec = crate::survival::predict::resolve_termspec_for_prediction(
3363 &self.resolved_termspec,
3364 self.training_headers.as_ref(),
3365 col_map,
3366 "resolved_termspec",
3367 )
3368 .map_err(|e| FittedModelError::SchemaMismatch {
3369 reason: format!("measure-jet extrapolation variance: {e}"),
3370 })?;
3371 let probe = data.slice(ndarray::s![0..1, ..]);
3378 let design = build_term_collection_design(probe, &spec).map_err(|e| {
3379 FittedModelError::SchemaMismatch {
3380 reason: format!(
3381 "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3382 ),
3383 }
3384 })?;
3385 let lambdas = &fit.lambdas;
3386 let phi_scale = fit.coefficient_covariance_scale();
3391 let mut total = Array1::<f64>::zeros(data.nrows());
3392 let mut contributed = false;
3393 for term in &spec.smooth_terms {
3394 let SmoothBasisSpec::MeasureJet {
3395 feature_cols,
3396 spec: mj,
3397 input_scales,
3398 } = &term.basis
3399 else {
3400 continue;
3401 };
3402 let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3403 (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3404 else {
3405 log::warn!(
3406 "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3407 quadrature); skipping its extrapolation variance",
3408 term.name
3409 );
3410 continue;
3411 };
3412 let n_levels = frozen.eps_band.len();
3413 let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3420 lambdas
3421 .get(global_index)
3422 .copied()
3423 .ok_or_else(|| FittedModelError::SchemaMismatch {
3424 reason: format!(
3425 "measure-jet term '{}': penalty global index {global_index} out \
3426 of bounds for {} fitted lambdas",
3427 term.name,
3428 lambdas.len()
3429 ),
3430 })
3431 };
3432 let mut per_scale: Vec<(usize, f64)> = Vec::new();
3433 let mut fused: Option<f64> = None;
3434 for info in &design.penaltyinfo {
3435 if info.termname.as_deref() != Some(term.name.as_str()) {
3436 continue;
3437 }
3438 match &info.penalty.source {
3439 PenaltySource::Other(label) => {
3440 if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3441 let level: usize = level_txt.parse().map_err(|_| {
3442 FittedModelError::SchemaMismatch {
3443 reason: format!(
3444 "measure-jet term '{}': unparseable penalty label \
3445 '{label}'",
3446 term.name
3447 ),
3448 }
3449 })?;
3450 per_scale.push((level, read_lambda(info.global_index)?));
3451 }
3452 }
3453 PenaltySource::Primary => {
3454 fused = Some(read_lambda(info.global_index)?);
3455 }
3456 _ => {}
3457 }
3458 }
3459 let mut lambda_phys = Vec::with_capacity(n_levels);
3460 let spectrum = if per_scale.is_empty() {
3461 let Some(lam) = fused else {
3462 log::warn!(
3463 "measure-jet term '{}' has no fitted amplitude in the penalty \
3464 layout; skipping its extrapolation variance",
3465 term.name
3466 );
3467 continue;
3468 };
3469 let Some(c) = frozen.fused_penalty_normalization_scale else {
3470 log::warn!(
3471 "measure-jet term '{}' is missing the fused penalty normalization scale; \
3472 skipping its extrapolation variance",
3473 term.name
3474 );
3475 continue;
3476 };
3477 MeasureJetExtrapolationSpectrum::Fused(lam / c)
3478 } else {
3479 per_scale.sort_by_key(|&(level, _)| level);
3480 let levels_complete = per_scale.len() == n_levels
3481 && per_scale
3482 .iter()
3483 .enumerate()
3484 .all(|(i, &(level, _))| level == i);
3485 if !levels_complete {
3486 log::warn!(
3487 "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3488 scales; skipping its extrapolation variance",
3489 term.name,
3490 per_scale.len(),
3491 n_levels
3492 );
3493 continue;
3494 }
3495 if frozen.penalty_normalization_scales.len() != n_levels {
3496 log::warn!(
3497 "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3498 band scales; skipping its extrapolation variance",
3499 term.name,
3500 frozen.penalty_normalization_scales.len(),
3501 n_levels
3502 );
3503 continue;
3504 }
3505 lambda_phys.extend(
3506 per_scale
3507 .iter()
3508 .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3509 );
3510 MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3511 };
3512 let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3517 for (j, &col) in feature_cols.iter().enumerate() {
3518 if col >= data.ncols() {
3519 return Err(FittedModelError::SchemaMismatch {
3520 reason: format!(
3521 "measure-jet term '{}': prediction column {col} out of bounds \
3522 for {} data columns",
3523 term.name,
3524 data.ncols()
3525 ),
3526 });
3527 }
3528 queries.column_mut(j).assign(&data.column(col));
3529 }
3530 if let Some(scales) = input_scales {
3531 if scales.len() != feature_cols.len() {
3532 return Err(FittedModelError::SchemaMismatch {
3533 reason: format!(
3534 "measure-jet term '{}': {} input scales for {} axes",
3535 term.name,
3536 scales.len(),
3537 feature_cols.len()
3538 ),
3539 });
3540 }
3541 for (j, &scale) in scales.iter().enumerate() {
3542 queries.column_mut(j).mapv_inplace(|v| v / scale);
3543 }
3544 }
3545 let support = gam_terms::basis::measure_jet_support_curve(
3546 queries.view(),
3547 centers.view(),
3548 frozen.masses.view(),
3549 &frozen.eps_band,
3550 )
3551 .map_err(|e| FittedModelError::SchemaMismatch {
3552 reason: format!(
3553 "measure-jet term '{}': support curve failed: {e}",
3554 term.name
3555 ),
3556 })?;
3557 for i in 0..data.nrows() {
3558 let v = gam_terms::basis::measure_jet_extrapolation_variance(
3559 support.row(i),
3560 &frozen.eps_band,
3561 &frozen.support_means,
3562 spectrum,
3563 Self::MEASURE_JET_COVERAGE_FLOOR,
3564 )
3565 .map_err(|e| FittedModelError::SchemaMismatch {
3566 reason: format!(
3567 "measure-jet term '{}': extrapolation variance failed: {e}",
3568 term.name
3569 ),
3570 })?;
3571 total[i] += phi_scale * v;
3572 }
3573 contributed = true;
3574 }
3575 Ok(contributed.then_some(total))
3576 }
3577
3578 pub fn unified(&self) -> Option<&UnifiedFitResult> {
3580 self.payload().unified.as_ref()
3581 }
3582
3583 pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3584 let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3585 reason: format!("failed to read model '{}': {e}", path.display()),
3586 })?;
3587 let model: Self =
3588 serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3589 reason: format!("failed to parse model json: {e}"),
3590 })?;
3591 let model = model.with_synchronized_stateful_link_metadata();
3592 model.validate_for_persistence()?;
3593 model.validate_numeric_finiteness()?;
3594 Ok(model)
3595 }
3596
3597 pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3598 let normalized = self.clone().with_synchronized_stateful_link_metadata();
3599 normalized.validate_for_persistence()?;
3600 normalized.validate_numeric_finiteness()?;
3601 let parent = path.parent().unwrap_or_else(|| Path::new("."));
3608 let file_name = path
3609 .file_name()
3610 .and_then(|s| s.to_str())
3611 .unwrap_or("model.json");
3612 let pid = std::process::id();
3613 let nanos = std::time::SystemTime::now()
3614 .duration_since(std::time::UNIX_EPOCH)
3615 .map(|d| d.as_nanos())
3616 .unwrap_or(0);
3617 let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3618 let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3619 reason: format!("failed to write model '{}': {e}", tmp.display()),
3620 })?;
3621 let mut writer = std::io::BufWriter::new(file);
3622 let ser_result = serde_json::to_writer(&mut writer, &normalized);
3623 if let Err(e) = ser_result {
3624 std::io::Write::flush(&mut writer).ok();
3627 drop(writer);
3628 fs::remove_file(&tmp).ok();
3629 return Err(FittedModelError::PayloadCorrupt {
3630 reason: format!("failed to serialize model: {e}"),
3631 });
3632 }
3633 std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3634 reason: format!("failed to write model '{}': {e}", tmp.display()),
3635 })?;
3636 let inner = writer
3638 .into_inner()
3639 .map_err(|e| FittedModelError::PayloadCorrupt {
3640 reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3641 })?;
3642 inner.sync_all().ok();
3643 drop(inner);
3644 if let Err(e) = fs::rename(&tmp, path) {
3645 fs::remove_file(&tmp).ok();
3646 return Err(FittedModelError::PayloadCorrupt {
3647 reason: format!("failed to publish model '{}': {e}", path.display()),
3648 });
3649 }
3650 if let Ok(d) = fs::File::open(parent) {
3655 d.sync_all().ok();
3656 }
3657 Ok(())
3658 }
3659
3660 pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3661 self.data_schema
3662 .as_ref()
3663 .ok_or_else(|| FittedModelError::MissingField {
3664 reason: "model is missing data_schema; refit".to_string(),
3665 })
3666 }
3667
3668 pub fn saved_spline_scan(
3672 &self,
3673 ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3674 let Some(saved) = self.spline_scan.as_ref() else {
3675 return Ok(None);
3676 };
3677 let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3678 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3679 Ok(Some((saved.feature_column.as_str(), fit)))
3680 }
3681
3682 pub fn saved_residual_cascade(
3687 &self,
3688 ) -> Result<
3689 Option<(
3690 &[String],
3691 gam_solve::residual_cascade::ResidualCascadeFit,
3692 )>,
3693 FittedModelError,
3694 > {
3695 let Some(saved) = self.residual_cascade.as_ref() else {
3696 return Ok(None);
3697 };
3698 let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3699 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3700 Ok(Some((saved.feature_columns.as_slice(), fit)))
3701 }
3702
3703 pub fn random_effect_group_columns(&self) -> HashSet<String> {
3704 let Some(training_headers) = self.training_headers.as_ref() else {
3705 return HashSet::new();
3706 };
3707 let mut out = HashSet::<String>::new();
3708 for spec in self.saved_term_specs() {
3709 for term in &spec.random_effect_terms {
3710 if let Some(name) = training_headers.get(term.feature_col) {
3711 out.insert(name.clone());
3712 }
3713 }
3714 }
3715 out
3716 }
3717
3718 pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3719 self.validate_payload_version()?;
3733 if let Some(scan) = self.spline_scan.as_ref() {
3734 if self.fit_result.is_some() || self.unified.is_some() {
3739 return Err(FittedModelError::SchemaMismatch {
3740 reason: "spline-scan model must not also carry a dense fit_result/unified \
3741 payload; the representations are mutually exclusive"
3742 .to_string(),
3743 });
3744 }
3745 if self.model_kind != ModelKind::Standard
3746 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3747 {
3748 return Err(FittedModelError::SchemaMismatch {
3749 reason: format!(
3750 "spline-scan representation requires a standard Gaussian-identity model; \
3751 got model_kind={:?}, likelihood={:?}",
3752 self.model_kind,
3753 self.family_state.likelihood()
3754 ),
3755 });
3756 }
3757 if scan.feature_column.is_empty() {
3758 return Err(FittedModelError::MissingField {
3759 reason: "spline-scan model is missing its feature column name; refit"
3760 .to_string(),
3761 });
3762 }
3763 gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3764 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3765 if self.data_schema.is_none() {
3771 return Err(FittedModelError::MissingField {
3772 reason: "spline-scan model is missing data_schema; refit".to_string(),
3773 });
3774 }
3775 if self.training_headers.is_none() {
3776 return Err(FittedModelError::MissingField {
3777 reason: "spline-scan model is missing training_headers; refit".to_string(),
3778 });
3779 }
3780 return Ok(());
3781 } else if let Some(cascade) = self.residual_cascade.as_ref() {
3782 if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3786 return Err(FittedModelError::SchemaMismatch {
3787 reason: "residual-cascade model must not also carry spline_scan / \
3788 fit_result / unified payloads; the representations are \
3789 mutually exclusive"
3790 .to_string(),
3791 });
3792 }
3793 if self.model_kind != ModelKind::Standard
3794 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3795 {
3796 return Err(FittedModelError::SchemaMismatch {
3797 reason: format!(
3798 "residual-cascade representation requires a standard Gaussian-identity \
3799 model; got model_kind={:?}, likelihood={:?}",
3800 self.model_kind,
3801 self.family_state.likelihood()
3802 ),
3803 });
3804 }
3805 if cascade.feature_columns.is_empty()
3806 || !(2..=3).contains(&cascade.feature_columns.len())
3807 {
3808 return Err(FittedModelError::MissingField {
3809 reason: format!(
3810 "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3811 cascade.feature_columns.len()
3812 ),
3813 });
3814 }
3815 gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3816 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3817 if self.data_schema.is_none() {
3818 return Err(FittedModelError::MissingField {
3819 reason: "residual-cascade model is missing data_schema; refit".to_string(),
3820 });
3821 }
3822 if self.training_headers.is_none() {
3823 return Err(FittedModelError::MissingField {
3824 reason: "residual-cascade model is missing training_headers; refit".to_string(),
3825 });
3826 }
3827 return Ok(());
3828 } else if self.fit_result.is_none() {
3829 return Err(FittedModelError::MissingField {
3830 reason: "model is missing canonical fit_result payload; refit".to_string(),
3831 });
3832 }
3833 if self.data_schema.is_none() {
3834 return Err(FittedModelError::MissingField {
3835 reason: "model is missing data_schema; refit".to_string(),
3836 });
3837 }
3838 if self.training_headers.is_none() {
3839 return Err(FittedModelError::MissingField {
3840 reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3841 .to_string(),
3842 });
3843 }
3844 let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3845 FittedModelError::MissingField {
3846 reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3847 .to_string(),
3848 }
3849 })?;
3850 validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3851
3852 if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3853 return Err(FittedModelError::MissingField {
3854 reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3855 .to_string(),
3856 });
3857 }
3858 if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3859 validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3860 }
3861 if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3862 let score = self.transformation_score_calibration.ok_or_else(|| {
3863 FittedModelError::MissingField {
3864 reason: "transformation-normal model is missing transformation_score_calibration; refit"
3865 .to_string(),
3866 }
3867 })?;
3868 score.validate("transformation-normal model")?;
3869 }
3870 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3871 if self.formula_logslope.is_none() {
3872 return Err(FittedModelError::MissingField {
3873 reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3874 });
3875 }
3876 if self.z_column.is_none() {
3877 return Err(FittedModelError::MissingField {
3878 reason: "marginal-slope model is missing z_column; refit".to_string(),
3879 });
3880 }
3881 let z_normalization =
3882 self.latent_z_normalization
3883 .ok_or_else(|| FittedModelError::MissingField {
3884 reason: "marginal-slope model is missing latent_z_normalization; refit"
3885 .to_string(),
3886 })?;
3887 z_normalization.validate("marginal-slope model")?;
3888 let latent_measure =
3889 self.latent_measure
3890 .as_ref()
3891 .ok_or_else(|| FittedModelError::MissingField {
3892 reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3893 })?;
3894 latent_measure
3895 .validate("marginal-slope model latent_measure")
3896 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3897 if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3898 return Err(FittedModelError::MissingField {
3899 reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3900 });
3901 }
3902 if self.resolved_termspec_logslope.as_ref().is_none() {
3903 return Err(FittedModelError::MissingField {
3904 reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3905 .to_string(),
3906 });
3907 }
3908 match self.family_state.frailty() {
3909 Some(FrailtySpec::None)
3910 | Some(FrailtySpec::GaussianShift {
3911 sigma_fixed: Some(_),
3912 }) => {}
3913 Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3914 return Err(FittedModelError::IncompatibleConfig {
3915 reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3916 .to_string(),
3917 });
3918 }
3919 Some(FrailtySpec::HazardMultiplier { .. }) => {
3920 return Err(FittedModelError::IncompatibleConfig {
3921 reason: "marginal-slope model does not support HazardMultiplier frailty"
3922 .to_string(),
3923 });
3924 }
3925 None => {
3926 return Err(FittedModelError::MissingField {
3927 reason: "marginal-slope model is missing family_state.frailty; refit"
3928 .to_string(),
3929 });
3930 }
3931 }
3932 }
3933
3934 if let FittedFamily::Survival {
3935 survival_likelihood,
3936 frailty,
3937 ..
3938 } = &self.family_state
3939 {
3940 if matches!(
3941 survival_likelihood.as_deref(),
3942 Some("latent") | Some("latent-binary")
3943 ) {
3944 return Err(FittedModelError::SchemaMismatch {
3945 reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3946 .to_string(),
3947 });
3948 }
3949 if survival_likelihood.as_deref() == Some("marginal-slope") {
3950 if self.formula_logslope.is_none() {
3951 return Err(FittedModelError::MissingField {
3952 reason: "survival marginal-slope model is missing formula_logslope; refit"
3953 .to_string(),
3954 });
3955 }
3956 if self.z_column.is_none() {
3957 return Err(FittedModelError::MissingField {
3958 reason: "survival marginal-slope model is missing z_column; refit"
3959 .to_string(),
3960 });
3961 }
3962 let z_normalization =
3963 self.latent_z_normalization
3964 .ok_or_else(|| {
3965 FittedModelError::MissingField {
3966 reason:
3967 "survival marginal-slope model is missing latent_z_normalization; refit"
3968 .to_string(),
3969 }
3970 })?;
3971 z_normalization.validate("survival marginal-slope model")?;
3972 let latent_measure =
3973 self.latent_measure
3974 .as_ref()
3975 .ok_or_else(|| FittedModelError::MissingField {
3976 reason:
3977 "survival marginal-slope model is missing latent_measure; refit"
3978 .to_string(),
3979 })?;
3980 latent_measure
3981 .validate("survival marginal-slope model latent_measure")
3982 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3983 if self.logslope_baseline.is_none() {
3984 return Err(FittedModelError::MissingField {
3985 reason: "survival marginal-slope model is missing logslope_baseline; refit"
3986 .to_string(),
3987 });
3988 }
3989 if self.resolved_termspec_logslope.as_ref().is_none() {
3990 return Err(FittedModelError::MissingField {
3991 reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3992 .to_string(),
3993 });
3994 }
3995 match frailty {
3996 FrailtySpec::None
3997 | FrailtySpec::GaussianShift {
3998 sigma_fixed: Some(_),
3999 } => {}
4000 FrailtySpec::GaussianShift { sigma_fixed: None } => {
4001 return Err(FittedModelError::IncompatibleConfig {
4002 reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4003 .to_string(),
4004 });
4005 }
4006 FrailtySpec::HazardMultiplier { .. } => {
4007 return Err(FittedModelError::IncompatibleConfig {
4008 reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4009 .to_string(),
4010 });
4011 }
4012 }
4013 } else if !matches!(frailty, FrailtySpec::None) {
4014 return Err(FittedModelError::IncompatibleConfig {
4015 reason:
4016 "non-marginal survival models do not currently persist a frailty modifier"
4017 .to_string(),
4018 });
4019 }
4020 if self.survival_time_basis.is_none() {
4028 return Err(FittedModelError::MissingField {
4029 reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4030 });
4031 }
4032 if self.survival_time_anchor.is_none() {
4033 return Err(FittedModelError::MissingField {
4034 reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4035 });
4036 }
4037 }
4038 if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4039 match frailty {
4040 FrailtySpec::HazardMultiplier {
4041 sigma_fixed: Some(_),
4042 ..
4043 } => {}
4044 FrailtySpec::HazardMultiplier {
4045 sigma_fixed: None, ..
4046 } => {
4047 return Err(FittedModelError::IncompatibleConfig {
4048 reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4049 .to_string(),
4050 });
4051 }
4052 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4053 return Err(FittedModelError::IncompatibleConfig {
4054 reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4055 .to_string(),
4056 });
4057 }
4058 }
4059 if self.survival_likelihood.as_deref() != Some("latent") {
4060 return Err(FittedModelError::SchemaMismatch {
4061 reason: "latent survival model must persist survival_likelihood=latent"
4062 .to_string(),
4063 });
4064 }
4065 }
4066 if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4067 match frailty {
4068 FrailtySpec::HazardMultiplier {
4069 sigma_fixed: Some(_),
4070 ..
4071 } => {}
4072 FrailtySpec::HazardMultiplier {
4073 sigma_fixed: None, ..
4074 } => {
4075 return Err(FittedModelError::IncompatibleConfig {
4076 reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4077 .to_string(),
4078 });
4079 }
4080 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4081 return Err(FittedModelError::IncompatibleConfig {
4082 reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4083 .to_string(),
4084 });
4085 }
4086 }
4087 if self.survival_likelihood.as_deref() != Some("latent-binary") {
4088 return Err(FittedModelError::SchemaMismatch {
4089 reason: "latent binary model must persist survival_likelihood=latent-binary"
4090 .to_string(),
4091 });
4092 }
4093 }
4094
4095 let family_likelihood = match &self.family_state {
4096 FittedFamily::Standard { likelihood, .. }
4097 | FittedFamily::LocationScale { likelihood, .. }
4098 | FittedFamily::MarginalSlope { likelihood, .. }
4099 | FittedFamily::Survival { likelihood, .. }
4100 | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4101 FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4102 };
4103 let is_standard_or_location_scale = matches!(
4104 self.family_state,
4105 FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4106 );
4107 if is_standard_or_location_scale
4108 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4109 {
4110 self.saved_sas_state()?;
4111 }
4112 if is_standard_or_location_scale
4113 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4114 {
4115 self.saved_beta_logistic_state()?;
4116 }
4117 if is_standard_or_location_scale
4118 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4119 {
4120 self.saved_mixture_state()?;
4121 }
4122 if matches!(self.family_state, FittedFamily::Standard { .. })
4123 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4124 {
4125 self.saved_latent_cloglog_state()?;
4126 }
4127 if matches!(self.family_state, FittedFamily::LocationScale { .. })
4128 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4129 {
4130 return Err(FittedModelError::IncompatibleConfig {
4131 reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4132 .to_string(),
4133 });
4134 }
4135 if matches!(self.family_state, FittedFamily::Survival { .. })
4136 && self.survival_likelihood.is_none()
4137 {
4138 return Err(FittedModelError::MissingField {
4139 reason: "saved survival model is missing survival_likelihood metadata; refit"
4140 .to_string(),
4141 });
4142 }
4143 let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4144 || self.linkwiggle_degree.is_some()
4145 || self.beta_link_wiggle.is_some()
4146 || self
4147 .fit_result
4148 .as_ref()
4149 .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4150 .is_some();
4151 let saved_link_wiggle = self.saved_link_wiggle()?;
4152 if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4153 return Err(FittedModelError::SchemaMismatch {
4154 reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4155 .to_string(),
4156 });
4157 }
4158 let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4159 || self.baseline_timewiggle_degree.is_some()
4160 || self.baseline_timewiggle_penalty_orders.is_some()
4161 || self.baseline_timewiggle_double_penalty.is_some()
4162 || self.beta_baseline_timewiggle.is_some()
4163 || self.beta_baseline_timewiggle_by_cause.is_some();
4164 let is_joint_cause_specific = self
4165 .survival_cause_count
4166 .is_some_and(|cause_count| cause_count > 1);
4167 if has_any_saved_baseline_time_wiggle {
4168 if is_joint_cause_specific {
4169 let complete = self.baseline_timewiggle_knots.is_some()
4170 && self.baseline_timewiggle_degree.is_some()
4171 && self.baseline_timewiggle_penalty_orders.is_some()
4172 && self.baseline_timewiggle_double_penalty.is_some()
4173 && self.beta_baseline_timewiggle_by_cause.is_some();
4174 if !complete {
4175 return Err(FittedModelError::SchemaMismatch {
4176 reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4177 .to_string(),
4178 });
4179 }
4180 } else if self.saved_baseline_time_wiggle()?.is_none() {
4181 return Err(FittedModelError::SchemaMismatch {
4182 reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4183 .to_string(),
4184 });
4185 }
4186 }
4187 if self
4188 .survival_likelihood
4189 .as_deref()
4190 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4191 {
4192 validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4193 }
4194
4195 if let Some(runtime) = self.score_warp_runtime.as_ref() {
4205 runtime.validate_exact_replay_contract().map_err(|err| {
4206 FittedModelError::PayloadCorrupt {
4207 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4208 }
4209 })?;
4210 }
4211 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4212 runtime.validate_exact_replay_contract().map_err(|err| {
4213 FittedModelError::PayloadCorrupt {
4214 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4215 }
4216 })?;
4217 }
4218 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4219 validate_marginal_slope_saved_fit(
4220 self.fit_result.as_ref().expect("checked above"),
4221 self.score_warp_runtime.as_ref(),
4222 self.link_deviation_runtime.as_ref(),
4223 "fit_result",
4224 )?;
4225 let unified = self
4226 .unified
4227 .as_ref()
4228 .ok_or_else(|| FittedModelError::MissingField {
4229 reason: "marginal-slope model is missing unified fit payload; refit"
4230 .to_string(),
4231 })?;
4232 validate_marginal_slope_saved_fit(
4233 unified,
4234 self.score_warp_runtime.as_ref(),
4235 self.link_deviation_runtime.as_ref(),
4236 "unified",
4237 )?;
4238 }
4239 if self
4240 .survival_likelihood
4241 .as_deref()
4242 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4243 {
4244 validate_survival_marginal_slope_saved_fit(
4245 self.fit_result.as_ref().expect("checked above"),
4246 self.score_warp_runtime.as_ref(),
4247 self.link_deviation_runtime.as_ref(),
4248 "fit_result",
4249 )?;
4250 if let Some(unified) = self.unified.as_ref() {
4251 validate_survival_marginal_slope_saved_fit(
4252 unified,
4253 self.score_warp_runtime.as_ref(),
4254 self.link_deviation_runtime.as_ref(),
4255 "unified",
4256 )?;
4257 }
4258 }
4259
4260 Ok(())
4271 }
4272
4273 pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4274 let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4275 if let Some(fit) = self.fit_result.as_ref() {
4276 fit.validate_numeric_finiteness()
4277 .map_err(|e| corrupt(e.to_string()))?;
4278 }
4279
4280 for (name, opt) in [
4281 ("survival_baseline_scale", self.survival_baseline_scale),
4282 ("survival_baseline_shape", self.survival_baseline_shape),
4283 ("survival_baseline_rate", self.survival_baseline_rate),
4284 ("survival_baseline_makeham", self.survival_baseline_makeham),
4285 (
4286 "survival_time_smooth_lambda",
4287 self.survival_time_smooth_lambda,
4288 ),
4289 ("survival_time_anchor", self.survival_time_anchor),
4290 ("survivalridge_lambda", self.survivalridge_lambda),
4291 ] {
4292 if let Some(v) = opt {
4293 ensure_finite_scalar(name, v).map_err(corrupt)?;
4294 }
4295 }
4296
4297 if let Some(v) = self.beta_noise.as_ref() {
4298 validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4299 }
4300 if let Some(v) = self.noise_projection.as_ref() {
4301 validate_all_finite("noise_projection", v.iter().flatten().copied())
4302 .map_err(corrupt)?;
4303 if self.noise_projection_ridge_alpha.is_none() {
4304 return Err(FittedModelError::MissingField {
4305 reason:
4306 "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4307 .to_string(),
4308 });
4309 }
4310 }
4311 if let Some(v) = self.noise_center.as_ref() {
4312 validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4313 }
4314 if let Some(v) = self.noise_scale.as_ref() {
4315 validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4316 }
4317 if let Some(v) = self.noise_projection_ridge_alpha {
4318 ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4319 if v < 0.0 {
4320 return Err(FittedModelError::InvalidInput {
4321 reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4322 });
4323 }
4324 }
4325 if let Some(v) = self.gaussian_response_scale {
4326 ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4327 }
4328 if let Some(v) = self.beta_link_wiggle.as_ref() {
4329 validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4330 }
4331 if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4332 validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4333 }
4334 if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4335 validate_all_finite(
4336 "beta_baseline_timewiggle_by_cause",
4337 v.iter().flatten().copied(),
4338 )
4339 .map_err(corrupt)?;
4340 }
4341 if let Some(v) = self.latent_z_normalization {
4342 v.validate("latent_z_normalization")?;
4343 }
4344 if let Some(v) = self.latent_measure.as_ref() {
4345 v.validate("latent_measure").map_err(corrupt)?;
4346 }
4347 if let Some(v) = self.survival_beta_time.as_ref() {
4348 validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4349 }
4350 if let Some(v) = self.survival_beta_threshold.as_ref() {
4351 validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4352 }
4353 if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4354 validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4355 }
4356 if let Some(v) = self.survival_noise_projection.as_ref() {
4357 validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4358 .map_err(corrupt)?;
4359 if self.survival_noise_projection_ridge_alpha.is_none() {
4360 return Err(FittedModelError::MissingField {
4361 reason:
4362 "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4363 .to_string(),
4364 });
4365 }
4366 }
4367 if let Some(v) = self.survival_noise_center.as_ref() {
4368 validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4369 }
4370 if let Some(v) = self.survival_noise_projection_ridge_alpha {
4371 ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4372 if v < 0.0 {
4373 return Err(FittedModelError::InvalidInput {
4374 reason: format!(
4375 "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4376 ),
4377 });
4378 }
4379 }
4380 if let Some(v) = self.survival_noise_scale.as_ref() {
4381 validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4382 }
4383 if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4384 validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4385 .map_err(corrupt)?;
4386 }
4387 if let Some(v) = self.sas_param_covariance.as_ref() {
4388 validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4389 .map_err(corrupt)?;
4390 }
4391 Ok(())
4392 }
4393}
4394
4395fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4396 a.rows().into_iter().map(|row| row.to_vec()).collect()
4397}
4398
4399use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4400
4401fn validate_frozen_term_collectionspec(
4402 spec: &TermCollectionSpec,
4403 label: &str,
4404) -> Result<(), FittedModelError> {
4405 spec.validate_frozen(label)
4406 .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4407}
4408
4409impl Deref for FittedModel {
4410 type Target = FittedModelPayload;
4411
4412 fn deref(&self) -> &Self::Target {
4413 self.payload()
4414 }
4415}
4416
4417impl DerefMut for FittedModel {
4418 fn deref_mut(&mut self) -> &mut Self::Target {
4419 self.payload_mut()
4420 }
4421}
4422
4423pub fn survival_baseline_config_from_model(
4428 model: &FittedModel,
4429) -> Result<SurvivalBaselineConfig, FittedModelError> {
4430 let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4431 FittedModelError::MissingField {
4432 reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4433 }
4434 })?;
4435 parse_survival_baseline_config(
4436 target,
4437 model.survival_baseline_scale,
4438 model.survival_baseline_shape,
4439 model.survival_baseline_rate,
4440 model.survival_baseline_makeham,
4441 )
4442 .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4443}
4444
4445pub fn load_survival_time_basis_config_from_model(
4446 model: &FittedModel,
4447) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4448 match model
4449 .survival_time_basis
4450 .as_deref()
4451 .ok_or_else(|| FittedModelError::MissingField {
4452 reason: "saved survival model missing survival_time_basis".to_string(),
4453 })?
4454 .to_ascii_lowercase()
4455 .as_str()
4456 {
4457 "none" => Ok(SurvivalTimeBasisConfig::None),
4458 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4459 "bspline" => {
4460 let degree =
4461 model
4462 .survival_time_degree
4463 .ok_or_else(|| FittedModelError::MissingField {
4464 reason: "saved survival bspline model missing survival_time_degree"
4465 .to_string(),
4466 })?;
4467 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4468 FittedModelError::MissingField {
4469 reason: "saved survival bspline model missing survival_time_knots".to_string(),
4470 }
4471 })?;
4472 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4473 if degree < 1 || knots.is_empty() {
4474 return Err(FittedModelError::SchemaMismatch {
4475 reason: "saved survival bspline time basis metadata is invalid".to_string(),
4476 });
4477 }
4478 Ok(SurvivalTimeBasisConfig::BSpline {
4479 degree,
4480 knots: Array1::from_vec(knots),
4481 smooth_lambda,
4482 })
4483 }
4484 "ispline" => {
4485 let degree =
4486 model
4487 .survival_time_degree
4488 .ok_or_else(|| FittedModelError::MissingField {
4489 reason: "saved survival ispline model missing survival_time_degree"
4490 .to_string(),
4491 })?;
4492 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4493 FittedModelError::MissingField {
4494 reason: "saved survival ispline model missing survival_time_knots".to_string(),
4495 }
4496 })?;
4497 let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4498 FittedModelError::MissingField {
4499 reason: "saved survival ispline model missing survival_time_keep_cols"
4500 .to_string(),
4501 }
4502 })?;
4503 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4504 if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4505 return Err(FittedModelError::SchemaMismatch {
4506 reason: "saved survival ispline time basis metadata is invalid".to_string(),
4507 });
4508 }
4509 Ok(SurvivalTimeBasisConfig::ISpline {
4510 degree,
4511 knots: Array1::from_vec(knots),
4512 keep_cols,
4513 smooth_lambda,
4514 })
4515 }
4516 other => Err(FittedModelError::IncompatibleConfig {
4517 reason: format!("unsupported saved survival_time_basis '{other}'"),
4518 }),
4519 }
4520}
4521
4522#[cfg(test)]
4523mod tests {
4524 use super::*;
4525 use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4526 use crate::survival::lognormal_kernel::FrailtySpec;
4527 use gam_solve::pirls::PirlsStatus;
4528 use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4529 use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4530 use gam_data::SchemaColumn;
4531 use ndarray::{Array1, Array2, array};
4532
4533 fn empty_termspec() -> TermCollectionSpec {
4534 TermCollectionSpec {
4535 linear_terms: vec![],
4536 random_effect_terms: vec![],
4537 smooth_terms: vec![],
4538 }
4539 }
4540
4541 #[test]
4545 fn spline_scan_payload_round_trips_and_validates() {
4546 let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4547 let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4548 let w = vec![1.0_f64; x.len()];
4549 let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4550 let make_payload = || {
4551 crate::inference::model_payload_builders::assemble_spline_scan_payload(
4552 "y ~ s(x)".to_string(),
4553 "x".to_string(),
4554 &fit,
4555 DataSchema {
4556 columns: vec![
4557 SchemaColumn {
4558 name: "y".to_string(),
4559 kind: ColumnKindTag::Continuous,
4560 levels: vec![],
4561 },
4562 SchemaColumn {
4563 name: "x".to_string(),
4564 kind: ColumnKindTag::Continuous,
4565 levels: vec![],
4566 },
4567 ],
4568 },
4569 vec!["x".to_string()],
4570 vec![(0.0, 1.0)],
4571 )
4572 };
4573 let model = FittedModel::from_payload(make_payload());
4576 model
4577 .validate_for_persistence()
4578 .expect("scan model validates");
4579 model
4580 .validate_numeric_finiteness()
4581 .expect("scan model is finite");
4582
4583 let json = serde_json::to_string(&model).expect("serialize model");
4584 let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4585 restored
4586 .validate_for_persistence()
4587 .expect("restored scan model validates");
4588 let (column, replay) = restored
4589 .saved_spline_scan()
4590 .expect("restore scan fit")
4591 .expect("payload carries the scan representation");
4592 assert_eq!(column, "x");
4593 for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4594 let (m0, v0) = fit.predict(xq).expect("predict original");
4595 let (m1, v1) = replay.predict(xq).expect("predict replayed");
4596 assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4597 assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4598 }
4599
4600 let mut dense = make_payload();
4602 dense.spline_scan = None;
4603 let err = FittedModel::from_payload(dense)
4604 .validate_for_persistence()
4605 .expect_err("dense payload without fit_result must be rejected");
4606 assert!(err.to_string().contains("fit_result"));
4607
4608 let mut corrupt = make_payload();
4610 corrupt
4611 .spline_scan
4612 .as_mut()
4613 .expect("scan channel present")
4614 .state
4615 .knots
4616 .truncate(2);
4617 FittedModel::from_payload(corrupt)
4618 .validate_for_persistence()
4619 .expect_err("corrupt scan state must be rejected");
4620 let mut unnamed = make_payload();
4621 unnamed
4622 .spline_scan
4623 .as_mut()
4624 .expect("scan channel present")
4625 .feature_column
4626 .clear();
4627 FittedModel::from_payload(unnamed)
4628 .validate_for_persistence()
4629 .expect_err("missing feature column must be rejected");
4630 }
4631
4632 fn standard_gaussian_payload() -> FittedModelPayload {
4633 FittedModelPayload::new(
4634 MODEL_PAYLOAD_VERSION,
4635 "y ~ 1".to_string(),
4636 ModelKind::Standard,
4637 FittedFamily::Standard {
4638 likelihood: LikelihoodSpec::gaussian_identity(),
4639 link: Some(StandardLink::Identity),
4640 latent_cloglog_state: None,
4641 mixture_state: None,
4642 sas_state: None,
4643 },
4644 "gaussian".to_string(),
4645 )
4646 }
4647
4648 fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4649 SavedCompiledFlexBlock {
4650 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4651 breakpoints: vec![-1.0, 1.0],
4652 basis_dim,
4653 span_c0: vec![vec![0.0; basis_dim]],
4654 span_c1: vec![vec![0.0; basis_dim]],
4655 span_c2: vec![vec![0.0; basis_dim]],
4656 span_c3: vec![vec![0.0; basis_dim]],
4657 anchor_correction: None,
4658 anchor_components: Vec::new(),
4659 }
4660 }
4661
4662 fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4663 let beta = Array1::from_vec(
4664 blocks
4665 .iter()
4666 .flat_map(|block| block.beta.iter().copied())
4667 .collect(),
4668 );
4669 let p = beta.len();
4670 UnifiedFitResult {
4671 blocks,
4672 log_lambdas: Array1::zeros(0),
4673 lambdas: Array1::zeros(0),
4674 likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4675 likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4676 log_likelihood_normalization: LogLikelihoodNormalization::Full,
4677 log_likelihood: 0.0,
4678 deviance: 0.0,
4679 reml_score: 0.0,
4680 stable_penalty_term: 0.0,
4681 penalized_objective: 0.0,
4682 used_device: false,
4683 outer_iterations: 0,
4684 outer_cost_evals: 0,
4685 inner_pirls_solves: 0,
4686 outer_converged: true,
4687 outer_gradient_norm: None,
4688 standard_deviation: 1.0,
4689 covariance_conditional: Some(Array2::zeros((p, p))),
4690 covariance_corrected: Some(Array2::zeros((p, p))),
4691 inference: None,
4692 fitted_link: FittedLinkState::Standard(None),
4693 geometry: None,
4694 block_states: vec![],
4695 beta,
4696 pirls_status: PirlsStatus::Converged,
4697 max_abs_eta: 0.0,
4698 constraint_kkt: None,
4699 artifacts: FitArtifacts {
4700 pirls: None,
4701 null_space_logdet: None,
4702 null_space_dim: None,
4703 survival_link_wiggle_knots: None,
4704 survival_link_wiggle_degree: None,
4705 criterion_certificate: None,
4706 rho_posterior_certificate: None,
4707 rho_posterior_escalation: None,
4708 rho_covariance: None,
4709 },
4710 inner_cycles: 0,
4711 }
4712 }
4713
4714 fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4715 let mut payload = FittedModelPayload::new(
4716 version,
4717 "y ~ 1".to_string(),
4718 ModelKind::MarginalSlope,
4719 FittedFamily::MarginalSlope {
4720 likelihood: LikelihoodSpec::binomial_probit(),
4721 base_link: InverseLink::Standard(StandardLink::Probit),
4722 frailty: FrailtySpec::None,
4723 },
4724 "bernoulli-marginal-slope".to_string(),
4725 );
4726 payload.fit_result = Some(fit.clone());
4727 payload.unified = Some(fit);
4728 payload.data_schema = Some(DataSchema {
4729 columns: vec![SchemaColumn {
4730 name: "z".to_string(),
4731 kind: ColumnKindTag::Continuous,
4732 levels: vec![],
4733 }],
4734 });
4735 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4736 payload.resolved_termspec = Some(empty_termspec());
4737 payload.resolved_termspec_logslope = Some(empty_termspec());
4738 payload.formula_logslope = Some("1".to_string());
4739 payload.z_column = Some("z".to_string());
4740 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4741 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4742 payload.marginal_baseline = Some(0.0);
4743 payload.logslope_baseline = Some(0.0);
4744 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4745 payload
4746 }
4747
4748 #[test]
4749 fn from_payload_synchronizes_used_device_from_saved_fit() {
4750 let mut fit = saved_fit(vec![
4751 FittedBlock {
4752 beta: Array1::from_vec(vec![0.25]),
4753 role: BlockRole::Mean,
4754 edf: 1.0,
4755 lambdas: Array1::zeros(0),
4756 },
4757 FittedBlock {
4758 beta: Array1::from_vec(vec![0.5]),
4759 role: BlockRole::Scale,
4760 edf: 1.0,
4761 lambdas: Array1::zeros(0),
4762 },
4763 ]);
4764 fit.used_device = true;
4765 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4766 payload.used_device = false;
4767
4768 let model = FittedModel::from_payload(payload);
4769
4770 assert!(model.payload().used_device);
4771 }
4772
4773 fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4774 let mut payload = FittedModelPayload::new(
4775 version,
4776 "Surv(entry, exit, event) ~ 1".to_string(),
4777 ModelKind::Survival,
4778 FittedFamily::Survival {
4779 likelihood: LikelihoodSpec::royston_parmar(),
4780 survival_likelihood: Some("marginal-slope".to_string()),
4781 survival_distribution: Some(ResidualDistribution::Gaussian),
4782 frailty: FrailtySpec::None,
4783 },
4784 "survival".to_string(),
4785 );
4786 payload.fit_result = Some(fit.clone());
4787 payload.unified = Some(fit);
4788 payload.survival_likelihood = Some("marginal-slope".to_string());
4789 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4790 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4791 payload.data_schema = Some(DataSchema {
4792 columns: vec![SchemaColumn {
4793 name: "z".to_string(),
4794 kind: ColumnKindTag::Continuous,
4795 levels: vec![],
4796 }],
4797 });
4798 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4799 payload.resolved_termspec = Some(empty_termspec());
4800 payload.resolved_termspec_logslope = Some(empty_termspec());
4801 payload.formula_logslope = Some("1".to_string());
4802 payload.z_column = Some("z".to_string());
4803 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4804 payload.logslope_baseline = Some(0.0);
4805 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4806 payload
4807 }
4808
4809 fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4810 let mut payload = FittedModelPayload::new(
4816 MODEL_PAYLOAD_VERSION,
4817 "y ~ x".to_string(),
4818 ModelKind::LocationScale,
4819 FittedFamily::LocationScale {
4820 likelihood: LikelihoodSpec::gamma_log(),
4821 base_link: Some(InverseLink::Standard(StandardLink::Log)),
4822 },
4823 "gamma-location-scale".to_string(),
4824 );
4825 payload.data_schema = Some(DataSchema {
4826 columns: vec![
4827 SchemaColumn {
4828 name: "y".to_string(),
4829 kind: ColumnKindTag::Continuous,
4830 levels: vec![],
4831 },
4832 SchemaColumn {
4833 name: "x".to_string(),
4834 kind: ColumnKindTag::Continuous,
4835 levels: vec![],
4836 },
4837 ],
4838 });
4839 payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4840 payload.resolved_termspec = Some(empty_termspec());
4841 payload.resolved_termspec_noise = Some(empty_termspec());
4842 payload.formula_noise = Some("x".to_string());
4843 payload.beta_noise = Some(vec![0.0]);
4844 payload.link = Some(InverseLink::Standard(StandardLink::Log));
4845 payload
4846 }
4847
4848 #[test]
4855 fn dispersion_location_scale_payload_is_not_classified_binomial() {
4856 let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4857 assert_eq!(
4858 model.predict_model_class(),
4859 PredictModelClass::DispersionLocationScale,
4860 "Gamma dispersion location-scale must route through the dispersion \
4861 predictor, not the binomial threshold-scale class",
4862 );
4863 assert!(
4864 !matches!(
4865 model.predict_model_class(),
4866 PredictModelClass::BinomialLocationScale
4867 ),
4868 "dispersion location-scale must never be classified as binomial",
4869 );
4870
4871 for likelihood in [
4873 LikelihoodSpec::gamma_log(),
4874 LikelihoodSpec::new(
4875 ResponseFamily::NegativeBinomial {
4876 theta: 1.0,
4877 theta_fixed: false,
4878 },
4879 InverseLink::Standard(StandardLink::Log),
4880 ),
4881 LikelihoodSpec::new(
4882 ResponseFamily::Beta { phi: 1.0 },
4883 InverseLink::Standard(StandardLink::Logit),
4884 ),
4885 LikelihoodSpec::new(
4886 ResponseFamily::Tweedie { p: 1.5 },
4887 InverseLink::Standard(StandardLink::Log),
4888 ),
4889 ] {
4890 let mut payload = gamma_dispersion_location_scale_payload();
4891 payload.family_state = FittedFamily::LocationScale {
4892 base_link: Some(likelihood.link.clone()),
4893 likelihood: likelihood.clone(),
4894 };
4895 let model = FittedModel::from_payload(payload);
4896 assert_eq!(
4897 model.predict_model_class(),
4898 PredictModelClass::DispersionLocationScale,
4899 "dispersion family {:?} mis-classified",
4900 likelihood.response,
4901 );
4902 }
4903 }
4904
4905 #[test]
4906 fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4907 let data = array![[100.0], [-100.0]];
4908 let col_map = HashMap::from([("g".to_string(), 0usize)]);
4909
4910 let mut plain_payload = standard_gaussian_payload();
4911 plain_payload.data_schema = Some(DataSchema {
4912 columns: vec![SchemaColumn {
4913 name: "g".to_string(),
4914 kind: ColumnKindTag::Continuous,
4915 levels: vec![],
4916 }],
4917 });
4918 plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4919 plain_payload.resolved_termspec = Some(empty_termspec());
4920 let plain = FittedModel::from_payload(plain_payload.clone());
4921 let clipped = plain
4922 .axis_clip_to_training_ranges(data.view(), &col_map)
4923 .expect("ordinary continuous axis should clip outside the training range");
4924 assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4925
4926 let mut group_payload = plain_payload;
4927 let mut group_spec = empty_termspec();
4928 group_spec
4929 .random_effect_terms
4930 .push(gam_terms::smooth::RandomEffectTermSpec {
4931 name: "g".to_string(),
4932 feature_col: 0,
4933 drop_first_level: false,
4934 penalized: true,
4935 frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4936 });
4937 group_payload.resolved_termspec = Some(group_spec);
4938 let group_model = FittedModel::from_payload(group_payload);
4939
4940 assert_eq!(
4941 group_model.random_effect_group_columns(),
4942 HashSet::from(["g".to_string()])
4943 );
4944
4945 assert_eq!(
4946 group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4947 None,
4948 "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4949 );
4950 }
4951
4952 #[test]
4953 fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4954 let fit = saved_fit(vec![
4955 FittedBlock {
4956 beta: array![0.1],
4957 role: BlockRole::Mean,
4958 edf: 1.0,
4959 lambdas: Array1::zeros(0),
4960 },
4961 FittedBlock {
4962 beta: array![0.2],
4963 role: BlockRole::Scale,
4964 edf: 1.0,
4965 lambdas: Array1::zeros(0),
4966 },
4967 FittedBlock {
4968 beta: array![0.3],
4969 role: BlockRole::Mean,
4970 edf: 1.0,
4971 lambdas: Array1::zeros(0),
4972 },
4973 ]);
4974 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4975 payload.score_warp_runtime = Some(anchored_runtime(2));
4976
4977 let err = FittedModel::from_payload(payload)
4978 .validate_for_persistence()
4979 .expect_err("marginal-slope score-warp basis mismatch should fail validation");
4980 assert!(err.to_string().contains("score-warp coefficient mismatch"));
4981 }
4982
4983 #[test]
4984 fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
4985 let fit = saved_fit(vec![
4986 FittedBlock {
4987 beta: array![0.1],
4988 role: BlockRole::Time,
4989 edf: 1.0,
4990 lambdas: Array1::zeros(0),
4991 },
4992 FittedBlock {
4993 beta: array![0.2],
4994 role: BlockRole::Mean,
4995 edf: 1.0,
4996 lambdas: Array1::zeros(0),
4997 },
4998 FittedBlock {
4999 beta: array![0.3],
5000 role: BlockRole::Scale,
5001 edf: 1.0,
5002 lambdas: Array1::zeros(0),
5003 },
5004 FittedBlock {
5005 beta: array![0.4],
5006 role: BlockRole::LinkWiggle,
5007 edf: 1.0,
5008 lambdas: Array1::zeros(0),
5009 },
5010 ]);
5011 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5012 payload.link_deviation_runtime = Some(anchored_runtime(2));
5013
5014 let err = FittedModel::from_payload(payload)
5015 .saved_prediction_runtime()
5016 .expect_err(
5017 "survival marginal-slope link basis mismatch should fail runtime validation",
5018 );
5019 assert!(
5020 err.to_string()
5021 .contains("link-deviation coefficient mismatch")
5022 );
5023 }
5024
5025 #[test]
5026 fn apply_survival_time_basis_writes_all_required_fields() {
5027 use crate::survival::construction::SavedSurvivalTimeBasis;
5028
5029 let fit = saved_fit(vec![
5030 FittedBlock {
5031 beta: array![0.1],
5032 role: BlockRole::Time,
5033 edf: 1.0,
5034 lambdas: Array1::zeros(0),
5035 },
5036 FittedBlock {
5037 beta: array![0.2],
5038 role: BlockRole::Mean,
5039 edf: 1.0,
5040 lambdas: Array1::zeros(0),
5041 },
5042 FittedBlock {
5043 beta: array![0.3],
5044 role: BlockRole::Scale,
5045 edf: 1.0,
5046 lambdas: Array1::zeros(0),
5047 },
5048 ]);
5049 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5050
5051 let snapshot = SavedSurvivalTimeBasis {
5056 basisname: "royston-parmar".to_string(),
5057 degree: Some(3),
5058 knots: Some(vec![0.0, 1.0, 2.0]),
5059 keep_cols: Some(vec![0, 2]),
5060 smooth_lambda: Some(0.5),
5061 anchor: 0.25,
5062 };
5063 payload.apply_survival_time_basis(&snapshot);
5064
5065 assert_eq!(
5066 payload.survival_time_basis.as_deref(),
5067 Some("royston-parmar")
5068 );
5069 assert_eq!(payload.survival_time_degree, Some(3));
5070 assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5071 assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5072 assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5073 assert_eq!(payload.survival_time_anchor, Some(0.25));
5074 }
5075
5076 #[test]
5077 fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5078 let fit = saved_fit(vec![
5079 FittedBlock {
5080 beta: array![0.1],
5081 role: BlockRole::Time,
5082 edf: 1.0,
5083 lambdas: Array1::zeros(0),
5084 },
5085 FittedBlock {
5086 beta: array![0.2],
5087 role: BlockRole::Mean,
5088 edf: 1.0,
5089 lambdas: Array1::zeros(0),
5090 },
5091 FittedBlock {
5092 beta: array![0.3],
5093 role: BlockRole::Scale,
5094 edf: 1.0,
5095 lambdas: Array1::zeros(0),
5096 },
5097 ]);
5098 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5099 payload.survival_time_basis = Some("ispline".to_string());
5105
5106 let err = FittedModel::from_payload(payload)
5107 .validate_for_persistence()
5108 .expect_err("survival model without time-anchor metadata should fail validation");
5109 assert!(err.to_string().contains("missing survival_time_anchor"));
5110 }
5111
5112 #[test]
5113 fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5114 let fit = saved_fit(vec![
5115 FittedBlock {
5116 beta: array![0.1],
5117 role: BlockRole::Time,
5118 edf: 1.0,
5119 lambdas: Array1::zeros(0),
5120 },
5121 FittedBlock {
5122 beta: array![0.2],
5123 role: BlockRole::Mean,
5124 edf: 1.0,
5125 lambdas: Array1::zeros(0),
5126 },
5127 FittedBlock {
5128 beta: array![0.3],
5129 role: BlockRole::Scale,
5130 edf: 1.0,
5131 lambdas: Array1::zeros(0),
5132 },
5133 ]);
5134 let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5135
5136 let err = FittedModel::from_payload(payload)
5137 .validate_for_persistence()
5138 .expect_err("survival model without time-basis metadata should fail validation");
5139 assert!(err.to_string().contains("missing survival_time_basis"));
5140 }
5141
5142 #[test]
5143 fn saved_prediction_runtime_rejects_stale_payload_version() {
5144 let fit = saved_fit(vec![
5145 FittedBlock {
5146 beta: array![0.1],
5147 role: BlockRole::Mean,
5148 edf: 1.0,
5149 lambdas: Array1::zeros(0),
5150 },
5151 FittedBlock {
5152 beta: array![0.2],
5153 role: BlockRole::Scale,
5154 edf: 1.0,
5155 lambdas: Array1::zeros(0),
5156 },
5157 ]);
5158 let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5159
5160 let err = FittedModel::from_payload(payload)
5161 .saved_prediction_runtime()
5162 .expect_err("stale payload version should fail before runtime assembly");
5163 assert!(err.to_string().contains("payload schema mismatch"));
5164 }
5165}