Skip to main content

gam_models/inference/
model.rs

1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4    LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7    SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12    monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15    inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16    parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21    InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22    SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25// The data-schema value types live in the `gam-data` foundation crate; they
26// were previously authored here and are still named `gam::inference::model::{
27// ColumnKindTag, DataSchema, SchemaColumn}` by a broad set of integration tests
28// and by saved-payload consumers. Re-export them so that public path stays
29// valid rather than forcing every caller onto the relocated crate path.
30pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39/// Canonical saved-model payload schema version.
40///
41/// Every `FittedModelPayload` written by any binary (CLI `gam`, gam-pyffi,
42/// downstream library users) must set this as its `version` field, and every
43/// load path asserts equality via `validate_for_persistence`. Bump this when:
44///   - A required field is added to `FittedModelPayload` and the set of
45///     Option<T> fields that must be `Some(...)` for a given `family_state`
46///     changes (otherwise the `#[serde(default)]` decode would silently fill
47///     the new field with `None` when loading an older model and the CLI
48///     predict path would run with stale metadata).
49///   - The on-wire shape of any `serde`-tagged enum variant changes such that
50///     older payloads no longer round-trip losslessly.
51///   - The semantics of an existing field change (e.g. sign convention,
52///     coordinate frame) in a way that predict output would silently diverge
53///     between old and new readers.
54///
55/// Do NOT bump for purely additive `Option<T>` fields that the save-time
56/// invariant (`validate_for_persistence`) does not yet require. Those are
57/// forward-compatible.
58pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60/// Schema-free saved-model metadata keyed by stable group id.
61///
62/// The values are JSON rather than a typed enum because group provenance is
63/// supplied by caller-owned catalogs. `FittedModelPayload::group_metadata`
64/// wraps this in `Option` with `#[serde(default)]`, so model files written
65/// before the field existed deserialize as `None`.
66pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68/// Saved exact spline-scan fit (#1030/#1034): the predict-time feature column
69/// plus the lossless smoother state the Gaussian-bridge `predict` replays.
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72    /// Training column name feeding the single 1-D smooth at predict time.
73    pub feature_column: String,
74    pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77/// Saved multiresolution residual-cascade fit (#1032): the predict-time feature
78/// columns (d ∈ {2, 3}) plus the serializable cascade state that `from_state`
79/// rebuilds a predict-capable `ResidualCascadeFit` from. The cascade is a
80/// DIFFERENT posterior from the dense Duchon/Matérn term — never a silent swap.
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83    /// Training column names for the d ∈ {2, 3} scattered-smooth coordinates.
84    pub feature_columns: Vec<String>,
85    pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88/// Typed error surface for `src/inference/model.rs` saved-model code.
89///
90/// Every variant carries a free-form `reason: String` payload; `Display`
91/// emits exactly that payload, so converting a `FittedModelError` into
92/// `String` (via the `From` impl below) is byte-equivalent to the pre-
93/// refactor `Err(format!(...))` / `Err("...".to_string())` strings that
94/// the same call sites produced. This lets external callers keep using
95/// `?` against `Result<_, String>` without source changes — the typed
96/// enum is purely an in-module discipline gain.
97#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99    /// Saved payload structure / shape / version disagrees with what the
100    /// current binary expects (e.g. covariance shape, block ordering,
101    /// schema version, C2 continuity, out-of-range span/basis indices).
102    SchemaMismatch { reason: String },
103    /// Saved payload bytes / numeric content are corrupt or unreadable
104    /// (non-finite scalars, invalid JSON, IO failure, malformed stateful
105    /// link state).
106    PayloadCorrupt { reason: String },
107    /// A required field that the current code path needs is absent from
108    /// the payload (typically `..; refit` errors).
109    MissingField { reason: String },
110    /// A combination of saved-model options is not supported by the
111    /// current binary (unsupported deployment-extension kind, unsupported
112    /// kernel marker, unsupported survival_time_basis variant, etc.).
113    IncompatibleConfig { reason: String },
114    /// An input value rejected by a save-time sanity gate (e.g. negative
115    /// ridge alpha).
116    InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120    FittedModelError {
121        SchemaMismatch,
122        PayloadCorrupt,
123        MissingField,
124        IncompatibleConfig,
125        InvalidInput,
126    }
127}
128
129// Boundary conversions so external `Result<_, EstimationError>` /
130// `Result<_, SurvivalPredictError>` call sites can propagate with `?`.
131// Survival prediction keeps the model-layer source so the chain identifies
132// the payload/schema failure that triggered the prediction error.
133impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134    fn from(err: FittedModelError) -> Self {
135        gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136    }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140    fn from(err: FittedModelError) -> Self {
141        crate::survival::predict::SurvivalPredictError::ModelPayload {
142            context: "saved-model survival prediction payload",
143            source: err,
144        }
145    }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150    pub mean: f64,
151    pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156        if !self.mean.is_finite() {
157            return Err(FittedModelError::PayloadCorrupt {
158                reason: format!("{context} latent z mean must be finite"),
159            });
160        }
161        if !(self.sd.is_finite() && self.sd > 1e-12) {
162            return Err(FittedModelError::PayloadCorrupt {
163                reason: format!(
164                    "{context} latent z sd must be finite and > 1e-12; got {}",
165                    self.sd
166                ),
167            });
168        }
169        Ok::<(), _>(())
170    }
171
172    pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173        self.validate(context)?;
174        if z.iter().any(|value| !value.is_finite()) {
175            return Err(FittedModelError::PayloadCorrupt {
176                reason: format!("{context} requires finite z values"),
177            });
178        }
179        Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180    }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189    #[default]
190    FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195    #[serde(default)]
196    pub score_kind: TransformationScoreKind,
197    #[serde(default = "default_transformation_score_pit_clip_eps")]
198    pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202    TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206    pub fn finite_support_pit() -> Self {
207        Self {
208            score_kind: TransformationScoreKind::FiniteSupportPit,
209            clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210        }
211    }
212
213    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214        if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215            return Err(FittedModelError::IncompatibleConfig {
216                reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217            });
218        }
219        if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220            return Err(FittedModelError::IncompatibleConfig {
221                reason: format!(
222                    "{context} requires PIT clip_eps in (0, 0.5), got {}",
223                    self.clip_eps
224                ),
225            });
226        }
227        Ok::<(), _>(())
228    }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233    pub version: u32,
234    pub formula: String,
235    pub model_kind: ModelKind,
236    pub family_state: FittedFamily,
237    pub family: String,
238    /// Human-readable advisories produced while materializing this model —
239    /// e.g. an mgcv-style "k was reduced to the data support" note when a
240    /// cubic-regression marginal is capped, or a basis-degradation note. These
241    /// are surfaced to CLI users via `print_inference_summary`; persisting them
242    /// here lets the Python (gamfit) interface surface the SAME advisories as
243    /// warnings / `model.notes` instead of silently dropping them at the FFI
244    /// boundary (#1543). `#[serde(default)]` keeps older payloads (which had no
245    /// such field) deserializing cleanly as "no notes".
246    #[serde(default)]
247    pub inference_notes: Vec<String>,
248    #[serde(default)]
249    pub used_device: bool,
250    #[serde(default)]
251    pub fit_result: Option<UnifiedFitResult>,
252    /// Unified (family-agnostic) representation of the fit result.
253    #[serde(default)]
254    pub unified: Option<UnifiedFitResult>,
255    /// Exact O(n) spline-scan fit representation (#1030/#1034): the
256    /// state-space smoothing-spline posterior of a single 1-D Gaussian
257    /// smooth. When `Some`, this standard Gaussian model's predictions
258    /// replay the Gaussian bridge from this state and the model carries no
259    /// dense `fit_result` (the representations are mutually exclusive —
260    /// enforced by `validate_for_persistence`). `#[serde(default)]` so older
261    /// payloads read as: not a scan model.
262    #[serde(default)]
263    pub spline_scan: Option<SavedSplineScan>,
264    /// O(n log n) multiresolution residual-cascade fit (#1032): the persisted
265    /// multilevel Wendland-frame state for a single scattered 2–3D Gaussian
266    /// smooth past the dense-kernel cliff. When `Some`, predictions replay the
267    /// cascade posterior; mutually exclusive with `spline_scan`/`fit_result`.
268    /// `#[serde(default)]` keeps forward-compatibility with older payloads.
269    #[serde(default)]
270    pub residual_cascade: Option<SavedResidualCascade>,
271    #[serde(default)]
272    pub data_schema: Option<DataSchema>,
273    pub link: Option<InverseLink>,
274    #[serde(default)]
275    pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276    #[serde(default)]
277    pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278    #[serde(default)]
279    pub formula_noise: Option<String>,
280    #[serde(default)]
281    pub formula_logslope: Option<String>,
282    #[serde(default)]
283    pub formula_logslopes: Option<Vec<String>>,
284    #[serde(default)]
285    pub offset_column: Option<String>,
286    #[serde(default)]
287    pub noise_offset_column: Option<String>,
288    #[serde(default)]
289    pub beta_noise: Option<Vec<f64>>,
290    #[serde(default)]
291    pub noise_projection: Option<Vec<Vec<f64>>>,
292    #[serde(default)]
293    pub noise_center: Option<Vec<f64>>,
294    #[serde(default)]
295    pub noise_scale: Option<Vec<f64>>,
296    #[serde(default)]
297    pub noise_non_intercept_start: Option<usize>,
298    /// Tikhonov ridge alpha used by `solve_scale_projection` when fitting
299    /// `noise_projection`.  Persisted so prediction-time replay is identical
300    /// to fit-time projection.
301    #[serde(default)]
302    pub noise_projection_ridge_alpha: Option<f64>,
303    #[serde(default)]
304    pub gaussian_response_scale: Option<f64>,
305    #[serde(default)]
306    pub linkwiggle_knots: Option<Vec<f64>>,
307    #[serde(default)]
308    pub linkwiggle_degree: Option<usize>,
309    #[serde(default)]
310    pub beta_link_wiggle: Option<Vec<f64>>,
311    #[serde(default)]
312    pub baseline_timewiggle_knots: Option<Vec<f64>>,
313    #[serde(default)]
314    pub baseline_timewiggle_degree: Option<usize>,
315    #[serde(default)]
316    pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
317    #[serde(default)]
318    pub baseline_timewiggle_double_penalty: Option<bool>,
319    #[serde(default)]
320    pub beta_baseline_timewiggle: Option<Vec<f64>>,
321    #[serde(default)]
322    pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
323    #[serde(default)]
324    pub z_column: Option<String>,
325    #[serde(default)]
326    pub z_columns: Option<Vec<String>>,
327    #[serde(default)]
328    pub latent_z_normalization: Option<SavedLatentZNormalization>,
329    #[serde(default)]
330    pub latent_score_contract: Option<SavedLatentScoreContract>,
331    #[serde(default)]
332    pub latent_measure: Option<LatentMeasureKind>,
333    /// Optional rank-INT calibration for the latent score (BMS family).
334    /// When `Some`, the marginal-slope predictor routes the input `z`
335    /// through [`LatentZRankIntCalibration::apply_at_predict`] before the
336    /// closed-form standard-normal kernel, matching fit-time semantics.
337    /// `#[serde(default)]` so models persisted before this field existed
338    /// continue to deserialize cleanly (interpreted as: no calibration).
339    #[serde(default)]
340    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
341    /// Optional conditional location-scale calibration of the latent score
342    /// (#905, BMS family). When `Some`, the marginal-slope predictor replaces
343    /// the (normalized) input `z` by `ζ = (z − m(C))/√v(C)` — rebuilding the
344    /// conditioning span `a(C)` from the marginal prediction design — before
345    /// the closed-form standard-normal kernel, matching fit-time semantics.
346    /// Mutually exclusive with `latent_z_rank_int_calibration`. `#[serde(default)]`
347    /// so pre-existing models deserialize cleanly (interpreted as: no
348    /// conditional calibration).
349    #[serde(default)]
350    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
351    #[serde(default)]
352    pub marginal_baseline: Option<f64>,
353    #[serde(default)]
354    pub logslope_baseline: Option<f64>,
355    #[serde(default)]
356    pub logslope_baselines: Option<Vec<f64>>,
357    #[serde(default)]
358    pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
359    #[serde(default)]
360    pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
361    /// Width `p₁` of the survival marginal-slope absorbed Stage-1 influence block
362    /// (#461) when present (the dedicated trailing absorber block). Predict drops
363    /// its `γ`; this records the column count so the predictor can account for
364    /// the extra block and slice `γ` out of the joint covariance.
365    #[serde(default)]
366    pub influence_absorber_width: Option<usize>,
367    #[serde(default)]
368    pub survival_entry: Option<String>,
369    #[serde(default)]
370    pub survival_exit: Option<String>,
371    #[serde(default)]
372    pub survival_event: Option<String>,
373    #[serde(default)]
374    pub survivalspec: Option<String>,
375    #[serde(default)]
376    pub survival_cause_count: Option<usize>,
377    #[serde(default)]
378    pub survival_endpoint_names: Option<Vec<String>>,
379    #[serde(default)]
380    pub survival_baseline_target: Option<String>,
381    #[serde(default)]
382    pub survival_baseline_scale: Option<f64>,
383    #[serde(default)]
384    pub survival_baseline_shape: Option<f64>,
385    #[serde(default)]
386    pub survival_baseline_rate: Option<f64>,
387    #[serde(default)]
388    pub survival_baseline_makeham: Option<f64>,
389    #[serde(default)]
390    pub survival_time_basis: Option<String>,
391    #[serde(default)]
392    pub survival_time_degree: Option<usize>,
393    #[serde(default)]
394    pub survival_time_knots: Option<Vec<f64>>,
395    #[serde(default)]
396    pub survival_time_keep_cols: Option<Vec<usize>>,
397    #[serde(default)]
398    pub survival_time_smooth_lambda: Option<f64>,
399    #[serde(default)]
400    pub survival_time_anchor: Option<f64>,
401    #[serde(default)]
402    pub survivalridge_lambda: Option<f64>,
403    #[serde(default)]
404    pub survival_likelihood: Option<String>,
405    #[serde(default)]
406    pub survival_beta_time: Option<Vec<f64>>,
407    #[serde(default)]
408    pub survival_beta_threshold: Option<Vec<f64>>,
409    #[serde(default)]
410    pub survival_beta_log_sigma: Option<Vec<f64>>,
411    #[serde(default)]
412    pub survival_noise_projection: Option<Vec<Vec<f64>>>,
413    #[serde(default)]
414    pub survival_noise_center: Option<Vec<f64>>,
415    #[serde(default)]
416    pub survival_noise_scale: Option<Vec<f64>>,
417    #[serde(default)]
418    pub survival_noise_non_intercept_start: Option<usize>,
419    /// Survival analog of `noise_projection_ridge_alpha`: the Tikhonov ridge
420    /// used when fitting the survival log-sigma projection.
421    #[serde(default)]
422    pub survival_noise_projection_ridge_alpha: Option<f64>,
423    #[serde(default)]
424    pub survival_distribution: Option<ResidualDistribution>,
425    #[serde(default)]
426    pub training_headers: Option<Vec<String>>,
427    /// Container type of the table the model was fitted on, as detected by the
428    /// Python binding (`"pandas"`, `"polars"`, `"pyarrow"`, `"numpy"`, or an
429    /// ambiguous tag such as `"unknown"`). This is presentation provenance, not
430    /// math: `gamfit.Model.predict` uses it as the output-container fallback
431    /// when the *prediction input* is itself container-ambiguous (a `dict` of
432    /// columns or a `list` of record dicts). Persisting it in the model payload
433    /// makes the fallback survive `save`/`load` and `dumps`/`loads`, so a
434    /// reloaded model predicts into the same container as the in-memory one.
435    /// `None` for older payloads (and for fits that never recorded a kind), in
436    /// which case the fallback degrades to `"dict"`, matching pre-persistence
437    /// behaviour for unknown-kind training tables.
438    #[serde(default, skip_serializing_if = "Option::is_none")]
439    pub training_table_kind: Option<String>,
440    /// Per-column (min, max) of the training input matrix, parallel to
441    /// `training_headers`. At predict time, inputs are axis-clipped to these
442    /// ranges so that out-of-distribution points evaluate at the nearest face
443    /// of the training bounding box rather than extrapolating polynomial
444    /// trends from polyharmonic / spline bases beyond the data envelope. Old
445    /// model JSONs that pre-date this field load with `None`, in which case
446    /// the predict path falls through unchanged (no clipping).
447    #[serde(default)]
448    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
449    /// User-supplied per-group metadata, keyed by stable group identifier.
450    ///
451    /// This is intentionally schema-free JSON so provenance maps can carry
452    /// mixed scalar/list/object values. Missing in older payloads means no
453    /// group metadata was persisted.
454    #[serde(default, skip_serializing_if = "Option::is_none")]
455    pub group_metadata: Option<GroupMetadata>,
456    /// Deployment-time no-refit group extensions applied after fitting.
457    ///
458    /// Each entry records the requested group coordinate, caller metadata, and
459    /// prior used to initialize the inserted coefficient. The active
460    /// prediction contract lives in `data_schema` + `resolved_termspec`; this
461    /// ledger preserves provenance without requiring a refit.
462    #[serde(default, skip_serializing_if = "Vec::is_empty")]
463    pub deployment_extensions: Vec<SavedDeploymentExtension>,
464    /// Transformation-normal: B-spline knots for the response-direction basis.
465    #[serde(default)]
466    pub transformation_response_knots: Option<Vec<f64>>,
467    /// Transformation-normal: deviation nullspace transform matrix (row-major).
468    #[serde(default)]
469    pub transformation_response_transform: Option<Vec<Vec<f64>>>,
470    /// Transformation-normal: B-spline degree for the response basis.
471    #[serde(default)]
472    pub transformation_response_degree: Option<usize>,
473    /// Transformation-normal: median of the response used for anchoring.
474    #[serde(default)]
475    pub transformation_response_median: Option<f64>,
476    /// Transformation-normal saved score contract. The score is the exact
477    /// finite-support PIT:
478    /// z = Phi^{-1}((Phi(h) - Phi(h_L)) / (Phi(h_U) - Phi(h_L))).
479    #[serde(default)]
480    pub transformation_score_calibration: Option<TransformationScoreCalibration>,
481    #[serde(default)]
482    pub resolved_termspec: Option<TermCollectionSpec>,
483    #[serde(default)]
484    pub resolved_termspec_noise: Option<TermCollectionSpec>,
485    #[serde(default)]
486    pub resolved_termspec_logslope: Option<TermCollectionSpec>,
487    #[serde(default)]
488    pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
489    #[serde(default)]
490    pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
491    /// Precomputed exact Gaussian-identity jackknife+ statistics (#942).
492    ///
493    /// Populated *only* for a standard Gaussian-identity model fit with unit
494    /// prior weights, where the closed-form Sherman–Morrison leave-one-out
495    /// substrate gives a distribution-free finite-sample (≥ level) prediction
496    /// interval with no held-out fold. When `Some`, `predict(interval=level)`
497    /// auto-routes through it (the MAGIC default); when `None` — any other
498    /// family/link, reweighted rows, or an older payload — predict falls back
499    /// to the model-based posterior band and labels the provenance honestly.
500    /// `#[serde(default)]` so pre-existing models deserialize as: no jackknife+
501    /// substrate available.
502    #[serde(default)]
503    pub gaussian_jackknife_plus:
504        Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
505    /// Precomputed substrate for the EXACT Gaussian-identity full-conformal set
506    /// (#942 Layer 1 + the frozen-ρ self-diagnostic).
507    ///
508    /// Populated under the SAME eligibility as `gaussian_jackknife_plus`
509    /// (Gaussian-identity, unit prior weights, offset-free, no link wiggle). It
510    /// persists the training design + response + frozen penalty `Sλ` so the
511    /// distribution-free EXACT prediction set (a union of intervals, valid for
512    /// any penalized smooth) can be replayed per test point — one Cholesky each,
513    /// zero refits — and surfaces the frozen-ρ certificate flag. `None` for any
514    /// ineligible model or an older payload, in which case the exact-set predict
515    /// path errors with a clear message and the caller uses jackknife+ or the
516    /// posterior band. `#[serde(default)]` so pre-existing models deserialize as
517    /// no exact substrate available.
518    #[serde(default)]
519    pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
523pub struct SavedDeploymentExtension {
524    pub name: String,
525    pub kind: String,
526    pub term: String,
527    pub level: JsonValue,
528    pub level_bits: u64,
529    pub coefficient_index: usize,
530    pub coefficient_mean: f64,
531    pub coefficient_variance: f64,
532    #[serde(default, skip_serializing_if = "Option::is_none")]
533    pub metadata: Option<JsonValue>,
534    #[serde(default, skip_serializing_if = "Option::is_none")]
535    pub prior: Option<JsonValue>,
536}
537
538/// Append deployment-only extension columns to the fitted design coordinate system.
539///
540/// No-refit group extension adds a new coefficient block after the fitted
541/// coefficient vector:
542///
543///   beta_ext = [beta_old, beta_new],    beta_new = mu_new.
544///
545/// For a new random-effect level g, the appended basis is the indicator
546/// e_g(x_i) = 1{x_i == g}.  The old fitted basis X_old is not rebuilt or
547/// reordered, so rows that do not exercise g have
548///
549///   eta_ext = X_old beta_old + 0 * beta_new = eta_old.
550///
551/// Rows at the new level get the exact prior-mean shift e_g beta_new.  This
552/// helper enforces the coordinate identity by requiring extension coefficient
553/// indices to be the consecutive tail columns of the base design.
554pub fn append_deployment_extension_columns(
555    model: &FittedModelPayload,
556    data: ndarray::ArrayView2<'_, f64>,
557    col_map: &HashMap<String, usize>,
558    training_headers: Option<&Vec<String>>,
559    base_design: Array2<f64>,
560) -> Result<Array2<f64>, FittedModelError> {
561    if model.deployment_extensions.is_empty() {
562        return Ok(base_design);
563    }
564    if base_design.nrows() != data.nrows() {
565        return Err(FittedModelError::SchemaMismatch {
566            reason: format!(
567                "deployment extension design row mismatch: base design has {} rows but data has {}",
568                base_design.nrows(),
569                data.nrows()
570            ),
571        });
572    }
573    let spec = model
574        .resolved_termspec
575        .as_ref()
576        .ok_or_else(|| FittedModelError::MissingField {
577            reason: "deployment extension prediction requires saved resolved_termspec; refit"
578                .to_string(),
579        })?;
580    let n = base_design.nrows();
581    let p_old = base_design.ncols();
582    let mut extensions: Vec<&SavedDeploymentExtension> =
583        model.deployment_extensions.iter().collect();
584    extensions.sort_by_key(|extension| extension.coefficient_index);
585    for (tail_idx, extension) in extensions.iter().enumerate() {
586        let expected = p_old + tail_idx;
587        if extension.coefficient_index != expected {
588            return Err(FittedModelError::SchemaMismatch {
589                reason: format!(
590                    "deployment extension '{}' has coefficient index {}, expected append-only index {}",
591                    extension.name, extension.coefficient_index, expected
592                ),
593            });
594        }
595    }
596
597    let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
598    out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
599    for (tail_idx, extension) in extensions.into_iter().enumerate() {
600        if extension.kind != "random-effect-level" {
601            return Err(FittedModelError::IncompatibleConfig {
602                reason: format!(
603                    "unsupported deployment extension kind '{}' for '{}'",
604                    extension.kind, extension.name
605                ),
606            });
607        }
608        let term = spec
609            .random_effect_terms
610            .iter()
611            .find(|term| term.name == extension.term)
612            .ok_or_else(|| FittedModelError::MissingField {
613                reason: format!(
614                    "deployment extension '{}' references unknown random-effect term '{}'",
615                    extension.name, extension.term
616                ),
617            })?;
618        let prediction_col = training_headers
619            .and_then(|headers| headers.get(term.feature_col))
620            .and_then(|name| col_map.get(name))
621            .copied()
622            .unwrap_or(term.feature_col);
623        if prediction_col >= data.ncols() {
624            return Err(FittedModelError::SchemaMismatch {
625                reason: format!(
626                    "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
627                    extension.name,
628                    prediction_col,
629                    data.ncols()
630                ),
631            });
632        }
633        let col = p_old + tail_idx;
634        for row in 0..n {
635            if data[[row, prediction_col]].to_bits() == extension.level_bits {
636                out[[row, col]] = 1.0;
637            }
638        }
639    }
640    Ok(out)
641}
642
643#[derive(Clone, Debug, Serialize, Deserialize)]
644pub struct SavedLatentScoreContract {
645    pub semantics: String,
646    pub source_transform_id: Option<String>,
647    pub normalization_mean: f64,
648    pub normalization_sd: f64,
649    pub clip_eps: Option<f64>,
650    pub conditioning_columns: Vec<String>,
651}
652
653impl FittedModelPayload {
654    pub fn new(
655        version: u32,
656        formula: String,
657        model_kind: ModelKind,
658        family_state: FittedFamily,
659        family: String,
660    ) -> Self {
661        Self {
662            version,
663            formula,
664            model_kind,
665            family_state,
666            family,
667            inference_notes: Vec::new(),
668            used_device: false,
669            fit_result: None,
670            unified: None,
671            spline_scan: None,
672            residual_cascade: None,
673            data_schema: None,
674            link: None,
675            mixture_link_param_covariance: None,
676            sas_param_covariance: None,
677            formula_noise: None,
678            formula_logslope: None,
679            formula_logslopes: None,
680            offset_column: None,
681            noise_offset_column: None,
682            beta_noise: None,
683            noise_projection: None,
684            noise_center: None,
685            noise_scale: None,
686            noise_non_intercept_start: None,
687            noise_projection_ridge_alpha: None,
688            gaussian_response_scale: None,
689            linkwiggle_knots: None,
690            linkwiggle_degree: None,
691            beta_link_wiggle: None,
692            baseline_timewiggle_knots: None,
693            baseline_timewiggle_degree: None,
694            baseline_timewiggle_penalty_orders: None,
695            baseline_timewiggle_double_penalty: None,
696            beta_baseline_timewiggle: None,
697            beta_baseline_timewiggle_by_cause: None,
698            z_column: None,
699            z_columns: None,
700            latent_z_normalization: None,
701            latent_score_contract: None,
702            latent_measure: None,
703            latent_z_rank_int_calibration: None,
704            latent_z_conditional_calibration: None,
705            marginal_baseline: None,
706            logslope_baseline: None,
707            logslope_baselines: None,
708            score_warp_runtime: None,
709            link_deviation_runtime: None,
710            influence_absorber_width: None,
711            survival_entry: None,
712            survival_exit: None,
713            survival_event: None,
714            survivalspec: None,
715            survival_cause_count: None,
716            survival_endpoint_names: None,
717            survival_baseline_target: None,
718            survival_baseline_scale: None,
719            survival_baseline_shape: None,
720            survival_baseline_rate: None,
721            survival_baseline_makeham: None,
722            survival_time_basis: None,
723            survival_time_degree: None,
724            survival_time_knots: None,
725            survival_time_keep_cols: None,
726            survival_time_smooth_lambda: None,
727            survival_time_anchor: None,
728            survivalridge_lambda: None,
729            survival_likelihood: None,
730            survival_beta_time: None,
731            survival_beta_threshold: None,
732            survival_beta_log_sigma: None,
733            survival_noise_projection: None,
734            survival_noise_center: None,
735            survival_noise_scale: None,
736            survival_noise_non_intercept_start: None,
737            survival_noise_projection_ridge_alpha: None,
738            survival_distribution: None,
739            training_headers: None,
740            training_table_kind: None,
741            training_feature_ranges: None,
742            group_metadata: None,
743            deployment_extensions: Vec::new(),
744            transformation_response_knots: None,
745            transformation_response_transform: None,
746            transformation_response_degree: None,
747            transformation_response_median: None,
748            transformation_score_calibration: None,
749            resolved_termspec: None,
750            resolved_termspec_noise: None,
751            resolved_termspec_logslope: None,
752            resolved_termspec_logslopes: None,
753            adaptive_regularization_diagnostics: None,
754            gaussian_jackknife_plus: None,
755            full_conformal: None,
756        }
757    }
758
759    pub fn set_training_feature_metadata(
760        &mut self,
761        headers: Vec<String>,
762        feature_ranges: Vec<(f64, f64)>,
763    ) {
764        self.training_headers = Some(headers);
765        self.training_feature_ranges = Some(feature_ranges);
766    }
767
768    fn synchronize_empty_feature_contract(&mut self) {
769        if self.fit_result.is_none() {
770            return;
771        }
772        let Some(schema) = self.data_schema.as_ref() else {
773            return;
774        };
775        if !schema.columns.is_empty() {
776            return;
777        }
778        self.training_headers.get_or_insert_with(Vec::new);
779        self.resolved_termspec
780            .get_or_insert_with(|| TermCollectionSpec {
781                linear_terms: Vec::new(),
782                smooth_terms: Vec::new(),
783                random_effect_terms: Vec::new(),
784            });
785    }
786
787    /// Write the persistable time-basis snapshot for a survival model.
788    ///
789    /// This is the only path that should populate the `survival_time_*`
790    /// fields used by the loader. Routing every FFI builder through this
791    /// helper guarantees no builder can silently drop a field — the
792    /// marginal-slope save→load bug was a builder that
793    /// missed `survival_time_basis`.
794    pub fn apply_survival_time_basis(
795        &mut self,
796        snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
797    ) {
798        self.survival_time_basis = Some(snapshot.basisname.clone());
799        self.survival_time_degree = snapshot.degree;
800        self.survival_time_knots = snapshot.knots.clone();
801        self.survival_time_keep_cols = snapshot.keep_cols.clone();
802        self.survival_time_smooth_lambda = snapshot.smooth_lambda;
803        self.survival_time_anchor = Some(snapshot.anchor);
804    }
805
806    fn validate_payload_version(&self) -> Result<(), FittedModelError> {
807        if self.version != MODEL_PAYLOAD_VERSION {
808            return Err(FittedModelError::SchemaMismatch {
809                reason: format!(
810                    "saved model payload schema mismatch: file has version={}, \
811                 this binary expects MODEL_PAYLOAD_VERSION={}. \
812                 Refit with the current CLI, or rebuild the reader at the same \
813                 version the model was written with.",
814                    self.version, MODEL_PAYLOAD_VERSION
815                ),
816            });
817        }
818        Ok(())
819    }
820}
821
822#[derive(Clone, Serialize, Deserialize)]
823#[serde(tag = "model_type", rename_all = "kebab-case")]
824pub enum FittedModel {
825    Standard { payload: FittedModelPayload },
826    LocationScale { payload: FittedModelPayload },
827    MarginalSlope { payload: FittedModelPayload },
828    Survival { payload: FittedModelPayload },
829    TransformationNormal { payload: FittedModelPayload },
830}
831
832#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
833#[serde(rename_all = "kebab-case")]
834pub enum ModelKind {
835    Standard,
836    LocationScale,
837    MarginalSlope,
838    Survival,
839    TransformationNormal,
840}
841
842#[derive(Clone, Debug, Serialize, Deserialize)]
843#[serde(tag = "family_kind", rename_all = "kebab-case")]
844pub enum FittedFamily {
845    Standard {
846        likelihood: LikelihoodSpec,
847        #[serde(default)]
848        link: Option<StandardLink>,
849        #[serde(default)]
850        latent_cloglog_state: Option<LatentCLogLogState>,
851        #[serde(default)]
852        mixture_state: Option<MixtureLinkState>,
853        #[serde(default)]
854        sas_state: Option<SasLinkState>,
855    },
856    LocationScale {
857        likelihood: LikelihoodSpec,
858        #[serde(default)]
859        base_link: Option<InverseLink>,
860    },
861    MarginalSlope {
862        likelihood: LikelihoodSpec,
863        base_link: InverseLink,
864        frailty: FrailtySpec,
865    },
866    Survival {
867        likelihood: LikelihoodSpec,
868        #[serde(default)]
869        survival_likelihood: Option<String>,
870        #[serde(default)]
871        survival_distribution: Option<ResidualDistribution>,
872        frailty: FrailtySpec,
873    },
874    LatentSurvival {
875        frailty: FrailtySpec,
876    },
877    LatentBinary {
878        frailty: FrailtySpec,
879    },
880    TransformationNormal {
881        likelihood: LikelihoodSpec,
882    },
883}
884
885#[derive(Clone, Copy, Debug, Eq, PartialEq)]
886pub enum PredictModelClass {
887    Standard,
888    GaussianLocationScale,
889    BinomialLocationScale,
890    /// Genuine-dispersion location-scale (#913): NegativeBinomial / Gamma / Beta
891    /// / Tweedie mean families fitted with a `noise_formula` overdispersion
892    /// channel. Predicted through the GLM mean inverse link (not the binomial
893    /// threshold-scale predictor).
894    DispersionLocationScale,
895    BernoulliMarginalSlope,
896    Survival,
897    TransformationNormal,
898}
899
900impl PredictModelClass {
901    #[inline]
902    pub const fn name(self) -> &'static str {
903        match self {
904            Self::Standard => "standard",
905            Self::GaussianLocationScale => "gaussian location-scale",
906            Self::BinomialLocationScale => "binomial location-scale",
907            Self::DispersionLocationScale => "dispersion location-scale",
908            Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
909            Self::Survival => "survival",
910            Self::TransformationNormal => "transformation-normal",
911        }
912    }
913}
914
915#[derive(Clone, Debug)]
916pub struct SavedLinkWiggleRuntime {
917    pub knots: Vec<f64>,
918    pub degree: usize,
919    pub beta: Vec<f64>,
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedBaselineTimeWiggleRuntime {
924    pub knots: Vec<f64>,
925    pub degree: usize,
926    pub penalty_orders: Vec<usize>,
927    pub double_penalty: bool,
928    pub beta: Vec<f64>,
929}
930
931// Re-export so saved-model consumers can refer to the anchor-block tag
932// without reaching across module boundaries.
933pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
934
935#[derive(Clone, Debug, Serialize, Deserialize)]
936pub struct SavedCompiledFlexBlock {
937    pub kernel: String,
938    pub breakpoints: Vec<f64>,
939    pub basis_dim: usize,
940    pub span_c0: Vec<Vec<f64>>,
941    pub span_c1: Vec<Vec<f64>>,
942    pub span_c2: Vec<Vec<f64>>,
943    pub span_c3: Vec<Vec<f64>>,
944    /// Cross-block anchor-residual coefficient matrix `M` of shape
945    /// `d × basis_dim`. When present, predict-time evaluation subtracts
946    /// `n_row · M` from each cubic-span row (where `n_row` stacks the
947    /// per-row parametric anchor values in the order given by
948    /// `anchor_components`).
949    #[serde(default)]
950    pub anchor_correction: Option<Vec<Vec<f64>>>,
951    /// Ordered list of parametric anchor components whose stacked row
952    /// values combine into `n_row`. Empty unless
953    /// `anchor_correction` is `Some`.
954    #[serde(default)]
955    pub anchor_components: Vec<SavedAnchorComponent>,
956}
957
958#[derive(Clone, Debug, Serialize, Deserialize)]
959pub struct SavedAnchorComponent {
960    pub kind: SavedAnchorKind,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub enum SavedAnchorKind {
965    Parametric {
966        block: ParametricAnchorBlock,
967        ncols: usize,
968    },
969    /// Flex-evaluation anchor (sibling flex block's reparameterised basis,
970    /// evaluated at training rows at fit time and at predict rows at
971    /// predict time). The predictor stacks `ncols` columns from the
972    /// sibling runtime's `design(arg)` into `n_row`.
973    FlexEvaluation { ncols: usize },
974}
975
976#[derive(Clone, Debug)]
977pub struct SavedPredictionRuntime {
978    pub model_class: PredictModelClass,
979    pub likelihood: LikelihoodSpec,
980    pub inverse_link: Option<InverseLink>,
981    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
982    pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
983    pub score_warp: Option<SavedCompiledFlexBlock>,
984    pub link_deviation: Option<SavedCompiledFlexBlock>,
985    /// Rank-INT latent-z calibration carried into the predictor build.
986    /// `None` for non-BMS models and for BMS fits whose latent measure
987    /// did not require rank-INT calibration.
988    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
989    /// Conditional location-scale latent-z calibration (#905) carried into the
990    /// predictor build. `None` for non-BMS models and for BMS fits whose Auto
991    /// path did not detect a conditional `E[z|C]`/`Var(z|C)` shift. When
992    /// `Some`, the predictor replaces the normalized `z` by `ζ = (z−m(C))/√v(C)`
993    /// using the marginal prediction design as the conditioning span.
994    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
995    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the
996    /// survival marginal-slope fit hosted a dedicated additive absorber (the
997    /// trailing block). `None` when no CTN Stage-1 chain produced an influence
998    /// Jacobian. At predict the absorber's `γ` is DROPPED (the orthogonalized
999    /// β̂ is a training-fit property), so the predictor uses this width only to
1000    /// (a) account for the extra trailing block in the saved block count and
1001    /// (b) slice `γ`'s rows/cols out of the joint covariance. Survival hosts the
1002    /// absorber as its own block (unlike the BMS A2 widened-marginal design),
1003    /// so it never widens any persisted prediction design.
1004    pub influence_absorber_width: Option<usize>,
1005}
1006
1007pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1008    fit.block_by_role(BlockRole::Location)
1009        .or_else(|| fit.block_by_role(BlockRole::Mean))
1010        .map(|block| block.beta.clone())
1011}
1012
1013pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1014    fit.block_by_role(BlockRole::Threshold)
1015        .or_else(|| fit.block_by_role(BlockRole::Location))
1016        .or_else(|| fit.block_by_role(BlockRole::Mean))
1017        .map(|block| block.beta.clone())
1018}
1019
1020pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021    fit.block_by_role(BlockRole::Scale)
1022        .map(|block| block.beta.clone())
1023}
1024
1025/// Whether a `ModelKind::LocationScale` likelihood's response is one of the
1026/// genuine-dispersion mean families (#913) — NegativeBinomial, Gamma, Beta or
1027/// Tweedie. These carry a `noise_formula` overdispersion channel and must be
1028/// predicted through the GLM mean inverse link (the
1029/// [`PredictModelClass::DispersionLocationScale`] path), NOT the binomial
1030/// threshold-scale predictor. The binomial location-scale (BMS ordinal) path is
1031/// the only other non-Gaussian location-scale family, with a `Binomial`
1032/// response.
1033fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1034    use gam_problem::types::ResponseFamily;
1035    matches!(
1036        response,
1037        ResponseFamily::NegativeBinomial { .. }
1038            | ResponseFamily::Gamma
1039            | ResponseFamily::Beta { .. }
1040            | ResponseFamily::Tweedie { .. }
1041    )
1042}
1043
1044fn validate_location_scale_saved_fit(
1045    fit: &UnifiedFitResult,
1046    model_class: PredictModelClass,
1047    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1048) -> Result<(), FittedModelError> {
1049    let primary = match model_class {
1050        // Gaussian and dispersion (#913) location-scale both predict the mean
1051        // through the Location block; the binomial threshold-scale class reads
1052        // the Threshold block instead.
1053        PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1054            gaussian_location_scale_mean_beta(fit)
1055        }
1056        PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1057        _ => None,
1058    }
1059    .ok_or_else(|| FittedModelError::MissingField {
1060        reason: match model_class {
1061            PredictModelClass::GaussianLocationScale => {
1062                "gaussian-location-scale saved fit is missing mean/location block".to_string()
1063            }
1064            PredictModelClass::DispersionLocationScale => {
1065                "dispersion-location-scale saved fit is missing mean/location block".to_string()
1066            }
1067            PredictModelClass::BinomialLocationScale => {
1068                "binomial-location-scale saved fit is missing threshold/location block".to_string()
1069            }
1070            _ => "location-scale saved fit is missing primary block".to_string(),
1071        },
1072    })?;
1073
1074    let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1075        reason: "location-scale saved fit is missing scale block".to_string(),
1076    })?;
1077    let expected =
1078        primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1079
1080    if let Some(cov) = fit.beta_covariance()
1081        && (cov.nrows() != expected || cov.ncols() != expected)
1082    {
1083        return Err(FittedModelError::SchemaMismatch {
1084            reason: format!(
1085                "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1086                cov.nrows(),
1087                cov.ncols(),
1088                expected,
1089                expected
1090            ),
1091        });
1092    }
1093    if let Some(cov) = fit.beta_covariance_corrected()
1094        && (cov.nrows() != expected || cov.ncols() != expected)
1095    {
1096        return Err(FittedModelError::SchemaMismatch {
1097            reason: format!(
1098                "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1099                cov.nrows(),
1100                cov.ncols(),
1101                expected,
1102                expected
1103            ),
1104        });
1105    }
1106    Ok(())
1107}
1108
1109fn validate_survival_saved_block_matches_payload(
1110    fit: &UnifiedFitResult,
1111    role: BlockRole,
1112    payload_beta: Option<&Vec<f64>>,
1113    label: &str,
1114) -> Result<usize, FittedModelError> {
1115    let block = fit
1116        .block_by_role(role)
1117        .ok_or_else(|| FittedModelError::MissingField {
1118            reason: format!("location-scale survival saved fit is missing {label} block"),
1119        })?;
1120    if let Some(saved) = payload_beta
1121        && block.beta.to_vec() != *saved
1122    {
1123        return Err(FittedModelError::SchemaMismatch {
1124            reason: format!(
1125                "location-scale survival saved {label} coefficients disagree with fit_result"
1126            ),
1127        });
1128    }
1129    Ok(block.beta.len())
1130}
1131
1132fn validate_survival_location_scale_saved_fit(
1133    payload: &FittedModelPayload,
1134    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1135) -> Result<(), FittedModelError> {
1136    let fit = payload
1137        .fit_result
1138        .as_ref()
1139        .ok_or_else(|| FittedModelError::MissingField {
1140            reason: "location-scale survival model is missing canonical fit_result payload"
1141                .to_string(),
1142        })?;
1143    let p_time = validate_survival_saved_block_matches_payload(
1144        fit,
1145        BlockRole::Time,
1146        payload.survival_beta_time.as_ref(),
1147        "time",
1148    )?;
1149    let p_threshold = validate_survival_saved_block_matches_payload(
1150        fit,
1151        BlockRole::Threshold,
1152        payload.survival_beta_threshold.as_ref(),
1153        "threshold",
1154    )?;
1155    let p_log_sigma = validate_survival_saved_block_matches_payload(
1156        fit,
1157        BlockRole::Scale,
1158        payload.survival_beta_log_sigma.as_ref(),
1159        "log-sigma",
1160    )?;
1161    let p_wiggle = match link_wiggle {
1162        Some(runtime) => {
1163            let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1164                FittedModelError::MissingField {
1165                    reason: "location-scale survival saved fit is missing link-wiggle block"
1166                        .to_string(),
1167                }
1168            })?;
1169            if block.beta.to_vec() != runtime.beta {
1170                return Err(FittedModelError::SchemaMismatch {
1171                    reason:
1172                        "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1173                            .to_string(),
1174                });
1175            }
1176            runtime.beta.len()
1177        }
1178        None => {
1179            if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1180                return Err(FittedModelError::SchemaMismatch {
1181                    reason:
1182                        "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1183                            .to_string(),
1184                });
1185            }
1186            0
1187        }
1188    };
1189    let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1190
1191    if let Some(cov) = fit.beta_covariance()
1192        && (cov.nrows() != expected || cov.ncols() != expected)
1193    {
1194        return Err(FittedModelError::SchemaMismatch {
1195            reason: format!(
1196                "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1197                cov.nrows(),
1198                cov.ncols(),
1199                expected,
1200                expected
1201            ),
1202        });
1203    }
1204    if let Some(cov) = fit.beta_covariance_corrected()
1205        && (cov.nrows() != expected || cov.ncols() != expected)
1206    {
1207        return Err(FittedModelError::SchemaMismatch {
1208            reason: format!(
1209                "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1210                cov.nrows(),
1211                cov.ncols(),
1212                expected,
1213                expected
1214            ),
1215        });
1216    }
1217    Ok(())
1218}
1219
1220fn validate_marginal_slope_saved_fit(
1221    fit: &UnifiedFitResult,
1222    score_warp: Option<&SavedCompiledFlexBlock>,
1223    link_deviation: Option<&SavedCompiledFlexBlock>,
1224    fit_label: &str,
1225) -> Result<(), FittedModelError> {
1226    validate_marginal_slope_saved_fit_impl(
1227        fit,
1228        score_warp,
1229        link_deviation,
1230        fit_label,
1231        "bernoulli",
1232        2,
1233        "marginal, logslope",
1234    )
1235}
1236
1237fn validate_survival_marginal_slope_saved_fit(
1238    fit: &UnifiedFitResult,
1239    score_warp: Option<&SavedCompiledFlexBlock>,
1240    link_deviation: Option<&SavedCompiledFlexBlock>,
1241    fit_label: &str,
1242) -> Result<(), FittedModelError> {
1243    validate_marginal_slope_saved_fit_impl(
1244        fit,
1245        score_warp,
1246        link_deviation,
1247        fit_label,
1248        "survival",
1249        3,
1250        "time, marginal, slope",
1251    )
1252}
1253
1254/// Shared block-count + coefficient-dimension validation for the bernoulli
1255/// and survival marginal-slope saved-fit gates. The only family-specific
1256/// inputs are the family kind string ("bernoulli" / "survival"), the base
1257/// block count (2 for bernoulli, 3 for survival — the survival path has an
1258/// extra time block), and the base-block role list rendered in the error
1259/// message ("marginal, logslope" / "time, marginal, slope"). The score-warp
1260/// / link-deviation tail follows the same shape in both families.
1261fn validate_marginal_slope_saved_fit_impl(
1262    fit: &UnifiedFitResult,
1263    score_warp: Option<&SavedCompiledFlexBlock>,
1264    link_deviation: Option<&SavedCompiledFlexBlock>,
1265    fit_label: &str,
1266    family_kind: &str,
1267    base_block_count: usize,
1268    base_block_role_list: &str,
1269) -> Result<(), FittedModelError> {
1270    let expected_blocks = base_block_count
1271        + usize::from(score_warp.is_some())
1272        + usize::from(link_deviation.is_some());
1273    if fit.blocks.len() != expected_blocks {
1274        let score_warp_suffix = if score_warp.is_some() {
1275            ", score-warp"
1276        } else {
1277            ""
1278        };
1279        let link_deviation_suffix = if link_deviation.is_some() {
1280            ", link-deviation"
1281        } else {
1282            ""
1283        };
1284        return Err(FittedModelError::SchemaMismatch {
1285            reason: format!(
1286                "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1287                fit.blocks.len(),
1288            ),
1289        });
1290    }
1291    if let Some(runtime) = score_warp {
1292        let beta = &fit.blocks[base_block_count].beta;
1293        if beta.len() != runtime.basis_dim {
1294            return Err(FittedModelError::SchemaMismatch {
1295                reason: format!(
1296                    "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1297                    beta.len(),
1298                    runtime.basis_dim
1299                ),
1300            });
1301        }
1302    }
1303    if let Some(runtime) = link_deviation {
1304        let idx = base_block_count + usize::from(score_warp.is_some());
1305        let beta = &fit.blocks[idx].beta;
1306        if beta.len() != runtime.basis_dim {
1307            return Err(FittedModelError::SchemaMismatch {
1308                reason: format!(
1309                    "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1310                    beta.len(),
1311                    runtime.basis_dim
1312                ),
1313            });
1314        }
1315    }
1316    Ok(())
1317}
1318
1319impl SavedLinkWiggleRuntime {
1320    fn validate_monotone_derivative(
1321        &self,
1322        q0: &Array1<f64>,
1323    ) -> Result<Array1<f64>, FittedModelError> {
1324        // Monotonicity is verified pointwise at the actual evaluation grid `q0`
1325        // (the predict η). The fit guarantees a strictly-increasing warped link
1326        // across the training η range (#1596); checking at `q0` here flags an
1327        // extrapolation point where the learnable link genuinely turns
1328        // non-invertible, without rejecting the whole model for a sign dip in the
1329        // basis tail far outside any data or evaluation point.
1330        let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1331        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1332        let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1333        if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1334            return Err(FittedModelError::PayloadCorrupt {
1335                reason: format!(
1336                    "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1337                ),
1338            });
1339        }
1340        Ok(dq_dq0)
1341    }
1342
1343    pub fn constrained_basis(
1344        &self,
1345        q0: &Array1<f64>,
1346        basis_options: BasisOptions,
1347    ) -> Result<Array2<f64>, FittedModelError> {
1348        let knot_arr = Array1::from_vec(self.knots.clone());
1349        let constrained = monotone_wiggle_basis_with_derivative_order(
1350            q0.view(),
1351            &knot_arr,
1352            self.degree,
1353            basis_options.derivative_order,
1354        )
1355        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1356        if constrained.ncols() != self.beta.len() {
1357            return Err(FittedModelError::SchemaMismatch {
1358                reason: format!(
1359                    "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1360                    self.beta.len(),
1361                    constrained.ncols()
1362                ),
1363            });
1364        }
1365        Ok(constrained)
1366    }
1367
1368    pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1369        self.validate_monotone_derivative(q0)?;
1370        self.constrained_basis(q0, BasisOptions::value())
1371    }
1372
1373    pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1374        let q = Array1::from_vec(vec![q0]);
1375        let x = self.design(&q)?;
1376        if x.nrows() != 1 {
1377            return Err(FittedModelError::SchemaMismatch {
1378                reason: format!(
1379                    "saved link-wiggle scalar evaluation expected 1 row, got {}",
1380                    x.nrows()
1381                ),
1382            });
1383        }
1384        Ok(x.row(0).to_owned())
1385    }
1386
1387    pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1388        self.validate_monotone_derivative(q0)?;
1389        let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1390        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1391        Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1392    }
1393
1394    pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395        self.validate_monotone_derivative(q0)
1396    }
1397}
1398
1399impl SavedBaselineTimeWiggleRuntime {
1400    pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1401        validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1402            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1403    }
1404}
1405
1406impl SavedCompiledFlexBlock {
1407    pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1408        if self.kernel.is_empty() {
1409            return Err(FittedModelError::SchemaMismatch {
1410                reason: "saved anchored deviation runtime is missing the exact kernel marker"
1411                    .to_string(),
1412            });
1413        }
1414        if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1415            return Err(FittedModelError::IncompatibleConfig {
1416                reason: format!(
1417                    "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1418                    self.kernel,
1419                    crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1420                ),
1421            });
1422        }
1423        if self.basis_dim == 0 {
1424            return Err(FittedModelError::SchemaMismatch {
1425                reason: format!(
1426                    "saved anchored deviation runtime basis_dim must be positive, got {}",
1427                    self.basis_dim
1428                ),
1429            });
1430        }
1431        if self.breakpoints.len() < 2 {
1432            return Err(FittedModelError::SchemaMismatch {
1433                reason: format!(
1434                    "saved anchored deviation runtime requires at least two breakpoints, got {}",
1435                    self.breakpoints.len()
1436                ),
1437            });
1438        }
1439        for window in self.breakpoints.windows(2) {
1440            let left = window[0];
1441            let right = window[1];
1442            if !left.is_finite() || !right.is_finite() || right <= left {
1443                return Err(FittedModelError::PayloadCorrupt {
1444                    reason: format!(
1445                        "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1446                    ),
1447                });
1448            }
1449        }
1450        let span_count = self.breakpoints.len() - 1;
1451        self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1452        self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1453        self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1454        self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1455        self.validate_c2_span_continuity()?;
1456        self.validate_anchor_residual_shape()?;
1457        Ok(())
1458    }
1459
1460    fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1461        let coeffs = match self.anchor_correction.as_ref() {
1462            Some(c) => c,
1463            None => {
1464                if !self.anchor_components.is_empty() {
1465                    return Err(FittedModelError::SchemaMismatch {
1466                        reason:
1467                            "saved anchored deviation runtime has anchor_components but no anchor_correction"
1468                                .to_string(),
1469                    });
1470                }
1471                return Ok(());
1472            }
1473        };
1474        let d: usize = self
1475            .anchor_components
1476            .iter()
1477            .map(|c| match &c.kind {
1478                SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1479                SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1480            })
1481            .sum();
1482        if coeffs.len() != d {
1483            return Err(FittedModelError::SchemaMismatch {
1484                reason: format!(
1485                    "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1486                    coeffs.len(),
1487                    d,
1488                ),
1489            });
1490        }
1491        for (i, row) in coeffs.iter().enumerate() {
1492            if row.len() != self.basis_dim {
1493                return Err(FittedModelError::SchemaMismatch {
1494                    reason: format!(
1495                        "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1496                        i,
1497                        row.len(),
1498                        self.basis_dim,
1499                    ),
1500                });
1501            }
1502            for (j, &v) in row.iter().enumerate() {
1503                if !v.is_finite() {
1504                    return Err(FittedModelError::PayloadCorrupt {
1505                        reason: format!(
1506                            "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1507                        ),
1508                    });
1509                }
1510            }
1511        }
1512        Ok(())
1513    }
1514
1515    fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1516        const TOL: f64 = 1e-8;
1517        for span_idx in 1..self.breakpoints.len() - 1 {
1518            let left_span = span_idx - 1;
1519            let right_span = span_idx;
1520            let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1521            for basis_idx in 0..self.basis_dim {
1522                let left_value = self.span_c0[left_span][basis_idx]
1523                    + self.span_c1[left_span][basis_idx] * width
1524                    + self.span_c2[left_span][basis_idx] * width * width
1525                    + self.span_c3[left_span][basis_idx] * width * width * width;
1526                let left_d1 = self.span_c1[left_span][basis_idx]
1527                    + 2.0 * self.span_c2[left_span][basis_idx] * width
1528                    + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1529                let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1530                    + 6.0 * self.span_c3[left_span][basis_idx] * width;
1531                let right_value = self.span_c0[right_span][basis_idx];
1532                let right_d1 = self.span_c1[right_span][basis_idx];
1533                let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1534                if (left_value - right_value).abs() > TOL
1535                    || (left_d1 - right_d1).abs() > TOL
1536                    || (left_d2 - right_d2).abs() > TOL
1537                {
1538                    return Err(FittedModelError::SchemaMismatch {
1539                        reason: format!(
1540                            "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1541                            left_value - right_value,
1542                            left_d1 - right_d1,
1543                            left_d2 - right_d2
1544                        ),
1545                    });
1546                }
1547            }
1548        }
1549        Ok(())
1550    }
1551
1552    fn validate_coefficient_matrix(
1553        &self,
1554        matrix: &[Vec<f64>],
1555        label: &str,
1556        expected_rows: usize,
1557    ) -> Result<(), FittedModelError> {
1558        if matrix.len() != expected_rows {
1559            return Err(FittedModelError::SchemaMismatch {
1560                reason: format!(
1561                    "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1562                    matrix.len(),
1563                    expected_rows
1564                ),
1565            });
1566        }
1567        for (row_idx, row) in matrix.iter().enumerate() {
1568            if row.len() != self.basis_dim {
1569                return Err(FittedModelError::SchemaMismatch {
1570                    reason: format!(
1571                        "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1572                        row_idx,
1573                        row.len(),
1574                        self.basis_dim
1575                    ),
1576                });
1577            }
1578            for (j, &value) in row.iter().enumerate() {
1579                if !value.is_finite() {
1580                    return Err(FittedModelError::PayloadCorrupt {
1581                        reason: format!(
1582                            "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1583                        ),
1584                    });
1585                }
1586            }
1587        }
1588        Ok(())
1589    }
1590
1591    fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1592        let last_span = self.breakpoints.len() - 2;
1593        let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1594        self.span_c0[last_span][basis_idx]
1595            + self.span_c1[last_span][basis_idx] * width
1596            + self.span_c2[last_span][basis_idx] * width * width
1597            + self.span_c3[last_span][basis_idx] * width * width * width
1598    }
1599
1600    fn evaluate_span_polynomial_design(
1601        &self,
1602        values: &Array1<f64>,
1603        derivative_order: usize,
1604    ) -> Result<Array2<f64>, FittedModelError> {
1605        self.validate_exact_replay_contract()?;
1606        let (left_ep, right_ep) = self.support_interval()?;
1607        let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1608        for (row_idx, &value) in values.iter().enumerate() {
1609            if !value.is_finite() {
1610                return Err(FittedModelError::PayloadCorrupt {
1611                    reason: format!(
1612                        "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1613                    ),
1614                });
1615            }
1616            if value < left_ep {
1617                if derivative_order == 0 {
1618                    for basis_idx in 0..self.basis_dim {
1619                        out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1620                    }
1621                }
1622                continue;
1623            }
1624            if value > right_ep {
1625                if derivative_order == 0 {
1626                    for basis_idx in 0..self.basis_dim {
1627                        out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1628                    }
1629                }
1630                continue;
1631            }
1632            let span_idx = self.left_biased_span_index_for(value)?;
1633            let t = value - self.breakpoints[span_idx];
1634            for basis_idx in 0..self.basis_dim {
1635                let c0 = self.span_c0[span_idx][basis_idx];
1636                let c1 = self.span_c1[span_idx][basis_idx];
1637                let c2 = self.span_c2[span_idx][basis_idx];
1638                let c3 = self.span_c3[span_idx][basis_idx];
1639                out[[row_idx, basis_idx]] = match derivative_order {
1640                    0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1641                    1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1642                    2 => 2.0 * c2 + 6.0 * c3 * t,
1643                    3 => 6.0 * c3,
1644                    4 => 0.0,
1645                    other => {
1646                        return Err(FittedModelError::IncompatibleConfig {
1647                            reason: format!(
1648                                "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1649                            ),
1650                        });
1651                    }
1652                };
1653            }
1654        }
1655        Ok(out)
1656    }
1657
1658    pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1659        self.validate_exact_replay_contract()?;
1660        Ok(self.breakpoints.clone())
1661    }
1662
1663    pub fn span_count(&self) -> Result<usize, FittedModelError> {
1664        Ok(self.breakpoints()?.windows(2).count())
1665    }
1666
1667    pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1668        let points = self.breakpoints()?;
1669        span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1670            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1671    }
1672
1673    fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1674        let mut span_idx = span_index_for_breakpoints(
1675            &self.breakpoints,
1676            value,
1677            "saved anchored deviation span lookup",
1678        )
1679        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1680        // LEFT-bias at interior breakpoints mirrors DeviationRuntime. The
1681        // saved cubic basis is C2, but d3 remains span-local.
1682        if span_idx > 0 && value == self.breakpoints[span_idx] {
1683            span_idx -= 1;
1684        }
1685        Ok(span_idx)
1686    }
1687
1688    pub fn local_cubic_on_span(
1689        &self,
1690        beta: &Array1<f64>,
1691        span_idx: usize,
1692    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1693        self.validate_exact_replay_contract()?;
1694        if beta.len() != self.basis_dim {
1695            return Err(FittedModelError::SchemaMismatch {
1696                reason: format!(
1697                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1698                    beta.len(),
1699                    self.basis_dim
1700                ),
1701            });
1702        }
1703        self.local_cubic_on_span_validated(beta, span_idx)
1704    }
1705
1706    fn local_cubic_on_span_validated(
1707        &self,
1708        beta: &Array1<f64>,
1709        span_idx: usize,
1710    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1711        let points = &self.breakpoints;
1712        if span_idx + 1 >= points.len() {
1713            return Err(FittedModelError::SchemaMismatch {
1714                reason: format!(
1715                    "saved anchored deviation span index {} out of range for {} spans",
1716                    span_idx,
1717                    points.len() - 1
1718                ),
1719            });
1720        }
1721        let left = points[span_idx];
1722        let right = points[span_idx + 1];
1723        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1724            left,
1725            right,
1726            c0: self.span_c0[span_idx]
1727                .iter()
1728                .zip(beta.iter())
1729                .map(|(coeff, weight)| coeff * weight)
1730                .sum(),
1731            c1: self.span_c1[span_idx]
1732                .iter()
1733                .zip(beta.iter())
1734                .map(|(coeff, weight)| coeff * weight)
1735                .sum(),
1736            c2: self.span_c2[span_idx]
1737                .iter()
1738                .zip(beta.iter())
1739                .map(|(coeff, weight)| coeff * weight)
1740                .sum(),
1741            c3: self.span_c3[span_idx]
1742                .iter()
1743                .zip(beta.iter())
1744                .map(|(coeff, weight)| coeff * weight)
1745                .sum(),
1746        })
1747    }
1748
1749    pub fn basis_span_cubic(
1750        &self,
1751        span_idx: usize,
1752        basis_idx: usize,
1753    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1754        self.validate_exact_replay_contract()?;
1755        if basis_idx >= self.basis_dim {
1756            return Err(FittedModelError::SchemaMismatch {
1757                reason: format!(
1758                    "saved anchored deviation basis index {} out of range for {} coefficients",
1759                    basis_idx, self.basis_dim
1760                ),
1761            });
1762        }
1763        self.basis_span_cubic_validated(span_idx, basis_idx)
1764    }
1765
1766    fn basis_span_cubic_validated(
1767        &self,
1768        span_idx: usize,
1769        basis_idx: usize,
1770    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1771        let points = &self.breakpoints;
1772        if span_idx + 1 >= points.len() {
1773            return Err(FittedModelError::SchemaMismatch {
1774                reason: format!(
1775                    "saved anchored deviation span index {} out of range for {} spans",
1776                    span_idx,
1777                    points.len() - 1
1778                ),
1779            });
1780        }
1781        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1782            left: points[span_idx],
1783            right: points[span_idx + 1],
1784            c0: self.span_c0[span_idx][basis_idx],
1785            c1: self.span_c1[span_idx][basis_idx],
1786            c2: self.span_c2[span_idx][basis_idx],
1787            c3: self.span_c3[span_idx][basis_idx],
1788        })
1789    }
1790
1791    pub fn basis_cubic_at(
1792        &self,
1793        basis_idx: usize,
1794        value: f64,
1795    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1796        self.validate_exact_replay_contract()?;
1797        if basis_idx >= self.basis_dim {
1798            return Err(FittedModelError::SchemaMismatch {
1799                reason: format!(
1800                    "saved anchored deviation basis index {} out of range for {} coefficients",
1801                    basis_idx, self.basis_dim
1802                ),
1803            });
1804        }
1805        let (left_ep, right_ep) = self.support_interval()?;
1806        if value < left_ep {
1807            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1808                left: left_ep,
1809                right: left_ep + 1.0,
1810                c0: self.span_c0[0][basis_idx],
1811                c1: 0.0,
1812                c2: 0.0,
1813                c3: 0.0,
1814            });
1815        }
1816        if value > right_ep {
1817            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1818                left: right_ep,
1819                right: right_ep + 1.0,
1820                c0: self.right_boundary_basis_value(basis_idx),
1821                c1: 0.0,
1822                c2: 0.0,
1823                c3: 0.0,
1824            });
1825        }
1826        let span_idx = self.left_biased_span_index_for(value)?;
1827        self.basis_span_cubic_validated(span_idx, basis_idx)
1828    }
1829
1830    pub fn local_cubic_at(
1831        &self,
1832        beta: &Array1<f64>,
1833        value: f64,
1834    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1835        self.validate_exact_replay_contract()?;
1836        if beta.len() != self.basis_dim {
1837            return Err(FittedModelError::SchemaMismatch {
1838                reason: format!(
1839                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1840                    beta.len(),
1841                    self.basis_dim
1842                ),
1843            });
1844        }
1845        let (left_ep, right_ep) = self.support_interval()?;
1846        if value < left_ep {
1847            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1848                left: left_ep,
1849                right: left_ep + 1.0,
1850                c0: self.span_c0[0]
1851                    .iter()
1852                    .zip(beta.iter())
1853                    .map(|(coeff, weight)| coeff * weight)
1854                    .sum(),
1855                c1: 0.0,
1856                c2: 0.0,
1857                c3: 0.0,
1858            });
1859        }
1860        if value > right_ep {
1861            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1862                left: right_ep,
1863                right: right_ep + 1.0,
1864                c0: (0..self.basis_dim)
1865                    .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1866                    .sum(),
1867                c1: 0.0,
1868                c2: 0.0,
1869                c3: 0.0,
1870            });
1871        }
1872        let span_idx = self.left_biased_span_index_for(value)?;
1873        self.local_cubic_on_span_validated(beta, span_idx)
1874    }
1875
1876    fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1877        let points = self.breakpoints()?;
1878        match (points.first(), points.last()) {
1879            (Some(&left), Some(&right)) => Ok((left, right)),
1880            _ => Err(FittedModelError::MissingField {
1881                reason: "saved anchored deviation runtime is missing support breakpoints"
1882                    .to_string(),
1883            }),
1884        }
1885    }
1886
1887    pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1888        // Note: when the saved runtime carries an anchor residual
1889        // (cross-block orthogonalisation), the value `design()` returns
1890        // is the raw cubic span output *without* the per-row `n_row · M`
1891        // subtraction. Callers used inside BMS prediction must either
1892        // switch to `design_with_anchor_rows` (when the per-row anchor
1893        // rows are available) or call `design_uncorrected` explicitly and
1894        // apply the subtraction at the call site. For runtimes without a
1895        // residual the two paths coincide.
1896        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1897    }
1898
1899    /// Raw cubic-span design without any anchor-residual subtraction.
1900    ///
1901    /// Exposed for callers that intend to apply the `n_row · M` correction
1902    /// post-hoc (e.g., BMS `link_terms_value_d1` subtracts a precomputed
1903    /// `correction.dot(beta)` scalar from the linear-predictor contribution
1904    /// rather than building a full anchor-row matrix). Equivalent to
1905    /// `design()` when no residual is present.
1906    pub fn design_uncorrected(
1907        &self,
1908        values: &Array1<f64>,
1909    ) -> Result<Array2<f64>, FittedModelError> {
1910        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1911    }
1912
1913    /// Evaluate the residual-corrected design at the supplied values.
1914    ///
1915    /// `anchor_rows` must be an `n × d` matrix where `n == values.len()`
1916    /// and `d == sum of anchor_components ncols`. Each row holds
1917    /// the concatenated parametric anchor design at the same prediction
1918    /// row as the corresponding `values[i]`. When the runtime has no
1919    /// anchor residual, `anchor_rows` must have zero columns (or be
1920    /// `Array2::zeros((n, 0))`).
1921    pub fn design_with_anchor_rows(
1922        &self,
1923        values: &Array1<f64>,
1924        anchor_rows: ndarray::ArrayView2<f64>,
1925    ) -> Result<Array2<f64>, FittedModelError> {
1926        let mut out =
1927            self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1928        if let Some(m_rows) = self.anchor_correction.as_ref() {
1929            let d = m_rows.len();
1930            if anchor_rows.nrows() != values.len() {
1931                return Err(FittedModelError::SchemaMismatch {
1932                    reason: format!(
1933                        "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1934                        anchor_rows.nrows(),
1935                        values.len(),
1936                    ),
1937                });
1938            }
1939            if anchor_rows.ncols() != d {
1940                return Err(FittedModelError::SchemaMismatch {
1941                    reason: format!(
1942                        "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1943                        anchor_rows.ncols(),
1944                        d,
1945                    ),
1946                });
1947            }
1948            // Materialise M (d × basis_dim) once.
1949            let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1950            for (i, row) in m_rows.iter().enumerate() {
1951                if row.len() != self.basis_dim {
1952                    return Err(FittedModelError::SchemaMismatch {
1953                        reason: format!(
1954                            "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1955                            i,
1956                            row.len(),
1957                            self.basis_dim,
1958                        ),
1959                    });
1960                }
1961                for (j, &v) in row.iter().enumerate() {
1962                    m_dense[[i, j]] = v;
1963                }
1964            }
1965            // The compiler bakes the orthonormalising rotation into M, so
1966            // the predict-time subtraction is simply `n_anchor_rows · M`.
1967            let subtract = anchor_rows.dot(&m_dense);
1968            out = out - subtract;
1969        } else if anchor_rows.ncols() != 0 {
1970            return Err(FittedModelError::SchemaMismatch {
1971                reason: format!(
1972                    "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1973                    anchor_rows.ncols(),
1974                ),
1975            });
1976        }
1977        Ok(out)
1978    }
1979
1980    /// Build the n × basis_dim per-row, per-basis correction matrix
1981    /// `N · M` for a batch of predict rows.
1982    ///
1983    /// `n_anchor_rows` is the n × d matrix of stacked parametric anchor
1984    /// rows at the prediction rows (concatenation of the marginal and
1985    /// logslope design rows in component order). Returns `None` when the
1986    /// runtime has no anchor residual (zero-cost path).
1987    pub fn anchor_correction_matrix(
1988        &self,
1989        n_anchor_rows: ndarray::ArrayView2<f64>,
1990    ) -> Result<Option<Array2<f64>>, FittedModelError> {
1991        let Some(m_rows) = self.anchor_correction.as_ref() else {
1992            return Ok(None);
1993        };
1994        let d = m_rows.len();
1995        if n_anchor_rows.ncols() != d {
1996            return Err(FittedModelError::SchemaMismatch {
1997                reason: format!(
1998                    "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
1999                    n_anchor_rows.ncols(),
2000                    d,
2001                ),
2002            });
2003        }
2004        let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2005        for (i, row) in m_rows.iter().enumerate() {
2006            if row.len() != self.basis_dim {
2007                return Err(FittedModelError::SchemaMismatch {
2008                    reason: format!(
2009                        "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2010                        i,
2011                        row.len(),
2012                        self.basis_dim,
2013                    ),
2014                });
2015            }
2016            for (j, &v) in row.iter().enumerate() {
2017                m_dense[[i, j]] = v;
2018            }
2019        }
2020        // The compiler bakes the orthonormalising rotation into M, so
2021        // the predict-time correction is simply `n_anchor_rows · M`.
2022        Ok(Some(n_anchor_rows.dot(&m_dense)))
2023    }
2024
2025    pub fn first_derivative_design(
2026        &self,
2027        values: &Array1<f64>,
2028    ) -> Result<Array2<f64>, FittedModelError> {
2029        self.evaluate_span_polynomial_design(
2030            values,
2031            BasisOptions::first_derivative().derivative_order,
2032        )
2033    }
2034
2035    pub fn second_derivative_design(
2036        &self,
2037        values: &Array1<f64>,
2038    ) -> Result<Array2<f64>, FittedModelError> {
2039        self.evaluate_span_polynomial_design(
2040            values,
2041            BasisOptions::second_derivative().derivative_order,
2042        )
2043    }
2044}
2045
2046impl FittedFamily {
2047    #[inline]
2048    pub fn likelihood(&self) -> LikelihoodSpec {
2049        let spec = match self {
2050            Self::Standard { likelihood, .. }
2051            | Self::LocationScale { likelihood, .. }
2052            | Self::MarginalSlope { likelihood, .. }
2053            | Self::Survival { likelihood, .. }
2054            | Self::TransformationNormal { likelihood, .. } => likelihood,
2055            Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2056                return LikelihoodSpec::royston_parmar();
2057            }
2058        };
2059        spec.clone()
2060    }
2061
2062    #[inline]
2063    pub fn frailty(&self) -> Option<&FrailtySpec> {
2064        match self {
2065            Self::MarginalSlope { frailty, .. }
2066            | Self::Survival { frailty, .. }
2067            | Self::LatentSurvival { frailty }
2068            | Self::LatentBinary { frailty } => Some(frailty),
2069            _ => None,
2070        }
2071    }
2072}
2073
2074/// Recursively collect the feature columns of a smooth basis whose out-of-hull
2075/// evaluation is bounded, so they can be exempted from the predict-time axis
2076/// clip (see [`FittedModel::training_smooth_extrapolation_axes`]). Wrapper bases
2077/// (`by=`, factor-smooth, sum-to-zero) delegate to their inner smooth; `Sphere`
2078/// and `Pca` are intentionally not collected.
2079fn collect_smooth_extrapolation_axes(
2080    basis: &gam_terms::smooth::SmoothBasisSpec,
2081    n_training_headers: usize,
2082    out: &mut std::collections::HashSet<usize>,
2083) {
2084    use gam_terms::smooth::SmoothBasisSpec;
2085    let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2086        if col < n_training_headers {
2087            out.insert(col);
2088        }
2089    };
2090    match basis {
2091        // 1D B-spline: first-order linear extension off the boundary slope.
2092        SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2093        // Tensor B-spline: each margin linear-extends independently. Periodic
2094        // margins are additionally (and harmlessly) exempted via the periodic set.
2095        SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2096            for &c in feature_cols {
2097                push(c, out);
2098            }
2099        }
2100        // Radial bases with a bounded out-of-hull contract: Duchon / thin-plate
2101        // are linear outside the data span (natural-spline boundary conditions),
2102        // Matérn reverts to its mean as the kernel decays. Measure-jet shares
2103        // the Matérn contract (Gaussian representers decay to the parametric
2104        // layer off the data support) — and off-web queries are exactly the
2105        // ones its support diagnostic must see unclipped.
2106        SmoothBasisSpec::ThinPlate { feature_cols, .. }
2107        | SmoothBasisSpec::Matern { feature_cols, .. }
2108        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2109        | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2110            for &c in feature_cols {
2111                push(c, out);
2112            }
2113        }
2114        // Factor-smooth: the continuous marginal axes are B-splines that
2115        // linear-extend; the group column is categorical and is left to the
2116        // random-effect / level-lookup machinery.
2117        SmoothBasisSpec::FactorSmooth { spec } => {
2118            for &c in &spec.continuous_cols {
2119                push(c, out);
2120            }
2121        }
2122        // Wrappers delegate to the inner smooth they modulate / replicate.
2123        SmoothBasisSpec::ByVariable { inner, .. }
2124        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2125            collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2126        }
2127        SmoothBasisSpec::BySmooth { smooth, .. } => {
2128            collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2129        }
2130        // Sphere: latitude is clipped to manifold bounds, longitude is periodic —
2131        // both handled elsewhere with non-plateau semantics. Pca: no extrapolation
2132        // contract, stays clipped. ConstantCurvature: chart coordinates must stay
2133        // inside the κ-stereographic chart (open ball for κ < 0), so clipping new
2134        // data to the training range is the safe out-of-hull behavior.
2135        SmoothBasisSpec::Sphere { .. }
2136        | SmoothBasisSpec::ConstantCurvature { .. }
2137        | SmoothBasisSpec::Pca { .. } => {}
2138    }
2139}
2140
2141/// Collect training-column indices that feed a *numeric* `by=` multiplier of a
2142/// varying-coefficient smooth (`s(x, by=z)`).
2143///
2144/// A numeric by-variable enters the model as a linear multiplier of the centred
2145/// smooth basis, so the prediction is exactly affine in `z` for any fixed `x`:
2146/// `pred(x, z) = intercept + z·f(x)`. The by-multiplier is therefore an
2147/// *unbounded linear* extrapolant — exactly like a parametric `linear` axis
2148/// (already exempt from the clip via `training_linear_axes`), not a
2149/// bounded-basis axis. Clamping it to the training range `[min(z), max(z)]`
2150/// before the design is built would make the varying-coefficient effect plateau
2151/// outside the sampled range (slope 0 above the max) and silently replace the
2152/// natural `z == 0` baseline with the `z == min` prediction — the prediction is
2153/// no longer affine in `z`. Exempt it from the predict-time axis clip.
2154///
2155/// Factor `by=` columns are categorical group labels handled by the
2156/// level-lookup / random-effect machinery and are deliberately *not* exempted
2157/// here. Returned indices reference `self.training_headers`, matching the
2158/// iteration in `axis_clip_to_training_ranges`.
2159fn collect_by_variable_numeric_axes(
2160    basis: &gam_terms::smooth::SmoothBasisSpec,
2161    n_training_headers: usize,
2162    out: &mut std::collections::HashSet<usize>,
2163) {
2164    use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2165    match basis {
2166        SmoothBasisSpec::ByVariable {
2167            inner,
2168            by_col,
2169            kind,
2170            ..
2171        } => {
2172            if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2173                out.insert(*by_col);
2174            }
2175            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2176        }
2177        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2178            if let ByVarKind::Numeric { feature_col } = by_kind
2179                && *feature_col < n_training_headers
2180            {
2181                out.insert(*feature_col);
2182            }
2183            collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2184        }
2185        SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2186            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2187        }
2188        _ => {}
2189    }
2190}
2191
2192impl FittedModel {
2193    /// Axis-clip each continuous new-data column to the (min, max) range
2194    /// observed in training. Categorical and binary columns are left
2195    /// untouched so unseen levels surface rather than being silently remapped
2196    /// onto seen ones. Returns `Some(clipped_copy)` only if at least one
2197    /// value was actually clipped; otherwise `None` so callers can avoid
2198    /// owning a redundant copy. Pre-2026-04-29 model JSONs that lack the
2199    /// `training_feature_ranges` field deserialize to `None` and pass through
2200    /// unchanged.
2201    pub fn axis_clip_to_training_ranges(
2202        &self,
2203        data: ndarray::ArrayView2<'_, f64>,
2204        col_map: &std::collections::HashMap<String, usize>,
2205    ) -> Option<ndarray::Array2<f64>> {
2206        let training_headers = self.training_headers.as_ref()?;
2207        let ranges = self.training_feature_ranges.as_ref()?;
2208        if training_headers.len() != ranges.len() {
2209            return None;
2210        }
2211        let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2212            std::collections::HashMap::new();
2213        if let Some(schema) = self.data_schema.as_ref() {
2214            for col in &schema.columns {
2215                kind_by_header.insert(col.name.as_str(), col.kind);
2216            }
2217        }
2218        // Periodic axes (sphere longitude, periodic-B-spline 1D, periodic
2219        // tensor margins) must never be clipped to the training range:
2220        // clamping a value just past the seam to the training extreme breaks
2221        // the cyclic invariant f(x₀) = f(x₀ + period) at predict time and
2222        // shows up as a visible seam in surface plots.
2223        let periodic_axes = self.training_periodic_axes(training_headers);
2224        // Parametric/linear-term axes must never be clipped either: a linear
2225        // term's contract is η = β0 + β1·x, i.e. genuine linear extrapolation.
2226        // Clamping its input to the training extreme turns predict into a
2227        // piecewise-constant plateau outside the training hull and freezes the
2228        // prediction SE at the boundary (the clamped x feeds xᵀ Var(β) x), so
2229        // credible intervals stop widening with distance from the data. This
2230        // mirrors how periodic axes are exempted just above.
2231        let linear_axes = self.training_linear_axes(training_headers.len());
2232        // Random-effect grouping axes are categorical even when their source
2233        // column is numeric. Clipping them would remap an unseen group label to
2234        // a boundary training level instead of letting the random-effect block
2235        // encode it as the prior-mean zero effect.
2236        let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2237        // Non-parametric smooth axes whose basis extrapolates boundedly on its
2238        // own (B-spline linear extension, Duchon/thin-plate natural-spline linear
2239        // tail, Matérn kernel decay). Clamping their input to the training extreme
2240        // hands the basis an already-clamped coordinate, so its extrapolation
2241        // never fires and predict freezes at a boundary plateau — diverging from
2242        // the raw design path, which does not clip. Exempt them so both paths go
2243        // through the single basis-layer extrapolation. See the method doc.
2244        let smooth_extrapolation_axes =
2245            self.training_smooth_extrapolation_axes(training_headers.len());
2246        // Numeric `by=` multipliers of a varying-coefficient smooth `s(x, by=z)`
2247        // are linear multipliers of the centred basis (prediction affine in z),
2248        // so — like parametric linear axes — clipping them to the training range
2249        // turns the varying-coefficient effect into a boundary plateau and
2250        // destroys the z==0 baseline. Exempt them from the clip.
2251        let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2252        // Sphere latitude is a closed-manifold coordinate: its clip bounds are
2253        // the manifold's intrinsic domain ([-π/2, π/2] or [-90, 90]), not the
2254        // sampled range, so a pole prediction reaches the true pole instead of
2255        // being clamped to a near-pole latitude (see the method doc).
2256        let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2257        let mut clipped = data.to_owned();
2258        let mut any_clipped = false;
2259        for (col_in_training, (header, &(lo, hi))) in
2260            training_headers.iter().zip(ranges.iter()).enumerate()
2261        {
2262            let (lo, hi) = sphere_lat_bounds
2263                .get(&col_in_training)
2264                .copied()
2265                .unwrap_or((lo, hi));
2266            if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2267                continue;
2268            }
2269            if !matches!(
2270                kind_by_header.get(header.as_str()).copied(),
2271                Some(ColumnKindTag::Continuous)
2272            ) {
2273                continue;
2274            }
2275            if periodic_axes.contains(&col_in_training) {
2276                continue;
2277            }
2278            if linear_axes.contains(&col_in_training) {
2279                continue;
2280            }
2281            if random_effect_axes.contains(&col_in_training) {
2282                continue;
2283            }
2284            if smooth_extrapolation_axes.contains(&col_in_training) {
2285                continue;
2286            }
2287            if by_variable_axes.contains(&col_in_training) {
2288                continue;
2289            }
2290            let Some(&col_idx) = col_map.get(header) else {
2291                continue;
2292            };
2293            if col_idx >= clipped.ncols() {
2294                continue;
2295            }
2296            let mut col = clipped.column_mut(col_idx);
2297            for v in col.iter_mut() {
2298                if v.is_finite() {
2299                    if *v < lo {
2300                        *v = lo;
2301                        any_clipped = true;
2302                    } else if *v > hi {
2303                        *v = hi;
2304                        any_clipped = true;
2305                    }
2306                }
2307            }
2308        }
2309        if any_clipped { Some(clipped) } else { None }
2310    }
2311
2312    fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2313        let mut specs: Vec<&TermCollectionSpec> = [
2314            self.resolved_termspec.as_ref(),
2315            self.resolved_termspec_noise.as_ref(),
2316            self.resolved_termspec_logslope.as_ref(),
2317        ]
2318        .into_iter()
2319        .flatten()
2320        .collect();
2321        if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2322            specs.extend(logslopes.iter());
2323        }
2324        specs
2325    }
2326
2327    /// Collect the set of training-column indices that are periodic axes —
2328    /// i.e. features for which a periodic basis (sphere longitude, periodic
2329    /// B-spline 1D, periodic tensor margin) must be allowed to take any
2330    /// real value at predict time and not be clamped to the training range.
2331    /// Returned indices reference `self.training_headers` (training-time
2332    /// layout), matching the iteration in `axis_clip_to_training_ranges`.
2333    fn training_periodic_axes(
2334        &self,
2335        training_headers: &[String],
2336    ) -> std::collections::HashSet<usize> {
2337        use gam_terms::basis::BSplineKnotSpec;
2338        use gam_terms::smooth::SmoothBasisSpec;
2339        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2340        let Some(spec) = self.resolved_termspec.as_ref() else {
2341            return out;
2342        };
2343        for term in &spec.smooth_terms {
2344            match &term.basis {
2345                // Sphere terms: longitude (second feature col) is always
2346                // periodic and exempt from clipping. Latitude is not periodic
2347                // but is a closed-manifold coordinate, so it is clipped to the
2348                // manifold's intrinsic bounds rather than the sampled range —
2349                // see `training_sphere_latitude_bounds`.
2350                SmoothBasisSpec::Sphere { feature_cols, .. } => {
2351                    if let Some(&lon_col) = feature_cols.get(1)
2352                        && lon_col < training_headers.len()
2353                    {
2354                        out.insert(lon_col);
2355                    }
2356                }
2357                // 1D periodic B-spline: the single feature column is periodic.
2358                SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2359                    if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2360                        && *feature_col < training_headers.len()
2361                    {
2362                        out.insert(*feature_col);
2363                    }
2364                }
2365                // Tensor B-spline: each axis whose marginal knotspec is
2366                // PeriodicUniform is periodic; mark those columns.
2367                SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2368                    for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2369                        if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2370                            && let Some(&col) = feature_cols.get(i)
2371                            && col < training_headers.len()
2372                        {
2373                            out.insert(col);
2374                        }
2375                    }
2376                }
2377                _ => {}
2378            }
2379        }
2380        out
2381    }
2382
2383    /// Collect the set of training-column indices that feed a parametric/linear
2384    /// term — on *any* modelled surface (mean, noise/scale, log-slope). A
2385    /// linear term realises the design column `∏ feature_cols` and contributes
2386    /// `β·(that product)` to the linear predictor, so its inputs must be allowed
2387    /// to take any real value at predict time: clamping them to the training
2388    /// range would replace genuine linear extrapolation with a boundary plateau
2389    /// (and freeze the prediction SE at the hull edge). Returned indices
2390    /// reference `self.training_headers` (training-time layout), matching the
2391    /// iteration in `axis_clip_to_training_ranges`. Wilkinson-Rogers `:`
2392    /// interactions contribute every column in their `feature_cols` product.
2393    fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2394        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2395        for spec in self.saved_term_specs() {
2396            for term in &spec.linear_terms {
2397                for col in term.effective_feature_cols() {
2398                    if col < n_training_headers {
2399                        out.insert(col);
2400                    }
2401                }
2402            }
2403        }
2404        out
2405    }
2406
2407    /// Collect the set of training-column indices that feed random-effect
2408    /// grouping terms. These columns are categorical model axes regardless of
2409    /// the ingest schema's scalar storage type, so prediction must leave them
2410    /// untouched and let the frozen random-effect levels decide whether a row is
2411    /// a seen level or the zero-effect unseen-level fallback.
2412    fn training_random_effect_axes(
2413        &self,
2414        n_training_headers: usize,
2415    ) -> std::collections::HashSet<usize> {
2416        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2417        for spec in self.saved_term_specs() {
2418            for term in &spec.random_effect_terms {
2419                if term.feature_col < n_training_headers {
2420                    out.insert(term.feature_col);
2421                }
2422            }
2423        }
2424        out
2425    }
2426
2427    /// Collect the set of training-column indices that feed a non-parametric
2428    /// smooth whose basis performs its own *bounded* extrapolation outside the
2429    /// training hull — on any modelled surface (mean, noise/scale, log-slope).
2430    ///
2431    /// These columns must be exempt from the predict-time axis clip for the same
2432    /// reason periodic/linear/random-effect axes are: the clip clamps a new value
2433    /// to the training extreme *before* the design is built, so the basis is
2434    /// handed an already-clamped coordinate and its extrapolation machinery never
2435    /// fires. The result is a piecewise-constant plateau frozen at the boundary
2436    /// fitted value (with a prediction SE frozen at the hull edge), and — worse —
2437    /// a model that yields *different* predictions through the `FittedModel`
2438    /// predict pipeline than through the raw `build_term_collection_design` path,
2439    /// which does not clip. Exempting these axes routes both entry points through
2440    /// the single basis-layer extrapolation, restoring internal consistency.
2441    ///
2442    /// Only bases with a *bounded* out-of-hull contract are listed, so removing
2443    /// the clip cannot reintroduce the wild basis blow-up the clip guards against:
2444    ///   - B-spline 1D / tensor margins: first-order linear extension off the
2445    ///     boundary slope (`apply_linear_extension_from_first_derivative`) — grows
2446    ///     at most linearly.
2447    ///   - Duchon / thin-plate: the natural-spline boundary conditions make the
2448    ///     fit linear outside the data span — also at most linear growth.
2449    ///   - Matérn: the kernel decays with distance, so the fit reverts smoothly to
2450    ///     its (low-order polynomial / constant) mean far from the data — bounded.
2451    /// `sphere()` axes are deliberately *not* listed: latitude is a closed-manifold
2452    /// coordinate clipped to its intrinsic bounds (`training_sphere_latitude_bounds`)
2453    /// and longitude is periodic (`training_periodic_axes`); both already have the
2454    /// correct, non-plateau handling. `Pca` projections have no extrapolation
2455    /// contract and stay clipped.
2456    ///
2457    /// Returned indices reference `self.training_headers` (training-time layout),
2458    /// matching the iteration in `axis_clip_to_training_ranges`.
2459    fn training_smooth_extrapolation_axes(
2460        &self,
2461        n_training_headers: usize,
2462    ) -> std::collections::HashSet<usize> {
2463        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2464        for spec in self.saved_term_specs() {
2465            for term in &spec.smooth_terms {
2466                collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2467            }
2468        }
2469        out
2470    }
2471
2472    /// Collect the set of training-column indices that feed a *numeric* `by=`
2473    /// multiplier of a varying-coefficient smooth. These columns are linear
2474    /// multipliers (`pred = intercept + z·f(x)`), so they must be exempt from
2475    /// the predict-time axis clip for the same reason parametric linear axes
2476    /// are — see `collect_by_variable_numeric_axes` for the full rationale.
2477    fn training_by_variable_numeric_axes(
2478        &self,
2479        n_training_headers: usize,
2480    ) -> std::collections::HashSet<usize> {
2481        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2482        for spec in self.saved_term_specs() {
2483            for term in &spec.smooth_terms {
2484                collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2485            }
2486        }
2487        out
2488    }
2489
2490    /// Manifold-intrinsic clip bounds for sphere *latitude* columns.
2491    ///
2492    /// A `sphere(lat, lon)` smooth charts the closed manifold S²: the poles
2493    /// (lat = ±π/2, or ±90°) are interior limit points of that manifold, not
2494    /// endpoints of an unbounded axis to extrapolate along. A finite training
2495    /// sample never reaches the pole exactly, so clamping a pole prediction to
2496    /// the *observed* latitude extreme lands at a near-pole latitude where the
2497    /// Wahba SOS kernel's `cos(lat)·cos(lat_c)·cos(Δlon)` term has not yet
2498    /// damped — and the predictor then sweeps a spurious `cos(lon)` profile at
2499    /// what is physically a single point, reintroducing the pole artefact the
2500    /// SOS basis exists to remove.
2501    ///
2502    /// The correct clip bound for this coordinate is therefore the manifold's
2503    /// intrinsic domain — `[-π/2, π/2]` radians or `[-90, 90]` degrees — not
2504    /// the sampled range. Clamping to those bounds keeps the pole reachable
2505    /// (single-valued in longitude) while still mapping any out-of-domain
2506    /// latitude onto the manifold boundary. Longitude needs no entry here: it
2507    /// is periodic and already exempted from clipping entirely
2508    /// (`training_periodic_axes`). Returned indices reference
2509    /// `self.training_headers`, matching the iteration in
2510    /// `axis_clip_to_training_ranges`.
2511    fn training_sphere_latitude_bounds(
2512        &self,
2513        training_headers: &[String],
2514    ) -> std::collections::HashMap<usize, (f64, f64)> {
2515        use gam_terms::smooth::SmoothBasisSpec;
2516        let mut out: std::collections::HashMap<usize, (f64, f64)> =
2517            std::collections::HashMap::new();
2518        let Some(spec) = self.resolved_termspec.as_ref() else {
2519            return out;
2520        };
2521        for term in &spec.smooth_terms {
2522            if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2523                && let Some(&lat_col) = feature_cols.first()
2524                && lat_col < training_headers.len()
2525            {
2526                let bound = if spec.radians {
2527                    std::f64::consts::FRAC_PI_2
2528                } else {
2529                    90.0
2530                };
2531                out.insert(lat_col, (-bound, bound));
2532            }
2533        }
2534        out
2535    }
2536
2537    pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2538        let likelihood = payload.family_state.likelihood();
2539        let class = match payload.model_kind {
2540            ModelKind::Survival => PredictModelClass::Survival,
2541            ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2542            ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2543            ModelKind::LocationScale => {
2544                if likelihood == LikelihoodSpec::gaussian_identity() {
2545                    PredictModelClass::GaussianLocationScale
2546                } else if is_dispersion_location_scale_response(&likelihood.response) {
2547                    PredictModelClass::DispersionLocationScale
2548                } else {
2549                    PredictModelClass::BinomialLocationScale
2550                }
2551            }
2552            ModelKind::Standard => PredictModelClass::Standard,
2553        };
2554        match class {
2555            PredictModelClass::Survival => {
2556                payload.model_kind = ModelKind::Survival;
2557                Self::Survival { payload }
2558            }
2559            PredictModelClass::BernoulliMarginalSlope => {
2560                payload.model_kind = ModelKind::MarginalSlope;
2561                Self::MarginalSlope { payload }
2562            }
2563            PredictModelClass::TransformationNormal => {
2564                payload.model_kind = ModelKind::TransformationNormal;
2565                Self::TransformationNormal { payload }
2566            }
2567            PredictModelClass::GaussianLocationScale
2568            | PredictModelClass::BinomialLocationScale
2569            | PredictModelClass::DispersionLocationScale => {
2570                payload.model_kind = ModelKind::LocationScale;
2571                Self::LocationScale { payload }
2572            }
2573            PredictModelClass::Standard => {
2574                payload.model_kind = ModelKind::Standard;
2575                Self::Standard { payload }
2576            }
2577        }
2578        .with_synchronized_stateful_link_metadata()
2579    }
2580
2581    #[inline]
2582    pub fn payload(&self) -> &FittedModelPayload {
2583        match self {
2584            Self::Standard { payload }
2585            | Self::LocationScale { payload }
2586            | Self::MarginalSlope { payload }
2587            | Self::Survival { payload }
2588            | Self::TransformationNormal { payload } => payload,
2589        }
2590    }
2591
2592    #[inline]
2593    fn payload_mut(&mut self) -> &mut FittedModelPayload {
2594        match self {
2595            Self::Standard { payload }
2596            | Self::LocationScale { payload }
2597            | Self::MarginalSlope { payload }
2598            | Self::Survival { payload }
2599            | Self::TransformationNormal { payload } => payload,
2600        }
2601    }
2602
2603    fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2604        self.synchronize_stateful_link_metadata();
2605        self
2606    }
2607
2608    fn synchronize_stateful_link_metadata(&mut self) {
2609        let payload = self.payload_mut();
2610        payload.used_device = payload
2611            .fit_result
2612            .as_ref()
2613            .or(payload.unified.as_ref())
2614            .is_some_and(|fit| fit.used_device);
2615        payload.synchronize_empty_feature_contract();
2616        let Some(fit) = payload.fit_result.as_ref() else {
2617            return;
2618        };
2619        match (&mut payload.family_state, &fit.fitted_link) {
2620            (
2621                FittedFamily::Standard {
2622                    likelihood,
2623                    latent_cloglog_state,
2624                    ..
2625                },
2626                FittedLinkState::LatentCLogLog { state },
2627            ) if likelihood.is_latent_cloglog() => {
2628                *latent_cloglog_state = Some(*state);
2629            }
2630            (
2631                FittedFamily::Standard {
2632                    likelihood,
2633                    sas_state,
2634                    ..
2635                },
2636                FittedLinkState::Sas { state, covariance },
2637            ) if likelihood.is_binomial_sas() => {
2638                *sas_state = Some(*state);
2639                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2640            }
2641            (
2642                FittedFamily::Standard {
2643                    likelihood,
2644                    sas_state,
2645                    ..
2646                },
2647                FittedLinkState::BetaLogistic { state, covariance },
2648            ) if likelihood.is_binomial_beta_logistic() => {
2649                *sas_state = Some(*state);
2650                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2651            }
2652            (
2653                FittedFamily::Standard {
2654                    likelihood,
2655                    mixture_state,
2656                    ..
2657                },
2658                FittedLinkState::Mixture { state, covariance },
2659            ) if likelihood.is_binomial_mixture() => {
2660                *mixture_state = Some(state.clone());
2661                payload.mixture_link_param_covariance =
2662                    covariance.as_ref().map(array2_to_nestedvec);
2663            }
2664            _ => {}
2665        }
2666    }
2667
2668    #[inline]
2669    pub fn likelihood(&self) -> LikelihoodSpec {
2670        self.payload().family_state.likelihood()
2671    }
2672
2673    /// Columns this model consumes from a prediction frame — its *input
2674    /// contract*.
2675    ///
2676    /// Every variable named by the main formula (features, interaction margins,
2677    /// random-effect groups, and a smooth's `by=` column), the survival
2678    /// entry/exit columns or the transformation-normal response, the auxiliary
2679    /// noise / logslope formula columns, and the offset / noise-offset /
2680    /// latent-`z` columns. The event-indicator and the plain response of a
2681    /// standard model are deliberately excluded: they are not needed to *form*
2682    /// a prediction (the conformal-calibration fold layers the response back on
2683    /// separately).
2684    ///
2685    /// This is the single authority shared by the CLI and PyFFI predict paths.
2686    /// A prediction frame column that is *not* in this set is irrelevant to the
2687    /// model and must be ignored rather than strict-encoded against the
2688    /// training schema — otherwise an unrelated ID/label column with a held-out
2689    /// categorical level aborts predict (#840).
2690    pub fn prediction_required_columns(
2691        &self,
2692    ) -> Result<std::collections::BTreeSet<String>, String> {
2693        let payload = self.payload();
2694        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2695        let mut required = std::collections::BTreeSet::<String>::new();
2696        parsed_term_column_names(&parsed.terms, &mut required);
2697
2698        if let Some((entry, exit, _event)) =
2699            parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2700        {
2701            if let Some(entry) = entry {
2702                required.insert(entry);
2703            }
2704            required.insert(exit);
2705        } else if let Some((left, right, _event)) =
2706            parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707        {
2708            required.insert(left);
2709            required.insert(right);
2710        } else if matches!(
2711            self.predict_model_class(),
2712            PredictModelClass::TransformationNormal
2713        ) {
2714            let response = parsed.response.trim();
2715            if !response.is_empty() && !response.starts_with("Surv(") {
2716                required.insert(response.to_string());
2717            }
2718        }
2719
2720        if let Some(offset) = payload.offset_column.as_ref() {
2721            required.insert(offset.clone());
2722        }
2723        if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2724            required.insert(noise_offset.clone());
2725        }
2726        if matches!(
2727            self.predict_model_class(),
2728            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2729        ) {
2730            if let Some(z_column) = payload.z_column.as_ref() {
2731                required.remove("z");
2732                required.insert(z_column.clone());
2733            }
2734        }
2735        if let Some(noise_formula) = payload.formula_noise.as_ref() {
2736            self.add_auxiliary_formula_columns(
2737                &mut required,
2738                noise_formula,
2739                parsed.response.as_str(),
2740            )?;
2741        }
2742        if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2743            if logslope_formula != "same-as-main" {
2744                self.add_auxiliary_formula_columns(
2745                    &mut required,
2746                    logslope_formula,
2747                    parsed.response.as_str(),
2748                )?;
2749            }
2750        }
2751        Ok(required)
2752    }
2753
2754    /// Columns a *post-fit diagnostic* command (diagnose / sample / report)
2755    /// needs **beyond** [`Self::prediction_required_columns`].
2756    ///
2757    /// Prediction deliberately drops a standard GAM's bare response so a
2758    /// prediction frame may omit it (#840 / #864). Diagnostics are statements
2759    /// *about* that observed response — residuals, R², posterior likelihoods,
2760    /// leave-one-out — so the response must be present. This returns the bare
2761    /// response column when the prediction projection would otherwise drop it,
2762    /// and nothing when the response is already prediction-required (survival
2763    /// `Surv(...)` time/event columns, the transformation-normal response) or
2764    /// is not a plain data column.
2765    ///
2766    /// Centralising the intent here is what makes it *structurally impossible*
2767    /// for a diagnostic command to silently drop the response: callers use
2768    /// `load_dataset…_for_diagnostics`, which always folds these in, instead of
2769    /// each remembering to thread an `extra_required` response by hand.
2770    pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2771        let payload = self.payload();
2772        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2773        // Survival responses are `Surv(...)` expressions, not bare columns; the
2774        // underlying entry/exit columns are already prediction-required.
2775        if parse_surv_response(parsed.response.as_str())
2776            .map_err(|e| e.to_string())?
2777            .is_some()
2778            || parse_surv_interval_response(parsed.response.as_str())
2779                .map_err(|e| e.to_string())?
2780                .is_some()
2781        {
2782            return Ok(Vec::new());
2783        }
2784        let response = parsed.response.trim();
2785        // A response that is empty, or a function-call expression rather than a
2786        // plain data column, has no bare column to re-add.
2787        if response.is_empty() || response.contains('(') {
2788            return Ok(Vec::new());
2789        }
2790        // Already prediction-required (e.g. transformation-normal re-adds it):
2791        // nothing extra to fold in.
2792        if self.prediction_required_columns()?.contains(response) {
2793            return Ok(Vec::new());
2794        }
2795        Ok(vec![response.to_string()])
2796    }
2797
2798    /// Add the columns referenced by an auxiliary (noise / logslope) formula,
2799    /// which may be supplied as a full `lhs ~ rhs` formula or as a bare RHS.
2800    fn add_auxiliary_formula_columns(
2801        &self,
2802        required: &mut std::collections::BTreeSet<String>,
2803        formula_or_rhs: &str,
2804        response: &str,
2805    ) -> Result<(), String> {
2806        let trimmed = formula_or_rhs.trim();
2807        if trimmed.is_empty() || trimmed == "1" {
2808            return Ok(());
2809        }
2810        let formula = if trimmed.contains('~') {
2811            trimmed.to_string()
2812        } else {
2813            format!("{response} ~ {trimmed}")
2814        };
2815        let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2816        parsed_term_column_names(&parsed.terms, required);
2817        Ok(())
2818    }
2819
2820    #[inline]
2821    pub fn predict_model_class(&self) -> PredictModelClass {
2822        match &self.payload().family_state {
2823            FittedFamily::Survival { .. }
2824            | FittedFamily::LatentSurvival { .. }
2825            | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2826            FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2827            FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2828            FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2829                PredictModelClass::GaussianLocationScale
2830            }
2831            FittedFamily::LocationScale { likelihood, .. }
2832                if is_dispersion_location_scale_response(&likelihood.response) =>
2833            {
2834                PredictModelClass::DispersionLocationScale
2835            }
2836            FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2837            FittedFamily::Standard { .. } => PredictModelClass::Standard,
2838        }
2839    }
2840
2841    pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2842        let payload = self.payload();
2843        let (knots, degree) = match (
2844            payload.linkwiggle_knots.as_ref(),
2845            payload.linkwiggle_degree,
2846        ) {
2847            (None, None) => return Ok(None),
2848            (Some(knots), Some(degree)) => (knots.clone(), degree),
2849            _ => {
2850                return Err(FittedModelError::SchemaMismatch {
2851                    reason:
2852                        "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2853                            .to_string(),
2854                })
2855            }
2856        };
2857        let resolved_link = self.resolved_inverse_link()?;
2858        let saved_link_disallows_wiggle = resolved_link
2859            .as_ref()
2860            .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2861            || payload
2862                .link
2863                .as_ref()
2864                .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2865        if saved_link_disallows_wiggle {
2866            return Err(FittedModelError::IncompatibleConfig {
2867                reason: joint_wiggle_unsupported_link_message("link wiggle"),
2868            });
2869        }
2870        let beta = match self.predict_model_class() {
2871            // #1596: the frozen-basis de-aliased standard link-warp is fit in a
2872            // reduced, identifiable coordinate `γ` and the fit_result LinkWiggle
2873            // block stores `γ` (its true free parameters). The full-width
2874            // standard-basis lift `β_w = Z·γ`, the coefficients the predict-time
2875            // I-spline basis multiplies, is persisted in `payload.beta_link_wiggle`
2876            // — prefer it when present. Without it (the dynamic-basis path) the
2877            // block coefficients ARE the standard-basis warp, read directly.
2878            PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2879                payload.beta_link_wiggle.clone().expect("checked is_some")
2880            }
2881            PredictModelClass::Standard => {
2882                let fit = payload.fit_result.as_ref().ok_or_else(|| {
2883                    FittedModelError::MissingField {
2884                        reason:
2885                            "standard link-wiggle model is missing canonical fit_result payload"
2886                                .to_string(),
2887                    }
2888                })?;
2889                if fit.blocks.len() != 2
2890                    || fit.blocks[0].role != BlockRole::Mean
2891                    || fit.blocks[1].role != BlockRole::LinkWiggle
2892                {
2893                    return Err(FittedModelError::SchemaMismatch {
2894                        reason:
2895                            "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2896                                .to_string(),
2897                    });
2898                }
2899                fit.block_by_role(BlockRole::LinkWiggle)
2900                    .ok_or_else(|| FittedModelError::MissingField {
2901                        reason:
2902                            "standard link-wiggle model is missing LinkWiggle coefficient block"
2903                                .to_string(),
2904                    })?
2905                    .beta
2906                    .to_vec()
2907            }
2908            _ => payload
2909                .beta_link_wiggle
2910                .clone()
2911                .ok_or_else(|| FittedModelError::MissingField {
2912                    reason:
2913                        "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2914                            .to_string(),
2915                })?,
2916        };
2917        Ok(Some(SavedLinkWiggleRuntime {
2918            knots,
2919            degree,
2920            beta,
2921        }))
2922    }
2923
2924    pub fn saved_baseline_time_wiggle(
2925        &self,
2926    ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2927        let payload = self.payload();
2928        if payload
2929            .survival_cause_count
2930            .is_some_and(|cause_count| cause_count > 1)
2931            && payload.beta_baseline_timewiggle.is_none()
2932            && payload.beta_baseline_timewiggle_by_cause.is_some()
2933        {
2934            return Err(FittedModelError::SchemaMismatch {
2935                reason:
2936                    "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2937                        .to_string(),
2938            });
2939        }
2940        match (
2941            payload.baseline_timewiggle_knots.as_ref(),
2942            payload.baseline_timewiggle_degree,
2943            payload.baseline_timewiggle_penalty_orders.as_ref(),
2944            payload.baseline_timewiggle_double_penalty,
2945            payload.beta_baseline_timewiggle.as_ref(),
2946        ) {
2947            (None, None, None, None, None) => Ok(None),
2948            (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2949                Ok(Some(SavedBaselineTimeWiggleRuntime {
2950                    knots: knots.clone(),
2951                    degree,
2952                    penalty_orders: penalty_orders.clone(),
2953                    double_penalty,
2954                    beta: beta.clone(),
2955                }))
2956            }
2957            _ => Err(FittedModelError::SchemaMismatch {
2958                reason:
2959                    "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2960                        .to_string(),
2961            }),
2962        }
2963    }
2964
2965    /// Whether this model has a link wiggle component with complete metadata.
2966    #[inline]
2967    pub fn has_link_wiggle(&self) -> bool {
2968        self.saved_link_wiggle()
2969            .map(|runtime| runtime.is_some())
2970            .unwrap_or(false)
2971    }
2972
2973    /// Whether this model has a baseline-time wiggle component with complete metadata.
2974    #[inline]
2975    pub fn has_baseline_time_wiggle(&self) -> bool {
2976        let payload = self.payload();
2977        if payload
2978            .survival_cause_count
2979            .is_some_and(|cause_count| cause_count > 1)
2980        {
2981            return payload.baseline_timewiggle_knots.is_some()
2982                && payload.baseline_timewiggle_degree.is_some()
2983                && payload.baseline_timewiggle_penalty_orders.is_some()
2984                && payload.baseline_timewiggle_double_penalty.is_some()
2985                && payload.beta_baseline_timewiggle_by_cause.is_some();
2986        }
2987        self.saved_baseline_time_wiggle()
2988            .map(|runtime| runtime.is_some())
2989            .unwrap_or(false)
2990    }
2991
2992    /// Whether the default point prediction must integrate the inverse link
2993    /// over the coefficient posterior — reporting the posterior mean
2994    /// `E[g⁻¹(Xβ)]` — rather than plugging in the posterior mode `g⁻¹(Xβ̂)`.
2995    ///
2996    /// SPEC (issue #960): the posterior mean is *always* the default point
2997    /// estimate (never MAP). It is observably distinct from the plug-in exactly
2998    /// when the inverse link is *curved* over the posterior's uncertainty, so
2999    /// `E[g⁻¹(η)] ≠ g⁻¹(E[η])` by Jensen. The curvature-based classification is:
3000    ///   * all log-link families (Poisson / Gamma / Tweedie / NegativeBinomial):
3001    ///     `E[exp η] = exp(η + se²/2) ≠ exp(η)` (log-normal MGF);
3002    ///   * all Binomial links (logit / probit / cloglog / SAS / BetaLogistic /
3003    ///     Mixture / LatentCLogLog): bounded sigmoidal inverse links;
3004    ///   * Beta (logit link): `E[σ(η)] ≠ σ(E[η])`;
3005    ///   * Royston–Parmar (curved survival-probability inverse link).
3006    /// The integral collapses to the plug-in (so the cheaper plug-in path is
3007    /// exact and taken instead) only for the effectively-linear identity-link
3008    /// Gaussian. Any model carrying a link wiggle or baseline-time wiggle is
3009    /// curved regardless of family. This curvature partition mirrors
3010    /// `families::family_runtime::posterior_mean`, the compute path that produces the
3011    /// corrected mean for each of these families.
3012    ///
3013    /// This is the single source of truth shared by the CLI (`gam predict`)
3014    /// and the Python FFI prediction path so the two can never drift on which
3015    /// models receive the posterior-mean correction.
3016    #[inline]
3017    pub fn prediction_uses_posterior_mean(&self) -> bool {
3018        let family = self.likelihood();
3019        let curved_family = match &family.response {
3020            // Identity-link Gaussian: inverse link is linear, so the posterior
3021            // mean equals the plug-in and the cheaper exact path is taken.
3022            ResponseFamily::Gaussian => false,
3023            // Log-link families: E[exp η] = exp(η + se²/2) ≠ exp(η).
3024            ResponseFamily::Poisson
3025            | ResponseFamily::Gamma
3026            | ResponseFamily::Tweedie { .. }
3027            | ResponseFamily::NegativeBinomial { .. } => true,
3028            // Beta (logit link): E[σ(η)] ≠ σ(E[η]).
3029            ResponseFamily::Beta { .. } => true,
3030            // Royston–Parmar: curved survival-probability inverse link.
3031            ResponseFamily::RoystonParmar => true,
3032            // Binomial: every link variant (logit / probit / cloglog / SAS /
3033            // BetaLogistic / Mixture / LatentCLogLog) is a curved sigmoid.
3034            ResponseFamily::Binomial => matches!(
3035                &family.link,
3036                InverseLink::Standard(_)
3037                    | InverseLink::Sas(_)
3038                    | InverseLink::BetaLogistic(_)
3039                    | InverseLink::Mixture(_)
3040                    | InverseLink::LatentCLogLog(_)
3041            ),
3042        };
3043        curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3044    }
3045
3046    pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3047        self.payload().validate_payload_version()?;
3048        if matches!(
3049            self.predict_model_class(),
3050            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3051        ) {
3052            if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3053                runtime.validate_exact_replay_contract().map_err(|err| {
3054                    FittedModelError::PayloadCorrupt {
3055                        reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3056                    }
3057                })?;
3058            }
3059            if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3060                runtime.validate_exact_replay_contract().map_err(|err| {
3061                    FittedModelError::PayloadCorrupt {
3062                        reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3063                    }
3064                })?;
3065            }
3066        }
3067        let runtime = SavedPredictionRuntime {
3068            model_class: self.predict_model_class(),
3069            likelihood: self.likelihood(),
3070            inverse_link: self.resolved_inverse_link()?,
3071            link_wiggle: self.saved_link_wiggle()?,
3072            baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3073            score_warp: self.payload().score_warp_runtime.clone(),
3074            link_deviation: self.payload().link_deviation_runtime.clone(),
3075            latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3076            latent_z_conditional_calibration: self
3077                .payload()
3078                .latent_z_conditional_calibration
3079                .clone(),
3080            influence_absorber_width: self.payload().influence_absorber_width,
3081        };
3082        if matches!(
3083            runtime.model_class,
3084            PredictModelClass::GaussianLocationScale
3085                | PredictModelClass::BinomialLocationScale
3086                | PredictModelClass::DispersionLocationScale
3087        ) {
3088            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3089                FittedModelError::MissingField {
3090                    reason: "location-scale model is missing canonical fit_result payload"
3091                        .to_string(),
3092                }
3093            })?;
3094            validate_location_scale_saved_fit(
3095                fit,
3096                runtime.model_class,
3097                runtime.link_wiggle.as_ref(),
3098            )?;
3099        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3100            && self
3101                .payload()
3102                .survival_likelihood
3103                .as_deref()
3104                .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3105        {
3106            validate_survival_location_scale_saved_fit(
3107                self.payload(),
3108                runtime.link_wiggle.as_ref(),
3109            )?;
3110        } else if matches!(
3111            runtime.model_class,
3112            PredictModelClass::BernoulliMarginalSlope
3113        ) {
3114            let unified =
3115                self.payload()
3116                    .unified
3117                    .as_ref()
3118                    .ok_or_else(|| FittedModelError::MissingField {
3119                        reason: "marginal-slope model is missing unified fit payload; refit"
3120                            .to_string(),
3121                    })?;
3122            validate_marginal_slope_saved_fit(
3123                unified,
3124                runtime.score_warp.as_ref(),
3125                runtime.link_deviation.as_ref(),
3126                "unified",
3127            )?;
3128        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3129            && self
3130                .payload()
3131                .survival_likelihood
3132                .as_deref()
3133                .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3134        {
3135            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3136                FittedModelError::MissingField {
3137                    reason: "survival marginal-slope model is missing canonical fit_result payload"
3138                        .to_string(),
3139                }
3140            })?;
3141            validate_survival_marginal_slope_saved_fit(
3142                fit,
3143                runtime.score_warp.as_ref(),
3144                runtime.link_deviation.as_ref(),
3145                "fit_result",
3146            )?;
3147        }
3148        Ok(runtime)
3149    }
3150
3151    pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3152        let payload = self.payload();
3153        let raw = match &payload.family_state {
3154            FittedFamily::Standard {
3155                likelihood,
3156                sas_state,
3157                ..
3158            } if likelihood.is_binomial_sas() => {
3159                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3160                    reason: "binomial-sas model is missing state in family_state.sas_state"
3161                        .to_string(),
3162                })?
3163            }
3164            FittedFamily::LocationScale {
3165                likelihood,
3166                base_link,
3167            } if likelihood.is_binomial_sas() => match base_link {
3168                Some(InverseLink::Sas(state)) => *state,
3169                _ => {
3170                    return Err(FittedModelError::MissingField {
3171                        reason: "binomial-sas location-scale model is missing SAS base_link state"
3172                            .to_string(),
3173                    });
3174                }
3175            },
3176            _ => return Ok(None),
3177        };
3178        state_from_sasspec(SasLinkSpec {
3179            initial_epsilon: raw.epsilon,
3180            initial_log_delta: raw.log_delta,
3181        })
3182        .map(Some)
3183        .map_err(|e| FittedModelError::PayloadCorrupt {
3184            reason: format!("invalid saved SAS link state: {e}"),
3185        })
3186    }
3187
3188    pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3189        let payload = self.payload();
3190        let raw = match &payload.family_state {
3191            FittedFamily::Standard {
3192                likelihood,
3193                sas_state,
3194                ..
3195            } if likelihood.is_binomial_beta_logistic() => {
3196                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3197                    reason:
3198                        "binomial-beta-logistic model is missing state in family_state.sas_state"
3199                            .to_string(),
3200                })?
3201            }
3202            FittedFamily::LocationScale {
3203                likelihood,
3204                base_link,
3205            } if likelihood.is_binomial_beta_logistic() => match base_link {
3206                Some(InverseLink::BetaLogistic(state)) => *state,
3207                _ => {
3208                    return Err(FittedModelError::MissingField {
3209                        reason:
3210                            "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3211                                .to_string(),
3212                    });
3213                }
3214            },
3215            _ => return Ok(None),
3216        };
3217        state_from_beta_logisticspec(SasLinkSpec {
3218            initial_epsilon: raw.epsilon,
3219            initial_log_delta: raw.log_delta,
3220        })
3221        .map(Some)
3222        .map_err(|e| FittedModelError::PayloadCorrupt {
3223            reason: format!("invalid saved Beta-Logistic link state: {e}"),
3224        })
3225    }
3226
3227    pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3228        let payload = self.payload();
3229        match &payload.family_state {
3230            FittedFamily::Standard {
3231                likelihood,
3232                mixture_state,
3233                ..
3234            } if likelihood.is_binomial_mixture() => mixture_state
3235                .clone()
3236                .ok_or_else(|| FittedModelError::MissingField {
3237                    reason: "binomial-mixture model is missing state in family_state.mixture_state"
3238                        .to_string(),
3239                })
3240                .map(Some),
3241            FittedFamily::LocationScale {
3242                likelihood,
3243                base_link,
3244            } if likelihood.is_binomial_mixture() => match base_link {
3245                Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3246                _ => Err(FittedModelError::MissingField {
3247                    reason:
3248                        "binomial-mixture location-scale model is missing mixture base_link state"
3249                            .to_string(),
3250                }),
3251            },
3252            _ => Ok(None),
3253        }
3254    }
3255
3256    pub fn saved_latent_cloglog_state(
3257        &self,
3258    ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3259        let payload = self.payload();
3260        match &payload.family_state {
3261            FittedFamily::Standard {
3262                likelihood,
3263                latent_cloglog_state,
3264                ..
3265            } if likelihood.is_latent_cloglog() => latent_cloglog_state
3266                .ok_or_else(|| FittedModelError::MissingField {
3267                    reason:
3268                        "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3269                            .to_string(),
3270                })
3271                .map(Some),
3272            _ => Ok(None),
3273        }
3274    }
3275
3276    pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3277        let stateful = if let Some(state) = self.saved_mixture_state()? {
3278            Some(InverseLink::Mixture(state))
3279        } else if let Some(state) = self.saved_latent_cloglog_state()? {
3280            Some(InverseLink::LatentCLogLog(state))
3281        } else if let Some(state) = self.saved_beta_logistic_state()? {
3282            Some(InverseLink::BetaLogistic(state))
3283        } else {
3284            self.saved_sas_state()?.map(InverseLink::Sas)
3285        };
3286        match &self.payload().family_state {
3287            FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3288            FittedFamily::Standard { link, .. } => {
3289                Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3290            }
3291            FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3292            FittedFamily::Survival { .. }
3293            | FittedFamily::LatentSurvival { .. }
3294            | FittedFamily::LatentBinary { .. } => Ok(None),
3295            FittedFamily::TransformationNormal { .. } => Ok(None),
3296        }
3297    }
3298
3299    /// V∞ §5 coverage floor for the measure-jet extrapolation variance: a
3300    /// band level "covers" a query once its kernel mass reaches this fraction
3301    /// of that level's web-averaged support. Magic-by-default (no dial):
3302    /// 0.05 keeps the ε★ gate's bounded discontinuity at ≤ 5 % of the
3303    /// spectrum's total prior ignorance (see the monotonicity theorem in
3304    /// `terms/basis/measure_jet_predict.rs`) while still refusing credit
3305    /// for stray sub-floor kernel mass at levels finer than the first
3306    /// covering scale.
3307    const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3308
3309    /// V∞ §5 producer: per-row measure-jet extrapolation variance on the η
3310    /// scale for a prediction batch (`docs/measure_jet_v_infinity.md`).
3311    ///
3312    /// For every frozen measure-jet term in `resolved_termspec` this prices
3313    /// the off-support ignorance of the fitted multiscale spectrum at each
3314    /// query row: support curve from the frozen nodes/masses/band
3315    /// ([`gam_terms::basis::measure_jet_support_curve`]), fitted per-scale
3316    /// amplitudes λ̂_ℓ read from the fit's `lambdas` through the replayed
3317    /// design's penalty layout, folded through
3318    /// [`gam_terms::basis::measure_jet_extrapolation_variance`] and scaled by
3319    /// the fit's coefficient-covariance scale φ̂ so the result sits on Vp's
3320    /// η-variance scale. Terms not yet frozen (no `frozen_quadrature` or
3321    /// non-`UserProvided` centers) are skipped with a warning. Returns
3322    /// `Ok(None)` when no measure-jet term contributes, so callers leave
3323    /// `PredictUncertaintyOptions::extrapolation_variance` untouched.
3324    ///
3325    /// `data` must be the RAW (unclipped) prediction matrix in prediction
3326    /// column order — clipping to the training ranges would freeze the
3327    /// distance signal at the hull and defeat the honesty contract — and
3328    /// `col_map` the prediction header → column map (the same map handed to
3329    /// the design builder). This is the minimal-plumbing producer seam: the
3330    /// option-building callers (CLI predict, FFI) hold exactly
3331    /// `(model, data, col_map)` at the point where they assemble
3332    /// `PredictUncertaintyOptions`, and the fusion in
3333    /// `predict_gamwith_uncertainty` adds the array AFTER its multiplicative
3334    /// inflations: `Var_total = Var_Vp·inflation + Var_extrap`.
3335    pub fn measure_jet_extrapolation_variance(
3336        &self,
3337        data: ndarray::ArrayView2<'_, f64>,
3338        col_map: &HashMap<String, usize>,
3339    ) -> Result<Option<Array1<f64>>, FittedModelError> {
3340        use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3341        use gam_terms::smooth::build_term_collection_design;
3342        use gam_terms::smooth::SmoothBasisSpec;
3343        let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3344            return Ok(None);
3345        };
3346        if data.nrows() == 0
3347            || !saved_spec
3348                .smooth_terms
3349                .iter()
3350                .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3351        {
3352            return Ok(None);
3353        }
3354        let fit = self
3355            .fit_result
3356            .as_ref()
3357            .ok_or_else(|| FittedModelError::MissingField {
3358                reason: "measure-jet extrapolation variance requires the canonical \
3359                    fit_result payload; refit"
3360                    .to_string(),
3361            })?;
3362        let spec = crate::survival::predict::resolve_termspec_for_prediction(
3363            &self.resolved_termspec,
3364            self.training_headers.as_ref(),
3365            col_map,
3366            "resolved_termspec",
3367        )
3368        .map_err(|e| FittedModelError::SchemaMismatch {
3369            reason: format!("measure-jet extrapolation variance: {e}"),
3370        })?;
3371        // Penalty layout replay: the global penalty indices (→ `fit.lambdas`)
3372        // come from the SAME design builder the predict pipeline uses. One
3373        // probe row suffices — for a frozen spec the penalty layout is
3374        // row-count-invariant (centers, masses, band, and identifiability
3375        // transforms all replay verbatim) — keeping this O(centers²) instead
3376        // of duplicating the full O(rows·centers) prediction design build.
3377        let probe = data.slice(ndarray::s![0..1, ..]);
3378        let design = build_term_collection_design(probe, &spec).map_err(|e| {
3379            FittedModelError::SchemaMismatch {
3380                reason: format!(
3381                    "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3382                ),
3383            }
3384        })?;
3385        let lambdas = &fit.lambdas;
3386        // λ̂ are fitted on Frobenius-normalized penalties. The term loop
3387        // unnormalizes them to physical precisions before pricing; multiplying
3388        // by the coefficient-covariance scale puts Var_extrap on the same
3389        // η-variance scale as Vp.
3390        let phi_scale = fit.coefficient_covariance_scale();
3391        let mut total = Array1::<f64>::zeros(data.nrows());
3392        let mut contributed = false;
3393        for term in &spec.smooth_terms {
3394            let SmoothBasisSpec::MeasureJet {
3395                feature_cols,
3396                spec: mj,
3397                input_scales,
3398            } = &term.basis
3399            else {
3400                continue;
3401            };
3402            let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3403                (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3404            else {
3405                log::warn!(
3406                    "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3407                    quadrature); skipping its extrapolation variance",
3408                    term.name
3409                );
3410                continue;
3411            };
3412            let n_levels = frozen.eps_band.len();
3413            // λ̂ per level from the replayed layout: per-scale candidates carry
3414            // `PenaltySource::Other("measure_jet_scale_ℓ")`; fused
3415            // (pinned-order) mode carries one Primary charged once for the
3416            // whole band. The DoublePenaltyNullspace ridge is EXCLUDED — it shrinks
3417            // coefficients, it is not a scale amplitude, and counting it would
3418            // double-charge the spectrum.
3419            let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3420                lambdas
3421                    .get(global_index)
3422                    .copied()
3423                    .ok_or_else(|| FittedModelError::SchemaMismatch {
3424                        reason: format!(
3425                            "measure-jet term '{}': penalty global index {global_index} out \
3426                            of bounds for {} fitted lambdas",
3427                            term.name,
3428                            lambdas.len()
3429                        ),
3430                    })
3431            };
3432            let mut per_scale: Vec<(usize, f64)> = Vec::new();
3433            let mut fused: Option<f64> = None;
3434            for info in &design.penaltyinfo {
3435                if info.termname.as_deref() != Some(term.name.as_str()) {
3436                    continue;
3437                }
3438                match &info.penalty.source {
3439                    PenaltySource::Other(label) => {
3440                        if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3441                            let level: usize = level_txt.parse().map_err(|_| {
3442                                FittedModelError::SchemaMismatch {
3443                                    reason: format!(
3444                                        "measure-jet term '{}': unparseable penalty label \
3445                                        '{label}'",
3446                                        term.name
3447                                    ),
3448                                }
3449                            })?;
3450                            per_scale.push((level, read_lambda(info.global_index)?));
3451                        }
3452                    }
3453                    PenaltySource::Primary => {
3454                        fused = Some(read_lambda(info.global_index)?);
3455                    }
3456                    _ => {}
3457                }
3458            }
3459            let mut lambda_phys = Vec::with_capacity(n_levels);
3460            let spectrum = if per_scale.is_empty() {
3461                let Some(lam) = fused else {
3462                    log::warn!(
3463                        "measure-jet term '{}' has no fitted amplitude in the penalty \
3464                        layout; skipping its extrapolation variance",
3465                        term.name
3466                    );
3467                    continue;
3468                };
3469                let Some(c) = frozen.fused_penalty_normalization_scale else {
3470                    log::warn!(
3471                        "measure-jet term '{}' is missing the fused penalty normalization scale; \
3472                        skipping its extrapolation variance",
3473                        term.name
3474                    );
3475                    continue;
3476                };
3477                MeasureJetExtrapolationSpectrum::Fused(lam / c)
3478            } else {
3479                per_scale.sort_by_key(|&(level, _)| level);
3480                let levels_complete = per_scale.len() == n_levels
3481                    && per_scale
3482                        .iter()
3483                        .enumerate()
3484                        .all(|(i, &(level, _))| level == i);
3485                if !levels_complete {
3486                    log::warn!(
3487                        "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3488                        scales; skipping its extrapolation variance",
3489                        term.name,
3490                        per_scale.len(),
3491                        n_levels
3492                    );
3493                    continue;
3494                }
3495                if frozen.penalty_normalization_scales.len() != n_levels {
3496                    log::warn!(
3497                        "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3498                        band scales; skipping its extrapolation variance",
3499                        term.name,
3500                        frozen.penalty_normalization_scales.len(),
3501                        n_levels
3502                    );
3503                    continue;
3504                }
3505                lambda_phys.extend(
3506                    per_scale
3507                        .iter()
3508                        .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3509                );
3510                MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3511            };
3512            // Query rows in the frozen geometry's coordinates: select the
3513            // term's axes and replay the per-axis standardization exactly as
3514            // the build dispatch does (divide by σ_a when input_scales is
3515            // Some; the persisted centers are already post-standardization).
3516            let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3517            for (j, &col) in feature_cols.iter().enumerate() {
3518                if col >= data.ncols() {
3519                    return Err(FittedModelError::SchemaMismatch {
3520                        reason: format!(
3521                            "measure-jet term '{}': prediction column {col} out of bounds \
3522                            for {} data columns",
3523                            term.name,
3524                            data.ncols()
3525                        ),
3526                    });
3527                }
3528                queries.column_mut(j).assign(&data.column(col));
3529            }
3530            if let Some(scales) = input_scales {
3531                if scales.len() != feature_cols.len() {
3532                    return Err(FittedModelError::SchemaMismatch {
3533                        reason: format!(
3534                            "measure-jet term '{}': {} input scales for {} axes",
3535                            term.name,
3536                            scales.len(),
3537                            feature_cols.len()
3538                        ),
3539                    });
3540                }
3541                for (j, &scale) in scales.iter().enumerate() {
3542                    queries.column_mut(j).mapv_inplace(|v| v / scale);
3543                }
3544            }
3545            let support = gam_terms::basis::measure_jet_support_curve(
3546                queries.view(),
3547                centers.view(),
3548                frozen.masses.view(),
3549                &frozen.eps_band,
3550            )
3551            .map_err(|e| FittedModelError::SchemaMismatch {
3552                reason: format!(
3553                    "measure-jet term '{}': support curve failed: {e}",
3554                    term.name
3555                ),
3556            })?;
3557            for i in 0..data.nrows() {
3558                let v = gam_terms::basis::measure_jet_extrapolation_variance(
3559                    support.row(i),
3560                    &frozen.eps_band,
3561                    &frozen.support_means,
3562                    spectrum,
3563                    Self::MEASURE_JET_COVERAGE_FLOOR,
3564                )
3565                .map_err(|e| FittedModelError::SchemaMismatch {
3566                    reason: format!(
3567                        "measure-jet term '{}': extrapolation variance failed: {e}",
3568                        term.name
3569                    ),
3570                })?;
3571                total[i] += phi_scale * v;
3572            }
3573            contributed = true;
3574        }
3575        Ok(contributed.then_some(total))
3576    }
3577
3578    /// Access the unified fit result, if stored.
3579    pub fn unified(&self) -> Option<&UnifiedFitResult> {
3580        self.payload().unified.as_ref()
3581    }
3582
3583    pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3584        let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3585            reason: format!("failed to read model '{}': {e}", path.display()),
3586        })?;
3587        let model: Self =
3588            serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3589                reason: format!("failed to parse model json: {e}"),
3590            })?;
3591        let model = model.with_synchronized_stateful_link_metadata();
3592        model.validate_for_persistence()?;
3593        model.validate_numeric_finiteness()?;
3594        Ok(model)
3595    }
3596
3597    pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3598        let normalized = self.clone().with_synchronized_stateful_link_metadata();
3599        normalized.validate_for_persistence()?;
3600        normalized.validate_numeric_finiteness()?;
3601        // Write to a sibling temp file, fsync, then rename into place so a
3602        // crash mid-write never corrupts the user's existing saved fit.
3603        // Concurrent writers to the same path each have a distinct temp
3604        // suffix (pid + nanos), so neither stomps the other's in-flight
3605        // bytes; the rename winner is last-rename-wins, which is the
3606        // expected last-write-wins semantics for a single canonical path.
3607        let parent = path.parent().unwrap_or_else(|| Path::new("."));
3608        let file_name = path
3609            .file_name()
3610            .and_then(|s| s.to_str())
3611            .unwrap_or("model.json");
3612        let pid = std::process::id();
3613        let nanos = std::time::SystemTime::now()
3614            .duration_since(std::time::UNIX_EPOCH)
3615            .map(|d| d.as_nanos())
3616            .unwrap_or(0);
3617        let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3618        let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3619            reason: format!("failed to write model '{}': {e}", tmp.display()),
3620        })?;
3621        let mut writer = std::io::BufWriter::new(file);
3622        let ser_result = serde_json::to_writer(&mut writer, &normalized);
3623        if let Err(e) = ser_result {
3624            // Best-effort temp cleanup on serialization failure. flush
3625            // returns io::Result<()>; discarding via `.ok()` is enough.
3626            std::io::Write::flush(&mut writer).ok();
3627            drop(writer);
3628            fs::remove_file(&tmp).ok();
3629            return Err(FittedModelError::PayloadCorrupt {
3630                reason: format!("failed to serialize model: {e}"),
3631            });
3632        }
3633        std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3634            reason: format!("failed to write model '{}': {e}", tmp.display()),
3635        })?;
3636        // Recover the underlying File to fsync its contents before rename.
3637        let inner = writer
3638            .into_inner()
3639            .map_err(|e| FittedModelError::PayloadCorrupt {
3640                reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3641            })?;
3642        inner.sync_all().ok();
3643        drop(inner);
3644        if let Err(e) = fs::rename(&tmp, path) {
3645            fs::remove_file(&tmp).ok();
3646            return Err(FittedModelError::PayloadCorrupt {
3647                reason: format!("failed to publish model '{}': {e}", path.display()),
3648            });
3649        }
3650        // fsync the parent directory so the rename itself is durable
3651        // across a crash; without this, the rename can be lost even though
3652        // file contents reached disk. Best-effort on platforms that don't
3653        // support opening a directory for fsync.
3654        if let Ok(d) = fs::File::open(parent) {
3655            d.sync_all().ok();
3656        }
3657        Ok(())
3658    }
3659
3660    pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3661        self.data_schema
3662            .as_ref()
3663            .ok_or_else(|| FittedModelError::MissingField {
3664                reason: "model is missing data_schema; refit".to_string(),
3665            })
3666    }
3667
3668    /// Restore the exact in-memory spline-scan fit from a scan-bearing
3669    /// payload (#1030/#1034). `Ok(None)` for dense models; the returned
3670    /// `predict` replays the training Gaussian bridge bit-for-bit.
3671    pub fn saved_spline_scan(
3672        &self,
3673    ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3674        let Some(saved) = self.spline_scan.as_ref() else {
3675            return Ok(None);
3676        };
3677        let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3678            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3679        Ok(Some((saved.feature_column.as_str(), fit)))
3680    }
3681
3682    /// Restore the in-memory residual-cascade fit from a cascade-bearing
3683    /// payload (#1032). `Ok(None)` for non-cascade models; the returned fit
3684    /// replays the multilevel Wendland-frame posterior for the d ∈ {2, 3}
3685    /// feature columns at each predict point.
3686    pub fn saved_residual_cascade(
3687        &self,
3688    ) -> Result<
3689        Option<(
3690            &[String],
3691            gam_solve::residual_cascade::ResidualCascadeFit,
3692        )>,
3693        FittedModelError,
3694    > {
3695        let Some(saved) = self.residual_cascade.as_ref() else {
3696            return Ok(None);
3697        };
3698        let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3699            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3700        Ok(Some((saved.feature_columns.as_slice(), fit)))
3701    }
3702
3703    pub fn random_effect_group_columns(&self) -> HashSet<String> {
3704        let Some(training_headers) = self.training_headers.as_ref() else {
3705            return HashSet::new();
3706        };
3707        let mut out = HashSet::<String>::new();
3708        for spec in self.saved_term_specs() {
3709            for term in &spec.random_effect_terms {
3710                if let Some(name) = training_headers.get(term.feature_col) {
3711                    out.insert(name.clone());
3712                }
3713            }
3714        }
3715        out
3716    }
3717
3718    pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3719        // Hard version gate. The struct's ~40 Option<T> fields carry
3720        // `#[serde(default)]`, which is by design forward-compatible: old
3721        // payloads missing a new optional field decode with `None`. BUT:
3722        // when a new CLI release adds a required field for some family_state
3723        // (enforced below), an older model loaded by the newer CLI would have
3724        // `None` in that slot and the family-specific branch below would
3725        // correctly reject it — unless the new field also happens to slot
3726        // under a branch that hasn't been touched. Conversely, a newer model
3727        // loaded by an older CLI silently drops fields the older struct
3728        // doesn't know about. Both directions are silent-drift hazards. We
3729        // close them with an exact-version check anchored to the canonical
3730        // MODEL_PAYLOAD_VERSION constant — every payload must round-trip
3731        // identically between writers and readers running the same schema.
3732        self.validate_payload_version()?;
3733        if let Some(scan) = self.spline_scan.as_ref() {
3734            // Spline-scan representation (#1030/#1034): the smoother state IS
3735            // the fit. It is exclusive with the dense representation, only
3736            // standard Gaussian-identity models can carry it, and the state
3737            // must restore cleanly so predict never sees a corrupt snapshot.
3738            if self.fit_result.is_some() || self.unified.is_some() {
3739                return Err(FittedModelError::SchemaMismatch {
3740                    reason: "spline-scan model must not also carry a dense fit_result/unified \
3741                             payload; the representations are mutually exclusive"
3742                        .to_string(),
3743                });
3744            }
3745            if self.model_kind != ModelKind::Standard
3746                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3747            {
3748                return Err(FittedModelError::SchemaMismatch {
3749                    reason: format!(
3750                        "spline-scan representation requires a standard Gaussian-identity model; \
3751                         got model_kind={:?}, likelihood={:?}",
3752                        self.model_kind,
3753                        self.family_state.likelihood()
3754                    ),
3755                });
3756            }
3757            if scan.feature_column.is_empty() {
3758                return Err(FittedModelError::MissingField {
3759                    reason: "spline-scan model is missing its feature column name; refit"
3760                        .to_string(),
3761                });
3762            }
3763            gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3764                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3765            // A scan model carries NO dense design, so the dense-path
3766            // requirements below (resolved_termspec, fit_result finiteness,
3767            // family-specific blocks) do not apply. Enforce only the metadata
3768            // predict actually consumes — the feature column resolves against
3769            // training_headers / data_schema — then accept.
3770            if self.data_schema.is_none() {
3771                return Err(FittedModelError::MissingField {
3772                    reason: "spline-scan model is missing data_schema; refit".to_string(),
3773                });
3774            }
3775            if self.training_headers.is_none() {
3776                return Err(FittedModelError::MissingField {
3777                    reason: "spline-scan model is missing training_headers; refit".to_string(),
3778                });
3779            }
3780            return Ok(());
3781        } else if let Some(cascade) = self.residual_cascade.as_ref() {
3782            // Residual-cascade representation (#1032): a multilevel
3783            // Wendland-frame model for a scattered d ∈ {2,3} Gaussian smooth.
3784            // Exclusive with the dense representation and with the scan.
3785            if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3786                return Err(FittedModelError::SchemaMismatch {
3787                    reason: "residual-cascade model must not also carry spline_scan / \
3788                             fit_result / unified payloads; the representations are \
3789                             mutually exclusive"
3790                        .to_string(),
3791                });
3792            }
3793            if self.model_kind != ModelKind::Standard
3794                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3795            {
3796                return Err(FittedModelError::SchemaMismatch {
3797                    reason: format!(
3798                        "residual-cascade representation requires a standard Gaussian-identity \
3799                         model; got model_kind={:?}, likelihood={:?}",
3800                        self.model_kind,
3801                        self.family_state.likelihood()
3802                    ),
3803                });
3804            }
3805            if cascade.feature_columns.is_empty()
3806                || !(2..=3).contains(&cascade.feature_columns.len())
3807            {
3808                return Err(FittedModelError::MissingField {
3809                    reason: format!(
3810                        "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3811                        cascade.feature_columns.len()
3812                    ),
3813                });
3814            }
3815            gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3816                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3817            if self.data_schema.is_none() {
3818                return Err(FittedModelError::MissingField {
3819                    reason: "residual-cascade model is missing data_schema; refit".to_string(),
3820                });
3821            }
3822            if self.training_headers.is_none() {
3823                return Err(FittedModelError::MissingField {
3824                    reason: "residual-cascade model is missing training_headers; refit".to_string(),
3825                });
3826            }
3827            return Ok(());
3828        } else if self.fit_result.is_none() {
3829            return Err(FittedModelError::MissingField {
3830                reason: "model is missing canonical fit_result payload; refit".to_string(),
3831            });
3832        }
3833        if self.data_schema.is_none() {
3834            return Err(FittedModelError::MissingField {
3835                reason: "model is missing data_schema; refit".to_string(),
3836            });
3837        }
3838        if self.training_headers.is_none() {
3839            return Err(FittedModelError::MissingField {
3840                reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3841                    .to_string(),
3842            });
3843        }
3844        let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3845            FittedModelError::MissingField {
3846                reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3847                    .to_string(),
3848            }
3849        })?;
3850        validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3851
3852        if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3853            return Err(FittedModelError::MissingField {
3854                reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3855                    .to_string(),
3856            });
3857        }
3858        if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3859            validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3860        }
3861        if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3862            let score = self.transformation_score_calibration.ok_or_else(|| {
3863                FittedModelError::MissingField {
3864                    reason: "transformation-normal model is missing transformation_score_calibration; refit"
3865                        .to_string(),
3866                }
3867            })?;
3868            score.validate("transformation-normal model")?;
3869        }
3870        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3871            if self.formula_logslope.is_none() {
3872                return Err(FittedModelError::MissingField {
3873                    reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3874                });
3875            }
3876            if self.z_column.is_none() {
3877                return Err(FittedModelError::MissingField {
3878                    reason: "marginal-slope model is missing z_column; refit".to_string(),
3879                });
3880            }
3881            let z_normalization =
3882                self.latent_z_normalization
3883                    .ok_or_else(|| FittedModelError::MissingField {
3884                        reason: "marginal-slope model is missing latent_z_normalization; refit"
3885                            .to_string(),
3886                    })?;
3887            z_normalization.validate("marginal-slope model")?;
3888            let latent_measure =
3889                self.latent_measure
3890                    .as_ref()
3891                    .ok_or_else(|| FittedModelError::MissingField {
3892                        reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3893                    })?;
3894            latent_measure
3895                .validate("marginal-slope model latent_measure")
3896                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3897            if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3898                return Err(FittedModelError::MissingField {
3899                    reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3900                });
3901            }
3902            if self.resolved_termspec_logslope.as_ref().is_none() {
3903                return Err(FittedModelError::MissingField {
3904                    reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3905                        .to_string(),
3906                });
3907            }
3908            match self.family_state.frailty() {
3909                Some(FrailtySpec::None)
3910                | Some(FrailtySpec::GaussianShift {
3911                    sigma_fixed: Some(_),
3912                }) => {}
3913                Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3914                    return Err(FittedModelError::IncompatibleConfig {
3915                        reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3916                            .to_string(),
3917                    });
3918                }
3919                Some(FrailtySpec::HazardMultiplier { .. }) => {
3920                    return Err(FittedModelError::IncompatibleConfig {
3921                        reason: "marginal-slope model does not support HazardMultiplier frailty"
3922                            .to_string(),
3923                    });
3924                }
3925                None => {
3926                    return Err(FittedModelError::MissingField {
3927                        reason: "marginal-slope model is missing family_state.frailty; refit"
3928                            .to_string(),
3929                    });
3930                }
3931            }
3932        }
3933
3934        if let FittedFamily::Survival {
3935            survival_likelihood,
3936            frailty,
3937            ..
3938        } = &self.family_state
3939        {
3940            if matches!(
3941                survival_likelihood.as_deref(),
3942                Some("latent") | Some("latent-binary")
3943            ) {
3944                return Err(FittedModelError::SchemaMismatch {
3945                    reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3946                        .to_string(),
3947                });
3948            }
3949            if survival_likelihood.as_deref() == Some("marginal-slope") {
3950                if self.formula_logslope.is_none() {
3951                    return Err(FittedModelError::MissingField {
3952                        reason: "survival marginal-slope model is missing formula_logslope; refit"
3953                            .to_string(),
3954                    });
3955                }
3956                if self.z_column.is_none() {
3957                    return Err(FittedModelError::MissingField {
3958                        reason: "survival marginal-slope model is missing z_column; refit"
3959                            .to_string(),
3960                    });
3961                }
3962                let z_normalization =
3963                    self.latent_z_normalization
3964                        .ok_or_else(|| {
3965                            FittedModelError::MissingField {
3966                        reason:
3967                            "survival marginal-slope model is missing latent_z_normalization; refit"
3968                                .to_string(),
3969                    }
3970                        })?;
3971                z_normalization.validate("survival marginal-slope model")?;
3972                let latent_measure =
3973                    self.latent_measure
3974                        .as_ref()
3975                        .ok_or_else(|| FittedModelError::MissingField {
3976                            reason:
3977                                "survival marginal-slope model is missing latent_measure; refit"
3978                                    .to_string(),
3979                        })?;
3980                latent_measure
3981                    .validate("survival marginal-slope model latent_measure")
3982                    .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3983                if self.logslope_baseline.is_none() {
3984                    return Err(FittedModelError::MissingField {
3985                        reason: "survival marginal-slope model is missing logslope_baseline; refit"
3986                            .to_string(),
3987                    });
3988                }
3989                if self.resolved_termspec_logslope.as_ref().is_none() {
3990                    return Err(FittedModelError::MissingField {
3991                        reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3992                            .to_string(),
3993                    });
3994                }
3995                match frailty {
3996                    FrailtySpec::None
3997                    | FrailtySpec::GaussianShift {
3998                        sigma_fixed: Some(_),
3999                    } => {}
4000                    FrailtySpec::GaussianShift { sigma_fixed: None } => {
4001                        return Err(FittedModelError::IncompatibleConfig {
4002                            reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4003                                .to_string(),
4004                        });
4005                    }
4006                    FrailtySpec::HazardMultiplier { .. } => {
4007                        return Err(FittedModelError::IncompatibleConfig {
4008                            reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4009                                .to_string(),
4010                        });
4011                    }
4012                }
4013            } else if !matches!(frailty, FrailtySpec::None) {
4014                return Err(FittedModelError::IncompatibleConfig {
4015                    reason:
4016                        "non-marginal survival models do not currently persist a frailty modifier"
4017                            .to_string(),
4018                });
4019            }
4020            // Non-latent survival predict reconstructs the baseline-time
4021            // basis via `load_survival_time_basis_config_from_model` and
4022            // anchors that basis at `survival_time_anchor`; both are
4023            // required for the saved model to be loadable. The CLI's
4024            // marginal-slope+time-wiggle save path previously dropped one or
4025            // the other on partial-write, producing models that loaded but
4026            // would panic at the first predict. Enforce both before persisting.
4027            if self.survival_time_basis.is_none() {
4028                return Err(FittedModelError::MissingField {
4029                    reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4030                });
4031            }
4032            if self.survival_time_anchor.is_none() {
4033                return Err(FittedModelError::MissingField {
4034                    reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4035                });
4036            }
4037        }
4038        if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4039            match frailty {
4040                FrailtySpec::HazardMultiplier {
4041                    sigma_fixed: Some(_),
4042                    ..
4043                } => {}
4044                FrailtySpec::HazardMultiplier {
4045                    sigma_fixed: None, ..
4046                } => {
4047                    return Err(FittedModelError::IncompatibleConfig {
4048                        reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4049                            .to_string(),
4050                    });
4051                }
4052                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4053                    return Err(FittedModelError::IncompatibleConfig {
4054                        reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4055                            .to_string(),
4056                    });
4057                }
4058            }
4059            if self.survival_likelihood.as_deref() != Some("latent") {
4060                return Err(FittedModelError::SchemaMismatch {
4061                    reason: "latent survival model must persist survival_likelihood=latent"
4062                        .to_string(),
4063                });
4064            }
4065        }
4066        if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4067            match frailty {
4068                FrailtySpec::HazardMultiplier {
4069                    sigma_fixed: Some(_),
4070                    ..
4071                } => {}
4072                FrailtySpec::HazardMultiplier {
4073                    sigma_fixed: None, ..
4074                } => {
4075                    return Err(FittedModelError::IncompatibleConfig {
4076                        reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4077                            .to_string(),
4078                    });
4079                }
4080                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4081                    return Err(FittedModelError::IncompatibleConfig {
4082                        reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4083                            .to_string(),
4084                    });
4085                }
4086            }
4087            if self.survival_likelihood.as_deref() != Some("latent-binary") {
4088                return Err(FittedModelError::SchemaMismatch {
4089                    reason: "latent binary model must persist survival_likelihood=latent-binary"
4090                        .to_string(),
4091                });
4092            }
4093        }
4094
4095        let family_likelihood = match &self.family_state {
4096            FittedFamily::Standard { likelihood, .. }
4097            | FittedFamily::LocationScale { likelihood, .. }
4098            | FittedFamily::MarginalSlope { likelihood, .. }
4099            | FittedFamily::Survival { likelihood, .. }
4100            | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4101            FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4102        };
4103        let is_standard_or_location_scale = matches!(
4104            self.family_state,
4105            FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4106        );
4107        if is_standard_or_location_scale
4108            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4109        {
4110            self.saved_sas_state()?;
4111        }
4112        if is_standard_or_location_scale
4113            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4114        {
4115            self.saved_beta_logistic_state()?;
4116        }
4117        if is_standard_or_location_scale
4118            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4119        {
4120            self.saved_mixture_state()?;
4121        }
4122        if matches!(self.family_state, FittedFamily::Standard { .. })
4123            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4124        {
4125            self.saved_latent_cloglog_state()?;
4126        }
4127        if matches!(self.family_state, FittedFamily::LocationScale { .. })
4128            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4129        {
4130            return Err(FittedModelError::IncompatibleConfig {
4131                reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4132                    .to_string(),
4133            });
4134        }
4135        if matches!(self.family_state, FittedFamily::Survival { .. })
4136            && self.survival_likelihood.is_none()
4137        {
4138            return Err(FittedModelError::MissingField {
4139                reason: "saved survival model is missing survival_likelihood metadata; refit"
4140                    .to_string(),
4141            });
4142        }
4143        let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4144            || self.linkwiggle_degree.is_some()
4145            || self.beta_link_wiggle.is_some()
4146            || self
4147                .fit_result
4148                .as_ref()
4149                .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4150                .is_some();
4151        let saved_link_wiggle = self.saved_link_wiggle()?;
4152        if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4153            return Err(FittedModelError::SchemaMismatch {
4154                reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4155                    .to_string(),
4156            });
4157        }
4158        let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4159            || self.baseline_timewiggle_degree.is_some()
4160            || self.baseline_timewiggle_penalty_orders.is_some()
4161            || self.baseline_timewiggle_double_penalty.is_some()
4162            || self.beta_baseline_timewiggle.is_some()
4163            || self.beta_baseline_timewiggle_by_cause.is_some();
4164        let is_joint_cause_specific = self
4165            .survival_cause_count
4166            .is_some_and(|cause_count| cause_count > 1);
4167        if has_any_saved_baseline_time_wiggle {
4168            if is_joint_cause_specific {
4169                let complete = self.baseline_timewiggle_knots.is_some()
4170                    && self.baseline_timewiggle_degree.is_some()
4171                    && self.baseline_timewiggle_penalty_orders.is_some()
4172                    && self.baseline_timewiggle_double_penalty.is_some()
4173                    && self.beta_baseline_timewiggle_by_cause.is_some();
4174                if !complete {
4175                    return Err(FittedModelError::SchemaMismatch {
4176                        reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4177                            .to_string(),
4178                    });
4179                }
4180            } else if self.saved_baseline_time_wiggle()?.is_none() {
4181                return Err(FittedModelError::SchemaMismatch {
4182                    reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4183                        .to_string(),
4184                });
4185            }
4186        }
4187        if self
4188            .survival_likelihood
4189            .as_deref()
4190            .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4191        {
4192            validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4193        }
4194
4195        // Validate anchored-deviation replay contracts at LOAD/SAVE time rather
4196        // than waiting for first predict call. Previously these contracts
4197        // (span table dimensions, coefficient matrices, etc.) were only
4198        // asserted inside `saved_prediction_runtime`, which runs on the first
4199        // predict invocation. A corrupted runtime would therefore pass
4200        // `load_from_path` silently and fail later under a different error
4201        // surface. Enforcing the same check here makes the model self-
4202        // diagnostic: `gam fit` catches its own bad output at save, and
4203        // `gam predict` catches bad input at load rather than mid-pipeline.
4204        if let Some(runtime) = self.score_warp_runtime.as_ref() {
4205            runtime.validate_exact_replay_contract().map_err(|err| {
4206                FittedModelError::PayloadCorrupt {
4207                    reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4208                }
4209            })?;
4210        }
4211        if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4212            runtime.validate_exact_replay_contract().map_err(|err| {
4213                FittedModelError::PayloadCorrupt {
4214                    reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4215                }
4216            })?;
4217        }
4218        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4219            validate_marginal_slope_saved_fit(
4220                self.fit_result.as_ref().expect("checked above"),
4221                self.score_warp_runtime.as_ref(),
4222                self.link_deviation_runtime.as_ref(),
4223                "fit_result",
4224            )?;
4225            let unified = self
4226                .unified
4227                .as_ref()
4228                .ok_or_else(|| FittedModelError::MissingField {
4229                    reason: "marginal-slope model is missing unified fit payload; refit"
4230                        .to_string(),
4231                })?;
4232            validate_marginal_slope_saved_fit(
4233                unified,
4234                self.score_warp_runtime.as_ref(),
4235                self.link_deviation_runtime.as_ref(),
4236                "unified",
4237            )?;
4238        }
4239        if self
4240            .survival_likelihood
4241            .as_deref()
4242            .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4243        {
4244            validate_survival_marginal_slope_saved_fit(
4245                self.fit_result.as_ref().expect("checked above"),
4246                self.score_warp_runtime.as_ref(),
4247                self.link_deviation_runtime.as_ref(),
4248                "fit_result",
4249            )?;
4250            if let Some(unified) = self.unified.as_ref() {
4251                validate_survival_marginal_slope_saved_fit(
4252                    unified,
4253                    self.score_warp_runtime.as_ref(),
4254                    self.link_deviation_runtime.as_ref(),
4255                    "unified",
4256                )?;
4257            }
4258        }
4259
4260        // Posterior-mean / uncertainty backends are validated at predict time
4261        // by `prediction_backend_from_model`, which has access to the actual
4262        // requested mode and emits the canonical "nonlinear posterior-mean
4263        // prediction requires either covariance or a saved penalized Hessian"
4264        // error.  Save-time we deliberately do NOT enforce that gate: a fit
4265        // produced for MAP / plug-in scoring can be persisted and replayed
4266        // without ever needing a covariance backend, and gating it here would
4267        // refuse legitimate MAP-only saves whose `UnifiedFitResult` carries
4268        // beta + lambdas without a stabilized Hessian.
4269
4270        Ok(())
4271    }
4272
4273    pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4274        let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4275        if let Some(fit) = self.fit_result.as_ref() {
4276            fit.validate_numeric_finiteness()
4277                .map_err(|e| corrupt(e.to_string()))?;
4278        }
4279
4280        for (name, opt) in [
4281            ("survival_baseline_scale", self.survival_baseline_scale),
4282            ("survival_baseline_shape", self.survival_baseline_shape),
4283            ("survival_baseline_rate", self.survival_baseline_rate),
4284            ("survival_baseline_makeham", self.survival_baseline_makeham),
4285            (
4286                "survival_time_smooth_lambda",
4287                self.survival_time_smooth_lambda,
4288            ),
4289            ("survival_time_anchor", self.survival_time_anchor),
4290            ("survivalridge_lambda", self.survivalridge_lambda),
4291        ] {
4292            if let Some(v) = opt {
4293                ensure_finite_scalar(name, v).map_err(corrupt)?;
4294            }
4295        }
4296
4297        if let Some(v) = self.beta_noise.as_ref() {
4298            validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4299        }
4300        if let Some(v) = self.noise_projection.as_ref() {
4301            validate_all_finite("noise_projection", v.iter().flatten().copied())
4302                .map_err(corrupt)?;
4303            if self.noise_projection_ridge_alpha.is_none() {
4304                return Err(FittedModelError::MissingField {
4305                    reason:
4306                        "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4307                            .to_string(),
4308                });
4309            }
4310        }
4311        if let Some(v) = self.noise_center.as_ref() {
4312            validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4313        }
4314        if let Some(v) = self.noise_scale.as_ref() {
4315            validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4316        }
4317        if let Some(v) = self.noise_projection_ridge_alpha {
4318            ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4319            if v < 0.0 {
4320                return Err(FittedModelError::InvalidInput {
4321                    reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4322                });
4323            }
4324        }
4325        if let Some(v) = self.gaussian_response_scale {
4326            ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4327        }
4328        if let Some(v) = self.beta_link_wiggle.as_ref() {
4329            validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4330        }
4331        if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4332            validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4333        }
4334        if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4335            validate_all_finite(
4336                "beta_baseline_timewiggle_by_cause",
4337                v.iter().flatten().copied(),
4338            )
4339            .map_err(corrupt)?;
4340        }
4341        if let Some(v) = self.latent_z_normalization {
4342            v.validate("latent_z_normalization")?;
4343        }
4344        if let Some(v) = self.latent_measure.as_ref() {
4345            v.validate("latent_measure").map_err(corrupt)?;
4346        }
4347        if let Some(v) = self.survival_beta_time.as_ref() {
4348            validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4349        }
4350        if let Some(v) = self.survival_beta_threshold.as_ref() {
4351            validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4352        }
4353        if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4354            validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4355        }
4356        if let Some(v) = self.survival_noise_projection.as_ref() {
4357            validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4358                .map_err(corrupt)?;
4359            if self.survival_noise_projection_ridge_alpha.is_none() {
4360                return Err(FittedModelError::MissingField {
4361                    reason:
4362                        "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4363                            .to_string(),
4364                });
4365            }
4366        }
4367        if let Some(v) = self.survival_noise_center.as_ref() {
4368            validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4369        }
4370        if let Some(v) = self.survival_noise_projection_ridge_alpha {
4371            ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4372            if v < 0.0 {
4373                return Err(FittedModelError::InvalidInput {
4374                    reason: format!(
4375                        "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4376                    ),
4377                });
4378            }
4379        }
4380        if let Some(v) = self.survival_noise_scale.as_ref() {
4381            validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4382        }
4383        if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4384            validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4385                .map_err(corrupt)?;
4386        }
4387        if let Some(v) = self.sas_param_covariance.as_ref() {
4388            validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4389                .map_err(corrupt)?;
4390        }
4391        Ok(())
4392    }
4393}
4394
4395fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4396    a.rows().into_iter().map(|row| row.to_vec()).collect()
4397}
4398
4399use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4400
4401fn validate_frozen_term_collectionspec(
4402    spec: &TermCollectionSpec,
4403    label: &str,
4404) -> Result<(), FittedModelError> {
4405    spec.validate_frozen(label)
4406        .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4407}
4408
4409impl Deref for FittedModel {
4410    type Target = FittedModelPayload;
4411
4412    fn deref(&self) -> &Self::Target {
4413        self.payload()
4414    }
4415}
4416
4417impl DerefMut for FittedModel {
4418    fn deref_mut(&mut self) -> &mut Self::Target {
4419        self.payload_mut()
4420    }
4421}
4422
4423// ---------------------------------------------------------------------------
4424// Reconstruct library types from saved models
4425// ---------------------------------------------------------------------------
4426
4427pub fn survival_baseline_config_from_model(
4428    model: &FittedModel,
4429) -> Result<SurvivalBaselineConfig, FittedModelError> {
4430    let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4431        FittedModelError::MissingField {
4432            reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4433        }
4434    })?;
4435    parse_survival_baseline_config(
4436        target,
4437        model.survival_baseline_scale,
4438        model.survival_baseline_shape,
4439        model.survival_baseline_rate,
4440        model.survival_baseline_makeham,
4441    )
4442    .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4443}
4444
4445pub fn load_survival_time_basis_config_from_model(
4446    model: &FittedModel,
4447) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4448    match model
4449        .survival_time_basis
4450        .as_deref()
4451        .ok_or_else(|| FittedModelError::MissingField {
4452            reason: "saved survival model missing survival_time_basis".to_string(),
4453        })?
4454        .to_ascii_lowercase()
4455        .as_str()
4456    {
4457        "none" => Ok(SurvivalTimeBasisConfig::None),
4458        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4459        "bspline" => {
4460            let degree =
4461                model
4462                    .survival_time_degree
4463                    .ok_or_else(|| FittedModelError::MissingField {
4464                        reason: "saved survival bspline model missing survival_time_degree"
4465                            .to_string(),
4466                    })?;
4467            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4468                FittedModelError::MissingField {
4469                    reason: "saved survival bspline model missing survival_time_knots".to_string(),
4470                }
4471            })?;
4472            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4473            if degree < 1 || knots.is_empty() {
4474                return Err(FittedModelError::SchemaMismatch {
4475                    reason: "saved survival bspline time basis metadata is invalid".to_string(),
4476                });
4477            }
4478            Ok(SurvivalTimeBasisConfig::BSpline {
4479                degree,
4480                knots: Array1::from_vec(knots),
4481                smooth_lambda,
4482            })
4483        }
4484        "ispline" => {
4485            let degree =
4486                model
4487                    .survival_time_degree
4488                    .ok_or_else(|| FittedModelError::MissingField {
4489                        reason: "saved survival ispline model missing survival_time_degree"
4490                            .to_string(),
4491                    })?;
4492            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4493                FittedModelError::MissingField {
4494                    reason: "saved survival ispline model missing survival_time_knots".to_string(),
4495                }
4496            })?;
4497            let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4498                FittedModelError::MissingField {
4499                    reason: "saved survival ispline model missing survival_time_keep_cols"
4500                        .to_string(),
4501                }
4502            })?;
4503            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4504            if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4505                return Err(FittedModelError::SchemaMismatch {
4506                    reason: "saved survival ispline time basis metadata is invalid".to_string(),
4507                });
4508            }
4509            Ok(SurvivalTimeBasisConfig::ISpline {
4510                degree,
4511                knots: Array1::from_vec(knots),
4512                keep_cols,
4513                smooth_lambda,
4514            })
4515        }
4516        other => Err(FittedModelError::IncompatibleConfig {
4517            reason: format!("unsupported saved survival_time_basis '{other}'"),
4518        }),
4519    }
4520}
4521
4522#[cfg(test)]
4523mod tests {
4524    use super::*;
4525    use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4526    use crate::survival::lognormal_kernel::FrailtySpec;
4527    use gam_solve::pirls::PirlsStatus;
4528    use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4529    use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4530    use gam_data::SchemaColumn;
4531    use ndarray::{Array1, Array2, array};
4532
4533    fn empty_termspec() -> TermCollectionSpec {
4534        TermCollectionSpec {
4535            linear_terms: vec![],
4536            random_effect_terms: vec![],
4537            smooth_terms: vec![],
4538        }
4539    }
4540
4541    /// #1030/#1034: a scan-bearing payload must round-trip through JSON +
4542    /// `validate_for_persistence` and replay the training Gaussian bridge
4543    /// bit-for-bit; structural corruption must fail loudly at validation.
4544    #[test]
4545    fn spline_scan_payload_round_trips_and_validates() {
4546        let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4547        let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4548        let w = vec![1.0_f64; x.len()];
4549        let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4550        let make_payload = || {
4551            crate::inference::model_payload_builders::assemble_spline_scan_payload(
4552                "y ~ s(x)".to_string(),
4553                "x".to_string(),
4554                &fit,
4555                DataSchema {
4556                    columns: vec![
4557                        SchemaColumn {
4558                            name: "y".to_string(),
4559                            kind: ColumnKindTag::Continuous,
4560                            levels: vec![],
4561                        },
4562                        SchemaColumn {
4563                            name: "x".to_string(),
4564                            kind: ColumnKindTag::Continuous,
4565                            levels: vec![],
4566                        },
4567                    ],
4568                },
4569                vec!["x".to_string()],
4570                vec![(0.0, 1.0)],
4571            )
4572        };
4573        // The on-disk form is the FittedModel tagged enum; validation and the
4574        // scan accessor live on FittedModel (Deref only goes Model -> Payload).
4575        let model = FittedModel::from_payload(make_payload());
4576        model
4577            .validate_for_persistence()
4578            .expect("scan model validates");
4579        model
4580            .validate_numeric_finiteness()
4581            .expect("scan model is finite");
4582
4583        let json = serde_json::to_string(&model).expect("serialize model");
4584        let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4585        restored
4586            .validate_for_persistence()
4587            .expect("restored scan model validates");
4588        let (column, replay) = restored
4589            .saved_spline_scan()
4590            .expect("restore scan fit")
4591            .expect("payload carries the scan representation");
4592        assert_eq!(column, "x");
4593        for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4594            let (m0, v0) = fit.predict(xq).expect("predict original");
4595            let (m1, v1) = replay.predict(xq).expect("predict replayed");
4596            assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4597            assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4598        }
4599
4600        // A dense model without the scan channel still requires fit_result.
4601        let mut dense = make_payload();
4602        dense.spline_scan = None;
4603        let err = FittedModel::from_payload(dense)
4604            .validate_for_persistence()
4605            .expect_err("dense payload without fit_result must be rejected");
4606        assert!(err.to_string().contains("fit_result"));
4607
4608        // Structural corruption fails at validation, not inside predict.
4609        let mut corrupt = make_payload();
4610        corrupt
4611            .spline_scan
4612            .as_mut()
4613            .expect("scan channel present")
4614            .state
4615            .knots
4616            .truncate(2);
4617        FittedModel::from_payload(corrupt)
4618            .validate_for_persistence()
4619            .expect_err("corrupt scan state must be rejected");
4620        let mut unnamed = make_payload();
4621        unnamed
4622            .spline_scan
4623            .as_mut()
4624            .expect("scan channel present")
4625            .feature_column
4626            .clear();
4627        FittedModel::from_payload(unnamed)
4628            .validate_for_persistence()
4629            .expect_err("missing feature column must be rejected");
4630    }
4631
4632    fn standard_gaussian_payload() -> FittedModelPayload {
4633        FittedModelPayload::new(
4634            MODEL_PAYLOAD_VERSION,
4635            "y ~ 1".to_string(),
4636            ModelKind::Standard,
4637            FittedFamily::Standard {
4638                likelihood: LikelihoodSpec::gaussian_identity(),
4639                link: Some(StandardLink::Identity),
4640                latent_cloglog_state: None,
4641                mixture_state: None,
4642                sas_state: None,
4643            },
4644            "gaussian".to_string(),
4645        )
4646    }
4647
4648    fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4649        SavedCompiledFlexBlock {
4650            kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4651            breakpoints: vec![-1.0, 1.0],
4652            basis_dim,
4653            span_c0: vec![vec![0.0; basis_dim]],
4654            span_c1: vec![vec![0.0; basis_dim]],
4655            span_c2: vec![vec![0.0; basis_dim]],
4656            span_c3: vec![vec![0.0; basis_dim]],
4657            anchor_correction: None,
4658            anchor_components: Vec::new(),
4659        }
4660    }
4661
4662    fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4663        let beta = Array1::from_vec(
4664            blocks
4665                .iter()
4666                .flat_map(|block| block.beta.iter().copied())
4667                .collect(),
4668        );
4669        let p = beta.len();
4670        UnifiedFitResult {
4671            blocks,
4672            log_lambdas: Array1::zeros(0),
4673            lambdas: Array1::zeros(0),
4674            likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4675            likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4676            log_likelihood_normalization: LogLikelihoodNormalization::Full,
4677            log_likelihood: 0.0,
4678            deviance: 0.0,
4679            reml_score: 0.0,
4680            stable_penalty_term: 0.0,
4681            penalized_objective: 0.0,
4682            used_device: false,
4683            outer_iterations: 0,
4684            outer_cost_evals: 0,
4685            inner_pirls_solves: 0,
4686            outer_converged: true,
4687            outer_gradient_norm: None,
4688            standard_deviation: 1.0,
4689            covariance_conditional: Some(Array2::zeros((p, p))),
4690            covariance_corrected: Some(Array2::zeros((p, p))),
4691            inference: None,
4692            fitted_link: FittedLinkState::Standard(None),
4693            geometry: None,
4694            block_states: vec![],
4695            beta,
4696            pirls_status: PirlsStatus::Converged,
4697            max_abs_eta: 0.0,
4698            constraint_kkt: None,
4699            artifacts: FitArtifacts {
4700                pirls: None,
4701                null_space_logdet: None,
4702                null_space_dim: None,
4703                survival_link_wiggle_knots: None,
4704                survival_link_wiggle_degree: None,
4705                criterion_certificate: None,
4706                rho_posterior_certificate: None,
4707                rho_posterior_escalation: None,
4708                rho_covariance: None,
4709            },
4710            inner_cycles: 0,
4711        }
4712    }
4713
4714    fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4715        let mut payload = FittedModelPayload::new(
4716            version,
4717            "y ~ 1".to_string(),
4718            ModelKind::MarginalSlope,
4719            FittedFamily::MarginalSlope {
4720                likelihood: LikelihoodSpec::binomial_probit(),
4721                base_link: InverseLink::Standard(StandardLink::Probit),
4722                frailty: FrailtySpec::None,
4723            },
4724            "bernoulli-marginal-slope".to_string(),
4725        );
4726        payload.fit_result = Some(fit.clone());
4727        payload.unified = Some(fit);
4728        payload.data_schema = Some(DataSchema {
4729            columns: vec![SchemaColumn {
4730                name: "z".to_string(),
4731                kind: ColumnKindTag::Continuous,
4732                levels: vec![],
4733            }],
4734        });
4735        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4736        payload.resolved_termspec = Some(empty_termspec());
4737        payload.resolved_termspec_logslope = Some(empty_termspec());
4738        payload.formula_logslope = Some("1".to_string());
4739        payload.z_column = Some("z".to_string());
4740        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4741        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4742        payload.marginal_baseline = Some(0.0);
4743        payload.logslope_baseline = Some(0.0);
4744        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4745        payload
4746    }
4747
4748    #[test]
4749    fn from_payload_synchronizes_used_device_from_saved_fit() {
4750        let mut fit = saved_fit(vec![
4751            FittedBlock {
4752                beta: Array1::from_vec(vec![0.25]),
4753                role: BlockRole::Mean,
4754                edf: 1.0,
4755                lambdas: Array1::zeros(0),
4756            },
4757            FittedBlock {
4758                beta: Array1::from_vec(vec![0.5]),
4759                role: BlockRole::Scale,
4760                edf: 1.0,
4761                lambdas: Array1::zeros(0),
4762            },
4763        ]);
4764        fit.used_device = true;
4765        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4766        payload.used_device = false;
4767
4768        let model = FittedModel::from_payload(payload);
4769
4770        assert!(model.payload().used_device);
4771    }
4772
4773    fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4774        let mut payload = FittedModelPayload::new(
4775            version,
4776            "Surv(entry, exit, event) ~ 1".to_string(),
4777            ModelKind::Survival,
4778            FittedFamily::Survival {
4779                likelihood: LikelihoodSpec::royston_parmar(),
4780                survival_likelihood: Some("marginal-slope".to_string()),
4781                survival_distribution: Some(ResidualDistribution::Gaussian),
4782                frailty: FrailtySpec::None,
4783            },
4784            "survival".to_string(),
4785        );
4786        payload.fit_result = Some(fit.clone());
4787        payload.unified = Some(fit);
4788        payload.survival_likelihood = Some("marginal-slope".to_string());
4789        payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4790        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4791        payload.data_schema = Some(DataSchema {
4792            columns: vec![SchemaColumn {
4793                name: "z".to_string(),
4794                kind: ColumnKindTag::Continuous,
4795                levels: vec![],
4796            }],
4797        });
4798        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4799        payload.resolved_termspec = Some(empty_termspec());
4800        payload.resolved_termspec_logslope = Some(empty_termspec());
4801        payload.formula_logslope = Some("1".to_string());
4802        payload.z_column = Some("z".to_string());
4803        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4804        payload.logslope_baseline = Some(0.0);
4805        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4806        payload
4807    }
4808
4809    fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4810        // A #913 genuine-dispersion location-scale model: Gamma mean family with
4811        // a log-precision `noise_formula` channel. Its likelihood response is
4812        // non-Gaussian and non-Binomial, so the predict-path classifier must
4813        // route it to `DispersionLocationScale`, NOT the binomial threshold-scale
4814        // class (issue #1064).
4815        let mut payload = FittedModelPayload::new(
4816            MODEL_PAYLOAD_VERSION,
4817            "y ~ x".to_string(),
4818            ModelKind::LocationScale,
4819            FittedFamily::LocationScale {
4820                likelihood: LikelihoodSpec::gamma_log(),
4821                base_link: Some(InverseLink::Standard(StandardLink::Log)),
4822            },
4823            "gamma-location-scale".to_string(),
4824        );
4825        payload.data_schema = Some(DataSchema {
4826            columns: vec![
4827                SchemaColumn {
4828                    name: "y".to_string(),
4829                    kind: ColumnKindTag::Continuous,
4830                    levels: vec![],
4831                },
4832                SchemaColumn {
4833                    name: "x".to_string(),
4834                    kind: ColumnKindTag::Continuous,
4835                    levels: vec![],
4836                },
4837            ],
4838        });
4839        payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4840        payload.resolved_termspec = Some(empty_termspec());
4841        payload.resolved_termspec_noise = Some(empty_termspec());
4842        payload.formula_noise = Some("x".to_string());
4843        payload.beta_noise = Some(vec![0.0]);
4844        payload.link = Some(InverseLink::Standard(StandardLink::Log));
4845        payload
4846    }
4847
4848    /// #1064 regression: a dispersion location-scale (#913) payload must be
4849    /// classified as `DispersionLocationScale` at every predict-path entry —
4850    /// both `from_payload` (load) and `predict_model_class` (runtime) — and never
4851    /// fall through to the binomial threshold-scale class. Before the fix the
4852    /// non-Gaussian `else` arm mis-routed every dispersion model to
4853    /// `BinomialLocationScale`, predicting the wrong family/link.
4854    #[test]
4855    fn dispersion_location_scale_payload_is_not_classified_binomial() {
4856        let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4857        assert_eq!(
4858            model.predict_model_class(),
4859            PredictModelClass::DispersionLocationScale,
4860            "Gamma dispersion location-scale must route through the dispersion \
4861             predictor, not the binomial threshold-scale class",
4862        );
4863        assert!(
4864            !matches!(
4865                model.predict_model_class(),
4866                PredictModelClass::BinomialLocationScale
4867            ),
4868            "dispersion location-scale must never be classified as binomial",
4869        );
4870
4871        // Each of the four #913 dispersion mean families classifies the same way.
4872        for likelihood in [
4873            LikelihoodSpec::gamma_log(),
4874            LikelihoodSpec::new(
4875                ResponseFamily::NegativeBinomial {
4876                    theta: 1.0,
4877                    theta_fixed: false,
4878                },
4879                InverseLink::Standard(StandardLink::Log),
4880            ),
4881            LikelihoodSpec::new(
4882                ResponseFamily::Beta { phi: 1.0 },
4883                InverseLink::Standard(StandardLink::Logit),
4884            ),
4885            LikelihoodSpec::new(
4886                ResponseFamily::Tweedie { p: 1.5 },
4887                InverseLink::Standard(StandardLink::Log),
4888            ),
4889        ] {
4890            let mut payload = gamma_dispersion_location_scale_payload();
4891            payload.family_state = FittedFamily::LocationScale {
4892                base_link: Some(likelihood.link.clone()),
4893                likelihood: likelihood.clone(),
4894            };
4895            let model = FittedModel::from_payload(payload);
4896            assert_eq!(
4897                model.predict_model_class(),
4898                PredictModelClass::DispersionLocationScale,
4899                "dispersion family {:?} mis-classified",
4900                likelihood.response,
4901            );
4902        }
4903    }
4904
4905    #[test]
4906    fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4907        let data = array![[100.0], [-100.0]];
4908        let col_map = HashMap::from([("g".to_string(), 0usize)]);
4909
4910        let mut plain_payload = standard_gaussian_payload();
4911        plain_payload.data_schema = Some(DataSchema {
4912            columns: vec![SchemaColumn {
4913                name: "g".to_string(),
4914                kind: ColumnKindTag::Continuous,
4915                levels: vec![],
4916            }],
4917        });
4918        plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4919        plain_payload.resolved_termspec = Some(empty_termspec());
4920        let plain = FittedModel::from_payload(plain_payload.clone());
4921        let clipped = plain
4922            .axis_clip_to_training_ranges(data.view(), &col_map)
4923            .expect("ordinary continuous axis should clip outside the training range");
4924        assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4925
4926        let mut group_payload = plain_payload;
4927        let mut group_spec = empty_termspec();
4928        group_spec
4929            .random_effect_terms
4930            .push(gam_terms::smooth::RandomEffectTermSpec {
4931                name: "g".to_string(),
4932                feature_col: 0,
4933                drop_first_level: false,
4934                penalized: true,
4935                frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4936            });
4937        group_payload.resolved_termspec = Some(group_spec);
4938        let group_model = FittedModel::from_payload(group_payload);
4939
4940        assert_eq!(
4941            group_model.random_effect_group_columns(),
4942            HashSet::from(["g".to_string()])
4943        );
4944
4945        assert_eq!(
4946            group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4947            None,
4948            "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4949        );
4950    }
4951
4952    #[test]
4953    fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4954        let fit = saved_fit(vec![
4955            FittedBlock {
4956                beta: array![0.1],
4957                role: BlockRole::Mean,
4958                edf: 1.0,
4959                lambdas: Array1::zeros(0),
4960            },
4961            FittedBlock {
4962                beta: array![0.2],
4963                role: BlockRole::Scale,
4964                edf: 1.0,
4965                lambdas: Array1::zeros(0),
4966            },
4967            FittedBlock {
4968                beta: array![0.3],
4969                role: BlockRole::Mean,
4970                edf: 1.0,
4971                lambdas: Array1::zeros(0),
4972            },
4973        ]);
4974        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4975        payload.score_warp_runtime = Some(anchored_runtime(2));
4976
4977        let err = FittedModel::from_payload(payload)
4978            .validate_for_persistence()
4979            .expect_err("marginal-slope score-warp basis mismatch should fail validation");
4980        assert!(err.to_string().contains("score-warp coefficient mismatch"));
4981    }
4982
4983    #[test]
4984    fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
4985        let fit = saved_fit(vec![
4986            FittedBlock {
4987                beta: array![0.1],
4988                role: BlockRole::Time,
4989                edf: 1.0,
4990                lambdas: Array1::zeros(0),
4991            },
4992            FittedBlock {
4993                beta: array![0.2],
4994                role: BlockRole::Mean,
4995                edf: 1.0,
4996                lambdas: Array1::zeros(0),
4997            },
4998            FittedBlock {
4999                beta: array![0.3],
5000                role: BlockRole::Scale,
5001                edf: 1.0,
5002                lambdas: Array1::zeros(0),
5003            },
5004            FittedBlock {
5005                beta: array![0.4],
5006                role: BlockRole::LinkWiggle,
5007                edf: 1.0,
5008                lambdas: Array1::zeros(0),
5009            },
5010        ]);
5011        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5012        payload.link_deviation_runtime = Some(anchored_runtime(2));
5013
5014        let err = FittedModel::from_payload(payload)
5015            .saved_prediction_runtime()
5016            .expect_err(
5017                "survival marginal-slope link basis mismatch should fail runtime validation",
5018            );
5019        assert!(
5020            err.to_string()
5021                .contains("link-deviation coefficient mismatch")
5022        );
5023    }
5024
5025    #[test]
5026    fn apply_survival_time_basis_writes_all_required_fields() {
5027        use crate::survival::construction::SavedSurvivalTimeBasis;
5028
5029        let fit = saved_fit(vec![
5030            FittedBlock {
5031                beta: array![0.1],
5032                role: BlockRole::Time,
5033                edf: 1.0,
5034                lambdas: Array1::zeros(0),
5035            },
5036            FittedBlock {
5037                beta: array![0.2],
5038                role: BlockRole::Mean,
5039                edf: 1.0,
5040                lambdas: Array1::zeros(0),
5041            },
5042            FittedBlock {
5043                beta: array![0.3],
5044                role: BlockRole::Scale,
5045                edf: 1.0,
5046                lambdas: Array1::zeros(0),
5047            },
5048        ]);
5049        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5050
5051        // Snapshot writes must match every persisted survival_time_* field —
5052        // forgetting one is exactly the marginal-slope save
5053        // regression. Routing through `apply_survival_time_basis` is the
5054        // structural contract that prevents that recurrence.
5055        let snapshot = SavedSurvivalTimeBasis {
5056            basisname: "royston-parmar".to_string(),
5057            degree: Some(3),
5058            knots: Some(vec![0.0, 1.0, 2.0]),
5059            keep_cols: Some(vec![0, 2]),
5060            smooth_lambda: Some(0.5),
5061            anchor: 0.25,
5062        };
5063        payload.apply_survival_time_basis(&snapshot);
5064
5065        assert_eq!(
5066            payload.survival_time_basis.as_deref(),
5067            Some("royston-parmar")
5068        );
5069        assert_eq!(payload.survival_time_degree, Some(3));
5070        assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5071        assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5072        assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5073        assert_eq!(payload.survival_time_anchor, Some(0.25));
5074    }
5075
5076    #[test]
5077    fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5078        let fit = saved_fit(vec![
5079            FittedBlock {
5080                beta: array![0.1],
5081                role: BlockRole::Time,
5082                edf: 1.0,
5083                lambdas: Array1::zeros(0),
5084            },
5085            FittedBlock {
5086                beta: array![0.2],
5087                role: BlockRole::Mean,
5088                edf: 1.0,
5089                lambdas: Array1::zeros(0),
5090            },
5091            FittedBlock {
5092                beta: array![0.3],
5093                role: BlockRole::Scale,
5094                edf: 1.0,
5095                lambdas: Array1::zeros(0),
5096            },
5097        ]);
5098        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5099        // Pass the time_basis presence check but deliberately omit the
5100        // anchor — this is exactly the partial-write shape that the CLI's
5101        // marginal-slope+time-wiggle save path had before the structural
5102        // refactor (main.rs previously set basis/degree/knots/keep_cols/
5103        // smooth_lambda but forgot the anchor).
5104        payload.survival_time_basis = Some("ispline".to_string());
5105
5106        let err = FittedModel::from_payload(payload)
5107            .validate_for_persistence()
5108            .expect_err("survival model without time-anchor metadata should fail validation");
5109        assert!(err.to_string().contains("missing survival_time_anchor"));
5110    }
5111
5112    #[test]
5113    fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5114        let fit = saved_fit(vec![
5115            FittedBlock {
5116                beta: array![0.1],
5117                role: BlockRole::Time,
5118                edf: 1.0,
5119                lambdas: Array1::zeros(0),
5120            },
5121            FittedBlock {
5122                beta: array![0.2],
5123                role: BlockRole::Mean,
5124                edf: 1.0,
5125                lambdas: Array1::zeros(0),
5126            },
5127            FittedBlock {
5128                beta: array![0.3],
5129                role: BlockRole::Scale,
5130                edf: 1.0,
5131                lambdas: Array1::zeros(0),
5132            },
5133        ]);
5134        let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5135
5136        let err = FittedModel::from_payload(payload)
5137            .validate_for_persistence()
5138            .expect_err("survival model without time-basis metadata should fail validation");
5139        assert!(err.to_string().contains("missing survival_time_basis"));
5140    }
5141
5142    #[test]
5143    fn saved_prediction_runtime_rejects_stale_payload_version() {
5144        let fit = saved_fit(vec![
5145            FittedBlock {
5146                beta: array![0.1],
5147                role: BlockRole::Mean,
5148                edf: 1.0,
5149                lambdas: Array1::zeros(0),
5150            },
5151            FittedBlock {
5152                beta: array![0.2],
5153                role: BlockRole::Scale,
5154                edf: 1.0,
5155                lambdas: Array1::zeros(0),
5156            },
5157        ]);
5158        let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5159
5160        let err = FittedModel::from_payload(payload)
5161            .saved_prediction_runtime()
5162            .expect_err("stale payload version should fail before runtime assembly");
5163        assert!(err.to_string().contains("payload schema mismatch"));
5164    }
5165}