1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4 LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7 SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12 monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15 inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16 parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21 InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22 SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72 pub feature_column: String,
74 pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83 pub feature_columns: Vec<String>,
85 pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99 SchemaMismatch { reason: String },
103 PayloadCorrupt { reason: String },
107 MissingField { reason: String },
110 IncompatibleConfig { reason: String },
114 InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120 FittedModelError {
121 SchemaMismatch,
122 PayloadCorrupt,
123 MissingField,
124 IncompatibleConfig,
125 InvalidInput,
126 }
127}
128
129impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134 fn from(err: FittedModelError) -> Self {
135 gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136 }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140 fn from(err: FittedModelError) -> Self {
141 crate::survival::predict::SurvivalPredictError::ModelPayload {
142 context: "saved-model survival prediction payload",
143 source: err,
144 }
145 }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150 pub mean: f64,
151 pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156 if !self.mean.is_finite() {
157 return Err(FittedModelError::PayloadCorrupt {
158 reason: format!("{context} latent z mean must be finite"),
159 });
160 }
161 if !(self.sd.is_finite() && self.sd > 1e-12) {
162 return Err(FittedModelError::PayloadCorrupt {
163 reason: format!(
164 "{context} latent z sd must be finite and > 1e-12; got {}",
165 self.sd
166 ),
167 });
168 }
169 Ok::<(), _>(())
170 }
171
172 pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173 self.validate(context)?;
174 if z.iter().any(|value| !value.is_finite()) {
175 return Err(FittedModelError::PayloadCorrupt {
176 reason: format!("{context} requires finite z values"),
177 });
178 }
179 Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180 }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189 #[default]
190 FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195 #[serde(default)]
196 pub score_kind: TransformationScoreKind,
197 #[serde(default = "default_transformation_score_pit_clip_eps")]
198 pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202 TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206 pub fn finite_support_pit() -> Self {
207 Self {
208 score_kind: TransformationScoreKind::FiniteSupportPit,
209 clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210 }
211 }
212
213 pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214 if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215 return Err(FittedModelError::IncompatibleConfig {
216 reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217 });
218 }
219 if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220 return Err(FittedModelError::IncompatibleConfig {
221 reason: format!(
222 "{context} requires PIT clip_eps in (0, 0.5), got {}",
223 self.clip_eps
224 ),
225 });
226 }
227 Ok::<(), _>(())
228 }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233 pub version: u32,
234 pub formula: String,
235 pub model_kind: ModelKind,
236 pub family_state: FittedFamily,
237 pub family: String,
238 #[serde(default)]
247 pub inference_notes: Vec<String>,
248 #[serde(default)]
249 pub used_device: bool,
250 #[serde(default)]
251 pub fit_result: Option<UnifiedFitResult>,
252 #[serde(default)]
254 pub unified: Option<UnifiedFitResult>,
255 #[serde(default)]
263 pub spline_scan: Option<SavedSplineScan>,
264 #[serde(default)]
270 pub residual_cascade: Option<SavedResidualCascade>,
271 #[serde(default)]
272 pub data_schema: Option<DataSchema>,
273 pub link: Option<InverseLink>,
274 #[serde(default)]
275 pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276 #[serde(default)]
277 pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278 #[serde(default)]
279 pub formula_noise: Option<String>,
280 #[serde(default)]
281 pub formula_logslope: Option<String>,
282 #[serde(default)]
283 pub formula_logslopes: Option<Vec<String>>,
284 #[serde(default)]
285 pub offset_column: Option<String>,
286 #[serde(default)]
287 pub noise_offset_column: Option<String>,
288 #[serde(default)]
289 pub beta_noise: Option<Vec<f64>>,
290 #[serde(default)]
291 pub noise_projection: Option<Vec<Vec<f64>>>,
292 #[serde(default)]
293 pub noise_center: Option<Vec<f64>>,
294 #[serde(default)]
295 pub noise_scale: Option<Vec<f64>>,
296 #[serde(default)]
297 pub noise_non_intercept_start: Option<usize>,
298 #[serde(default)]
302 pub noise_projection_ridge_alpha: Option<f64>,
303 #[serde(default)]
304 pub gaussian_response_scale: Option<f64>,
305 #[serde(default)]
306 pub linkwiggle_knots: Option<Vec<f64>>,
307 #[serde(default)]
308 pub linkwiggle_degree: Option<usize>,
309 #[serde(default)]
310 pub beta_link_wiggle: Option<Vec<f64>>,
311 #[serde(default)]
312 pub baseline_timewiggle_knots: Option<Vec<f64>>,
313 #[serde(default)]
314 pub baseline_timewiggle_degree: Option<usize>,
315 #[serde(default)]
316 pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
317 #[serde(default)]
318 pub baseline_timewiggle_double_penalty: Option<bool>,
319 #[serde(default)]
320 pub beta_baseline_timewiggle: Option<Vec<f64>>,
321 #[serde(default)]
322 pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
323 #[serde(default)]
324 pub z_column: Option<String>,
325 #[serde(default)]
326 pub z_columns: Option<Vec<String>>,
327 #[serde(default)]
328 pub latent_z_normalization: Option<SavedLatentZNormalization>,
329 #[serde(default)]
330 pub latent_score_contract: Option<SavedLatentScoreContract>,
331 #[serde(default)]
332 pub latent_measure: Option<LatentMeasureKind>,
333 #[serde(default)]
340 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
341 #[serde(default)]
350 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
351 #[serde(default)]
352 pub marginal_baseline: Option<f64>,
353 #[serde(default)]
354 pub logslope_baseline: Option<f64>,
355 #[serde(default)]
356 pub logslope_baselines: Option<Vec<f64>>,
357 #[serde(default)]
358 pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
359 #[serde(default)]
360 pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
361 #[serde(default)]
366 pub influence_absorber_width: Option<usize>,
367 #[serde(default)]
368 pub survival_entry: Option<String>,
369 #[serde(default)]
370 pub survival_exit: Option<String>,
371 #[serde(default)]
372 pub survival_event: Option<String>,
373 #[serde(default)]
374 pub survivalspec: Option<String>,
375 #[serde(default)]
376 pub survival_cause_count: Option<usize>,
377 #[serde(default)]
378 pub survival_endpoint_names: Option<Vec<String>>,
379 #[serde(default)]
380 pub survival_baseline_target: Option<String>,
381 #[serde(default)]
382 pub survival_baseline_scale: Option<f64>,
383 #[serde(default)]
384 pub survival_baseline_shape: Option<f64>,
385 #[serde(default)]
386 pub survival_baseline_rate: Option<f64>,
387 #[serde(default)]
388 pub survival_baseline_makeham: Option<f64>,
389 #[serde(default)]
390 pub survival_time_basis: Option<String>,
391 #[serde(default)]
392 pub survival_time_degree: Option<usize>,
393 #[serde(default)]
394 pub survival_time_knots: Option<Vec<f64>>,
395 #[serde(default)]
396 pub survival_time_keep_cols: Option<Vec<usize>>,
397 #[serde(default)]
398 pub survival_time_smooth_lambda: Option<f64>,
399 #[serde(default)]
400 pub survival_time_anchor: Option<f64>,
401 #[serde(default)]
402 pub survivalridge_lambda: Option<f64>,
403 #[serde(default)]
404 pub survival_likelihood: Option<String>,
405 #[serde(default)]
406 pub survival_beta_time: Option<Vec<f64>>,
407 #[serde(default)]
408 pub survival_beta_threshold: Option<Vec<f64>>,
409 #[serde(default)]
410 pub survival_beta_log_sigma: Option<Vec<f64>>,
411 #[serde(default)]
412 pub survival_noise_projection: Option<Vec<Vec<f64>>>,
413 #[serde(default)]
414 pub survival_noise_center: Option<Vec<f64>>,
415 #[serde(default)]
416 pub survival_noise_scale: Option<Vec<f64>>,
417 #[serde(default)]
418 pub survival_noise_non_intercept_start: Option<usize>,
419 #[serde(default)]
422 pub survival_noise_projection_ridge_alpha: Option<f64>,
423 #[serde(default)]
424 pub survival_distribution: Option<ResidualDistribution>,
425 #[serde(default)]
426 pub training_headers: Option<Vec<String>>,
427 #[serde(default, skip_serializing_if = "Option::is_none")]
439 pub training_table_kind: Option<String>,
440 #[serde(default)]
448 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
449 #[serde(default, skip_serializing_if = "Option::is_none")]
455 pub group_metadata: Option<GroupMetadata>,
456 #[serde(default, skip_serializing_if = "Vec::is_empty")]
463 pub deployment_extensions: Vec<SavedDeploymentExtension>,
464 #[serde(default)]
466 pub transformation_response_knots: Option<Vec<f64>>,
467 #[serde(default)]
469 pub transformation_response_transform: Option<Vec<Vec<f64>>>,
470 #[serde(default)]
472 pub transformation_response_degree: Option<usize>,
473 #[serde(default)]
475 pub transformation_response_median: Option<f64>,
476 #[serde(default)]
480 pub transformation_score_calibration: Option<TransformationScoreCalibration>,
481 #[serde(default)]
482 pub resolved_termspec: Option<TermCollectionSpec>,
483 #[serde(default)]
484 pub resolved_termspec_noise: Option<TermCollectionSpec>,
485 #[serde(default)]
486 pub resolved_termspec_logslope: Option<TermCollectionSpec>,
487 #[serde(default)]
488 pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
489 #[serde(default)]
490 pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
491 #[serde(default)]
503 pub gaussian_jackknife_plus:
504 Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
505 #[serde(default)]
519 pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
523pub struct SavedDeploymentExtension {
524 pub name: String,
525 pub kind: String,
526 pub term: String,
527 pub level: JsonValue,
528 pub level_bits: u64,
529 pub coefficient_index: usize,
530 pub coefficient_mean: f64,
531 pub coefficient_variance: f64,
532 #[serde(default, skip_serializing_if = "Option::is_none")]
533 pub metadata: Option<JsonValue>,
534 #[serde(default, skip_serializing_if = "Option::is_none")]
535 pub prior: Option<JsonValue>,
536}
537
538pub fn append_deployment_extension_columns(
555 model: &FittedModelPayload,
556 data: ndarray::ArrayView2<'_, f64>,
557 col_map: &HashMap<String, usize>,
558 training_headers: Option<&Vec<String>>,
559 base_design: Array2<f64>,
560) -> Result<Array2<f64>, FittedModelError> {
561 if model.deployment_extensions.is_empty() {
562 return Ok(base_design);
563 }
564 if base_design.nrows() != data.nrows() {
565 return Err(FittedModelError::SchemaMismatch {
566 reason: format!(
567 "deployment extension design row mismatch: base design has {} rows but data has {}",
568 base_design.nrows(),
569 data.nrows()
570 ),
571 });
572 }
573 let spec = model
574 .resolved_termspec
575 .as_ref()
576 .ok_or_else(|| FittedModelError::MissingField {
577 reason: "deployment extension prediction requires saved resolved_termspec; refit"
578 .to_string(),
579 })?;
580 let n = base_design.nrows();
581 let p_old = base_design.ncols();
582 let mut extensions: Vec<&SavedDeploymentExtension> =
583 model.deployment_extensions.iter().collect();
584 extensions.sort_by_key(|extension| extension.coefficient_index);
585 for (tail_idx, extension) in extensions.iter().enumerate() {
586 let expected = p_old + tail_idx;
587 if extension.coefficient_index != expected {
588 return Err(FittedModelError::SchemaMismatch {
589 reason: format!(
590 "deployment extension '{}' has coefficient index {}, expected append-only index {}",
591 extension.name, extension.coefficient_index, expected
592 ),
593 });
594 }
595 }
596
597 let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
598 out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
599 for (tail_idx, extension) in extensions.into_iter().enumerate() {
600 if extension.kind != "random-effect-level" {
601 return Err(FittedModelError::IncompatibleConfig {
602 reason: format!(
603 "unsupported deployment extension kind '{}' for '{}'",
604 extension.kind, extension.name
605 ),
606 });
607 }
608 let term = spec
609 .random_effect_terms
610 .iter()
611 .find(|term| term.name == extension.term)
612 .ok_or_else(|| FittedModelError::MissingField {
613 reason: format!(
614 "deployment extension '{}' references unknown random-effect term '{}'",
615 extension.name, extension.term
616 ),
617 })?;
618 let prediction_col = training_headers
619 .and_then(|headers| headers.get(term.feature_col))
620 .and_then(|name| col_map.get(name))
621 .copied()
622 .unwrap_or(term.feature_col);
623 if prediction_col >= data.ncols() {
624 return Err(FittedModelError::SchemaMismatch {
625 reason: format!(
626 "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
627 extension.name,
628 prediction_col,
629 data.ncols()
630 ),
631 });
632 }
633 let col = p_old + tail_idx;
634 for row in 0..n {
635 if data[[row, prediction_col]].to_bits() == extension.level_bits {
636 out[[row, col]] = 1.0;
637 }
638 }
639 }
640 Ok(out)
641}
642
643#[derive(Clone, Debug, Serialize, Deserialize)]
644pub struct SavedLatentScoreContract {
645 pub semantics: String,
646 pub source_transform_id: Option<String>,
647 pub normalization_mean: f64,
648 pub normalization_sd: f64,
649 pub clip_eps: Option<f64>,
650 pub conditioning_columns: Vec<String>,
651}
652
653impl FittedModelPayload {
654 pub fn new(
655 version: u32,
656 formula: String,
657 model_kind: ModelKind,
658 family_state: FittedFamily,
659 family: String,
660 ) -> Self {
661 Self {
662 version,
663 formula,
664 model_kind,
665 family_state,
666 family,
667 inference_notes: Vec::new(),
668 used_device: false,
669 fit_result: None,
670 unified: None,
671 spline_scan: None,
672 residual_cascade: None,
673 data_schema: None,
674 link: None,
675 mixture_link_param_covariance: None,
676 sas_param_covariance: None,
677 formula_noise: None,
678 formula_logslope: None,
679 formula_logslopes: None,
680 offset_column: None,
681 noise_offset_column: None,
682 beta_noise: None,
683 noise_projection: None,
684 noise_center: None,
685 noise_scale: None,
686 noise_non_intercept_start: None,
687 noise_projection_ridge_alpha: None,
688 gaussian_response_scale: None,
689 linkwiggle_knots: None,
690 linkwiggle_degree: None,
691 beta_link_wiggle: None,
692 baseline_timewiggle_knots: None,
693 baseline_timewiggle_degree: None,
694 baseline_timewiggle_penalty_orders: None,
695 baseline_timewiggle_double_penalty: None,
696 beta_baseline_timewiggle: None,
697 beta_baseline_timewiggle_by_cause: None,
698 z_column: None,
699 z_columns: None,
700 latent_z_normalization: None,
701 latent_score_contract: None,
702 latent_measure: None,
703 latent_z_rank_int_calibration: None,
704 latent_z_conditional_calibration: None,
705 marginal_baseline: None,
706 logslope_baseline: None,
707 logslope_baselines: None,
708 score_warp_runtime: None,
709 link_deviation_runtime: None,
710 influence_absorber_width: None,
711 survival_entry: None,
712 survival_exit: None,
713 survival_event: None,
714 survivalspec: None,
715 survival_cause_count: None,
716 survival_endpoint_names: None,
717 survival_baseline_target: None,
718 survival_baseline_scale: None,
719 survival_baseline_shape: None,
720 survival_baseline_rate: None,
721 survival_baseline_makeham: None,
722 survival_time_basis: None,
723 survival_time_degree: None,
724 survival_time_knots: None,
725 survival_time_keep_cols: None,
726 survival_time_smooth_lambda: None,
727 survival_time_anchor: None,
728 survivalridge_lambda: None,
729 survival_likelihood: None,
730 survival_beta_time: None,
731 survival_beta_threshold: None,
732 survival_beta_log_sigma: None,
733 survival_noise_projection: None,
734 survival_noise_center: None,
735 survival_noise_scale: None,
736 survival_noise_non_intercept_start: None,
737 survival_noise_projection_ridge_alpha: None,
738 survival_distribution: None,
739 training_headers: None,
740 training_table_kind: None,
741 training_feature_ranges: None,
742 group_metadata: None,
743 deployment_extensions: Vec::new(),
744 transformation_response_knots: None,
745 transformation_response_transform: None,
746 transformation_response_degree: None,
747 transformation_response_median: None,
748 transformation_score_calibration: None,
749 resolved_termspec: None,
750 resolved_termspec_noise: None,
751 resolved_termspec_logslope: None,
752 resolved_termspec_logslopes: None,
753 adaptive_regularization_diagnostics: None,
754 gaussian_jackknife_plus: None,
755 full_conformal: None,
756 }
757 }
758
759 pub fn set_training_feature_metadata(
760 &mut self,
761 headers: Vec<String>,
762 feature_ranges: Vec<(f64, f64)>,
763 ) {
764 self.training_headers = Some(headers);
765 self.training_feature_ranges = Some(feature_ranges);
766 }
767
768 fn synchronize_empty_feature_contract(&mut self) {
769 if self.fit_result.is_none() {
770 return;
771 }
772 let Some(schema) = self.data_schema.as_ref() else {
773 return;
774 };
775 if !schema.columns.is_empty() {
776 return;
777 }
778 self.training_headers.get_or_insert_with(Vec::new);
779 self.resolved_termspec
780 .get_or_insert_with(|| TermCollectionSpec {
781 linear_terms: Vec::new(),
782 smooth_terms: Vec::new(),
783 random_effect_terms: Vec::new(),
784 });
785 }
786
787 pub fn apply_survival_time_basis(
795 &mut self,
796 snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
797 ) {
798 self.survival_time_basis = Some(snapshot.basisname.clone());
799 self.survival_time_degree = snapshot.degree;
800 self.survival_time_knots = snapshot.knots.clone();
801 self.survival_time_keep_cols = snapshot.keep_cols.clone();
802 self.survival_time_smooth_lambda = snapshot.smooth_lambda;
803 self.survival_time_anchor = Some(snapshot.anchor);
804 }
805
806 fn validate_payload_version(&self) -> Result<(), FittedModelError> {
807 if self.version != MODEL_PAYLOAD_VERSION {
808 return Err(FittedModelError::SchemaMismatch {
809 reason: format!(
810 "saved model payload schema mismatch: file has version={}, \
811 this binary expects MODEL_PAYLOAD_VERSION={}. \
812 Refit with the current CLI, or rebuild the reader at the same \
813 version the model was written with.",
814 self.version, MODEL_PAYLOAD_VERSION
815 ),
816 });
817 }
818 Ok(())
819 }
820}
821
822#[derive(Clone, Serialize, Deserialize)]
823#[serde(tag = "model_type", rename_all = "kebab-case")]
824pub enum FittedModel {
825 Standard { payload: FittedModelPayload },
826 LocationScale { payload: FittedModelPayload },
827 MarginalSlope { payload: FittedModelPayload },
828 Survival { payload: FittedModelPayload },
829 TransformationNormal { payload: FittedModelPayload },
830}
831
832#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
833#[serde(rename_all = "kebab-case")]
834pub enum ModelKind {
835 Standard,
836 LocationScale,
837 MarginalSlope,
838 Survival,
839 TransformationNormal,
840}
841
842#[derive(Clone, Debug, Serialize, Deserialize)]
843#[serde(tag = "family_kind", rename_all = "kebab-case")]
844pub enum FittedFamily {
845 Standard {
846 likelihood: LikelihoodSpec,
847 #[serde(default)]
848 link: Option<StandardLink>,
849 #[serde(default)]
850 latent_cloglog_state: Option<LatentCLogLogState>,
851 #[serde(default)]
852 mixture_state: Option<MixtureLinkState>,
853 #[serde(default)]
854 sas_state: Option<SasLinkState>,
855 },
856 LocationScale {
857 likelihood: LikelihoodSpec,
858 #[serde(default)]
859 base_link: Option<InverseLink>,
860 },
861 MarginalSlope {
862 likelihood: LikelihoodSpec,
863 base_link: InverseLink,
864 frailty: FrailtySpec,
865 },
866 Survival {
867 likelihood: LikelihoodSpec,
868 #[serde(default)]
869 survival_likelihood: Option<String>,
870 #[serde(default)]
871 survival_distribution: Option<ResidualDistribution>,
872 frailty: FrailtySpec,
873 },
874 LatentSurvival {
875 frailty: FrailtySpec,
876 },
877 LatentBinary {
878 frailty: FrailtySpec,
879 },
880 TransformationNormal {
881 likelihood: LikelihoodSpec,
882 },
883}
884
885#[derive(Clone, Copy, Debug, Eq, PartialEq)]
886pub enum PredictModelClass {
887 Standard,
888 GaussianLocationScale,
889 BinomialLocationScale,
890 DispersionLocationScale,
895 BernoulliMarginalSlope,
896 Survival,
897 TransformationNormal,
898}
899
900impl PredictModelClass {
901 #[inline]
902 pub const fn name(self) -> &'static str {
903 match self {
904 Self::Standard => "standard",
905 Self::GaussianLocationScale => "gaussian location-scale",
906 Self::BinomialLocationScale => "binomial location-scale",
907 Self::DispersionLocationScale => "dispersion location-scale",
908 Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
909 Self::Survival => "survival",
910 Self::TransformationNormal => "transformation-normal",
911 }
912 }
913}
914
915#[derive(Clone, Debug)]
916pub struct SavedLinkWiggleRuntime {
917 pub knots: Vec<f64>,
918 pub degree: usize,
919 pub beta: Vec<f64>,
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedBaselineTimeWiggleRuntime {
924 pub knots: Vec<f64>,
925 pub degree: usize,
926 pub penalty_orders: Vec<usize>,
927 pub double_penalty: bool,
928 pub beta: Vec<f64>,
929}
930
931pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
934
935#[derive(Clone, Debug, Serialize, Deserialize)]
936pub struct SavedCompiledFlexBlock {
937 pub kernel: String,
938 pub breakpoints: Vec<f64>,
939 pub basis_dim: usize,
940 pub span_c0: Vec<Vec<f64>>,
941 pub span_c1: Vec<Vec<f64>>,
942 pub span_c2: Vec<Vec<f64>>,
943 pub span_c3: Vec<Vec<f64>>,
944 #[serde(default)]
950 pub anchor_correction: Option<Vec<Vec<f64>>>,
951 #[serde(default)]
955 pub anchor_components: Vec<SavedAnchorComponent>,
956}
957
958#[derive(Clone, Debug, Serialize, Deserialize)]
959pub struct SavedAnchorComponent {
960 pub kind: SavedAnchorKind,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub enum SavedAnchorKind {
965 Parametric {
966 block: ParametricAnchorBlock,
967 ncols: usize,
968 },
969 FlexEvaluation { ncols: usize },
974}
975
976#[derive(Clone, Debug)]
977pub struct SavedPredictionRuntime {
978 pub model_class: PredictModelClass,
979 pub likelihood: LikelihoodSpec,
980 pub inverse_link: Option<InverseLink>,
981 pub link_wiggle: Option<SavedLinkWiggleRuntime>,
982 pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
983 pub score_warp: Option<SavedCompiledFlexBlock>,
984 pub link_deviation: Option<SavedCompiledFlexBlock>,
985 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
989 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
995 pub influence_absorber_width: Option<usize>,
1005}
1006
1007pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1008 fit.block_by_role(BlockRole::Location)
1009 .or_else(|| fit.block_by_role(BlockRole::Mean))
1010 .map(|block| block.beta.clone())
1011}
1012
1013pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1014 fit.block_by_role(BlockRole::Threshold)
1015 .or_else(|| fit.block_by_role(BlockRole::Location))
1016 .or_else(|| fit.block_by_role(BlockRole::Mean))
1017 .map(|block| block.beta.clone())
1018}
1019
1020pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021 fit.block_by_role(BlockRole::Scale)
1022 .map(|block| block.beta.clone())
1023}
1024
1025fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1034 use gam_problem::types::ResponseFamily;
1035 matches!(
1036 response,
1037 ResponseFamily::NegativeBinomial { .. }
1038 | ResponseFamily::Gamma
1039 | ResponseFamily::Beta { .. }
1040 | ResponseFamily::Tweedie { .. }
1041 )
1042}
1043
1044fn validate_location_scale_saved_fit(
1045 fit: &UnifiedFitResult,
1046 model_class: PredictModelClass,
1047 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1048) -> Result<(), FittedModelError> {
1049 let primary = match model_class {
1050 PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1054 gaussian_location_scale_mean_beta(fit)
1055 }
1056 PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1057 _ => None,
1058 }
1059 .ok_or_else(|| FittedModelError::MissingField {
1060 reason: match model_class {
1061 PredictModelClass::GaussianLocationScale => {
1062 "gaussian-location-scale saved fit is missing mean/location block".to_string()
1063 }
1064 PredictModelClass::DispersionLocationScale => {
1065 "dispersion-location-scale saved fit is missing mean/location block".to_string()
1066 }
1067 PredictModelClass::BinomialLocationScale => {
1068 "binomial-location-scale saved fit is missing threshold/location block".to_string()
1069 }
1070 _ => "location-scale saved fit is missing primary block".to_string(),
1071 },
1072 })?;
1073
1074 let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1075 reason: "location-scale saved fit is missing scale block".to_string(),
1076 })?;
1077 let expected =
1078 primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1079
1080 if let Some(cov) = fit.beta_covariance()
1081 && (cov.nrows() != expected || cov.ncols() != expected)
1082 {
1083 return Err(FittedModelError::SchemaMismatch {
1084 reason: format!(
1085 "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1086 cov.nrows(),
1087 cov.ncols(),
1088 expected,
1089 expected
1090 ),
1091 });
1092 }
1093 if let Some(cov) = fit.beta_covariance_corrected()
1094 && (cov.nrows() != expected || cov.ncols() != expected)
1095 {
1096 return Err(FittedModelError::SchemaMismatch {
1097 reason: format!(
1098 "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1099 cov.nrows(),
1100 cov.ncols(),
1101 expected,
1102 expected
1103 ),
1104 });
1105 }
1106 Ok(())
1107}
1108
1109fn validate_survival_saved_block_matches_payload(
1110 fit: &UnifiedFitResult,
1111 role: BlockRole,
1112 payload_beta: Option<&Vec<f64>>,
1113 label: &str,
1114) -> Result<usize, FittedModelError> {
1115 let block = fit
1116 .block_by_role(role)
1117 .ok_or_else(|| FittedModelError::MissingField {
1118 reason: format!("location-scale survival saved fit is missing {label} block"),
1119 })?;
1120 if let Some(saved) = payload_beta
1121 && block.beta.to_vec() != *saved
1122 {
1123 return Err(FittedModelError::SchemaMismatch {
1124 reason: format!(
1125 "location-scale survival saved {label} coefficients disagree with fit_result"
1126 ),
1127 });
1128 }
1129 Ok(block.beta.len())
1130}
1131
1132fn validate_survival_location_scale_saved_fit(
1133 payload: &FittedModelPayload,
1134 link_wiggle: Option<&SavedLinkWiggleRuntime>,
1135) -> Result<(), FittedModelError> {
1136 let fit = payload
1137 .fit_result
1138 .as_ref()
1139 .ok_or_else(|| FittedModelError::MissingField {
1140 reason: "location-scale survival model is missing canonical fit_result payload"
1141 .to_string(),
1142 })?;
1143 let p_time = validate_survival_saved_block_matches_payload(
1144 fit,
1145 BlockRole::Time,
1146 payload.survival_beta_time.as_ref(),
1147 "time",
1148 )?;
1149 let p_threshold = validate_survival_saved_block_matches_payload(
1150 fit,
1151 BlockRole::Threshold,
1152 payload.survival_beta_threshold.as_ref(),
1153 "threshold",
1154 )?;
1155 let p_log_sigma = validate_survival_saved_block_matches_payload(
1156 fit,
1157 BlockRole::Scale,
1158 payload.survival_beta_log_sigma.as_ref(),
1159 "log-sigma",
1160 )?;
1161 let p_wiggle = match link_wiggle {
1162 Some(runtime) => {
1163 let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1164 FittedModelError::MissingField {
1165 reason: "location-scale survival saved fit is missing link-wiggle block"
1166 .to_string(),
1167 }
1168 })?;
1169 if block.beta.to_vec() != runtime.beta {
1170 return Err(FittedModelError::SchemaMismatch {
1171 reason:
1172 "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1173 .to_string(),
1174 });
1175 }
1176 runtime.beta.len()
1177 }
1178 None => {
1179 if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1180 return Err(FittedModelError::SchemaMismatch {
1181 reason:
1182 "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1183 .to_string(),
1184 });
1185 }
1186 0
1187 }
1188 };
1189 let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1190
1191 if let Some(cov) = fit.beta_covariance()
1192 && (cov.nrows() != expected || cov.ncols() != expected)
1193 {
1194 return Err(FittedModelError::SchemaMismatch {
1195 reason: format!(
1196 "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1197 cov.nrows(),
1198 cov.ncols(),
1199 expected,
1200 expected
1201 ),
1202 });
1203 }
1204 if let Some(cov) = fit.beta_covariance_corrected()
1205 && (cov.nrows() != expected || cov.ncols() != expected)
1206 {
1207 return Err(FittedModelError::SchemaMismatch {
1208 reason: format!(
1209 "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1210 cov.nrows(),
1211 cov.ncols(),
1212 expected,
1213 expected
1214 ),
1215 });
1216 }
1217 Ok(())
1218}
1219
1220fn validate_marginal_slope_saved_fit(
1221 fit: &UnifiedFitResult,
1222 score_warp: Option<&SavedCompiledFlexBlock>,
1223 link_deviation: Option<&SavedCompiledFlexBlock>,
1224 fit_label: &str,
1225) -> Result<(), FittedModelError> {
1226 validate_marginal_slope_saved_fit_impl(
1227 fit,
1228 score_warp,
1229 link_deviation,
1230 fit_label,
1231 "bernoulli",
1232 2,
1233 "marginal, logslope",
1234 )
1235}
1236
1237fn validate_survival_marginal_slope_saved_fit(
1238 fit: &UnifiedFitResult,
1239 score_warp: Option<&SavedCompiledFlexBlock>,
1240 link_deviation: Option<&SavedCompiledFlexBlock>,
1241 fit_label: &str,
1242) -> Result<(), FittedModelError> {
1243 validate_marginal_slope_saved_fit_impl(
1244 fit,
1245 score_warp,
1246 link_deviation,
1247 fit_label,
1248 "survival",
1249 3,
1250 "time, marginal, slope",
1251 )
1252}
1253
1254fn validate_marginal_slope_saved_fit_impl(
1262 fit: &UnifiedFitResult,
1263 score_warp: Option<&SavedCompiledFlexBlock>,
1264 link_deviation: Option<&SavedCompiledFlexBlock>,
1265 fit_label: &str,
1266 family_kind: &str,
1267 base_block_count: usize,
1268 base_block_role_list: &str,
1269) -> Result<(), FittedModelError> {
1270 let expected_blocks = base_block_count
1271 + usize::from(score_warp.is_some())
1272 + usize::from(link_deviation.is_some());
1273 if fit.blocks.len() != expected_blocks {
1274 let score_warp_suffix = if score_warp.is_some() {
1275 ", score-warp"
1276 } else {
1277 ""
1278 };
1279 let link_deviation_suffix = if link_deviation.is_some() {
1280 ", link-deviation"
1281 } else {
1282 ""
1283 };
1284 return Err(FittedModelError::SchemaMismatch {
1285 reason: format!(
1286 "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1287 fit.blocks.len(),
1288 ),
1289 });
1290 }
1291 if let Some(runtime) = score_warp {
1292 let beta = &fit.blocks[base_block_count].beta;
1293 if beta.len() != runtime.basis_dim {
1294 return Err(FittedModelError::SchemaMismatch {
1295 reason: format!(
1296 "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1297 beta.len(),
1298 runtime.basis_dim
1299 ),
1300 });
1301 }
1302 }
1303 if let Some(runtime) = link_deviation {
1304 let idx = base_block_count + usize::from(score_warp.is_some());
1305 let beta = &fit.blocks[idx].beta;
1306 if beta.len() != runtime.basis_dim {
1307 return Err(FittedModelError::SchemaMismatch {
1308 reason: format!(
1309 "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1310 beta.len(),
1311 runtime.basis_dim
1312 ),
1313 });
1314 }
1315 }
1316 Ok(())
1317}
1318
1319impl SavedLinkWiggleRuntime {
1320 fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1321 validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved link-wiggle")
1322 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1323 }
1324
1325 fn validate_monotone_derivative(
1326 &self,
1327 q0: &Array1<f64>,
1328 ) -> Result<Array1<f64>, FittedModelError> {
1329 self.validate_global_monotonicity()?;
1330 let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1331 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1332 let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1333 if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1334 return Err(FittedModelError::PayloadCorrupt {
1335 reason: format!(
1336 "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1337 ),
1338 });
1339 }
1340 Ok(dq_dq0)
1341 }
1342
1343 pub fn constrained_basis(
1344 &self,
1345 q0: &Array1<f64>,
1346 basis_options: BasisOptions,
1347 ) -> Result<Array2<f64>, FittedModelError> {
1348 let knot_arr = Array1::from_vec(self.knots.clone());
1349 let constrained = monotone_wiggle_basis_with_derivative_order(
1350 q0.view(),
1351 &knot_arr,
1352 self.degree,
1353 basis_options.derivative_order,
1354 )
1355 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1356 if constrained.ncols() != self.beta.len() {
1357 return Err(FittedModelError::SchemaMismatch {
1358 reason: format!(
1359 "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1360 self.beta.len(),
1361 constrained.ncols()
1362 ),
1363 });
1364 }
1365 Ok(constrained)
1366 }
1367
1368 pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1369 self.validate_global_monotonicity()?;
1370 self.constrained_basis(q0, BasisOptions::value())
1371 }
1372
1373 pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1374 let q = Array1::from_vec(vec![q0]);
1375 let x = self.design(&q)?;
1376 if x.nrows() != 1 {
1377 return Err(FittedModelError::SchemaMismatch {
1378 reason: format!(
1379 "saved link-wiggle scalar evaluation expected 1 row, got {}",
1380 x.nrows()
1381 ),
1382 });
1383 }
1384 Ok(x.row(0).to_owned())
1385 }
1386
1387 pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1388 self.validate_monotone_derivative(q0)?;
1389 let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1390 let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1391 Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1392 }
1393
1394 pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395 self.validate_monotone_derivative(q0)
1396 }
1397}
1398
1399impl SavedBaselineTimeWiggleRuntime {
1400 pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1401 validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1402 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1403 }
1404}
1405
1406impl SavedCompiledFlexBlock {
1407 pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1408 if self.kernel.is_empty() {
1409 return Err(FittedModelError::SchemaMismatch {
1410 reason: "saved anchored deviation runtime is missing the exact kernel marker"
1411 .to_string(),
1412 });
1413 }
1414 if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1415 return Err(FittedModelError::IncompatibleConfig {
1416 reason: format!(
1417 "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1418 self.kernel,
1419 crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1420 ),
1421 });
1422 }
1423 if self.basis_dim == 0 {
1424 return Err(FittedModelError::SchemaMismatch {
1425 reason: format!(
1426 "saved anchored deviation runtime basis_dim must be positive, got {}",
1427 self.basis_dim
1428 ),
1429 });
1430 }
1431 if self.breakpoints.len() < 2 {
1432 return Err(FittedModelError::SchemaMismatch {
1433 reason: format!(
1434 "saved anchored deviation runtime requires at least two breakpoints, got {}",
1435 self.breakpoints.len()
1436 ),
1437 });
1438 }
1439 for window in self.breakpoints.windows(2) {
1440 let left = window[0];
1441 let right = window[1];
1442 if !left.is_finite() || !right.is_finite() || right <= left {
1443 return Err(FittedModelError::PayloadCorrupt {
1444 reason: format!(
1445 "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1446 ),
1447 });
1448 }
1449 }
1450 let span_count = self.breakpoints.len() - 1;
1451 self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1452 self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1453 self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1454 self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1455 self.validate_c2_span_continuity()?;
1456 self.validate_anchor_residual_shape()?;
1457 Ok(())
1458 }
1459
1460 fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1461 let coeffs = match self.anchor_correction.as_ref() {
1462 Some(c) => c,
1463 None => {
1464 if !self.anchor_components.is_empty() {
1465 return Err(FittedModelError::SchemaMismatch {
1466 reason:
1467 "saved anchored deviation runtime has anchor_components but no anchor_correction"
1468 .to_string(),
1469 });
1470 }
1471 return Ok(());
1472 }
1473 };
1474 let d: usize = self
1475 .anchor_components
1476 .iter()
1477 .map(|c| match &c.kind {
1478 SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1479 SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1480 })
1481 .sum();
1482 if coeffs.len() != d {
1483 return Err(FittedModelError::SchemaMismatch {
1484 reason: format!(
1485 "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1486 coeffs.len(),
1487 d,
1488 ),
1489 });
1490 }
1491 for (i, row) in coeffs.iter().enumerate() {
1492 if row.len() != self.basis_dim {
1493 return Err(FittedModelError::SchemaMismatch {
1494 reason: format!(
1495 "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1496 i,
1497 row.len(),
1498 self.basis_dim,
1499 ),
1500 });
1501 }
1502 for (j, &v) in row.iter().enumerate() {
1503 if !v.is_finite() {
1504 return Err(FittedModelError::PayloadCorrupt {
1505 reason: format!(
1506 "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1507 ),
1508 });
1509 }
1510 }
1511 }
1512 Ok(())
1513 }
1514
1515 fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1516 const TOL: f64 = 1e-8;
1517 for span_idx in 1..self.breakpoints.len() - 1 {
1518 let left_span = span_idx - 1;
1519 let right_span = span_idx;
1520 let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1521 for basis_idx in 0..self.basis_dim {
1522 let left_value = self.span_c0[left_span][basis_idx]
1523 + self.span_c1[left_span][basis_idx] * width
1524 + self.span_c2[left_span][basis_idx] * width * width
1525 + self.span_c3[left_span][basis_idx] * width * width * width;
1526 let left_d1 = self.span_c1[left_span][basis_idx]
1527 + 2.0 * self.span_c2[left_span][basis_idx] * width
1528 + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1529 let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1530 + 6.0 * self.span_c3[left_span][basis_idx] * width;
1531 let right_value = self.span_c0[right_span][basis_idx];
1532 let right_d1 = self.span_c1[right_span][basis_idx];
1533 let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1534 if (left_value - right_value).abs() > TOL
1535 || (left_d1 - right_d1).abs() > TOL
1536 || (left_d2 - right_d2).abs() > TOL
1537 {
1538 return Err(FittedModelError::SchemaMismatch {
1539 reason: format!(
1540 "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1541 left_value - right_value,
1542 left_d1 - right_d1,
1543 left_d2 - right_d2
1544 ),
1545 });
1546 }
1547 }
1548 }
1549 Ok(())
1550 }
1551
1552 fn validate_coefficient_matrix(
1553 &self,
1554 matrix: &[Vec<f64>],
1555 label: &str,
1556 expected_rows: usize,
1557 ) -> Result<(), FittedModelError> {
1558 if matrix.len() != expected_rows {
1559 return Err(FittedModelError::SchemaMismatch {
1560 reason: format!(
1561 "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1562 matrix.len(),
1563 expected_rows
1564 ),
1565 });
1566 }
1567 for (row_idx, row) in matrix.iter().enumerate() {
1568 if row.len() != self.basis_dim {
1569 return Err(FittedModelError::SchemaMismatch {
1570 reason: format!(
1571 "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1572 row_idx,
1573 row.len(),
1574 self.basis_dim
1575 ),
1576 });
1577 }
1578 for (j, &value) in row.iter().enumerate() {
1579 if !value.is_finite() {
1580 return Err(FittedModelError::PayloadCorrupt {
1581 reason: format!(
1582 "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1583 ),
1584 });
1585 }
1586 }
1587 }
1588 Ok(())
1589 }
1590
1591 fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1592 let last_span = self.breakpoints.len() - 2;
1593 let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1594 self.span_c0[last_span][basis_idx]
1595 + self.span_c1[last_span][basis_idx] * width
1596 + self.span_c2[last_span][basis_idx] * width * width
1597 + self.span_c3[last_span][basis_idx] * width * width * width
1598 }
1599
1600 fn evaluate_span_polynomial_design(
1601 &self,
1602 values: &Array1<f64>,
1603 derivative_order: usize,
1604 ) -> Result<Array2<f64>, FittedModelError> {
1605 self.validate_exact_replay_contract()?;
1606 let (left_ep, right_ep) = self.support_interval()?;
1607 let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1608 for (row_idx, &value) in values.iter().enumerate() {
1609 if !value.is_finite() {
1610 return Err(FittedModelError::PayloadCorrupt {
1611 reason: format!(
1612 "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1613 ),
1614 });
1615 }
1616 if value < left_ep {
1617 if derivative_order == 0 {
1618 for basis_idx in 0..self.basis_dim {
1619 out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1620 }
1621 }
1622 continue;
1623 }
1624 if value > right_ep {
1625 if derivative_order == 0 {
1626 for basis_idx in 0..self.basis_dim {
1627 out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1628 }
1629 }
1630 continue;
1631 }
1632 let span_idx = self.left_biased_span_index_for(value)?;
1633 let t = value - self.breakpoints[span_idx];
1634 for basis_idx in 0..self.basis_dim {
1635 let c0 = self.span_c0[span_idx][basis_idx];
1636 let c1 = self.span_c1[span_idx][basis_idx];
1637 let c2 = self.span_c2[span_idx][basis_idx];
1638 let c3 = self.span_c3[span_idx][basis_idx];
1639 out[[row_idx, basis_idx]] = match derivative_order {
1640 0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1641 1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1642 2 => 2.0 * c2 + 6.0 * c3 * t,
1643 3 => 6.0 * c3,
1644 4 => 0.0,
1645 other => {
1646 return Err(FittedModelError::IncompatibleConfig {
1647 reason: format!(
1648 "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1649 ),
1650 });
1651 }
1652 };
1653 }
1654 }
1655 Ok(out)
1656 }
1657
1658 pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1659 self.validate_exact_replay_contract()?;
1660 Ok(self.breakpoints.clone())
1661 }
1662
1663 pub fn span_count(&self) -> Result<usize, FittedModelError> {
1664 Ok(self.breakpoints()?.windows(2).count())
1665 }
1666
1667 pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1668 let points = self.breakpoints()?;
1669 span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1670 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1671 }
1672
1673 fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1674 let mut span_idx = span_index_for_breakpoints(
1675 &self.breakpoints,
1676 value,
1677 "saved anchored deviation span lookup",
1678 )
1679 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1680 if span_idx > 0 && value == self.breakpoints[span_idx] {
1683 span_idx -= 1;
1684 }
1685 Ok(span_idx)
1686 }
1687
1688 pub fn local_cubic_on_span(
1689 &self,
1690 beta: &Array1<f64>,
1691 span_idx: usize,
1692 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1693 self.validate_exact_replay_contract()?;
1694 if beta.len() != self.basis_dim {
1695 return Err(FittedModelError::SchemaMismatch {
1696 reason: format!(
1697 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1698 beta.len(),
1699 self.basis_dim
1700 ),
1701 });
1702 }
1703 self.local_cubic_on_span_validated(beta, span_idx)
1704 }
1705
1706 fn local_cubic_on_span_validated(
1707 &self,
1708 beta: &Array1<f64>,
1709 span_idx: usize,
1710 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1711 let points = &self.breakpoints;
1712 if span_idx + 1 >= points.len() {
1713 return Err(FittedModelError::SchemaMismatch {
1714 reason: format!(
1715 "saved anchored deviation span index {} out of range for {} spans",
1716 span_idx,
1717 points.len() - 1
1718 ),
1719 });
1720 }
1721 let left = points[span_idx];
1722 let right = points[span_idx + 1];
1723 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1724 left,
1725 right,
1726 c0: self.span_c0[span_idx]
1727 .iter()
1728 .zip(beta.iter())
1729 .map(|(coeff, weight)| coeff * weight)
1730 .sum(),
1731 c1: self.span_c1[span_idx]
1732 .iter()
1733 .zip(beta.iter())
1734 .map(|(coeff, weight)| coeff * weight)
1735 .sum(),
1736 c2: self.span_c2[span_idx]
1737 .iter()
1738 .zip(beta.iter())
1739 .map(|(coeff, weight)| coeff * weight)
1740 .sum(),
1741 c3: self.span_c3[span_idx]
1742 .iter()
1743 .zip(beta.iter())
1744 .map(|(coeff, weight)| coeff * weight)
1745 .sum(),
1746 })
1747 }
1748
1749 pub fn basis_span_cubic(
1750 &self,
1751 span_idx: usize,
1752 basis_idx: usize,
1753 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1754 self.validate_exact_replay_contract()?;
1755 if basis_idx >= self.basis_dim {
1756 return Err(FittedModelError::SchemaMismatch {
1757 reason: format!(
1758 "saved anchored deviation basis index {} out of range for {} coefficients",
1759 basis_idx, self.basis_dim
1760 ),
1761 });
1762 }
1763 self.basis_span_cubic_validated(span_idx, basis_idx)
1764 }
1765
1766 fn basis_span_cubic_validated(
1767 &self,
1768 span_idx: usize,
1769 basis_idx: usize,
1770 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1771 let points = &self.breakpoints;
1772 if span_idx + 1 >= points.len() {
1773 return Err(FittedModelError::SchemaMismatch {
1774 reason: format!(
1775 "saved anchored deviation span index {} out of range for {} spans",
1776 span_idx,
1777 points.len() - 1
1778 ),
1779 });
1780 }
1781 Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1782 left: points[span_idx],
1783 right: points[span_idx + 1],
1784 c0: self.span_c0[span_idx][basis_idx],
1785 c1: self.span_c1[span_idx][basis_idx],
1786 c2: self.span_c2[span_idx][basis_idx],
1787 c3: self.span_c3[span_idx][basis_idx],
1788 })
1789 }
1790
1791 pub fn basis_cubic_at(
1792 &self,
1793 basis_idx: usize,
1794 value: f64,
1795 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1796 self.validate_exact_replay_contract()?;
1797 if basis_idx >= self.basis_dim {
1798 return Err(FittedModelError::SchemaMismatch {
1799 reason: format!(
1800 "saved anchored deviation basis index {} out of range for {} coefficients",
1801 basis_idx, self.basis_dim
1802 ),
1803 });
1804 }
1805 let (left_ep, right_ep) = self.support_interval()?;
1806 if value < left_ep {
1807 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1808 left: left_ep,
1809 right: left_ep + 1.0,
1810 c0: self.span_c0[0][basis_idx],
1811 c1: 0.0,
1812 c2: 0.0,
1813 c3: 0.0,
1814 });
1815 }
1816 if value > right_ep {
1817 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1818 left: right_ep,
1819 right: right_ep + 1.0,
1820 c0: self.right_boundary_basis_value(basis_idx),
1821 c1: 0.0,
1822 c2: 0.0,
1823 c3: 0.0,
1824 });
1825 }
1826 let span_idx = self.left_biased_span_index_for(value)?;
1827 self.basis_span_cubic_validated(span_idx, basis_idx)
1828 }
1829
1830 pub fn local_cubic_at(
1831 &self,
1832 beta: &Array1<f64>,
1833 value: f64,
1834 ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1835 self.validate_exact_replay_contract()?;
1836 if beta.len() != self.basis_dim {
1837 return Err(FittedModelError::SchemaMismatch {
1838 reason: format!(
1839 "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1840 beta.len(),
1841 self.basis_dim
1842 ),
1843 });
1844 }
1845 let (left_ep, right_ep) = self.support_interval()?;
1846 if value < left_ep {
1847 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1848 left: left_ep,
1849 right: left_ep + 1.0,
1850 c0: self.span_c0[0]
1851 .iter()
1852 .zip(beta.iter())
1853 .map(|(coeff, weight)| coeff * weight)
1854 .sum(),
1855 c1: 0.0,
1856 c2: 0.0,
1857 c3: 0.0,
1858 });
1859 }
1860 if value > right_ep {
1861 return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1862 left: right_ep,
1863 right: right_ep + 1.0,
1864 c0: (0..self.basis_dim)
1865 .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1866 .sum(),
1867 c1: 0.0,
1868 c2: 0.0,
1869 c3: 0.0,
1870 });
1871 }
1872 let span_idx = self.left_biased_span_index_for(value)?;
1873 self.local_cubic_on_span_validated(beta, span_idx)
1874 }
1875
1876 fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1877 let points = self.breakpoints()?;
1878 match (points.first(), points.last()) {
1879 (Some(&left), Some(&right)) => Ok((left, right)),
1880 _ => Err(FittedModelError::MissingField {
1881 reason: "saved anchored deviation runtime is missing support breakpoints"
1882 .to_string(),
1883 }),
1884 }
1885 }
1886
1887 pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1888 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1897 }
1898
1899 pub fn design_uncorrected(
1907 &self,
1908 values: &Array1<f64>,
1909 ) -> Result<Array2<f64>, FittedModelError> {
1910 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1911 }
1912
1913 pub fn design_with_anchor_rows(
1922 &self,
1923 values: &Array1<f64>,
1924 anchor_rows: ndarray::ArrayView2<f64>,
1925 ) -> Result<Array2<f64>, FittedModelError> {
1926 let mut out =
1927 self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1928 if let Some(m_rows) = self.anchor_correction.as_ref() {
1929 let d = m_rows.len();
1930 if anchor_rows.nrows() != values.len() {
1931 return Err(FittedModelError::SchemaMismatch {
1932 reason: format!(
1933 "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1934 anchor_rows.nrows(),
1935 values.len(),
1936 ),
1937 });
1938 }
1939 if anchor_rows.ncols() != d {
1940 return Err(FittedModelError::SchemaMismatch {
1941 reason: format!(
1942 "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1943 anchor_rows.ncols(),
1944 d,
1945 ),
1946 });
1947 }
1948 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1950 for (i, row) in m_rows.iter().enumerate() {
1951 if row.len() != self.basis_dim {
1952 return Err(FittedModelError::SchemaMismatch {
1953 reason: format!(
1954 "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1955 i,
1956 row.len(),
1957 self.basis_dim,
1958 ),
1959 });
1960 }
1961 for (j, &v) in row.iter().enumerate() {
1962 m_dense[[i, j]] = v;
1963 }
1964 }
1965 let subtract = anchor_rows.dot(&m_dense);
1968 out = out - subtract;
1969 } else if anchor_rows.ncols() != 0 {
1970 return Err(FittedModelError::SchemaMismatch {
1971 reason: format!(
1972 "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1973 anchor_rows.ncols(),
1974 ),
1975 });
1976 }
1977 Ok(out)
1978 }
1979
1980 pub fn anchor_correction_matrix(
1988 &self,
1989 n_anchor_rows: ndarray::ArrayView2<f64>,
1990 ) -> Result<Option<Array2<f64>>, FittedModelError> {
1991 let Some(m_rows) = self.anchor_correction.as_ref() else {
1992 return Ok(None);
1993 };
1994 let d = m_rows.len();
1995 if n_anchor_rows.ncols() != d {
1996 return Err(FittedModelError::SchemaMismatch {
1997 reason: format!(
1998 "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
1999 n_anchor_rows.ncols(),
2000 d,
2001 ),
2002 });
2003 }
2004 let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2005 for (i, row) in m_rows.iter().enumerate() {
2006 if row.len() != self.basis_dim {
2007 return Err(FittedModelError::SchemaMismatch {
2008 reason: format!(
2009 "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2010 i,
2011 row.len(),
2012 self.basis_dim,
2013 ),
2014 });
2015 }
2016 for (j, &v) in row.iter().enumerate() {
2017 m_dense[[i, j]] = v;
2018 }
2019 }
2020 Ok(Some(n_anchor_rows.dot(&m_dense)))
2023 }
2024
2025 pub fn first_derivative_design(
2026 &self,
2027 values: &Array1<f64>,
2028 ) -> Result<Array2<f64>, FittedModelError> {
2029 self.evaluate_span_polynomial_design(
2030 values,
2031 BasisOptions::first_derivative().derivative_order,
2032 )
2033 }
2034
2035 pub fn second_derivative_design(
2036 &self,
2037 values: &Array1<f64>,
2038 ) -> Result<Array2<f64>, FittedModelError> {
2039 self.evaluate_span_polynomial_design(
2040 values,
2041 BasisOptions::second_derivative().derivative_order,
2042 )
2043 }
2044}
2045
2046impl FittedFamily {
2047 #[inline]
2048 pub fn likelihood(&self) -> LikelihoodSpec {
2049 let spec = match self {
2050 Self::Standard { likelihood, .. }
2051 | Self::LocationScale { likelihood, .. }
2052 | Self::MarginalSlope { likelihood, .. }
2053 | Self::Survival { likelihood, .. }
2054 | Self::TransformationNormal { likelihood, .. } => likelihood,
2055 Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2056 return LikelihoodSpec::royston_parmar();
2057 }
2058 };
2059 spec.clone()
2060 }
2061
2062 #[inline]
2063 pub fn frailty(&self) -> Option<&FrailtySpec> {
2064 match self {
2065 Self::MarginalSlope { frailty, .. }
2066 | Self::Survival { frailty, .. }
2067 | Self::LatentSurvival { frailty }
2068 | Self::LatentBinary { frailty } => Some(frailty),
2069 _ => None,
2070 }
2071 }
2072}
2073
2074fn collect_smooth_extrapolation_axes(
2080 basis: &gam_terms::smooth::SmoothBasisSpec,
2081 n_training_headers: usize,
2082 out: &mut std::collections::HashSet<usize>,
2083) {
2084 use gam_terms::smooth::SmoothBasisSpec;
2085 let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2086 if col < n_training_headers {
2087 out.insert(col);
2088 }
2089 };
2090 match basis {
2091 SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2093 SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2096 for &c in feature_cols {
2097 push(c, out);
2098 }
2099 }
2100 SmoothBasisSpec::ThinPlate { feature_cols, .. }
2107 | SmoothBasisSpec::Matern { feature_cols, .. }
2108 | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2109 | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2110 for &c in feature_cols {
2111 push(c, out);
2112 }
2113 }
2114 SmoothBasisSpec::FactorSmooth { spec } => {
2118 for &c in &spec.continuous_cols {
2119 push(c, out);
2120 }
2121 }
2122 SmoothBasisSpec::ByVariable { inner, .. }
2124 | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2125 collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2126 }
2127 SmoothBasisSpec::BySmooth { smooth, .. } => {
2128 collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2129 }
2130 SmoothBasisSpec::Sphere { .. }
2136 | SmoothBasisSpec::ConstantCurvature { .. }
2137 | SmoothBasisSpec::Pca { .. } => {}
2138 }
2139}
2140
2141fn collect_by_variable_numeric_axes(
2160 basis: &gam_terms::smooth::SmoothBasisSpec,
2161 n_training_headers: usize,
2162 out: &mut std::collections::HashSet<usize>,
2163) {
2164 use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2165 match basis {
2166 SmoothBasisSpec::ByVariable {
2167 inner,
2168 by_col,
2169 kind,
2170 ..
2171 } => {
2172 if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2173 out.insert(*by_col);
2174 }
2175 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2176 }
2177 SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2178 if let ByVarKind::Numeric { feature_col } = by_kind
2179 && *feature_col < n_training_headers
2180 {
2181 out.insert(*feature_col);
2182 }
2183 collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2184 }
2185 SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2186 collect_by_variable_numeric_axes(inner, n_training_headers, out);
2187 }
2188 _ => {}
2189 }
2190}
2191
2192impl FittedModel {
2193 pub fn axis_clip_to_training_ranges(
2202 &self,
2203 data: ndarray::ArrayView2<'_, f64>,
2204 col_map: &std::collections::HashMap<String, usize>,
2205 ) -> Option<ndarray::Array2<f64>> {
2206 let training_headers = self.training_headers.as_ref()?;
2207 let ranges = self.training_feature_ranges.as_ref()?;
2208 if training_headers.len() != ranges.len() {
2209 return None;
2210 }
2211 let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2212 std::collections::HashMap::new();
2213 if let Some(schema) = self.data_schema.as_ref() {
2214 for col in &schema.columns {
2215 kind_by_header.insert(col.name.as_str(), col.kind);
2216 }
2217 }
2218 let periodic_axes = self.training_periodic_axes(training_headers);
2224 let linear_axes = self.training_linear_axes(training_headers.len());
2232 let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2237 let smooth_extrapolation_axes =
2245 self.training_smooth_extrapolation_axes(training_headers.len());
2246 let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2252 let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2257 let mut clipped = data.to_owned();
2258 let mut any_clipped = false;
2259 for (col_in_training, (header, &(lo, hi))) in
2260 training_headers.iter().zip(ranges.iter()).enumerate()
2261 {
2262 let (lo, hi) = sphere_lat_bounds
2263 .get(&col_in_training)
2264 .copied()
2265 .unwrap_or((lo, hi));
2266 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2267 continue;
2268 }
2269 if !matches!(
2270 kind_by_header.get(header.as_str()).copied(),
2271 Some(ColumnKindTag::Continuous)
2272 ) {
2273 continue;
2274 }
2275 if periodic_axes.contains(&col_in_training) {
2276 continue;
2277 }
2278 if linear_axes.contains(&col_in_training) {
2279 continue;
2280 }
2281 if random_effect_axes.contains(&col_in_training) {
2282 continue;
2283 }
2284 if smooth_extrapolation_axes.contains(&col_in_training) {
2285 continue;
2286 }
2287 if by_variable_axes.contains(&col_in_training) {
2288 continue;
2289 }
2290 let Some(&col_idx) = col_map.get(header) else {
2291 continue;
2292 };
2293 if col_idx >= clipped.ncols() {
2294 continue;
2295 }
2296 let mut col = clipped.column_mut(col_idx);
2297 for v in col.iter_mut() {
2298 if v.is_finite() {
2299 if *v < lo {
2300 *v = lo;
2301 any_clipped = true;
2302 } else if *v > hi {
2303 *v = hi;
2304 any_clipped = true;
2305 }
2306 }
2307 }
2308 }
2309 if any_clipped { Some(clipped) } else { None }
2310 }
2311
2312 fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2313 let mut specs: Vec<&TermCollectionSpec> = [
2314 self.resolved_termspec.as_ref(),
2315 self.resolved_termspec_noise.as_ref(),
2316 self.resolved_termspec_logslope.as_ref(),
2317 ]
2318 .into_iter()
2319 .flatten()
2320 .collect();
2321 if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2322 specs.extend(logslopes.iter());
2323 }
2324 specs
2325 }
2326
2327 fn training_periodic_axes(
2334 &self,
2335 training_headers: &[String],
2336 ) -> std::collections::HashSet<usize> {
2337 use gam_terms::basis::BSplineKnotSpec;
2338 use gam_terms::smooth::SmoothBasisSpec;
2339 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2340 let Some(spec) = self.resolved_termspec.as_ref() else {
2341 return out;
2342 };
2343 for term in &spec.smooth_terms {
2344 match &term.basis {
2345 SmoothBasisSpec::Sphere { feature_cols, .. } => {
2351 if let Some(&lon_col) = feature_cols.get(1)
2352 && lon_col < training_headers.len()
2353 {
2354 out.insert(lon_col);
2355 }
2356 }
2357 SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2359 if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2360 && *feature_col < training_headers.len()
2361 {
2362 out.insert(*feature_col);
2363 }
2364 }
2365 SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2368 for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2369 if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2370 && let Some(&col) = feature_cols.get(i)
2371 && col < training_headers.len()
2372 {
2373 out.insert(col);
2374 }
2375 }
2376 }
2377 _ => {}
2378 }
2379 }
2380 out
2381 }
2382
2383 fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2394 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2395 for spec in self.saved_term_specs() {
2396 for term in &spec.linear_terms {
2397 for col in term.effective_feature_cols() {
2398 if col < n_training_headers {
2399 out.insert(col);
2400 }
2401 }
2402 }
2403 }
2404 out
2405 }
2406
2407 fn training_random_effect_axes(
2413 &self,
2414 n_training_headers: usize,
2415 ) -> std::collections::HashSet<usize> {
2416 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2417 for spec in self.saved_term_specs() {
2418 for term in &spec.random_effect_terms {
2419 if term.feature_col < n_training_headers {
2420 out.insert(term.feature_col);
2421 }
2422 }
2423 }
2424 out
2425 }
2426
2427 fn training_smooth_extrapolation_axes(
2460 &self,
2461 n_training_headers: usize,
2462 ) -> std::collections::HashSet<usize> {
2463 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2464 for spec in self.saved_term_specs() {
2465 for term in &spec.smooth_terms {
2466 collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2467 }
2468 }
2469 out
2470 }
2471
2472 fn training_by_variable_numeric_axes(
2478 &self,
2479 n_training_headers: usize,
2480 ) -> std::collections::HashSet<usize> {
2481 let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2482 for spec in self.saved_term_specs() {
2483 for term in &spec.smooth_terms {
2484 collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2485 }
2486 }
2487 out
2488 }
2489
2490 fn training_sphere_latitude_bounds(
2512 &self,
2513 training_headers: &[String],
2514 ) -> std::collections::HashMap<usize, (f64, f64)> {
2515 use gam_terms::smooth::SmoothBasisSpec;
2516 let mut out: std::collections::HashMap<usize, (f64, f64)> =
2517 std::collections::HashMap::new();
2518 let Some(spec) = self.resolved_termspec.as_ref() else {
2519 return out;
2520 };
2521 for term in &spec.smooth_terms {
2522 if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2523 && let Some(&lat_col) = feature_cols.first()
2524 && lat_col < training_headers.len()
2525 {
2526 let bound = if spec.radians {
2527 std::f64::consts::FRAC_PI_2
2528 } else {
2529 90.0
2530 };
2531 out.insert(lat_col, (-bound, bound));
2532 }
2533 }
2534 out
2535 }
2536
2537 pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2538 let likelihood = payload.family_state.likelihood();
2539 let class = match payload.model_kind {
2540 ModelKind::Survival => PredictModelClass::Survival,
2541 ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2542 ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2543 ModelKind::LocationScale => {
2544 if likelihood == LikelihoodSpec::gaussian_identity() {
2545 PredictModelClass::GaussianLocationScale
2546 } else if is_dispersion_location_scale_response(&likelihood.response) {
2547 PredictModelClass::DispersionLocationScale
2548 } else {
2549 PredictModelClass::BinomialLocationScale
2550 }
2551 }
2552 ModelKind::Standard => PredictModelClass::Standard,
2553 };
2554 match class {
2555 PredictModelClass::Survival => {
2556 payload.model_kind = ModelKind::Survival;
2557 Self::Survival { payload }
2558 }
2559 PredictModelClass::BernoulliMarginalSlope => {
2560 payload.model_kind = ModelKind::MarginalSlope;
2561 Self::MarginalSlope { payload }
2562 }
2563 PredictModelClass::TransformationNormal => {
2564 payload.model_kind = ModelKind::TransformationNormal;
2565 Self::TransformationNormal { payload }
2566 }
2567 PredictModelClass::GaussianLocationScale
2568 | PredictModelClass::BinomialLocationScale
2569 | PredictModelClass::DispersionLocationScale => {
2570 payload.model_kind = ModelKind::LocationScale;
2571 Self::LocationScale { payload }
2572 }
2573 PredictModelClass::Standard => {
2574 payload.model_kind = ModelKind::Standard;
2575 Self::Standard { payload }
2576 }
2577 }
2578 .with_synchronized_stateful_link_metadata()
2579 }
2580
2581 #[inline]
2582 pub fn payload(&self) -> &FittedModelPayload {
2583 match self {
2584 Self::Standard { payload }
2585 | Self::LocationScale { payload }
2586 | Self::MarginalSlope { payload }
2587 | Self::Survival { payload }
2588 | Self::TransformationNormal { payload } => payload,
2589 }
2590 }
2591
2592 #[inline]
2593 fn payload_mut(&mut self) -> &mut FittedModelPayload {
2594 match self {
2595 Self::Standard { payload }
2596 | Self::LocationScale { payload }
2597 | Self::MarginalSlope { payload }
2598 | Self::Survival { payload }
2599 | Self::TransformationNormal { payload } => payload,
2600 }
2601 }
2602
2603 fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2604 self.synchronize_stateful_link_metadata();
2605 self
2606 }
2607
2608 fn synchronize_stateful_link_metadata(&mut self) {
2609 let payload = self.payload_mut();
2610 payload.used_device = payload
2611 .fit_result
2612 .as_ref()
2613 .or(payload.unified.as_ref())
2614 .is_some_and(|fit| fit.used_device);
2615 payload.synchronize_empty_feature_contract();
2616 let Some(fit) = payload.fit_result.as_ref() else {
2617 return;
2618 };
2619 match (&mut payload.family_state, &fit.fitted_link) {
2620 (
2621 FittedFamily::Standard {
2622 likelihood,
2623 latent_cloglog_state,
2624 ..
2625 },
2626 FittedLinkState::LatentCLogLog { state },
2627 ) if likelihood.is_latent_cloglog() => {
2628 *latent_cloglog_state = Some(*state);
2629 }
2630 (
2631 FittedFamily::Standard {
2632 likelihood,
2633 sas_state,
2634 ..
2635 },
2636 FittedLinkState::Sas { state, covariance },
2637 ) if likelihood.is_binomial_sas() => {
2638 *sas_state = Some(*state);
2639 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2640 }
2641 (
2642 FittedFamily::Standard {
2643 likelihood,
2644 sas_state,
2645 ..
2646 },
2647 FittedLinkState::BetaLogistic { state, covariance },
2648 ) if likelihood.is_binomial_beta_logistic() => {
2649 *sas_state = Some(*state);
2650 payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2651 }
2652 (
2653 FittedFamily::Standard {
2654 likelihood,
2655 mixture_state,
2656 ..
2657 },
2658 FittedLinkState::Mixture { state, covariance },
2659 ) if likelihood.is_binomial_mixture() => {
2660 *mixture_state = Some(state.clone());
2661 payload.mixture_link_param_covariance =
2662 covariance.as_ref().map(array2_to_nestedvec);
2663 }
2664 _ => {}
2665 }
2666 }
2667
2668 #[inline]
2669 pub fn likelihood(&self) -> LikelihoodSpec {
2670 self.payload().family_state.likelihood()
2671 }
2672
2673 pub fn prediction_required_columns(
2691 &self,
2692 ) -> Result<std::collections::BTreeSet<String>, String> {
2693 let payload = self.payload();
2694 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2695 let mut required = std::collections::BTreeSet::<String>::new();
2696 parsed_term_column_names(&parsed.terms, &mut required);
2697
2698 if let Some((entry, exit, _event)) =
2699 parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2700 {
2701 if let Some(entry) = entry {
2702 required.insert(entry);
2703 }
2704 required.insert(exit);
2705 } else if let Some((left, right, _event)) =
2706 parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707 {
2708 required.insert(left);
2709 required.insert(right);
2710 } else if matches!(
2711 self.predict_model_class(),
2712 PredictModelClass::TransformationNormal
2713 ) {
2714 let response = parsed.response.trim();
2715 if !response.is_empty() && !response.starts_with("Surv(") {
2716 required.insert(response.to_string());
2717 }
2718 }
2719
2720 if let Some(offset) = payload.offset_column.as_ref() {
2721 required.insert(offset.clone());
2722 }
2723 if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2724 required.insert(noise_offset.clone());
2725 }
2726 if matches!(
2727 self.predict_model_class(),
2728 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2729 ) {
2730 if let Some(z_column) = payload.z_column.as_ref() {
2731 required.remove("z");
2732 required.insert(z_column.clone());
2733 }
2734 }
2735 if let Some(noise_formula) = payload.formula_noise.as_ref() {
2736 self.add_auxiliary_formula_columns(
2737 &mut required,
2738 noise_formula,
2739 parsed.response.as_str(),
2740 )?;
2741 }
2742 if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2743 if logslope_formula != "same-as-main" {
2744 self.add_auxiliary_formula_columns(
2745 &mut required,
2746 logslope_formula,
2747 parsed.response.as_str(),
2748 )?;
2749 }
2750 }
2751 Ok(required)
2752 }
2753
2754 pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2771 let payload = self.payload();
2772 let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2773 if parse_surv_response(parsed.response.as_str())
2776 .map_err(|e| e.to_string())?
2777 .is_some()
2778 || parse_surv_interval_response(parsed.response.as_str())
2779 .map_err(|e| e.to_string())?
2780 .is_some()
2781 {
2782 return Ok(Vec::new());
2783 }
2784 let response = parsed.response.trim();
2785 if response.is_empty() || response.contains('(') {
2788 return Ok(Vec::new());
2789 }
2790 if self.prediction_required_columns()?.contains(response) {
2793 return Ok(Vec::new());
2794 }
2795 Ok(vec![response.to_string()])
2796 }
2797
2798 fn add_auxiliary_formula_columns(
2801 &self,
2802 required: &mut std::collections::BTreeSet<String>,
2803 formula_or_rhs: &str,
2804 response: &str,
2805 ) -> Result<(), String> {
2806 let trimmed = formula_or_rhs.trim();
2807 if trimmed.is_empty() || trimmed == "1" {
2808 return Ok(());
2809 }
2810 let formula = if trimmed.contains('~') {
2811 trimmed.to_string()
2812 } else {
2813 format!("{response} ~ {trimmed}")
2814 };
2815 let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2816 parsed_term_column_names(&parsed.terms, required);
2817 Ok(())
2818 }
2819
2820 #[inline]
2821 pub fn predict_model_class(&self) -> PredictModelClass {
2822 match &self.payload().family_state {
2823 FittedFamily::Survival { .. }
2824 | FittedFamily::LatentSurvival { .. }
2825 | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2826 FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2827 FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2828 FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2829 PredictModelClass::GaussianLocationScale
2830 }
2831 FittedFamily::LocationScale { likelihood, .. }
2832 if is_dispersion_location_scale_response(&likelihood.response) =>
2833 {
2834 PredictModelClass::DispersionLocationScale
2835 }
2836 FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2837 FittedFamily::Standard { .. } => PredictModelClass::Standard,
2838 }
2839 }
2840
2841 pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2842 let payload = self.payload();
2843 let (knots, degree) = match (
2844 payload.linkwiggle_knots.as_ref(),
2845 payload.linkwiggle_degree,
2846 ) {
2847 (None, None) => return Ok(None),
2848 (Some(knots), Some(degree)) => (knots.clone(), degree),
2849 _ => {
2850 return Err(FittedModelError::SchemaMismatch {
2851 reason:
2852 "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2853 .to_string(),
2854 })
2855 }
2856 };
2857 let resolved_link = self.resolved_inverse_link()?;
2858 let saved_link_disallows_wiggle = resolved_link
2859 .as_ref()
2860 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2861 || payload
2862 .link
2863 .as_ref()
2864 .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2865 if saved_link_disallows_wiggle {
2866 return Err(FittedModelError::IncompatibleConfig {
2867 reason: joint_wiggle_unsupported_link_message("link wiggle"),
2868 });
2869 }
2870 let beta = match self.predict_model_class() {
2871 PredictModelClass::Standard => {
2872 if payload.beta_link_wiggle.is_some() {
2873 return Err(FittedModelError::SchemaMismatch {
2874 reason:
2875 "standard link-wiggle coefficients must be stored in fit_result LinkWiggle block, not payload.beta_link_wiggle"
2876 .to_string(),
2877 });
2878 }
2879 let fit = payload.fit_result.as_ref().ok_or_else(|| {
2880 FittedModelError::MissingField {
2881 reason:
2882 "standard link-wiggle model is missing canonical fit_result payload"
2883 .to_string(),
2884 }
2885 })?;
2886 if fit.blocks.len() != 2
2887 || fit.blocks[0].role != BlockRole::Mean
2888 || fit.blocks[1].role != BlockRole::LinkWiggle
2889 {
2890 return Err(FittedModelError::SchemaMismatch {
2891 reason:
2892 "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2893 .to_string(),
2894 });
2895 }
2896 fit.block_by_role(BlockRole::LinkWiggle)
2897 .ok_or_else(|| FittedModelError::MissingField {
2898 reason:
2899 "standard link-wiggle model is missing LinkWiggle coefficient block"
2900 .to_string(),
2901 })?
2902 .beta
2903 .to_vec()
2904 }
2905 _ => payload
2906 .beta_link_wiggle
2907 .clone()
2908 .ok_or_else(|| FittedModelError::MissingField {
2909 reason:
2910 "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2911 .to_string(),
2912 })?,
2913 };
2914 Ok(Some(SavedLinkWiggleRuntime {
2915 knots,
2916 degree,
2917 beta,
2918 }))
2919 }
2920
2921 pub fn saved_baseline_time_wiggle(
2922 &self,
2923 ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2924 let payload = self.payload();
2925 if payload
2926 .survival_cause_count
2927 .is_some_and(|cause_count| cause_count > 1)
2928 && payload.beta_baseline_timewiggle.is_none()
2929 && payload.beta_baseline_timewiggle_by_cause.is_some()
2930 {
2931 return Err(FittedModelError::SchemaMismatch {
2932 reason:
2933 "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2934 .to_string(),
2935 });
2936 }
2937 match (
2938 payload.baseline_timewiggle_knots.as_ref(),
2939 payload.baseline_timewiggle_degree,
2940 payload.baseline_timewiggle_penalty_orders.as_ref(),
2941 payload.baseline_timewiggle_double_penalty,
2942 payload.beta_baseline_timewiggle.as_ref(),
2943 ) {
2944 (None, None, None, None, None) => Ok(None),
2945 (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2946 Ok(Some(SavedBaselineTimeWiggleRuntime {
2947 knots: knots.clone(),
2948 degree,
2949 penalty_orders: penalty_orders.clone(),
2950 double_penalty,
2951 beta: beta.clone(),
2952 }))
2953 }
2954 _ => Err(FittedModelError::SchemaMismatch {
2955 reason:
2956 "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2957 .to_string(),
2958 }),
2959 }
2960 }
2961
2962 #[inline]
2964 pub fn has_link_wiggle(&self) -> bool {
2965 self.saved_link_wiggle()
2966 .map(|runtime| runtime.is_some())
2967 .unwrap_or(false)
2968 }
2969
2970 #[inline]
2972 pub fn has_baseline_time_wiggle(&self) -> bool {
2973 let payload = self.payload();
2974 if payload
2975 .survival_cause_count
2976 .is_some_and(|cause_count| cause_count > 1)
2977 {
2978 return payload.baseline_timewiggle_knots.is_some()
2979 && payload.baseline_timewiggle_degree.is_some()
2980 && payload.baseline_timewiggle_penalty_orders.is_some()
2981 && payload.baseline_timewiggle_double_penalty.is_some()
2982 && payload.beta_baseline_timewiggle_by_cause.is_some();
2983 }
2984 self.saved_baseline_time_wiggle()
2985 .map(|runtime| runtime.is_some())
2986 .unwrap_or(false)
2987 }
2988
2989 #[inline]
3014 pub fn prediction_uses_posterior_mean(&self) -> bool {
3015 let family = self.likelihood();
3016 let curved_family = match &family.response {
3017 ResponseFamily::Gaussian => false,
3020 ResponseFamily::Poisson
3022 | ResponseFamily::Gamma
3023 | ResponseFamily::Tweedie { .. }
3024 | ResponseFamily::NegativeBinomial { .. } => true,
3025 ResponseFamily::Beta { .. } => true,
3027 ResponseFamily::RoystonParmar => true,
3029 ResponseFamily::Binomial => matches!(
3032 &family.link,
3033 InverseLink::Standard(_)
3034 | InverseLink::Sas(_)
3035 | InverseLink::BetaLogistic(_)
3036 | InverseLink::Mixture(_)
3037 | InverseLink::LatentCLogLog(_)
3038 ),
3039 };
3040 curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3041 }
3042
3043 pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3044 self.payload().validate_payload_version()?;
3045 if matches!(
3046 self.predict_model_class(),
3047 PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3048 ) {
3049 if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3050 runtime.validate_exact_replay_contract().map_err(|err| {
3051 FittedModelError::PayloadCorrupt {
3052 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3053 }
3054 })?;
3055 }
3056 if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3057 runtime.validate_exact_replay_contract().map_err(|err| {
3058 FittedModelError::PayloadCorrupt {
3059 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3060 }
3061 })?;
3062 }
3063 }
3064 let runtime = SavedPredictionRuntime {
3065 model_class: self.predict_model_class(),
3066 likelihood: self.likelihood(),
3067 inverse_link: self.resolved_inverse_link()?,
3068 link_wiggle: self.saved_link_wiggle()?,
3069 baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3070 score_warp: self.payload().score_warp_runtime.clone(),
3071 link_deviation: self.payload().link_deviation_runtime.clone(),
3072 latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3073 latent_z_conditional_calibration: self
3074 .payload()
3075 .latent_z_conditional_calibration
3076 .clone(),
3077 influence_absorber_width: self.payload().influence_absorber_width,
3078 };
3079 if matches!(
3080 runtime.model_class,
3081 PredictModelClass::GaussianLocationScale
3082 | PredictModelClass::BinomialLocationScale
3083 | PredictModelClass::DispersionLocationScale
3084 ) {
3085 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3086 FittedModelError::MissingField {
3087 reason: "location-scale model is missing canonical fit_result payload"
3088 .to_string(),
3089 }
3090 })?;
3091 validate_location_scale_saved_fit(
3092 fit,
3093 runtime.model_class,
3094 runtime.link_wiggle.as_ref(),
3095 )?;
3096 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3097 && self
3098 .payload()
3099 .survival_likelihood
3100 .as_deref()
3101 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3102 {
3103 validate_survival_location_scale_saved_fit(
3104 self.payload(),
3105 runtime.link_wiggle.as_ref(),
3106 )?;
3107 } else if matches!(
3108 runtime.model_class,
3109 PredictModelClass::BernoulliMarginalSlope
3110 ) {
3111 let unified =
3112 self.payload()
3113 .unified
3114 .as_ref()
3115 .ok_or_else(|| FittedModelError::MissingField {
3116 reason: "marginal-slope model is missing unified fit payload; refit"
3117 .to_string(),
3118 })?;
3119 validate_marginal_slope_saved_fit(
3120 unified,
3121 runtime.score_warp.as_ref(),
3122 runtime.link_deviation.as_ref(),
3123 "unified",
3124 )?;
3125 } else if matches!(runtime.model_class, PredictModelClass::Survival)
3126 && self
3127 .payload()
3128 .survival_likelihood
3129 .as_deref()
3130 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3131 {
3132 let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3133 FittedModelError::MissingField {
3134 reason: "survival marginal-slope model is missing canonical fit_result payload"
3135 .to_string(),
3136 }
3137 })?;
3138 validate_survival_marginal_slope_saved_fit(
3139 fit,
3140 runtime.score_warp.as_ref(),
3141 runtime.link_deviation.as_ref(),
3142 "fit_result",
3143 )?;
3144 }
3145 Ok(runtime)
3146 }
3147
3148 pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3149 let payload = self.payload();
3150 let raw = match &payload.family_state {
3151 FittedFamily::Standard {
3152 likelihood,
3153 sas_state,
3154 ..
3155 } if likelihood.is_binomial_sas() => {
3156 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3157 reason: "binomial-sas model is missing state in family_state.sas_state"
3158 .to_string(),
3159 })?
3160 }
3161 FittedFamily::LocationScale {
3162 likelihood,
3163 base_link,
3164 } if likelihood.is_binomial_sas() => match base_link {
3165 Some(InverseLink::Sas(state)) => *state,
3166 _ => {
3167 return Err(FittedModelError::MissingField {
3168 reason: "binomial-sas location-scale model is missing SAS base_link state"
3169 .to_string(),
3170 });
3171 }
3172 },
3173 _ => return Ok(None),
3174 };
3175 state_from_sasspec(SasLinkSpec {
3176 initial_epsilon: raw.epsilon,
3177 initial_log_delta: raw.log_delta,
3178 })
3179 .map(Some)
3180 .map_err(|e| FittedModelError::PayloadCorrupt {
3181 reason: format!("invalid saved SAS link state: {e}"),
3182 })
3183 }
3184
3185 pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3186 let payload = self.payload();
3187 let raw = match &payload.family_state {
3188 FittedFamily::Standard {
3189 likelihood,
3190 sas_state,
3191 ..
3192 } if likelihood.is_binomial_beta_logistic() => {
3193 (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3194 reason:
3195 "binomial-beta-logistic model is missing state in family_state.sas_state"
3196 .to_string(),
3197 })?
3198 }
3199 FittedFamily::LocationScale {
3200 likelihood,
3201 base_link,
3202 } if likelihood.is_binomial_beta_logistic() => match base_link {
3203 Some(InverseLink::BetaLogistic(state)) => *state,
3204 _ => {
3205 return Err(FittedModelError::MissingField {
3206 reason:
3207 "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3208 .to_string(),
3209 });
3210 }
3211 },
3212 _ => return Ok(None),
3213 };
3214 state_from_beta_logisticspec(SasLinkSpec {
3215 initial_epsilon: raw.epsilon,
3216 initial_log_delta: raw.log_delta,
3217 })
3218 .map(Some)
3219 .map_err(|e| FittedModelError::PayloadCorrupt {
3220 reason: format!("invalid saved Beta-Logistic link state: {e}"),
3221 })
3222 }
3223
3224 pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3225 let payload = self.payload();
3226 match &payload.family_state {
3227 FittedFamily::Standard {
3228 likelihood,
3229 mixture_state,
3230 ..
3231 } if likelihood.is_binomial_mixture() => mixture_state
3232 .clone()
3233 .ok_or_else(|| FittedModelError::MissingField {
3234 reason: "binomial-mixture model is missing state in family_state.mixture_state"
3235 .to_string(),
3236 })
3237 .map(Some),
3238 FittedFamily::LocationScale {
3239 likelihood,
3240 base_link,
3241 } if likelihood.is_binomial_mixture() => match base_link {
3242 Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3243 _ => Err(FittedModelError::MissingField {
3244 reason:
3245 "binomial-mixture location-scale model is missing mixture base_link state"
3246 .to_string(),
3247 }),
3248 },
3249 _ => Ok(None),
3250 }
3251 }
3252
3253 pub fn saved_latent_cloglog_state(
3254 &self,
3255 ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3256 let payload = self.payload();
3257 match &payload.family_state {
3258 FittedFamily::Standard {
3259 likelihood,
3260 latent_cloglog_state,
3261 ..
3262 } if likelihood.is_latent_cloglog() => latent_cloglog_state
3263 .ok_or_else(|| FittedModelError::MissingField {
3264 reason:
3265 "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3266 .to_string(),
3267 })
3268 .map(Some),
3269 _ => Ok(None),
3270 }
3271 }
3272
3273 pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3274 let stateful = if let Some(state) = self.saved_mixture_state()? {
3275 Some(InverseLink::Mixture(state))
3276 } else if let Some(state) = self.saved_latent_cloglog_state()? {
3277 Some(InverseLink::LatentCLogLog(state))
3278 } else if let Some(state) = self.saved_beta_logistic_state()? {
3279 Some(InverseLink::BetaLogistic(state))
3280 } else {
3281 self.saved_sas_state()?.map(InverseLink::Sas)
3282 };
3283 match &self.payload().family_state {
3284 FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3285 FittedFamily::Standard { link, .. } => {
3286 Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3287 }
3288 FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3289 FittedFamily::Survival { .. }
3290 | FittedFamily::LatentSurvival { .. }
3291 | FittedFamily::LatentBinary { .. } => Ok(None),
3292 FittedFamily::TransformationNormal { .. } => Ok(None),
3293 }
3294 }
3295
3296 const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3305
3306 pub fn measure_jet_extrapolation_variance(
3333 &self,
3334 data: ndarray::ArrayView2<'_, f64>,
3335 col_map: &HashMap<String, usize>,
3336 ) -> Result<Option<Array1<f64>>, FittedModelError> {
3337 use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3338 use gam_terms::smooth::build_term_collection_design;
3339 use gam_terms::smooth::SmoothBasisSpec;
3340 let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3341 return Ok(None);
3342 };
3343 if data.nrows() == 0
3344 || !saved_spec
3345 .smooth_terms
3346 .iter()
3347 .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3348 {
3349 return Ok(None);
3350 }
3351 let fit = self
3352 .fit_result
3353 .as_ref()
3354 .ok_or_else(|| FittedModelError::MissingField {
3355 reason: "measure-jet extrapolation variance requires the canonical \
3356 fit_result payload; refit"
3357 .to_string(),
3358 })?;
3359 let spec = crate::survival::predict::resolve_termspec_for_prediction(
3360 &self.resolved_termspec,
3361 self.training_headers.as_ref(),
3362 col_map,
3363 "resolved_termspec",
3364 )
3365 .map_err(|e| FittedModelError::SchemaMismatch {
3366 reason: format!("measure-jet extrapolation variance: {e}"),
3367 })?;
3368 let probe = data.slice(ndarray::s![0..1, ..]);
3375 let design = build_term_collection_design(probe, &spec).map_err(|e| {
3376 FittedModelError::SchemaMismatch {
3377 reason: format!(
3378 "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3379 ),
3380 }
3381 })?;
3382 let lambdas = &fit.lambdas;
3383 let phi_scale = fit.coefficient_covariance_scale();
3388 let mut total = Array1::<f64>::zeros(data.nrows());
3389 let mut contributed = false;
3390 for term in &spec.smooth_terms {
3391 let SmoothBasisSpec::MeasureJet {
3392 feature_cols,
3393 spec: mj,
3394 input_scales,
3395 } = &term.basis
3396 else {
3397 continue;
3398 };
3399 let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3400 (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3401 else {
3402 log::warn!(
3403 "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3404 quadrature); skipping its extrapolation variance",
3405 term.name
3406 );
3407 continue;
3408 };
3409 let n_levels = frozen.eps_band.len();
3410 let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3417 lambdas
3418 .get(global_index)
3419 .copied()
3420 .ok_or_else(|| FittedModelError::SchemaMismatch {
3421 reason: format!(
3422 "measure-jet term '{}': penalty global index {global_index} out \
3423 of bounds for {} fitted lambdas",
3424 term.name,
3425 lambdas.len()
3426 ),
3427 })
3428 };
3429 let mut per_scale: Vec<(usize, f64)> = Vec::new();
3430 let mut fused: Option<f64> = None;
3431 for info in &design.penaltyinfo {
3432 if info.termname.as_deref() != Some(term.name.as_str()) {
3433 continue;
3434 }
3435 match &info.penalty.source {
3436 PenaltySource::Other(label) => {
3437 if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3438 let level: usize = level_txt.parse().map_err(|_| {
3439 FittedModelError::SchemaMismatch {
3440 reason: format!(
3441 "measure-jet term '{}': unparseable penalty label \
3442 '{label}'",
3443 term.name
3444 ),
3445 }
3446 })?;
3447 per_scale.push((level, read_lambda(info.global_index)?));
3448 }
3449 }
3450 PenaltySource::Primary => {
3451 fused = Some(read_lambda(info.global_index)?);
3452 }
3453 _ => {}
3454 }
3455 }
3456 let mut lambda_phys = Vec::with_capacity(n_levels);
3457 let spectrum = if per_scale.is_empty() {
3458 let Some(lam) = fused else {
3459 log::warn!(
3460 "measure-jet term '{}' has no fitted amplitude in the penalty \
3461 layout; skipping its extrapolation variance",
3462 term.name
3463 );
3464 continue;
3465 };
3466 let Some(c) = frozen.fused_penalty_normalization_scale else {
3467 log::warn!(
3468 "measure-jet term '{}' is missing the fused penalty normalization scale; \
3469 skipping its extrapolation variance",
3470 term.name
3471 );
3472 continue;
3473 };
3474 MeasureJetExtrapolationSpectrum::Fused(lam / c)
3475 } else {
3476 per_scale.sort_by_key(|&(level, _)| level);
3477 let levels_complete = per_scale.len() == n_levels
3478 && per_scale
3479 .iter()
3480 .enumerate()
3481 .all(|(i, &(level, _))| level == i);
3482 if !levels_complete {
3483 log::warn!(
3484 "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3485 scales; skipping its extrapolation variance",
3486 term.name,
3487 per_scale.len(),
3488 n_levels
3489 );
3490 continue;
3491 }
3492 if frozen.penalty_normalization_scales.len() != n_levels {
3493 log::warn!(
3494 "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3495 band scales; skipping its extrapolation variance",
3496 term.name,
3497 frozen.penalty_normalization_scales.len(),
3498 n_levels
3499 );
3500 continue;
3501 }
3502 lambda_phys.extend(
3503 per_scale
3504 .iter()
3505 .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3506 );
3507 MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3508 };
3509 let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3514 for (j, &col) in feature_cols.iter().enumerate() {
3515 if col >= data.ncols() {
3516 return Err(FittedModelError::SchemaMismatch {
3517 reason: format!(
3518 "measure-jet term '{}': prediction column {col} out of bounds \
3519 for {} data columns",
3520 term.name,
3521 data.ncols()
3522 ),
3523 });
3524 }
3525 queries.column_mut(j).assign(&data.column(col));
3526 }
3527 if let Some(scales) = input_scales {
3528 if scales.len() != feature_cols.len() {
3529 return Err(FittedModelError::SchemaMismatch {
3530 reason: format!(
3531 "measure-jet term '{}': {} input scales for {} axes",
3532 term.name,
3533 scales.len(),
3534 feature_cols.len()
3535 ),
3536 });
3537 }
3538 for (j, &scale) in scales.iter().enumerate() {
3539 queries.column_mut(j).mapv_inplace(|v| v / scale);
3540 }
3541 }
3542 let support = gam_terms::basis::measure_jet_support_curve(
3543 queries.view(),
3544 centers.view(),
3545 frozen.masses.view(),
3546 &frozen.eps_band,
3547 )
3548 .map_err(|e| FittedModelError::SchemaMismatch {
3549 reason: format!(
3550 "measure-jet term '{}': support curve failed: {e}",
3551 term.name
3552 ),
3553 })?;
3554 for i in 0..data.nrows() {
3555 let v = gam_terms::basis::measure_jet_extrapolation_variance(
3556 support.row(i),
3557 &frozen.eps_band,
3558 &frozen.support_means,
3559 spectrum,
3560 Self::MEASURE_JET_COVERAGE_FLOOR,
3561 )
3562 .map_err(|e| FittedModelError::SchemaMismatch {
3563 reason: format!(
3564 "measure-jet term '{}': extrapolation variance failed: {e}",
3565 term.name
3566 ),
3567 })?;
3568 total[i] += phi_scale * v;
3569 }
3570 contributed = true;
3571 }
3572 Ok(contributed.then_some(total))
3573 }
3574
3575 pub fn unified(&self) -> Option<&UnifiedFitResult> {
3577 self.payload().unified.as_ref()
3578 }
3579
3580 pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3581 let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3582 reason: format!("failed to read model '{}': {e}", path.display()),
3583 })?;
3584 let model: Self =
3585 serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3586 reason: format!("failed to parse model json: {e}"),
3587 })?;
3588 let model = model.with_synchronized_stateful_link_metadata();
3589 model.validate_for_persistence()?;
3590 model.validate_numeric_finiteness()?;
3591 Ok(model)
3592 }
3593
3594 pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3595 let normalized = self.clone().with_synchronized_stateful_link_metadata();
3596 normalized.validate_for_persistence()?;
3597 normalized.validate_numeric_finiteness()?;
3598 let parent = path.parent().unwrap_or_else(|| Path::new("."));
3605 let file_name = path
3606 .file_name()
3607 .and_then(|s| s.to_str())
3608 .unwrap_or("model.json");
3609 let pid = std::process::id();
3610 let nanos = std::time::SystemTime::now()
3611 .duration_since(std::time::UNIX_EPOCH)
3612 .map(|d| d.as_nanos())
3613 .unwrap_or(0);
3614 let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3615 let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3616 reason: format!("failed to write model '{}': {e}", tmp.display()),
3617 })?;
3618 let mut writer = std::io::BufWriter::new(file);
3619 let ser_result = serde_json::to_writer(&mut writer, &normalized);
3620 if let Err(e) = ser_result {
3621 std::io::Write::flush(&mut writer).ok();
3624 drop(writer);
3625 fs::remove_file(&tmp).ok();
3626 return Err(FittedModelError::PayloadCorrupt {
3627 reason: format!("failed to serialize model: {e}"),
3628 });
3629 }
3630 std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3631 reason: format!("failed to write model '{}': {e}", tmp.display()),
3632 })?;
3633 let inner = writer
3635 .into_inner()
3636 .map_err(|e| FittedModelError::PayloadCorrupt {
3637 reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3638 })?;
3639 inner.sync_all().ok();
3640 drop(inner);
3641 if let Err(e) = fs::rename(&tmp, path) {
3642 fs::remove_file(&tmp).ok();
3643 return Err(FittedModelError::PayloadCorrupt {
3644 reason: format!("failed to publish model '{}': {e}", path.display()),
3645 });
3646 }
3647 if let Ok(d) = fs::File::open(parent) {
3652 d.sync_all().ok();
3653 }
3654 Ok(())
3655 }
3656
3657 pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3658 self.data_schema
3659 .as_ref()
3660 .ok_or_else(|| FittedModelError::MissingField {
3661 reason: "model is missing data_schema; refit".to_string(),
3662 })
3663 }
3664
3665 pub fn saved_spline_scan(
3669 &self,
3670 ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3671 let Some(saved) = self.spline_scan.as_ref() else {
3672 return Ok(None);
3673 };
3674 let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3675 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3676 Ok(Some((saved.feature_column.as_str(), fit)))
3677 }
3678
3679 pub fn saved_residual_cascade(
3684 &self,
3685 ) -> Result<
3686 Option<(
3687 &[String],
3688 gam_solve::residual_cascade::ResidualCascadeFit,
3689 )>,
3690 FittedModelError,
3691 > {
3692 let Some(saved) = self.residual_cascade.as_ref() else {
3693 return Ok(None);
3694 };
3695 let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3696 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3697 Ok(Some((saved.feature_columns.as_slice(), fit)))
3698 }
3699
3700 pub fn random_effect_group_columns(&self) -> HashSet<String> {
3701 let Some(training_headers) = self.training_headers.as_ref() else {
3702 return HashSet::new();
3703 };
3704 let mut out = HashSet::<String>::new();
3705 for spec in self.saved_term_specs() {
3706 for term in &spec.random_effect_terms {
3707 if let Some(name) = training_headers.get(term.feature_col) {
3708 out.insert(name.clone());
3709 }
3710 }
3711 }
3712 out
3713 }
3714
3715 pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3716 self.validate_payload_version()?;
3730 if let Some(scan) = self.spline_scan.as_ref() {
3731 if self.fit_result.is_some() || self.unified.is_some() {
3736 return Err(FittedModelError::SchemaMismatch {
3737 reason: "spline-scan model must not also carry a dense fit_result/unified \
3738 payload; the representations are mutually exclusive"
3739 .to_string(),
3740 });
3741 }
3742 if self.model_kind != ModelKind::Standard
3743 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3744 {
3745 return Err(FittedModelError::SchemaMismatch {
3746 reason: format!(
3747 "spline-scan representation requires a standard Gaussian-identity model; \
3748 got model_kind={:?}, likelihood={:?}",
3749 self.model_kind,
3750 self.family_state.likelihood()
3751 ),
3752 });
3753 }
3754 if scan.feature_column.is_empty() {
3755 return Err(FittedModelError::MissingField {
3756 reason: "spline-scan model is missing its feature column name; refit"
3757 .to_string(),
3758 });
3759 }
3760 gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3761 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3762 if self.data_schema.is_none() {
3768 return Err(FittedModelError::MissingField {
3769 reason: "spline-scan model is missing data_schema; refit".to_string(),
3770 });
3771 }
3772 if self.training_headers.is_none() {
3773 return Err(FittedModelError::MissingField {
3774 reason: "spline-scan model is missing training_headers; refit".to_string(),
3775 });
3776 }
3777 return Ok(());
3778 } else if let Some(cascade) = self.residual_cascade.as_ref() {
3779 if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3783 return Err(FittedModelError::SchemaMismatch {
3784 reason: "residual-cascade model must not also carry spline_scan / \
3785 fit_result / unified payloads; the representations are \
3786 mutually exclusive"
3787 .to_string(),
3788 });
3789 }
3790 if self.model_kind != ModelKind::Standard
3791 || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3792 {
3793 return Err(FittedModelError::SchemaMismatch {
3794 reason: format!(
3795 "residual-cascade representation requires a standard Gaussian-identity \
3796 model; got model_kind={:?}, likelihood={:?}",
3797 self.model_kind,
3798 self.family_state.likelihood()
3799 ),
3800 });
3801 }
3802 if cascade.feature_columns.is_empty()
3803 || !(2..=3).contains(&cascade.feature_columns.len())
3804 {
3805 return Err(FittedModelError::MissingField {
3806 reason: format!(
3807 "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3808 cascade.feature_columns.len()
3809 ),
3810 });
3811 }
3812 gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3813 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3814 if self.data_schema.is_none() {
3815 return Err(FittedModelError::MissingField {
3816 reason: "residual-cascade model is missing data_schema; refit".to_string(),
3817 });
3818 }
3819 if self.training_headers.is_none() {
3820 return Err(FittedModelError::MissingField {
3821 reason: "residual-cascade model is missing training_headers; refit".to_string(),
3822 });
3823 }
3824 return Ok(());
3825 } else if self.fit_result.is_none() {
3826 return Err(FittedModelError::MissingField {
3827 reason: "model is missing canonical fit_result payload; refit".to_string(),
3828 });
3829 }
3830 if self.data_schema.is_none() {
3831 return Err(FittedModelError::MissingField {
3832 reason: "model is missing data_schema; refit".to_string(),
3833 });
3834 }
3835 if self.training_headers.is_none() {
3836 return Err(FittedModelError::MissingField {
3837 reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3838 .to_string(),
3839 });
3840 }
3841 let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3842 FittedModelError::MissingField {
3843 reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3844 .to_string(),
3845 }
3846 })?;
3847 validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3848
3849 if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3850 return Err(FittedModelError::MissingField {
3851 reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3852 .to_string(),
3853 });
3854 }
3855 if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3856 validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3857 }
3858 if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3859 let score = self.transformation_score_calibration.ok_or_else(|| {
3860 FittedModelError::MissingField {
3861 reason: "transformation-normal model is missing transformation_score_calibration; refit"
3862 .to_string(),
3863 }
3864 })?;
3865 score.validate("transformation-normal model")?;
3866 }
3867 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3868 if self.formula_logslope.is_none() {
3869 return Err(FittedModelError::MissingField {
3870 reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3871 });
3872 }
3873 if self.z_column.is_none() {
3874 return Err(FittedModelError::MissingField {
3875 reason: "marginal-slope model is missing z_column; refit".to_string(),
3876 });
3877 }
3878 let z_normalization =
3879 self.latent_z_normalization
3880 .ok_or_else(|| FittedModelError::MissingField {
3881 reason: "marginal-slope model is missing latent_z_normalization; refit"
3882 .to_string(),
3883 })?;
3884 z_normalization.validate("marginal-slope model")?;
3885 let latent_measure =
3886 self.latent_measure
3887 .as_ref()
3888 .ok_or_else(|| FittedModelError::MissingField {
3889 reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3890 })?;
3891 latent_measure
3892 .validate("marginal-slope model latent_measure")
3893 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3894 if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3895 return Err(FittedModelError::MissingField {
3896 reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3897 });
3898 }
3899 if self.resolved_termspec_logslope.as_ref().is_none() {
3900 return Err(FittedModelError::MissingField {
3901 reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3902 .to_string(),
3903 });
3904 }
3905 match self.family_state.frailty() {
3906 Some(FrailtySpec::None)
3907 | Some(FrailtySpec::GaussianShift {
3908 sigma_fixed: Some(_),
3909 }) => {}
3910 Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3911 return Err(FittedModelError::IncompatibleConfig {
3912 reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3913 .to_string(),
3914 });
3915 }
3916 Some(FrailtySpec::HazardMultiplier { .. }) => {
3917 return Err(FittedModelError::IncompatibleConfig {
3918 reason: "marginal-slope model does not support HazardMultiplier frailty"
3919 .to_string(),
3920 });
3921 }
3922 None => {
3923 return Err(FittedModelError::MissingField {
3924 reason: "marginal-slope model is missing family_state.frailty; refit"
3925 .to_string(),
3926 });
3927 }
3928 }
3929 }
3930
3931 if let FittedFamily::Survival {
3932 survival_likelihood,
3933 frailty,
3934 ..
3935 } = &self.family_state
3936 {
3937 if matches!(
3938 survival_likelihood.as_deref(),
3939 Some("latent") | Some("latent-binary")
3940 ) {
3941 return Err(FittedModelError::SchemaMismatch {
3942 reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3943 .to_string(),
3944 });
3945 }
3946 if survival_likelihood.as_deref() == Some("marginal-slope") {
3947 if self.formula_logslope.is_none() {
3948 return Err(FittedModelError::MissingField {
3949 reason: "survival marginal-slope model is missing formula_logslope; refit"
3950 .to_string(),
3951 });
3952 }
3953 if self.z_column.is_none() {
3954 return Err(FittedModelError::MissingField {
3955 reason: "survival marginal-slope model is missing z_column; refit"
3956 .to_string(),
3957 });
3958 }
3959 let z_normalization =
3960 self.latent_z_normalization
3961 .ok_or_else(|| {
3962 FittedModelError::MissingField {
3963 reason:
3964 "survival marginal-slope model is missing latent_z_normalization; refit"
3965 .to_string(),
3966 }
3967 })?;
3968 z_normalization.validate("survival marginal-slope model")?;
3969 let latent_measure =
3970 self.latent_measure
3971 .as_ref()
3972 .ok_or_else(|| FittedModelError::MissingField {
3973 reason:
3974 "survival marginal-slope model is missing latent_measure; refit"
3975 .to_string(),
3976 })?;
3977 latent_measure
3978 .validate("survival marginal-slope model latent_measure")
3979 .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3980 if self.logslope_baseline.is_none() {
3981 return Err(FittedModelError::MissingField {
3982 reason: "survival marginal-slope model is missing logslope_baseline; refit"
3983 .to_string(),
3984 });
3985 }
3986 if self.resolved_termspec_logslope.as_ref().is_none() {
3987 return Err(FittedModelError::MissingField {
3988 reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3989 .to_string(),
3990 });
3991 }
3992 match frailty {
3993 FrailtySpec::None
3994 | FrailtySpec::GaussianShift {
3995 sigma_fixed: Some(_),
3996 } => {}
3997 FrailtySpec::GaussianShift { sigma_fixed: None } => {
3998 return Err(FittedModelError::IncompatibleConfig {
3999 reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4000 .to_string(),
4001 });
4002 }
4003 FrailtySpec::HazardMultiplier { .. } => {
4004 return Err(FittedModelError::IncompatibleConfig {
4005 reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4006 .to_string(),
4007 });
4008 }
4009 }
4010 } else if !matches!(frailty, FrailtySpec::None) {
4011 return Err(FittedModelError::IncompatibleConfig {
4012 reason:
4013 "non-marginal survival models do not currently persist a frailty modifier"
4014 .to_string(),
4015 });
4016 }
4017 if self.survival_time_basis.is_none() {
4025 return Err(FittedModelError::MissingField {
4026 reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4027 });
4028 }
4029 if self.survival_time_anchor.is_none() {
4030 return Err(FittedModelError::MissingField {
4031 reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4032 });
4033 }
4034 }
4035 if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4036 match frailty {
4037 FrailtySpec::HazardMultiplier {
4038 sigma_fixed: Some(_),
4039 ..
4040 } => {}
4041 FrailtySpec::HazardMultiplier {
4042 sigma_fixed: None, ..
4043 } => {
4044 return Err(FittedModelError::IncompatibleConfig {
4045 reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4046 .to_string(),
4047 });
4048 }
4049 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4050 return Err(FittedModelError::IncompatibleConfig {
4051 reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4052 .to_string(),
4053 });
4054 }
4055 }
4056 if self.survival_likelihood.as_deref() != Some("latent") {
4057 return Err(FittedModelError::SchemaMismatch {
4058 reason: "latent survival model must persist survival_likelihood=latent"
4059 .to_string(),
4060 });
4061 }
4062 }
4063 if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4064 match frailty {
4065 FrailtySpec::HazardMultiplier {
4066 sigma_fixed: Some(_),
4067 ..
4068 } => {}
4069 FrailtySpec::HazardMultiplier {
4070 sigma_fixed: None, ..
4071 } => {
4072 return Err(FittedModelError::IncompatibleConfig {
4073 reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4074 .to_string(),
4075 });
4076 }
4077 FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4078 return Err(FittedModelError::IncompatibleConfig {
4079 reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4080 .to_string(),
4081 });
4082 }
4083 }
4084 if self.survival_likelihood.as_deref() != Some("latent-binary") {
4085 return Err(FittedModelError::SchemaMismatch {
4086 reason: "latent binary model must persist survival_likelihood=latent-binary"
4087 .to_string(),
4088 });
4089 }
4090 }
4091
4092 let family_likelihood = match &self.family_state {
4093 FittedFamily::Standard { likelihood, .. }
4094 | FittedFamily::LocationScale { likelihood, .. }
4095 | FittedFamily::MarginalSlope { likelihood, .. }
4096 | FittedFamily::Survival { likelihood, .. }
4097 | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4098 FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4099 };
4100 let is_standard_or_location_scale = matches!(
4101 self.family_state,
4102 FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4103 );
4104 if is_standard_or_location_scale
4105 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4106 {
4107 self.saved_sas_state()?;
4108 }
4109 if is_standard_or_location_scale
4110 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4111 {
4112 self.saved_beta_logistic_state()?;
4113 }
4114 if is_standard_or_location_scale
4115 && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4116 {
4117 self.saved_mixture_state()?;
4118 }
4119 if matches!(self.family_state, FittedFamily::Standard { .. })
4120 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4121 {
4122 self.saved_latent_cloglog_state()?;
4123 }
4124 if matches!(self.family_state, FittedFamily::LocationScale { .. })
4125 && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4126 {
4127 return Err(FittedModelError::IncompatibleConfig {
4128 reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4129 .to_string(),
4130 });
4131 }
4132 if matches!(self.family_state, FittedFamily::Survival { .. })
4133 && self.survival_likelihood.is_none()
4134 {
4135 return Err(FittedModelError::MissingField {
4136 reason: "saved survival model is missing survival_likelihood metadata; refit"
4137 .to_string(),
4138 });
4139 }
4140 let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4141 || self.linkwiggle_degree.is_some()
4142 || self.beta_link_wiggle.is_some()
4143 || self
4144 .fit_result
4145 .as_ref()
4146 .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4147 .is_some();
4148 let saved_link_wiggle = self.saved_link_wiggle()?;
4149 if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4150 return Err(FittedModelError::SchemaMismatch {
4151 reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4152 .to_string(),
4153 });
4154 }
4155 let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4156 || self.baseline_timewiggle_degree.is_some()
4157 || self.baseline_timewiggle_penalty_orders.is_some()
4158 || self.baseline_timewiggle_double_penalty.is_some()
4159 || self.beta_baseline_timewiggle.is_some()
4160 || self.beta_baseline_timewiggle_by_cause.is_some();
4161 let is_joint_cause_specific = self
4162 .survival_cause_count
4163 .is_some_and(|cause_count| cause_count > 1);
4164 if has_any_saved_baseline_time_wiggle {
4165 if is_joint_cause_specific {
4166 let complete = self.baseline_timewiggle_knots.is_some()
4167 && self.baseline_timewiggle_degree.is_some()
4168 && self.baseline_timewiggle_penalty_orders.is_some()
4169 && self.baseline_timewiggle_double_penalty.is_some()
4170 && self.beta_baseline_timewiggle_by_cause.is_some();
4171 if !complete {
4172 return Err(FittedModelError::SchemaMismatch {
4173 reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4174 .to_string(),
4175 });
4176 }
4177 } else if self.saved_baseline_time_wiggle()?.is_none() {
4178 return Err(FittedModelError::SchemaMismatch {
4179 reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4180 .to_string(),
4181 });
4182 }
4183 }
4184 if self
4185 .survival_likelihood
4186 .as_deref()
4187 .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4188 {
4189 validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4190 }
4191
4192 if let Some(runtime) = self.score_warp_runtime.as_ref() {
4202 runtime.validate_exact_replay_contract().map_err(|err| {
4203 FittedModelError::PayloadCorrupt {
4204 reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4205 }
4206 })?;
4207 }
4208 if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4209 runtime.validate_exact_replay_contract().map_err(|err| {
4210 FittedModelError::PayloadCorrupt {
4211 reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4212 }
4213 })?;
4214 }
4215 if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4216 validate_marginal_slope_saved_fit(
4217 self.fit_result.as_ref().expect("checked above"),
4218 self.score_warp_runtime.as_ref(),
4219 self.link_deviation_runtime.as_ref(),
4220 "fit_result",
4221 )?;
4222 let unified = self
4223 .unified
4224 .as_ref()
4225 .ok_or_else(|| FittedModelError::MissingField {
4226 reason: "marginal-slope model is missing unified fit payload; refit"
4227 .to_string(),
4228 })?;
4229 validate_marginal_slope_saved_fit(
4230 unified,
4231 self.score_warp_runtime.as_ref(),
4232 self.link_deviation_runtime.as_ref(),
4233 "unified",
4234 )?;
4235 }
4236 if self
4237 .survival_likelihood
4238 .as_deref()
4239 .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4240 {
4241 validate_survival_marginal_slope_saved_fit(
4242 self.fit_result.as_ref().expect("checked above"),
4243 self.score_warp_runtime.as_ref(),
4244 self.link_deviation_runtime.as_ref(),
4245 "fit_result",
4246 )?;
4247 if let Some(unified) = self.unified.as_ref() {
4248 validate_survival_marginal_slope_saved_fit(
4249 unified,
4250 self.score_warp_runtime.as_ref(),
4251 self.link_deviation_runtime.as_ref(),
4252 "unified",
4253 )?;
4254 }
4255 }
4256
4257 Ok(())
4268 }
4269
4270 pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4271 let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4272 if let Some(fit) = self.fit_result.as_ref() {
4273 fit.validate_numeric_finiteness()
4274 .map_err(|e| corrupt(e.to_string()))?;
4275 }
4276
4277 for (name, opt) in [
4278 ("survival_baseline_scale", self.survival_baseline_scale),
4279 ("survival_baseline_shape", self.survival_baseline_shape),
4280 ("survival_baseline_rate", self.survival_baseline_rate),
4281 ("survival_baseline_makeham", self.survival_baseline_makeham),
4282 (
4283 "survival_time_smooth_lambda",
4284 self.survival_time_smooth_lambda,
4285 ),
4286 ("survival_time_anchor", self.survival_time_anchor),
4287 ("survivalridge_lambda", self.survivalridge_lambda),
4288 ] {
4289 if let Some(v) = opt {
4290 ensure_finite_scalar(name, v).map_err(corrupt)?;
4291 }
4292 }
4293
4294 if let Some(v) = self.beta_noise.as_ref() {
4295 validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4296 }
4297 if let Some(v) = self.noise_projection.as_ref() {
4298 validate_all_finite("noise_projection", v.iter().flatten().copied())
4299 .map_err(corrupt)?;
4300 if self.noise_projection_ridge_alpha.is_none() {
4301 return Err(FittedModelError::MissingField {
4302 reason:
4303 "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4304 .to_string(),
4305 });
4306 }
4307 }
4308 if let Some(v) = self.noise_center.as_ref() {
4309 validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4310 }
4311 if let Some(v) = self.noise_scale.as_ref() {
4312 validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4313 }
4314 if let Some(v) = self.noise_projection_ridge_alpha {
4315 ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4316 if v < 0.0 {
4317 return Err(FittedModelError::InvalidInput {
4318 reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4319 });
4320 }
4321 }
4322 if let Some(v) = self.gaussian_response_scale {
4323 ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4324 }
4325 if let Some(v) = self.beta_link_wiggle.as_ref() {
4326 validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4327 }
4328 if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4329 validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4330 }
4331 if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4332 validate_all_finite(
4333 "beta_baseline_timewiggle_by_cause",
4334 v.iter().flatten().copied(),
4335 )
4336 .map_err(corrupt)?;
4337 }
4338 if let Some(v) = self.latent_z_normalization {
4339 v.validate("latent_z_normalization")?;
4340 }
4341 if let Some(v) = self.latent_measure.as_ref() {
4342 v.validate("latent_measure").map_err(corrupt)?;
4343 }
4344 if let Some(v) = self.survival_beta_time.as_ref() {
4345 validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4346 }
4347 if let Some(v) = self.survival_beta_threshold.as_ref() {
4348 validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4349 }
4350 if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4351 validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4352 }
4353 if let Some(v) = self.survival_noise_projection.as_ref() {
4354 validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4355 .map_err(corrupt)?;
4356 if self.survival_noise_projection_ridge_alpha.is_none() {
4357 return Err(FittedModelError::MissingField {
4358 reason:
4359 "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4360 .to_string(),
4361 });
4362 }
4363 }
4364 if let Some(v) = self.survival_noise_center.as_ref() {
4365 validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4366 }
4367 if let Some(v) = self.survival_noise_projection_ridge_alpha {
4368 ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4369 if v < 0.0 {
4370 return Err(FittedModelError::InvalidInput {
4371 reason: format!(
4372 "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4373 ),
4374 });
4375 }
4376 }
4377 if let Some(v) = self.survival_noise_scale.as_ref() {
4378 validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4379 }
4380 if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4381 validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4382 .map_err(corrupt)?;
4383 }
4384 if let Some(v) = self.sas_param_covariance.as_ref() {
4385 validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4386 .map_err(corrupt)?;
4387 }
4388 Ok(())
4389 }
4390}
4391
4392fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4393 a.rows().into_iter().map(|row| row.to_vec()).collect()
4394}
4395
4396use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4397
4398fn validate_frozen_term_collectionspec(
4399 spec: &TermCollectionSpec,
4400 label: &str,
4401) -> Result<(), FittedModelError> {
4402 spec.validate_frozen(label)
4403 .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4404}
4405
4406impl Deref for FittedModel {
4407 type Target = FittedModelPayload;
4408
4409 fn deref(&self) -> &Self::Target {
4410 self.payload()
4411 }
4412}
4413
4414impl DerefMut for FittedModel {
4415 fn deref_mut(&mut self) -> &mut Self::Target {
4416 self.payload_mut()
4417 }
4418}
4419
4420pub fn survival_baseline_config_from_model(
4425 model: &FittedModel,
4426) -> Result<SurvivalBaselineConfig, FittedModelError> {
4427 let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4428 FittedModelError::MissingField {
4429 reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4430 }
4431 })?;
4432 parse_survival_baseline_config(
4433 target,
4434 model.survival_baseline_scale,
4435 model.survival_baseline_shape,
4436 model.survival_baseline_rate,
4437 model.survival_baseline_makeham,
4438 )
4439 .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4440}
4441
4442pub fn load_survival_time_basis_config_from_model(
4443 model: &FittedModel,
4444) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4445 match model
4446 .survival_time_basis
4447 .as_deref()
4448 .ok_or_else(|| FittedModelError::MissingField {
4449 reason: "saved survival model missing survival_time_basis".to_string(),
4450 })?
4451 .to_ascii_lowercase()
4452 .as_str()
4453 {
4454 "none" => Ok(SurvivalTimeBasisConfig::None),
4455 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4456 "bspline" => {
4457 let degree =
4458 model
4459 .survival_time_degree
4460 .ok_or_else(|| FittedModelError::MissingField {
4461 reason: "saved survival bspline model missing survival_time_degree"
4462 .to_string(),
4463 })?;
4464 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4465 FittedModelError::MissingField {
4466 reason: "saved survival bspline model missing survival_time_knots".to_string(),
4467 }
4468 })?;
4469 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4470 if degree < 1 || knots.is_empty() {
4471 return Err(FittedModelError::SchemaMismatch {
4472 reason: "saved survival bspline time basis metadata is invalid".to_string(),
4473 });
4474 }
4475 Ok(SurvivalTimeBasisConfig::BSpline {
4476 degree,
4477 knots: Array1::from_vec(knots),
4478 smooth_lambda,
4479 })
4480 }
4481 "ispline" => {
4482 let degree =
4483 model
4484 .survival_time_degree
4485 .ok_or_else(|| FittedModelError::MissingField {
4486 reason: "saved survival ispline model missing survival_time_degree"
4487 .to_string(),
4488 })?;
4489 let knots = model.survival_time_knots.clone().ok_or_else(|| {
4490 FittedModelError::MissingField {
4491 reason: "saved survival ispline model missing survival_time_knots".to_string(),
4492 }
4493 })?;
4494 let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4495 FittedModelError::MissingField {
4496 reason: "saved survival ispline model missing survival_time_keep_cols"
4497 .to_string(),
4498 }
4499 })?;
4500 let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4501 if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4502 return Err(FittedModelError::SchemaMismatch {
4503 reason: "saved survival ispline time basis metadata is invalid".to_string(),
4504 });
4505 }
4506 Ok(SurvivalTimeBasisConfig::ISpline {
4507 degree,
4508 knots: Array1::from_vec(knots),
4509 keep_cols,
4510 smooth_lambda,
4511 })
4512 }
4513 other => Err(FittedModelError::IncompatibleConfig {
4514 reason: format!("unsupported saved survival_time_basis '{other}'"),
4515 }),
4516 }
4517}
4518
4519#[cfg(test)]
4520mod tests {
4521 use super::*;
4522 use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4523 use crate::survival::lognormal_kernel::FrailtySpec;
4524 use gam_solve::pirls::PirlsStatus;
4525 use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4526 use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4527 use gam_data::SchemaColumn;
4528 use ndarray::{Array1, Array2, array};
4529
4530 fn empty_termspec() -> TermCollectionSpec {
4531 TermCollectionSpec {
4532 linear_terms: vec![],
4533 random_effect_terms: vec![],
4534 smooth_terms: vec![],
4535 }
4536 }
4537
4538 #[test]
4542 fn spline_scan_payload_round_trips_and_validates() {
4543 let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4544 let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4545 let w = vec![1.0_f64; x.len()];
4546 let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4547 let make_payload = || {
4548 crate::inference::model_payload_builders::assemble_spline_scan_payload(
4549 "y ~ s(x)".to_string(),
4550 "x".to_string(),
4551 &fit,
4552 DataSchema {
4553 columns: vec![
4554 SchemaColumn {
4555 name: "y".to_string(),
4556 kind: ColumnKindTag::Continuous,
4557 levels: vec![],
4558 },
4559 SchemaColumn {
4560 name: "x".to_string(),
4561 kind: ColumnKindTag::Continuous,
4562 levels: vec![],
4563 },
4564 ],
4565 },
4566 vec!["x".to_string()],
4567 vec![(0.0, 1.0)],
4568 )
4569 };
4570 let model = FittedModel::from_payload(make_payload());
4573 model
4574 .validate_for_persistence()
4575 .expect("scan model validates");
4576 model
4577 .validate_numeric_finiteness()
4578 .expect("scan model is finite");
4579
4580 let json = serde_json::to_string(&model).expect("serialize model");
4581 let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4582 restored
4583 .validate_for_persistence()
4584 .expect("restored scan model validates");
4585 let (column, replay) = restored
4586 .saved_spline_scan()
4587 .expect("restore scan fit")
4588 .expect("payload carries the scan representation");
4589 assert_eq!(column, "x");
4590 for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4591 let (m0, v0) = fit.predict(xq).expect("predict original");
4592 let (m1, v1) = replay.predict(xq).expect("predict replayed");
4593 assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4594 assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4595 }
4596
4597 let mut dense = make_payload();
4599 dense.spline_scan = None;
4600 let err = FittedModel::from_payload(dense)
4601 .validate_for_persistence()
4602 .expect_err("dense payload without fit_result must be rejected");
4603 assert!(err.to_string().contains("fit_result"));
4604
4605 let mut corrupt = make_payload();
4607 corrupt
4608 .spline_scan
4609 .as_mut()
4610 .expect("scan channel present")
4611 .state
4612 .knots
4613 .truncate(2);
4614 FittedModel::from_payload(corrupt)
4615 .validate_for_persistence()
4616 .expect_err("corrupt scan state must be rejected");
4617 let mut unnamed = make_payload();
4618 unnamed
4619 .spline_scan
4620 .as_mut()
4621 .expect("scan channel present")
4622 .feature_column
4623 .clear();
4624 FittedModel::from_payload(unnamed)
4625 .validate_for_persistence()
4626 .expect_err("missing feature column must be rejected");
4627 }
4628
4629 fn standard_gaussian_payload() -> FittedModelPayload {
4630 FittedModelPayload::new(
4631 MODEL_PAYLOAD_VERSION,
4632 "y ~ 1".to_string(),
4633 ModelKind::Standard,
4634 FittedFamily::Standard {
4635 likelihood: LikelihoodSpec::gaussian_identity(),
4636 link: Some(StandardLink::Identity),
4637 latent_cloglog_state: None,
4638 mixture_state: None,
4639 sas_state: None,
4640 },
4641 "gaussian".to_string(),
4642 )
4643 }
4644
4645 fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4646 SavedCompiledFlexBlock {
4647 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4648 breakpoints: vec![-1.0, 1.0],
4649 basis_dim,
4650 span_c0: vec![vec![0.0; basis_dim]],
4651 span_c1: vec![vec![0.0; basis_dim]],
4652 span_c2: vec![vec![0.0; basis_dim]],
4653 span_c3: vec![vec![0.0; basis_dim]],
4654 anchor_correction: None,
4655 anchor_components: Vec::new(),
4656 }
4657 }
4658
4659 fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4660 let beta = Array1::from_vec(
4661 blocks
4662 .iter()
4663 .flat_map(|block| block.beta.iter().copied())
4664 .collect(),
4665 );
4666 let p = beta.len();
4667 UnifiedFitResult {
4668 blocks,
4669 log_lambdas: Array1::zeros(0),
4670 lambdas: Array1::zeros(0),
4671 likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4672 likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4673 log_likelihood_normalization: LogLikelihoodNormalization::Full,
4674 log_likelihood: 0.0,
4675 deviance: 0.0,
4676 reml_score: 0.0,
4677 stable_penalty_term: 0.0,
4678 penalized_objective: 0.0,
4679 used_device: false,
4680 outer_iterations: 0,
4681 outer_cost_evals: 0,
4682 outer_converged: true,
4683 outer_gradient_norm: None,
4684 standard_deviation: 1.0,
4685 covariance_conditional: Some(Array2::zeros((p, p))),
4686 covariance_corrected: Some(Array2::zeros((p, p))),
4687 inference: None,
4688 fitted_link: FittedLinkState::Standard(None),
4689 geometry: None,
4690 block_states: vec![],
4691 beta,
4692 pirls_status: PirlsStatus::Converged,
4693 max_abs_eta: 0.0,
4694 constraint_kkt: None,
4695 artifacts: FitArtifacts {
4696 pirls: None,
4697 null_space_logdet: None,
4698 null_space_dim: None,
4699 survival_link_wiggle_knots: None,
4700 survival_link_wiggle_degree: None,
4701 criterion_certificate: None,
4702 rho_posterior_certificate: None,
4703 rho_posterior_escalation: None,
4704 rho_covariance: None,
4705 },
4706 inner_cycles: 0,
4707 }
4708 }
4709
4710 fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4711 let mut payload = FittedModelPayload::new(
4712 version,
4713 "y ~ 1".to_string(),
4714 ModelKind::MarginalSlope,
4715 FittedFamily::MarginalSlope {
4716 likelihood: LikelihoodSpec::binomial_probit(),
4717 base_link: InverseLink::Standard(StandardLink::Probit),
4718 frailty: FrailtySpec::None,
4719 },
4720 "bernoulli-marginal-slope".to_string(),
4721 );
4722 payload.fit_result = Some(fit.clone());
4723 payload.unified = Some(fit);
4724 payload.data_schema = Some(DataSchema {
4725 columns: vec![SchemaColumn {
4726 name: "z".to_string(),
4727 kind: ColumnKindTag::Continuous,
4728 levels: vec![],
4729 }],
4730 });
4731 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4732 payload.resolved_termspec = Some(empty_termspec());
4733 payload.resolved_termspec_logslope = Some(empty_termspec());
4734 payload.formula_logslope = Some("1".to_string());
4735 payload.z_column = Some("z".to_string());
4736 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4737 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4738 payload.marginal_baseline = Some(0.0);
4739 payload.logslope_baseline = Some(0.0);
4740 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4741 payload
4742 }
4743
4744 #[test]
4745 fn from_payload_synchronizes_used_device_from_saved_fit() {
4746 let mut fit = saved_fit(vec![
4747 FittedBlock {
4748 beta: Array1::from_vec(vec![0.25]),
4749 role: BlockRole::Mean,
4750 edf: 1.0,
4751 lambdas: Array1::zeros(0),
4752 },
4753 FittedBlock {
4754 beta: Array1::from_vec(vec![0.5]),
4755 role: BlockRole::Scale,
4756 edf: 1.0,
4757 lambdas: Array1::zeros(0),
4758 },
4759 ]);
4760 fit.used_device = true;
4761 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4762 payload.used_device = false;
4763
4764 let model = FittedModel::from_payload(payload);
4765
4766 assert!(model.payload().used_device);
4767 }
4768
4769 fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4770 let mut payload = FittedModelPayload::new(
4771 version,
4772 "Surv(entry, exit, event) ~ 1".to_string(),
4773 ModelKind::Survival,
4774 FittedFamily::Survival {
4775 likelihood: LikelihoodSpec::royston_parmar(),
4776 survival_likelihood: Some("marginal-slope".to_string()),
4777 survival_distribution: Some(ResidualDistribution::Gaussian),
4778 frailty: FrailtySpec::None,
4779 },
4780 "survival".to_string(),
4781 );
4782 payload.fit_result = Some(fit.clone());
4783 payload.unified = Some(fit);
4784 payload.survival_likelihood = Some("marginal-slope".to_string());
4785 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4786 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4787 payload.data_schema = Some(DataSchema {
4788 columns: vec![SchemaColumn {
4789 name: "z".to_string(),
4790 kind: ColumnKindTag::Continuous,
4791 levels: vec![],
4792 }],
4793 });
4794 payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4795 payload.resolved_termspec = Some(empty_termspec());
4796 payload.resolved_termspec_logslope = Some(empty_termspec());
4797 payload.formula_logslope = Some("1".to_string());
4798 payload.z_column = Some("z".to_string());
4799 payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4800 payload.logslope_baseline = Some(0.0);
4801 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4802 payload
4803 }
4804
4805 fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4806 let mut payload = FittedModelPayload::new(
4812 MODEL_PAYLOAD_VERSION,
4813 "y ~ x".to_string(),
4814 ModelKind::LocationScale,
4815 FittedFamily::LocationScale {
4816 likelihood: LikelihoodSpec::gamma_log(),
4817 base_link: Some(InverseLink::Standard(StandardLink::Log)),
4818 },
4819 "gamma-location-scale".to_string(),
4820 );
4821 payload.data_schema = Some(DataSchema {
4822 columns: vec![
4823 SchemaColumn {
4824 name: "y".to_string(),
4825 kind: ColumnKindTag::Continuous,
4826 levels: vec![],
4827 },
4828 SchemaColumn {
4829 name: "x".to_string(),
4830 kind: ColumnKindTag::Continuous,
4831 levels: vec![],
4832 },
4833 ],
4834 });
4835 payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4836 payload.resolved_termspec = Some(empty_termspec());
4837 payload.resolved_termspec_noise = Some(empty_termspec());
4838 payload.formula_noise = Some("x".to_string());
4839 payload.beta_noise = Some(vec![0.0]);
4840 payload.link = Some(InverseLink::Standard(StandardLink::Log));
4841 payload
4842 }
4843
4844 #[test]
4851 fn dispersion_location_scale_payload_is_not_classified_binomial() {
4852 let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4853 assert_eq!(
4854 model.predict_model_class(),
4855 PredictModelClass::DispersionLocationScale,
4856 "Gamma dispersion location-scale must route through the dispersion \
4857 predictor, not the binomial threshold-scale class",
4858 );
4859 assert!(
4860 !matches!(
4861 model.predict_model_class(),
4862 PredictModelClass::BinomialLocationScale
4863 ),
4864 "dispersion location-scale must never be classified as binomial",
4865 );
4866
4867 for likelihood in [
4869 LikelihoodSpec::gamma_log(),
4870 LikelihoodSpec::new(
4871 ResponseFamily::NegativeBinomial {
4872 theta: 1.0,
4873 theta_fixed: false,
4874 },
4875 InverseLink::Standard(StandardLink::Log),
4876 ),
4877 LikelihoodSpec::new(
4878 ResponseFamily::Beta { phi: 1.0 },
4879 InverseLink::Standard(StandardLink::Logit),
4880 ),
4881 LikelihoodSpec::new(
4882 ResponseFamily::Tweedie { p: 1.5 },
4883 InverseLink::Standard(StandardLink::Log),
4884 ),
4885 ] {
4886 let mut payload = gamma_dispersion_location_scale_payload();
4887 payload.family_state = FittedFamily::LocationScale {
4888 base_link: Some(likelihood.link.clone()),
4889 likelihood: likelihood.clone(),
4890 };
4891 let model = FittedModel::from_payload(payload);
4892 assert_eq!(
4893 model.predict_model_class(),
4894 PredictModelClass::DispersionLocationScale,
4895 "dispersion family {:?} mis-classified",
4896 likelihood.response,
4897 );
4898 }
4899 }
4900
4901 #[test]
4902 fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4903 let data = array![[100.0], [-100.0]];
4904 let col_map = HashMap::from([("g".to_string(), 0usize)]);
4905
4906 let mut plain_payload = standard_gaussian_payload();
4907 plain_payload.data_schema = Some(DataSchema {
4908 columns: vec![SchemaColumn {
4909 name: "g".to_string(),
4910 kind: ColumnKindTag::Continuous,
4911 levels: vec![],
4912 }],
4913 });
4914 plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4915 plain_payload.resolved_termspec = Some(empty_termspec());
4916 let plain = FittedModel::from_payload(plain_payload.clone());
4917 let clipped = plain
4918 .axis_clip_to_training_ranges(data.view(), &col_map)
4919 .expect("ordinary continuous axis should clip outside the training range");
4920 assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4921
4922 let mut group_payload = plain_payload;
4923 let mut group_spec = empty_termspec();
4924 group_spec
4925 .random_effect_terms
4926 .push(gam_terms::smooth::RandomEffectTermSpec {
4927 name: "g".to_string(),
4928 feature_col: 0,
4929 drop_first_level: false,
4930 penalized: true,
4931 frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4932 });
4933 group_payload.resolved_termspec = Some(group_spec);
4934 let group_model = FittedModel::from_payload(group_payload);
4935
4936 assert_eq!(
4937 group_model.random_effect_group_columns(),
4938 HashSet::from(["g".to_string()])
4939 );
4940
4941 assert_eq!(
4942 group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4943 None,
4944 "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4945 );
4946 }
4947
4948 #[test]
4949 fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4950 let fit = saved_fit(vec![
4951 FittedBlock {
4952 beta: array![0.1],
4953 role: BlockRole::Mean,
4954 edf: 1.0,
4955 lambdas: Array1::zeros(0),
4956 },
4957 FittedBlock {
4958 beta: array![0.2],
4959 role: BlockRole::Scale,
4960 edf: 1.0,
4961 lambdas: Array1::zeros(0),
4962 },
4963 FittedBlock {
4964 beta: array![0.3],
4965 role: BlockRole::Mean,
4966 edf: 1.0,
4967 lambdas: Array1::zeros(0),
4968 },
4969 ]);
4970 let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4971 payload.score_warp_runtime = Some(anchored_runtime(2));
4972
4973 let err = FittedModel::from_payload(payload)
4974 .validate_for_persistence()
4975 .expect_err("marginal-slope score-warp basis mismatch should fail validation");
4976 assert!(err.to_string().contains("score-warp coefficient mismatch"));
4977 }
4978
4979 #[test]
4980 fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
4981 let fit = saved_fit(vec![
4982 FittedBlock {
4983 beta: array![0.1],
4984 role: BlockRole::Time,
4985 edf: 1.0,
4986 lambdas: Array1::zeros(0),
4987 },
4988 FittedBlock {
4989 beta: array![0.2],
4990 role: BlockRole::Mean,
4991 edf: 1.0,
4992 lambdas: Array1::zeros(0),
4993 },
4994 FittedBlock {
4995 beta: array![0.3],
4996 role: BlockRole::Scale,
4997 edf: 1.0,
4998 lambdas: Array1::zeros(0),
4999 },
5000 FittedBlock {
5001 beta: array![0.4],
5002 role: BlockRole::LinkWiggle,
5003 edf: 1.0,
5004 lambdas: Array1::zeros(0),
5005 },
5006 ]);
5007 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5008 payload.link_deviation_runtime = Some(anchored_runtime(2));
5009
5010 let err = FittedModel::from_payload(payload)
5011 .saved_prediction_runtime()
5012 .expect_err(
5013 "survival marginal-slope link basis mismatch should fail runtime validation",
5014 );
5015 assert!(
5016 err.to_string()
5017 .contains("link-deviation coefficient mismatch")
5018 );
5019 }
5020
5021 #[test]
5022 fn apply_survival_time_basis_writes_all_required_fields() {
5023 use crate::survival::construction::SavedSurvivalTimeBasis;
5024
5025 let fit = saved_fit(vec![
5026 FittedBlock {
5027 beta: array![0.1],
5028 role: BlockRole::Time,
5029 edf: 1.0,
5030 lambdas: Array1::zeros(0),
5031 },
5032 FittedBlock {
5033 beta: array![0.2],
5034 role: BlockRole::Mean,
5035 edf: 1.0,
5036 lambdas: Array1::zeros(0),
5037 },
5038 FittedBlock {
5039 beta: array![0.3],
5040 role: BlockRole::Scale,
5041 edf: 1.0,
5042 lambdas: Array1::zeros(0),
5043 },
5044 ]);
5045 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5046
5047 let snapshot = SavedSurvivalTimeBasis {
5052 basisname: "royston-parmar".to_string(),
5053 degree: Some(3),
5054 knots: Some(vec![0.0, 1.0, 2.0]),
5055 keep_cols: Some(vec![0, 2]),
5056 smooth_lambda: Some(0.5),
5057 anchor: 0.25,
5058 };
5059 payload.apply_survival_time_basis(&snapshot);
5060
5061 assert_eq!(
5062 payload.survival_time_basis.as_deref(),
5063 Some("royston-parmar")
5064 );
5065 assert_eq!(payload.survival_time_degree, Some(3));
5066 assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5067 assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5068 assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5069 assert_eq!(payload.survival_time_anchor, Some(0.25));
5070 }
5071
5072 #[test]
5073 fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5074 let fit = saved_fit(vec![
5075 FittedBlock {
5076 beta: array![0.1],
5077 role: BlockRole::Time,
5078 edf: 1.0,
5079 lambdas: Array1::zeros(0),
5080 },
5081 FittedBlock {
5082 beta: array![0.2],
5083 role: BlockRole::Mean,
5084 edf: 1.0,
5085 lambdas: Array1::zeros(0),
5086 },
5087 FittedBlock {
5088 beta: array![0.3],
5089 role: BlockRole::Scale,
5090 edf: 1.0,
5091 lambdas: Array1::zeros(0),
5092 },
5093 ]);
5094 let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5095 payload.survival_time_basis = Some("ispline".to_string());
5101
5102 let err = FittedModel::from_payload(payload)
5103 .validate_for_persistence()
5104 .expect_err("survival model without time-anchor metadata should fail validation");
5105 assert!(err.to_string().contains("missing survival_time_anchor"));
5106 }
5107
5108 #[test]
5109 fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5110 let fit = saved_fit(vec![
5111 FittedBlock {
5112 beta: array![0.1],
5113 role: BlockRole::Time,
5114 edf: 1.0,
5115 lambdas: Array1::zeros(0),
5116 },
5117 FittedBlock {
5118 beta: array![0.2],
5119 role: BlockRole::Mean,
5120 edf: 1.0,
5121 lambdas: Array1::zeros(0),
5122 },
5123 FittedBlock {
5124 beta: array![0.3],
5125 role: BlockRole::Scale,
5126 edf: 1.0,
5127 lambdas: Array1::zeros(0),
5128 },
5129 ]);
5130 let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5131
5132 let err = FittedModel::from_payload(payload)
5133 .validate_for_persistence()
5134 .expect_err("survival model without time-basis metadata should fail validation");
5135 assert!(err.to_string().contains("missing survival_time_basis"));
5136 }
5137
5138 #[test]
5139 fn saved_prediction_runtime_rejects_stale_payload_version() {
5140 let fit = saved_fit(vec![
5141 FittedBlock {
5142 beta: array![0.1],
5143 role: BlockRole::Mean,
5144 edf: 1.0,
5145 lambdas: Array1::zeros(0),
5146 },
5147 FittedBlock {
5148 beta: array![0.2],
5149 role: BlockRole::Scale,
5150 edf: 1.0,
5151 lambdas: Array1::zeros(0),
5152 },
5153 ]);
5154 let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5155
5156 let err = FittedModel::from_payload(payload)
5157 .saved_prediction_runtime()
5158 .expect_err("stale payload version should fail before runtime assembly");
5159 assert!(err.to_string().contains("payload schema mismatch"));
5160 }
5161}