Skip to main content

gam_models/inference/
model.rs

1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4    LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7    SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12    monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15    inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16    parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21    InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22    SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25// The data-schema value types live in the `gam-data` foundation crate; they
26// were previously authored here and are still named `gam::inference::model::{
27// ColumnKindTag, DataSchema, SchemaColumn}` by a broad set of integration tests
28// and by saved-payload consumers. Re-export them so that public path stays
29// valid rather than forcing every caller onto the relocated crate path.
30pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39/// Canonical saved-model payload schema version.
40///
41/// Every `FittedModelPayload` written by any binary (CLI `gam`, gam-pyffi,
42/// downstream library users) must set this as its `version` field, and every
43/// load path asserts equality via `validate_for_persistence`. Bump this when:
44///   - A required field is added to `FittedModelPayload` and the set of
45///     Option<T> fields that must be `Some(...)` for a given `family_state`
46///     changes (otherwise the `#[serde(default)]` decode would silently fill
47///     the new field with `None` when loading an older model and the CLI
48///     predict path would run with stale metadata).
49///   - The on-wire shape of any `serde`-tagged enum variant changes such that
50///     older payloads no longer round-trip losslessly.
51///   - The semantics of an existing field change (e.g. sign convention,
52///     coordinate frame) in a way that predict output would silently diverge
53///     between old and new readers.
54///
55/// Do NOT bump for purely additive `Option<T>` fields that the save-time
56/// invariant (`validate_for_persistence`) does not yet require. Those are
57/// forward-compatible.
58pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60/// Schema-free saved-model metadata keyed by stable group id.
61///
62/// The values are JSON rather than a typed enum because group provenance is
63/// supplied by caller-owned catalogs. `FittedModelPayload::group_metadata`
64/// wraps this in `Option` with `#[serde(default)]`, so model files written
65/// before the field existed deserialize as `None`.
66pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68/// Saved exact spline-scan fit (#1030/#1034): the predict-time feature column
69/// plus the lossless smoother state the Gaussian-bridge `predict` replays.
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72    /// Training column name feeding the single 1-D smooth at predict time.
73    pub feature_column: String,
74    pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77/// Saved multiresolution residual-cascade fit (#1032): the predict-time feature
78/// columns (d ∈ {2, 3}) plus the serializable cascade state that `from_state`
79/// rebuilds a predict-capable `ResidualCascadeFit` from. The cascade is a
80/// DIFFERENT posterior from the dense Duchon/Matérn term — never a silent swap.
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83    /// Training column names for the d ∈ {2, 3} scattered-smooth coordinates.
84    pub feature_columns: Vec<String>,
85    pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88/// Typed error surface for `src/inference/model.rs` saved-model code.
89///
90/// Every variant carries a free-form `reason: String` payload; `Display`
91/// emits exactly that payload, so converting a `FittedModelError` into
92/// `String` (via the `From` impl below) is byte-equivalent to the pre-
93/// refactor `Err(format!(...))` / `Err("...".to_string())` strings that
94/// the same call sites produced. This lets external callers keep using
95/// `?` against `Result<_, String>` without source changes — the typed
96/// enum is purely an in-module discipline gain.
97#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99    /// Saved payload structure / shape / version disagrees with what the
100    /// current binary expects (e.g. covariance shape, block ordering,
101    /// schema version, C2 continuity, out-of-range span/basis indices).
102    SchemaMismatch { reason: String },
103    /// Saved payload bytes / numeric content are corrupt or unreadable
104    /// (non-finite scalars, invalid JSON, IO failure, malformed stateful
105    /// link state).
106    PayloadCorrupt { reason: String },
107    /// A required field that the current code path needs is absent from
108    /// the payload (typically `..; refit` errors).
109    MissingField { reason: String },
110    /// A combination of saved-model options is not supported by the
111    /// current binary (unsupported deployment-extension kind, unsupported
112    /// kernel marker, unsupported survival_time_basis variant, etc.).
113    IncompatibleConfig { reason: String },
114    /// An input value rejected by a save-time sanity gate (e.g. negative
115    /// ridge alpha).
116    InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120    FittedModelError {
121        SchemaMismatch,
122        PayloadCorrupt,
123        MissingField,
124        IncompatibleConfig,
125        InvalidInput,
126    }
127}
128
129// Boundary conversions so external `Result<_, EstimationError>` /
130// `Result<_, SurvivalPredictError>` call sites can propagate with `?`.
131// Survival prediction keeps the model-layer source so the chain identifies
132// the payload/schema failure that triggered the prediction error.
133impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134    fn from(err: FittedModelError) -> Self {
135        gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136    }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140    fn from(err: FittedModelError) -> Self {
141        crate::survival::predict::SurvivalPredictError::ModelPayload {
142            context: "saved-model survival prediction payload",
143            source: err,
144        }
145    }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150    pub mean: f64,
151    pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156        if !self.mean.is_finite() {
157            return Err(FittedModelError::PayloadCorrupt {
158                reason: format!("{context} latent z mean must be finite"),
159            });
160        }
161        if !(self.sd.is_finite() && self.sd > 1e-12) {
162            return Err(FittedModelError::PayloadCorrupt {
163                reason: format!(
164                    "{context} latent z sd must be finite and > 1e-12; got {}",
165                    self.sd
166                ),
167            });
168        }
169        Ok::<(), _>(())
170    }
171
172    pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173        self.validate(context)?;
174        if z.iter().any(|value| !value.is_finite()) {
175            return Err(FittedModelError::PayloadCorrupt {
176                reason: format!("{context} requires finite z values"),
177            });
178        }
179        Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180    }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189    #[default]
190    FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195    #[serde(default)]
196    pub score_kind: TransformationScoreKind,
197    #[serde(default = "default_transformation_score_pit_clip_eps")]
198    pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202    TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206    pub fn finite_support_pit() -> Self {
207        Self {
208            score_kind: TransformationScoreKind::FiniteSupportPit,
209            clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210        }
211    }
212
213    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214        if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215            return Err(FittedModelError::IncompatibleConfig {
216                reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217            });
218        }
219        if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220            return Err(FittedModelError::IncompatibleConfig {
221                reason: format!(
222                    "{context} requires PIT clip_eps in (0, 0.5), got {}",
223                    self.clip_eps
224                ),
225            });
226        }
227        Ok::<(), _>(())
228    }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233    pub version: u32,
234    pub formula: String,
235    pub model_kind: ModelKind,
236    pub family_state: FittedFamily,
237    pub family: String,
238    /// Human-readable advisories produced while materializing this model —
239    /// e.g. an mgcv-style "k was reduced to the data support" note when a
240    /// cubic-regression marginal is capped, or a basis-degradation note. These
241    /// are surfaced to CLI users via `print_inference_summary`; persisting them
242    /// here lets the Python (gamfit) interface surface the SAME advisories as
243    /// warnings / `model.notes` instead of silently dropping them at the FFI
244    /// boundary (#1543). `#[serde(default)]` keeps older payloads (which had no
245    /// such field) deserializing cleanly as "no notes".
246    #[serde(default)]
247    pub inference_notes: Vec<String>,
248    #[serde(default)]
249    pub used_device: bool,
250    #[serde(default)]
251    pub fit_result: Option<UnifiedFitResult>,
252    /// Unified (family-agnostic) representation of the fit result.
253    #[serde(default)]
254    pub unified: Option<UnifiedFitResult>,
255    /// Exact O(n) spline-scan fit representation (#1030/#1034): the
256    /// state-space smoothing-spline posterior of a single 1-D Gaussian
257    /// smooth. When `Some`, this standard Gaussian model's predictions
258    /// replay the Gaussian bridge from this state and the model carries no
259    /// dense `fit_result` (the representations are mutually exclusive —
260    /// enforced by `validate_for_persistence`). `#[serde(default)]` so older
261    /// payloads read as: not a scan model.
262    #[serde(default)]
263    pub spline_scan: Option<SavedSplineScan>,
264    /// O(n log n) multiresolution residual-cascade fit (#1032): the persisted
265    /// multilevel Wendland-frame state for a single scattered 2–3D Gaussian
266    /// smooth past the dense-kernel cliff. When `Some`, predictions replay the
267    /// cascade posterior; mutually exclusive with `spline_scan`/`fit_result`.
268    /// `#[serde(default)]` keeps forward-compatibility with older payloads.
269    #[serde(default)]
270    pub residual_cascade: Option<SavedResidualCascade>,
271    #[serde(default)]
272    pub data_schema: Option<DataSchema>,
273    pub link: Option<InverseLink>,
274    #[serde(default)]
275    pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276    #[serde(default)]
277    pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278    #[serde(default)]
279    pub formula_noise: Option<String>,
280    #[serde(default)]
281    pub formula_logslope: Option<String>,
282    #[serde(default)]
283    pub formula_logslopes: Option<Vec<String>>,
284    #[serde(default)]
285    pub offset_column: Option<String>,
286    #[serde(default)]
287    pub noise_offset_column: Option<String>,
288    #[serde(default)]
289    pub beta_noise: Option<Vec<f64>>,
290    #[serde(default)]
291    pub noise_projection: Option<Vec<Vec<f64>>>,
292    #[serde(default)]
293    pub noise_center: Option<Vec<f64>>,
294    #[serde(default)]
295    pub noise_scale: Option<Vec<f64>>,
296    #[serde(default)]
297    pub noise_non_intercept_start: Option<usize>,
298    /// Tikhonov ridge alpha used by `solve_scale_projection` when fitting
299    /// `noise_projection`.  Persisted so prediction-time replay is identical
300    /// to fit-time projection.
301    #[serde(default)]
302    pub noise_projection_ridge_alpha: Option<f64>,
303    #[serde(default)]
304    pub gaussian_response_scale: Option<f64>,
305    #[serde(default)]
306    pub linkwiggle_knots: Option<Vec<f64>>,
307    #[serde(default)]
308    pub linkwiggle_degree: Option<usize>,
309    #[serde(default)]
310    pub beta_link_wiggle: Option<Vec<f64>>,
311    #[serde(default)]
312    pub baseline_timewiggle_knots: Option<Vec<f64>>,
313    #[serde(default)]
314    pub baseline_timewiggle_degree: Option<usize>,
315    #[serde(default)]
316    pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
317    #[serde(default)]
318    pub baseline_timewiggle_double_penalty: Option<bool>,
319    #[serde(default)]
320    pub beta_baseline_timewiggle: Option<Vec<f64>>,
321    #[serde(default)]
322    pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
323    #[serde(default)]
324    pub z_column: Option<String>,
325    #[serde(default)]
326    pub z_columns: Option<Vec<String>>,
327    #[serde(default)]
328    pub latent_z_normalization: Option<SavedLatentZNormalization>,
329    #[serde(default)]
330    pub latent_score_contract: Option<SavedLatentScoreContract>,
331    #[serde(default)]
332    pub latent_measure: Option<LatentMeasureKind>,
333    /// Optional rank-INT calibration for the latent score (BMS family).
334    /// When `Some`, the marginal-slope predictor routes the input `z`
335    /// through [`LatentZRankIntCalibration::apply_at_predict`] before the
336    /// closed-form standard-normal kernel, matching fit-time semantics.
337    /// `#[serde(default)]` so models persisted before this field existed
338    /// continue to deserialize cleanly (interpreted as: no calibration).
339    #[serde(default)]
340    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
341    /// Optional conditional location-scale calibration of the latent score
342    /// (#905, BMS family). When `Some`, the marginal-slope predictor replaces
343    /// the (normalized) input `z` by `ζ = (z − m(C))/√v(C)` — rebuilding the
344    /// conditioning span `a(C)` from the marginal prediction design — before
345    /// the closed-form standard-normal kernel, matching fit-time semantics.
346    /// Mutually exclusive with `latent_z_rank_int_calibration`. `#[serde(default)]`
347    /// so pre-existing models deserialize cleanly (interpreted as: no
348    /// conditional calibration).
349    #[serde(default)]
350    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
351    #[serde(default)]
352    pub marginal_baseline: Option<f64>,
353    #[serde(default)]
354    pub logslope_baseline: Option<f64>,
355    #[serde(default)]
356    pub logslope_baselines: Option<Vec<f64>>,
357    #[serde(default)]
358    pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
359    #[serde(default)]
360    pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
361    /// Width `p₁` of the survival marginal-slope absorbed Stage-1 influence block
362    /// (#461) when present (the dedicated trailing absorber block). Predict drops
363    /// its `γ`; this records the column count so the predictor can account for
364    /// the extra block and slice `γ` out of the joint covariance.
365    #[serde(default)]
366    pub influence_absorber_width: Option<usize>,
367    #[serde(default)]
368    pub survival_entry: Option<String>,
369    #[serde(default)]
370    pub survival_exit: Option<String>,
371    #[serde(default)]
372    pub survival_event: Option<String>,
373    #[serde(default)]
374    pub survivalspec: Option<String>,
375    #[serde(default)]
376    pub survival_cause_count: Option<usize>,
377    #[serde(default)]
378    pub survival_endpoint_names: Option<Vec<String>>,
379    #[serde(default)]
380    pub survival_baseline_target: Option<String>,
381    #[serde(default)]
382    pub survival_baseline_scale: Option<f64>,
383    #[serde(default)]
384    pub survival_baseline_shape: Option<f64>,
385    #[serde(default)]
386    pub survival_baseline_rate: Option<f64>,
387    #[serde(default)]
388    pub survival_baseline_makeham: Option<f64>,
389    #[serde(default)]
390    pub survival_time_basis: Option<String>,
391    #[serde(default)]
392    pub survival_time_degree: Option<usize>,
393    #[serde(default)]
394    pub survival_time_knots: Option<Vec<f64>>,
395    #[serde(default)]
396    pub survival_time_keep_cols: Option<Vec<usize>>,
397    #[serde(default)]
398    pub survival_time_smooth_lambda: Option<f64>,
399    #[serde(default)]
400    pub survival_time_anchor: Option<f64>,
401    #[serde(default)]
402    pub survivalridge_lambda: Option<f64>,
403    #[serde(default)]
404    pub survival_likelihood: Option<String>,
405    #[serde(default)]
406    pub survival_beta_time: Option<Vec<f64>>,
407    #[serde(default)]
408    pub survival_beta_threshold: Option<Vec<f64>>,
409    #[serde(default)]
410    pub survival_beta_log_sigma: Option<Vec<f64>>,
411    #[serde(default)]
412    pub survival_noise_projection: Option<Vec<Vec<f64>>>,
413    #[serde(default)]
414    pub survival_noise_center: Option<Vec<f64>>,
415    #[serde(default)]
416    pub survival_noise_scale: Option<Vec<f64>>,
417    #[serde(default)]
418    pub survival_noise_non_intercept_start: Option<usize>,
419    /// Survival analog of `noise_projection_ridge_alpha`: the Tikhonov ridge
420    /// used when fitting the survival log-sigma projection.
421    #[serde(default)]
422    pub survival_noise_projection_ridge_alpha: Option<f64>,
423    #[serde(default)]
424    pub survival_distribution: Option<ResidualDistribution>,
425    #[serde(default)]
426    pub training_headers: Option<Vec<String>>,
427    /// Container type of the table the model was fitted on, as detected by the
428    /// Python binding (`"pandas"`, `"polars"`, `"pyarrow"`, `"numpy"`, or an
429    /// ambiguous tag such as `"unknown"`). This is presentation provenance, not
430    /// math: `gamfit.Model.predict` uses it as the output-container fallback
431    /// when the *prediction input* is itself container-ambiguous (a `dict` of
432    /// columns or a `list` of record dicts). Persisting it in the model payload
433    /// makes the fallback survive `save`/`load` and `dumps`/`loads`, so a
434    /// reloaded model predicts into the same container as the in-memory one.
435    /// `None` for older payloads (and for fits that never recorded a kind), in
436    /// which case the fallback degrades to `"dict"`, matching pre-persistence
437    /// behaviour for unknown-kind training tables.
438    #[serde(default, skip_serializing_if = "Option::is_none")]
439    pub training_table_kind: Option<String>,
440    /// Per-column (min, max) of the training input matrix, parallel to
441    /// `training_headers`. At predict time, inputs are axis-clipped to these
442    /// ranges so that out-of-distribution points evaluate at the nearest face
443    /// of the training bounding box rather than extrapolating polynomial
444    /// trends from polyharmonic / spline bases beyond the data envelope. Old
445    /// model JSONs that pre-date this field load with `None`, in which case
446    /// the predict path falls through unchanged (no clipping).
447    #[serde(default)]
448    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
449    /// User-supplied per-group metadata, keyed by stable group identifier.
450    ///
451    /// This is intentionally schema-free JSON so provenance maps can carry
452    /// mixed scalar/list/object values. Missing in older payloads means no
453    /// group metadata was persisted.
454    #[serde(default, skip_serializing_if = "Option::is_none")]
455    pub group_metadata: Option<GroupMetadata>,
456    /// Deployment-time no-refit group extensions applied after fitting.
457    ///
458    /// Each entry records the requested group coordinate, caller metadata, and
459    /// prior used to initialize the inserted coefficient. The active
460    /// prediction contract lives in `data_schema` + `resolved_termspec`; this
461    /// ledger preserves provenance without requiring a refit.
462    #[serde(default, skip_serializing_if = "Vec::is_empty")]
463    pub deployment_extensions: Vec<SavedDeploymentExtension>,
464    /// Transformation-normal: B-spline knots for the response-direction basis.
465    #[serde(default)]
466    pub transformation_response_knots: Option<Vec<f64>>,
467    /// Transformation-normal: deviation nullspace transform matrix (row-major).
468    #[serde(default)]
469    pub transformation_response_transform: Option<Vec<Vec<f64>>>,
470    /// Transformation-normal: B-spline degree for the response basis.
471    #[serde(default)]
472    pub transformation_response_degree: Option<usize>,
473    /// Transformation-normal: median of the response used for anchoring.
474    #[serde(default)]
475    pub transformation_response_median: Option<f64>,
476    /// Transformation-normal saved score contract. The score is the exact
477    /// finite-support PIT:
478    /// z = Phi^{-1}((Phi(h) - Phi(h_L)) / (Phi(h_U) - Phi(h_L))).
479    #[serde(default)]
480    pub transformation_score_calibration: Option<TransformationScoreCalibration>,
481    #[serde(default)]
482    pub resolved_termspec: Option<TermCollectionSpec>,
483    #[serde(default)]
484    pub resolved_termspec_noise: Option<TermCollectionSpec>,
485    #[serde(default)]
486    pub resolved_termspec_logslope: Option<TermCollectionSpec>,
487    #[serde(default)]
488    pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
489    #[serde(default)]
490    pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
491    /// Precomputed exact Gaussian-identity jackknife+ statistics (#942).
492    ///
493    /// Populated *only* for a standard Gaussian-identity model fit with unit
494    /// prior weights, where the closed-form Sherman–Morrison leave-one-out
495    /// substrate gives a distribution-free finite-sample (≥ level) prediction
496    /// interval with no held-out fold. When `Some`, `predict(interval=level)`
497    /// auto-routes through it (the MAGIC default); when `None` — any other
498    /// family/link, reweighted rows, or an older payload — predict falls back
499    /// to the model-based posterior band and labels the provenance honestly.
500    /// `#[serde(default)]` so pre-existing models deserialize as: no jackknife+
501    /// substrate available.
502    #[serde(default)]
503    pub gaussian_jackknife_plus:
504        Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
505    /// Precomputed substrate for the EXACT Gaussian-identity full-conformal set
506    /// (#942 Layer 1 + the frozen-ρ self-diagnostic).
507    ///
508    /// Populated under the SAME eligibility as `gaussian_jackknife_plus`
509    /// (Gaussian-identity, unit prior weights, offset-free, no link wiggle). It
510    /// persists the training design + response + frozen penalty `Sλ` so the
511    /// distribution-free EXACT prediction set (a union of intervals, valid for
512    /// any penalized smooth) can be replayed per test point — one Cholesky each,
513    /// zero refits — and surfaces the frozen-ρ certificate flag. `None` for any
514    /// ineligible model or an older payload, in which case the exact-set predict
515    /// path errors with a clear message and the caller uses jackknife+ or the
516    /// posterior band. `#[serde(default)]` so pre-existing models deserialize as
517    /// no exact substrate available.
518    #[serde(default)]
519    pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
523pub struct SavedDeploymentExtension {
524    pub name: String,
525    pub kind: String,
526    pub term: String,
527    pub level: JsonValue,
528    pub level_bits: u64,
529    pub coefficient_index: usize,
530    pub coefficient_mean: f64,
531    pub coefficient_variance: f64,
532    #[serde(default, skip_serializing_if = "Option::is_none")]
533    pub metadata: Option<JsonValue>,
534    #[serde(default, skip_serializing_if = "Option::is_none")]
535    pub prior: Option<JsonValue>,
536}
537
538/// Append deployment-only extension columns to the fitted design coordinate system.
539///
540/// No-refit group extension adds a new coefficient block after the fitted
541/// coefficient vector:
542///
543///   beta_ext = [beta_old, beta_new],    beta_new = mu_new.
544///
545/// For a new random-effect level g, the appended basis is the indicator
546/// e_g(x_i) = 1{x_i == g}.  The old fitted basis X_old is not rebuilt or
547/// reordered, so rows that do not exercise g have
548///
549///   eta_ext = X_old beta_old + 0 * beta_new = eta_old.
550///
551/// Rows at the new level get the exact prior-mean shift e_g beta_new.  This
552/// helper enforces the coordinate identity by requiring extension coefficient
553/// indices to be the consecutive tail columns of the base design.
554pub fn append_deployment_extension_columns(
555    model: &FittedModelPayload,
556    data: ndarray::ArrayView2<'_, f64>,
557    col_map: &HashMap<String, usize>,
558    training_headers: Option<&Vec<String>>,
559    base_design: Array2<f64>,
560) -> Result<Array2<f64>, FittedModelError> {
561    if model.deployment_extensions.is_empty() {
562        return Ok(base_design);
563    }
564    if base_design.nrows() != data.nrows() {
565        return Err(FittedModelError::SchemaMismatch {
566            reason: format!(
567                "deployment extension design row mismatch: base design has {} rows but data has {}",
568                base_design.nrows(),
569                data.nrows()
570            ),
571        });
572    }
573    let spec = model
574        .resolved_termspec
575        .as_ref()
576        .ok_or_else(|| FittedModelError::MissingField {
577            reason: "deployment extension prediction requires saved resolved_termspec; refit"
578                .to_string(),
579        })?;
580    let n = base_design.nrows();
581    let p_old = base_design.ncols();
582    let mut extensions: Vec<&SavedDeploymentExtension> =
583        model.deployment_extensions.iter().collect();
584    extensions.sort_by_key(|extension| extension.coefficient_index);
585    for (tail_idx, extension) in extensions.iter().enumerate() {
586        let expected = p_old + tail_idx;
587        if extension.coefficient_index != expected {
588            return Err(FittedModelError::SchemaMismatch {
589                reason: format!(
590                    "deployment extension '{}' has coefficient index {}, expected append-only index {}",
591                    extension.name, extension.coefficient_index, expected
592                ),
593            });
594        }
595    }
596
597    let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
598    out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
599    for (tail_idx, extension) in extensions.into_iter().enumerate() {
600        if extension.kind != "random-effect-level" {
601            return Err(FittedModelError::IncompatibleConfig {
602                reason: format!(
603                    "unsupported deployment extension kind '{}' for '{}'",
604                    extension.kind, extension.name
605                ),
606            });
607        }
608        let term = spec
609            .random_effect_terms
610            .iter()
611            .find(|term| term.name == extension.term)
612            .ok_or_else(|| FittedModelError::MissingField {
613                reason: format!(
614                    "deployment extension '{}' references unknown random-effect term '{}'",
615                    extension.name, extension.term
616                ),
617            })?;
618        let prediction_col = training_headers
619            .and_then(|headers| headers.get(term.feature_col))
620            .and_then(|name| col_map.get(name))
621            .copied()
622            .unwrap_or(term.feature_col);
623        if prediction_col >= data.ncols() {
624            return Err(FittedModelError::SchemaMismatch {
625                reason: format!(
626                    "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
627                    extension.name,
628                    prediction_col,
629                    data.ncols()
630                ),
631            });
632        }
633        let col = p_old + tail_idx;
634        for row in 0..n {
635            if data[[row, prediction_col]].to_bits() == extension.level_bits {
636                out[[row, col]] = 1.0;
637            }
638        }
639    }
640    Ok(out)
641}
642
643#[derive(Clone, Debug, Serialize, Deserialize)]
644pub struct SavedLatentScoreContract {
645    pub semantics: String,
646    pub source_transform_id: Option<String>,
647    pub normalization_mean: f64,
648    pub normalization_sd: f64,
649    pub clip_eps: Option<f64>,
650    pub conditioning_columns: Vec<String>,
651}
652
653impl FittedModelPayload {
654    pub fn new(
655        version: u32,
656        formula: String,
657        model_kind: ModelKind,
658        family_state: FittedFamily,
659        family: String,
660    ) -> Self {
661        Self {
662            version,
663            formula,
664            model_kind,
665            family_state,
666            family,
667            inference_notes: Vec::new(),
668            used_device: false,
669            fit_result: None,
670            unified: None,
671            spline_scan: None,
672            residual_cascade: None,
673            data_schema: None,
674            link: None,
675            mixture_link_param_covariance: None,
676            sas_param_covariance: None,
677            formula_noise: None,
678            formula_logslope: None,
679            formula_logslopes: None,
680            offset_column: None,
681            noise_offset_column: None,
682            beta_noise: None,
683            noise_projection: None,
684            noise_center: None,
685            noise_scale: None,
686            noise_non_intercept_start: None,
687            noise_projection_ridge_alpha: None,
688            gaussian_response_scale: None,
689            linkwiggle_knots: None,
690            linkwiggle_degree: None,
691            beta_link_wiggle: None,
692            baseline_timewiggle_knots: None,
693            baseline_timewiggle_degree: None,
694            baseline_timewiggle_penalty_orders: None,
695            baseline_timewiggle_double_penalty: None,
696            beta_baseline_timewiggle: None,
697            beta_baseline_timewiggle_by_cause: None,
698            z_column: None,
699            z_columns: None,
700            latent_z_normalization: None,
701            latent_score_contract: None,
702            latent_measure: None,
703            latent_z_rank_int_calibration: None,
704            latent_z_conditional_calibration: None,
705            marginal_baseline: None,
706            logslope_baseline: None,
707            logslope_baselines: None,
708            score_warp_runtime: None,
709            link_deviation_runtime: None,
710            influence_absorber_width: None,
711            survival_entry: None,
712            survival_exit: None,
713            survival_event: None,
714            survivalspec: None,
715            survival_cause_count: None,
716            survival_endpoint_names: None,
717            survival_baseline_target: None,
718            survival_baseline_scale: None,
719            survival_baseline_shape: None,
720            survival_baseline_rate: None,
721            survival_baseline_makeham: None,
722            survival_time_basis: None,
723            survival_time_degree: None,
724            survival_time_knots: None,
725            survival_time_keep_cols: None,
726            survival_time_smooth_lambda: None,
727            survival_time_anchor: None,
728            survivalridge_lambda: None,
729            survival_likelihood: None,
730            survival_beta_time: None,
731            survival_beta_threshold: None,
732            survival_beta_log_sigma: None,
733            survival_noise_projection: None,
734            survival_noise_center: None,
735            survival_noise_scale: None,
736            survival_noise_non_intercept_start: None,
737            survival_noise_projection_ridge_alpha: None,
738            survival_distribution: None,
739            training_headers: None,
740            training_table_kind: None,
741            training_feature_ranges: None,
742            group_metadata: None,
743            deployment_extensions: Vec::new(),
744            transformation_response_knots: None,
745            transformation_response_transform: None,
746            transformation_response_degree: None,
747            transformation_response_median: None,
748            transformation_score_calibration: None,
749            resolved_termspec: None,
750            resolved_termspec_noise: None,
751            resolved_termspec_logslope: None,
752            resolved_termspec_logslopes: None,
753            adaptive_regularization_diagnostics: None,
754            gaussian_jackknife_plus: None,
755            full_conformal: None,
756        }
757    }
758
759    pub fn set_training_feature_metadata(
760        &mut self,
761        headers: Vec<String>,
762        feature_ranges: Vec<(f64, f64)>,
763    ) {
764        self.training_headers = Some(headers);
765        self.training_feature_ranges = Some(feature_ranges);
766    }
767
768    fn synchronize_empty_feature_contract(&mut self) {
769        if self.fit_result.is_none() {
770            return;
771        }
772        let Some(schema) = self.data_schema.as_ref() else {
773            return;
774        };
775        if !schema.columns.is_empty() {
776            return;
777        }
778        self.training_headers.get_or_insert_with(Vec::new);
779        self.resolved_termspec
780            .get_or_insert_with(|| TermCollectionSpec {
781                linear_terms: Vec::new(),
782                smooth_terms: Vec::new(),
783                random_effect_terms: Vec::new(),
784            });
785    }
786
787    /// Write the persistable time-basis snapshot for a survival model.
788    ///
789    /// This is the only path that should populate the `survival_time_*`
790    /// fields used by the loader. Routing every FFI builder through this
791    /// helper guarantees no builder can silently drop a field — the
792    /// marginal-slope save→load bug was a builder that
793    /// missed `survival_time_basis`.
794    pub fn apply_survival_time_basis(
795        &mut self,
796        snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
797    ) {
798        self.survival_time_basis = Some(snapshot.basisname.clone());
799        self.survival_time_degree = snapshot.degree;
800        self.survival_time_knots = snapshot.knots.clone();
801        self.survival_time_keep_cols = snapshot.keep_cols.clone();
802        self.survival_time_smooth_lambda = snapshot.smooth_lambda;
803        self.survival_time_anchor = Some(snapshot.anchor);
804    }
805
806    fn validate_payload_version(&self) -> Result<(), FittedModelError> {
807        if self.version != MODEL_PAYLOAD_VERSION {
808            return Err(FittedModelError::SchemaMismatch {
809                reason: format!(
810                    "saved model payload schema mismatch: file has version={}, \
811                 this binary expects MODEL_PAYLOAD_VERSION={}. \
812                 Refit with the current CLI, or rebuild the reader at the same \
813                 version the model was written with.",
814                    self.version, MODEL_PAYLOAD_VERSION
815                ),
816            });
817        }
818        Ok(())
819    }
820}
821
822#[derive(Clone, Serialize, Deserialize)]
823#[serde(tag = "model_type", rename_all = "kebab-case")]
824pub enum FittedModel {
825    Standard { payload: FittedModelPayload },
826    LocationScale { payload: FittedModelPayload },
827    MarginalSlope { payload: FittedModelPayload },
828    Survival { payload: FittedModelPayload },
829    TransformationNormal { payload: FittedModelPayload },
830}
831
832#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
833#[serde(rename_all = "kebab-case")]
834pub enum ModelKind {
835    Standard,
836    LocationScale,
837    MarginalSlope,
838    Survival,
839    TransformationNormal,
840}
841
842#[derive(Clone, Debug, Serialize, Deserialize)]
843#[serde(tag = "family_kind", rename_all = "kebab-case")]
844pub enum FittedFamily {
845    Standard {
846        likelihood: LikelihoodSpec,
847        #[serde(default)]
848        link: Option<StandardLink>,
849        #[serde(default)]
850        latent_cloglog_state: Option<LatentCLogLogState>,
851        #[serde(default)]
852        mixture_state: Option<MixtureLinkState>,
853        #[serde(default)]
854        sas_state: Option<SasLinkState>,
855    },
856    LocationScale {
857        likelihood: LikelihoodSpec,
858        #[serde(default)]
859        base_link: Option<InverseLink>,
860    },
861    MarginalSlope {
862        likelihood: LikelihoodSpec,
863        base_link: InverseLink,
864        frailty: FrailtySpec,
865    },
866    Survival {
867        likelihood: LikelihoodSpec,
868        #[serde(default)]
869        survival_likelihood: Option<String>,
870        #[serde(default)]
871        survival_distribution: Option<ResidualDistribution>,
872        frailty: FrailtySpec,
873    },
874    LatentSurvival {
875        frailty: FrailtySpec,
876    },
877    LatentBinary {
878        frailty: FrailtySpec,
879    },
880    TransformationNormal {
881        likelihood: LikelihoodSpec,
882    },
883}
884
885#[derive(Clone, Copy, Debug, Eq, PartialEq)]
886pub enum PredictModelClass {
887    Standard,
888    GaussianLocationScale,
889    BinomialLocationScale,
890    /// Genuine-dispersion location-scale (#913): NegativeBinomial / Gamma / Beta
891    /// / Tweedie mean families fitted with a `noise_formula` overdispersion
892    /// channel. Predicted through the GLM mean inverse link (not the binomial
893    /// threshold-scale predictor).
894    DispersionLocationScale,
895    BernoulliMarginalSlope,
896    Survival,
897    TransformationNormal,
898}
899
900impl PredictModelClass {
901    #[inline]
902    pub const fn name(self) -> &'static str {
903        match self {
904            Self::Standard => "standard",
905            Self::GaussianLocationScale => "gaussian location-scale",
906            Self::BinomialLocationScale => "binomial location-scale",
907            Self::DispersionLocationScale => "dispersion location-scale",
908            Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
909            Self::Survival => "survival",
910            Self::TransformationNormal => "transformation-normal",
911        }
912    }
913}
914
915#[derive(Clone, Debug)]
916pub struct SavedLinkWiggleRuntime {
917    pub knots: Vec<f64>,
918    pub degree: usize,
919    pub beta: Vec<f64>,
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedBaselineTimeWiggleRuntime {
924    pub knots: Vec<f64>,
925    pub degree: usize,
926    pub penalty_orders: Vec<usize>,
927    pub double_penalty: bool,
928    pub beta: Vec<f64>,
929}
930
931// Re-export so saved-model consumers can refer to the anchor-block tag
932// without reaching across module boundaries.
933pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
934
935#[derive(Clone, Debug, Serialize, Deserialize)]
936pub struct SavedCompiledFlexBlock {
937    pub kernel: String,
938    pub breakpoints: Vec<f64>,
939    pub basis_dim: usize,
940    pub span_c0: Vec<Vec<f64>>,
941    pub span_c1: Vec<Vec<f64>>,
942    pub span_c2: Vec<Vec<f64>>,
943    pub span_c3: Vec<Vec<f64>>,
944    /// Cross-block anchor-residual coefficient matrix `M` of shape
945    /// `d × basis_dim`. When present, predict-time evaluation subtracts
946    /// `n_row · M` from each cubic-span row (where `n_row` stacks the
947    /// per-row parametric anchor values in the order given by
948    /// `anchor_components`).
949    #[serde(default)]
950    pub anchor_correction: Option<Vec<Vec<f64>>>,
951    /// Ordered list of parametric anchor components whose stacked row
952    /// values combine into `n_row`. Empty unless
953    /// `anchor_correction` is `Some`.
954    #[serde(default)]
955    pub anchor_components: Vec<SavedAnchorComponent>,
956}
957
958#[derive(Clone, Debug, Serialize, Deserialize)]
959pub struct SavedAnchorComponent {
960    pub kind: SavedAnchorKind,
961}
962
963#[derive(Clone, Debug, Serialize, Deserialize)]
964pub enum SavedAnchorKind {
965    Parametric {
966        block: ParametricAnchorBlock,
967        ncols: usize,
968    },
969    /// Flex-evaluation anchor (sibling flex block's reparameterised basis,
970    /// evaluated at training rows at fit time and at predict rows at
971    /// predict time). The predictor stacks `ncols` columns from the
972    /// sibling runtime's `design(arg)` into `n_row`.
973    FlexEvaluation { ncols: usize },
974}
975
976#[derive(Clone, Debug)]
977pub struct SavedPredictionRuntime {
978    pub model_class: PredictModelClass,
979    pub likelihood: LikelihoodSpec,
980    pub inverse_link: Option<InverseLink>,
981    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
982    pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
983    pub score_warp: Option<SavedCompiledFlexBlock>,
984    pub link_deviation: Option<SavedCompiledFlexBlock>,
985    /// Rank-INT latent-z calibration carried into the predictor build.
986    /// `None` for non-BMS models and for BMS fits whose latent measure
987    /// did not require rank-INT calibration.
988    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
989    /// Conditional location-scale latent-z calibration (#905) carried into the
990    /// predictor build. `None` for non-BMS models and for BMS fits whose Auto
991    /// path did not detect a conditional `E[z|C]`/`Var(z|C)` shift. When
992    /// `Some`, the predictor replaces the normalized `z` by `ζ = (z−m(C))/√v(C)`
993    /// using the marginal prediction design as the conditioning span.
994    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
995    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the
996    /// survival marginal-slope fit hosted a dedicated additive absorber (the
997    /// trailing block). `None` when no CTN Stage-1 chain produced an influence
998    /// Jacobian. At predict the absorber's `γ` is DROPPED (the orthogonalized
999    /// β̂ is a training-fit property), so the predictor uses this width only to
1000    /// (a) account for the extra trailing block in the saved block count and
1001    /// (b) slice `γ`'s rows/cols out of the joint covariance. Survival hosts the
1002    /// absorber as its own block (unlike the BMS A2 widened-marginal design),
1003    /// so it never widens any persisted prediction design.
1004    pub influence_absorber_width: Option<usize>,
1005}
1006
1007pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1008    fit.block_by_role(BlockRole::Location)
1009        .or_else(|| fit.block_by_role(BlockRole::Mean))
1010        .map(|block| block.beta.clone())
1011}
1012
1013pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1014    fit.block_by_role(BlockRole::Threshold)
1015        .or_else(|| fit.block_by_role(BlockRole::Location))
1016        .or_else(|| fit.block_by_role(BlockRole::Mean))
1017        .map(|block| block.beta.clone())
1018}
1019
1020pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021    fit.block_by_role(BlockRole::Scale)
1022        .map(|block| block.beta.clone())
1023}
1024
1025/// Whether a `ModelKind::LocationScale` likelihood's response is one of the
1026/// genuine-dispersion mean families (#913) — NegativeBinomial, Gamma, Beta or
1027/// Tweedie. These carry a `noise_formula` overdispersion channel and must be
1028/// predicted through the GLM mean inverse link (the
1029/// [`PredictModelClass::DispersionLocationScale`] path), NOT the binomial
1030/// threshold-scale predictor. The binomial location-scale (BMS ordinal) path is
1031/// the only other non-Gaussian location-scale family, with a `Binomial`
1032/// response.
1033fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1034    use gam_problem::types::ResponseFamily;
1035    matches!(
1036        response,
1037        ResponseFamily::NegativeBinomial { .. }
1038            | ResponseFamily::Gamma
1039            | ResponseFamily::Beta { .. }
1040            | ResponseFamily::Tweedie { .. }
1041    )
1042}
1043
1044fn validate_location_scale_saved_fit(
1045    fit: &UnifiedFitResult,
1046    model_class: PredictModelClass,
1047    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1048) -> Result<(), FittedModelError> {
1049    let primary = match model_class {
1050        // Gaussian and dispersion (#913) location-scale both predict the mean
1051        // through the Location block; the binomial threshold-scale class reads
1052        // the Threshold block instead.
1053        PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1054            gaussian_location_scale_mean_beta(fit)
1055        }
1056        PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1057        _ => None,
1058    }
1059    .ok_or_else(|| FittedModelError::MissingField {
1060        reason: match model_class {
1061            PredictModelClass::GaussianLocationScale => {
1062                "gaussian-location-scale saved fit is missing mean/location block".to_string()
1063            }
1064            PredictModelClass::DispersionLocationScale => {
1065                "dispersion-location-scale saved fit is missing mean/location block".to_string()
1066            }
1067            PredictModelClass::BinomialLocationScale => {
1068                "binomial-location-scale saved fit is missing threshold/location block".to_string()
1069            }
1070            _ => "location-scale saved fit is missing primary block".to_string(),
1071        },
1072    })?;
1073
1074    let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1075        reason: "location-scale saved fit is missing scale block".to_string(),
1076    })?;
1077    let expected =
1078        primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1079
1080    if let Some(cov) = fit.beta_covariance()
1081        && (cov.nrows() != expected || cov.ncols() != expected)
1082    {
1083        return Err(FittedModelError::SchemaMismatch {
1084            reason: format!(
1085                "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1086                cov.nrows(),
1087                cov.ncols(),
1088                expected,
1089                expected
1090            ),
1091        });
1092    }
1093    if let Some(cov) = fit.beta_covariance_corrected()
1094        && (cov.nrows() != expected || cov.ncols() != expected)
1095    {
1096        return Err(FittedModelError::SchemaMismatch {
1097            reason: format!(
1098                "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1099                cov.nrows(),
1100                cov.ncols(),
1101                expected,
1102                expected
1103            ),
1104        });
1105    }
1106    Ok(())
1107}
1108
1109fn validate_survival_saved_block_matches_payload(
1110    fit: &UnifiedFitResult,
1111    role: BlockRole,
1112    payload_beta: Option<&Vec<f64>>,
1113    label: &str,
1114) -> Result<usize, FittedModelError> {
1115    let block = fit
1116        .block_by_role(role)
1117        .ok_or_else(|| FittedModelError::MissingField {
1118            reason: format!("location-scale survival saved fit is missing {label} block"),
1119        })?;
1120    if let Some(saved) = payload_beta
1121        && block.beta.to_vec() != *saved
1122    {
1123        return Err(FittedModelError::SchemaMismatch {
1124            reason: format!(
1125                "location-scale survival saved {label} coefficients disagree with fit_result"
1126            ),
1127        });
1128    }
1129    Ok(block.beta.len())
1130}
1131
1132fn validate_survival_location_scale_saved_fit(
1133    payload: &FittedModelPayload,
1134    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1135) -> Result<(), FittedModelError> {
1136    let fit = payload
1137        .fit_result
1138        .as_ref()
1139        .ok_or_else(|| FittedModelError::MissingField {
1140            reason: "location-scale survival model is missing canonical fit_result payload"
1141                .to_string(),
1142        })?;
1143    let p_time = validate_survival_saved_block_matches_payload(
1144        fit,
1145        BlockRole::Time,
1146        payload.survival_beta_time.as_ref(),
1147        "time",
1148    )?;
1149    let p_threshold = validate_survival_saved_block_matches_payload(
1150        fit,
1151        BlockRole::Threshold,
1152        payload.survival_beta_threshold.as_ref(),
1153        "threshold",
1154    )?;
1155    let p_log_sigma = validate_survival_saved_block_matches_payload(
1156        fit,
1157        BlockRole::Scale,
1158        payload.survival_beta_log_sigma.as_ref(),
1159        "log-sigma",
1160    )?;
1161    let p_wiggle = match link_wiggle {
1162        Some(runtime) => {
1163            let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1164                FittedModelError::MissingField {
1165                    reason: "location-scale survival saved fit is missing link-wiggle block"
1166                        .to_string(),
1167                }
1168            })?;
1169            if block.beta.to_vec() != runtime.beta {
1170                return Err(FittedModelError::SchemaMismatch {
1171                    reason:
1172                        "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1173                            .to_string(),
1174                });
1175            }
1176            runtime.beta.len()
1177        }
1178        None => {
1179            if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1180                return Err(FittedModelError::SchemaMismatch {
1181                    reason:
1182                        "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1183                            .to_string(),
1184                });
1185            }
1186            0
1187        }
1188    };
1189    let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1190
1191    if let Some(cov) = fit.beta_covariance()
1192        && (cov.nrows() != expected || cov.ncols() != expected)
1193    {
1194        return Err(FittedModelError::SchemaMismatch {
1195            reason: format!(
1196                "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1197                cov.nrows(),
1198                cov.ncols(),
1199                expected,
1200                expected
1201            ),
1202        });
1203    }
1204    if let Some(cov) = fit.beta_covariance_corrected()
1205        && (cov.nrows() != expected || cov.ncols() != expected)
1206    {
1207        return Err(FittedModelError::SchemaMismatch {
1208            reason: format!(
1209                "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1210                cov.nrows(),
1211                cov.ncols(),
1212                expected,
1213                expected
1214            ),
1215        });
1216    }
1217    Ok(())
1218}
1219
1220fn validate_marginal_slope_saved_fit(
1221    fit: &UnifiedFitResult,
1222    score_warp: Option<&SavedCompiledFlexBlock>,
1223    link_deviation: Option<&SavedCompiledFlexBlock>,
1224    fit_label: &str,
1225) -> Result<(), FittedModelError> {
1226    validate_marginal_slope_saved_fit_impl(
1227        fit,
1228        score_warp,
1229        link_deviation,
1230        fit_label,
1231        "bernoulli",
1232        2,
1233        "marginal, logslope",
1234    )
1235}
1236
1237fn validate_survival_marginal_slope_saved_fit(
1238    fit: &UnifiedFitResult,
1239    score_warp: Option<&SavedCompiledFlexBlock>,
1240    link_deviation: Option<&SavedCompiledFlexBlock>,
1241    fit_label: &str,
1242) -> Result<(), FittedModelError> {
1243    validate_marginal_slope_saved_fit_impl(
1244        fit,
1245        score_warp,
1246        link_deviation,
1247        fit_label,
1248        "survival",
1249        3,
1250        "time, marginal, slope",
1251    )
1252}
1253
1254/// Shared block-count + coefficient-dimension validation for the bernoulli
1255/// and survival marginal-slope saved-fit gates. The only family-specific
1256/// inputs are the family kind string ("bernoulli" / "survival"), the base
1257/// block count (2 for bernoulli, 3 for survival — the survival path has an
1258/// extra time block), and the base-block role list rendered in the error
1259/// message ("marginal, logslope" / "time, marginal, slope"). The score-warp
1260/// / link-deviation tail follows the same shape in both families.
1261fn validate_marginal_slope_saved_fit_impl(
1262    fit: &UnifiedFitResult,
1263    score_warp: Option<&SavedCompiledFlexBlock>,
1264    link_deviation: Option<&SavedCompiledFlexBlock>,
1265    fit_label: &str,
1266    family_kind: &str,
1267    base_block_count: usize,
1268    base_block_role_list: &str,
1269) -> Result<(), FittedModelError> {
1270    let expected_blocks = base_block_count
1271        + usize::from(score_warp.is_some())
1272        + usize::from(link_deviation.is_some());
1273    if fit.blocks.len() != expected_blocks {
1274        let score_warp_suffix = if score_warp.is_some() {
1275            ", score-warp"
1276        } else {
1277            ""
1278        };
1279        let link_deviation_suffix = if link_deviation.is_some() {
1280            ", link-deviation"
1281        } else {
1282            ""
1283        };
1284        return Err(FittedModelError::SchemaMismatch {
1285            reason: format!(
1286                "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1287                fit.blocks.len(),
1288            ),
1289        });
1290    }
1291    if let Some(runtime) = score_warp {
1292        let beta = &fit.blocks[base_block_count].beta;
1293        if beta.len() != runtime.basis_dim {
1294            return Err(FittedModelError::SchemaMismatch {
1295                reason: format!(
1296                    "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1297                    beta.len(),
1298                    runtime.basis_dim
1299                ),
1300            });
1301        }
1302    }
1303    if let Some(runtime) = link_deviation {
1304        let idx = base_block_count + usize::from(score_warp.is_some());
1305        let beta = &fit.blocks[idx].beta;
1306        if beta.len() != runtime.basis_dim {
1307            return Err(FittedModelError::SchemaMismatch {
1308                reason: format!(
1309                    "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1310                    beta.len(),
1311                    runtime.basis_dim
1312                ),
1313            });
1314        }
1315    }
1316    Ok(())
1317}
1318
1319impl SavedLinkWiggleRuntime {
1320    fn validate_monotone_derivative(
1321        &self,
1322        q0: &Array1<f64>,
1323    ) -> Result<Array1<f64>, FittedModelError> {
1324        // Monotonicity is verified pointwise at the actual evaluation grid `q0`
1325        // (the predict η). The fit guarantees a strictly-increasing warped link
1326        // across the training η range (#1596); checking at `q0` here flags an
1327        // extrapolation point where the learnable link genuinely turns
1328        // non-invertible, without rejecting the whole model for a sign dip in the
1329        // basis tail far outside any data or evaluation point.
1330        let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1331        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1332        let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1333        if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1334            return Err(FittedModelError::PayloadCorrupt {
1335                reason: format!(
1336                    "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1337                ),
1338            });
1339        }
1340        Ok(dq_dq0)
1341    }
1342
1343    pub fn constrained_basis(
1344        &self,
1345        q0: &Array1<f64>,
1346        basis_options: BasisOptions,
1347    ) -> Result<Array2<f64>, FittedModelError> {
1348        let knot_arr = Array1::from_vec(self.knots.clone());
1349        let constrained = monotone_wiggle_basis_with_derivative_order(
1350            q0.view(),
1351            &knot_arr,
1352            self.degree,
1353            basis_options.derivative_order,
1354        )
1355        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1356        if constrained.ncols() != self.beta.len() {
1357            return Err(FittedModelError::SchemaMismatch {
1358                reason: format!(
1359                    "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1360                    self.beta.len(),
1361                    constrained.ncols()
1362                ),
1363            });
1364        }
1365        Ok(constrained)
1366    }
1367
1368    pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1369        self.validate_monotone_derivative(q0)?;
1370        self.constrained_basis(q0, BasisOptions::value())
1371    }
1372
1373    pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1374        let q = Array1::from_vec(vec![q0]);
1375        let x = self.design(&q)?;
1376        if x.nrows() != 1 {
1377            return Err(FittedModelError::SchemaMismatch {
1378                reason: format!(
1379                    "saved link-wiggle scalar evaluation expected 1 row, got {}",
1380                    x.nrows()
1381                ),
1382            });
1383        }
1384        Ok(x.row(0).to_owned())
1385    }
1386
1387    pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1388        self.validate_monotone_derivative(q0)?;
1389        let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1390        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1391        Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1392    }
1393
1394    pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395        self.validate_monotone_derivative(q0)
1396    }
1397}
1398
1399impl SavedBaselineTimeWiggleRuntime {
1400    pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1401        validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1402            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1403    }
1404}
1405
1406impl SavedCompiledFlexBlock {
1407    pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1408        if self.kernel.is_empty() {
1409            return Err(FittedModelError::SchemaMismatch {
1410                reason: "saved anchored deviation runtime is missing the exact kernel marker"
1411                    .to_string(),
1412            });
1413        }
1414        if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1415            return Err(FittedModelError::IncompatibleConfig {
1416                reason: format!(
1417                    "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1418                    self.kernel,
1419                    crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1420                ),
1421            });
1422        }
1423        if self.basis_dim == 0 {
1424            return Err(FittedModelError::SchemaMismatch {
1425                reason: format!(
1426                    "saved anchored deviation runtime basis_dim must be positive, got {}",
1427                    self.basis_dim
1428                ),
1429            });
1430        }
1431        if self.breakpoints.len() < 2 {
1432            return Err(FittedModelError::SchemaMismatch {
1433                reason: format!(
1434                    "saved anchored deviation runtime requires at least two breakpoints, got {}",
1435                    self.breakpoints.len()
1436                ),
1437            });
1438        }
1439        for window in self.breakpoints.windows(2) {
1440            let left = window[0];
1441            let right = window[1];
1442            if !left.is_finite() || !right.is_finite() || right <= left {
1443                return Err(FittedModelError::PayloadCorrupt {
1444                    reason: format!(
1445                        "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1446                    ),
1447                });
1448            }
1449        }
1450        let span_count = self.breakpoints.len() - 1;
1451        self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1452        self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1453        self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1454        self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1455        self.validate_c2_span_continuity()?;
1456        self.validate_anchor_residual_shape()?;
1457        Ok(())
1458    }
1459
1460    fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1461        let coeffs = match self.anchor_correction.as_ref() {
1462            Some(c) => c,
1463            None => {
1464                if !self.anchor_components.is_empty() {
1465                    return Err(FittedModelError::SchemaMismatch {
1466                        reason:
1467                            "saved anchored deviation runtime has anchor_components but no anchor_correction"
1468                                .to_string(),
1469                    });
1470                }
1471                return Ok(());
1472            }
1473        };
1474        let d: usize = self
1475            .anchor_components
1476            .iter()
1477            .map(|c| match &c.kind {
1478                SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1479                SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1480            })
1481            .sum();
1482        if coeffs.len() != d {
1483            return Err(FittedModelError::SchemaMismatch {
1484                reason: format!(
1485                    "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1486                    coeffs.len(),
1487                    d,
1488                ),
1489            });
1490        }
1491        for (i, row) in coeffs.iter().enumerate() {
1492            if row.len() != self.basis_dim {
1493                return Err(FittedModelError::SchemaMismatch {
1494                    reason: format!(
1495                        "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1496                        i,
1497                        row.len(),
1498                        self.basis_dim,
1499                    ),
1500                });
1501            }
1502            for (j, &v) in row.iter().enumerate() {
1503                if !v.is_finite() {
1504                    return Err(FittedModelError::PayloadCorrupt {
1505                        reason: format!(
1506                            "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1507                        ),
1508                    });
1509                }
1510            }
1511        }
1512        Ok(())
1513    }
1514
1515    fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1516        const TOL: f64 = 1e-8;
1517        for span_idx in 1..self.breakpoints.len() - 1 {
1518            let left_span = span_idx - 1;
1519            let right_span = span_idx;
1520            let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1521            for basis_idx in 0..self.basis_dim {
1522                let left_value = self.span_c0[left_span][basis_idx]
1523                    + self.span_c1[left_span][basis_idx] * width
1524                    + self.span_c2[left_span][basis_idx] * width * width
1525                    + self.span_c3[left_span][basis_idx] * width * width * width;
1526                let left_d1 = self.span_c1[left_span][basis_idx]
1527                    + 2.0 * self.span_c2[left_span][basis_idx] * width
1528                    + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1529                let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1530                    + 6.0 * self.span_c3[left_span][basis_idx] * width;
1531                let right_value = self.span_c0[right_span][basis_idx];
1532                let right_d1 = self.span_c1[right_span][basis_idx];
1533                let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1534                if (left_value - right_value).abs() > TOL
1535                    || (left_d1 - right_d1).abs() > TOL
1536                    || (left_d2 - right_d2).abs() > TOL
1537                {
1538                    return Err(FittedModelError::SchemaMismatch {
1539                        reason: format!(
1540                            "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1541                            left_value - right_value,
1542                            left_d1 - right_d1,
1543                            left_d2 - right_d2
1544                        ),
1545                    });
1546                }
1547            }
1548        }
1549        Ok(())
1550    }
1551
1552    fn validate_coefficient_matrix(
1553        &self,
1554        matrix: &[Vec<f64>],
1555        label: &str,
1556        expected_rows: usize,
1557    ) -> Result<(), FittedModelError> {
1558        if matrix.len() != expected_rows {
1559            return Err(FittedModelError::SchemaMismatch {
1560                reason: format!(
1561                    "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1562                    matrix.len(),
1563                    expected_rows
1564                ),
1565            });
1566        }
1567        for (row_idx, row) in matrix.iter().enumerate() {
1568            if row.len() != self.basis_dim {
1569                return Err(FittedModelError::SchemaMismatch {
1570                    reason: format!(
1571                        "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1572                        row_idx,
1573                        row.len(),
1574                        self.basis_dim
1575                    ),
1576                });
1577            }
1578            for (j, &value) in row.iter().enumerate() {
1579                if !value.is_finite() {
1580                    return Err(FittedModelError::PayloadCorrupt {
1581                        reason: format!(
1582                            "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1583                        ),
1584                    });
1585                }
1586            }
1587        }
1588        Ok(())
1589    }
1590
1591    fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1592        let last_span = self.breakpoints.len() - 2;
1593        let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1594        self.span_c0[last_span][basis_idx]
1595            + self.span_c1[last_span][basis_idx] * width
1596            + self.span_c2[last_span][basis_idx] * width * width
1597            + self.span_c3[last_span][basis_idx] * width * width * width
1598    }
1599
1600    fn evaluate_span_polynomial_design(
1601        &self,
1602        values: &Array1<f64>,
1603        derivative_order: usize,
1604    ) -> Result<Array2<f64>, FittedModelError> {
1605        self.validate_exact_replay_contract()?;
1606        let (left_ep, right_ep) = self.support_interval()?;
1607        let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1608        for (row_idx, &value) in values.iter().enumerate() {
1609            if !value.is_finite() {
1610                return Err(FittedModelError::PayloadCorrupt {
1611                    reason: format!(
1612                        "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1613                    ),
1614                });
1615            }
1616            if value < left_ep {
1617                if derivative_order == 0 {
1618                    for basis_idx in 0..self.basis_dim {
1619                        out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1620                    }
1621                }
1622                continue;
1623            }
1624            if value > right_ep {
1625                if derivative_order == 0 {
1626                    for basis_idx in 0..self.basis_dim {
1627                        out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1628                    }
1629                }
1630                continue;
1631            }
1632            let span_idx = self.left_biased_span_index_for(value)?;
1633            let t = value - self.breakpoints[span_idx];
1634            for basis_idx in 0..self.basis_dim {
1635                let c0 = self.span_c0[span_idx][basis_idx];
1636                let c1 = self.span_c1[span_idx][basis_idx];
1637                let c2 = self.span_c2[span_idx][basis_idx];
1638                let c3 = self.span_c3[span_idx][basis_idx];
1639                out[[row_idx, basis_idx]] = match derivative_order {
1640                    0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1641                    1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1642                    2 => 2.0 * c2 + 6.0 * c3 * t,
1643                    3 => 6.0 * c3,
1644                    4 => 0.0,
1645                    other => {
1646                        return Err(FittedModelError::IncompatibleConfig {
1647                            reason: format!(
1648                                "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1649                            ),
1650                        });
1651                    }
1652                };
1653            }
1654        }
1655        Ok(out)
1656    }
1657
1658    pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1659        self.validate_exact_replay_contract()?;
1660        Ok(self.breakpoints.clone())
1661    }
1662
1663    pub fn span_count(&self) -> Result<usize, FittedModelError> {
1664        Ok(self.breakpoints()?.windows(2).count())
1665    }
1666
1667    pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1668        let points = self.breakpoints()?;
1669        span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1670            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1671    }
1672
1673    fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1674        let mut span_idx = span_index_for_breakpoints(
1675            &self.breakpoints,
1676            value,
1677            "saved anchored deviation span lookup",
1678        )
1679        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1680        // LEFT-bias at interior breakpoints mirrors DeviationRuntime. The
1681        // saved cubic basis is C2, but d3 remains span-local.
1682        if span_idx > 0 && value == self.breakpoints[span_idx] {
1683            span_idx -= 1;
1684        }
1685        Ok(span_idx)
1686    }
1687
1688    pub fn local_cubic_on_span(
1689        &self,
1690        beta: &Array1<f64>,
1691        span_idx: usize,
1692    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1693        self.validate_exact_replay_contract()?;
1694        if beta.len() != self.basis_dim {
1695            return Err(FittedModelError::SchemaMismatch {
1696                reason: format!(
1697                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1698                    beta.len(),
1699                    self.basis_dim
1700                ),
1701            });
1702        }
1703        self.local_cubic_on_span_validated(beta, span_idx)
1704    }
1705
1706    fn local_cubic_on_span_validated(
1707        &self,
1708        beta: &Array1<f64>,
1709        span_idx: usize,
1710    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1711        let points = &self.breakpoints;
1712        if span_idx + 1 >= points.len() {
1713            return Err(FittedModelError::SchemaMismatch {
1714                reason: format!(
1715                    "saved anchored deviation span index {} out of range for {} spans",
1716                    span_idx,
1717                    points.len() - 1
1718                ),
1719            });
1720        }
1721        let left = points[span_idx];
1722        let right = points[span_idx + 1];
1723        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1724            left,
1725            right,
1726            c0: self.span_c0[span_idx]
1727                .iter()
1728                .zip(beta.iter())
1729                .map(|(coeff, weight)| coeff * weight)
1730                .sum(),
1731            c1: self.span_c1[span_idx]
1732                .iter()
1733                .zip(beta.iter())
1734                .map(|(coeff, weight)| coeff * weight)
1735                .sum(),
1736            c2: self.span_c2[span_idx]
1737                .iter()
1738                .zip(beta.iter())
1739                .map(|(coeff, weight)| coeff * weight)
1740                .sum(),
1741            c3: self.span_c3[span_idx]
1742                .iter()
1743                .zip(beta.iter())
1744                .map(|(coeff, weight)| coeff * weight)
1745                .sum(),
1746        })
1747    }
1748
1749    pub fn basis_span_cubic(
1750        &self,
1751        span_idx: usize,
1752        basis_idx: usize,
1753    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1754        self.validate_exact_replay_contract()?;
1755        if basis_idx >= self.basis_dim {
1756            return Err(FittedModelError::SchemaMismatch {
1757                reason: format!(
1758                    "saved anchored deviation basis index {} out of range for {} coefficients",
1759                    basis_idx, self.basis_dim
1760                ),
1761            });
1762        }
1763        self.basis_span_cubic_validated(span_idx, basis_idx)
1764    }
1765
1766    fn basis_span_cubic_validated(
1767        &self,
1768        span_idx: usize,
1769        basis_idx: usize,
1770    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1771        let points = &self.breakpoints;
1772        if span_idx + 1 >= points.len() {
1773            return Err(FittedModelError::SchemaMismatch {
1774                reason: format!(
1775                    "saved anchored deviation span index {} out of range for {} spans",
1776                    span_idx,
1777                    points.len() - 1
1778                ),
1779            });
1780        }
1781        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1782            left: points[span_idx],
1783            right: points[span_idx + 1],
1784            c0: self.span_c0[span_idx][basis_idx],
1785            c1: self.span_c1[span_idx][basis_idx],
1786            c2: self.span_c2[span_idx][basis_idx],
1787            c3: self.span_c3[span_idx][basis_idx],
1788        })
1789    }
1790
1791    pub fn basis_cubic_at(
1792        &self,
1793        basis_idx: usize,
1794        value: f64,
1795    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1796        self.validate_exact_replay_contract()?;
1797        if basis_idx >= self.basis_dim {
1798            return Err(FittedModelError::SchemaMismatch {
1799                reason: format!(
1800                    "saved anchored deviation basis index {} out of range for {} coefficients",
1801                    basis_idx, self.basis_dim
1802                ),
1803            });
1804        }
1805        let (left_ep, right_ep) = self.support_interval()?;
1806        if value < left_ep {
1807            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1808                left: left_ep,
1809                right: left_ep + 1.0,
1810                c0: self.span_c0[0][basis_idx],
1811                c1: 0.0,
1812                c2: 0.0,
1813                c3: 0.0,
1814            });
1815        }
1816        if value > right_ep {
1817            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1818                left: right_ep,
1819                right: right_ep + 1.0,
1820                c0: self.right_boundary_basis_value(basis_idx),
1821                c1: 0.0,
1822                c2: 0.0,
1823                c3: 0.0,
1824            });
1825        }
1826        let span_idx = self.left_biased_span_index_for(value)?;
1827        self.basis_span_cubic_validated(span_idx, basis_idx)
1828    }
1829
1830    pub fn local_cubic_at(
1831        &self,
1832        beta: &Array1<f64>,
1833        value: f64,
1834    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1835        self.validate_exact_replay_contract()?;
1836        if beta.len() != self.basis_dim {
1837            return Err(FittedModelError::SchemaMismatch {
1838                reason: format!(
1839                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1840                    beta.len(),
1841                    self.basis_dim
1842                ),
1843            });
1844        }
1845        let (left_ep, right_ep) = self.support_interval()?;
1846        if value < left_ep {
1847            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1848                left: left_ep,
1849                right: left_ep + 1.0,
1850                c0: self.span_c0[0]
1851                    .iter()
1852                    .zip(beta.iter())
1853                    .map(|(coeff, weight)| coeff * weight)
1854                    .sum(),
1855                c1: 0.0,
1856                c2: 0.0,
1857                c3: 0.0,
1858            });
1859        }
1860        if value > right_ep {
1861            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1862                left: right_ep,
1863                right: right_ep + 1.0,
1864                c0: (0..self.basis_dim)
1865                    .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1866                    .sum(),
1867                c1: 0.0,
1868                c2: 0.0,
1869                c3: 0.0,
1870            });
1871        }
1872        let span_idx = self.left_biased_span_index_for(value)?;
1873        self.local_cubic_on_span_validated(beta, span_idx)
1874    }
1875
1876    fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1877        let points = self.breakpoints()?;
1878        match (points.first(), points.last()) {
1879            (Some(&left), Some(&right)) => Ok((left, right)),
1880            _ => Err(FittedModelError::MissingField {
1881                reason: "saved anchored deviation runtime is missing support breakpoints"
1882                    .to_string(),
1883            }),
1884        }
1885    }
1886
1887    pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1888        // Note: when the saved runtime carries an anchor residual
1889        // (cross-block orthogonalisation), the value `design()` returns
1890        // is the raw cubic span output *without* the per-row `n_row · M`
1891        // subtraction. Callers used inside BMS prediction must either
1892        // switch to `design_with_anchor_rows` (when the per-row anchor
1893        // rows are available) or call `design_uncorrected` explicitly and
1894        // apply the subtraction at the call site. For runtimes without a
1895        // residual the two paths coincide.
1896        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1897    }
1898
1899    /// Raw cubic-span design without any anchor-residual subtraction.
1900    ///
1901    /// Exposed for callers that intend to apply the `n_row · M` correction
1902    /// post-hoc (e.g., BMS `link_terms_value_d1` subtracts a precomputed
1903    /// `correction.dot(beta)` scalar from the linear-predictor contribution
1904    /// rather than building a full anchor-row matrix). Equivalent to
1905    /// `design()` when no residual is present.
1906    pub fn design_uncorrected(
1907        &self,
1908        values: &Array1<f64>,
1909    ) -> Result<Array2<f64>, FittedModelError> {
1910        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1911    }
1912
1913    /// Evaluate the residual-corrected design at the supplied values.
1914    ///
1915    /// `anchor_rows` must be an `n × d` matrix where `n == values.len()`
1916    /// and `d == sum of anchor_components ncols`. Each row holds
1917    /// the concatenated parametric anchor design at the same prediction
1918    /// row as the corresponding `values[i]`. When the runtime has no
1919    /// anchor residual, `anchor_rows` must have zero columns (or be
1920    /// `Array2::zeros((n, 0))`).
1921    pub fn design_with_anchor_rows(
1922        &self,
1923        values: &Array1<f64>,
1924        anchor_rows: ndarray::ArrayView2<f64>,
1925    ) -> Result<Array2<f64>, FittedModelError> {
1926        let mut out =
1927            self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1928        if let Some(m_rows) = self.anchor_correction.as_ref() {
1929            let d = m_rows.len();
1930            if anchor_rows.nrows() != values.len() {
1931                return Err(FittedModelError::SchemaMismatch {
1932                    reason: format!(
1933                        "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1934                        anchor_rows.nrows(),
1935                        values.len(),
1936                    ),
1937                });
1938            }
1939            if anchor_rows.ncols() != d {
1940                return Err(FittedModelError::SchemaMismatch {
1941                    reason: format!(
1942                        "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1943                        anchor_rows.ncols(),
1944                        d,
1945                    ),
1946                });
1947            }
1948            // Materialise M (d × basis_dim) once.
1949            let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1950            for (i, row) in m_rows.iter().enumerate() {
1951                if row.len() != self.basis_dim {
1952                    return Err(FittedModelError::SchemaMismatch {
1953                        reason: format!(
1954                            "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1955                            i,
1956                            row.len(),
1957                            self.basis_dim,
1958                        ),
1959                    });
1960                }
1961                for (j, &v) in row.iter().enumerate() {
1962                    m_dense[[i, j]] = v;
1963                }
1964            }
1965            // The compiler bakes the orthonormalising rotation into M, so
1966            // the predict-time subtraction is simply `n_anchor_rows · M`.
1967            let subtract = anchor_rows.dot(&m_dense);
1968            out = out - subtract;
1969        } else if anchor_rows.ncols() != 0 {
1970            return Err(FittedModelError::SchemaMismatch {
1971                reason: format!(
1972                    "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1973                    anchor_rows.ncols(),
1974                ),
1975            });
1976        }
1977        Ok(out)
1978    }
1979
1980    /// Build the n × basis_dim per-row, per-basis correction matrix
1981    /// `N · M` for a batch of predict rows.
1982    ///
1983    /// `n_anchor_rows` is the n × d matrix of stacked parametric anchor
1984    /// rows at the prediction rows (concatenation of the marginal and
1985    /// logslope design rows in component order). Returns `None` when the
1986    /// runtime has no anchor residual (zero-cost path).
1987    pub fn anchor_correction_matrix(
1988        &self,
1989        n_anchor_rows: ndarray::ArrayView2<f64>,
1990    ) -> Result<Option<Array2<f64>>, FittedModelError> {
1991        let Some(m_rows) = self.anchor_correction.as_ref() else {
1992            return Ok(None);
1993        };
1994        let d = m_rows.len();
1995        if n_anchor_rows.ncols() != d {
1996            return Err(FittedModelError::SchemaMismatch {
1997                reason: format!(
1998                    "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
1999                    n_anchor_rows.ncols(),
2000                    d,
2001                ),
2002            });
2003        }
2004        let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2005        for (i, row) in m_rows.iter().enumerate() {
2006            if row.len() != self.basis_dim {
2007                return Err(FittedModelError::SchemaMismatch {
2008                    reason: format!(
2009                        "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2010                        i,
2011                        row.len(),
2012                        self.basis_dim,
2013                    ),
2014                });
2015            }
2016            for (j, &v) in row.iter().enumerate() {
2017                m_dense[[i, j]] = v;
2018            }
2019        }
2020        // The compiler bakes the orthonormalising rotation into M, so
2021        // the predict-time correction is simply `n_anchor_rows · M`.
2022        Ok(Some(n_anchor_rows.dot(&m_dense)))
2023    }
2024
2025    pub fn first_derivative_design(
2026        &self,
2027        values: &Array1<f64>,
2028    ) -> Result<Array2<f64>, FittedModelError> {
2029        self.evaluate_span_polynomial_design(
2030            values,
2031            BasisOptions::first_derivative().derivative_order,
2032        )
2033    }
2034
2035    pub fn second_derivative_design(
2036        &self,
2037        values: &Array1<f64>,
2038    ) -> Result<Array2<f64>, FittedModelError> {
2039        self.evaluate_span_polynomial_design(
2040            values,
2041            BasisOptions::second_derivative().derivative_order,
2042        )
2043    }
2044}
2045
2046impl FittedFamily {
2047    #[inline]
2048    pub fn likelihood(&self) -> LikelihoodSpec {
2049        let spec = match self {
2050            Self::Standard { likelihood, .. }
2051            | Self::LocationScale { likelihood, .. }
2052            | Self::MarginalSlope { likelihood, .. }
2053            | Self::Survival { likelihood, .. }
2054            | Self::TransformationNormal { likelihood, .. } => likelihood,
2055            Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2056                return LikelihoodSpec::royston_parmar();
2057            }
2058        };
2059        spec.clone()
2060    }
2061
2062    #[inline]
2063    pub fn frailty(&self) -> Option<&FrailtySpec> {
2064        match self {
2065            Self::MarginalSlope { frailty, .. }
2066            | Self::Survival { frailty, .. }
2067            | Self::LatentSurvival { frailty }
2068            | Self::LatentBinary { frailty } => Some(frailty),
2069            _ => None,
2070        }
2071    }
2072}
2073
2074/// Recursively collect the feature columns of a smooth basis whose out-of-hull
2075/// evaluation is bounded, so they can be exempted from the predict-time axis
2076/// clip (see [`FittedModel::training_smooth_extrapolation_axes`]). Wrapper bases
2077/// (`by=`, factor-smooth, sum-to-zero) delegate to their inner smooth; `Sphere`
2078/// and `Pca` are intentionally not collected.
2079fn collect_smooth_extrapolation_axes(
2080    basis: &gam_terms::smooth::SmoothBasisSpec,
2081    n_training_headers: usize,
2082    out: &mut std::collections::HashSet<usize>,
2083) {
2084    use gam_terms::smooth::SmoothBasisSpec;
2085    let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2086        if col < n_training_headers {
2087            out.insert(col);
2088        }
2089    };
2090    match basis {
2091        // 1D B-spline: first-order linear extension off the boundary slope.
2092        SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2093        // Tensor B-spline: each margin linear-extends independently. Periodic
2094        // margins are additionally (and harmlessly) exempted via the periodic set.
2095        SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2096            for &c in feature_cols {
2097                push(c, out);
2098            }
2099        }
2100        // Radial bases with a bounded out-of-hull contract: Duchon / thin-plate
2101        // are linear outside the data span (natural-spline boundary conditions),
2102        // Matérn reverts to its mean as the kernel decays. Measure-jet shares
2103        // the Matérn contract (Gaussian representers decay to the parametric
2104        // layer off the data support) — and off-web queries are exactly the
2105        // ones its support diagnostic must see unclipped.
2106        SmoothBasisSpec::ThinPlate { feature_cols, .. }
2107        | SmoothBasisSpec::Matern { feature_cols, .. }
2108        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2109        | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2110            for &c in feature_cols {
2111                push(c, out);
2112            }
2113        }
2114        // Factor-smooth: the continuous marginal axes are B-splines that
2115        // linear-extend; the group column is categorical and is left to the
2116        // random-effect / level-lookup machinery.
2117        SmoothBasisSpec::FactorSmooth { spec } => {
2118            for &c in &spec.continuous_cols {
2119                push(c, out);
2120            }
2121        }
2122        // Wrappers delegate to the inner smooth they modulate / replicate.
2123        SmoothBasisSpec::ByVariable { inner, .. }
2124        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2125            collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2126        }
2127        SmoothBasisSpec::BySmooth { smooth, .. } => {
2128            collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2129        }
2130        // Sphere: latitude is clipped to manifold bounds, longitude is periodic —
2131        // both handled elsewhere with non-plateau semantics. Pca: no extrapolation
2132        // contract, stays clipped. ConstantCurvature: chart coordinates must stay
2133        // inside the κ-stereographic chart (open ball for κ < 0), so clipping new
2134        // data to the training range is the safe out-of-hull behavior.
2135        SmoothBasisSpec::Sphere { .. }
2136        | SmoothBasisSpec::ConstantCurvature { .. }
2137        | SmoothBasisSpec::Pca { .. } => {}
2138    }
2139}
2140
2141/// Collect training-column indices that feed a *numeric* `by=` multiplier of a
2142/// varying-coefficient smooth (`s(x, by=z)`).
2143///
2144/// A numeric by-variable enters the model as a linear multiplier of the centred
2145/// smooth basis, so the prediction is exactly affine in `z` for any fixed `x`:
2146/// `pred(x, z) = intercept + z·f(x)`. The by-multiplier is therefore an
2147/// *unbounded linear* extrapolant — exactly like a parametric `linear` axis
2148/// (already exempt from the clip via `training_linear_axes`), not a
2149/// bounded-basis axis. Clamping it to the training range `[min(z), max(z)]`
2150/// before the design is built would make the varying-coefficient effect plateau
2151/// outside the sampled range (slope 0 above the max) and silently replace the
2152/// natural `z == 0` baseline with the `z == min` prediction — the prediction is
2153/// no longer affine in `z`. Exempt it from the predict-time axis clip.
2154///
2155/// Factor `by=` columns are categorical group labels handled by the
2156/// level-lookup / random-effect machinery and are deliberately *not* exempted
2157/// here. Returned indices reference `self.training_headers`, matching the
2158/// iteration in `axis_clip_to_training_ranges`.
2159fn collect_by_variable_numeric_axes(
2160    basis: &gam_terms::smooth::SmoothBasisSpec,
2161    n_training_headers: usize,
2162    out: &mut std::collections::HashSet<usize>,
2163) {
2164    use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2165    match basis {
2166        SmoothBasisSpec::ByVariable {
2167            inner,
2168            by_col,
2169            kind,
2170            ..
2171        } => {
2172            if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2173                out.insert(*by_col);
2174            }
2175            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2176        }
2177        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2178            if let ByVarKind::Numeric { feature_col } = by_kind
2179                && *feature_col < n_training_headers
2180            {
2181                out.insert(*feature_col);
2182            }
2183            collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2184        }
2185        SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2186            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2187        }
2188        _ => {}
2189    }
2190}
2191
2192impl FittedModel {
2193    /// Axis-clip each continuous new-data column to the (min, max) range
2194    /// observed in training. Categorical and binary columns are left
2195    /// untouched so unseen levels surface rather than being silently remapped
2196    /// onto seen ones. Returns `Some(clipped_copy)` only if at least one
2197    /// value was actually clipped; otherwise `None` so callers can avoid
2198    /// owning a redundant copy. Pre-2026-04-29 model JSONs that lack the
2199    /// `training_feature_ranges` field deserialize to `None` and pass through
2200    /// unchanged.
2201    pub fn axis_clip_to_training_ranges(
2202        &self,
2203        data: ndarray::ArrayView2<'_, f64>,
2204        col_map: &std::collections::HashMap<String, usize>,
2205    ) -> Option<ndarray::Array2<f64>> {
2206        let training_headers = self.training_headers.as_ref()?;
2207        let ranges = self.training_feature_ranges.as_ref()?;
2208        if training_headers.len() != ranges.len() {
2209            return None;
2210        }
2211        let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2212            std::collections::HashMap::new();
2213        if let Some(schema) = self.data_schema.as_ref() {
2214            for col in &schema.columns {
2215                kind_by_header.insert(col.name.as_str(), col.kind);
2216            }
2217        }
2218        // Periodic axes (sphere longitude, periodic-B-spline 1D, periodic
2219        // tensor margins) must never be clipped to the training range:
2220        // clamping a value just past the seam to the training extreme breaks
2221        // the cyclic invariant f(x₀) = f(x₀ + period) at predict time and
2222        // shows up as a visible seam in surface plots.
2223        let periodic_axes = self.training_periodic_axes(training_headers);
2224        // Parametric/linear-term axes must never be clipped either: a linear
2225        // term's contract is η = β0 + β1·x, i.e. genuine linear extrapolation.
2226        // Clamping its input to the training extreme turns predict into a
2227        // piecewise-constant plateau outside the training hull and freezes the
2228        // prediction SE at the boundary (the clamped x feeds xᵀ Var(β) x), so
2229        // credible intervals stop widening with distance from the data. This
2230        // mirrors how periodic axes are exempted just above.
2231        let linear_axes = self.training_linear_axes(training_headers.len());
2232        // Random-effect grouping axes are categorical even when their source
2233        // column is numeric. Clipping them would remap an unseen group label to
2234        // a boundary training level instead of letting the random-effect block
2235        // encode it as the prior-mean zero effect.
2236        let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2237        // Non-parametric smooth axes whose basis extrapolates boundedly on its
2238        // own (B-spline linear extension, Duchon/thin-plate natural-spline linear
2239        // tail, Matérn kernel decay). Clamping their input to the training extreme
2240        // hands the basis an already-clamped coordinate, so its extrapolation
2241        // never fires and predict freezes at a boundary plateau — diverging from
2242        // the raw design path, which does not clip. Exempt them so both paths go
2243        // through the single basis-layer extrapolation. See the method doc.
2244        let smooth_extrapolation_axes =
2245            self.training_smooth_extrapolation_axes(training_headers.len());
2246        // Numeric `by=` multipliers of a varying-coefficient smooth `s(x, by=z)`
2247        // are linear multipliers of the centred basis (prediction affine in z),
2248        // so — like parametric linear axes — clipping them to the training range
2249        // turns the varying-coefficient effect into a boundary plateau and
2250        // destroys the z==0 baseline. Exempt them from the clip.
2251        let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2252        // Sphere latitude is a closed-manifold coordinate: its clip bounds are
2253        // the manifold's intrinsic domain ([-π/2, π/2] or [-90, 90]), not the
2254        // sampled range, so a pole prediction reaches the true pole instead of
2255        // being clamped to a near-pole latitude (see the method doc).
2256        let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2257        let mut clipped = data.to_owned();
2258        let mut any_clipped = false;
2259        for (col_in_training, (header, &(lo, hi))) in
2260            training_headers.iter().zip(ranges.iter()).enumerate()
2261        {
2262            let (lo, hi) = sphere_lat_bounds
2263                .get(&col_in_training)
2264                .copied()
2265                .unwrap_or((lo, hi));
2266            if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2267                continue;
2268            }
2269            if !matches!(
2270                kind_by_header.get(header.as_str()).copied(),
2271                Some(ColumnKindTag::Continuous)
2272            ) {
2273                continue;
2274            }
2275            if periodic_axes.contains(&col_in_training) {
2276                continue;
2277            }
2278            if linear_axes.contains(&col_in_training) {
2279                continue;
2280            }
2281            if random_effect_axes.contains(&col_in_training) {
2282                continue;
2283            }
2284            if smooth_extrapolation_axes.contains(&col_in_training) {
2285                continue;
2286            }
2287            if by_variable_axes.contains(&col_in_training) {
2288                continue;
2289            }
2290            let Some(&col_idx) = col_map.get(header) else {
2291                continue;
2292            };
2293            if col_idx >= clipped.ncols() {
2294                continue;
2295            }
2296            let mut col = clipped.column_mut(col_idx);
2297            for v in col.iter_mut() {
2298                if v.is_finite() {
2299                    if *v < lo {
2300                        *v = lo;
2301                        any_clipped = true;
2302                    } else if *v > hi {
2303                        *v = hi;
2304                        any_clipped = true;
2305                    }
2306                }
2307            }
2308        }
2309        if any_clipped { Some(clipped) } else { None }
2310    }
2311
2312    fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2313        let mut specs: Vec<&TermCollectionSpec> = [
2314            self.resolved_termspec.as_ref(),
2315            self.resolved_termspec_noise.as_ref(),
2316            self.resolved_termspec_logslope.as_ref(),
2317        ]
2318        .into_iter()
2319        .flatten()
2320        .collect();
2321        if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2322            specs.extend(logslopes.iter());
2323        }
2324        specs
2325    }
2326
2327    /// Collect the set of training-column indices that are periodic axes —
2328    /// i.e. features for which a periodic basis (sphere longitude, periodic
2329    /// B-spline 1D, periodic tensor margin) must be allowed to take any
2330    /// real value at predict time and not be clamped to the training range.
2331    /// Returned indices reference `self.training_headers` (training-time
2332    /// layout), matching the iteration in `axis_clip_to_training_ranges`.
2333    fn training_periodic_axes(
2334        &self,
2335        training_headers: &[String],
2336    ) -> std::collections::HashSet<usize> {
2337        use gam_terms::basis::BSplineKnotSpec;
2338        use gam_terms::smooth::SmoothBasisSpec;
2339        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2340        let Some(spec) = self.resolved_termspec.as_ref() else {
2341            return out;
2342        };
2343        for term in &spec.smooth_terms {
2344            match &term.basis {
2345                // Sphere terms: longitude (second feature col) is always
2346                // periodic and exempt from clipping. Latitude is not periodic
2347                // but is a closed-manifold coordinate, so it is clipped to the
2348                // manifold's intrinsic bounds rather than the sampled range —
2349                // see `training_sphere_latitude_bounds`.
2350                SmoothBasisSpec::Sphere { feature_cols, .. } => {
2351                    if let Some(&lon_col) = feature_cols.get(1)
2352                        && lon_col < training_headers.len()
2353                    {
2354                        out.insert(lon_col);
2355                    }
2356                }
2357                // 1D periodic B-spline: the single feature column is periodic.
2358                SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2359                    if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2360                        && *feature_col < training_headers.len()
2361                    {
2362                        out.insert(*feature_col);
2363                    }
2364                }
2365                // Tensor B-spline: each axis whose marginal knotspec is
2366                // PeriodicUniform is periodic; mark those columns.
2367                SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2368                    for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2369                        if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2370                            && let Some(&col) = feature_cols.get(i)
2371                            && col < training_headers.len()
2372                        {
2373                            out.insert(col);
2374                        }
2375                    }
2376                }
2377                _ => {}
2378            }
2379        }
2380        out
2381    }
2382
2383    /// Collect the set of training-column indices that feed a parametric/linear
2384    /// term — on *any* modelled surface (mean, noise/scale, log-slope). A
2385    /// linear term realises the design column `∏ feature_cols` and contributes
2386    /// `β·(that product)` to the linear predictor, so its inputs must be allowed
2387    /// to take any real value at predict time: clamping them to the training
2388    /// range would replace genuine linear extrapolation with a boundary plateau
2389    /// (and freeze the prediction SE at the hull edge). Returned indices
2390    /// reference `self.training_headers` (training-time layout), matching the
2391    /// iteration in `axis_clip_to_training_ranges`. Wilkinson-Rogers `:`
2392    /// interactions contribute every column in their `feature_cols` product.
2393    fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2394        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2395        for spec in self.saved_term_specs() {
2396            for term in &spec.linear_terms {
2397                for col in term.effective_feature_cols() {
2398                    if col < n_training_headers {
2399                        out.insert(col);
2400                    }
2401                }
2402            }
2403        }
2404        out
2405    }
2406
2407    /// Collect the set of training-column indices that feed random-effect
2408    /// grouping terms. These columns are categorical model axes regardless of
2409    /// the ingest schema's scalar storage type, so prediction must leave them
2410    /// untouched and let the frozen random-effect levels decide whether a row is
2411    /// a seen level or the zero-effect unseen-level fallback.
2412    fn training_random_effect_axes(
2413        &self,
2414        n_training_headers: usize,
2415    ) -> std::collections::HashSet<usize> {
2416        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2417        for spec in self.saved_term_specs() {
2418            for term in &spec.random_effect_terms {
2419                if term.feature_col < n_training_headers {
2420                    out.insert(term.feature_col);
2421                }
2422            }
2423        }
2424        out
2425    }
2426
2427    /// Collect the set of training-column indices that feed a non-parametric
2428    /// smooth whose basis performs its own *bounded* extrapolation outside the
2429    /// training hull — on any modelled surface (mean, noise/scale, log-slope).
2430    ///
2431    /// These columns must be exempt from the predict-time axis clip for the same
2432    /// reason periodic/linear/random-effect axes are: the clip clamps a new value
2433    /// to the training extreme *before* the design is built, so the basis is
2434    /// handed an already-clamped coordinate and its extrapolation machinery never
2435    /// fires. The result is a piecewise-constant plateau frozen at the boundary
2436    /// fitted value (with a prediction SE frozen at the hull edge), and — worse —
2437    /// a model that yields *different* predictions through the `FittedModel`
2438    /// predict pipeline than through the raw `build_term_collection_design` path,
2439    /// which does not clip. Exempting these axes routes both entry points through
2440    /// the single basis-layer extrapolation, restoring internal consistency.
2441    ///
2442    /// Only bases with a *bounded* out-of-hull contract are listed, so removing
2443    /// the clip cannot reintroduce the wild basis blow-up the clip guards against:
2444    ///   - B-spline 1D / tensor margins: first-order linear extension off the
2445    ///     boundary slope (`apply_linear_extension_from_first_derivative`) — grows
2446    ///     at most linearly.
2447    ///   - Duchon / thin-plate: the natural-spline boundary conditions make the
2448    ///     fit linear outside the data span — also at most linear growth.
2449    ///   - Matérn: the kernel decays with distance, so the fit reverts smoothly to
2450    ///     its (low-order polynomial / constant) mean far from the data — bounded.
2451    /// `sphere()` axes are deliberately *not* listed: latitude is a closed-manifold
2452    /// coordinate clipped to its intrinsic bounds (`training_sphere_latitude_bounds`)
2453    /// and longitude is periodic (`training_periodic_axes`); both already have the
2454    /// correct, non-plateau handling. `Pca` projections have no extrapolation
2455    /// contract and stay clipped.
2456    ///
2457    /// Returned indices reference `self.training_headers` (training-time layout),
2458    /// matching the iteration in `axis_clip_to_training_ranges`.
2459    fn training_smooth_extrapolation_axes(
2460        &self,
2461        n_training_headers: usize,
2462    ) -> std::collections::HashSet<usize> {
2463        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2464        for spec in self.saved_term_specs() {
2465            for term in &spec.smooth_terms {
2466                collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2467            }
2468        }
2469        out
2470    }
2471
2472    /// Collect the set of training-column indices that feed a *numeric* `by=`
2473    /// multiplier of a varying-coefficient smooth. These columns are linear
2474    /// multipliers (`pred = intercept + z·f(x)`), so they must be exempt from
2475    /// the predict-time axis clip for the same reason parametric linear axes
2476    /// are — see `collect_by_variable_numeric_axes` for the full rationale.
2477    fn training_by_variable_numeric_axes(
2478        &self,
2479        n_training_headers: usize,
2480    ) -> std::collections::HashSet<usize> {
2481        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2482        for spec in self.saved_term_specs() {
2483            for term in &spec.smooth_terms {
2484                collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2485            }
2486        }
2487        out
2488    }
2489
2490    /// Manifold-intrinsic clip bounds for sphere *latitude* columns.
2491    ///
2492    /// A `sphere(lat, lon)` smooth charts the closed manifold S²: the poles
2493    /// (lat = ±π/2, or ±90°) are interior limit points of that manifold, not
2494    /// endpoints of an unbounded axis to extrapolate along. A finite training
2495    /// sample never reaches the pole exactly, so clamping a pole prediction to
2496    /// the *observed* latitude extreme lands at a near-pole latitude where the
2497    /// Wahba SOS kernel's `cos(lat)·cos(lat_c)·cos(Δlon)` term has not yet
2498    /// damped — and the predictor then sweeps a spurious `cos(lon)` profile at
2499    /// what is physically a single point, reintroducing the pole artefact the
2500    /// SOS basis exists to remove.
2501    ///
2502    /// The correct clip bound for this coordinate is therefore the manifold's
2503    /// intrinsic domain — `[-π/2, π/2]` radians or `[-90, 90]` degrees — not
2504    /// the sampled range. Clamping to those bounds keeps the pole reachable
2505    /// (single-valued in longitude) while still mapping any out-of-domain
2506    /// latitude onto the manifold boundary. Longitude needs no entry here: it
2507    /// is periodic and already exempted from clipping entirely
2508    /// (`training_periodic_axes`). Returned indices reference
2509    /// `self.training_headers`, matching the iteration in
2510    /// `axis_clip_to_training_ranges`.
2511    fn training_sphere_latitude_bounds(
2512        &self,
2513        training_headers: &[String],
2514    ) -> std::collections::HashMap<usize, (f64, f64)> {
2515        use gam_terms::smooth::SmoothBasisSpec;
2516        let mut out: std::collections::HashMap<usize, (f64, f64)> =
2517            std::collections::HashMap::new();
2518        let Some(spec) = self.resolved_termspec.as_ref() else {
2519            return out;
2520        };
2521        for term in &spec.smooth_terms {
2522            if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2523                && let Some(&lat_col) = feature_cols.first()
2524                && lat_col < training_headers.len()
2525            {
2526                let bound = if spec.radians {
2527                    std::f64::consts::FRAC_PI_2
2528                } else {
2529                    90.0
2530                };
2531                out.insert(lat_col, (-bound, bound));
2532            }
2533        }
2534        out
2535    }
2536
2537    pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2538        let likelihood = payload.family_state.likelihood();
2539        let class = match payload.model_kind {
2540            ModelKind::Survival => PredictModelClass::Survival,
2541            ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2542            ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2543            ModelKind::LocationScale => {
2544                if likelihood == LikelihoodSpec::gaussian_identity() {
2545                    PredictModelClass::GaussianLocationScale
2546                } else if is_dispersion_location_scale_response(&likelihood.response) {
2547                    PredictModelClass::DispersionLocationScale
2548                } else {
2549                    PredictModelClass::BinomialLocationScale
2550                }
2551            }
2552            ModelKind::Standard => PredictModelClass::Standard,
2553        };
2554        match class {
2555            PredictModelClass::Survival => {
2556                payload.model_kind = ModelKind::Survival;
2557                Self::Survival { payload }
2558            }
2559            PredictModelClass::BernoulliMarginalSlope => {
2560                payload.model_kind = ModelKind::MarginalSlope;
2561                Self::MarginalSlope { payload }
2562            }
2563            PredictModelClass::TransformationNormal => {
2564                payload.model_kind = ModelKind::TransformationNormal;
2565                Self::TransformationNormal { payload }
2566            }
2567            PredictModelClass::GaussianLocationScale
2568            | PredictModelClass::BinomialLocationScale
2569            | PredictModelClass::DispersionLocationScale => {
2570                payload.model_kind = ModelKind::LocationScale;
2571                Self::LocationScale { payload }
2572            }
2573            PredictModelClass::Standard => {
2574                payload.model_kind = ModelKind::Standard;
2575                Self::Standard { payload }
2576            }
2577        }
2578        .with_synchronized_stateful_link_metadata()
2579    }
2580
2581    #[inline]
2582    pub fn payload(&self) -> &FittedModelPayload {
2583        match self {
2584            Self::Standard { payload }
2585            | Self::LocationScale { payload }
2586            | Self::MarginalSlope { payload }
2587            | Self::Survival { payload }
2588            | Self::TransformationNormal { payload } => payload,
2589        }
2590    }
2591
2592    #[inline]
2593    fn payload_mut(&mut self) -> &mut FittedModelPayload {
2594        match self {
2595            Self::Standard { payload }
2596            | Self::LocationScale { payload }
2597            | Self::MarginalSlope { payload }
2598            | Self::Survival { payload }
2599            | Self::TransformationNormal { payload } => payload,
2600        }
2601    }
2602
2603    fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2604        self.synchronize_stateful_link_metadata();
2605        self
2606    }
2607
2608    fn synchronize_stateful_link_metadata(&mut self) {
2609        let payload = self.payload_mut();
2610        payload.used_device = payload
2611            .fit_result
2612            .as_ref()
2613            .or(payload.unified.as_ref())
2614            .is_some_and(|fit| fit.used_device);
2615        payload.synchronize_empty_feature_contract();
2616        let Some(fit) = payload.fit_result.as_ref() else {
2617            return;
2618        };
2619        match (&mut payload.family_state, &fit.fitted_link) {
2620            (
2621                FittedFamily::Standard {
2622                    likelihood,
2623                    latent_cloglog_state,
2624                    ..
2625                },
2626                FittedLinkState::LatentCLogLog { state },
2627            ) if likelihood.is_latent_cloglog() => {
2628                *latent_cloglog_state = Some(*state);
2629            }
2630            (
2631                FittedFamily::Standard {
2632                    likelihood,
2633                    sas_state,
2634                    ..
2635                },
2636                FittedLinkState::Sas { state, covariance },
2637            ) if likelihood.is_binomial_sas() => {
2638                *sas_state = Some(*state);
2639                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2640            }
2641            (
2642                FittedFamily::Standard {
2643                    likelihood,
2644                    sas_state,
2645                    ..
2646                },
2647                FittedLinkState::BetaLogistic { state, covariance },
2648            ) if likelihood.is_binomial_beta_logistic() => {
2649                *sas_state = Some(*state);
2650                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2651            }
2652            (
2653                FittedFamily::Standard {
2654                    likelihood,
2655                    mixture_state,
2656                    ..
2657                },
2658                FittedLinkState::Mixture { state, covariance },
2659            ) if likelihood.is_binomial_mixture() => {
2660                *mixture_state = Some(state.clone());
2661                payload.mixture_link_param_covariance =
2662                    covariance.as_ref().map(array2_to_nestedvec);
2663            }
2664            _ => {}
2665        }
2666    }
2667
2668    #[inline]
2669    pub fn likelihood(&self) -> LikelihoodSpec {
2670        self.payload().family_state.likelihood()
2671    }
2672
2673    /// Columns this model consumes from a prediction frame — its *input
2674    /// contract*.
2675    ///
2676    /// Every variable named by the main formula (features, interaction margins,
2677    /// random-effect groups, and a smooth's `by=` column), the survival
2678    /// entry/exit columns or the transformation-normal response, the auxiliary
2679    /// noise / logslope formula columns, and the offset / noise-offset /
2680    /// latent-`z` columns. The event-indicator and the plain response of a
2681    /// standard model are deliberately excluded: they are not needed to *form*
2682    /// a prediction (the conformal-calibration fold layers the response back on
2683    /// separately).
2684    ///
2685    /// This is the single authority shared by the CLI and PyFFI predict paths.
2686    /// A prediction frame column that is *not* in this set is irrelevant to the
2687    /// model and must be ignored rather than strict-encoded against the
2688    /// training schema — otherwise an unrelated ID/label column with a held-out
2689    /// categorical level aborts predict (#840).
2690    pub fn prediction_required_columns(
2691        &self,
2692    ) -> Result<std::collections::BTreeSet<String>, String> {
2693        let payload = self.payload();
2694        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2695        let mut required = std::collections::BTreeSet::<String>::new();
2696        parsed_term_column_names(&parsed.terms, &mut required);
2697
2698        if let Some((entry, exit, _event)) =
2699            parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2700        {
2701            if let Some(entry) = entry {
2702                required.insert(entry);
2703            }
2704            required.insert(exit);
2705        } else if let Some((left, right, _event)) =
2706            parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707        {
2708            required.insert(left);
2709            required.insert(right);
2710        }
2711        // A transformation-normal (CTM) prediction returns the response-scale
2712        // conditional mean E[Y|x], a function of the covariates alone (issue
2713        // #1612). The earlier implementation precomputed the PIT h(y|x) of the
2714        // supplied response, which made the outcome column mandatory at predict
2715        // time; the response is no longer required, so a covariate-only frame
2716        // must predict without it.
2717
2718        if let Some(offset) = payload.offset_column.as_ref() {
2719            required.insert(offset.clone());
2720        }
2721        if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2722            required.insert(noise_offset.clone());
2723        }
2724        if matches!(
2725            self.predict_model_class(),
2726            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2727        ) {
2728            if let Some(z_column) = payload.z_column.as_ref() {
2729                required.remove("z");
2730                required.insert(z_column.clone());
2731            }
2732        }
2733        if let Some(noise_formula) = payload.formula_noise.as_ref() {
2734            self.add_auxiliary_formula_columns(
2735                &mut required,
2736                noise_formula,
2737                parsed.response.as_str(),
2738            )?;
2739        }
2740        if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2741            if logslope_formula != "same-as-main" {
2742                self.add_auxiliary_formula_columns(
2743                    &mut required,
2744                    logslope_formula,
2745                    parsed.response.as_str(),
2746                )?;
2747            }
2748        }
2749        Ok(required)
2750    }
2751
2752    /// Columns a *post-fit diagnostic* command (diagnose / sample / report)
2753    /// needs **beyond** [`Self::prediction_required_columns`].
2754    ///
2755    /// Prediction deliberately drops a standard GAM's bare response so a
2756    /// prediction frame may omit it (#840 / #864). Diagnostics are statements
2757    /// *about* that observed response — residuals, R², posterior likelihoods,
2758    /// leave-one-out — so the response must be present. This returns the bare
2759    /// response column when the prediction projection would otherwise drop it,
2760    /// and nothing when the response is already prediction-required (survival
2761    /// `Surv(...)` time/event columns, the transformation-normal response) or
2762    /// is not a plain data column.
2763    ///
2764    /// Centralising the intent here is what makes it *structurally impossible*
2765    /// for a diagnostic command to silently drop the response: callers use
2766    /// `load_dataset…_for_diagnostics`, which always folds these in, instead of
2767    /// each remembering to thread an `extra_required` response by hand.
2768    pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2769        let payload = self.payload();
2770        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2771        // Survival responses are `Surv(...)` expressions, not bare columns; the
2772        // underlying entry/exit columns are already prediction-required.
2773        if parse_surv_response(parsed.response.as_str())
2774            .map_err(|e| e.to_string())?
2775            .is_some()
2776            || parse_surv_interval_response(parsed.response.as_str())
2777                .map_err(|e| e.to_string())?
2778                .is_some()
2779        {
2780            return Ok(Vec::new());
2781        }
2782        let response = parsed.response.trim();
2783        // A response that is empty, or a function-call expression rather than a
2784        // plain data column, has no bare column to re-add.
2785        if response.is_empty() || response.contains('(') {
2786            return Ok(Vec::new());
2787        }
2788        // Already prediction-required (e.g. transformation-normal re-adds it):
2789        // nothing extra to fold in.
2790        if self.prediction_required_columns()?.contains(response) {
2791            return Ok(Vec::new());
2792        }
2793        Ok(vec![response.to_string()])
2794    }
2795
2796    /// Add the columns referenced by an auxiliary (noise / logslope) formula,
2797    /// which may be supplied as a full `lhs ~ rhs` formula or as a bare RHS.
2798    fn add_auxiliary_formula_columns(
2799        &self,
2800        required: &mut std::collections::BTreeSet<String>,
2801        formula_or_rhs: &str,
2802        response: &str,
2803    ) -> Result<(), String> {
2804        let trimmed = formula_or_rhs.trim();
2805        if trimmed.is_empty() || trimmed == "1" {
2806            return Ok(());
2807        }
2808        let formula = if trimmed.contains('~') {
2809            trimmed.to_string()
2810        } else {
2811            format!("{response} ~ {trimmed}")
2812        };
2813        let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2814        parsed_term_column_names(&parsed.terms, required);
2815        Ok(())
2816    }
2817
2818    #[inline]
2819    pub fn predict_model_class(&self) -> PredictModelClass {
2820        match &self.payload().family_state {
2821            FittedFamily::Survival { .. }
2822            | FittedFamily::LatentSurvival { .. }
2823            | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2824            FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2825            FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2826            FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2827                PredictModelClass::GaussianLocationScale
2828            }
2829            FittedFamily::LocationScale { likelihood, .. }
2830                if is_dispersion_location_scale_response(&likelihood.response) =>
2831            {
2832                PredictModelClass::DispersionLocationScale
2833            }
2834            FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2835            FittedFamily::Standard { .. } => PredictModelClass::Standard,
2836        }
2837    }
2838
2839    pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2840        let payload = self.payload();
2841        let (knots, degree) = match (
2842            payload.linkwiggle_knots.as_ref(),
2843            payload.linkwiggle_degree,
2844        ) {
2845            (None, None) => return Ok(None),
2846            (Some(knots), Some(degree)) => (knots.clone(), degree),
2847            _ => {
2848                return Err(FittedModelError::SchemaMismatch {
2849                    reason:
2850                        "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2851                            .to_string(),
2852                })
2853            }
2854        };
2855        let resolved_link = self.resolved_inverse_link()?;
2856        let saved_link_disallows_wiggle = resolved_link
2857            .as_ref()
2858            .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2859            || payload
2860                .link
2861                .as_ref()
2862                .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2863        if saved_link_disallows_wiggle {
2864            return Err(FittedModelError::IncompatibleConfig {
2865                reason: joint_wiggle_unsupported_link_message("link wiggle"),
2866            });
2867        }
2868        let beta = match self.predict_model_class() {
2869            // #1596: the frozen-basis de-aliased standard link-warp is fit in a
2870            // reduced, identifiable coordinate `γ` and the fit_result LinkWiggle
2871            // block stores `γ` (its true free parameters). The full-width
2872            // standard-basis lift `β_w = Z·γ`, the coefficients the predict-time
2873            // I-spline basis multiplies, is persisted in `payload.beta_link_wiggle`
2874            // — prefer it when present. Without it (the dynamic-basis path) the
2875            // block coefficients ARE the standard-basis warp, read directly.
2876            PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2877                payload.beta_link_wiggle.clone().expect("checked is_some")
2878            }
2879            PredictModelClass::Standard => {
2880                let fit = payload.fit_result.as_ref().ok_or_else(|| {
2881                    FittedModelError::MissingField {
2882                        reason:
2883                            "standard link-wiggle model is missing canonical fit_result payload"
2884                                .to_string(),
2885                    }
2886                })?;
2887                if fit.blocks.len() != 2
2888                    || fit.blocks[0].role != BlockRole::Mean
2889                    || fit.blocks[1].role != BlockRole::LinkWiggle
2890                {
2891                    return Err(FittedModelError::SchemaMismatch {
2892                        reason:
2893                            "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2894                                .to_string(),
2895                    });
2896                }
2897                fit.block_by_role(BlockRole::LinkWiggle)
2898                    .ok_or_else(|| FittedModelError::MissingField {
2899                        reason:
2900                            "standard link-wiggle model is missing LinkWiggle coefficient block"
2901                                .to_string(),
2902                    })?
2903                    .beta
2904                    .to_vec()
2905            }
2906            _ => payload
2907                .beta_link_wiggle
2908                .clone()
2909                .ok_or_else(|| FittedModelError::MissingField {
2910                    reason:
2911                        "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2912                            .to_string(),
2913                })?,
2914        };
2915        Ok(Some(SavedLinkWiggleRuntime {
2916            knots,
2917            degree,
2918            beta,
2919        }))
2920    }
2921
2922    pub fn saved_baseline_time_wiggle(
2923        &self,
2924    ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2925        let payload = self.payload();
2926        if payload
2927            .survival_cause_count
2928            .is_some_and(|cause_count| cause_count > 1)
2929            && payload.beta_baseline_timewiggle.is_none()
2930            && payload.beta_baseline_timewiggle_by_cause.is_some()
2931        {
2932            return Err(FittedModelError::SchemaMismatch {
2933                reason:
2934                    "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2935                        .to_string(),
2936            });
2937        }
2938        match (
2939            payload.baseline_timewiggle_knots.as_ref(),
2940            payload.baseline_timewiggle_degree,
2941            payload.baseline_timewiggle_penalty_orders.as_ref(),
2942            payload.baseline_timewiggle_double_penalty,
2943            payload.beta_baseline_timewiggle.as_ref(),
2944        ) {
2945            (None, None, None, None, None) => Ok(None),
2946            (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2947                Ok(Some(SavedBaselineTimeWiggleRuntime {
2948                    knots: knots.clone(),
2949                    degree,
2950                    penalty_orders: penalty_orders.clone(),
2951                    double_penalty,
2952                    beta: beta.clone(),
2953                }))
2954            }
2955            _ => Err(FittedModelError::SchemaMismatch {
2956                reason:
2957                    "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2958                        .to_string(),
2959            }),
2960        }
2961    }
2962
2963    /// Whether this model has a link wiggle component with complete metadata.
2964    #[inline]
2965    pub fn has_link_wiggle(&self) -> bool {
2966        self.saved_link_wiggle()
2967            .map(|runtime| runtime.is_some())
2968            .unwrap_or(false)
2969    }
2970
2971    /// Whether this model has a baseline-time wiggle component with complete metadata.
2972    #[inline]
2973    pub fn has_baseline_time_wiggle(&self) -> bool {
2974        let payload = self.payload();
2975        if payload
2976            .survival_cause_count
2977            .is_some_and(|cause_count| cause_count > 1)
2978        {
2979            return payload.baseline_timewiggle_knots.is_some()
2980                && payload.baseline_timewiggle_degree.is_some()
2981                && payload.baseline_timewiggle_penalty_orders.is_some()
2982                && payload.baseline_timewiggle_double_penalty.is_some()
2983                && payload.beta_baseline_timewiggle_by_cause.is_some();
2984        }
2985        self.saved_baseline_time_wiggle()
2986            .map(|runtime| runtime.is_some())
2987            .unwrap_or(false)
2988    }
2989
2990    /// Whether the default point prediction must integrate the inverse link
2991    /// over the coefficient posterior — reporting the posterior mean
2992    /// `E[g⁻¹(Xβ)]` — rather than plugging in the posterior mode `g⁻¹(Xβ̂)`.
2993    ///
2994    /// SPEC (issue #960): the posterior mean is *always* the default point
2995    /// estimate (never MAP). It is observably distinct from the plug-in exactly
2996    /// when the inverse link is *curved* over the posterior's uncertainty, so
2997    /// `E[g⁻¹(η)] ≠ g⁻¹(E[η])` by Jensen. The curvature-based classification is:
2998    ///   * all log-link families (Poisson / Gamma / Tweedie / NegativeBinomial):
2999    ///     `E[exp η] = exp(η + se²/2) ≠ exp(η)` (log-normal MGF);
3000    ///   * all Binomial links (logit / probit / cloglog / SAS / BetaLogistic /
3001    ///     Mixture / LatentCLogLog): bounded sigmoidal inverse links;
3002    ///   * Beta (logit link): `E[σ(η)] ≠ σ(E[η])`;
3003    ///   * Royston–Parmar (curved survival-probability inverse link).
3004    /// The integral collapses to the plug-in (so the cheaper plug-in path is
3005    /// exact and taken instead) only for the effectively-linear identity-link
3006    /// Gaussian. Any model carrying a link wiggle or baseline-time wiggle is
3007    /// curved regardless of family. This curvature partition mirrors
3008    /// `families::family_runtime::posterior_mean`, the compute path that produces the
3009    /// corrected mean for each of these families.
3010    ///
3011    /// This is the single source of truth shared by the CLI (`gam predict`)
3012    /// and the Python FFI prediction path so the two can never drift on which
3013    /// models receive the posterior-mean correction.
3014    #[inline]
3015    pub fn prediction_uses_posterior_mean(&self) -> bool {
3016        let family = self.likelihood();
3017        let curved_family = match &family.response {
3018            // Identity-link Gaussian: inverse link is linear, so the posterior
3019            // mean equals the plug-in and the cheaper exact path is taken.
3020            ResponseFamily::Gaussian => false,
3021            // Log-link families: E[exp η] = exp(η + se²/2) ≠ exp(η).
3022            ResponseFamily::Poisson
3023            | ResponseFamily::Gamma
3024            | ResponseFamily::Tweedie { .. }
3025            | ResponseFamily::NegativeBinomial { .. } => true,
3026            // Beta (logit link): E[σ(η)] ≠ σ(E[η]).
3027            ResponseFamily::Beta { .. } => true,
3028            // Royston–Parmar: curved survival-probability inverse link.
3029            ResponseFamily::RoystonParmar => true,
3030            // Binomial: every link variant (logit / probit / cloglog / SAS /
3031            // BetaLogistic / Mixture / LatentCLogLog) is a curved sigmoid.
3032            ResponseFamily::Binomial => matches!(
3033                &family.link,
3034                InverseLink::Standard(_)
3035                    | InverseLink::Sas(_)
3036                    | InverseLink::BetaLogistic(_)
3037                    | InverseLink::Mixture(_)
3038                    | InverseLink::LatentCLogLog(_)
3039            ),
3040        };
3041        curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3042    }
3043
3044    pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3045        self.payload().validate_payload_version()?;
3046        if matches!(
3047            self.predict_model_class(),
3048            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3049        ) {
3050            if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3051                runtime.validate_exact_replay_contract().map_err(|err| {
3052                    FittedModelError::PayloadCorrupt {
3053                        reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3054                    }
3055                })?;
3056            }
3057            if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3058                runtime.validate_exact_replay_contract().map_err(|err| {
3059                    FittedModelError::PayloadCorrupt {
3060                        reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3061                    }
3062                })?;
3063            }
3064        }
3065        let runtime = SavedPredictionRuntime {
3066            model_class: self.predict_model_class(),
3067            likelihood: self.likelihood(),
3068            inverse_link: self.resolved_inverse_link()?,
3069            link_wiggle: self.saved_link_wiggle()?,
3070            baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3071            score_warp: self.payload().score_warp_runtime.clone(),
3072            link_deviation: self.payload().link_deviation_runtime.clone(),
3073            latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3074            latent_z_conditional_calibration: self
3075                .payload()
3076                .latent_z_conditional_calibration
3077                .clone(),
3078            influence_absorber_width: self.payload().influence_absorber_width,
3079        };
3080        if matches!(
3081            runtime.model_class,
3082            PredictModelClass::GaussianLocationScale
3083                | PredictModelClass::BinomialLocationScale
3084                | PredictModelClass::DispersionLocationScale
3085        ) {
3086            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3087                FittedModelError::MissingField {
3088                    reason: "location-scale model is missing canonical fit_result payload"
3089                        .to_string(),
3090                }
3091            })?;
3092            validate_location_scale_saved_fit(
3093                fit,
3094                runtime.model_class,
3095                runtime.link_wiggle.as_ref(),
3096            )?;
3097        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3098            && self
3099                .payload()
3100                .survival_likelihood
3101                .as_deref()
3102                .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3103        {
3104            validate_survival_location_scale_saved_fit(
3105                self.payload(),
3106                runtime.link_wiggle.as_ref(),
3107            )?;
3108        } else if matches!(
3109            runtime.model_class,
3110            PredictModelClass::BernoulliMarginalSlope
3111        ) {
3112            let unified =
3113                self.payload()
3114                    .unified
3115                    .as_ref()
3116                    .ok_or_else(|| FittedModelError::MissingField {
3117                        reason: "marginal-slope model is missing unified fit payload; refit"
3118                            .to_string(),
3119                    })?;
3120            validate_marginal_slope_saved_fit(
3121                unified,
3122                runtime.score_warp.as_ref(),
3123                runtime.link_deviation.as_ref(),
3124                "unified",
3125            )?;
3126        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3127            && self
3128                .payload()
3129                .survival_likelihood
3130                .as_deref()
3131                .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3132        {
3133            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3134                FittedModelError::MissingField {
3135                    reason: "survival marginal-slope model is missing canonical fit_result payload"
3136                        .to_string(),
3137                }
3138            })?;
3139            validate_survival_marginal_slope_saved_fit(
3140                fit,
3141                runtime.score_warp.as_ref(),
3142                runtime.link_deviation.as_ref(),
3143                "fit_result",
3144            )?;
3145        }
3146        Ok(runtime)
3147    }
3148
3149    pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3150        let payload = self.payload();
3151        let raw = match &payload.family_state {
3152            FittedFamily::Standard {
3153                likelihood,
3154                sas_state,
3155                ..
3156            } if likelihood.is_binomial_sas() => {
3157                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3158                    reason: "binomial-sas model is missing state in family_state.sas_state"
3159                        .to_string(),
3160                })?
3161            }
3162            FittedFamily::LocationScale {
3163                likelihood,
3164                base_link,
3165            } if likelihood.is_binomial_sas() => match base_link {
3166                Some(InverseLink::Sas(state)) => *state,
3167                _ => {
3168                    return Err(FittedModelError::MissingField {
3169                        reason: "binomial-sas location-scale model is missing SAS base_link state"
3170                            .to_string(),
3171                    });
3172                }
3173            },
3174            _ => return Ok(None),
3175        };
3176        state_from_sasspec(SasLinkSpec {
3177            initial_epsilon: raw.epsilon,
3178            initial_log_delta: raw.log_delta,
3179        })
3180        .map(Some)
3181        .map_err(|e| FittedModelError::PayloadCorrupt {
3182            reason: format!("invalid saved SAS link state: {e}"),
3183        })
3184    }
3185
3186    pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3187        let payload = self.payload();
3188        let raw = match &payload.family_state {
3189            FittedFamily::Standard {
3190                likelihood,
3191                sas_state,
3192                ..
3193            } if likelihood.is_binomial_beta_logistic() => {
3194                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3195                    reason:
3196                        "binomial-beta-logistic model is missing state in family_state.sas_state"
3197                            .to_string(),
3198                })?
3199            }
3200            FittedFamily::LocationScale {
3201                likelihood,
3202                base_link,
3203            } if likelihood.is_binomial_beta_logistic() => match base_link {
3204                Some(InverseLink::BetaLogistic(state)) => *state,
3205                _ => {
3206                    return Err(FittedModelError::MissingField {
3207                        reason:
3208                            "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3209                                .to_string(),
3210                    });
3211                }
3212            },
3213            _ => return Ok(None),
3214        };
3215        state_from_beta_logisticspec(SasLinkSpec {
3216            initial_epsilon: raw.epsilon,
3217            initial_log_delta: raw.log_delta,
3218        })
3219        .map(Some)
3220        .map_err(|e| FittedModelError::PayloadCorrupt {
3221            reason: format!("invalid saved Beta-Logistic link state: {e}"),
3222        })
3223    }
3224
3225    pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3226        let payload = self.payload();
3227        match &payload.family_state {
3228            FittedFamily::Standard {
3229                likelihood,
3230                mixture_state,
3231                ..
3232            } if likelihood.is_binomial_mixture() => mixture_state
3233                .clone()
3234                .ok_or_else(|| FittedModelError::MissingField {
3235                    reason: "binomial-mixture model is missing state in family_state.mixture_state"
3236                        .to_string(),
3237                })
3238                .map(Some),
3239            FittedFamily::LocationScale {
3240                likelihood,
3241                base_link,
3242            } if likelihood.is_binomial_mixture() => match base_link {
3243                Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3244                _ => Err(FittedModelError::MissingField {
3245                    reason:
3246                        "binomial-mixture location-scale model is missing mixture base_link state"
3247                            .to_string(),
3248                }),
3249            },
3250            _ => Ok(None),
3251        }
3252    }
3253
3254    pub fn saved_latent_cloglog_state(
3255        &self,
3256    ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3257        let payload = self.payload();
3258        match &payload.family_state {
3259            FittedFamily::Standard {
3260                likelihood,
3261                latent_cloglog_state,
3262                ..
3263            } if likelihood.is_latent_cloglog() => latent_cloglog_state
3264                .ok_or_else(|| FittedModelError::MissingField {
3265                    reason:
3266                        "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3267                            .to_string(),
3268                })
3269                .map(Some),
3270            _ => Ok(None),
3271        }
3272    }
3273
3274    pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3275        let stateful = if let Some(state) = self.saved_mixture_state()? {
3276            Some(InverseLink::Mixture(state))
3277        } else if let Some(state) = self.saved_latent_cloglog_state()? {
3278            Some(InverseLink::LatentCLogLog(state))
3279        } else if let Some(state) = self.saved_beta_logistic_state()? {
3280            Some(InverseLink::BetaLogistic(state))
3281        } else {
3282            self.saved_sas_state()?.map(InverseLink::Sas)
3283        };
3284        match &self.payload().family_state {
3285            FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3286            FittedFamily::Standard { link, .. } => {
3287                Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3288            }
3289            FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3290            FittedFamily::Survival { .. }
3291            | FittedFamily::LatentSurvival { .. }
3292            | FittedFamily::LatentBinary { .. } => Ok(None),
3293            FittedFamily::TransformationNormal { .. } => Ok(None),
3294        }
3295    }
3296
3297    /// V∞ §5 coverage floor for the measure-jet extrapolation variance: a
3298    /// band level "covers" a query once its kernel mass reaches this fraction
3299    /// of that level's web-averaged support. Magic-by-default (no dial):
3300    /// 0.05 keeps the ε★ gate's bounded discontinuity at ≤ 5 % of the
3301    /// spectrum's total prior ignorance (see the monotonicity theorem in
3302    /// `terms/basis/measure_jet_predict.rs`) while still refusing credit
3303    /// for stray sub-floor kernel mass at levels finer than the first
3304    /// covering scale.
3305    const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3306
3307    /// V∞ §5 producer: per-row measure-jet extrapolation variance on the η
3308    /// scale for a prediction batch (`docs/measure_jet_v_infinity.md`).
3309    ///
3310    /// For every frozen measure-jet term in `resolved_termspec` this prices
3311    /// the off-support ignorance of the fitted multiscale spectrum at each
3312    /// query row: support curve from the frozen nodes/masses/band
3313    /// ([`gam_terms::basis::measure_jet_support_curve`]), fitted per-scale
3314    /// amplitudes λ̂_ℓ read from the fit's `lambdas` through the replayed
3315    /// design's penalty layout, folded through
3316    /// [`gam_terms::basis::measure_jet_extrapolation_variance`] and scaled by
3317    /// the fit's coefficient-covariance scale φ̂ so the result sits on Vp's
3318    /// η-variance scale. Terms not yet frozen (no `frozen_quadrature` or
3319    /// non-`UserProvided` centers) are skipped with a warning. Returns
3320    /// `Ok(None)` when no measure-jet term contributes, so callers leave
3321    /// `PredictUncertaintyOptions::extrapolation_variance` untouched.
3322    ///
3323    /// `data` must be the RAW (unclipped) prediction matrix in prediction
3324    /// column order — clipping to the training ranges would freeze the
3325    /// distance signal at the hull and defeat the honesty contract — and
3326    /// `col_map` the prediction header → column map (the same map handed to
3327    /// the design builder). This is the minimal-plumbing producer seam: the
3328    /// option-building callers (CLI predict, FFI) hold exactly
3329    /// `(model, data, col_map)` at the point where they assemble
3330    /// `PredictUncertaintyOptions`, and the fusion in
3331    /// `predict_gamwith_uncertainty` adds the array AFTER its multiplicative
3332    /// inflations: `Var_total = Var_Vp·inflation + Var_extrap`.
3333    pub fn measure_jet_extrapolation_variance(
3334        &self,
3335        data: ndarray::ArrayView2<'_, f64>,
3336        col_map: &HashMap<String, usize>,
3337    ) -> Result<Option<Array1<f64>>, FittedModelError> {
3338        use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3339        use gam_terms::smooth::build_term_collection_design;
3340        use gam_terms::smooth::SmoothBasisSpec;
3341        let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3342            return Ok(None);
3343        };
3344        if data.nrows() == 0
3345            || !saved_spec
3346                .smooth_terms
3347                .iter()
3348                .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3349        {
3350            return Ok(None);
3351        }
3352        let fit = self
3353            .fit_result
3354            .as_ref()
3355            .ok_or_else(|| FittedModelError::MissingField {
3356                reason: "measure-jet extrapolation variance requires the canonical \
3357                    fit_result payload; refit"
3358                    .to_string(),
3359            })?;
3360        let spec = crate::survival::predict::resolve_termspec_for_prediction(
3361            &self.resolved_termspec,
3362            self.training_headers.as_ref(),
3363            col_map,
3364            "resolved_termspec",
3365        )
3366        .map_err(|e| FittedModelError::SchemaMismatch {
3367            reason: format!("measure-jet extrapolation variance: {e}"),
3368        })?;
3369        // Penalty layout replay: the global penalty indices (→ `fit.lambdas`)
3370        // come from the SAME design builder the predict pipeline uses. One
3371        // probe row suffices — for a frozen spec the penalty layout is
3372        // row-count-invariant (centers, masses, band, and identifiability
3373        // transforms all replay verbatim) — keeping this O(centers²) instead
3374        // of duplicating the full O(rows·centers) prediction design build.
3375        let probe = data.slice(ndarray::s![0..1, ..]);
3376        let design = build_term_collection_design(probe, &spec).map_err(|e| {
3377            FittedModelError::SchemaMismatch {
3378                reason: format!(
3379                    "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3380                ),
3381            }
3382        })?;
3383        let lambdas = &fit.lambdas;
3384        // λ̂ are fitted on Frobenius-normalized penalties. The term loop
3385        // unnormalizes them to physical precisions before pricing; multiplying
3386        // by the coefficient-covariance scale puts Var_extrap on the same
3387        // η-variance scale as Vp.
3388        let phi_scale = fit.coefficient_covariance_scale();
3389        let mut total = Array1::<f64>::zeros(data.nrows());
3390        let mut contributed = false;
3391        for term in &spec.smooth_terms {
3392            let SmoothBasisSpec::MeasureJet {
3393                feature_cols,
3394                spec: mj,
3395                input_scales,
3396            } = &term.basis
3397            else {
3398                continue;
3399            };
3400            let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3401                (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3402            else {
3403                log::warn!(
3404                    "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3405                    quadrature); skipping its extrapolation variance",
3406                    term.name
3407                );
3408                continue;
3409            };
3410            let n_levels = frozen.eps_band.len();
3411            // λ̂ per level from the replayed layout: per-scale candidates carry
3412            // `PenaltySource::Other("measure_jet_scale_ℓ")`; fused
3413            // (pinned-order) mode carries one Primary charged once for the
3414            // whole band. The DoublePenaltyNullspace ridge is EXCLUDED — it shrinks
3415            // coefficients, it is not a scale amplitude, and counting it would
3416            // double-charge the spectrum.
3417            let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3418                lambdas
3419                    .get(global_index)
3420                    .copied()
3421                    .ok_or_else(|| FittedModelError::SchemaMismatch {
3422                        reason: format!(
3423                            "measure-jet term '{}': penalty global index {global_index} out \
3424                            of bounds for {} fitted lambdas",
3425                            term.name,
3426                            lambdas.len()
3427                        ),
3428                    })
3429            };
3430            let mut per_scale: Vec<(usize, f64)> = Vec::new();
3431            let mut fused: Option<f64> = None;
3432            for info in &design.penaltyinfo {
3433                if info.termname.as_deref() != Some(term.name.as_str()) {
3434                    continue;
3435                }
3436                match &info.penalty.source {
3437                    PenaltySource::Other(label) => {
3438                        if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3439                            let level: usize = level_txt.parse().map_err(|_| {
3440                                FittedModelError::SchemaMismatch {
3441                                    reason: format!(
3442                                        "measure-jet term '{}': unparseable penalty label \
3443                                        '{label}'",
3444                                        term.name
3445                                    ),
3446                                }
3447                            })?;
3448                            per_scale.push((level, read_lambda(info.global_index)?));
3449                        }
3450                    }
3451                    PenaltySource::Primary => {
3452                        fused = Some(read_lambda(info.global_index)?);
3453                    }
3454                    _ => {}
3455                }
3456            }
3457            let mut lambda_phys = Vec::with_capacity(n_levels);
3458            let spectrum = if per_scale.is_empty() {
3459                let Some(lam) = fused else {
3460                    log::warn!(
3461                        "measure-jet term '{}' has no fitted amplitude in the penalty \
3462                        layout; skipping its extrapolation variance",
3463                        term.name
3464                    );
3465                    continue;
3466                };
3467                let Some(c) = frozen.fused_penalty_normalization_scale else {
3468                    log::warn!(
3469                        "measure-jet term '{}' is missing the fused penalty normalization scale; \
3470                        skipping its extrapolation variance",
3471                        term.name
3472                    );
3473                    continue;
3474                };
3475                MeasureJetExtrapolationSpectrum::Fused(lam / c)
3476            } else {
3477                per_scale.sort_by_key(|&(level, _)| level);
3478                let levels_complete = per_scale.len() == n_levels
3479                    && per_scale
3480                        .iter()
3481                        .enumerate()
3482                        .all(|(i, &(level, _))| level == i);
3483                if !levels_complete {
3484                    log::warn!(
3485                        "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3486                        scales; skipping its extrapolation variance",
3487                        term.name,
3488                        per_scale.len(),
3489                        n_levels
3490                    );
3491                    continue;
3492                }
3493                if frozen.penalty_normalization_scales.len() != n_levels {
3494                    log::warn!(
3495                        "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3496                        band scales; skipping its extrapolation variance",
3497                        term.name,
3498                        frozen.penalty_normalization_scales.len(),
3499                        n_levels
3500                    );
3501                    continue;
3502                }
3503                lambda_phys.extend(
3504                    per_scale
3505                        .iter()
3506                        .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3507                );
3508                MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3509            };
3510            // Query rows in the frozen geometry's coordinates: select the
3511            // term's axes and replay the per-axis standardization exactly as
3512            // the build dispatch does (divide by σ_a when input_scales is
3513            // Some; the persisted centers are already post-standardization).
3514            let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3515            for (j, &col) in feature_cols.iter().enumerate() {
3516                if col >= data.ncols() {
3517                    return Err(FittedModelError::SchemaMismatch {
3518                        reason: format!(
3519                            "measure-jet term '{}': prediction column {col} out of bounds \
3520                            for {} data columns",
3521                            term.name,
3522                            data.ncols()
3523                        ),
3524                    });
3525                }
3526                queries.column_mut(j).assign(&data.column(col));
3527            }
3528            if let Some(scales) = input_scales {
3529                if scales.len() != feature_cols.len() {
3530                    return Err(FittedModelError::SchemaMismatch {
3531                        reason: format!(
3532                            "measure-jet term '{}': {} input scales for {} axes",
3533                            term.name,
3534                            scales.len(),
3535                            feature_cols.len()
3536                        ),
3537                    });
3538                }
3539                for (j, &scale) in scales.iter().enumerate() {
3540                    queries.column_mut(j).mapv_inplace(|v| v / scale);
3541                }
3542            }
3543            let support = gam_terms::basis::measure_jet_support_curve(
3544                queries.view(),
3545                centers.view(),
3546                frozen.masses.view(),
3547                &frozen.eps_band,
3548            )
3549            .map_err(|e| FittedModelError::SchemaMismatch {
3550                reason: format!(
3551                    "measure-jet term '{}': support curve failed: {e}",
3552                    term.name
3553                ),
3554            })?;
3555            for i in 0..data.nrows() {
3556                let v = gam_terms::basis::measure_jet_extrapolation_variance(
3557                    support.row(i),
3558                    &frozen.eps_band,
3559                    &frozen.support_means,
3560                    spectrum,
3561                    Self::MEASURE_JET_COVERAGE_FLOOR,
3562                )
3563                .map_err(|e| FittedModelError::SchemaMismatch {
3564                    reason: format!(
3565                        "measure-jet term '{}': extrapolation variance failed: {e}",
3566                        term.name
3567                    ),
3568                })?;
3569                total[i] += phi_scale * v;
3570            }
3571            contributed = true;
3572        }
3573        Ok(contributed.then_some(total))
3574    }
3575
3576    /// Access the unified fit result, if stored.
3577    pub fn unified(&self) -> Option<&UnifiedFitResult> {
3578        self.payload().unified.as_ref()
3579    }
3580
3581    pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3582        let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3583            reason: format!("failed to read model '{}': {e}", path.display()),
3584        })?;
3585        let model: Self =
3586            serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3587                reason: format!("failed to parse model json: {e}"),
3588            })?;
3589        let model = model.with_synchronized_stateful_link_metadata();
3590        model.validate_for_persistence()?;
3591        model.validate_numeric_finiteness()?;
3592        Ok(model)
3593    }
3594
3595    pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3596        let normalized = self.clone().with_synchronized_stateful_link_metadata();
3597        normalized.validate_for_persistence()?;
3598        normalized.validate_numeric_finiteness()?;
3599        // Write to a sibling temp file, fsync, then rename into place so a
3600        // crash mid-write never corrupts the user's existing saved fit.
3601        // Concurrent writers to the same path each have a distinct temp
3602        // suffix (pid + nanos), so neither stomps the other's in-flight
3603        // bytes; the rename winner is last-rename-wins, which is the
3604        // expected last-write-wins semantics for a single canonical path.
3605        let parent = path.parent().unwrap_or_else(|| Path::new("."));
3606        let file_name = path
3607            .file_name()
3608            .and_then(|s| s.to_str())
3609            .unwrap_or("model.json");
3610        let pid = std::process::id();
3611        let nanos = std::time::SystemTime::now()
3612            .duration_since(std::time::UNIX_EPOCH)
3613            .map(|d| d.as_nanos())
3614            .unwrap_or(0);
3615        let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3616        let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3617            reason: format!("failed to write model '{}': {e}", tmp.display()),
3618        })?;
3619        let mut writer = std::io::BufWriter::new(file);
3620        let ser_result = serde_json::to_writer(&mut writer, &normalized);
3621        if let Err(e) = ser_result {
3622            // Best-effort temp cleanup on serialization failure. flush
3623            // returns io::Result<()>; discarding via `.ok()` is enough.
3624            std::io::Write::flush(&mut writer).ok();
3625            drop(writer);
3626            fs::remove_file(&tmp).ok();
3627            return Err(FittedModelError::PayloadCorrupt {
3628                reason: format!("failed to serialize model: {e}"),
3629            });
3630        }
3631        std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3632            reason: format!("failed to write model '{}': {e}", tmp.display()),
3633        })?;
3634        // Recover the underlying File to fsync its contents before rename.
3635        let inner = writer
3636            .into_inner()
3637            .map_err(|e| FittedModelError::PayloadCorrupt {
3638                reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3639            })?;
3640        inner.sync_all().ok();
3641        drop(inner);
3642        if let Err(e) = fs::rename(&tmp, path) {
3643            fs::remove_file(&tmp).ok();
3644            return Err(FittedModelError::PayloadCorrupt {
3645                reason: format!("failed to publish model '{}': {e}", path.display()),
3646            });
3647        }
3648        // fsync the parent directory so the rename itself is durable
3649        // across a crash; without this, the rename can be lost even though
3650        // file contents reached disk. Best-effort on platforms that don't
3651        // support opening a directory for fsync.
3652        if let Ok(d) = fs::File::open(parent) {
3653            d.sync_all().ok();
3654        }
3655        Ok(())
3656    }
3657
3658    pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3659        self.data_schema
3660            .as_ref()
3661            .ok_or_else(|| FittedModelError::MissingField {
3662                reason: "model is missing data_schema; refit".to_string(),
3663            })
3664    }
3665
3666    /// Restore the exact in-memory spline-scan fit from a scan-bearing
3667    /// payload (#1030/#1034). `Ok(None)` for dense models; the returned
3668    /// `predict` replays the training Gaussian bridge bit-for-bit.
3669    pub fn saved_spline_scan(
3670        &self,
3671    ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3672        let Some(saved) = self.spline_scan.as_ref() else {
3673            return Ok(None);
3674        };
3675        let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3676            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3677        Ok(Some((saved.feature_column.as_str(), fit)))
3678    }
3679
3680    /// Restore the in-memory residual-cascade fit from a cascade-bearing
3681    /// payload (#1032). `Ok(None)` for non-cascade models; the returned fit
3682    /// replays the multilevel Wendland-frame posterior for the d ∈ {2, 3}
3683    /// feature columns at each predict point.
3684    pub fn saved_residual_cascade(
3685        &self,
3686    ) -> Result<
3687        Option<(
3688            &[String],
3689            gam_solve::residual_cascade::ResidualCascadeFit,
3690        )>,
3691        FittedModelError,
3692    > {
3693        let Some(saved) = self.residual_cascade.as_ref() else {
3694            return Ok(None);
3695        };
3696        let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3697            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3698        Ok(Some((saved.feature_columns.as_slice(), fit)))
3699    }
3700
3701    pub fn random_effect_group_columns(&self) -> HashSet<String> {
3702        let Some(training_headers) = self.training_headers.as_ref() else {
3703            return HashSet::new();
3704        };
3705        let mut out = HashSet::<String>::new();
3706        for spec in self.saved_term_specs() {
3707            for term in &spec.random_effect_terms {
3708                if let Some(name) = training_headers.get(term.feature_col) {
3709                    out.insert(name.clone());
3710                }
3711            }
3712        }
3713        out
3714    }
3715
3716    pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3717        // Hard version gate. The struct's ~40 Option<T> fields carry
3718        // `#[serde(default)]`, which is by design forward-compatible: old
3719        // payloads missing a new optional field decode with `None`. BUT:
3720        // when a new CLI release adds a required field for some family_state
3721        // (enforced below), an older model loaded by the newer CLI would have
3722        // `None` in that slot and the family-specific branch below would
3723        // correctly reject it — unless the new field also happens to slot
3724        // under a branch that hasn't been touched. Conversely, a newer model
3725        // loaded by an older CLI silently drops fields the older struct
3726        // doesn't know about. Both directions are silent-drift hazards. We
3727        // close them with an exact-version check anchored to the canonical
3728        // MODEL_PAYLOAD_VERSION constant — every payload must round-trip
3729        // identically between writers and readers running the same schema.
3730        self.validate_payload_version()?;
3731        if let Some(scan) = self.spline_scan.as_ref() {
3732            // Spline-scan representation (#1030/#1034): the smoother state IS
3733            // the fit. It is exclusive with the dense representation, only
3734            // standard Gaussian-identity models can carry it, and the state
3735            // must restore cleanly so predict never sees a corrupt snapshot.
3736            if self.fit_result.is_some() || self.unified.is_some() {
3737                return Err(FittedModelError::SchemaMismatch {
3738                    reason: "spline-scan model must not also carry a dense fit_result/unified \
3739                             payload; the representations are mutually exclusive"
3740                        .to_string(),
3741                });
3742            }
3743            if self.model_kind != ModelKind::Standard
3744                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3745            {
3746                return Err(FittedModelError::SchemaMismatch {
3747                    reason: format!(
3748                        "spline-scan representation requires a standard Gaussian-identity model; \
3749                         got model_kind={:?}, likelihood={:?}",
3750                        self.model_kind,
3751                        self.family_state.likelihood()
3752                    ),
3753                });
3754            }
3755            if scan.feature_column.is_empty() {
3756                return Err(FittedModelError::MissingField {
3757                    reason: "spline-scan model is missing its feature column name; refit"
3758                        .to_string(),
3759                });
3760            }
3761            gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3762                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3763            // A scan model carries NO dense design, so the dense-path
3764            // requirements below (resolved_termspec, fit_result finiteness,
3765            // family-specific blocks) do not apply. Enforce only the metadata
3766            // predict actually consumes — the feature column resolves against
3767            // training_headers / data_schema — then accept.
3768            if self.data_schema.is_none() {
3769                return Err(FittedModelError::MissingField {
3770                    reason: "spline-scan model is missing data_schema; refit".to_string(),
3771                });
3772            }
3773            if self.training_headers.is_none() {
3774                return Err(FittedModelError::MissingField {
3775                    reason: "spline-scan model is missing training_headers; refit".to_string(),
3776                });
3777            }
3778            return Ok(());
3779        } else if let Some(cascade) = self.residual_cascade.as_ref() {
3780            // Residual-cascade representation (#1032): a multilevel
3781            // Wendland-frame model for a scattered d ∈ {2,3} Gaussian smooth.
3782            // Exclusive with the dense representation and with the scan.
3783            if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3784                return Err(FittedModelError::SchemaMismatch {
3785                    reason: "residual-cascade model must not also carry spline_scan / \
3786                             fit_result / unified payloads; the representations are \
3787                             mutually exclusive"
3788                        .to_string(),
3789                });
3790            }
3791            if self.model_kind != ModelKind::Standard
3792                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3793            {
3794                return Err(FittedModelError::SchemaMismatch {
3795                    reason: format!(
3796                        "residual-cascade representation requires a standard Gaussian-identity \
3797                         model; got model_kind={:?}, likelihood={:?}",
3798                        self.model_kind,
3799                        self.family_state.likelihood()
3800                    ),
3801                });
3802            }
3803            if cascade.feature_columns.is_empty()
3804                || !(2..=3).contains(&cascade.feature_columns.len())
3805            {
3806                return Err(FittedModelError::MissingField {
3807                    reason: format!(
3808                        "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3809                        cascade.feature_columns.len()
3810                    ),
3811                });
3812            }
3813            gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3814                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3815            if self.data_schema.is_none() {
3816                return Err(FittedModelError::MissingField {
3817                    reason: "residual-cascade model is missing data_schema; refit".to_string(),
3818                });
3819            }
3820            if self.training_headers.is_none() {
3821                return Err(FittedModelError::MissingField {
3822                    reason: "residual-cascade model is missing training_headers; refit".to_string(),
3823                });
3824            }
3825            return Ok(());
3826        } else if self.fit_result.is_none() {
3827            return Err(FittedModelError::MissingField {
3828                reason: "model is missing canonical fit_result payload; refit".to_string(),
3829            });
3830        }
3831        if self.data_schema.is_none() {
3832            return Err(FittedModelError::MissingField {
3833                reason: "model is missing data_schema; refit".to_string(),
3834            });
3835        }
3836        if self.training_headers.is_none() {
3837            return Err(FittedModelError::MissingField {
3838                reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3839                    .to_string(),
3840            });
3841        }
3842        let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3843            FittedModelError::MissingField {
3844                reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3845                    .to_string(),
3846            }
3847        })?;
3848        validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3849
3850        if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3851            return Err(FittedModelError::MissingField {
3852                reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3853                    .to_string(),
3854            });
3855        }
3856        if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3857            validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3858        }
3859        if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3860            let score = self.transformation_score_calibration.ok_or_else(|| {
3861                FittedModelError::MissingField {
3862                    reason: "transformation-normal model is missing transformation_score_calibration; refit"
3863                        .to_string(),
3864                }
3865            })?;
3866            score.validate("transformation-normal model")?;
3867        }
3868        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3869            if self.formula_logslope.is_none() {
3870                return Err(FittedModelError::MissingField {
3871                    reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3872                });
3873            }
3874            if self.z_column.is_none() {
3875                return Err(FittedModelError::MissingField {
3876                    reason: "marginal-slope model is missing z_column; refit".to_string(),
3877                });
3878            }
3879            let z_normalization =
3880                self.latent_z_normalization
3881                    .ok_or_else(|| FittedModelError::MissingField {
3882                        reason: "marginal-slope model is missing latent_z_normalization; refit"
3883                            .to_string(),
3884                    })?;
3885            z_normalization.validate("marginal-slope model")?;
3886            let latent_measure =
3887                self.latent_measure
3888                    .as_ref()
3889                    .ok_or_else(|| FittedModelError::MissingField {
3890                        reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3891                    })?;
3892            latent_measure
3893                .validate("marginal-slope model latent_measure")
3894                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3895            if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3896                return Err(FittedModelError::MissingField {
3897                    reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3898                });
3899            }
3900            if self.resolved_termspec_logslope.as_ref().is_none() {
3901                return Err(FittedModelError::MissingField {
3902                    reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3903                        .to_string(),
3904                });
3905            }
3906            match self.family_state.frailty() {
3907                Some(FrailtySpec::None)
3908                | Some(FrailtySpec::GaussianShift {
3909                    sigma_fixed: Some(_),
3910                }) => {}
3911                Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3912                    return Err(FittedModelError::IncompatibleConfig {
3913                        reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3914                            .to_string(),
3915                    });
3916                }
3917                Some(FrailtySpec::HazardMultiplier { .. }) => {
3918                    return Err(FittedModelError::IncompatibleConfig {
3919                        reason: "marginal-slope model does not support HazardMultiplier frailty"
3920                            .to_string(),
3921                    });
3922                }
3923                None => {
3924                    return Err(FittedModelError::MissingField {
3925                        reason: "marginal-slope model is missing family_state.frailty; refit"
3926                            .to_string(),
3927                    });
3928                }
3929            }
3930        }
3931
3932        if let FittedFamily::Survival {
3933            survival_likelihood,
3934            frailty,
3935            ..
3936        } = &self.family_state
3937        {
3938            if matches!(
3939                survival_likelihood.as_deref(),
3940                Some("latent") | Some("latent-binary")
3941            ) {
3942                return Err(FittedModelError::SchemaMismatch {
3943                    reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3944                        .to_string(),
3945                });
3946            }
3947            if survival_likelihood.as_deref() == Some("marginal-slope") {
3948                if self.formula_logslope.is_none() {
3949                    return Err(FittedModelError::MissingField {
3950                        reason: "survival marginal-slope model is missing formula_logslope; refit"
3951                            .to_string(),
3952                    });
3953                }
3954                if self.z_column.is_none() {
3955                    return Err(FittedModelError::MissingField {
3956                        reason: "survival marginal-slope model is missing z_column; refit"
3957                            .to_string(),
3958                    });
3959                }
3960                let z_normalization =
3961                    self.latent_z_normalization
3962                        .ok_or_else(|| {
3963                            FittedModelError::MissingField {
3964                        reason:
3965                            "survival marginal-slope model is missing latent_z_normalization; refit"
3966                                .to_string(),
3967                    }
3968                        })?;
3969                z_normalization.validate("survival marginal-slope model")?;
3970                let latent_measure =
3971                    self.latent_measure
3972                        .as_ref()
3973                        .ok_or_else(|| FittedModelError::MissingField {
3974                            reason:
3975                                "survival marginal-slope model is missing latent_measure; refit"
3976                                    .to_string(),
3977                        })?;
3978                latent_measure
3979                    .validate("survival marginal-slope model latent_measure")
3980                    .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3981                if self.logslope_baseline.is_none() {
3982                    return Err(FittedModelError::MissingField {
3983                        reason: "survival marginal-slope model is missing logslope_baseline; refit"
3984                            .to_string(),
3985                    });
3986                }
3987                if self.resolved_termspec_logslope.as_ref().is_none() {
3988                    return Err(FittedModelError::MissingField {
3989                        reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3990                            .to_string(),
3991                    });
3992                }
3993                match frailty {
3994                    FrailtySpec::None
3995                    | FrailtySpec::GaussianShift {
3996                        sigma_fixed: Some(_),
3997                    } => {}
3998                    FrailtySpec::GaussianShift { sigma_fixed: None } => {
3999                        return Err(FittedModelError::IncompatibleConfig {
4000                            reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4001                                .to_string(),
4002                        });
4003                    }
4004                    FrailtySpec::HazardMultiplier { .. } => {
4005                        return Err(FittedModelError::IncompatibleConfig {
4006                            reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4007                                .to_string(),
4008                        });
4009                    }
4010                }
4011            } else if !matches!(frailty, FrailtySpec::None) {
4012                return Err(FittedModelError::IncompatibleConfig {
4013                    reason:
4014                        "non-marginal survival models do not currently persist a frailty modifier"
4015                            .to_string(),
4016                });
4017            }
4018            // Non-latent survival predict reconstructs the baseline-time
4019            // basis via `load_survival_time_basis_config_from_model` and
4020            // anchors that basis at `survival_time_anchor`; both are
4021            // required for the saved model to be loadable. The CLI's
4022            // marginal-slope+time-wiggle save path previously dropped one or
4023            // the other on partial-write, producing models that loaded but
4024            // would panic at the first predict. Enforce both before persisting.
4025            if self.survival_time_basis.is_none() {
4026                return Err(FittedModelError::MissingField {
4027                    reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4028                });
4029            }
4030            if self.survival_time_anchor.is_none() {
4031                return Err(FittedModelError::MissingField {
4032                    reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4033                });
4034            }
4035        }
4036        if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4037            match frailty {
4038                FrailtySpec::HazardMultiplier {
4039                    sigma_fixed: Some(_),
4040                    ..
4041                } => {}
4042                FrailtySpec::HazardMultiplier {
4043                    sigma_fixed: None, ..
4044                } => {
4045                    return Err(FittedModelError::IncompatibleConfig {
4046                        reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4047                            .to_string(),
4048                    });
4049                }
4050                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4051                    return Err(FittedModelError::IncompatibleConfig {
4052                        reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4053                            .to_string(),
4054                    });
4055                }
4056            }
4057            if self.survival_likelihood.as_deref() != Some("latent") {
4058                return Err(FittedModelError::SchemaMismatch {
4059                    reason: "latent survival model must persist survival_likelihood=latent"
4060                        .to_string(),
4061                });
4062            }
4063        }
4064        if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4065            match frailty {
4066                FrailtySpec::HazardMultiplier {
4067                    sigma_fixed: Some(_),
4068                    ..
4069                } => {}
4070                FrailtySpec::HazardMultiplier {
4071                    sigma_fixed: None, ..
4072                } => {
4073                    return Err(FittedModelError::IncompatibleConfig {
4074                        reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4075                            .to_string(),
4076                    });
4077                }
4078                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4079                    return Err(FittedModelError::IncompatibleConfig {
4080                        reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4081                            .to_string(),
4082                    });
4083                }
4084            }
4085            if self.survival_likelihood.as_deref() != Some("latent-binary") {
4086                return Err(FittedModelError::SchemaMismatch {
4087                    reason: "latent binary model must persist survival_likelihood=latent-binary"
4088                        .to_string(),
4089                });
4090            }
4091        }
4092
4093        let family_likelihood = match &self.family_state {
4094            FittedFamily::Standard { likelihood, .. }
4095            | FittedFamily::LocationScale { likelihood, .. }
4096            | FittedFamily::MarginalSlope { likelihood, .. }
4097            | FittedFamily::Survival { likelihood, .. }
4098            | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4099            FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4100        };
4101        let is_standard_or_location_scale = matches!(
4102            self.family_state,
4103            FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4104        );
4105        if is_standard_or_location_scale
4106            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4107        {
4108            self.saved_sas_state()?;
4109        }
4110        if is_standard_or_location_scale
4111            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4112        {
4113            self.saved_beta_logistic_state()?;
4114        }
4115        if is_standard_or_location_scale
4116            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4117        {
4118            self.saved_mixture_state()?;
4119        }
4120        if matches!(self.family_state, FittedFamily::Standard { .. })
4121            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4122        {
4123            self.saved_latent_cloglog_state()?;
4124        }
4125        if matches!(self.family_state, FittedFamily::LocationScale { .. })
4126            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4127        {
4128            return Err(FittedModelError::IncompatibleConfig {
4129                reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4130                    .to_string(),
4131            });
4132        }
4133        if matches!(self.family_state, FittedFamily::Survival { .. })
4134            && self.survival_likelihood.is_none()
4135        {
4136            return Err(FittedModelError::MissingField {
4137                reason: "saved survival model is missing survival_likelihood metadata; refit"
4138                    .to_string(),
4139            });
4140        }
4141        let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4142            || self.linkwiggle_degree.is_some()
4143            || self.beta_link_wiggle.is_some()
4144            || self
4145                .fit_result
4146                .as_ref()
4147                .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4148                .is_some();
4149        let saved_link_wiggle = self.saved_link_wiggle()?;
4150        if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4151            return Err(FittedModelError::SchemaMismatch {
4152                reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4153                    .to_string(),
4154            });
4155        }
4156        let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4157            || self.baseline_timewiggle_degree.is_some()
4158            || self.baseline_timewiggle_penalty_orders.is_some()
4159            || self.baseline_timewiggle_double_penalty.is_some()
4160            || self.beta_baseline_timewiggle.is_some()
4161            || self.beta_baseline_timewiggle_by_cause.is_some();
4162        let is_joint_cause_specific = self
4163            .survival_cause_count
4164            .is_some_and(|cause_count| cause_count > 1);
4165        if has_any_saved_baseline_time_wiggle {
4166            if is_joint_cause_specific {
4167                let complete = self.baseline_timewiggle_knots.is_some()
4168                    && self.baseline_timewiggle_degree.is_some()
4169                    && self.baseline_timewiggle_penalty_orders.is_some()
4170                    && self.baseline_timewiggle_double_penalty.is_some()
4171                    && self.beta_baseline_timewiggle_by_cause.is_some();
4172                if !complete {
4173                    return Err(FittedModelError::SchemaMismatch {
4174                        reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4175                            .to_string(),
4176                    });
4177                }
4178            } else if self.saved_baseline_time_wiggle()?.is_none() {
4179                return Err(FittedModelError::SchemaMismatch {
4180                    reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4181                        .to_string(),
4182                });
4183            }
4184        }
4185        if self
4186            .survival_likelihood
4187            .as_deref()
4188            .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4189        {
4190            validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4191        }
4192
4193        // Validate anchored-deviation replay contracts at LOAD/SAVE time rather
4194        // than waiting for first predict call. Previously these contracts
4195        // (span table dimensions, coefficient matrices, etc.) were only
4196        // asserted inside `saved_prediction_runtime`, which runs on the first
4197        // predict invocation. A corrupted runtime would therefore pass
4198        // `load_from_path` silently and fail later under a different error
4199        // surface. Enforcing the same check here makes the model self-
4200        // diagnostic: `gam fit` catches its own bad output at save, and
4201        // `gam predict` catches bad input at load rather than mid-pipeline.
4202        if let Some(runtime) = self.score_warp_runtime.as_ref() {
4203            runtime.validate_exact_replay_contract().map_err(|err| {
4204                FittedModelError::PayloadCorrupt {
4205                    reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4206                }
4207            })?;
4208        }
4209        if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4210            runtime.validate_exact_replay_contract().map_err(|err| {
4211                FittedModelError::PayloadCorrupt {
4212                    reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4213                }
4214            })?;
4215        }
4216        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4217            validate_marginal_slope_saved_fit(
4218                self.fit_result.as_ref().expect("checked above"),
4219                self.score_warp_runtime.as_ref(),
4220                self.link_deviation_runtime.as_ref(),
4221                "fit_result",
4222            )?;
4223            let unified = self
4224                .unified
4225                .as_ref()
4226                .ok_or_else(|| FittedModelError::MissingField {
4227                    reason: "marginal-slope model is missing unified fit payload; refit"
4228                        .to_string(),
4229                })?;
4230            validate_marginal_slope_saved_fit(
4231                unified,
4232                self.score_warp_runtime.as_ref(),
4233                self.link_deviation_runtime.as_ref(),
4234                "unified",
4235            )?;
4236        }
4237        if self
4238            .survival_likelihood
4239            .as_deref()
4240            .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4241        {
4242            validate_survival_marginal_slope_saved_fit(
4243                self.fit_result.as_ref().expect("checked above"),
4244                self.score_warp_runtime.as_ref(),
4245                self.link_deviation_runtime.as_ref(),
4246                "fit_result",
4247            )?;
4248            if let Some(unified) = self.unified.as_ref() {
4249                validate_survival_marginal_slope_saved_fit(
4250                    unified,
4251                    self.score_warp_runtime.as_ref(),
4252                    self.link_deviation_runtime.as_ref(),
4253                    "unified",
4254                )?;
4255            }
4256        }
4257
4258        // Posterior-mean / uncertainty backends are validated at predict time
4259        // by `prediction_backend_from_model`, which has access to the actual
4260        // requested mode and emits the canonical "nonlinear posterior-mean
4261        // prediction requires either covariance or a saved penalized Hessian"
4262        // error.  Save-time we deliberately do NOT enforce that gate: a fit
4263        // produced for MAP / plug-in scoring can be persisted and replayed
4264        // without ever needing a covariance backend, and gating it here would
4265        // refuse legitimate MAP-only saves whose `UnifiedFitResult` carries
4266        // beta + lambdas without a stabilized Hessian.
4267
4268        Ok(())
4269    }
4270
4271    pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4272        let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4273        if let Some(fit) = self.fit_result.as_ref() {
4274            fit.validate_numeric_finiteness()
4275                .map_err(|e| corrupt(e.to_string()))?;
4276        }
4277
4278        for (name, opt) in [
4279            ("survival_baseline_scale", self.survival_baseline_scale),
4280            ("survival_baseline_shape", self.survival_baseline_shape),
4281            ("survival_baseline_rate", self.survival_baseline_rate),
4282            ("survival_baseline_makeham", self.survival_baseline_makeham),
4283            (
4284                "survival_time_smooth_lambda",
4285                self.survival_time_smooth_lambda,
4286            ),
4287            ("survival_time_anchor", self.survival_time_anchor),
4288            ("survivalridge_lambda", self.survivalridge_lambda),
4289        ] {
4290            if let Some(v) = opt {
4291                ensure_finite_scalar(name, v).map_err(corrupt)?;
4292            }
4293        }
4294
4295        if let Some(v) = self.beta_noise.as_ref() {
4296            validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4297        }
4298        if let Some(v) = self.noise_projection.as_ref() {
4299            validate_all_finite("noise_projection", v.iter().flatten().copied())
4300                .map_err(corrupt)?;
4301            if self.noise_projection_ridge_alpha.is_none() {
4302                return Err(FittedModelError::MissingField {
4303                    reason:
4304                        "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4305                            .to_string(),
4306                });
4307            }
4308        }
4309        if let Some(v) = self.noise_center.as_ref() {
4310            validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4311        }
4312        if let Some(v) = self.noise_scale.as_ref() {
4313            validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4314        }
4315        if let Some(v) = self.noise_projection_ridge_alpha {
4316            ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4317            if v < 0.0 {
4318                return Err(FittedModelError::InvalidInput {
4319                    reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4320                });
4321            }
4322        }
4323        if let Some(v) = self.gaussian_response_scale {
4324            ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4325        }
4326        if let Some(v) = self.beta_link_wiggle.as_ref() {
4327            validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4328        }
4329        if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4330            validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4331        }
4332        if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4333            validate_all_finite(
4334                "beta_baseline_timewiggle_by_cause",
4335                v.iter().flatten().copied(),
4336            )
4337            .map_err(corrupt)?;
4338        }
4339        if let Some(v) = self.latent_z_normalization {
4340            v.validate("latent_z_normalization")?;
4341        }
4342        if let Some(v) = self.latent_measure.as_ref() {
4343            v.validate("latent_measure").map_err(corrupt)?;
4344        }
4345        if let Some(v) = self.survival_beta_time.as_ref() {
4346            validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4347        }
4348        if let Some(v) = self.survival_beta_threshold.as_ref() {
4349            validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4350        }
4351        if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4352            validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4353        }
4354        if let Some(v) = self.survival_noise_projection.as_ref() {
4355            validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4356                .map_err(corrupt)?;
4357            if self.survival_noise_projection_ridge_alpha.is_none() {
4358                return Err(FittedModelError::MissingField {
4359                    reason:
4360                        "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4361                            .to_string(),
4362                });
4363            }
4364        }
4365        if let Some(v) = self.survival_noise_center.as_ref() {
4366            validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4367        }
4368        if let Some(v) = self.survival_noise_projection_ridge_alpha {
4369            ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4370            if v < 0.0 {
4371                return Err(FittedModelError::InvalidInput {
4372                    reason: format!(
4373                        "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4374                    ),
4375                });
4376            }
4377        }
4378        if let Some(v) = self.survival_noise_scale.as_ref() {
4379            validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4380        }
4381        if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4382            validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4383                .map_err(corrupt)?;
4384        }
4385        if let Some(v) = self.sas_param_covariance.as_ref() {
4386            validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4387                .map_err(corrupt)?;
4388        }
4389        Ok(())
4390    }
4391}
4392
4393fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4394    a.rows().into_iter().map(|row| row.to_vec()).collect()
4395}
4396
4397use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4398
4399fn validate_frozen_term_collectionspec(
4400    spec: &TermCollectionSpec,
4401    label: &str,
4402) -> Result<(), FittedModelError> {
4403    spec.validate_frozen(label)
4404        .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4405}
4406
4407impl Deref for FittedModel {
4408    type Target = FittedModelPayload;
4409
4410    fn deref(&self) -> &Self::Target {
4411        self.payload()
4412    }
4413}
4414
4415impl DerefMut for FittedModel {
4416    fn deref_mut(&mut self) -> &mut Self::Target {
4417        self.payload_mut()
4418    }
4419}
4420
4421// ---------------------------------------------------------------------------
4422// Reconstruct library types from saved models
4423// ---------------------------------------------------------------------------
4424
4425pub fn survival_baseline_config_from_model(
4426    model: &FittedModel,
4427) -> Result<SurvivalBaselineConfig, FittedModelError> {
4428    let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4429        FittedModelError::MissingField {
4430            reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4431        }
4432    })?;
4433    parse_survival_baseline_config(
4434        target,
4435        model.survival_baseline_scale,
4436        model.survival_baseline_shape,
4437        model.survival_baseline_rate,
4438        model.survival_baseline_makeham,
4439    )
4440    .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4441}
4442
4443pub fn load_survival_time_basis_config_from_model(
4444    model: &FittedModel,
4445) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4446    match model
4447        .survival_time_basis
4448        .as_deref()
4449        .ok_or_else(|| FittedModelError::MissingField {
4450            reason: "saved survival model missing survival_time_basis".to_string(),
4451        })?
4452        .to_ascii_lowercase()
4453        .as_str()
4454    {
4455        "none" => Ok(SurvivalTimeBasisConfig::None),
4456        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4457        "bspline" => {
4458            let degree =
4459                model
4460                    .survival_time_degree
4461                    .ok_or_else(|| FittedModelError::MissingField {
4462                        reason: "saved survival bspline model missing survival_time_degree"
4463                            .to_string(),
4464                    })?;
4465            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4466                FittedModelError::MissingField {
4467                    reason: "saved survival bspline model missing survival_time_knots".to_string(),
4468                }
4469            })?;
4470            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4471            if degree < 1 || knots.is_empty() {
4472                return Err(FittedModelError::SchemaMismatch {
4473                    reason: "saved survival bspline time basis metadata is invalid".to_string(),
4474                });
4475            }
4476            Ok(SurvivalTimeBasisConfig::BSpline {
4477                degree,
4478                knots: Array1::from_vec(knots),
4479                smooth_lambda,
4480            })
4481        }
4482        "ispline" => {
4483            let degree =
4484                model
4485                    .survival_time_degree
4486                    .ok_or_else(|| FittedModelError::MissingField {
4487                        reason: "saved survival ispline model missing survival_time_degree"
4488                            .to_string(),
4489                    })?;
4490            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4491                FittedModelError::MissingField {
4492                    reason: "saved survival ispline model missing survival_time_knots".to_string(),
4493                }
4494            })?;
4495            let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4496                FittedModelError::MissingField {
4497                    reason: "saved survival ispline model missing survival_time_keep_cols"
4498                        .to_string(),
4499                }
4500            })?;
4501            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4502            if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4503                return Err(FittedModelError::SchemaMismatch {
4504                    reason: "saved survival ispline time basis metadata is invalid".to_string(),
4505                });
4506            }
4507            Ok(SurvivalTimeBasisConfig::ISpline {
4508                degree,
4509                knots: Array1::from_vec(knots),
4510                keep_cols,
4511                smooth_lambda,
4512            })
4513        }
4514        other => Err(FittedModelError::IncompatibleConfig {
4515            reason: format!("unsupported saved survival_time_basis '{other}'"),
4516        }),
4517    }
4518}
4519
4520#[cfg(test)]
4521mod tests {
4522    use super::*;
4523    use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4524    use crate::survival::lognormal_kernel::FrailtySpec;
4525    use gam_solve::pirls::PirlsStatus;
4526    use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4527    use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4528    use gam_data::SchemaColumn;
4529    use ndarray::{Array1, Array2, array};
4530
4531    fn empty_termspec() -> TermCollectionSpec {
4532        TermCollectionSpec {
4533            linear_terms: vec![],
4534            random_effect_terms: vec![],
4535            smooth_terms: vec![],
4536        }
4537    }
4538
4539    /// #1030/#1034: a scan-bearing payload must round-trip through JSON +
4540    /// `validate_for_persistence` and replay the training Gaussian bridge
4541    /// bit-for-bit; structural corruption must fail loudly at validation.
4542    #[test]
4543    fn spline_scan_payload_round_trips_and_validates() {
4544        let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4545        let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4546        let w = vec![1.0_f64; x.len()];
4547        let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4548        let make_payload = || {
4549            crate::inference::model_payload_builders::assemble_spline_scan_payload(
4550                "y ~ s(x)".to_string(),
4551                "x".to_string(),
4552                &fit,
4553                DataSchema {
4554                    columns: vec![
4555                        SchemaColumn {
4556                            name: "y".to_string(),
4557                            kind: ColumnKindTag::Continuous,
4558                            levels: vec![],
4559                        },
4560                        SchemaColumn {
4561                            name: "x".to_string(),
4562                            kind: ColumnKindTag::Continuous,
4563                            levels: vec![],
4564                        },
4565                    ],
4566                },
4567                vec!["x".to_string()],
4568                vec![(0.0, 1.0)],
4569            )
4570        };
4571        // The on-disk form is the FittedModel tagged enum; validation and the
4572        // scan accessor live on FittedModel (Deref only goes Model -> Payload).
4573        let model = FittedModel::from_payload(make_payload());
4574        model
4575            .validate_for_persistence()
4576            .expect("scan model validates");
4577        model
4578            .validate_numeric_finiteness()
4579            .expect("scan model is finite");
4580
4581        let json = serde_json::to_string(&model).expect("serialize model");
4582        let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4583        restored
4584            .validate_for_persistence()
4585            .expect("restored scan model validates");
4586        let (column, replay) = restored
4587            .saved_spline_scan()
4588            .expect("restore scan fit")
4589            .expect("payload carries the scan representation");
4590        assert_eq!(column, "x");
4591        for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4592            let (m0, v0) = fit.predict(xq).expect("predict original");
4593            let (m1, v1) = replay.predict(xq).expect("predict replayed");
4594            assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4595            assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4596        }
4597
4598        // A dense model without the scan channel still requires fit_result.
4599        let mut dense = make_payload();
4600        dense.spline_scan = None;
4601        let err = FittedModel::from_payload(dense)
4602            .validate_for_persistence()
4603            .expect_err("dense payload without fit_result must be rejected");
4604        assert!(err.to_string().contains("fit_result"));
4605
4606        // Structural corruption fails at validation, not inside predict.
4607        let mut corrupt = make_payload();
4608        corrupt
4609            .spline_scan
4610            .as_mut()
4611            .expect("scan channel present")
4612            .state
4613            .knots
4614            .truncate(2);
4615        FittedModel::from_payload(corrupt)
4616            .validate_for_persistence()
4617            .expect_err("corrupt scan state must be rejected");
4618        let mut unnamed = make_payload();
4619        unnamed
4620            .spline_scan
4621            .as_mut()
4622            .expect("scan channel present")
4623            .feature_column
4624            .clear();
4625        FittedModel::from_payload(unnamed)
4626            .validate_for_persistence()
4627            .expect_err("missing feature column must be rejected");
4628    }
4629
4630    fn standard_gaussian_payload() -> FittedModelPayload {
4631        FittedModelPayload::new(
4632            MODEL_PAYLOAD_VERSION,
4633            "y ~ 1".to_string(),
4634            ModelKind::Standard,
4635            FittedFamily::Standard {
4636                likelihood: LikelihoodSpec::gaussian_identity(),
4637                link: Some(StandardLink::Identity),
4638                latent_cloglog_state: None,
4639                mixture_state: None,
4640                sas_state: None,
4641            },
4642            "gaussian".to_string(),
4643        )
4644    }
4645
4646    fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4647        SavedCompiledFlexBlock {
4648            kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4649            breakpoints: vec![-1.0, 1.0],
4650            basis_dim,
4651            span_c0: vec![vec![0.0; basis_dim]],
4652            span_c1: vec![vec![0.0; basis_dim]],
4653            span_c2: vec![vec![0.0; basis_dim]],
4654            span_c3: vec![vec![0.0; basis_dim]],
4655            anchor_correction: None,
4656            anchor_components: Vec::new(),
4657        }
4658    }
4659
4660    fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4661        let beta = Array1::from_vec(
4662            blocks
4663                .iter()
4664                .flat_map(|block| block.beta.iter().copied())
4665                .collect(),
4666        );
4667        let p = beta.len();
4668        UnifiedFitResult {
4669            blocks,
4670            log_lambdas: Array1::zeros(0),
4671            lambdas: Array1::zeros(0),
4672            likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4673            likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4674            log_likelihood_normalization: LogLikelihoodNormalization::Full,
4675            log_likelihood: 0.0,
4676            deviance: 0.0,
4677            reml_score: 0.0,
4678            stable_penalty_term: 0.0,
4679            penalized_objective: 0.0,
4680            used_device: false,
4681            outer_iterations: 0,
4682            outer_cost_evals: 0,
4683            inner_pirls_solves: 0,
4684            outer_converged: true,
4685            outer_gradient_norm: None,
4686            standard_deviation: 1.0,
4687            covariance_conditional: Some(Array2::zeros((p, p))),
4688            covariance_corrected: Some(Array2::zeros((p, p))),
4689            inference: None,
4690            fitted_link: FittedLinkState::Standard(None),
4691            geometry: None,
4692            block_states: vec![],
4693            beta,
4694            pirls_status: PirlsStatus::Converged,
4695            max_abs_eta: 0.0,
4696            constraint_kkt: None,
4697            artifacts: FitArtifacts {
4698                pirls: None,
4699                null_space_logdet: None,
4700                null_space_dim: None,
4701                survival_link_wiggle_knots: None,
4702                survival_link_wiggle_degree: None,
4703                criterion_certificate: None,
4704                rho_posterior_certificate: None,
4705                rho_posterior_escalation: None,
4706                rho_covariance: None,
4707            },
4708            inner_cycles: 0,
4709        }
4710    }
4711
4712    fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4713        let mut payload = FittedModelPayload::new(
4714            version,
4715            "y ~ 1".to_string(),
4716            ModelKind::MarginalSlope,
4717            FittedFamily::MarginalSlope {
4718                likelihood: LikelihoodSpec::binomial_probit(),
4719                base_link: InverseLink::Standard(StandardLink::Probit),
4720                frailty: FrailtySpec::None,
4721            },
4722            "bernoulli-marginal-slope".to_string(),
4723        );
4724        payload.fit_result = Some(fit.clone());
4725        payload.unified = Some(fit);
4726        payload.data_schema = Some(DataSchema {
4727            columns: vec![SchemaColumn {
4728                name: "z".to_string(),
4729                kind: ColumnKindTag::Continuous,
4730                levels: vec![],
4731            }],
4732        });
4733        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4734        payload.resolved_termspec = Some(empty_termspec());
4735        payload.resolved_termspec_logslope = Some(empty_termspec());
4736        payload.formula_logslope = Some("1".to_string());
4737        payload.z_column = Some("z".to_string());
4738        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4739        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4740        payload.marginal_baseline = Some(0.0);
4741        payload.logslope_baseline = Some(0.0);
4742        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4743        payload
4744    }
4745
4746    #[test]
4747    fn from_payload_synchronizes_used_device_from_saved_fit() {
4748        let mut fit = saved_fit(vec![
4749            FittedBlock {
4750                beta: Array1::from_vec(vec![0.25]),
4751                role: BlockRole::Mean,
4752                edf: 1.0,
4753                lambdas: Array1::zeros(0),
4754            },
4755            FittedBlock {
4756                beta: Array1::from_vec(vec![0.5]),
4757                role: BlockRole::Scale,
4758                edf: 1.0,
4759                lambdas: Array1::zeros(0),
4760            },
4761        ]);
4762        fit.used_device = true;
4763        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4764        payload.used_device = false;
4765
4766        let model = FittedModel::from_payload(payload);
4767
4768        assert!(model.payload().used_device);
4769    }
4770
4771    fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4772        let mut payload = FittedModelPayload::new(
4773            version,
4774            "Surv(entry, exit, event) ~ 1".to_string(),
4775            ModelKind::Survival,
4776            FittedFamily::Survival {
4777                likelihood: LikelihoodSpec::royston_parmar(),
4778                survival_likelihood: Some("marginal-slope".to_string()),
4779                survival_distribution: Some(ResidualDistribution::Gaussian),
4780                frailty: FrailtySpec::None,
4781            },
4782            "survival".to_string(),
4783        );
4784        payload.fit_result = Some(fit.clone());
4785        payload.unified = Some(fit);
4786        payload.survival_likelihood = Some("marginal-slope".to_string());
4787        payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4788        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4789        payload.data_schema = Some(DataSchema {
4790            columns: vec![SchemaColumn {
4791                name: "z".to_string(),
4792                kind: ColumnKindTag::Continuous,
4793                levels: vec![],
4794            }],
4795        });
4796        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4797        payload.resolved_termspec = Some(empty_termspec());
4798        payload.resolved_termspec_logslope = Some(empty_termspec());
4799        payload.formula_logslope = Some("1".to_string());
4800        payload.z_column = Some("z".to_string());
4801        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4802        payload.logslope_baseline = Some(0.0);
4803        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4804        payload
4805    }
4806
4807    fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4808        // A #913 genuine-dispersion location-scale model: Gamma mean family with
4809        // a log-precision `noise_formula` channel. Its likelihood response is
4810        // non-Gaussian and non-Binomial, so the predict-path classifier must
4811        // route it to `DispersionLocationScale`, NOT the binomial threshold-scale
4812        // class (issue #1064).
4813        let mut payload = FittedModelPayload::new(
4814            MODEL_PAYLOAD_VERSION,
4815            "y ~ x".to_string(),
4816            ModelKind::LocationScale,
4817            FittedFamily::LocationScale {
4818                likelihood: LikelihoodSpec::gamma_log(),
4819                base_link: Some(InverseLink::Standard(StandardLink::Log)),
4820            },
4821            "gamma-location-scale".to_string(),
4822        );
4823        payload.data_schema = Some(DataSchema {
4824            columns: vec![
4825                SchemaColumn {
4826                    name: "y".to_string(),
4827                    kind: ColumnKindTag::Continuous,
4828                    levels: vec![],
4829                },
4830                SchemaColumn {
4831                    name: "x".to_string(),
4832                    kind: ColumnKindTag::Continuous,
4833                    levels: vec![],
4834                },
4835            ],
4836        });
4837        payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4838        payload.resolved_termspec = Some(empty_termspec());
4839        payload.resolved_termspec_noise = Some(empty_termspec());
4840        payload.formula_noise = Some("x".to_string());
4841        payload.beta_noise = Some(vec![0.0]);
4842        payload.link = Some(InverseLink::Standard(StandardLink::Log));
4843        payload
4844    }
4845
4846    /// #1064 regression: a dispersion location-scale (#913) payload must be
4847    /// classified as `DispersionLocationScale` at every predict-path entry —
4848    /// both `from_payload` (load) and `predict_model_class` (runtime) — and never
4849    /// fall through to the binomial threshold-scale class. Before the fix the
4850    /// non-Gaussian `else` arm mis-routed every dispersion model to
4851    /// `BinomialLocationScale`, predicting the wrong family/link.
4852    #[test]
4853    fn dispersion_location_scale_payload_is_not_classified_binomial() {
4854        let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4855        assert_eq!(
4856            model.predict_model_class(),
4857            PredictModelClass::DispersionLocationScale,
4858            "Gamma dispersion location-scale must route through the dispersion \
4859             predictor, not the binomial threshold-scale class",
4860        );
4861        assert!(
4862            !matches!(
4863                model.predict_model_class(),
4864                PredictModelClass::BinomialLocationScale
4865            ),
4866            "dispersion location-scale must never be classified as binomial",
4867        );
4868
4869        // Each of the four #913 dispersion mean families classifies the same way.
4870        for likelihood in [
4871            LikelihoodSpec::gamma_log(),
4872            LikelihoodSpec::new(
4873                ResponseFamily::NegativeBinomial {
4874                    theta: 1.0,
4875                    theta_fixed: false,
4876                },
4877                InverseLink::Standard(StandardLink::Log),
4878            ),
4879            LikelihoodSpec::new(
4880                ResponseFamily::Beta { phi: 1.0 },
4881                InverseLink::Standard(StandardLink::Logit),
4882            ),
4883            LikelihoodSpec::new(
4884                ResponseFamily::Tweedie { p: 1.5 },
4885                InverseLink::Standard(StandardLink::Log),
4886            ),
4887        ] {
4888            let mut payload = gamma_dispersion_location_scale_payload();
4889            payload.family_state = FittedFamily::LocationScale {
4890                base_link: Some(likelihood.link.clone()),
4891                likelihood: likelihood.clone(),
4892            };
4893            let model = FittedModel::from_payload(payload);
4894            assert_eq!(
4895                model.predict_model_class(),
4896                PredictModelClass::DispersionLocationScale,
4897                "dispersion family {:?} mis-classified",
4898                likelihood.response,
4899            );
4900        }
4901    }
4902
4903    #[test]
4904    fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4905        let data = array![[100.0], [-100.0]];
4906        let col_map = HashMap::from([("g".to_string(), 0usize)]);
4907
4908        let mut plain_payload = standard_gaussian_payload();
4909        plain_payload.data_schema = Some(DataSchema {
4910            columns: vec![SchemaColumn {
4911                name: "g".to_string(),
4912                kind: ColumnKindTag::Continuous,
4913                levels: vec![],
4914            }],
4915        });
4916        plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4917        plain_payload.resolved_termspec = Some(empty_termspec());
4918        let plain = FittedModel::from_payload(plain_payload.clone());
4919        let clipped = plain
4920            .axis_clip_to_training_ranges(data.view(), &col_map)
4921            .expect("ordinary continuous axis should clip outside the training range");
4922        assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4923
4924        let mut group_payload = plain_payload;
4925        let mut group_spec = empty_termspec();
4926        group_spec
4927            .random_effect_terms
4928            .push(gam_terms::smooth::RandomEffectTermSpec {
4929                name: "g".to_string(),
4930                feature_col: 0,
4931                drop_first_level: false,
4932                penalized: true,
4933                frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4934            });
4935        group_payload.resolved_termspec = Some(group_spec);
4936        let group_model = FittedModel::from_payload(group_payload);
4937
4938        assert_eq!(
4939            group_model.random_effect_group_columns(),
4940            HashSet::from(["g".to_string()])
4941        );
4942
4943        assert_eq!(
4944            group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4945            None,
4946            "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4947        );
4948    }
4949
4950    #[test]
4951    fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4952        let fit = saved_fit(vec![
4953            FittedBlock {
4954                beta: array![0.1],
4955                role: BlockRole::Mean,
4956                edf: 1.0,
4957                lambdas: Array1::zeros(0),
4958            },
4959            FittedBlock {
4960                beta: array![0.2],
4961                role: BlockRole::Scale,
4962                edf: 1.0,
4963                lambdas: Array1::zeros(0),
4964            },
4965            FittedBlock {
4966                beta: array![0.3],
4967                role: BlockRole::Mean,
4968                edf: 1.0,
4969                lambdas: Array1::zeros(0),
4970            },
4971        ]);
4972        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4973        payload.score_warp_runtime = Some(anchored_runtime(2));
4974
4975        let err = FittedModel::from_payload(payload)
4976            .validate_for_persistence()
4977            .expect_err("marginal-slope score-warp basis mismatch should fail validation");
4978        assert!(err.to_string().contains("score-warp coefficient mismatch"));
4979    }
4980
4981    #[test]
4982    fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
4983        let fit = saved_fit(vec![
4984            FittedBlock {
4985                beta: array![0.1],
4986                role: BlockRole::Time,
4987                edf: 1.0,
4988                lambdas: Array1::zeros(0),
4989            },
4990            FittedBlock {
4991                beta: array![0.2],
4992                role: BlockRole::Mean,
4993                edf: 1.0,
4994                lambdas: Array1::zeros(0),
4995            },
4996            FittedBlock {
4997                beta: array![0.3],
4998                role: BlockRole::Scale,
4999                edf: 1.0,
5000                lambdas: Array1::zeros(0),
5001            },
5002            FittedBlock {
5003                beta: array![0.4],
5004                role: BlockRole::LinkWiggle,
5005                edf: 1.0,
5006                lambdas: Array1::zeros(0),
5007            },
5008        ]);
5009        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5010        payload.link_deviation_runtime = Some(anchored_runtime(2));
5011
5012        let err = FittedModel::from_payload(payload)
5013            .saved_prediction_runtime()
5014            .expect_err(
5015                "survival marginal-slope link basis mismatch should fail runtime validation",
5016            );
5017        assert!(
5018            err.to_string()
5019                .contains("link-deviation coefficient mismatch")
5020        );
5021    }
5022
5023    #[test]
5024    fn apply_survival_time_basis_writes_all_required_fields() {
5025        use crate::survival::construction::SavedSurvivalTimeBasis;
5026
5027        let fit = saved_fit(vec![
5028            FittedBlock {
5029                beta: array![0.1],
5030                role: BlockRole::Time,
5031                edf: 1.0,
5032                lambdas: Array1::zeros(0),
5033            },
5034            FittedBlock {
5035                beta: array![0.2],
5036                role: BlockRole::Mean,
5037                edf: 1.0,
5038                lambdas: Array1::zeros(0),
5039            },
5040            FittedBlock {
5041                beta: array![0.3],
5042                role: BlockRole::Scale,
5043                edf: 1.0,
5044                lambdas: Array1::zeros(0),
5045            },
5046        ]);
5047        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5048
5049        // Snapshot writes must match every persisted survival_time_* field —
5050        // forgetting one is exactly the marginal-slope save
5051        // regression. Routing through `apply_survival_time_basis` is the
5052        // structural contract that prevents that recurrence.
5053        let snapshot = SavedSurvivalTimeBasis {
5054            basisname: "royston-parmar".to_string(),
5055            degree: Some(3),
5056            knots: Some(vec![0.0, 1.0, 2.0]),
5057            keep_cols: Some(vec![0, 2]),
5058            smooth_lambda: Some(0.5),
5059            anchor: 0.25,
5060        };
5061        payload.apply_survival_time_basis(&snapshot);
5062
5063        assert_eq!(
5064            payload.survival_time_basis.as_deref(),
5065            Some("royston-parmar")
5066        );
5067        assert_eq!(payload.survival_time_degree, Some(3));
5068        assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5069        assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5070        assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5071        assert_eq!(payload.survival_time_anchor, Some(0.25));
5072    }
5073
5074    #[test]
5075    fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5076        let fit = saved_fit(vec![
5077            FittedBlock {
5078                beta: array![0.1],
5079                role: BlockRole::Time,
5080                edf: 1.0,
5081                lambdas: Array1::zeros(0),
5082            },
5083            FittedBlock {
5084                beta: array![0.2],
5085                role: BlockRole::Mean,
5086                edf: 1.0,
5087                lambdas: Array1::zeros(0),
5088            },
5089            FittedBlock {
5090                beta: array![0.3],
5091                role: BlockRole::Scale,
5092                edf: 1.0,
5093                lambdas: Array1::zeros(0),
5094            },
5095        ]);
5096        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5097        // Pass the time_basis presence check but deliberately omit the
5098        // anchor — this is exactly the partial-write shape that the CLI's
5099        // marginal-slope+time-wiggle save path had before the structural
5100        // refactor (main.rs previously set basis/degree/knots/keep_cols/
5101        // smooth_lambda but forgot the anchor).
5102        payload.survival_time_basis = Some("ispline".to_string());
5103
5104        let err = FittedModel::from_payload(payload)
5105            .validate_for_persistence()
5106            .expect_err("survival model without time-anchor metadata should fail validation");
5107        assert!(err.to_string().contains("missing survival_time_anchor"));
5108    }
5109
5110    #[test]
5111    fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5112        let fit = saved_fit(vec![
5113            FittedBlock {
5114                beta: array![0.1],
5115                role: BlockRole::Time,
5116                edf: 1.0,
5117                lambdas: Array1::zeros(0),
5118            },
5119            FittedBlock {
5120                beta: array![0.2],
5121                role: BlockRole::Mean,
5122                edf: 1.0,
5123                lambdas: Array1::zeros(0),
5124            },
5125            FittedBlock {
5126                beta: array![0.3],
5127                role: BlockRole::Scale,
5128                edf: 1.0,
5129                lambdas: Array1::zeros(0),
5130            },
5131        ]);
5132        let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5133
5134        let err = FittedModel::from_payload(payload)
5135            .validate_for_persistence()
5136            .expect_err("survival model without time-basis metadata should fail validation");
5137        assert!(err.to_string().contains("missing survival_time_basis"));
5138    }
5139
5140    #[test]
5141    fn saved_prediction_runtime_rejects_stale_payload_version() {
5142        let fit = saved_fit(vec![
5143            FittedBlock {
5144                beta: array![0.1],
5145                role: BlockRole::Mean,
5146                edf: 1.0,
5147                lambdas: Array1::zeros(0),
5148            },
5149            FittedBlock {
5150                beta: array![0.2],
5151                role: BlockRole::Scale,
5152                edf: 1.0,
5153                lambdas: Array1::zeros(0),
5154            },
5155        ]);
5156        let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5157
5158        let err = FittedModel::from_payload(payload)
5159            .saved_prediction_runtime()
5160            .expect_err("stale payload version should fail before runtime assembly");
5161        assert!(err.to_string().contains("payload schema mismatch"));
5162    }
5163}