Skip to main content

gam_models/inference/
model.rs

1use gam_terms::basis::BasisOptions;
2use gam_solve::estimate::{BlockRole, FittedLinkState, UnifiedFitResult};
3use crate::bms::{
4    LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
5};
6use crate::survival::construction::{
7    SurvivalBaselineConfig, SurvivalTimeBasisConfig, parse_survival_baseline_config,
8};
9use crate::survival::location_scale::ResidualDistribution;
10use crate::survival::lognormal_kernel::FrailtySpec;
11use crate::wiggle::{
12    monotone_wiggle_basis_with_derivative_order, validate_monotone_wiggle_beta_nonnegative,
13};
14use gam_terms::inference::formula_dsl::{
15    inverse_link_supports_joint_wiggle, joint_wiggle_unsupported_link_message, parse_formula,
16    parse_surv_interval_response, parse_surv_response, parsed_term_column_names,
17};
18use gam_solve::mixture_link::{state_from_beta_logisticspec, state_from_sasspec};
19use gam_terms::smooth::{AdaptiveRegularizationDiagnostics, TermCollectionSpec};
20use gam_problem::types::{
21    InverseLink, LatentCLogLogState, LikelihoodSpec, MixtureLinkState, ResponseFamily, SasLinkSpec,
22    SasLinkState, StandardLink,
23};
24use gam_runtime::span::span_index_for_breakpoints;
25// The data-schema value types live in the `gam-data` foundation crate; they
26// were previously authored here and are still named `gam::inference::model::{
27// ColumnKindTag, DataSchema, SchemaColumn}` by a broad set of integration tests
28// and by saved-payload consumers. Re-export them so that public path stays
29// valid rather than forcing every caller onto the relocated crate path.
30pub use gam_data::{ColumnKindTag, DataSchema, SchemaColumn};
31use ndarray::{Array1, Array2};
32use serde::{Deserialize, Serialize};
33use serde_json::Value as JsonValue;
34use std::collections::{BTreeMap, HashMap, HashSet};
35use std::fs;
36use std::ops::{Deref, DerefMut};
37use std::path::Path;
38
39/// Canonical saved-model payload schema version.
40///
41/// Every `FittedModelPayload` written by any binary (CLI `gam`, gam-pyffi,
42/// downstream library users) must set this as its `version` field, and every
43/// load path asserts equality via `validate_for_persistence`. Bump this when:
44///   - A required field is added to `FittedModelPayload` and the set of
45///     Option<T> fields that must be `Some(...)` for a given `family_state`
46///     changes (otherwise the `#[serde(default)]` decode would silently fill
47///     the new field with `None` when loading an older model and the CLI
48///     predict path would run with stale metadata).
49///   - The on-wire shape of any `serde`-tagged enum variant changes such that
50///     older payloads no longer round-trip losslessly.
51///   - The semantics of an existing field change (e.g. sign convention,
52///     coordinate frame) in a way that predict output would silently diverge
53///     between old and new readers.
54///
55/// Do NOT bump for purely additive `Option<T>` fields that the save-time
56/// invariant (`validate_for_persistence`) does not yet require. Those are
57/// forward-compatible.
58pub const MODEL_PAYLOAD_VERSION: u32 = 7;
59
60/// Schema-free saved-model metadata keyed by stable group id.
61///
62/// The values are JSON rather than a typed enum because group provenance is
63/// supplied by caller-owned catalogs. `FittedModelPayload::group_metadata`
64/// wraps this in `Option` with `#[serde(default)]`, so model files written
65/// before the field existed deserialize as `None`.
66pub type GroupMetadata = BTreeMap<String, JsonValue>;
67
68/// Saved exact spline-scan fit (#1030/#1034): the predict-time feature column
69/// plus the lossless smoother state the Gaussian-bridge `predict` replays.
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct SavedSplineScan {
72    /// Training column name feeding the single 1-D smooth at predict time.
73    pub feature_column: String,
74    pub state: gam_solve::spline_scan::SplineScanState,
75}
76
77/// Saved multiresolution residual-cascade fit (#1032): the predict-time feature
78/// columns (d ∈ {2, 3}) plus the serializable cascade state that `from_state`
79/// rebuilds a predict-capable `ResidualCascadeFit` from. The cascade is a
80/// DIFFERENT posterior from the dense Duchon/Matérn term — never a silent swap.
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct SavedResidualCascade {
83    /// Training column names for the d ∈ {2, 3} scattered-smooth coordinates.
84    pub feature_columns: Vec<String>,
85    pub state: gam_solve::residual_cascade::ResidualCascadeState,
86}
87
88/// Typed error surface for `src/inference/model.rs` saved-model code.
89///
90/// Every variant carries a free-form `reason: String` payload; `Display`
91/// emits exactly that payload, so converting a `FittedModelError` into
92/// `String` (via the `From` impl below) is byte-equivalent to the pre-
93/// refactor `Err(format!(...))` / `Err("...".to_string())` strings that
94/// the same call sites produced. This lets external callers keep using
95/// `?` against `Result<_, String>` without source changes — the typed
96/// enum is purely an in-module discipline gain.
97#[derive(Clone, Debug, PartialEq, Eq)]
98pub enum FittedModelError {
99    /// Saved payload structure / shape / version disagrees with what the
100    /// current binary expects (e.g. covariance shape, block ordering,
101    /// schema version, C2 continuity, out-of-range span/basis indices).
102    SchemaMismatch { reason: String },
103    /// Saved payload bytes / numeric content are corrupt or unreadable
104    /// (non-finite scalars, invalid JSON, IO failure, malformed stateful
105    /// link state).
106    PayloadCorrupt { reason: String },
107    /// A required field that the current code path needs is absent from
108    /// the payload (typically `..; refit` errors).
109    MissingField { reason: String },
110    /// A combination of saved-model options is not supported by the
111    /// current binary (unsupported deployment-extension kind, unsupported
112    /// kernel marker, unsupported survival_time_basis variant, etc.).
113    IncompatibleConfig { reason: String },
114    /// An input value rejected by a save-time sanity gate (e.g. negative
115    /// ridge alpha).
116    InvalidInput { reason: String },
117}
118
119impl_reason_error_boilerplate! {
120    FittedModelError {
121        SchemaMismatch,
122        PayloadCorrupt,
123        MissingField,
124        IncompatibleConfig,
125        InvalidInput,
126    }
127}
128
129// Boundary conversions so external `Result<_, EstimationError>` /
130// `Result<_, SurvivalPredictError>` call sites can propagate with `?`.
131// Survival prediction keeps the model-layer source so the chain identifies
132// the payload/schema failure that triggered the prediction error.
133impl From<FittedModelError> for gam_solve::model_types::EstimationError {
134    fn from(err: FittedModelError) -> Self {
135        gam_solve::model_types::EstimationError::InvalidInput(err.to_string())
136    }
137}
138
139impl From<FittedModelError> for crate::survival::predict::SurvivalPredictError {
140    fn from(err: FittedModelError) -> Self {
141        crate::survival::predict::SurvivalPredictError::ModelPayload {
142            context: "saved-model survival prediction payload",
143            source: err,
144        }
145    }
146}
147
148#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
149pub struct SavedLatentZNormalization {
150    pub mean: f64,
151    pub sd: f64,
152}
153
154impl SavedLatentZNormalization {
155    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
156        if !self.mean.is_finite() {
157            return Err(FittedModelError::PayloadCorrupt {
158                reason: format!("{context} latent z mean must be finite"),
159            });
160        }
161        if !(self.sd.is_finite() && self.sd > 1e-12) {
162            return Err(FittedModelError::PayloadCorrupt {
163                reason: format!(
164                    "{context} latent z sd must be finite and > 1e-12; got {}",
165                    self.sd
166                ),
167            });
168        }
169        Ok::<(), _>(())
170    }
171
172    pub fn apply(&self, z: &Array1<f64>, context: &str) -> Result<Array1<f64>, FittedModelError> {
173        self.validate(context)?;
174        if z.iter().any(|value| !value.is_finite()) {
175            return Err(FittedModelError::PayloadCorrupt {
176                reason: format!("{context} requires finite z values"),
177            });
178        }
179        Ok(z.mapv(|zi| (zi - self.mean) / self.sd))
180    }
181}
182
183pub const TRANSFORMATION_SCORE_PIT_CLIP_EPS: f64 = 1.0e-12;
184
185#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
186#[serde(rename_all = "kebab-case")]
187#[derive(Default)]
188pub enum TransformationScoreKind {
189    #[default]
190    FiniteSupportPit,
191}
192
193#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
194pub struct TransformationScoreCalibration {
195    #[serde(default)]
196    pub score_kind: TransformationScoreKind,
197    #[serde(default = "default_transformation_score_pit_clip_eps")]
198    pub clip_eps: f64,
199}
200
201const fn default_transformation_score_pit_clip_eps() -> f64 {
202    TRANSFORMATION_SCORE_PIT_CLIP_EPS
203}
204
205impl TransformationScoreCalibration {
206    pub fn finite_support_pit() -> Self {
207        Self {
208            score_kind: TransformationScoreKind::FiniteSupportPit,
209            clip_eps: TRANSFORMATION_SCORE_PIT_CLIP_EPS,
210        }
211    }
212
213    pub fn validate(&self, context: &str) -> Result<(), FittedModelError> {
214        if self.score_kind != TransformationScoreKind::FiniteSupportPit {
215            return Err(FittedModelError::IncompatibleConfig {
216                reason: format!("{context} supports only finite-support CTN PIT score semantics"),
217            });
218        }
219        if !(self.clip_eps.is_finite() && self.clip_eps > 0.0 && self.clip_eps < 0.5) {
220            return Err(FittedModelError::IncompatibleConfig {
221                reason: format!(
222                    "{context} requires PIT clip_eps in (0, 0.5), got {}",
223                    self.clip_eps
224                ),
225            });
226        }
227        Ok::<(), _>(())
228    }
229}
230
231#[derive(Clone, Serialize, Deserialize)]
232pub struct FittedModelPayload {
233    pub version: u32,
234    pub formula: String,
235    pub model_kind: ModelKind,
236    pub family_state: FittedFamily,
237    pub family: String,
238    /// Human-readable advisories produced while materializing this model —
239    /// e.g. an mgcv-style "k was reduced to the data support" note when a
240    /// cubic-regression marginal is capped, or a basis-degradation note. These
241    /// are surfaced to CLI users via `print_inference_summary`; persisting them
242    /// here lets the Python (gamfit) interface surface the SAME advisories as
243    /// warnings / `model.notes` instead of silently dropping them at the FFI
244    /// boundary (#1543). `#[serde(default)]` keeps older payloads (which had no
245    /// such field) deserializing cleanly as "no notes".
246    #[serde(default)]
247    pub inference_notes: Vec<String>,
248    #[serde(default)]
249    pub used_device: bool,
250    #[serde(default)]
251    pub fit_result: Option<UnifiedFitResult>,
252    /// Unified (family-agnostic) representation of the fit result.
253    #[serde(default)]
254    pub unified: Option<UnifiedFitResult>,
255    /// Exact O(n) spline-scan fit representation (#1030/#1034): the
256    /// state-space smoothing-spline posterior of a single 1-D Gaussian
257    /// smooth. When `Some`, this standard Gaussian model's predictions
258    /// replay the Gaussian bridge from this state and the model carries no
259    /// dense `fit_result` (the representations are mutually exclusive —
260    /// enforced by `validate_for_persistence`). `#[serde(default)]` so older
261    /// payloads read as: not a scan model.
262    #[serde(default)]
263    pub spline_scan: Option<SavedSplineScan>,
264    /// O(n log n) multiresolution residual-cascade fit (#1032): the persisted
265    /// multilevel Wendland-frame state for a single scattered 2–3D Gaussian
266    /// smooth past the dense-kernel cliff. When `Some`, predictions replay the
267    /// cascade posterior; mutually exclusive with `spline_scan`/`fit_result`.
268    /// `#[serde(default)]` keeps forward-compatibility with older payloads.
269    #[serde(default)]
270    pub residual_cascade: Option<SavedResidualCascade>,
271    #[serde(default)]
272    pub data_schema: Option<DataSchema>,
273    pub link: Option<InverseLink>,
274    #[serde(default)]
275    pub mixture_link_param_covariance: Option<Vec<Vec<f64>>>,
276    #[serde(default)]
277    pub sas_param_covariance: Option<Vec<Vec<f64>>>,
278    #[serde(default)]
279    pub formula_noise: Option<String>,
280    #[serde(default)]
281    pub formula_logslope: Option<String>,
282    #[serde(default)]
283    pub formula_logslopes: Option<Vec<String>>,
284    #[serde(default)]
285    pub offset_column: Option<String>,
286    #[serde(default)]
287    pub noise_offset_column: Option<String>,
288    /// Name of the analytic prior-weights column used at fit time (`weights=`),
289    /// persisted so replicate/generative sampling can re-resolve the per-row
290    /// weights and draw heteroskedastic Gaussian observation noise
291    /// `sigma_i = sigma_hat / sqrt(w_i)` (#2025). `None` for an unweighted fit.
292    #[serde(default)]
293    pub weight_column: Option<String>,
294    #[serde(default)]
295    pub beta_noise: Option<Vec<f64>>,
296    #[serde(default)]
297    pub noise_projection: Option<Vec<Vec<f64>>>,
298    #[serde(default)]
299    pub noise_center: Option<Vec<f64>>,
300    #[serde(default)]
301    pub noise_scale: Option<Vec<f64>>,
302    #[serde(default)]
303    pub noise_non_intercept_start: Option<usize>,
304    /// Tikhonov ridge alpha used by `solve_scale_projection` when fitting
305    /// `noise_projection`.  Persisted so prediction-time replay is identical
306    /// to fit-time projection.
307    #[serde(default)]
308    pub noise_projection_ridge_alpha: Option<f64>,
309    #[serde(default)]
310    pub gaussian_response_scale: Option<f64>,
311    #[serde(default)]
312    pub linkwiggle_knots: Option<Vec<f64>>,
313    #[serde(default)]
314    pub linkwiggle_degree: Option<usize>,
315    #[serde(default)]
316    pub beta_link_wiggle: Option<Vec<f64>>,
317    #[serde(default)]
318    pub baseline_timewiggle_knots: Option<Vec<f64>>,
319    #[serde(default)]
320    pub baseline_timewiggle_degree: Option<usize>,
321    #[serde(default)]
322    pub baseline_timewiggle_penalty_orders: Option<Vec<usize>>,
323    #[serde(default)]
324    pub baseline_timewiggle_double_penalty: Option<bool>,
325    #[serde(default)]
326    pub beta_baseline_timewiggle: Option<Vec<f64>>,
327    #[serde(default)]
328    pub beta_baseline_timewiggle_by_cause: Option<Vec<Vec<f64>>>,
329    #[serde(default)]
330    pub z_column: Option<String>,
331    #[serde(default)]
332    pub z_columns: Option<Vec<String>>,
333    #[serde(default)]
334    pub latent_z_normalization: Option<SavedLatentZNormalization>,
335    #[serde(default)]
336    pub latent_score_contract: Option<SavedLatentScoreContract>,
337    #[serde(default)]
338    pub latent_measure: Option<LatentMeasureKind>,
339    /// Optional rank-INT calibration for the latent score (BMS family).
340    /// When `Some`, the marginal-slope predictor routes the input `z`
341    /// through [`LatentZRankIntCalibration::apply_at_predict`] before the
342    /// closed-form standard-normal kernel, matching fit-time semantics.
343    /// `#[serde(default)]` so models persisted before this field existed
344    /// continue to deserialize cleanly (interpreted as: no calibration).
345    #[serde(default)]
346    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
347    /// Optional conditional location-scale calibration of the latent score
348    /// (#905, BMS family). When `Some`, the marginal-slope predictor replaces
349    /// the (normalized) input `z` by `ζ = (z − m(C))/√v(C)` — rebuilding the
350    /// conditioning span `a(C)` from the marginal prediction design — before
351    /// the closed-form standard-normal kernel, matching fit-time semantics.
352    /// Mutually exclusive with `latent_z_rank_int_calibration`. `#[serde(default)]`
353    /// so pre-existing models deserialize cleanly (interpreted as: no
354    /// conditional calibration).
355    #[serde(default)]
356    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
357    #[serde(default)]
358    pub marginal_baseline: Option<f64>,
359    #[serde(default)]
360    pub logslope_baseline: Option<f64>,
361    #[serde(default)]
362    pub logslope_baselines: Option<Vec<f64>>,
363    #[serde(default)]
364    pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
365    #[serde(default)]
366    pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
367    /// Width `p₁` of the survival marginal-slope absorbed Stage-1 influence block
368    /// (#461) when present (the dedicated trailing absorber block). Predict drops
369    /// its `γ`; this records the column count so the predictor can account for
370    /// the extra block and slice `γ` out of the joint covariance.
371    #[serde(default)]
372    pub influence_absorber_width: Option<usize>,
373    #[serde(default)]
374    pub survival_entry: Option<String>,
375    #[serde(default)]
376    pub survival_exit: Option<String>,
377    #[serde(default)]
378    pub survival_event: Option<String>,
379    #[serde(default)]
380    pub survivalspec: Option<String>,
381    #[serde(default)]
382    pub survival_cause_count: Option<usize>,
383    #[serde(default)]
384    pub survival_endpoint_names: Option<Vec<String>>,
385    #[serde(default)]
386    pub survival_baseline_target: Option<String>,
387    #[serde(default)]
388    pub survival_baseline_scale: Option<f64>,
389    #[serde(default)]
390    pub survival_baseline_shape: Option<f64>,
391    #[serde(default)]
392    pub survival_baseline_rate: Option<f64>,
393    #[serde(default)]
394    pub survival_baseline_makeham: Option<f64>,
395    #[serde(default)]
396    pub survival_time_basis: Option<String>,
397    #[serde(default)]
398    pub survival_time_degree: Option<usize>,
399    #[serde(default)]
400    pub survival_time_knots: Option<Vec<f64>>,
401    #[serde(default)]
402    pub survival_time_keep_cols: Option<Vec<usize>>,
403    #[serde(default)]
404    pub survival_time_smooth_lambda: Option<f64>,
405    #[serde(default)]
406    pub survival_time_anchor: Option<f64>,
407    #[serde(default)]
408    pub survivalridge_lambda: Option<f64>,
409    #[serde(default)]
410    pub survival_likelihood: Option<String>,
411    #[serde(default)]
412    pub survival_beta_time: Option<Vec<f64>>,
413    #[serde(default)]
414    pub survival_beta_threshold: Option<Vec<f64>>,
415    #[serde(default)]
416    pub survival_beta_log_sigma: Option<Vec<f64>>,
417    #[serde(default)]
418    pub survival_noise_projection: Option<Vec<Vec<f64>>>,
419    #[serde(default)]
420    pub survival_noise_center: Option<Vec<f64>>,
421    #[serde(default)]
422    pub survival_noise_scale: Option<Vec<f64>>,
423    #[serde(default)]
424    pub survival_noise_non_intercept_start: Option<usize>,
425    /// Survival analog of `noise_projection_ridge_alpha`: the Tikhonov ridge
426    /// used when fitting the survival log-sigma projection.
427    #[serde(default)]
428    pub survival_noise_projection_ridge_alpha: Option<f64>,
429    #[serde(default)]
430    pub survival_distribution: Option<ResidualDistribution>,
431    #[serde(default)]
432    pub training_headers: Option<Vec<String>>,
433    /// Container type of the table the model was fitted on, as detected by the
434    /// Python binding (`"pandas"`, `"polars"`, `"pyarrow"`, `"numpy"`, or an
435    /// ambiguous tag such as `"unknown"`). This is presentation provenance, not
436    /// math: `gamfit.Model.predict` uses it as the output-container fallback
437    /// when the *prediction input* is itself container-ambiguous (a `dict` of
438    /// columns or a `list` of record dicts). Persisting it in the model payload
439    /// makes the fallback survive `save`/`load` and `dumps`/`loads`, so a
440    /// reloaded model predicts into the same container as the in-memory one.
441    /// `None` for older payloads (and for fits that never recorded a kind), in
442    /// which case the fallback degrades to `"dict"`, matching pre-persistence
443    /// behaviour for unknown-kind training tables.
444    #[serde(default, skip_serializing_if = "Option::is_none")]
445    pub training_table_kind: Option<String>,
446    /// Per-column (min, max) of the training input matrix, parallel to
447    /// `training_headers`. At predict time, inputs are axis-clipped to these
448    /// ranges so that out-of-distribution points evaluate at the nearest face
449    /// of the training bounding box rather than extrapolating polynomial
450    /// trends from polyharmonic / spline bases beyond the data envelope. Old
451    /// model JSONs that pre-date this field load with `None`, in which case
452    /// the predict path falls through unchanged (no clipping).
453    #[serde(default)]
454    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
455    /// User-supplied per-group metadata, keyed by stable group identifier.
456    ///
457    /// This is intentionally schema-free JSON so provenance maps can carry
458    /// mixed scalar/list/object values. Missing in older payloads means no
459    /// group metadata was persisted.
460    #[serde(default, skip_serializing_if = "Option::is_none")]
461    pub group_metadata: Option<GroupMetadata>,
462    /// Deployment-time no-refit group extensions applied after fitting.
463    ///
464    /// Each entry records the requested group coordinate, caller metadata, and
465    /// prior used to initialize the inserted coefficient. The active
466    /// prediction contract lives in `data_schema` + `resolved_termspec`; this
467    /// ledger preserves provenance without requiring a refit.
468    #[serde(default, skip_serializing_if = "Vec::is_empty")]
469    pub deployment_extensions: Vec<SavedDeploymentExtension>,
470    /// Transformation-normal: B-spline knots for the response-direction basis.
471    #[serde(default)]
472    pub transformation_response_knots: Option<Vec<f64>>,
473    /// Transformation-normal: deviation nullspace transform matrix (row-major).
474    #[serde(default)]
475    pub transformation_response_transform: Option<Vec<Vec<f64>>>,
476    /// Transformation-normal: B-spline degree for the response basis.
477    #[serde(default)]
478    pub transformation_response_degree: Option<usize>,
479    /// Transformation-normal: median of the response used for anchoring.
480    #[serde(default)]
481    pub transformation_response_median: Option<f64>,
482    /// Transformation-normal saved score contract. The score is the exact
483    /// finite-support PIT:
484    /// z = Phi^{-1}((Phi(h) - Phi(h_L)) / (Phi(h_U) - Phi(h_L))).
485    #[serde(default)]
486    pub transformation_score_calibration: Option<TransformationScoreCalibration>,
487    #[serde(default)]
488    pub resolved_termspec: Option<TermCollectionSpec>,
489    #[serde(default)]
490    pub resolved_termspec_noise: Option<TermCollectionSpec>,
491    #[serde(default)]
492    pub resolved_termspec_logslope: Option<TermCollectionSpec>,
493    #[serde(default)]
494    pub resolved_termspec_logslopes: Option<Vec<TermCollectionSpec>>,
495    #[serde(default)]
496    pub adaptive_regularization_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
497    /// Precomputed exact Gaussian-identity jackknife+ statistics (#942).
498    ///
499    /// Populated *only* for a standard Gaussian-identity model fit with unit
500    /// prior weights, where the closed-form Sherman–Morrison leave-one-out
501    /// substrate gives a distribution-free finite-sample (≥ level) prediction
502    /// interval with no held-out fold. When `Some`, `predict(interval=level)`
503    /// auto-routes through it (the MAGIC default); when `None` — any other
504    /// family/link, reweighted rows, or an older payload — predict falls back
505    /// to the model-based posterior band and labels the provenance honestly.
506    /// `#[serde(default)]` so pre-existing models deserialize as: no jackknife+
507    /// substrate available.
508    #[serde(default)]
509    pub gaussian_jackknife_plus:
510        Option<crate::inference::full_conformal::GaussianJackknifePlusStats>,
511    /// Precomputed substrate for the EXACT Gaussian-identity full-conformal set
512    /// (#942 Layer 1 + the frozen-ρ self-diagnostic).
513    ///
514    /// Populated under the SAME eligibility as `gaussian_jackknife_plus`
515    /// (Gaussian-identity, unit prior weights, offset-free, no link wiggle). It
516    /// persists the training design + response + frozen penalty `Sλ` so the
517    /// distribution-free EXACT prediction set (a union of intervals, valid for
518    /// any penalized smooth) can be replayed per test point — one Cholesky each,
519    /// zero refits — and surfaces the frozen-ρ certificate flag. `None` for any
520    /// ineligible model or an older payload, in which case the exact-set predict
521    /// path errors with a clear message and the caller uses jackknife+ or the
522    /// posterior band. `#[serde(default)]` so pre-existing models deserialize as
523    /// no exact substrate available.
524    #[serde(default)]
525    pub full_conformal: Option<crate::inference::full_conformal::ExactFullConformalSubstrate>,
526}
527
528#[derive(Clone, Debug, Serialize, Deserialize)]
529pub struct SavedDeploymentExtension {
530    pub name: String,
531    pub kind: String,
532    pub term: String,
533    pub level: JsonValue,
534    pub level_bits: u64,
535    pub coefficient_index: usize,
536    pub coefficient_mean: f64,
537    pub coefficient_variance: f64,
538    #[serde(default, skip_serializing_if = "Option::is_none")]
539    pub metadata: Option<JsonValue>,
540    #[serde(default, skip_serializing_if = "Option::is_none")]
541    pub prior: Option<JsonValue>,
542}
543
544/// Append deployment-only extension columns to the fitted design coordinate system.
545///
546/// No-refit group extension adds a new coefficient block after the fitted
547/// coefficient vector:
548///
549///   beta_ext = [beta_old, beta_new],    beta_new = mu_new.
550///
551/// For a new random-effect level g, the appended basis is the indicator
552/// e_g(x_i) = 1{x_i == g}.  The old fitted basis X_old is not rebuilt or
553/// reordered, so rows that do not exercise g have
554///
555///   eta_ext = X_old beta_old + 0 * beta_new = eta_old.
556///
557/// Rows at the new level get the exact prior-mean shift e_g beta_new.  This
558/// helper enforces the coordinate identity by requiring extension coefficient
559/// indices to be the consecutive tail columns of the base design.
560pub fn append_deployment_extension_columns(
561    model: &FittedModelPayload,
562    data: ndarray::ArrayView2<'_, f64>,
563    col_map: &HashMap<String, usize>,
564    training_headers: Option<&Vec<String>>,
565    base_design: Array2<f64>,
566) -> Result<Array2<f64>, FittedModelError> {
567    if model.deployment_extensions.is_empty() {
568        return Ok(base_design);
569    }
570    if base_design.nrows() != data.nrows() {
571        return Err(FittedModelError::SchemaMismatch {
572            reason: format!(
573                "deployment extension design row mismatch: base design has {} rows but data has {}",
574                base_design.nrows(),
575                data.nrows()
576            ),
577        });
578    }
579    let spec = model
580        .resolved_termspec
581        .as_ref()
582        .ok_or_else(|| FittedModelError::MissingField {
583            reason: "deployment extension prediction requires saved resolved_termspec; refit"
584                .to_string(),
585        })?;
586    let n = base_design.nrows();
587    let p_old = base_design.ncols();
588    let mut extensions: Vec<&SavedDeploymentExtension> =
589        model.deployment_extensions.iter().collect();
590    extensions.sort_by_key(|extension| extension.coefficient_index);
591    for (tail_idx, extension) in extensions.iter().enumerate() {
592        let expected = p_old + tail_idx;
593        if extension.coefficient_index != expected {
594            return Err(FittedModelError::SchemaMismatch {
595                reason: format!(
596                    "deployment extension '{}' has coefficient index {}, expected append-only index {}",
597                    extension.name, extension.coefficient_index, expected
598                ),
599            });
600        }
601    }
602
603    let mut out = Array2::<f64>::zeros((n, p_old + extensions.len()));
604    out.slice_mut(ndarray::s![.., ..p_old]).assign(&base_design);
605    for (tail_idx, extension) in extensions.into_iter().enumerate() {
606        if extension.kind != "random-effect-level" {
607            return Err(FittedModelError::IncompatibleConfig {
608                reason: format!(
609                    "unsupported deployment extension kind '{}' for '{}'",
610                    extension.kind, extension.name
611                ),
612            });
613        }
614        let term = spec
615            .random_effect_terms
616            .iter()
617            .find(|term| term.name == extension.term)
618            .ok_or_else(|| FittedModelError::MissingField {
619                reason: format!(
620                    "deployment extension '{}' references unknown random-effect term '{}'",
621                    extension.name, extension.term
622                ),
623            })?;
624        let prediction_col = training_headers
625            .and_then(|headers| headers.get(term.feature_col))
626            .and_then(|name| col_map.get(name))
627            .copied()
628            .unwrap_or(term.feature_col);
629        if prediction_col >= data.ncols() {
630            return Err(FittedModelError::SchemaMismatch {
631                reason: format!(
632                    "deployment extension '{}' feature column {} out of bounds for {} prediction columns",
633                    extension.name,
634                    prediction_col,
635                    data.ncols()
636                ),
637            });
638        }
639        let col = p_old + tail_idx;
640        for row in 0..n {
641            if data[[row, prediction_col]].to_bits() == extension.level_bits {
642                out[[row, col]] = 1.0;
643            }
644        }
645    }
646    Ok(out)
647}
648
649#[derive(Clone, Debug, Serialize, Deserialize)]
650pub struct SavedLatentScoreContract {
651    pub semantics: String,
652    pub source_transform_id: Option<String>,
653    pub normalization_mean: f64,
654    pub normalization_sd: f64,
655    pub clip_eps: Option<f64>,
656    pub conditioning_columns: Vec<String>,
657}
658
659impl FittedModelPayload {
660    pub fn new(
661        version: u32,
662        formula: String,
663        model_kind: ModelKind,
664        family_state: FittedFamily,
665        family: String,
666    ) -> Self {
667        Self {
668            version,
669            formula,
670            model_kind,
671            family_state,
672            family,
673            inference_notes: Vec::new(),
674            used_device: false,
675            fit_result: None,
676            unified: None,
677            spline_scan: None,
678            residual_cascade: None,
679            data_schema: None,
680            link: None,
681            mixture_link_param_covariance: None,
682            sas_param_covariance: None,
683            formula_noise: None,
684            formula_logslope: None,
685            formula_logslopes: None,
686            offset_column: None,
687            noise_offset_column: None,
688            weight_column: None,
689            beta_noise: None,
690            noise_projection: None,
691            noise_center: None,
692            noise_scale: None,
693            noise_non_intercept_start: None,
694            noise_projection_ridge_alpha: None,
695            gaussian_response_scale: None,
696            linkwiggle_knots: None,
697            linkwiggle_degree: None,
698            beta_link_wiggle: None,
699            baseline_timewiggle_knots: None,
700            baseline_timewiggle_degree: None,
701            baseline_timewiggle_penalty_orders: None,
702            baseline_timewiggle_double_penalty: None,
703            beta_baseline_timewiggle: None,
704            beta_baseline_timewiggle_by_cause: None,
705            z_column: None,
706            z_columns: None,
707            latent_z_normalization: None,
708            latent_score_contract: None,
709            latent_measure: None,
710            latent_z_rank_int_calibration: None,
711            latent_z_conditional_calibration: None,
712            marginal_baseline: None,
713            logslope_baseline: None,
714            logslope_baselines: None,
715            score_warp_runtime: None,
716            link_deviation_runtime: None,
717            influence_absorber_width: None,
718            survival_entry: None,
719            survival_exit: None,
720            survival_event: None,
721            survivalspec: None,
722            survival_cause_count: None,
723            survival_endpoint_names: None,
724            survival_baseline_target: None,
725            survival_baseline_scale: None,
726            survival_baseline_shape: None,
727            survival_baseline_rate: None,
728            survival_baseline_makeham: None,
729            survival_time_basis: None,
730            survival_time_degree: None,
731            survival_time_knots: None,
732            survival_time_keep_cols: None,
733            survival_time_smooth_lambda: None,
734            survival_time_anchor: None,
735            survivalridge_lambda: None,
736            survival_likelihood: None,
737            survival_beta_time: None,
738            survival_beta_threshold: None,
739            survival_beta_log_sigma: None,
740            survival_noise_projection: None,
741            survival_noise_center: None,
742            survival_noise_scale: None,
743            survival_noise_non_intercept_start: None,
744            survival_noise_projection_ridge_alpha: None,
745            survival_distribution: None,
746            training_headers: None,
747            training_table_kind: None,
748            training_feature_ranges: None,
749            group_metadata: None,
750            deployment_extensions: Vec::new(),
751            transformation_response_knots: None,
752            transformation_response_transform: None,
753            transformation_response_degree: None,
754            transformation_response_median: None,
755            transformation_score_calibration: None,
756            resolved_termspec: None,
757            resolved_termspec_noise: None,
758            resolved_termspec_logslope: None,
759            resolved_termspec_logslopes: None,
760            adaptive_regularization_diagnostics: None,
761            gaussian_jackknife_plus: None,
762            full_conformal: None,
763        }
764    }
765
766    pub fn set_training_feature_metadata(
767        &mut self,
768        headers: Vec<String>,
769        feature_ranges: Vec<(f64, f64)>,
770    ) {
771        self.training_headers = Some(headers);
772        self.training_feature_ranges = Some(feature_ranges);
773    }
774
775    fn synchronize_empty_feature_contract(&mut self) {
776        if self.fit_result.is_none() {
777            return;
778        }
779        let Some(schema) = self.data_schema.as_ref() else {
780            return;
781        };
782        if !schema.columns.is_empty() {
783            return;
784        }
785        self.training_headers.get_or_insert_with(Vec::new);
786        self.resolved_termspec
787            .get_or_insert_with(|| TermCollectionSpec {
788                linear_terms: Vec::new(),
789                smooth_terms: Vec::new(),
790                random_effect_terms: Vec::new(),
791            });
792    }
793
794    /// Write the persistable time-basis snapshot for a survival model.
795    ///
796    /// This is the only path that should populate the `survival_time_*`
797    /// fields used by the loader. Routing every FFI builder through this
798    /// helper guarantees no builder can silently drop a field — the
799    /// marginal-slope save→load bug was a builder that
800    /// missed `survival_time_basis`.
801    pub fn apply_survival_time_basis(
802        &mut self,
803        snapshot: &crate::survival::construction::SavedSurvivalTimeBasis,
804    ) {
805        self.survival_time_basis = Some(snapshot.basisname.clone());
806        self.survival_time_degree = snapshot.degree;
807        self.survival_time_knots = snapshot.knots.clone();
808        self.survival_time_keep_cols = snapshot.keep_cols.clone();
809        self.survival_time_smooth_lambda = snapshot.smooth_lambda;
810        self.survival_time_anchor = Some(snapshot.anchor);
811    }
812
813    fn validate_payload_version(&self) -> Result<(), FittedModelError> {
814        if self.version != MODEL_PAYLOAD_VERSION {
815            return Err(FittedModelError::SchemaMismatch {
816                reason: format!(
817                    "saved model payload schema mismatch: file has version={}, \
818                 this binary expects MODEL_PAYLOAD_VERSION={}. \
819                 Refit with the current CLI, or rebuild the reader at the same \
820                 version the model was written with.",
821                    self.version, MODEL_PAYLOAD_VERSION
822                ),
823            });
824        }
825        Ok(())
826    }
827}
828
829#[derive(Clone, Serialize, Deserialize)]
830#[serde(tag = "model_type", rename_all = "kebab-case")]
831pub enum FittedModel {
832    Standard { payload: FittedModelPayload },
833    LocationScale { payload: FittedModelPayload },
834    MarginalSlope { payload: FittedModelPayload },
835    Survival { payload: FittedModelPayload },
836    TransformationNormal { payload: FittedModelPayload },
837}
838
839#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
840#[serde(rename_all = "kebab-case")]
841pub enum ModelKind {
842    Standard,
843    LocationScale,
844    MarginalSlope,
845    Survival,
846    TransformationNormal,
847}
848
849#[derive(Clone, Debug, Serialize, Deserialize)]
850#[serde(tag = "family_kind", rename_all = "kebab-case")]
851pub enum FittedFamily {
852    Standard {
853        likelihood: LikelihoodSpec,
854        #[serde(default)]
855        link: Option<StandardLink>,
856        #[serde(default)]
857        latent_cloglog_state: Option<LatentCLogLogState>,
858        #[serde(default)]
859        mixture_state: Option<MixtureLinkState>,
860        #[serde(default)]
861        sas_state: Option<SasLinkState>,
862    },
863    LocationScale {
864        likelihood: LikelihoodSpec,
865        #[serde(default)]
866        base_link: Option<InverseLink>,
867    },
868    MarginalSlope {
869        likelihood: LikelihoodSpec,
870        base_link: InverseLink,
871        frailty: FrailtySpec,
872    },
873    Survival {
874        likelihood: LikelihoodSpec,
875        #[serde(default)]
876        survival_likelihood: Option<String>,
877        #[serde(default)]
878        survival_distribution: Option<ResidualDistribution>,
879        frailty: FrailtySpec,
880    },
881    LatentSurvival {
882        frailty: FrailtySpec,
883    },
884    LatentBinary {
885        frailty: FrailtySpec,
886    },
887    TransformationNormal {
888        likelihood: LikelihoodSpec,
889    },
890}
891
892#[derive(Clone, Copy, Debug, Eq, PartialEq)]
893pub enum PredictModelClass {
894    Standard,
895    GaussianLocationScale,
896    BinomialLocationScale,
897    /// Genuine-dispersion location-scale (#913): NegativeBinomial / Gamma / Beta
898    /// / Tweedie mean families fitted with a `noise_formula` overdispersion
899    /// channel. Predicted through the GLM mean inverse link (not the binomial
900    /// threshold-scale predictor).
901    DispersionLocationScale,
902    BernoulliMarginalSlope,
903    Survival,
904    TransformationNormal,
905}
906
907impl PredictModelClass {
908    #[inline]
909    pub const fn name(self) -> &'static str {
910        match self {
911            Self::Standard => "standard",
912            Self::GaussianLocationScale => "gaussian location-scale",
913            Self::BinomialLocationScale => "binomial location-scale",
914            Self::DispersionLocationScale => "dispersion location-scale",
915            Self::BernoulliMarginalSlope => "bernoulli marginal-slope",
916            Self::Survival => "survival",
917            Self::TransformationNormal => "transformation-normal",
918        }
919    }
920}
921
922#[derive(Clone, Debug)]
923pub struct SavedLinkWiggleRuntime {
924    pub knots: Vec<f64>,
925    pub degree: usize,
926    pub beta: Vec<f64>,
927}
928
929#[derive(Clone, Debug)]
930pub struct SavedBaselineTimeWiggleRuntime {
931    pub knots: Vec<f64>,
932    pub degree: usize,
933    pub penalty_orders: Vec<usize>,
934    pub double_penalty: bool,
935    pub beta: Vec<f64>,
936}
937
938// Re-export so saved-model consumers can refer to the anchor-block tag
939// without reaching across module boundaries.
940pub use crate::bms::deviation_runtime::ParametricAnchorBlock;
941
942#[derive(Clone, Debug, Serialize, Deserialize)]
943pub struct SavedCompiledFlexBlock {
944    pub kernel: String,
945    pub breakpoints: Vec<f64>,
946    pub basis_dim: usize,
947    pub span_c0: Vec<Vec<f64>>,
948    pub span_c1: Vec<Vec<f64>>,
949    pub span_c2: Vec<Vec<f64>>,
950    pub span_c3: Vec<Vec<f64>>,
951    /// Cross-block anchor-residual coefficient matrix `M` of shape
952    /// `d × basis_dim`. When present, predict-time evaluation subtracts
953    /// `n_row · M` from each cubic-span row (where `n_row` stacks the
954    /// per-row parametric anchor values in the order given by
955    /// `anchor_components`).
956    #[serde(default)]
957    pub anchor_correction: Option<Vec<Vec<f64>>>,
958    /// Ordered list of parametric anchor components whose stacked row
959    /// values combine into `n_row`. Empty unless
960    /// `anchor_correction` is `Some`.
961    #[serde(default)]
962    pub anchor_components: Vec<SavedAnchorComponent>,
963}
964
965#[derive(Clone, Debug, Serialize, Deserialize)]
966pub struct SavedAnchorComponent {
967    pub kind: SavedAnchorKind,
968}
969
970#[derive(Clone, Debug, Serialize, Deserialize)]
971pub enum SavedAnchorKind {
972    Parametric {
973        block: ParametricAnchorBlock,
974        ncols: usize,
975    },
976    /// Flex-evaluation anchor (sibling flex block's reparameterised basis,
977    /// evaluated at training rows at fit time and at predict rows at
978    /// predict time). The predictor stacks `ncols` columns from the
979    /// sibling runtime's `design(arg)` into `n_row`.
980    FlexEvaluation { ncols: usize },
981}
982
983#[derive(Clone, Debug)]
984pub struct SavedPredictionRuntime {
985    pub model_class: PredictModelClass,
986    pub likelihood: LikelihoodSpec,
987    pub inverse_link: Option<InverseLink>,
988    pub link_wiggle: Option<SavedLinkWiggleRuntime>,
989    pub baseline_time_wiggle: Option<SavedBaselineTimeWiggleRuntime>,
990    pub score_warp: Option<SavedCompiledFlexBlock>,
991    pub link_deviation: Option<SavedCompiledFlexBlock>,
992    /// Rank-INT latent-z calibration carried into the predictor build.
993    /// `None` for non-BMS models and for BMS fits whose latent measure
994    /// did not require rank-INT calibration.
995    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
996    /// Conditional location-scale latent-z calibration (#905) carried into the
997    /// predictor build. `None` for non-BMS models and for BMS fits whose Auto
998    /// path did not detect a conditional `E[z|C]`/`Var(z|C)` shift. When
999    /// `Some`, the predictor replaces the normalized `z` by `ζ = (z−m(C))/√v(C)`
1000    /// using the marginal prediction design as the conditioning span.
1001    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
1002    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the
1003    /// survival marginal-slope fit hosted a dedicated additive absorber (the
1004    /// trailing block). `None` when no CTN Stage-1 chain produced an influence
1005    /// Jacobian. At predict the absorber's `γ` is DROPPED (the orthogonalized
1006    /// β̂ is a training-fit property), so the predictor uses this width only to
1007    /// (a) account for the extra trailing block in the saved block count and
1008    /// (b) slice `γ`'s rows/cols out of the joint covariance. Survival hosts the
1009    /// absorber as its own block (unlike the BMS A2 widened-marginal design),
1010    /// so it never widens any persisted prediction design.
1011    pub influence_absorber_width: Option<usize>,
1012}
1013
1014pub fn gaussian_location_scale_mean_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1015    fit.block_by_role(BlockRole::Location)
1016        .or_else(|| fit.block_by_role(BlockRole::Mean))
1017        .map(|block| block.beta.clone())
1018}
1019
1020pub fn binomial_location_scale_threshold_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1021    fit.block_by_role(BlockRole::Threshold)
1022        .or_else(|| fit.block_by_role(BlockRole::Location))
1023        .or_else(|| fit.block_by_role(BlockRole::Mean))
1024        .map(|block| block.beta.clone())
1025}
1026
1027pub fn location_scale_noise_beta(fit: &UnifiedFitResult) -> Option<Array1<f64>> {
1028    fit.block_by_role(BlockRole::Scale)
1029        .map(|block| block.beta.clone())
1030}
1031
1032/// Whether a `ModelKind::LocationScale` likelihood's response is one of the
1033/// genuine-dispersion mean families (#913) — NegativeBinomial, Gamma, Beta or
1034/// Tweedie. These carry a `noise_formula` overdispersion channel and must be
1035/// predicted through the GLM mean inverse link (the
1036/// [`PredictModelClass::DispersionLocationScale`] path), NOT the binomial
1037/// threshold-scale predictor. The binomial location-scale (BMS ordinal) path is
1038/// the only other non-Gaussian location-scale family, with a `Binomial`
1039/// response.
1040fn is_dispersion_location_scale_response(response: &gam_problem::types::ResponseFamily) -> bool {
1041    use gam_problem::types::ResponseFamily;
1042    matches!(
1043        response,
1044        ResponseFamily::NegativeBinomial { .. }
1045            | ResponseFamily::Gamma
1046            | ResponseFamily::Beta { .. }
1047            | ResponseFamily::Tweedie { .. }
1048    )
1049}
1050
1051fn validate_location_scale_saved_fit(
1052    fit: &UnifiedFitResult,
1053    model_class: PredictModelClass,
1054    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1055) -> Result<(), FittedModelError> {
1056    let primary = match model_class {
1057        // Gaussian and dispersion (#913) location-scale both predict the mean
1058        // through the Location block; the binomial threshold-scale class reads
1059        // the Threshold block instead.
1060        PredictModelClass::GaussianLocationScale | PredictModelClass::DispersionLocationScale => {
1061            gaussian_location_scale_mean_beta(fit)
1062        }
1063        PredictModelClass::BinomialLocationScale => binomial_location_scale_threshold_beta(fit),
1064        _ => None,
1065    }
1066    .ok_or_else(|| FittedModelError::MissingField {
1067        reason: match model_class {
1068            PredictModelClass::GaussianLocationScale => {
1069                "gaussian-location-scale saved fit is missing mean/location block".to_string()
1070            }
1071            PredictModelClass::DispersionLocationScale => {
1072                "dispersion-location-scale saved fit is missing mean/location block".to_string()
1073            }
1074            PredictModelClass::BinomialLocationScale => {
1075                "binomial-location-scale saved fit is missing threshold/location block".to_string()
1076            }
1077            _ => "location-scale saved fit is missing primary block".to_string(),
1078        },
1079    })?;
1080
1081    let scale = location_scale_noise_beta(fit).ok_or_else(|| FittedModelError::MissingField {
1082        reason: "location-scale saved fit is missing scale block".to_string(),
1083    })?;
1084    let expected =
1085        primary.len() + scale.len() + link_wiggle.map_or(0, |runtime| runtime.beta.len());
1086
1087    if let Some(cov) = fit.beta_covariance()
1088        && (cov.nrows() != expected || cov.ncols() != expected)
1089    {
1090        return Err(FittedModelError::SchemaMismatch {
1091            reason: format!(
1092                "location-scale saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1093                cov.nrows(),
1094                cov.ncols(),
1095                expected,
1096                expected
1097            ),
1098        });
1099    }
1100    if let Some(cov) = fit.beta_covariance_corrected()
1101        && (cov.nrows() != expected || cov.ncols() != expected)
1102    {
1103        return Err(FittedModelError::SchemaMismatch {
1104            reason: format!(
1105                "location-scale saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1106                cov.nrows(),
1107                cov.ncols(),
1108                expected,
1109                expected
1110            ),
1111        });
1112    }
1113    Ok(())
1114}
1115
1116fn validate_survival_saved_block_matches_payload(
1117    fit: &UnifiedFitResult,
1118    role: BlockRole,
1119    payload_beta: Option<&Vec<f64>>,
1120    label: &str,
1121) -> Result<usize, FittedModelError> {
1122    let block = fit
1123        .block_by_role(role)
1124        .ok_or_else(|| FittedModelError::MissingField {
1125            reason: format!("location-scale survival saved fit is missing {label} block"),
1126        })?;
1127    if let Some(saved) = payload_beta
1128        && block.beta.to_vec() != *saved
1129    {
1130        return Err(FittedModelError::SchemaMismatch {
1131            reason: format!(
1132                "location-scale survival saved {label} coefficients disagree with fit_result"
1133            ),
1134        });
1135    }
1136    Ok(block.beta.len())
1137}
1138
1139fn validate_survival_location_scale_saved_fit(
1140    payload: &FittedModelPayload,
1141    link_wiggle: Option<&SavedLinkWiggleRuntime>,
1142) -> Result<(), FittedModelError> {
1143    let fit = payload
1144        .fit_result
1145        .as_ref()
1146        .ok_or_else(|| FittedModelError::MissingField {
1147            reason: "location-scale survival model is missing canonical fit_result payload"
1148                .to_string(),
1149        })?;
1150    let p_time = validate_survival_saved_block_matches_payload(
1151        fit,
1152        BlockRole::Time,
1153        payload.survival_beta_time.as_ref(),
1154        "time",
1155    )?;
1156    let p_threshold = validate_survival_saved_block_matches_payload(
1157        fit,
1158        BlockRole::Threshold,
1159        payload.survival_beta_threshold.as_ref(),
1160        "threshold",
1161    )?;
1162    let p_log_sigma = validate_survival_saved_block_matches_payload(
1163        fit,
1164        BlockRole::Scale,
1165        payload.survival_beta_log_sigma.as_ref(),
1166        "log-sigma",
1167    )?;
1168    let p_wiggle = match link_wiggle {
1169        Some(runtime) => {
1170            let block = fit.block_by_role(BlockRole::LinkWiggle).ok_or_else(|| {
1171                FittedModelError::MissingField {
1172                    reason: "location-scale survival saved fit is missing link-wiggle block"
1173                        .to_string(),
1174                }
1175            })?;
1176            if block.beta.to_vec() != runtime.beta {
1177                return Err(FittedModelError::SchemaMismatch {
1178                    reason:
1179                        "location-scale survival saved link-wiggle coefficients disagree with fit_result"
1180                            .to_string(),
1181                });
1182            }
1183            runtime.beta.len()
1184        }
1185        None => {
1186            if fit.block_by_role(BlockRole::LinkWiggle).is_some() {
1187                return Err(FittedModelError::SchemaMismatch {
1188                    reason:
1189                        "location-scale survival saved fit has a LinkWiggle block without payload metadata"
1190                            .to_string(),
1191                });
1192            }
1193            0
1194        }
1195    };
1196    let expected = p_time + p_threshold + p_log_sigma + p_wiggle;
1197
1198    if let Some(cov) = fit.beta_covariance()
1199        && (cov.nrows() != expected || cov.ncols() != expected)
1200    {
1201        return Err(FittedModelError::SchemaMismatch {
1202            reason: format!(
1203                "location-scale survival saved conditional covariance shape mismatch: got {}x{}, expected {}x{}",
1204                cov.nrows(),
1205                cov.ncols(),
1206                expected,
1207                expected
1208            ),
1209        });
1210    }
1211    if let Some(cov) = fit.beta_covariance_corrected()
1212        && (cov.nrows() != expected || cov.ncols() != expected)
1213    {
1214        return Err(FittedModelError::SchemaMismatch {
1215            reason: format!(
1216                "location-scale survival saved corrected covariance shape mismatch: got {}x{}, expected {}x{}",
1217                cov.nrows(),
1218                cov.ncols(),
1219                expected,
1220                expected
1221            ),
1222        });
1223    }
1224    Ok(())
1225}
1226
1227fn validate_marginal_slope_saved_fit(
1228    fit: &UnifiedFitResult,
1229    score_warp: Option<&SavedCompiledFlexBlock>,
1230    link_deviation: Option<&SavedCompiledFlexBlock>,
1231    fit_label: &str,
1232) -> Result<(), FittedModelError> {
1233    validate_marginal_slope_saved_fit_impl(
1234        fit,
1235        score_warp,
1236        link_deviation,
1237        fit_label,
1238        "bernoulli",
1239        2,
1240        "marginal, logslope",
1241    )
1242}
1243
1244fn validate_survival_marginal_slope_saved_fit(
1245    fit: &UnifiedFitResult,
1246    score_warp: Option<&SavedCompiledFlexBlock>,
1247    link_deviation: Option<&SavedCompiledFlexBlock>,
1248    fit_label: &str,
1249) -> Result<(), FittedModelError> {
1250    validate_marginal_slope_saved_fit_impl(
1251        fit,
1252        score_warp,
1253        link_deviation,
1254        fit_label,
1255        "survival",
1256        3,
1257        "time, marginal, slope",
1258    )
1259}
1260
1261/// Shared block-count + coefficient-dimension validation for the bernoulli
1262/// and survival marginal-slope saved-fit gates. The only family-specific
1263/// inputs are the family kind string ("bernoulli" / "survival"), the base
1264/// block count (2 for bernoulli, 3 for survival — the survival path has an
1265/// extra time block), and the base-block role list rendered in the error
1266/// message ("marginal, logslope" / "time, marginal, slope"). The score-warp
1267/// / link-deviation tail follows the same shape in both families.
1268fn validate_marginal_slope_saved_fit_impl(
1269    fit: &UnifiedFitResult,
1270    score_warp: Option<&SavedCompiledFlexBlock>,
1271    link_deviation: Option<&SavedCompiledFlexBlock>,
1272    fit_label: &str,
1273    family_kind: &str,
1274    base_block_count: usize,
1275    base_block_role_list: &str,
1276) -> Result<(), FittedModelError> {
1277    let expected_blocks = base_block_count
1278        + usize::from(score_warp.is_some())
1279        + usize::from(link_deviation.is_some());
1280    if fit.blocks.len() != expected_blocks {
1281        let score_warp_suffix = if score_warp.is_some() {
1282            ", score-warp"
1283        } else {
1284            ""
1285        };
1286        let link_deviation_suffix = if link_deviation.is_some() {
1287            ", link-deviation"
1288        } else {
1289            ""
1290        };
1291        return Err(FittedModelError::SchemaMismatch {
1292            reason: format!(
1293                "{family_kind} marginal-slope saved {fit_label} requires {expected_blocks} blocks [{base_block_role_list}{score_warp_suffix}{link_deviation_suffix}], got {}",
1294                fit.blocks.len(),
1295            ),
1296        });
1297    }
1298    if let Some(runtime) = score_warp {
1299        let beta = &fit.blocks[base_block_count].beta;
1300        if beta.len() != runtime.basis_dim {
1301            return Err(FittedModelError::SchemaMismatch {
1302                reason: format!(
1303                    "{family_kind} marginal-slope saved {fit_label} score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
1304                    beta.len(),
1305                    runtime.basis_dim
1306                ),
1307            });
1308        }
1309    }
1310    if let Some(runtime) = link_deviation {
1311        let idx = base_block_count + usize::from(score_warp.is_some());
1312        let beta = &fit.blocks[idx].beta;
1313        if beta.len() != runtime.basis_dim {
1314            return Err(FittedModelError::SchemaMismatch {
1315                reason: format!(
1316                    "{family_kind} marginal-slope saved {fit_label} link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
1317                    beta.len(),
1318                    runtime.basis_dim
1319                ),
1320            });
1321        }
1322    }
1323    Ok(())
1324}
1325
1326impl SavedLinkWiggleRuntime {
1327    fn validate_monotone_derivative(
1328        &self,
1329        q0: &Array1<f64>,
1330    ) -> Result<Array1<f64>, FittedModelError> {
1331        // Monotonicity is verified pointwise at the actual evaluation grid `q0`
1332        // (the predict η). The fit guarantees a strictly-increasing warped link
1333        // across the training η range (#1596); checking at `q0` here flags an
1334        // extrapolation point where the learnable link genuinely turns
1335        // non-invertible, without rejecting the whole model for a sign dip in the
1336        // basis tail far outside any data or evaluation point.
1337        let d_constrained = self.constrained_basis(q0, BasisOptions::first_derivative())?;
1338        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1339        let dq_dq0 = d_constrained.dot(&beta_link_wiggle) + 1.0;
1340        if let Some((idx, value)) = dq_dq0.iter().copied().enumerate().find(|(_, v)| *v <= 0.0) {
1341            return Err(FittedModelError::PayloadCorrupt {
1342                reason: format!(
1343                    "saved link-wiggle is not monotone at row {idx}: dq/dq0={value:.3e} <= 0"
1344                ),
1345            });
1346        }
1347        Ok(dq_dq0)
1348    }
1349
1350    pub fn constrained_basis(
1351        &self,
1352        q0: &Array1<f64>,
1353        basis_options: BasisOptions,
1354    ) -> Result<Array2<f64>, FittedModelError> {
1355        let knot_arr = Array1::from_vec(self.knots.clone());
1356        let constrained = monotone_wiggle_basis_with_derivative_order(
1357            q0.view(),
1358            &knot_arr,
1359            self.degree,
1360            basis_options.derivative_order,
1361        )
1362        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1363        if constrained.ncols() != self.beta.len() {
1364            return Err(FittedModelError::SchemaMismatch {
1365                reason: format!(
1366                    "saved link-wiggle dimension mismatch: coefficients have {} entries but basis has {} columns",
1367                    self.beta.len(),
1368                    constrained.ncols()
1369                ),
1370            });
1371        }
1372        Ok(constrained)
1373    }
1374
1375    pub fn design(&self, q0: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1376        self.validate_monotone_derivative(q0)?;
1377        self.constrained_basis(q0, BasisOptions::value())
1378    }
1379
1380    pub fn basis_row_scalar(&self, q0: f64) -> Result<Array1<f64>, FittedModelError> {
1381        let q = Array1::from_vec(vec![q0]);
1382        let x = self.design(&q)?;
1383        if x.nrows() != 1 {
1384            return Err(FittedModelError::SchemaMismatch {
1385                reason: format!(
1386                    "saved link-wiggle scalar evaluation expected 1 row, got {}",
1387                    x.nrows()
1388                ),
1389            });
1390        }
1391        Ok(x.row(0).to_owned())
1392    }
1393
1394    pub fn apply(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1395        self.validate_monotone_derivative(q0)?;
1396        let xwiggle = self.constrained_basis(q0, BasisOptions::value())?;
1397        let beta_link_wiggle = Array1::from_vec(self.beta.clone());
1398        Ok(q0 + &xwiggle.dot(&beta_link_wiggle))
1399    }
1400
1401    pub fn derivative_q0(&self, q0: &Array1<f64>) -> Result<Array1<f64>, FittedModelError> {
1402        self.validate_monotone_derivative(q0)
1403    }
1404}
1405
1406impl SavedBaselineTimeWiggleRuntime {
1407    pub fn validate_global_monotonicity(&self) -> Result<(), FittedModelError> {
1408        validate_monotone_wiggle_beta_nonnegative(&self.beta, "saved baseline-timewiggle")
1409            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1410    }
1411}
1412
1413impl SavedCompiledFlexBlock {
1414    pub(crate) fn validate_exact_replay_contract(&self) -> Result<(), FittedModelError> {
1415        if self.kernel.is_empty() {
1416            return Err(FittedModelError::SchemaMismatch {
1417                reason: "saved anchored deviation runtime is missing the exact kernel marker"
1418                    .to_string(),
1419            });
1420        }
1421        if self.kernel != crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL {
1422            return Err(FittedModelError::IncompatibleConfig {
1423                reason: format!(
1424                    "saved anchored deviation runtime uses unsupported kernel '{}'; expected {}",
1425                    self.kernel,
1426                    crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL
1427                ),
1428            });
1429        }
1430        if self.basis_dim == 0 {
1431            return Err(FittedModelError::SchemaMismatch {
1432                reason: format!(
1433                    "saved anchored deviation runtime basis_dim must be positive, got {}",
1434                    self.basis_dim
1435                ),
1436            });
1437        }
1438        if self.breakpoints.len() < 2 {
1439            return Err(FittedModelError::SchemaMismatch {
1440                reason: format!(
1441                    "saved anchored deviation runtime requires at least two breakpoints, got {}",
1442                    self.breakpoints.len()
1443                ),
1444            });
1445        }
1446        for window in self.breakpoints.windows(2) {
1447            let left = window[0];
1448            let right = window[1];
1449            if !left.is_finite() || !right.is_finite() || right <= left {
1450                return Err(FittedModelError::PayloadCorrupt {
1451                    reason: format!(
1452                        "saved anchored deviation runtime breakpoints must be finite and strictly increasing, got [{left}, {right}]"
1453                    ),
1454                });
1455            }
1456        }
1457        let span_count = self.breakpoints.len() - 1;
1458        self.validate_coefficient_matrix(&self.span_c0, "c0", span_count)?;
1459        self.validate_coefficient_matrix(&self.span_c1, "c1", span_count)?;
1460        self.validate_coefficient_matrix(&self.span_c2, "c2", span_count)?;
1461        self.validate_coefficient_matrix(&self.span_c3, "c3", span_count)?;
1462        self.validate_c2_span_continuity()?;
1463        self.validate_anchor_residual_shape()?;
1464        Ok(())
1465    }
1466
1467    fn validate_anchor_residual_shape(&self) -> Result<(), FittedModelError> {
1468        let coeffs = match self.anchor_correction.as_ref() {
1469            Some(c) => c,
1470            None => {
1471                if !self.anchor_components.is_empty() {
1472                    return Err(FittedModelError::SchemaMismatch {
1473                        reason:
1474                            "saved anchored deviation runtime has anchor_components but no anchor_correction"
1475                                .to_string(),
1476                    });
1477                }
1478                return Ok(());
1479            }
1480        };
1481        let d: usize = self
1482            .anchor_components
1483            .iter()
1484            .map(|c| match &c.kind {
1485                SavedAnchorKind::Parametric { ncols, .. } => *ncols,
1486                SavedAnchorKind::FlexEvaluation { ncols } => *ncols,
1487            })
1488            .sum();
1489        if coeffs.len() != d {
1490            return Err(FittedModelError::SchemaMismatch {
1491                reason: format!(
1492                    "saved anchored deviation runtime anchor_correction has {} rows; expected {} (sum of component ncols)",
1493                    coeffs.len(),
1494                    d,
1495                ),
1496            });
1497        }
1498        for (i, row) in coeffs.iter().enumerate() {
1499            if row.len() != self.basis_dim {
1500                return Err(FittedModelError::SchemaMismatch {
1501                    reason: format!(
1502                        "saved anchored deviation runtime anchor_correction row {} has width {}, expected basis_dim {}",
1503                        i,
1504                        row.len(),
1505                        self.basis_dim,
1506                    ),
1507                });
1508            }
1509            for (j, &v) in row.iter().enumerate() {
1510                if !v.is_finite() {
1511                    return Err(FittedModelError::PayloadCorrupt {
1512                        reason: format!(
1513                            "saved anchored deviation runtime anchor_correction ({i},{j}) is non-finite"
1514                        ),
1515                    });
1516                }
1517            }
1518        }
1519        Ok(())
1520    }
1521
1522    fn validate_c2_span_continuity(&self) -> Result<(), FittedModelError> {
1523        const TOL: f64 = 1e-8;
1524        for span_idx in 1..self.breakpoints.len() - 1 {
1525            let left_span = span_idx - 1;
1526            let right_span = span_idx;
1527            let width = self.breakpoints[span_idx] - self.breakpoints[left_span];
1528            for basis_idx in 0..self.basis_dim {
1529                let left_value = self.span_c0[left_span][basis_idx]
1530                    + self.span_c1[left_span][basis_idx] * width
1531                    + self.span_c2[left_span][basis_idx] * width * width
1532                    + self.span_c3[left_span][basis_idx] * width * width * width;
1533                let left_d1 = self.span_c1[left_span][basis_idx]
1534                    + 2.0 * self.span_c2[left_span][basis_idx] * width
1535                    + 3.0 * self.span_c3[left_span][basis_idx] * width * width;
1536                let left_d2 = 2.0 * self.span_c2[left_span][basis_idx]
1537                    + 6.0 * self.span_c3[left_span][basis_idx] * width;
1538                let right_value = self.span_c0[right_span][basis_idx];
1539                let right_d1 = self.span_c1[right_span][basis_idx];
1540                let right_d2 = 2.0 * self.span_c2[right_span][basis_idx];
1541                if (left_value - right_value).abs() > TOL
1542                    || (left_d1 - right_d1).abs() > TOL
1543                    || (left_d2 - right_d2).abs() > TOL
1544                {
1545                    return Err(FittedModelError::SchemaMismatch {
1546                        reason: format!(
1547                            "saved anchored deviation runtime must be C2 cubic at breakpoint {span_idx}, basis {basis_idx}: value jump={:.3e}, d1 jump={:.3e}, d2 jump={:.3e}",
1548                            left_value - right_value,
1549                            left_d1 - right_d1,
1550                            left_d2 - right_d2
1551                        ),
1552                    });
1553                }
1554            }
1555        }
1556        Ok(())
1557    }
1558
1559    fn validate_coefficient_matrix(
1560        &self,
1561        matrix: &[Vec<f64>],
1562        label: &str,
1563        expected_rows: usize,
1564    ) -> Result<(), FittedModelError> {
1565        if matrix.len() != expected_rows {
1566            return Err(FittedModelError::SchemaMismatch {
1567                reason: format!(
1568                    "saved anchored deviation runtime {label} row count mismatch: got {}, expected {}",
1569                    matrix.len(),
1570                    expected_rows
1571                ),
1572            });
1573        }
1574        for (row_idx, row) in matrix.iter().enumerate() {
1575            if row.len() != self.basis_dim {
1576                return Err(FittedModelError::SchemaMismatch {
1577                    reason: format!(
1578                        "saved anchored deviation runtime {label} row {} has width {}, expected {}",
1579                        row_idx,
1580                        row.len(),
1581                        self.basis_dim
1582                    ),
1583                });
1584            }
1585            for (j, &value) in row.iter().enumerate() {
1586                if !value.is_finite() {
1587                    return Err(FittedModelError::PayloadCorrupt {
1588                        reason: format!(
1589                            "saved anchored deviation runtime {label} entry ({row_idx},{j}) is non-finite"
1590                        ),
1591                    });
1592                }
1593            }
1594        }
1595        Ok(())
1596    }
1597
1598    fn right_boundary_basis_value(&self, basis_idx: usize) -> f64 {
1599        let last_span = self.breakpoints.len() - 2;
1600        let width = self.breakpoints[last_span + 1] - self.breakpoints[last_span];
1601        self.span_c0[last_span][basis_idx]
1602            + self.span_c1[last_span][basis_idx] * width
1603            + self.span_c2[last_span][basis_idx] * width * width
1604            + self.span_c3[last_span][basis_idx] * width * width * width
1605    }
1606
1607    fn evaluate_span_polynomial_design(
1608        &self,
1609        values: &Array1<f64>,
1610        derivative_order: usize,
1611    ) -> Result<Array2<f64>, FittedModelError> {
1612        self.validate_exact_replay_contract()?;
1613        let (left_ep, right_ep) = self.support_interval()?;
1614        let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
1615        for (row_idx, &value) in values.iter().enumerate() {
1616            if !value.is_finite() {
1617                return Err(FittedModelError::PayloadCorrupt {
1618                    reason: format!(
1619                        "saved anchored deviation runtime design value at row {row_idx} is non-finite ({value})"
1620                    ),
1621                });
1622            }
1623            if value < left_ep {
1624                if derivative_order == 0 {
1625                    for basis_idx in 0..self.basis_dim {
1626                        out[[row_idx, basis_idx]] = self.span_c0[0][basis_idx];
1627                    }
1628                }
1629                continue;
1630            }
1631            if value > right_ep {
1632                if derivative_order == 0 {
1633                    for basis_idx in 0..self.basis_dim {
1634                        out[[row_idx, basis_idx]] = self.right_boundary_basis_value(basis_idx);
1635                    }
1636                }
1637                continue;
1638            }
1639            let span_idx = self.left_biased_span_index_for(value)?;
1640            let t = value - self.breakpoints[span_idx];
1641            for basis_idx in 0..self.basis_dim {
1642                let c0 = self.span_c0[span_idx][basis_idx];
1643                let c1 = self.span_c1[span_idx][basis_idx];
1644                let c2 = self.span_c2[span_idx][basis_idx];
1645                let c3 = self.span_c3[span_idx][basis_idx];
1646                out[[row_idx, basis_idx]] = match derivative_order {
1647                    0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
1648                    1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
1649                    2 => 2.0 * c2 + 6.0 * c3 * t,
1650                    3 => 6.0 * c3,
1651                    4 => 0.0,
1652                    other => {
1653                        return Err(FittedModelError::IncompatibleConfig {
1654                            reason: format!(
1655                                "saved anchored deviation runtime only supports derivative orders up to 4, got {other}"
1656                            ),
1657                        });
1658                    }
1659                };
1660            }
1661        }
1662        Ok(out)
1663    }
1664
1665    pub fn breakpoints(&self) -> Result<Vec<f64>, FittedModelError> {
1666        self.validate_exact_replay_contract()?;
1667        Ok(self.breakpoints.clone())
1668    }
1669
1670    pub fn span_count(&self) -> Result<usize, FittedModelError> {
1671        Ok(self.breakpoints()?.windows(2).count())
1672    }
1673
1674    pub fn span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1675        let points = self.breakpoints()?;
1676        span_index_for_breakpoints(&points, value, "saved anchored deviation span lookup")
1677            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })
1678    }
1679
1680    fn left_biased_span_index_for(&self, value: f64) -> Result<usize, FittedModelError> {
1681        let mut span_idx = span_index_for_breakpoints(
1682            &self.breakpoints,
1683            value,
1684            "saved anchored deviation span lookup",
1685        )
1686        .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
1687        // LEFT-bias at interior breakpoints mirrors DeviationRuntime. The
1688        // saved cubic basis is C2, but d3 remains span-local.
1689        if span_idx > 0 && value == self.breakpoints[span_idx] {
1690            span_idx -= 1;
1691        }
1692        Ok(span_idx)
1693    }
1694
1695    pub fn local_cubic_on_span(
1696        &self,
1697        beta: &Array1<f64>,
1698        span_idx: usize,
1699    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1700        self.validate_exact_replay_contract()?;
1701        if beta.len() != self.basis_dim {
1702            return Err(FittedModelError::SchemaMismatch {
1703                reason: format!(
1704                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1705                    beta.len(),
1706                    self.basis_dim
1707                ),
1708            });
1709        }
1710        self.local_cubic_on_span_validated(beta, span_idx)
1711    }
1712
1713    fn local_cubic_on_span_validated(
1714        &self,
1715        beta: &Array1<f64>,
1716        span_idx: usize,
1717    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1718        let points = &self.breakpoints;
1719        if span_idx + 1 >= points.len() {
1720            return Err(FittedModelError::SchemaMismatch {
1721                reason: format!(
1722                    "saved anchored deviation span index {} out of range for {} spans",
1723                    span_idx,
1724                    points.len() - 1
1725                ),
1726            });
1727        }
1728        let left = points[span_idx];
1729        let right = points[span_idx + 1];
1730        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1731            left,
1732            right,
1733            c0: self.span_c0[span_idx]
1734                .iter()
1735                .zip(beta.iter())
1736                .map(|(coeff, weight)| coeff * weight)
1737                .sum(),
1738            c1: self.span_c1[span_idx]
1739                .iter()
1740                .zip(beta.iter())
1741                .map(|(coeff, weight)| coeff * weight)
1742                .sum(),
1743            c2: self.span_c2[span_idx]
1744                .iter()
1745                .zip(beta.iter())
1746                .map(|(coeff, weight)| coeff * weight)
1747                .sum(),
1748            c3: self.span_c3[span_idx]
1749                .iter()
1750                .zip(beta.iter())
1751                .map(|(coeff, weight)| coeff * weight)
1752                .sum(),
1753        })
1754    }
1755
1756    pub fn basis_span_cubic(
1757        &self,
1758        span_idx: usize,
1759        basis_idx: usize,
1760    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1761        self.validate_exact_replay_contract()?;
1762        if basis_idx >= self.basis_dim {
1763            return Err(FittedModelError::SchemaMismatch {
1764                reason: format!(
1765                    "saved anchored deviation basis index {} out of range for {} coefficients",
1766                    basis_idx, self.basis_dim
1767                ),
1768            });
1769        }
1770        self.basis_span_cubic_validated(span_idx, basis_idx)
1771    }
1772
1773    fn basis_span_cubic_validated(
1774        &self,
1775        span_idx: usize,
1776        basis_idx: usize,
1777    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1778        let points = &self.breakpoints;
1779        if span_idx + 1 >= points.len() {
1780            return Err(FittedModelError::SchemaMismatch {
1781                reason: format!(
1782                    "saved anchored deviation span index {} out of range for {} spans",
1783                    span_idx,
1784                    points.len() - 1
1785                ),
1786            });
1787        }
1788        Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1789            left: points[span_idx],
1790            right: points[span_idx + 1],
1791            c0: self.span_c0[span_idx][basis_idx],
1792            c1: self.span_c1[span_idx][basis_idx],
1793            c2: self.span_c2[span_idx][basis_idx],
1794            c3: self.span_c3[span_idx][basis_idx],
1795        })
1796    }
1797
1798    pub fn basis_cubic_at(
1799        &self,
1800        basis_idx: usize,
1801        value: f64,
1802    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1803        self.validate_exact_replay_contract()?;
1804        if basis_idx >= self.basis_dim {
1805            return Err(FittedModelError::SchemaMismatch {
1806                reason: format!(
1807                    "saved anchored deviation basis index {} out of range for {} coefficients",
1808                    basis_idx, self.basis_dim
1809                ),
1810            });
1811        }
1812        let (left_ep, right_ep) = self.support_interval()?;
1813        if value < left_ep {
1814            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1815                left: left_ep,
1816                right: left_ep + 1.0,
1817                c0: self.span_c0[0][basis_idx],
1818                c1: 0.0,
1819                c2: 0.0,
1820                c3: 0.0,
1821            });
1822        }
1823        if value > right_ep {
1824            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1825                left: right_ep,
1826                right: right_ep + 1.0,
1827                c0: self.right_boundary_basis_value(basis_idx),
1828                c1: 0.0,
1829                c2: 0.0,
1830                c3: 0.0,
1831            });
1832        }
1833        let span_idx = self.left_biased_span_index_for(value)?;
1834        self.basis_span_cubic_validated(span_idx, basis_idx)
1835    }
1836
1837    pub fn local_cubic_at(
1838        &self,
1839        beta: &Array1<f64>,
1840        value: f64,
1841    ) -> Result<crate::cubic_cell_kernel::LocalSpanCubic, FittedModelError> {
1842        self.validate_exact_replay_contract()?;
1843        if beta.len() != self.basis_dim {
1844            return Err(FittedModelError::SchemaMismatch {
1845                reason: format!(
1846                    "saved anchored deviation coefficient length mismatch: got {}, expected {}",
1847                    beta.len(),
1848                    self.basis_dim
1849                ),
1850            });
1851        }
1852        let (left_ep, right_ep) = self.support_interval()?;
1853        if value < left_ep {
1854            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1855                left: left_ep,
1856                right: left_ep + 1.0,
1857                c0: self.span_c0[0]
1858                    .iter()
1859                    .zip(beta.iter())
1860                    .map(|(coeff, weight)| coeff * weight)
1861                    .sum(),
1862                c1: 0.0,
1863                c2: 0.0,
1864                c3: 0.0,
1865            });
1866        }
1867        if value > right_ep {
1868            return Ok(crate::cubic_cell_kernel::LocalSpanCubic {
1869                left: right_ep,
1870                right: right_ep + 1.0,
1871                c0: (0..self.basis_dim)
1872                    .map(|basis_idx| self.right_boundary_basis_value(basis_idx) * beta[basis_idx])
1873                    .sum(),
1874                c1: 0.0,
1875                c2: 0.0,
1876                c3: 0.0,
1877            });
1878        }
1879        let span_idx = self.left_biased_span_index_for(value)?;
1880        self.local_cubic_on_span_validated(beta, span_idx)
1881    }
1882
1883    fn support_interval(&self) -> Result<(f64, f64), FittedModelError> {
1884        let points = self.breakpoints()?;
1885        match (points.first(), points.last()) {
1886            (Some(&left), Some(&right)) => Ok((left, right)),
1887            _ => Err(FittedModelError::MissingField {
1888                reason: "saved anchored deviation runtime is missing support breakpoints"
1889                    .to_string(),
1890            }),
1891        }
1892    }
1893
1894    pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, FittedModelError> {
1895        // Note: when the saved runtime carries an anchor residual
1896        // (cross-block orthogonalisation), the value `design()` returns
1897        // is the raw cubic span output *without* the per-row `n_row · M`
1898        // subtraction. Callers used inside BMS prediction must either
1899        // switch to `design_with_anchor_rows` (when the per-row anchor
1900        // rows are available) or call `design_uncorrected` explicitly and
1901        // apply the subtraction at the call site. For runtimes without a
1902        // residual the two paths coincide.
1903        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1904    }
1905
1906    /// Raw cubic-span design without any anchor-residual subtraction.
1907    ///
1908    /// Exposed for callers that intend to apply the `n_row · M` correction
1909    /// post-hoc (e.g., BMS `link_terms_value_d1` subtracts a precomputed
1910    /// `correction.dot(beta)` scalar from the linear-predictor contribution
1911    /// rather than building a full anchor-row matrix). Equivalent to
1912    /// `design()` when no residual is present.
1913    pub fn design_uncorrected(
1914        &self,
1915        values: &Array1<f64>,
1916    ) -> Result<Array2<f64>, FittedModelError> {
1917        self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)
1918    }
1919
1920    /// Evaluate the residual-corrected design at the supplied values.
1921    ///
1922    /// `anchor_rows` must be an `n × d` matrix where `n == values.len()`
1923    /// and `d == sum of anchor_components ncols`. Each row holds
1924    /// the concatenated parametric anchor design at the same prediction
1925    /// row as the corresponding `values[i]`. When the runtime has no
1926    /// anchor residual, `anchor_rows` must have zero columns (or be
1927    /// `Array2::zeros((n, 0))`).
1928    pub fn design_with_anchor_rows(
1929        &self,
1930        values: &Array1<f64>,
1931        anchor_rows: ndarray::ArrayView2<f64>,
1932    ) -> Result<Array2<f64>, FittedModelError> {
1933        let mut out =
1934            self.evaluate_span_polynomial_design(values, BasisOptions::value().derivative_order)?;
1935        if let Some(m_rows) = self.anchor_correction.as_ref() {
1936            let d = m_rows.len();
1937            if anchor_rows.nrows() != values.len() {
1938                return Err(FittedModelError::SchemaMismatch {
1939                    reason: format!(
1940                        "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
1941                        anchor_rows.nrows(),
1942                        values.len(),
1943                    ),
1944                });
1945            }
1946            if anchor_rows.ncols() != d {
1947                return Err(FittedModelError::SchemaMismatch {
1948                    reason: format!(
1949                        "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
1950                        anchor_rows.ncols(),
1951                        d,
1952                    ),
1953                });
1954            }
1955            // Materialise M (d × basis_dim) once.
1956            let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
1957            for (i, row) in m_rows.iter().enumerate() {
1958                if row.len() != self.basis_dim {
1959                    return Err(FittedModelError::SchemaMismatch {
1960                        reason: format!(
1961                            "design_with_anchor_rows: anchor_correction row {} has length {}, expected basis_dim {}",
1962                            i,
1963                            row.len(),
1964                            self.basis_dim,
1965                        ),
1966                    });
1967                }
1968                for (j, &v) in row.iter().enumerate() {
1969                    m_dense[[i, j]] = v;
1970                }
1971            }
1972            // The compiler bakes the orthonormalising rotation into M, so
1973            // the predict-time subtraction is simply `n_anchor_rows · M`.
1974            let subtract = anchor_rows.dot(&m_dense);
1975            out = out - subtract;
1976        } else if anchor_rows.ncols() != 0 {
1977            return Err(FittedModelError::SchemaMismatch {
1978                reason: format!(
1979                    "design_with_anchor_rows: runtime has no anchor residual but anchor_rows has {} cols",
1980                    anchor_rows.ncols(),
1981                ),
1982            });
1983        }
1984        Ok(out)
1985    }
1986
1987    /// Build the n × basis_dim per-row, per-basis correction matrix
1988    /// `N · M` for a batch of predict rows.
1989    ///
1990    /// `n_anchor_rows` is the n × d matrix of stacked parametric anchor
1991    /// rows at the prediction rows (concatenation of the marginal and
1992    /// logslope design rows in component order). Returns `None` when the
1993    /// runtime has no anchor residual (zero-cost path).
1994    pub fn anchor_correction_matrix(
1995        &self,
1996        n_anchor_rows: ndarray::ArrayView2<f64>,
1997    ) -> Result<Option<Array2<f64>>, FittedModelError> {
1998        let Some(m_rows) = self.anchor_correction.as_ref() else {
1999            return Ok(None);
2000        };
2001        let d = m_rows.len();
2002        if n_anchor_rows.ncols() != d {
2003            return Err(FittedModelError::SchemaMismatch {
2004                reason: format!(
2005                    "anchor_correction_matrix: anchor_rows has {} cols, expected {} (sum of component ncols)",
2006                    n_anchor_rows.ncols(),
2007                    d,
2008                ),
2009            });
2010        }
2011        let mut m_dense = Array2::<f64>::zeros((d, self.basis_dim));
2012        for (i, row) in m_rows.iter().enumerate() {
2013            if row.len() != self.basis_dim {
2014                return Err(FittedModelError::SchemaMismatch {
2015                    reason: format!(
2016                        "anchor_correction_matrix: M row {} has length {}, expected basis_dim {}",
2017                        i,
2018                        row.len(),
2019                        self.basis_dim,
2020                    ),
2021                });
2022            }
2023            for (j, &v) in row.iter().enumerate() {
2024                m_dense[[i, j]] = v;
2025            }
2026        }
2027        // The compiler bakes the orthonormalising rotation into M, so
2028        // the predict-time correction is simply `n_anchor_rows · M`.
2029        Ok(Some(n_anchor_rows.dot(&m_dense)))
2030    }
2031
2032    pub fn first_derivative_design(
2033        &self,
2034        values: &Array1<f64>,
2035    ) -> Result<Array2<f64>, FittedModelError> {
2036        self.evaluate_span_polynomial_design(
2037            values,
2038            BasisOptions::first_derivative().derivative_order,
2039        )
2040    }
2041
2042    pub fn second_derivative_design(
2043        &self,
2044        values: &Array1<f64>,
2045    ) -> Result<Array2<f64>, FittedModelError> {
2046        self.evaluate_span_polynomial_design(
2047            values,
2048            BasisOptions::second_derivative().derivative_order,
2049        )
2050    }
2051}
2052
2053impl FittedFamily {
2054    #[inline]
2055    pub fn likelihood(&self) -> LikelihoodSpec {
2056        let spec = match self {
2057            Self::Standard { likelihood, .. }
2058            | Self::LocationScale { likelihood, .. }
2059            | Self::MarginalSlope { likelihood, .. }
2060            | Self::Survival { likelihood, .. }
2061            | Self::TransformationNormal { likelihood, .. } => likelihood,
2062            Self::LatentSurvival { .. } | Self::LatentBinary { .. } => {
2063                return LikelihoodSpec::royston_parmar();
2064            }
2065        };
2066        spec.clone()
2067    }
2068
2069    #[inline]
2070    pub fn frailty(&self) -> Option<&FrailtySpec> {
2071        match self {
2072            Self::MarginalSlope { frailty, .. }
2073            | Self::Survival { frailty, .. }
2074            | Self::LatentSurvival { frailty }
2075            | Self::LatentBinary { frailty } => Some(frailty),
2076            _ => None,
2077        }
2078    }
2079}
2080
2081/// Recursively collect the feature columns of a smooth basis whose out-of-hull
2082/// evaluation is bounded, so they can be exempted from the predict-time axis
2083/// clip (see [`FittedModel::training_smooth_extrapolation_axes`]). Wrapper bases
2084/// (`by=`, factor-smooth, sum-to-zero) delegate to their inner smooth; `Sphere`
2085/// and `Pca` are intentionally not collected.
2086fn collect_smooth_extrapolation_axes(
2087    basis: &gam_terms::smooth::SmoothBasisSpec,
2088    n_training_headers: usize,
2089    out: &mut std::collections::HashSet<usize>,
2090) {
2091    use gam_terms::smooth::SmoothBasisSpec;
2092    let push = |col: usize, out: &mut std::collections::HashSet<usize>| {
2093        if col < n_training_headers {
2094            out.insert(col);
2095        }
2096    };
2097    match basis {
2098        // 1D B-spline: first-order linear extension off the boundary slope.
2099        SmoothBasisSpec::BSpline1D { feature_col, .. } => push(*feature_col, out),
2100        // Tensor B-spline: each margin linear-extends independently. Periodic
2101        // margins are additionally (and harmlessly) exempted via the periodic set.
2102        SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
2103            for &c in feature_cols {
2104                push(c, out);
2105            }
2106        }
2107        // Radial bases with a bounded out-of-hull contract: Duchon / thin-plate
2108        // are linear outside the data span (natural-spline boundary conditions),
2109        // Matérn reverts to its mean as the kernel decays. Measure-jet shares
2110        // the Matérn contract (Gaussian representers decay to the parametric
2111        // layer off the data support) — and off-web queries are exactly the
2112        // ones its support diagnostic must see unclipped.
2113        SmoothBasisSpec::ThinPlate { feature_cols, .. }
2114        | SmoothBasisSpec::Matern { feature_cols, .. }
2115        | SmoothBasisSpec::MeasureJet { feature_cols, .. }
2116        | SmoothBasisSpec::Duchon { feature_cols, .. } => {
2117            for &c in feature_cols {
2118                push(c, out);
2119            }
2120        }
2121        // Factor-smooth: the continuous marginal axes are B-splines that
2122        // linear-extend; the group column is categorical and is left to the
2123        // random-effect / level-lookup machinery.
2124        SmoothBasisSpec::FactorSmooth { spec } => {
2125            for &c in &spec.continuous_cols {
2126                push(c, out);
2127            }
2128        }
2129        // Wrappers delegate to the inner smooth they modulate / replicate.
2130        SmoothBasisSpec::ByVariable { inner, .. }
2131        | SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2132            collect_smooth_extrapolation_axes(inner, n_training_headers, out)
2133        }
2134        SmoothBasisSpec::BySmooth { smooth, .. } => {
2135            collect_smooth_extrapolation_axes(smooth, n_training_headers, out)
2136        }
2137        // Sphere: latitude is clipped to manifold bounds, longitude is periodic —
2138        // both handled elsewhere with non-plateau semantics. Pca: no extrapolation
2139        // contract, stays clipped. ConstantCurvature: chart coordinates must stay
2140        // inside the κ-stereographic chart (open ball for κ < 0), so clipping new
2141        // data to the training range is the safe out-of-hull behavior.
2142        SmoothBasisSpec::Sphere { .. }
2143        | SmoothBasisSpec::ConstantCurvature { .. }
2144        | SmoothBasisSpec::Pca { .. } => {}
2145    }
2146}
2147
2148/// Collect training-column indices that feed a *numeric* `by=` multiplier of a
2149/// varying-coefficient smooth (`s(x, by=z)`).
2150///
2151/// A numeric by-variable enters the model as a linear multiplier of the centred
2152/// smooth basis, so the prediction is exactly affine in `z` for any fixed `x`:
2153/// `pred(x, z) = intercept + z·f(x)`. The by-multiplier is therefore an
2154/// *unbounded linear* extrapolant — exactly like a parametric `linear` axis
2155/// (already exempt from the clip via `training_linear_axes`), not a
2156/// bounded-basis axis. Clamping it to the training range `[min(z), max(z)]`
2157/// before the design is built would make the varying-coefficient effect plateau
2158/// outside the sampled range (slope 0 above the max) and silently replace the
2159/// natural `z == 0` baseline with the `z == min` prediction — the prediction is
2160/// no longer affine in `z`. Exempt it from the predict-time axis clip.
2161///
2162/// Factor `by=` columns are categorical group labels handled by the
2163/// level-lookup / random-effect machinery and are deliberately *not* exempted
2164/// here. Returned indices reference `self.training_headers`, matching the
2165/// iteration in `axis_clip_to_training_ranges`.
2166fn collect_by_variable_numeric_axes(
2167    basis: &gam_terms::smooth::SmoothBasisSpec,
2168    n_training_headers: usize,
2169    out: &mut std::collections::HashSet<usize>,
2170) {
2171    use gam_terms::smooth::{BySmoothKind, ByVarKind, SmoothBasisSpec};
2172    match basis {
2173        SmoothBasisSpec::ByVariable {
2174            inner,
2175            by_col,
2176            kind,
2177            ..
2178        } => {
2179            if matches!(kind, BySmoothKind::Numeric) && *by_col < n_training_headers {
2180                out.insert(*by_col);
2181            }
2182            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2183        }
2184        SmoothBasisSpec::BySmooth { smooth, by_kind } => {
2185            if let ByVarKind::Numeric { feature_col } = by_kind
2186                && *feature_col < n_training_headers
2187            {
2188                out.insert(*feature_col);
2189            }
2190            collect_by_variable_numeric_axes(smooth, n_training_headers, out);
2191        }
2192        SmoothBasisSpec::FactorSumToZero { inner, .. } => {
2193            collect_by_variable_numeric_axes(inner, n_training_headers, out);
2194        }
2195        _ => {}
2196    }
2197}
2198
2199impl FittedModel {
2200    /// Axis-clip each continuous new-data column to the (min, max) range
2201    /// observed in training. Categorical and binary columns are left
2202    /// untouched so unseen levels surface rather than being silently remapped
2203    /// onto seen ones. Returns `Some(clipped_copy)` only if at least one
2204    /// value was actually clipped; otherwise `None` so callers can avoid
2205    /// owning a redundant copy. Pre-2026-04-29 model JSONs that lack the
2206    /// `training_feature_ranges` field deserialize to `None` and pass through
2207    /// unchanged.
2208    pub fn axis_clip_to_training_ranges(
2209        &self,
2210        data: ndarray::ArrayView2<'_, f64>,
2211        col_map: &std::collections::HashMap<String, usize>,
2212    ) -> Option<ndarray::Array2<f64>> {
2213        let training_headers = self.training_headers.as_ref()?;
2214        let ranges = self.training_feature_ranges.as_ref()?;
2215        if training_headers.len() != ranges.len() {
2216            return None;
2217        }
2218        let mut kind_by_header: std::collections::HashMap<&str, ColumnKindTag> =
2219            std::collections::HashMap::new();
2220        if let Some(schema) = self.data_schema.as_ref() {
2221            for col in &schema.columns {
2222                kind_by_header.insert(col.name.as_str(), col.kind);
2223            }
2224        }
2225        // Periodic axes (sphere longitude, periodic-B-spline 1D, periodic
2226        // tensor margins) must never be clipped to the training range:
2227        // clamping a value just past the seam to the training extreme breaks
2228        // the cyclic invariant f(x₀) = f(x₀ + period) at predict time and
2229        // shows up as a visible seam in surface plots.
2230        let periodic_axes = self.training_periodic_axes(training_headers);
2231        // Parametric/linear-term axes must never be clipped either: a linear
2232        // term's contract is η = β0 + β1·x, i.e. genuine linear extrapolation.
2233        // Clamping its input to the training extreme turns predict into a
2234        // piecewise-constant plateau outside the training hull and freezes the
2235        // prediction SE at the boundary (the clamped x feeds xᵀ Var(β) x), so
2236        // credible intervals stop widening with distance from the data. This
2237        // mirrors how periodic axes are exempted just above.
2238        let linear_axes = self.training_linear_axes(training_headers.len());
2239        // Random-effect grouping axes are categorical even when their source
2240        // column is numeric. Clipping them would remap an unseen group label to
2241        // a boundary training level instead of letting the random-effect block
2242        // encode it as the prior-mean zero effect.
2243        let random_effect_axes = self.training_random_effect_axes(training_headers.len());
2244        // Non-parametric smooth axes whose basis extrapolates boundedly on its
2245        // own (B-spline linear extension, Duchon/thin-plate natural-spline linear
2246        // tail, Matérn kernel decay). Clamping their input to the training extreme
2247        // hands the basis an already-clamped coordinate, so its extrapolation
2248        // never fires and predict freezes at a boundary plateau — diverging from
2249        // the raw design path, which does not clip. Exempt them so both paths go
2250        // through the single basis-layer extrapolation. See the method doc.
2251        let smooth_extrapolation_axes =
2252            self.training_smooth_extrapolation_axes(training_headers.len());
2253        // Numeric `by=` multipliers of a varying-coefficient smooth `s(x, by=z)`
2254        // are linear multipliers of the centred basis (prediction affine in z),
2255        // so — like parametric linear axes — clipping them to the training range
2256        // turns the varying-coefficient effect into a boundary plateau and
2257        // destroys the z==0 baseline. Exempt them from the clip.
2258        let by_variable_axes = self.training_by_variable_numeric_axes(training_headers.len());
2259        // Sphere latitude is a closed-manifold coordinate: its clip bounds are
2260        // the manifold's intrinsic domain ([-π/2, π/2] or [-90, 90]), not the
2261        // sampled range, so a pole prediction reaches the true pole instead of
2262        // being clamped to a near-pole latitude (see the method doc).
2263        let sphere_lat_bounds = self.training_sphere_latitude_bounds(training_headers);
2264        let mut clipped = data.to_owned();
2265        let mut any_clipped = false;
2266        for (col_in_training, (header, &(lo, hi))) in
2267            training_headers.iter().zip(ranges.iter()).enumerate()
2268        {
2269            let (lo, hi) = sphere_lat_bounds
2270                .get(&col_in_training)
2271                .copied()
2272                .unwrap_or((lo, hi));
2273            if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
2274                continue;
2275            }
2276            if !matches!(
2277                kind_by_header.get(header.as_str()).copied(),
2278                Some(ColumnKindTag::Continuous)
2279            ) {
2280                continue;
2281            }
2282            if periodic_axes.contains(&col_in_training) {
2283                continue;
2284            }
2285            if linear_axes.contains(&col_in_training) {
2286                continue;
2287            }
2288            if random_effect_axes.contains(&col_in_training) {
2289                continue;
2290            }
2291            if smooth_extrapolation_axes.contains(&col_in_training) {
2292                continue;
2293            }
2294            if by_variable_axes.contains(&col_in_training) {
2295                continue;
2296            }
2297            let Some(&col_idx) = col_map.get(header) else {
2298                continue;
2299            };
2300            if col_idx >= clipped.ncols() {
2301                continue;
2302            }
2303            let mut col = clipped.column_mut(col_idx);
2304            for v in col.iter_mut() {
2305                if v.is_finite() {
2306                    if *v < lo {
2307                        *v = lo;
2308                        any_clipped = true;
2309                    } else if *v > hi {
2310                        *v = hi;
2311                        any_clipped = true;
2312                    }
2313                }
2314            }
2315        }
2316        if any_clipped { Some(clipped) } else { None }
2317    }
2318
2319    fn saved_term_specs(&self) -> Vec<&TermCollectionSpec> {
2320        let mut specs: Vec<&TermCollectionSpec> = [
2321            self.resolved_termspec.as_ref(),
2322            self.resolved_termspec_noise.as_ref(),
2323            self.resolved_termspec_logslope.as_ref(),
2324        ]
2325        .into_iter()
2326        .flatten()
2327        .collect();
2328        if let Some(logslopes) = self.resolved_termspec_logslopes.as_ref() {
2329            specs.extend(logslopes.iter());
2330        }
2331        specs
2332    }
2333
2334    /// Collect the set of training-column indices that are periodic axes —
2335    /// i.e. features for which a periodic basis (sphere longitude, periodic
2336    /// B-spline 1D, periodic tensor margin) must be allowed to take any
2337    /// real value at predict time and not be clamped to the training range.
2338    /// Returned indices reference `self.training_headers` (training-time
2339    /// layout), matching the iteration in `axis_clip_to_training_ranges`.
2340    fn training_periodic_axes(
2341        &self,
2342        training_headers: &[String],
2343    ) -> std::collections::HashSet<usize> {
2344        use gam_terms::basis::BSplineKnotSpec;
2345        use gam_terms::smooth::SmoothBasisSpec;
2346        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2347        let Some(spec) = self.resolved_termspec.as_ref() else {
2348            return out;
2349        };
2350        for term in &spec.smooth_terms {
2351            match &term.basis {
2352                // Sphere terms: longitude (second feature col) is always
2353                // periodic and exempt from clipping. Latitude is not periodic
2354                // but is a closed-manifold coordinate, so it is clipped to the
2355                // manifold's intrinsic bounds rather than the sampled range —
2356                // see `training_sphere_latitude_bounds`.
2357                SmoothBasisSpec::Sphere { feature_cols, .. } => {
2358                    if let Some(&lon_col) = feature_cols.get(1)
2359                        && lon_col < training_headers.len()
2360                    {
2361                        out.insert(lon_col);
2362                    }
2363                }
2364                // 1D periodic B-spline: the single feature column is periodic.
2365                SmoothBasisSpec::BSpline1D { feature_col, spec } => {
2366                    if matches!(spec.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2367                        && *feature_col < training_headers.len()
2368                    {
2369                        out.insert(*feature_col);
2370                    }
2371                }
2372                // Tensor B-spline: each axis whose marginal knotspec is
2373                // PeriodicUniform is periodic; mark those columns.
2374                SmoothBasisSpec::TensorBSpline { feature_cols, spec } => {
2375                    for (i, marginal) in spec.marginalspecs.iter().enumerate() {
2376                        if matches!(marginal.knotspec, BSplineKnotSpec::PeriodicUniform { .. })
2377                            && let Some(&col) = feature_cols.get(i)
2378                            && col < training_headers.len()
2379                        {
2380                            out.insert(col);
2381                        }
2382                    }
2383                }
2384                _ => {}
2385            }
2386        }
2387        out
2388    }
2389
2390    /// Collect the set of training-column indices that feed a parametric/linear
2391    /// term — on *any* modelled surface (mean, noise/scale, log-slope). A
2392    /// linear term realises the design column `∏ feature_cols` and contributes
2393    /// `β·(that product)` to the linear predictor, so its inputs must be allowed
2394    /// to take any real value at predict time: clamping them to the training
2395    /// range would replace genuine linear extrapolation with a boundary plateau
2396    /// (and freeze the prediction SE at the hull edge). Returned indices
2397    /// reference `self.training_headers` (training-time layout), matching the
2398    /// iteration in `axis_clip_to_training_ranges`. Wilkinson-Rogers `:`
2399    /// interactions contribute every column in their `feature_cols` product.
2400    fn training_linear_axes(&self, n_training_headers: usize) -> std::collections::HashSet<usize> {
2401        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2402        for spec in self.saved_term_specs() {
2403            for term in &spec.linear_terms {
2404                for col in term.effective_feature_cols() {
2405                    if col < n_training_headers {
2406                        out.insert(col);
2407                    }
2408                }
2409            }
2410        }
2411        out
2412    }
2413
2414    /// Collect the set of training-column indices that feed random-effect
2415    /// grouping terms. These columns are categorical model axes regardless of
2416    /// the ingest schema's scalar storage type, so prediction must leave them
2417    /// untouched and let the frozen random-effect levels decide whether a row is
2418    /// a seen level or the zero-effect unseen-level fallback.
2419    fn training_random_effect_axes(
2420        &self,
2421        n_training_headers: usize,
2422    ) -> std::collections::HashSet<usize> {
2423        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2424        for spec in self.saved_term_specs() {
2425            for term in &spec.random_effect_terms {
2426                if term.feature_col < n_training_headers {
2427                    out.insert(term.feature_col);
2428                }
2429            }
2430        }
2431        out
2432    }
2433
2434    /// Collect the set of training-column indices that feed a non-parametric
2435    /// smooth whose basis performs its own *bounded* extrapolation outside the
2436    /// training hull — on any modelled surface (mean, noise/scale, log-slope).
2437    ///
2438    /// These columns must be exempt from the predict-time axis clip for the same
2439    /// reason periodic/linear/random-effect axes are: the clip clamps a new value
2440    /// to the training extreme *before* the design is built, so the basis is
2441    /// handed an already-clamped coordinate and its extrapolation machinery never
2442    /// fires. The result is a piecewise-constant plateau frozen at the boundary
2443    /// fitted value (with a prediction SE frozen at the hull edge), and — worse —
2444    /// a model that yields *different* predictions through the `FittedModel`
2445    /// predict pipeline than through the raw `build_term_collection_design` path,
2446    /// which does not clip. Exempting these axes routes both entry points through
2447    /// the single basis-layer extrapolation, restoring internal consistency.
2448    ///
2449    /// Only bases with a *bounded* out-of-hull contract are listed, so removing
2450    /// the clip cannot reintroduce the wild basis blow-up the clip guards against:
2451    ///   - B-spline 1D / tensor margins: first-order linear extension off the
2452    ///     boundary slope (`apply_linear_extension_from_first_derivative`) — grows
2453    ///     at most linearly.
2454    ///   - Duchon / thin-plate: the natural-spline boundary conditions make the
2455    ///     fit linear outside the data span — also at most linear growth.
2456    ///   - Matérn: the kernel decays with distance, so the fit reverts smoothly to
2457    ///     its (low-order polynomial / constant) mean far from the data — bounded.
2458    /// `sphere()` axes are deliberately *not* listed: latitude is a closed-manifold
2459    /// coordinate clipped to its intrinsic bounds (`training_sphere_latitude_bounds`)
2460    /// and longitude is periodic (`training_periodic_axes`); both already have the
2461    /// correct, non-plateau handling. `Pca` projections have no extrapolation
2462    /// contract and stay clipped.
2463    ///
2464    /// Returned indices reference `self.training_headers` (training-time layout),
2465    /// matching the iteration in `axis_clip_to_training_ranges`.
2466    fn training_smooth_extrapolation_axes(
2467        &self,
2468        n_training_headers: usize,
2469    ) -> std::collections::HashSet<usize> {
2470        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2471        for spec in self.saved_term_specs() {
2472            for term in &spec.smooth_terms {
2473                collect_smooth_extrapolation_axes(&term.basis, n_training_headers, &mut out);
2474            }
2475        }
2476        out
2477    }
2478
2479    /// Collect the set of training-column indices that feed a *numeric* `by=`
2480    /// multiplier of a varying-coefficient smooth. These columns are linear
2481    /// multipliers (`pred = intercept + z·f(x)`), so they must be exempt from
2482    /// the predict-time axis clip for the same reason parametric linear axes
2483    /// are — see `collect_by_variable_numeric_axes` for the full rationale.
2484    fn training_by_variable_numeric_axes(
2485        &self,
2486        n_training_headers: usize,
2487    ) -> std::collections::HashSet<usize> {
2488        let mut out: std::collections::HashSet<usize> = std::collections::HashSet::new();
2489        for spec in self.saved_term_specs() {
2490            for term in &spec.smooth_terms {
2491                collect_by_variable_numeric_axes(&term.basis, n_training_headers, &mut out);
2492            }
2493        }
2494        out
2495    }
2496
2497    /// Manifold-intrinsic clip bounds for sphere *latitude* columns.
2498    ///
2499    /// A `sphere(lat, lon)` smooth charts the closed manifold S²: the poles
2500    /// (lat = ±π/2, or ±90°) are interior limit points of that manifold, not
2501    /// endpoints of an unbounded axis to extrapolate along. A finite training
2502    /// sample never reaches the pole exactly, so clamping a pole prediction to
2503    /// the *observed* latitude extreme lands at a near-pole latitude where the
2504    /// Wahba SOS kernel's `cos(lat)·cos(lat_c)·cos(Δlon)` term has not yet
2505    /// damped — and the predictor then sweeps a spurious `cos(lon)` profile at
2506    /// what is physically a single point, reintroducing the pole artefact the
2507    /// SOS basis exists to remove.
2508    ///
2509    /// The correct clip bound for this coordinate is therefore the manifold's
2510    /// intrinsic domain — `[-π/2, π/2]` radians or `[-90, 90]` degrees — not
2511    /// the sampled range. Clamping to those bounds keeps the pole reachable
2512    /// (single-valued in longitude) while still mapping any out-of-domain
2513    /// latitude onto the manifold boundary. Longitude needs no entry here: it
2514    /// is periodic and already exempted from clipping entirely
2515    /// (`training_periodic_axes`). Returned indices reference
2516    /// `self.training_headers`, matching the iteration in
2517    /// `axis_clip_to_training_ranges`.
2518    fn training_sphere_latitude_bounds(
2519        &self,
2520        training_headers: &[String],
2521    ) -> std::collections::HashMap<usize, (f64, f64)> {
2522        use gam_terms::smooth::SmoothBasisSpec;
2523        let mut out: std::collections::HashMap<usize, (f64, f64)> =
2524            std::collections::HashMap::new();
2525        let Some(spec) = self.resolved_termspec.as_ref() else {
2526            return out;
2527        };
2528        for term in &spec.smooth_terms {
2529            if let SmoothBasisSpec::Sphere { feature_cols, spec } = &term.basis
2530                && let Some(&lat_col) = feature_cols.first()
2531                && lat_col < training_headers.len()
2532            {
2533                let bound = if spec.radians {
2534                    std::f64::consts::FRAC_PI_2
2535                } else {
2536                    90.0
2537                };
2538                out.insert(lat_col, (-bound, bound));
2539            }
2540        }
2541        out
2542    }
2543
2544    pub fn from_payload(mut payload: FittedModelPayload) -> Self {
2545        let likelihood = payload.family_state.likelihood();
2546        let class = match payload.model_kind {
2547            ModelKind::Survival => PredictModelClass::Survival,
2548            ModelKind::MarginalSlope => PredictModelClass::BernoulliMarginalSlope,
2549            ModelKind::TransformationNormal => PredictModelClass::TransformationNormal,
2550            ModelKind::LocationScale => {
2551                if likelihood == LikelihoodSpec::gaussian_identity() {
2552                    PredictModelClass::GaussianLocationScale
2553                } else if is_dispersion_location_scale_response(&likelihood.response) {
2554                    PredictModelClass::DispersionLocationScale
2555                } else {
2556                    PredictModelClass::BinomialLocationScale
2557                }
2558            }
2559            ModelKind::Standard => PredictModelClass::Standard,
2560        };
2561        match class {
2562            PredictModelClass::Survival => {
2563                payload.model_kind = ModelKind::Survival;
2564                Self::Survival { payload }
2565            }
2566            PredictModelClass::BernoulliMarginalSlope => {
2567                payload.model_kind = ModelKind::MarginalSlope;
2568                Self::MarginalSlope { payload }
2569            }
2570            PredictModelClass::TransformationNormal => {
2571                payload.model_kind = ModelKind::TransformationNormal;
2572                Self::TransformationNormal { payload }
2573            }
2574            PredictModelClass::GaussianLocationScale
2575            | PredictModelClass::BinomialLocationScale
2576            | PredictModelClass::DispersionLocationScale => {
2577                payload.model_kind = ModelKind::LocationScale;
2578                Self::LocationScale { payload }
2579            }
2580            PredictModelClass::Standard => {
2581                payload.model_kind = ModelKind::Standard;
2582                Self::Standard { payload }
2583            }
2584        }
2585        .with_synchronized_stateful_link_metadata()
2586    }
2587
2588    #[inline]
2589    pub fn payload(&self) -> &FittedModelPayload {
2590        match self {
2591            Self::Standard { payload }
2592            | Self::LocationScale { payload }
2593            | Self::MarginalSlope { payload }
2594            | Self::Survival { payload }
2595            | Self::TransformationNormal { payload } => payload,
2596        }
2597    }
2598
2599    #[inline]
2600    fn payload_mut(&mut self) -> &mut FittedModelPayload {
2601        match self {
2602            Self::Standard { payload }
2603            | Self::LocationScale { payload }
2604            | Self::MarginalSlope { payload }
2605            | Self::Survival { payload }
2606            | Self::TransformationNormal { payload } => payload,
2607        }
2608    }
2609
2610    fn with_synchronized_stateful_link_metadata(mut self) -> Self {
2611        self.synchronize_stateful_link_metadata();
2612        self
2613    }
2614
2615    fn synchronize_stateful_link_metadata(&mut self) {
2616        let payload = self.payload_mut();
2617        payload.used_device = payload
2618            .fit_result
2619            .as_ref()
2620            .or(payload.unified.as_ref())
2621            .is_some_and(|fit| fit.used_device);
2622        payload.synchronize_empty_feature_contract();
2623        let Some(fit) = payload.fit_result.as_ref().or(payload.unified.as_ref()) else {
2624            return;
2625        };
2626        match (&mut payload.family_state, &fit.fitted_link) {
2627            (
2628                FittedFamily::Standard {
2629                    likelihood,
2630                    latent_cloglog_state,
2631                    ..
2632                },
2633                FittedLinkState::LatentCLogLog { state },
2634            ) if likelihood.is_latent_cloglog() => {
2635                *latent_cloglog_state = Some(*state);
2636            }
2637            (
2638                FittedFamily::Standard {
2639                    likelihood,
2640                    sas_state,
2641                    ..
2642                },
2643                FittedLinkState::Sas { state, covariance },
2644            ) if likelihood.is_binomial_sas() => {
2645                *sas_state = Some(*state);
2646                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2647            }
2648            (
2649                FittedFamily::Standard {
2650                    likelihood,
2651                    sas_state,
2652                    ..
2653                },
2654                FittedLinkState::BetaLogistic { state, covariance },
2655            ) if likelihood.is_binomial_beta_logistic() => {
2656                *sas_state = Some(*state);
2657                payload.sas_param_covariance = covariance.as_ref().map(array2_to_nestedvec);
2658            }
2659            (
2660                FittedFamily::Standard {
2661                    likelihood,
2662                    mixture_state,
2663                    ..
2664                },
2665                FittedLinkState::Mixture { state, covariance },
2666            ) if likelihood.is_binomial_mixture() => {
2667                *mixture_state = Some(state.clone());
2668                payload.mixture_link_param_covariance =
2669                    covariance.as_ref().map(array2_to_nestedvec);
2670            }
2671            _ => {}
2672        }
2673    }
2674
2675    #[inline]
2676    pub fn likelihood(&self) -> LikelihoodSpec {
2677        self.payload().family_state.likelihood()
2678    }
2679
2680    /// Columns this model consumes from a prediction frame — its *input
2681    /// contract*.
2682    ///
2683    /// Every variable named by the main formula (features, interaction margins,
2684    /// random-effect groups, and a smooth's `by=` column), the survival
2685    /// entry/exit columns or the transformation-normal response, the auxiliary
2686    /// noise / logslope formula columns, and the offset / noise-offset /
2687    /// latent-`z` columns. The event-indicator and the plain response of a
2688    /// standard model are deliberately excluded: they are not needed to *form*
2689    /// a prediction (the conformal-calibration fold layers the response back on
2690    /// separately).
2691    ///
2692    /// This is the single authority shared by the CLI and PyFFI predict paths.
2693    /// A prediction frame column that is *not* in this set is irrelevant to the
2694    /// model and must be ignored rather than strict-encoded against the
2695    /// training schema — otherwise an unrelated ID/label column with a held-out
2696    /// categorical level aborts predict (#840).
2697    pub fn prediction_required_columns(
2698        &self,
2699    ) -> Result<std::collections::BTreeSet<String>, String> {
2700        let payload = self.payload();
2701        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2702        let mut required = std::collections::BTreeSet::<String>::new();
2703        parsed_term_column_names(&parsed.terms, &mut required);
2704
2705        if let Some((entry, exit, _event)) =
2706            parse_surv_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2707        {
2708            if let Some(entry) = entry {
2709                required.insert(entry);
2710            }
2711            required.insert(exit);
2712        } else if let Some((left, right, _event)) =
2713            parse_surv_interval_response(parsed.response.as_str()).map_err(|e| e.to_string())?
2714        {
2715            required.insert(left);
2716            required.insert(right);
2717        }
2718        // A transformation-normal (CTM) prediction returns the response-scale
2719        // conditional mean E[Y|x], a function of the covariates alone (issue
2720        // #1612). The earlier implementation precomputed the PIT h(y|x) of the
2721        // supplied response, which made the outcome column mandatory at predict
2722        // time; the response is no longer required, so a covariate-only frame
2723        // must predict without it.
2724
2725        if let Some(offset) = payload.offset_column.as_ref() {
2726            required.insert(offset.clone());
2727        }
2728        if let Some(noise_offset) = payload.noise_offset_column.as_ref() {
2729            required.insert(noise_offset.clone());
2730        }
2731        if matches!(
2732            self.predict_model_class(),
2733            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
2734        ) {
2735            if let Some(z_column) = payload.z_column.as_ref() {
2736                required.remove("z");
2737                required.insert(z_column.clone());
2738            }
2739        }
2740        if let Some(noise_formula) = payload.formula_noise.as_ref() {
2741            self.add_auxiliary_formula_columns(
2742                &mut required,
2743                noise_formula,
2744                parsed.response.as_str(),
2745            )?;
2746        }
2747        if let Some(logslope_formula) = payload.formula_logslope.as_ref() {
2748            if logslope_formula != "same-as-main" {
2749                self.add_auxiliary_formula_columns(
2750                    &mut required,
2751                    logslope_formula,
2752                    parsed.response.as_str(),
2753                )?;
2754            }
2755        }
2756        Ok(required)
2757    }
2758
2759    /// Columns a *post-fit diagnostic* command (diagnose / sample / report)
2760    /// needs **beyond** [`Self::prediction_required_columns`].
2761    ///
2762    /// Prediction deliberately drops a standard GAM's bare response so a
2763    /// prediction frame may omit it (#840 / #864). Diagnostics are statements
2764    /// *about* that observed response — residuals, R², posterior likelihoods,
2765    /// leave-one-out — so the response must be present. This returns the bare
2766    /// response column when the prediction projection would otherwise drop it,
2767    /// and nothing when the response is already prediction-required (survival
2768    /// `Surv(...)` time/event columns, the transformation-normal response) or
2769    /// is not a plain data column.
2770    ///
2771    /// Centralising the intent here is what makes it *structurally impossible*
2772    /// for a diagnostic command to silently drop the response: callers use
2773    /// `load_dataset…_for_diagnostics`, which always folds these in, instead of
2774    /// each remembering to thread an `extra_required` response by hand.
2775    pub fn diagnostic_extra_columns(&self) -> Result<Vec<String>, String> {
2776        let payload = self.payload();
2777        let parsed = parse_formula(payload.formula.as_str()).map_err(|e| e.to_string())?;
2778        // Prior (case) weights never enter the linear predictor, so
2779        // `prediction_required_columns` deliberately omits the weight column and
2780        // a prediction frame may drop it. Diagnostics are weight-aware, though:
2781        // `diagnose` reconstructs the ALO working weights `w_i = prior_i ·
2782        // Fisher_i` (and the refit fallback re-weights the same way), so the
2783        // weight column must be loaded. Fold it in here — the single seam that
2784        // makes it structurally impossible for a diagnostic command to silently
2785        // drop a needed column — regardless of the response-shape early-outs
2786        // below, since it is orthogonal to the response.
2787        let mut extras: Vec<String> = Vec::new();
2788        if let Some(weight_column) = payload.weight_column.as_ref() {
2789            extras.push(weight_column.clone());
2790        }
2791        // Survival responses are `Surv(...)` expressions, not bare columns; the
2792        // underlying entry/exit columns are already prediction-required.
2793        if parse_surv_response(parsed.response.as_str())
2794            .map_err(|e| e.to_string())?
2795            .is_some()
2796            || parse_surv_interval_response(parsed.response.as_str())
2797                .map_err(|e| e.to_string())?
2798                .is_some()
2799        {
2800            return Ok(extras);
2801        }
2802        let response = parsed.response.trim();
2803        // A response that is empty, or a function-call expression rather than a
2804        // plain data column, has no bare column to re-add.
2805        if response.is_empty() || response.contains('(') {
2806            return Ok(extras);
2807        }
2808        // Already prediction-required (e.g. transformation-normal re-adds it):
2809        // nothing extra to fold in.
2810        if self.prediction_required_columns()?.contains(response) {
2811            return Ok(extras);
2812        }
2813        extras.push(response.to_string());
2814        Ok(extras)
2815    }
2816
2817    /// Add the columns referenced by an auxiliary (noise / logslope) formula,
2818    /// which may be supplied as a full `lhs ~ rhs` formula or as a bare RHS.
2819    fn add_auxiliary_formula_columns(
2820        &self,
2821        required: &mut std::collections::BTreeSet<String>,
2822        formula_or_rhs: &str,
2823        response: &str,
2824    ) -> Result<(), String> {
2825        let trimmed = formula_or_rhs.trim();
2826        if trimmed.is_empty() || trimmed == "1" {
2827            return Ok(());
2828        }
2829        let formula = if trimmed.contains('~') {
2830            trimmed.to_string()
2831        } else {
2832            format!("{response} ~ {trimmed}")
2833        };
2834        let parsed = parse_formula(formula.as_str()).map_err(|e| e.to_string())?;
2835        parsed_term_column_names(&parsed.terms, required);
2836        Ok(())
2837    }
2838
2839    #[inline]
2840    pub fn predict_model_class(&self) -> PredictModelClass {
2841        match &self.payload().family_state {
2842            FittedFamily::Survival { .. }
2843            | FittedFamily::LatentSurvival { .. }
2844            | FittedFamily::LatentBinary { .. } => PredictModelClass::Survival,
2845            FittedFamily::MarginalSlope { .. } => PredictModelClass::BernoulliMarginalSlope,
2846            FittedFamily::TransformationNormal { .. } => PredictModelClass::TransformationNormal,
2847            FittedFamily::LocationScale { likelihood, .. } if likelihood.is_gaussian_identity() => {
2848                PredictModelClass::GaussianLocationScale
2849            }
2850            FittedFamily::LocationScale { likelihood, .. }
2851                if is_dispersion_location_scale_response(&likelihood.response) =>
2852            {
2853                PredictModelClass::DispersionLocationScale
2854            }
2855            FittedFamily::LocationScale { .. } => PredictModelClass::BinomialLocationScale,
2856            FittedFamily::Standard { .. } => PredictModelClass::Standard,
2857        }
2858    }
2859
2860    pub fn saved_link_wiggle(&self) -> Result<Option<SavedLinkWiggleRuntime>, FittedModelError> {
2861        let payload = self.payload();
2862        let (knots, degree) = match (
2863            payload.linkwiggle_knots.as_ref(),
2864            payload.linkwiggle_degree,
2865        ) {
2866            (None, None) => return Ok(None),
2867            (Some(knots), Some(degree)) => (knots.clone(), degree),
2868            _ => {
2869                return Err(FittedModelError::SchemaMismatch {
2870                    reason:
2871                        "saved model has partial link-wiggle metadata; expected linkwiggle_knots and linkwiggle_degree together"
2872                            .to_string(),
2873                })
2874            }
2875        };
2876        let resolved_link = self.resolved_inverse_link()?;
2877        let saved_link_disallows_wiggle = resolved_link
2878            .as_ref()
2879            .is_some_and(|link| !inverse_link_supports_joint_wiggle(link))
2880            || payload
2881                .link
2882                .as_ref()
2883                .is_some_and(|link| !inverse_link_supports_joint_wiggle(link));
2884        if saved_link_disallows_wiggle {
2885            return Err(FittedModelError::IncompatibleConfig {
2886                reason: joint_wiggle_unsupported_link_message("link wiggle"),
2887            });
2888        }
2889        let beta = match self.predict_model_class() {
2890            // #1596: the frozen-basis de-aliased standard link-warp is fit in a
2891            // reduced, identifiable coordinate `γ` and the fit_result LinkWiggle
2892            // block stores `γ` (its true free parameters). The full-width
2893            // standard-basis lift `β_w = Z·γ`, the coefficients the predict-time
2894            // I-spline basis multiplies, is persisted in `payload.beta_link_wiggle`
2895            // — prefer it when present. Without it (the dynamic-basis path) the
2896            // block coefficients ARE the standard-basis warp, read directly.
2897            PredictModelClass::Standard if payload.beta_link_wiggle.is_some() => {
2898                payload.beta_link_wiggle.clone().expect("checked is_some")
2899            }
2900            PredictModelClass::Standard => {
2901                let fit = payload.fit_result.as_ref().ok_or_else(|| {
2902                    FittedModelError::MissingField {
2903                        reason:
2904                            "standard link-wiggle model is missing canonical fit_result payload"
2905                                .to_string(),
2906                    }
2907                })?;
2908                if fit.blocks.len() != 2
2909                    || fit.blocks[0].role != BlockRole::Mean
2910                    || fit.blocks[1].role != BlockRole::LinkWiggle
2911                {
2912                    return Err(FittedModelError::SchemaMismatch {
2913                        reason:
2914                            "standard link-wiggle models must store blocks in [Mean, LinkWiggle] order"
2915                                .to_string(),
2916                    });
2917                }
2918                fit.block_by_role(BlockRole::LinkWiggle)
2919                    .ok_or_else(|| FittedModelError::MissingField {
2920                        reason:
2921                            "standard link-wiggle model is missing LinkWiggle coefficient block"
2922                                .to_string(),
2923                    })?
2924                    .beta
2925                    .to_vec()
2926            }
2927            _ => payload
2928                .beta_link_wiggle
2929                .clone()
2930                .ok_or_else(|| FittedModelError::MissingField {
2931                    reason:
2932                        "saved model has link-wiggle metadata but is missing payload.beta_link_wiggle"
2933                            .to_string(),
2934                })?,
2935        };
2936        Ok(Some(SavedLinkWiggleRuntime {
2937            knots,
2938            degree,
2939            beta,
2940        }))
2941    }
2942
2943    pub fn saved_baseline_time_wiggle(
2944        &self,
2945    ) -> Result<Option<SavedBaselineTimeWiggleRuntime>, FittedModelError> {
2946        let payload = self.payload();
2947        if payload
2948            .survival_cause_count
2949            .is_some_and(|cause_count| cause_count > 1)
2950            && payload.beta_baseline_timewiggle.is_none()
2951            && payload.beta_baseline_timewiggle_by_cause.is_some()
2952        {
2953            return Err(FittedModelError::SchemaMismatch {
2954                reason:
2955                    "joint cause-specific survival stores baseline-timewiggle coefficients per cause"
2956                        .to_string(),
2957            });
2958        }
2959        match (
2960            payload.baseline_timewiggle_knots.as_ref(),
2961            payload.baseline_timewiggle_degree,
2962            payload.baseline_timewiggle_penalty_orders.as_ref(),
2963            payload.baseline_timewiggle_double_penalty,
2964            payload.beta_baseline_timewiggle.as_ref(),
2965        ) {
2966            (None, None, None, None, None) => Ok(None),
2967            (Some(knots), Some(degree), Some(penalty_orders), Some(double_penalty), Some(beta)) => {
2968                Ok(Some(SavedBaselineTimeWiggleRuntime {
2969                    knots: knots.clone(),
2970                    degree,
2971                    penalty_orders: penalty_orders.clone(),
2972                    double_penalty,
2973                    beta: beta.clone(),
2974                }))
2975            }
2976            _ => Err(FittedModelError::SchemaMismatch {
2977                reason:
2978                    "saved model has partial baseline-timewiggle metadata; expected knots+degree+penalty_order+double_penalty+beta_baseline_timewiggle together"
2979                        .to_string(),
2980            }),
2981        }
2982    }
2983
2984    /// Whether this model has a link wiggle component with complete metadata.
2985    #[inline]
2986    pub fn has_link_wiggle(&self) -> bool {
2987        self.saved_link_wiggle()
2988            .map(|runtime| runtime.is_some())
2989            .unwrap_or(false)
2990    }
2991
2992    /// Whether this model has a baseline-time wiggle component with complete metadata.
2993    #[inline]
2994    pub fn has_baseline_time_wiggle(&self) -> bool {
2995        let payload = self.payload();
2996        if payload
2997            .survival_cause_count
2998            .is_some_and(|cause_count| cause_count > 1)
2999        {
3000            return payload.baseline_timewiggle_knots.is_some()
3001                && payload.baseline_timewiggle_degree.is_some()
3002                && payload.baseline_timewiggle_penalty_orders.is_some()
3003                && payload.baseline_timewiggle_double_penalty.is_some()
3004                && payload.beta_baseline_timewiggle_by_cause.is_some();
3005        }
3006        self.saved_baseline_time_wiggle()
3007            .map(|runtime| runtime.is_some())
3008            .unwrap_or(false)
3009    }
3010
3011    /// Whether the default point prediction must integrate the inverse link
3012    /// over the coefficient posterior — reporting the posterior mean
3013    /// `E[g⁻¹(Xβ)]` — rather than plugging in the posterior mode `g⁻¹(Xβ̂)`.
3014    ///
3015    /// SPEC (issue #960): the posterior mean is *always* the default point
3016    /// estimate (never MAP). It is observably distinct from the plug-in exactly
3017    /// when the inverse link is *curved* over the posterior's uncertainty, so
3018    /// `E[g⁻¹(η)] ≠ g⁻¹(E[η])` by Jensen. The curvature-based classification is:
3019    ///   * all log-link families (Poisson / Gamma / Tweedie / NegativeBinomial):
3020    ///     `E[exp η] = exp(η + se²/2) ≠ exp(η)` (log-normal MGF);
3021    ///   * all Binomial links (logit / probit / cloglog / SAS / BetaLogistic /
3022    ///     Mixture / LatentCLogLog): bounded sigmoidal inverse links;
3023    ///   * Beta (logit link): `E[σ(η)] ≠ σ(E[η])`;
3024    ///   * Royston–Parmar (curved survival-probability inverse link).
3025    /// The integral collapses to the plug-in (so the cheaper plug-in path is
3026    /// exact and taken instead) only for the effectively-linear identity-link
3027    /// Gaussian. Any model carrying a link wiggle or baseline-time wiggle is
3028    /// curved regardless of family. This curvature partition mirrors
3029    /// `families::family_runtime::posterior_mean`, the compute path that produces the
3030    /// corrected mean for each of these families.
3031    ///
3032    /// This is the single source of truth shared by the CLI (`gam predict`)
3033    /// and the Python FFI prediction path so the two can never drift on which
3034    /// models receive the posterior-mean correction.
3035    #[inline]
3036    pub fn prediction_uses_posterior_mean(&self) -> bool {
3037        let family = self.likelihood();
3038        let curved_family = match &family.response {
3039            // Identity-link Gaussian: inverse link is linear, so the posterior
3040            // mean equals the plug-in and the cheaper exact path is taken.
3041            ResponseFamily::Gaussian => false,
3042            // Log-link families: E[exp η] = exp(η + se²/2) ≠ exp(η).
3043            ResponseFamily::Poisson
3044            | ResponseFamily::Gamma
3045            | ResponseFamily::Tweedie { .. }
3046            | ResponseFamily::NegativeBinomial { .. } => true,
3047            // Beta (logit link): E[σ(η)] ≠ σ(E[η]).
3048            ResponseFamily::Beta { .. } => true,
3049            // Royston–Parmar: curved survival-probability inverse link.
3050            ResponseFamily::RoystonParmar => true,
3051            // Binomial: every link variant (logit / probit / cloglog / SAS /
3052            // BetaLogistic / Mixture / LatentCLogLog) is a curved sigmoid.
3053            ResponseFamily::Binomial => matches!(
3054                &family.link,
3055                InverseLink::Standard(_)
3056                    | InverseLink::Sas(_)
3057                    | InverseLink::BetaLogistic(_)
3058                    | InverseLink::Mixture(_)
3059                    | InverseLink::LatentCLogLog(_)
3060            ),
3061        };
3062        curved_family || self.has_link_wiggle() || self.has_baseline_time_wiggle()
3063    }
3064
3065    pub fn saved_prediction_runtime(&self) -> Result<SavedPredictionRuntime, FittedModelError> {
3066        self.payload().validate_payload_version()?;
3067        if matches!(
3068            self.predict_model_class(),
3069            PredictModelClass::BernoulliMarginalSlope | PredictModelClass::Survival
3070        ) {
3071            if let Some(runtime) = self.payload().score_warp_runtime.as_ref() {
3072                runtime.validate_exact_replay_contract().map_err(|err| {
3073                    FittedModelError::PayloadCorrupt {
3074                        reason: format!("saved anchored score-warp runtime is invalid: {err}"),
3075                    }
3076                })?;
3077            }
3078            if let Some(runtime) = self.payload().link_deviation_runtime.as_ref() {
3079                runtime.validate_exact_replay_contract().map_err(|err| {
3080                    FittedModelError::PayloadCorrupt {
3081                        reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
3082                    }
3083                })?;
3084            }
3085        }
3086        let runtime = SavedPredictionRuntime {
3087            model_class: self.predict_model_class(),
3088            likelihood: self.likelihood(),
3089            inverse_link: self.resolved_inverse_link()?,
3090            link_wiggle: self.saved_link_wiggle()?,
3091            baseline_time_wiggle: self.saved_baseline_time_wiggle()?,
3092            score_warp: self.payload().score_warp_runtime.clone(),
3093            link_deviation: self.payload().link_deviation_runtime.clone(),
3094            latent_z_rank_int_calibration: self.payload().latent_z_rank_int_calibration.clone(),
3095            latent_z_conditional_calibration: self
3096                .payload()
3097                .latent_z_conditional_calibration
3098                .clone(),
3099            influence_absorber_width: self.payload().influence_absorber_width,
3100        };
3101        if matches!(
3102            runtime.model_class,
3103            PredictModelClass::GaussianLocationScale
3104                | PredictModelClass::BinomialLocationScale
3105                | PredictModelClass::DispersionLocationScale
3106        ) {
3107            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3108                FittedModelError::MissingField {
3109                    reason: "location-scale model is missing canonical fit_result payload"
3110                        .to_string(),
3111                }
3112            })?;
3113            validate_location_scale_saved_fit(
3114                fit,
3115                runtime.model_class,
3116                runtime.link_wiggle.as_ref(),
3117            )?;
3118        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3119            && self
3120                .payload()
3121                .survival_likelihood
3122                .as_deref()
3123                .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
3124        {
3125            validate_survival_location_scale_saved_fit(
3126                self.payload(),
3127                runtime.link_wiggle.as_ref(),
3128            )?;
3129        } else if matches!(
3130            runtime.model_class,
3131            PredictModelClass::BernoulliMarginalSlope
3132        ) {
3133            let unified =
3134                self.payload()
3135                    .unified
3136                    .as_ref()
3137                    .ok_or_else(|| FittedModelError::MissingField {
3138                        reason: "marginal-slope model is missing unified fit payload; refit"
3139                            .to_string(),
3140                    })?;
3141            validate_marginal_slope_saved_fit(
3142                unified,
3143                runtime.score_warp.as_ref(),
3144                runtime.link_deviation.as_ref(),
3145                "unified",
3146            )?;
3147        } else if matches!(runtime.model_class, PredictModelClass::Survival)
3148            && self
3149                .payload()
3150                .survival_likelihood
3151                .as_deref()
3152                .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
3153        {
3154            let fit = self.payload().fit_result.as_ref().ok_or_else(|| {
3155                FittedModelError::MissingField {
3156                    reason: "survival marginal-slope model is missing canonical fit_result payload"
3157                        .to_string(),
3158                }
3159            })?;
3160            validate_survival_marginal_slope_saved_fit(
3161                fit,
3162                runtime.score_warp.as_ref(),
3163                runtime.link_deviation.as_ref(),
3164                "fit_result",
3165            )?;
3166        }
3167        Ok(runtime)
3168    }
3169
3170    pub fn saved_sas_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3171        let payload = self.payload();
3172        let raw = match &payload.family_state {
3173            FittedFamily::Standard {
3174                likelihood,
3175                sas_state,
3176                ..
3177            } if likelihood.is_binomial_sas() => {
3178                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3179                    reason: "binomial-sas model is missing state in family_state.sas_state"
3180                        .to_string(),
3181                })?
3182            }
3183            FittedFamily::LocationScale {
3184                likelihood,
3185                base_link,
3186            } if likelihood.is_binomial_sas() => match base_link {
3187                Some(InverseLink::Sas(state)) => *state,
3188                _ => {
3189                    return Err(FittedModelError::MissingField {
3190                        reason: "binomial-sas location-scale model is missing SAS base_link state"
3191                            .to_string(),
3192                    });
3193                }
3194            },
3195            _ => return Ok(None),
3196        };
3197        state_from_sasspec(SasLinkSpec {
3198            initial_epsilon: raw.epsilon,
3199            initial_log_delta: raw.log_delta,
3200        })
3201        .map(Some)
3202        .map_err(|e| FittedModelError::PayloadCorrupt {
3203            reason: format!("invalid saved SAS link state: {e}"),
3204        })
3205    }
3206
3207    pub fn saved_beta_logistic_state(&self) -> Result<Option<SasLinkState>, FittedModelError> {
3208        let payload = self.payload();
3209        let raw = match &payload.family_state {
3210            FittedFamily::Standard {
3211                likelihood,
3212                sas_state,
3213                ..
3214            } if likelihood.is_binomial_beta_logistic() => {
3215                (*sas_state).ok_or_else(|| FittedModelError::MissingField {
3216                    reason:
3217                        "binomial-beta-logistic model is missing state in family_state.sas_state"
3218                            .to_string(),
3219                })?
3220            }
3221            FittedFamily::LocationScale {
3222                likelihood,
3223                base_link,
3224            } if likelihood.is_binomial_beta_logistic() => match base_link {
3225                Some(InverseLink::BetaLogistic(state)) => *state,
3226                _ => {
3227                    return Err(FittedModelError::MissingField {
3228                        reason:
3229                            "binomial-beta-logistic location-scale model is missing beta-logistic base_link state"
3230                                .to_string(),
3231                    });
3232                }
3233            },
3234            _ => return Ok(None),
3235        };
3236        state_from_beta_logisticspec(SasLinkSpec {
3237            initial_epsilon: raw.epsilon,
3238            initial_log_delta: raw.log_delta,
3239        })
3240        .map(Some)
3241        .map_err(|e| FittedModelError::PayloadCorrupt {
3242            reason: format!("invalid saved Beta-Logistic link state: {e}"),
3243        })
3244    }
3245
3246    pub fn saved_mixture_state(&self) -> Result<Option<MixtureLinkState>, FittedModelError> {
3247        let payload = self.payload();
3248        match &payload.family_state {
3249            FittedFamily::Standard {
3250                likelihood,
3251                mixture_state,
3252                ..
3253            } if likelihood.is_binomial_mixture() => mixture_state
3254                .clone()
3255                .ok_or_else(|| FittedModelError::MissingField {
3256                    reason: "binomial-mixture model is missing state in family_state.mixture_state"
3257                        .to_string(),
3258                })
3259                .map(Some),
3260            FittedFamily::LocationScale {
3261                likelihood,
3262                base_link,
3263            } if likelihood.is_binomial_mixture() => match base_link {
3264                Some(InverseLink::Mixture(state)) => Ok(Some(state.clone())),
3265                _ => Err(FittedModelError::MissingField {
3266                    reason:
3267                        "binomial-mixture location-scale model is missing mixture base_link state"
3268                            .to_string(),
3269                }),
3270            },
3271            _ => Ok(None),
3272        }
3273    }
3274
3275    pub fn saved_latent_cloglog_state(
3276        &self,
3277    ) -> Result<Option<LatentCLogLogState>, FittedModelError> {
3278        let payload = self.payload();
3279        match &payload.family_state {
3280            FittedFamily::Standard {
3281                likelihood,
3282                latent_cloglog_state,
3283                ..
3284            } if likelihood.is_latent_cloglog() => latent_cloglog_state
3285                .ok_or_else(|| FittedModelError::MissingField {
3286                    reason:
3287                        "latent-cloglog-binomial model is missing state in family_state.latent_cloglog_state"
3288                            .to_string(),
3289                })
3290                .map(Some),
3291            _ => Ok(None),
3292        }
3293    }
3294
3295    pub fn resolved_inverse_link(&self) -> Result<Option<InverseLink>, FittedModelError> {
3296        let stateful = if let Some(state) = self.saved_mixture_state()? {
3297            Some(InverseLink::Mixture(state))
3298        } else if let Some(state) = self.saved_latent_cloglog_state()? {
3299            Some(InverseLink::LatentCLogLog(state))
3300        } else if let Some(state) = self.saved_beta_logistic_state()? {
3301            Some(InverseLink::BetaLogistic(state))
3302        } else {
3303            self.saved_sas_state()?.map(InverseLink::Sas)
3304        };
3305        match &self.payload().family_state {
3306            FittedFamily::LocationScale { base_link, .. } => Ok(base_link.clone().or(stateful)),
3307            FittedFamily::Standard { link, .. } => {
3308                Ok(stateful.or_else(|| link.map(InverseLink::Standard)))
3309            }
3310            FittedFamily::MarginalSlope { base_link, .. } => Ok(Some(base_link.clone())),
3311            FittedFamily::Survival { .. }
3312            | FittedFamily::LatentSurvival { .. }
3313            | FittedFamily::LatentBinary { .. } => Ok(None),
3314            FittedFamily::TransformationNormal { .. } => Ok(None),
3315        }
3316    }
3317
3318    /// V∞ §5 coverage floor for the measure-jet extrapolation variance: a
3319    /// band level "covers" a query once its kernel mass reaches this fraction
3320    /// of that level's web-averaged support. Magic-by-default (no dial):
3321    /// 0.05 keeps the ε★ gate's bounded discontinuity at ≤ 5 % of the
3322    /// spectrum's total prior ignorance (see the monotonicity theorem in
3323    /// `terms/basis/measure_jet_predict.rs`) while still refusing credit
3324    /// for stray sub-floor kernel mass at levels finer than the first
3325    /// covering scale.
3326    const MEASURE_JET_COVERAGE_FLOOR: f64 = 0.05;
3327
3328    /// V∞ §5 producer: per-row measure-jet extrapolation variance on the η
3329    /// scale for a prediction batch (`docs/measure_jet_v_infinity.md`).
3330    ///
3331    /// For every frozen measure-jet term in `resolved_termspec` this prices
3332    /// the off-support ignorance of the fitted multiscale spectrum at each
3333    /// query row: support curve from the frozen nodes/masses/band
3334    /// ([`gam_terms::basis::measure_jet_support_curve`]), fitted per-scale
3335    /// amplitudes λ̂_ℓ read from the fit's `lambdas` through the replayed
3336    /// design's penalty layout, folded through
3337    /// [`gam_terms::basis::measure_jet_extrapolation_variance`] and scaled by
3338    /// the fit's coefficient-covariance scale φ̂ so the result sits on Vp's
3339    /// η-variance scale. Terms not yet frozen (no `frozen_quadrature` or
3340    /// non-`UserProvided` centers) are skipped with a warning. Returns
3341    /// `Ok(None)` when no measure-jet term contributes, so callers leave
3342    /// `PredictUncertaintyOptions::extrapolation_variance` untouched.
3343    ///
3344    /// `data` must be the RAW (unclipped) prediction matrix in prediction
3345    /// column order — clipping to the training ranges would freeze the
3346    /// distance signal at the hull and defeat the honesty contract — and
3347    /// `col_map` the prediction header → column map (the same map handed to
3348    /// the design builder). This is the minimal-plumbing producer seam: the
3349    /// option-building callers (CLI predict, FFI) hold exactly
3350    /// `(model, data, col_map)` at the point where they assemble
3351    /// `PredictUncertaintyOptions`, and the fusion in
3352    /// `predict_gamwith_uncertainty` adds the array AFTER its multiplicative
3353    /// inflations: `Var_total = Var_Vp·inflation + Var_extrap`.
3354    pub fn measure_jet_extrapolation_variance(
3355        &self,
3356        data: ndarray::ArrayView2<'_, f64>,
3357        col_map: &HashMap<String, usize>,
3358    ) -> Result<Option<Array1<f64>>, FittedModelError> {
3359        use gam_terms::basis::{CenterStrategy, MeasureJetExtrapolationSpectrum, PenaltySource};
3360        use gam_terms::smooth::build_term_collection_design;
3361        use gam_terms::smooth::SmoothBasisSpec;
3362        let Some(saved_spec) = self.resolved_termspec.as_ref() else {
3363            return Ok(None);
3364        };
3365        if data.nrows() == 0
3366            || !saved_spec
3367                .smooth_terms
3368                .iter()
3369                .any(|t| matches!(t.basis, SmoothBasisSpec::MeasureJet { .. }))
3370        {
3371            return Ok(None);
3372        }
3373        let fit = self
3374            .fit_result
3375            .as_ref()
3376            .ok_or_else(|| FittedModelError::MissingField {
3377                reason: "measure-jet extrapolation variance requires the canonical \
3378                    fit_result payload; refit"
3379                    .to_string(),
3380            })?;
3381        let spec = crate::survival::predict::resolve_termspec_for_prediction(
3382            &self.resolved_termspec,
3383            self.training_headers.as_ref(),
3384            col_map,
3385            "resolved_termspec",
3386        )
3387        .map_err(|e| FittedModelError::SchemaMismatch {
3388            reason: format!("measure-jet extrapolation variance: {e}"),
3389        })?;
3390        // Penalty layout replay: the global penalty indices (→ `fit.lambdas`)
3391        // come from the SAME design builder the predict pipeline uses. One
3392        // probe row suffices — for a frozen spec the penalty layout is
3393        // row-count-invariant (centers, masses, band, and identifiability
3394        // transforms all replay verbatim) — keeping this O(centers²) instead
3395        // of duplicating the full O(rows·centers) prediction design build.
3396        let probe = data.slice(ndarray::s![0..1, ..]);
3397        let design = build_term_collection_design(probe, &spec).map_err(|e| {
3398            FittedModelError::SchemaMismatch {
3399                reason: format!(
3400                    "measure-jet extrapolation variance: penalty-layout replay failed: {e}"
3401                ),
3402            }
3403        })?;
3404        let lambdas = &fit.lambdas;
3405        // λ̂ are fitted on Frobenius-normalized penalties. The term loop
3406        // unnormalizes them to physical precisions before pricing; multiplying
3407        // by the coefficient-covariance scale puts Var_extrap on the same
3408        // η-variance scale as Vp.
3409        let phi_scale = fit.coefficient_covariance_scale();
3410        let mut total = Array1::<f64>::zeros(data.nrows());
3411        let mut contributed = false;
3412        for term in &spec.smooth_terms {
3413            let SmoothBasisSpec::MeasureJet {
3414                feature_cols,
3415                spec: mj,
3416                input_scales,
3417            } = &term.basis
3418            else {
3419                continue;
3420            };
3421            let (Some(frozen), CenterStrategy::UserProvided(centers)) =
3422                (mj.frozen_quadrature.as_ref(), &mj.center_strategy)
3423            else {
3424                log::warn!(
3425                    "measure-jet term '{}' is not frozen (UserProvided centers + frozen \
3426                    quadrature); skipping its extrapolation variance",
3427                    term.name
3428                );
3429                continue;
3430            };
3431            let n_levels = frozen.eps_band.len();
3432            // λ̂ per level from the replayed layout: per-scale candidates carry
3433            // `PenaltySource::Other("measure_jet_scale_ℓ")`; fused
3434            // (pinned-order) mode carries one Primary charged once for the
3435            // whole band. The DoublePenaltyNullspace ridge is EXCLUDED — it shrinks
3436            // coefficients, it is not a scale amplitude, and counting it would
3437            // double-charge the spectrum.
3438            let read_lambda = |global_index: usize| -> Result<f64, FittedModelError> {
3439                lambdas
3440                    .get(global_index)
3441                    .copied()
3442                    .ok_or_else(|| FittedModelError::SchemaMismatch {
3443                        reason: format!(
3444                            "measure-jet term '{}': penalty global index {global_index} out \
3445                            of bounds for {} fitted lambdas",
3446                            term.name,
3447                            lambdas.len()
3448                        ),
3449                    })
3450            };
3451            let mut per_scale: Vec<(usize, f64)> = Vec::new();
3452            let mut fused: Option<f64> = None;
3453            for info in &design.penaltyinfo {
3454                if info.termname.as_deref() != Some(term.name.as_str()) {
3455                    continue;
3456                }
3457                match &info.penalty.source {
3458                    PenaltySource::Other(label) => {
3459                        if let Some(level_txt) = label.strip_prefix("measure_jet_scale_") {
3460                            let level: usize = level_txt.parse().map_err(|_| {
3461                                FittedModelError::SchemaMismatch {
3462                                    reason: format!(
3463                                        "measure-jet term '{}': unparseable penalty label \
3464                                        '{label}'",
3465                                        term.name
3466                                    ),
3467                                }
3468                            })?;
3469                            per_scale.push((level, read_lambda(info.global_index)?));
3470                        }
3471                    }
3472                    PenaltySource::Primary => {
3473                        fused = Some(read_lambda(info.global_index)?);
3474                    }
3475                    _ => {}
3476                }
3477            }
3478            let mut lambda_phys = Vec::with_capacity(n_levels);
3479            let spectrum = if per_scale.is_empty() {
3480                let Some(lam) = fused else {
3481                    log::warn!(
3482                        "measure-jet term '{}' has no fitted amplitude in the penalty \
3483                        layout; skipping its extrapolation variance",
3484                        term.name
3485                    );
3486                    continue;
3487                };
3488                let Some(c) = frozen.fused_penalty_normalization_scale else {
3489                    log::warn!(
3490                        "measure-jet term '{}' is missing the fused penalty normalization scale; \
3491                        skipping its extrapolation variance",
3492                        term.name
3493                    );
3494                    continue;
3495                };
3496                MeasureJetExtrapolationSpectrum::Fused(lam / c)
3497            } else {
3498                per_scale.sort_by_key(|&(level, _)| level);
3499                let levels_complete = per_scale.len() == n_levels
3500                    && per_scale
3501                        .iter()
3502                        .enumerate()
3503                        .all(|(i, &(level, _))| level == i);
3504                if !levels_complete {
3505                    log::warn!(
3506                        "measure-jet term '{}': {} fitted per-scale amplitudes for {} band \
3507                        scales; skipping its extrapolation variance",
3508                        term.name,
3509                        per_scale.len(),
3510                        n_levels
3511                    );
3512                    continue;
3513                }
3514                if frozen.penalty_normalization_scales.len() != n_levels {
3515                    log::warn!(
3516                        "measure-jet term '{}': {} frozen penalty normalization scales for {} \
3517                        band scales; skipping its extrapolation variance",
3518                        term.name,
3519                        frozen.penalty_normalization_scales.len(),
3520                        n_levels
3521                    );
3522                    continue;
3523                }
3524                lambda_phys.extend(
3525                    per_scale
3526                        .iter()
3527                        .map(|&(level, lam)| lam / frozen.penalty_normalization_scales[level]),
3528                );
3529                MeasureJetExtrapolationSpectrum::PerLevel(&lambda_phys)
3530            };
3531            // Query rows in the frozen geometry's coordinates: select the
3532            // term's axes and replay the per-axis standardization exactly as
3533            // the build dispatch does (divide by σ_a when input_scales is
3534            // Some; the persisted centers are already post-standardization).
3535            let mut queries = Array2::<f64>::zeros((data.nrows(), feature_cols.len()));
3536            for (j, &col) in feature_cols.iter().enumerate() {
3537                if col >= data.ncols() {
3538                    return Err(FittedModelError::SchemaMismatch {
3539                        reason: format!(
3540                            "measure-jet term '{}': prediction column {col} out of bounds \
3541                            for {} data columns",
3542                            term.name,
3543                            data.ncols()
3544                        ),
3545                    });
3546                }
3547                queries.column_mut(j).assign(&data.column(col));
3548            }
3549            if let Some(scales) = input_scales {
3550                if scales.len() != feature_cols.len() {
3551                    return Err(FittedModelError::SchemaMismatch {
3552                        reason: format!(
3553                            "measure-jet term '{}': {} input scales for {} axes",
3554                            term.name,
3555                            scales.len(),
3556                            feature_cols.len()
3557                        ),
3558                    });
3559                }
3560                for (j, &scale) in scales.iter().enumerate() {
3561                    queries.column_mut(j).mapv_inplace(|v| v / scale);
3562                }
3563            }
3564            let support = gam_terms::basis::measure_jet_support_curve(
3565                queries.view(),
3566                centers.view(),
3567                frozen.masses.view(),
3568                &frozen.eps_band,
3569            )
3570            .map_err(|e| FittedModelError::SchemaMismatch {
3571                reason: format!(
3572                    "measure-jet term '{}': support curve failed: {e}",
3573                    term.name
3574                ),
3575            })?;
3576            for i in 0..data.nrows() {
3577                let v = gam_terms::basis::measure_jet_extrapolation_variance(
3578                    support.row(i),
3579                    &frozen.eps_band,
3580                    &frozen.support_means,
3581                    spectrum,
3582                    Self::MEASURE_JET_COVERAGE_FLOOR,
3583                )
3584                .map_err(|e| FittedModelError::SchemaMismatch {
3585                    reason: format!(
3586                        "measure-jet term '{}': extrapolation variance failed: {e}",
3587                        term.name
3588                    ),
3589                })?;
3590                total[i] += phi_scale * v;
3591            }
3592            contributed = true;
3593        }
3594        Ok(contributed.then_some(total))
3595    }
3596
3597    /// Access the unified fit result, if stored.
3598    pub fn unified(&self) -> Option<&UnifiedFitResult> {
3599        self.payload().unified.as_ref()
3600    }
3601
3602    pub fn load_from_path(path: &Path) -> Result<Self, FittedModelError> {
3603        let payload = fs::read_to_string(path).map_err(|e| FittedModelError::PayloadCorrupt {
3604            reason: format!("failed to read model '{}': {e}", path.display()),
3605        })?;
3606        let model: Self =
3607            serde_json::from_str(&payload).map_err(|e| FittedModelError::PayloadCorrupt {
3608                reason: format!("failed to parse model json: {e}"),
3609            })?;
3610        let model = model.with_synchronized_stateful_link_metadata();
3611        model.validate_for_persistence()?;
3612        model.validate_numeric_finiteness()?;
3613        Ok(model)
3614    }
3615
3616    pub fn save_to_path(&self, path: &Path) -> Result<(), FittedModelError> {
3617        let normalized = self.clone().with_synchronized_stateful_link_metadata();
3618        normalized.validate_for_persistence()?;
3619        normalized.validate_numeric_finiteness()?;
3620        // Write to a sibling temp file, fsync, then rename into place so a
3621        // crash mid-write never corrupts the user's existing saved fit.
3622        // Concurrent writers to the same path each have a distinct temp
3623        // suffix (pid + nanos), so neither stomps the other's in-flight
3624        // bytes; the rename winner is last-rename-wins, which is the
3625        // expected last-write-wins semantics for a single canonical path.
3626        let parent = path.parent().unwrap_or_else(|| Path::new("."));
3627        let file_name = path
3628            .file_name()
3629            .and_then(|s| s.to_str())
3630            .unwrap_or("model.json");
3631        let pid = std::process::id();
3632        let nanos = std::time::SystemTime::now()
3633            .duration_since(std::time::UNIX_EPOCH)
3634            .map(|d| d.as_nanos())
3635            .unwrap_or(0);
3636        let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nanos:x}"));
3637        let file = fs::File::create(&tmp).map_err(|e| FittedModelError::PayloadCorrupt {
3638            reason: format!("failed to write model '{}': {e}", tmp.display()),
3639        })?;
3640        let mut writer = std::io::BufWriter::new(file);
3641        let ser_result = serde_json::to_writer(&mut writer, &normalized);
3642        if let Err(e) = ser_result {
3643            // Best-effort temp cleanup on serialization failure. flush
3644            // returns io::Result<()>; discarding via `.ok()` is enough.
3645            std::io::Write::flush(&mut writer).ok();
3646            drop(writer);
3647            fs::remove_file(&tmp).ok();
3648            return Err(FittedModelError::PayloadCorrupt {
3649                reason: format!("failed to serialize model: {e}"),
3650            });
3651        }
3652        std::io::Write::flush(&mut writer).map_err(|e| FittedModelError::PayloadCorrupt {
3653            reason: format!("failed to write model '{}': {e}", tmp.display()),
3654        })?;
3655        // Recover the underlying File to fsync its contents before rename.
3656        let inner = writer
3657            .into_inner()
3658            .map_err(|e| FittedModelError::PayloadCorrupt {
3659                reason: format!("failed to flush model '{}': {}", tmp.display(), e.error()),
3660            })?;
3661        inner.sync_all().ok();
3662        drop(inner);
3663        if let Err(e) = fs::rename(&tmp, path) {
3664            fs::remove_file(&tmp).ok();
3665            return Err(FittedModelError::PayloadCorrupt {
3666                reason: format!("failed to publish model '{}': {e}", path.display()),
3667            });
3668        }
3669        // fsync the parent directory so the rename itself is durable
3670        // across a crash; without this, the rename can be lost even though
3671        // file contents reached disk. Best-effort on platforms that don't
3672        // support opening a directory for fsync.
3673        if let Ok(d) = fs::File::open(parent) {
3674            d.sync_all().ok();
3675        }
3676        Ok(())
3677    }
3678
3679    pub fn require_data_schema(&self) -> Result<&DataSchema, FittedModelError> {
3680        self.data_schema
3681            .as_ref()
3682            .ok_or_else(|| FittedModelError::MissingField {
3683                reason: "model is missing data_schema; refit".to_string(),
3684            })
3685    }
3686
3687    /// Restore the exact in-memory spline-scan fit from a scan-bearing
3688    /// payload (#1030/#1034). `Ok(None)` for dense models; the returned
3689    /// `predict` replays the training Gaussian bridge bit-for-bit.
3690    pub fn saved_spline_scan(
3691        &self,
3692    ) -> Result<Option<(&str, gam_solve::spline_scan::SplineScanFit)>, FittedModelError> {
3693        let Some(saved) = self.spline_scan.as_ref() else {
3694            return Ok(None);
3695        };
3696        let fit = gam_solve::spline_scan::SplineScanFit::from_state(&saved.state)
3697            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3698        Ok(Some((saved.feature_column.as_str(), fit)))
3699    }
3700
3701    /// Restore the in-memory residual-cascade fit from a cascade-bearing
3702    /// payload (#1032). `Ok(None)` for non-cascade models; the returned fit
3703    /// replays the multilevel Wendland-frame posterior for the d ∈ {2, 3}
3704    /// feature columns at each predict point.
3705    pub fn saved_residual_cascade(
3706        &self,
3707    ) -> Result<
3708        Option<(
3709            &[String],
3710            gam_solve::residual_cascade::ResidualCascadeFit,
3711        )>,
3712        FittedModelError,
3713    > {
3714        let Some(saved) = self.residual_cascade.as_ref() else {
3715            return Ok(None);
3716        };
3717        let fit = gam_solve::residual_cascade::ResidualCascadeFit::from_state(&saved.state)
3718            .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3719        Ok(Some((saved.feature_columns.as_slice(), fit)))
3720    }
3721
3722    pub fn random_effect_group_columns(&self) -> HashSet<String> {
3723        let Some(training_headers) = self.training_headers.as_ref() else {
3724            return HashSet::new();
3725        };
3726        let mut out = HashSet::<String>::new();
3727        for spec in self.saved_term_specs() {
3728            for term in &spec.random_effect_terms {
3729                if let Some(name) = training_headers.get(term.feature_col) {
3730                    out.insert(name.clone());
3731                }
3732            }
3733        }
3734        out
3735    }
3736
3737    pub fn validate_for_persistence(&self) -> Result<(), FittedModelError> {
3738        // Hard version gate. The struct's ~40 Option<T> fields carry
3739        // `#[serde(default)]`, which is by design forward-compatible: old
3740        // payloads missing a new optional field decode with `None`. BUT:
3741        // when a new CLI release adds a required field for some family_state
3742        // (enforced below), an older model loaded by the newer CLI would have
3743        // `None` in that slot and the family-specific branch below would
3744        // correctly reject it — unless the new field also happens to slot
3745        // under a branch that hasn't been touched. Conversely, a newer model
3746        // loaded by an older CLI silently drops fields the older struct
3747        // doesn't know about. Both directions are silent-drift hazards. We
3748        // close them with an exact-version check anchored to the canonical
3749        // MODEL_PAYLOAD_VERSION constant — every payload must round-trip
3750        // identically between writers and readers running the same schema.
3751        self.validate_payload_version()?;
3752        if let Some(scan) = self.spline_scan.as_ref() {
3753            // Spline-scan representation (#1030/#1034): the smoother state IS
3754            // the fit. It is exclusive with the dense representation, only
3755            // standard Gaussian-identity models can carry it, and the state
3756            // must restore cleanly so predict never sees a corrupt snapshot.
3757            if self.fit_result.is_some() || self.unified.is_some() {
3758                return Err(FittedModelError::SchemaMismatch {
3759                    reason: "spline-scan model must not also carry a dense fit_result/unified \
3760                             payload; the representations are mutually exclusive"
3761                        .to_string(),
3762                });
3763            }
3764            if self.model_kind != ModelKind::Standard
3765                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3766            {
3767                return Err(FittedModelError::SchemaMismatch {
3768                    reason: format!(
3769                        "spline-scan representation requires a standard Gaussian-identity model; \
3770                         got model_kind={:?}, likelihood={:?}",
3771                        self.model_kind,
3772                        self.family_state.likelihood()
3773                    ),
3774                });
3775            }
3776            if scan.feature_column.is_empty() {
3777                return Err(FittedModelError::MissingField {
3778                    reason: "spline-scan model is missing its feature column name; refit"
3779                        .to_string(),
3780                });
3781            }
3782            gam_solve::spline_scan::SplineScanFit::from_state(&scan.state)
3783                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3784            // A scan model carries NO dense design, so the dense-path
3785            // requirements below (resolved_termspec, fit_result finiteness,
3786            // family-specific blocks) do not apply. Enforce only the metadata
3787            // predict actually consumes — the feature column resolves against
3788            // training_headers / data_schema — then accept.
3789            if self.data_schema.is_none() {
3790                return Err(FittedModelError::MissingField {
3791                    reason: "spline-scan model is missing data_schema; refit".to_string(),
3792                });
3793            }
3794            if self.training_headers.is_none() {
3795                return Err(FittedModelError::MissingField {
3796                    reason: "spline-scan model is missing training_headers; refit".to_string(),
3797                });
3798            }
3799            return Ok(());
3800        } else if let Some(cascade) = self.residual_cascade.as_ref() {
3801            // Residual-cascade representation (#1032): a multilevel
3802            // Wendland-frame model for a scattered d ∈ {2,3} Gaussian smooth.
3803            // Exclusive with the dense representation and with the scan.
3804            if self.spline_scan.is_some() || self.fit_result.is_some() || self.unified.is_some() {
3805                return Err(FittedModelError::SchemaMismatch {
3806                    reason: "residual-cascade model must not also carry spline_scan / \
3807                             fit_result / unified payloads; the representations are \
3808                             mutually exclusive"
3809                        .to_string(),
3810                });
3811            }
3812            if self.model_kind != ModelKind::Standard
3813                || self.family_state.likelihood() != LikelihoodSpec::gaussian_identity()
3814            {
3815                return Err(FittedModelError::SchemaMismatch {
3816                    reason: format!(
3817                        "residual-cascade representation requires a standard Gaussian-identity \
3818                         model; got model_kind={:?}, likelihood={:?}",
3819                        self.model_kind,
3820                        self.family_state.likelihood()
3821                    ),
3822                });
3823            }
3824            if cascade.feature_columns.is_empty()
3825                || !(2..=3).contains(&cascade.feature_columns.len())
3826            {
3827                return Err(FittedModelError::MissingField {
3828                    reason: format!(
3829                        "residual-cascade model needs 2 or 3 feature columns; got {}; refit",
3830                        cascade.feature_columns.len()
3831                    ),
3832                });
3833            }
3834            gam_solve::residual_cascade::ResidualCascadeFit::from_state(&cascade.state)
3835                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3836            if self.data_schema.is_none() {
3837                return Err(FittedModelError::MissingField {
3838                    reason: "residual-cascade model is missing data_schema; refit".to_string(),
3839                });
3840            }
3841            if self.training_headers.is_none() {
3842                return Err(FittedModelError::MissingField {
3843                    reason: "residual-cascade model is missing training_headers; refit".to_string(),
3844                });
3845            }
3846            return Ok(());
3847        } else if self.fit_result.is_none() {
3848            return Err(FittedModelError::MissingField {
3849                reason: "model is missing canonical fit_result payload; refit".to_string(),
3850            });
3851        }
3852        if self.data_schema.is_none() {
3853            return Err(FittedModelError::MissingField {
3854                reason: "model is missing data_schema; refit".to_string(),
3855            });
3856        }
3857        if self.training_headers.is_none() {
3858            return Err(FittedModelError::MissingField {
3859                reason: "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
3860                    .to_string(),
3861            });
3862        }
3863        let spec = self.resolved_termspec.as_ref().ok_or_else(|| {
3864            FittedModelError::MissingField {
3865                reason: "model is missing resolved_termspec; refit to guarantee train/predict design consistency"
3866                    .to_string(),
3867            }
3868        })?;
3869        validate_frozen_term_collectionspec(spec, "resolved_termspec")?;
3870
3871        if self.formula_noise.is_some() && self.resolved_termspec_noise.is_none() {
3872            return Err(FittedModelError::MissingField {
3873                reason: "model defines formula_noise but is missing resolved_termspec_noise; refit"
3874                    .to_string(),
3875            });
3876        }
3877        if let Some(spec_noise) = self.resolved_termspec_noise.as_ref() {
3878            validate_frozen_term_collectionspec(spec_noise, "resolved_termspec_noise")?;
3879        }
3880        if matches!(self.family_state, FittedFamily::TransformationNormal { .. }) {
3881            let score = self.transformation_score_calibration.ok_or_else(|| {
3882                FittedModelError::MissingField {
3883                    reason: "transformation-normal model is missing transformation_score_calibration; refit"
3884                        .to_string(),
3885                }
3886            })?;
3887            score.validate("transformation-normal model")?;
3888        }
3889        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
3890            if self.formula_logslope.is_none() {
3891                return Err(FittedModelError::MissingField {
3892                    reason: "marginal-slope model is missing formula_logslope; refit".to_string(),
3893                });
3894            }
3895            if self.z_column.is_none() {
3896                return Err(FittedModelError::MissingField {
3897                    reason: "marginal-slope model is missing z_column; refit".to_string(),
3898                });
3899            }
3900            let z_normalization =
3901                self.latent_z_normalization
3902                    .ok_or_else(|| FittedModelError::MissingField {
3903                        reason: "marginal-slope model is missing latent_z_normalization; refit"
3904                            .to_string(),
3905                    })?;
3906            z_normalization.validate("marginal-slope model")?;
3907            let latent_measure =
3908                self.latent_measure
3909                    .as_ref()
3910                    .ok_or_else(|| FittedModelError::MissingField {
3911                        reason: "marginal-slope model is missing latent_measure; refit".to_string(),
3912                    })?;
3913            latent_measure
3914                .validate("marginal-slope model latent_measure")
3915                .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
3916            if self.marginal_baseline.is_none() || self.logslope_baseline.is_none() {
3917                return Err(FittedModelError::MissingField {
3918                    reason: "marginal-slope model is missing baseline offsets; refit".to_string(),
3919                });
3920            }
3921            if self.resolved_termspec_logslope.as_ref().is_none() {
3922                return Err(FittedModelError::MissingField {
3923                    reason: "marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
3924                        .to_string(),
3925                });
3926            }
3927            match self.family_state.frailty() {
3928                Some(FrailtySpec::None)
3929                | Some(FrailtySpec::GaussianShift {
3930                    sigma_fixed: Some(_),
3931                }) => {}
3932                Some(FrailtySpec::GaussianShift { sigma_fixed: None }) => {
3933                    return Err(FittedModelError::IncompatibleConfig {
3934                        reason: "marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
3935                            .to_string(),
3936                    });
3937                }
3938                Some(FrailtySpec::HazardMultiplier { .. }) => {
3939                    return Err(FittedModelError::IncompatibleConfig {
3940                        reason: "marginal-slope model does not support HazardMultiplier frailty"
3941                            .to_string(),
3942                    });
3943                }
3944                None => {
3945                    return Err(FittedModelError::MissingField {
3946                        reason: "marginal-slope model is missing family_state.frailty; refit"
3947                            .to_string(),
3948                    });
3949                }
3950            }
3951        }
3952
3953        if let FittedFamily::Survival {
3954            survival_likelihood,
3955            frailty,
3956            ..
3957        } = &self.family_state
3958        {
3959            if matches!(
3960                survival_likelihood.as_deref(),
3961                Some("latent") | Some("latent-binary")
3962            ) {
3963                return Err(FittedModelError::SchemaMismatch {
3964                    reason: "latent hazard-window models must persist explicit family_state metadata, not generic survival metadata"
3965                        .to_string(),
3966                });
3967            }
3968            if survival_likelihood.as_deref() == Some("marginal-slope") {
3969                if self.formula_logslope.is_none() {
3970                    return Err(FittedModelError::MissingField {
3971                        reason: "survival marginal-slope model is missing formula_logslope; refit"
3972                            .to_string(),
3973                    });
3974                }
3975                if self.z_column.is_none() {
3976                    return Err(FittedModelError::MissingField {
3977                        reason: "survival marginal-slope model is missing z_column; refit"
3978                            .to_string(),
3979                    });
3980                }
3981                let z_normalization =
3982                    self.latent_z_normalization
3983                        .ok_or_else(|| {
3984                            FittedModelError::MissingField {
3985                        reason:
3986                            "survival marginal-slope model is missing latent_z_normalization; refit"
3987                                .to_string(),
3988                    }
3989                        })?;
3990                z_normalization.validate("survival marginal-slope model")?;
3991                let latent_measure =
3992                    self.latent_measure
3993                        .as_ref()
3994                        .ok_or_else(|| FittedModelError::MissingField {
3995                            reason:
3996                                "survival marginal-slope model is missing latent_measure; refit"
3997                                    .to_string(),
3998                        })?;
3999                latent_measure
4000                    .validate("survival marginal-slope model latent_measure")
4001                    .map_err(|reason| FittedModelError::PayloadCorrupt { reason })?;
4002                if self.logslope_baseline.is_none() {
4003                    return Err(FittedModelError::MissingField {
4004                        reason: "survival marginal-slope model is missing logslope_baseline; refit"
4005                            .to_string(),
4006                    });
4007                }
4008                if self.resolved_termspec_logslope.as_ref().is_none() {
4009                    return Err(FittedModelError::MissingField {
4010                        reason: "survival marginal-slope model is missing resolved_termspec_logslope for the logslope surface"
4011                            .to_string(),
4012                    });
4013                }
4014                match frailty {
4015                    FrailtySpec::None
4016                    | FrailtySpec::GaussianShift {
4017                        sigma_fixed: Some(_),
4018                    } => {}
4019                    FrailtySpec::GaussianShift { sigma_fixed: None } => {
4020                        return Err(FittedModelError::IncompatibleConfig {
4021                            reason: "survival marginal-slope model requires a fixed GaussianShift sigma in family_state.frailty"
4022                                .to_string(),
4023                        });
4024                    }
4025                    FrailtySpec::HazardMultiplier { .. } => {
4026                        return Err(FittedModelError::IncompatibleConfig {
4027                            reason: "survival marginal-slope model does not support HazardMultiplier frailty"
4028                                .to_string(),
4029                        });
4030                    }
4031                }
4032            } else if !matches!(frailty, FrailtySpec::None) {
4033                return Err(FittedModelError::IncompatibleConfig {
4034                    reason:
4035                        "non-marginal survival models do not currently persist a frailty modifier"
4036                            .to_string(),
4037                });
4038            }
4039            // Non-latent survival predict reconstructs the baseline-time
4040            // basis via `load_survival_time_basis_config_from_model` and
4041            // anchors that basis at `survival_time_anchor`; both are
4042            // required for the saved model to be loadable. The CLI's
4043            // marginal-slope+time-wiggle save path previously dropped one or
4044            // the other on partial-write, producing models that loaded but
4045            // would panic at the first predict. Enforce both before persisting.
4046            if self.survival_time_basis.is_none() {
4047                return Err(FittedModelError::MissingField {
4048                    reason: "survival model is missing survival_time_basis; refit to persist the baseline-time basis configuration".to_string(),
4049                });
4050            }
4051            if self.survival_time_anchor.is_none() {
4052                return Err(FittedModelError::MissingField {
4053                    reason: "survival model is missing survival_time_anchor; refit to persist the baseline-time anchor".to_string(),
4054                });
4055            }
4056        }
4057        if let FittedFamily::LatentSurvival { frailty } = &self.family_state {
4058            match frailty {
4059                FrailtySpec::HazardMultiplier {
4060                    sigma_fixed: Some(_),
4061                    ..
4062                } => {}
4063                FrailtySpec::HazardMultiplier {
4064                    sigma_fixed: None, ..
4065                } => {
4066                    return Err(FittedModelError::IncompatibleConfig {
4067                        reason: "latent survival model requires a fixed HazardMultiplier sigma in family_state.frailty"
4068                            .to_string(),
4069                    });
4070                }
4071                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4072                    return Err(FittedModelError::IncompatibleConfig {
4073                        reason: "latent survival model requires a fixed HazardMultiplier frailty specification"
4074                            .to_string(),
4075                    });
4076                }
4077            }
4078            if self.survival_likelihood.as_deref() != Some("latent") {
4079                return Err(FittedModelError::SchemaMismatch {
4080                    reason: "latent survival model must persist survival_likelihood=latent"
4081                        .to_string(),
4082                });
4083            }
4084        }
4085        if let FittedFamily::LatentBinary { frailty } = &self.family_state {
4086            match frailty {
4087                FrailtySpec::HazardMultiplier {
4088                    sigma_fixed: Some(_),
4089                    ..
4090                } => {}
4091                FrailtySpec::HazardMultiplier {
4092                    sigma_fixed: None, ..
4093                } => {
4094                    return Err(FittedModelError::IncompatibleConfig {
4095                        reason: "latent binary model requires a fixed HazardMultiplier sigma in family_state.frailty"
4096                            .to_string(),
4097                    });
4098                }
4099                FrailtySpec::GaussianShift { .. } | FrailtySpec::None => {
4100                    return Err(FittedModelError::IncompatibleConfig {
4101                        reason: "latent binary model requires a fixed HazardMultiplier frailty specification"
4102                            .to_string(),
4103                    });
4104                }
4105            }
4106            if self.survival_likelihood.as_deref() != Some("latent-binary") {
4107                return Err(FittedModelError::SchemaMismatch {
4108                    reason: "latent binary model must persist survival_likelihood=latent-binary"
4109                        .to_string(),
4110                });
4111            }
4112        }
4113
4114        let family_likelihood = match &self.family_state {
4115            FittedFamily::Standard { likelihood, .. }
4116            | FittedFamily::LocationScale { likelihood, .. }
4117            | FittedFamily::MarginalSlope { likelihood, .. }
4118            | FittedFamily::Survival { likelihood, .. }
4119            | FittedFamily::TransformationNormal { likelihood, .. } => Some(likelihood),
4120            FittedFamily::LatentSurvival { .. } | FittedFamily::LatentBinary { .. } => None,
4121        };
4122        let is_standard_or_location_scale = matches!(
4123            self.family_state,
4124            FittedFamily::Standard { .. } | FittedFamily::LocationScale { .. }
4125        );
4126        if is_standard_or_location_scale
4127            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_sas)
4128        {
4129            self.saved_sas_state()?;
4130        }
4131        if is_standard_or_location_scale
4132            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_beta_logistic)
4133        {
4134            self.saved_beta_logistic_state()?;
4135        }
4136        if is_standard_or_location_scale
4137            && family_likelihood.is_some_and(LikelihoodSpec::is_binomial_mixture)
4138        {
4139            self.saved_mixture_state()?;
4140        }
4141        if matches!(self.family_state, FittedFamily::Standard { .. })
4142            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4143        {
4144            self.saved_latent_cloglog_state()?;
4145        }
4146        if matches!(self.family_state, FittedFamily::LocationScale { .. })
4147            && family_likelihood.is_some_and(LikelihoodSpec::is_latent_cloglog)
4148        {
4149            return Err(FittedModelError::IncompatibleConfig {
4150                reason: "latent-cloglog-binomial is not supported for location-scale saved models"
4151                    .to_string(),
4152            });
4153        }
4154        if matches!(self.family_state, FittedFamily::Survival { .. })
4155            && self.survival_likelihood.is_none()
4156        {
4157            return Err(FittedModelError::MissingField {
4158                reason: "saved survival model is missing survival_likelihood metadata; refit"
4159                    .to_string(),
4160            });
4161        }
4162        let has_any_saved_link_wiggle = self.linkwiggle_knots.is_some()
4163            || self.linkwiggle_degree.is_some()
4164            || self.beta_link_wiggle.is_some()
4165            || self
4166                .fit_result
4167                .as_ref()
4168                .and_then(|fit| fit.block_by_role(BlockRole::LinkWiggle))
4169                .is_some();
4170        let saved_link_wiggle = self.saved_link_wiggle()?;
4171        if has_any_saved_link_wiggle && saved_link_wiggle.is_none() {
4172            return Err(FittedModelError::SchemaMismatch {
4173                reason: "saved model has incomplete link-wiggle state; expected metadata and coefficients"
4174                    .to_string(),
4175            });
4176        }
4177        let has_any_saved_baseline_time_wiggle = self.baseline_timewiggle_knots.is_some()
4178            || self.baseline_timewiggle_degree.is_some()
4179            || self.baseline_timewiggle_penalty_orders.is_some()
4180            || self.baseline_timewiggle_double_penalty.is_some()
4181            || self.beta_baseline_timewiggle.is_some()
4182            || self.beta_baseline_timewiggle_by_cause.is_some();
4183        let is_joint_cause_specific = self
4184            .survival_cause_count
4185            .is_some_and(|cause_count| cause_count > 1);
4186        if has_any_saved_baseline_time_wiggle {
4187            if is_joint_cause_specific {
4188                let complete = self.baseline_timewiggle_knots.is_some()
4189                    && self.baseline_timewiggle_degree.is_some()
4190                    && self.baseline_timewiggle_penalty_orders.is_some()
4191                    && self.baseline_timewiggle_double_penalty.is_some()
4192                    && self.beta_baseline_timewiggle_by_cause.is_some();
4193                if !complete {
4194                    return Err(FittedModelError::SchemaMismatch {
4195                        reason: "saved joint cause-specific survival model has incomplete baseline-timewiggle state; expected metadata and per-cause coefficients"
4196                            .to_string(),
4197                    });
4198                }
4199            } else if self.saved_baseline_time_wiggle()?.is_none() {
4200                return Err(FittedModelError::SchemaMismatch {
4201                    reason: "saved model has incomplete baseline-timewiggle state; expected metadata and coefficients"
4202                        .to_string(),
4203                });
4204            }
4205        }
4206        if self
4207            .survival_likelihood
4208            .as_deref()
4209            .is_some_and(|value| value.eq_ignore_ascii_case("location-scale"))
4210        {
4211            validate_survival_location_scale_saved_fit(self.payload(), saved_link_wiggle.as_ref())?;
4212        }
4213
4214        // Validate anchored-deviation replay contracts at LOAD/SAVE time rather
4215        // than waiting for first predict call. Previously these contracts
4216        // (span table dimensions, coefficient matrices, etc.) were only
4217        // asserted inside `saved_prediction_runtime`, which runs on the first
4218        // predict invocation. A corrupted runtime would therefore pass
4219        // `load_from_path` silently and fail later under a different error
4220        // surface. Enforcing the same check here makes the model self-
4221        // diagnostic: `gam fit` catches its own bad output at save, and
4222        // `gam predict` catches bad input at load rather than mid-pipeline.
4223        if let Some(runtime) = self.score_warp_runtime.as_ref() {
4224            runtime.validate_exact_replay_contract().map_err(|err| {
4225                FittedModelError::PayloadCorrupt {
4226                    reason: format!("saved anchored score-warp runtime is invalid: {err}"),
4227                }
4228            })?;
4229        }
4230        if let Some(runtime) = self.link_deviation_runtime.as_ref() {
4231            runtime.validate_exact_replay_contract().map_err(|err| {
4232                FittedModelError::PayloadCorrupt {
4233                    reason: format!("saved anchored link-deviation runtime is invalid: {err}"),
4234                }
4235            })?;
4236        }
4237        if matches!(self.family_state, FittedFamily::MarginalSlope { .. }) {
4238            validate_marginal_slope_saved_fit(
4239                self.fit_result.as_ref().expect("checked above"),
4240                self.score_warp_runtime.as_ref(),
4241                self.link_deviation_runtime.as_ref(),
4242                "fit_result",
4243            )?;
4244            let unified = self
4245                .unified
4246                .as_ref()
4247                .ok_or_else(|| FittedModelError::MissingField {
4248                    reason: "marginal-slope model is missing unified fit payload; refit"
4249                        .to_string(),
4250                })?;
4251            validate_marginal_slope_saved_fit(
4252                unified,
4253                self.score_warp_runtime.as_ref(),
4254                self.link_deviation_runtime.as_ref(),
4255                "unified",
4256            )?;
4257        }
4258        if self
4259            .survival_likelihood
4260            .as_deref()
4261            .is_some_and(|value| value.eq_ignore_ascii_case("marginal-slope"))
4262        {
4263            validate_survival_marginal_slope_saved_fit(
4264                self.fit_result.as_ref().expect("checked above"),
4265                self.score_warp_runtime.as_ref(),
4266                self.link_deviation_runtime.as_ref(),
4267                "fit_result",
4268            )?;
4269            if let Some(unified) = self.unified.as_ref() {
4270                validate_survival_marginal_slope_saved_fit(
4271                    unified,
4272                    self.score_warp_runtime.as_ref(),
4273                    self.link_deviation_runtime.as_ref(),
4274                    "unified",
4275                )?;
4276            }
4277        }
4278
4279        // Posterior-mean / uncertainty backends are validated at predict time
4280        // by `prediction_backend_from_model`, which has access to the actual
4281        // requested mode and emits the canonical "nonlinear posterior-mean
4282        // prediction requires either covariance or a saved penalized Hessian"
4283        // error.  Save-time we deliberately do NOT enforce that gate: a fit
4284        // produced for MAP / plug-in scoring can be persisted and replayed
4285        // without ever needing a covariance backend, and gating it here would
4286        // refuse legitimate MAP-only saves whose `UnifiedFitResult` carries
4287        // beta + lambdas without a stabilized Hessian.
4288
4289        Ok(())
4290    }
4291
4292    pub fn validate_numeric_finiteness(&self) -> Result<(), FittedModelError> {
4293        let corrupt = |reason: String| FittedModelError::PayloadCorrupt { reason };
4294        if let Some(fit) = self.fit_result.as_ref() {
4295            fit.validate_numeric_finiteness()
4296                .map_err(|e| corrupt(e.to_string()))?;
4297        }
4298
4299        for (name, opt) in [
4300            ("survival_baseline_scale", self.survival_baseline_scale),
4301            ("survival_baseline_shape", self.survival_baseline_shape),
4302            ("survival_baseline_rate", self.survival_baseline_rate),
4303            ("survival_baseline_makeham", self.survival_baseline_makeham),
4304            (
4305                "survival_time_smooth_lambda",
4306                self.survival_time_smooth_lambda,
4307            ),
4308            ("survival_time_anchor", self.survival_time_anchor),
4309            ("survivalridge_lambda", self.survivalridge_lambda),
4310        ] {
4311            if let Some(v) = opt {
4312                ensure_finite_scalar(name, v).map_err(corrupt)?;
4313            }
4314        }
4315
4316        if let Some(v) = self.beta_noise.as_ref() {
4317            validate_all_finite("beta_noise", v.iter().copied()).map_err(corrupt)?;
4318        }
4319        if let Some(v) = self.noise_projection.as_ref() {
4320            validate_all_finite("noise_projection", v.iter().flatten().copied())
4321                .map_err(corrupt)?;
4322            if self.noise_projection_ridge_alpha.is_none() {
4323                return Err(FittedModelError::MissingField {
4324                    reason:
4325                        "model has noise_projection but is missing noise_projection_ridge_alpha; refit"
4326                            .to_string(),
4327                });
4328            }
4329        }
4330        if let Some(v) = self.noise_center.as_ref() {
4331            validate_all_finite("noise_center", v.iter().copied()).map_err(corrupt)?;
4332        }
4333        if let Some(v) = self.noise_scale.as_ref() {
4334            validate_all_finite("noise_scale", v.iter().copied()).map_err(corrupt)?;
4335        }
4336        if let Some(v) = self.noise_projection_ridge_alpha {
4337            ensure_finite_scalar("noise_projection_ridge_alpha", v).map_err(corrupt)?;
4338            if v < 0.0 {
4339                return Err(FittedModelError::InvalidInput {
4340                    reason: format!("noise_projection_ridge_alpha must be non-negative, got {v}"),
4341                });
4342            }
4343        }
4344        if let Some(v) = self.gaussian_response_scale {
4345            ensure_finite_scalar("gaussian_response_scale", v).map_err(corrupt)?;
4346        }
4347        if let Some(v) = self.beta_link_wiggle.as_ref() {
4348            validate_all_finite("beta_link_wiggle", v.iter().copied()).map_err(corrupt)?;
4349        }
4350        if let Some(v) = self.beta_baseline_timewiggle.as_ref() {
4351            validate_all_finite("beta_baseline_timewiggle", v.iter().copied()).map_err(corrupt)?;
4352        }
4353        if let Some(v) = self.beta_baseline_timewiggle_by_cause.as_ref() {
4354            validate_all_finite(
4355                "beta_baseline_timewiggle_by_cause",
4356                v.iter().flatten().copied(),
4357            )
4358            .map_err(corrupt)?;
4359        }
4360        if let Some(v) = self.latent_z_normalization {
4361            v.validate("latent_z_normalization")?;
4362        }
4363        if let Some(v) = self.latent_measure.as_ref() {
4364            v.validate("latent_measure").map_err(corrupt)?;
4365        }
4366        if let Some(v) = self.survival_beta_time.as_ref() {
4367            validate_all_finite("survival_beta_time", v.iter().copied()).map_err(corrupt)?;
4368        }
4369        if let Some(v) = self.survival_beta_threshold.as_ref() {
4370            validate_all_finite("survival_beta_threshold", v.iter().copied()).map_err(corrupt)?;
4371        }
4372        if let Some(v) = self.survival_beta_log_sigma.as_ref() {
4373            validate_all_finite("survival_beta_log_sigma", v.iter().copied()).map_err(corrupt)?;
4374        }
4375        if let Some(v) = self.survival_noise_projection.as_ref() {
4376            validate_all_finite("survival_noise_projection", v.iter().flatten().copied())
4377                .map_err(corrupt)?;
4378            if self.survival_noise_projection_ridge_alpha.is_none() {
4379                return Err(FittedModelError::MissingField {
4380                    reason:
4381                        "model has survival_noise_projection but is missing survival_noise_projection_ridge_alpha; refit"
4382                            .to_string(),
4383                });
4384            }
4385        }
4386        if let Some(v) = self.survival_noise_center.as_ref() {
4387            validate_all_finite("survival_noise_center", v.iter().copied()).map_err(corrupt)?;
4388        }
4389        if let Some(v) = self.survival_noise_projection_ridge_alpha {
4390            ensure_finite_scalar("survival_noise_projection_ridge_alpha", v).map_err(corrupt)?;
4391            if v < 0.0 {
4392                return Err(FittedModelError::InvalidInput {
4393                    reason: format!(
4394                        "survival_noise_projection_ridge_alpha must be non-negative, got {v}"
4395                    ),
4396                });
4397            }
4398        }
4399        if let Some(v) = self.survival_noise_scale.as_ref() {
4400            validate_all_finite("survival_noise_scale", v.iter().copied()).map_err(corrupt)?;
4401        }
4402        if let Some(v) = self.mixture_link_param_covariance.as_ref() {
4403            validate_all_finite("mixture_link_param_covariance", v.iter().flatten().copied())
4404                .map_err(corrupt)?;
4405        }
4406        if let Some(v) = self.sas_param_covariance.as_ref() {
4407            validate_all_finite("sas_param_covariance", v.iter().flatten().copied())
4408                .map_err(corrupt)?;
4409        }
4410        Ok(())
4411    }
4412}
4413
4414fn array2_to_nestedvec(a: &ndarray::Array2<f64>) -> Vec<Vec<f64>> {
4415    a.rows().into_iter().map(|row| row.to_vec()).collect()
4416}
4417
4418use gam_solve::estimate::{ensure_finite_scalar, validate_all_finite};
4419
4420fn validate_frozen_term_collectionspec(
4421    spec: &TermCollectionSpec,
4422    label: &str,
4423) -> Result<(), FittedModelError> {
4424    spec.validate_frozen(label)
4425        .map_err(|reason| FittedModelError::SchemaMismatch { reason })
4426}
4427
4428impl Deref for FittedModel {
4429    type Target = FittedModelPayload;
4430
4431    fn deref(&self) -> &Self::Target {
4432        self.payload()
4433    }
4434}
4435
4436impl DerefMut for FittedModel {
4437    fn deref_mut(&mut self) -> &mut Self::Target {
4438        self.payload_mut()
4439    }
4440}
4441
4442// ---------------------------------------------------------------------------
4443// Reconstruct library types from saved models
4444// ---------------------------------------------------------------------------
4445
4446pub fn survival_baseline_config_from_model(
4447    model: &FittedModel,
4448) -> Result<SurvivalBaselineConfig, FittedModelError> {
4449    let target = model.survival_baseline_target.as_deref().ok_or_else(|| {
4450        FittedModelError::MissingField {
4451            reason: "saved survival model missing survival_baseline_target; refit".to_string(),
4452        }
4453    })?;
4454    parse_survival_baseline_config(
4455        target,
4456        model.survival_baseline_scale,
4457        model.survival_baseline_shape,
4458        model.survival_baseline_rate,
4459        model.survival_baseline_makeham,
4460    )
4461    .map_err(|reason| FittedModelError::IncompatibleConfig { reason })
4462}
4463
4464pub fn load_survival_time_basis_config_from_model(
4465    model: &FittedModel,
4466) -> Result<SurvivalTimeBasisConfig, FittedModelError> {
4467    match model
4468        .survival_time_basis
4469        .as_deref()
4470        .ok_or_else(|| FittedModelError::MissingField {
4471            reason: "saved survival model missing survival_time_basis".to_string(),
4472        })?
4473        .to_ascii_lowercase()
4474        .as_str()
4475    {
4476        "none" => Ok(SurvivalTimeBasisConfig::None),
4477        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
4478        "bspline" => {
4479            let degree =
4480                model
4481                    .survival_time_degree
4482                    .ok_or_else(|| FittedModelError::MissingField {
4483                        reason: "saved survival bspline model missing survival_time_degree"
4484                            .to_string(),
4485                    })?;
4486            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4487                FittedModelError::MissingField {
4488                    reason: "saved survival bspline model missing survival_time_knots".to_string(),
4489                }
4490            })?;
4491            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4492            if degree < 1 || knots.is_empty() {
4493                return Err(FittedModelError::SchemaMismatch {
4494                    reason: "saved survival bspline time basis metadata is invalid".to_string(),
4495                });
4496            }
4497            Ok(SurvivalTimeBasisConfig::BSpline {
4498                degree,
4499                knots: Array1::from_vec(knots),
4500                smooth_lambda,
4501            })
4502        }
4503        "ispline" => {
4504            let degree =
4505                model
4506                    .survival_time_degree
4507                    .ok_or_else(|| FittedModelError::MissingField {
4508                        reason: "saved survival ispline model missing survival_time_degree"
4509                            .to_string(),
4510                    })?;
4511            let knots = model.survival_time_knots.clone().ok_or_else(|| {
4512                FittedModelError::MissingField {
4513                    reason: "saved survival ispline model missing survival_time_knots".to_string(),
4514                }
4515            })?;
4516            let keep_cols = model.survival_time_keep_cols.clone().ok_or_else(|| {
4517                FittedModelError::MissingField {
4518                    reason: "saved survival ispline model missing survival_time_keep_cols"
4519                        .to_string(),
4520                }
4521            })?;
4522            let smooth_lambda = model.survival_time_smooth_lambda.unwrap_or(1e-2);
4523            if degree < 1 || knots.is_empty() || keep_cols.is_empty() {
4524                return Err(FittedModelError::SchemaMismatch {
4525                    reason: "saved survival ispline time basis metadata is invalid".to_string(),
4526                });
4527            }
4528            Ok(SurvivalTimeBasisConfig::ISpline {
4529                degree,
4530                knots: Array1::from_vec(knots),
4531                keep_cols,
4532                smooth_lambda,
4533            })
4534        }
4535        other => Err(FittedModelError::IncompatibleConfig {
4536            reason: format!("unsupported saved survival_time_basis '{other}'"),
4537        }),
4538    }
4539}
4540
4541#[cfg(test)]
4542mod tests {
4543    use super::*;
4544    use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
4545    use crate::survival::lognormal_kernel::FrailtySpec;
4546    use gam_solve::pirls::PirlsStatus;
4547    use gam_solve::estimate::{FitArtifacts, FittedBlock, FittedLinkState};
4548    use gam_problem::types::{LikelihoodScaleMetadata, LogLikelihoodNormalization};
4549    use gam_data::SchemaColumn;
4550    use ndarray::{Array1, Array2, array};
4551
4552    fn empty_termspec() -> TermCollectionSpec {
4553        TermCollectionSpec {
4554            linear_terms: vec![],
4555            random_effect_terms: vec![],
4556            smooth_terms: vec![],
4557        }
4558    }
4559
4560    /// #1030/#1034: a scan-bearing payload must round-trip through JSON +
4561    /// `validate_for_persistence` and replay the training Gaussian bridge
4562    /// bit-for-bit; structural corruption must fail loudly at validation.
4563    #[test]
4564    fn spline_scan_payload_round_trips_and_validates() {
4565        let x: Vec<f64> = (0..40).map(|i| i as f64 / 39.0).collect();
4566        let y: Vec<f64> = x.iter().map(|&v| (4.0 * v).sin() + 0.1 * v).collect();
4567        let w = vec![1.0_f64; x.len()];
4568        let fit = gam_solve::spline_scan::fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
4569        let make_payload = || {
4570            crate::inference::model_payload_builders::assemble_spline_scan_payload(
4571                "y ~ s(x)".to_string(),
4572                "x".to_string(),
4573                &fit,
4574                DataSchema {
4575                    columns: vec![
4576                        SchemaColumn {
4577                            name: "y".to_string(),
4578                            kind: ColumnKindTag::Continuous,
4579                            levels: vec![],
4580                        },
4581                        SchemaColumn {
4582                            name: "x".to_string(),
4583                            kind: ColumnKindTag::Continuous,
4584                            levels: vec![],
4585                        },
4586                    ],
4587                },
4588                vec!["x".to_string()],
4589                vec![(0.0, 1.0)],
4590            )
4591        };
4592        // The on-disk form is the FittedModel tagged enum; validation and the
4593        // scan accessor live on FittedModel (Deref only goes Model -> Payload).
4594        let model = FittedModel::from_payload(make_payload());
4595        model
4596            .validate_for_persistence()
4597            .expect("scan model validates");
4598        model
4599            .validate_numeric_finiteness()
4600            .expect("scan model is finite");
4601
4602        let json = serde_json::to_string(&model).expect("serialize model");
4603        let restored: FittedModel = serde_json::from_str(&json).expect("parse model");
4604        restored
4605            .validate_for_persistence()
4606            .expect("restored scan model validates");
4607        let (column, replay) = restored
4608            .saved_spline_scan()
4609            .expect("restore scan fit")
4610            .expect("payload carries the scan representation");
4611        assert_eq!(column, "x");
4612        for &xq in &[-0.1, 0.0, 0.31, 0.5, 0.77, 1.0, 1.4] {
4613            let (m0, v0) = fit.predict(xq).expect("predict original");
4614            let (m1, v1) = replay.predict(xq).expect("predict replayed");
4615            assert_eq!(m0.to_bits(), m1.to_bits(), "mean drift at x={xq}");
4616            assert_eq!(v0.to_bits(), v1.to_bits(), "variance drift at x={xq}");
4617        }
4618
4619        // A dense model without the scan channel still requires fit_result.
4620        let mut dense = make_payload();
4621        dense.spline_scan = None;
4622        let err = FittedModel::from_payload(dense)
4623            .validate_for_persistence()
4624            .expect_err("dense payload without fit_result must be rejected");
4625        assert!(err.to_string().contains("fit_result"));
4626
4627        // Structural corruption fails at validation, not inside predict.
4628        let mut corrupt = make_payload();
4629        corrupt
4630            .spline_scan
4631            .as_mut()
4632            .expect("scan channel present")
4633            .state
4634            .knots
4635            .truncate(2);
4636        FittedModel::from_payload(corrupt)
4637            .validate_for_persistence()
4638            .expect_err("corrupt scan state must be rejected");
4639        let mut unnamed = make_payload();
4640        unnamed
4641            .spline_scan
4642            .as_mut()
4643            .expect("scan channel present")
4644            .feature_column
4645            .clear();
4646        FittedModel::from_payload(unnamed)
4647            .validate_for_persistence()
4648            .expect_err("missing feature column must be rejected");
4649    }
4650
4651    fn standard_gaussian_payload() -> FittedModelPayload {
4652        FittedModelPayload::new(
4653            MODEL_PAYLOAD_VERSION,
4654            "y ~ 1".to_string(),
4655            ModelKind::Standard,
4656            FittedFamily::Standard {
4657                likelihood: LikelihoodSpec::gaussian_identity(),
4658                link: Some(StandardLink::Identity),
4659                latent_cloglog_state: None,
4660                mixture_state: None,
4661                sas_state: None,
4662            },
4663            "gaussian".to_string(),
4664        )
4665    }
4666
4667    fn anchored_runtime(basis_dim: usize) -> SavedCompiledFlexBlock {
4668        SavedCompiledFlexBlock {
4669            kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
4670            breakpoints: vec![-1.0, 1.0],
4671            basis_dim,
4672            span_c0: vec![vec![0.0; basis_dim]],
4673            span_c1: vec![vec![0.0; basis_dim]],
4674            span_c2: vec![vec![0.0; basis_dim]],
4675            span_c3: vec![vec![0.0; basis_dim]],
4676            anchor_correction: None,
4677            anchor_components: Vec::new(),
4678        }
4679    }
4680
4681    fn saved_fit(blocks: Vec<FittedBlock>) -> UnifiedFitResult {
4682        let beta = Array1::from_vec(
4683            blocks
4684                .iter()
4685                .flat_map(|block| block.beta.iter().copied())
4686                .collect(),
4687        );
4688        let p = beta.len();
4689        UnifiedFitResult {
4690            blocks,
4691            log_lambdas: Array1::zeros(0),
4692            lambdas: Array1::zeros(0),
4693            likelihood_family: Some(LikelihoodSpec::binomial_probit()),
4694            likelihood_scale: LikelihoodScaleMetadata::Unspecified,
4695            log_likelihood_normalization: LogLikelihoodNormalization::Full,
4696            log_likelihood: 0.0,
4697            deviance: 0.0,
4698            reml_score: 0.0,
4699            stable_penalty_term: 0.0,
4700            penalized_objective: 0.0,
4701            used_device: false,
4702            outer_iterations: 0,
4703            outer_cost_evals: 0,
4704            inner_pirls_solves: 0,
4705            outer_converged: true,
4706            outer_gradient_norm: None,
4707            standard_deviation: 1.0,
4708            covariance_conditional: Some(Array2::zeros((p, p))),
4709            covariance_corrected: Some(Array2::zeros((p, p))),
4710            inference: None,
4711            fitted_link: FittedLinkState::Standard(None),
4712            geometry: None,
4713            block_states: vec![],
4714            beta,
4715            pirls_status: PirlsStatus::Converged,
4716            max_abs_eta: 0.0,
4717            constraint_kkt: None,
4718            artifacts: FitArtifacts {
4719                pirls: None,
4720                null_space_logdet: None,
4721                null_space_dim: None,
4722                survival_link_wiggle_knots: None,
4723                survival_link_wiggle_degree: None,
4724                criterion_certificate: None,
4725                rho_posterior_certificate: None,
4726                rho_posterior_escalation: None,
4727                rho_covariance: None,
4728                joint_log_lambdas: None,
4729            },
4730            inner_cycles: 0,
4731        }
4732    }
4733
4734    fn marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4735        let mut payload = FittedModelPayload::new(
4736            version,
4737            "y ~ 1".to_string(),
4738            ModelKind::MarginalSlope,
4739            FittedFamily::MarginalSlope {
4740                likelihood: LikelihoodSpec::binomial_probit(),
4741                base_link: InverseLink::Standard(StandardLink::Probit),
4742                frailty: FrailtySpec::None,
4743            },
4744            "bernoulli-marginal-slope".to_string(),
4745        );
4746        payload.fit_result = Some(fit.clone());
4747        payload.unified = Some(fit);
4748        payload.data_schema = Some(DataSchema {
4749            columns: vec![SchemaColumn {
4750                name: "z".to_string(),
4751                kind: ColumnKindTag::Continuous,
4752                levels: vec![],
4753            }],
4754        });
4755        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4756        payload.resolved_termspec = Some(empty_termspec());
4757        payload.resolved_termspec_logslope = Some(empty_termspec());
4758        payload.formula_logslope = Some("1".to_string());
4759        payload.z_column = Some("z".to_string());
4760        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4761        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4762        payload.marginal_baseline = Some(0.0);
4763        payload.logslope_baseline = Some(0.0);
4764        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4765        payload
4766    }
4767
4768    #[test]
4769    fn from_payload_synchronizes_used_device_from_saved_fit() {
4770        let mut fit = saved_fit(vec![
4771            FittedBlock {
4772                beta: Array1::from_vec(vec![0.25]),
4773                role: BlockRole::Mean,
4774                edf: 1.0,
4775                lambdas: Array1::zeros(0),
4776            },
4777            FittedBlock {
4778                beta: Array1::from_vec(vec![0.5]),
4779                role: BlockRole::Scale,
4780                edf: 1.0,
4781                lambdas: Array1::zeros(0),
4782            },
4783        ]);
4784        fit.used_device = true;
4785        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4786        payload.used_device = false;
4787
4788        let model = FittedModel::from_payload(payload);
4789
4790        assert!(model.payload().used_device);
4791    }
4792
4793    fn survival_marginal_slope_payload(version: u32, fit: UnifiedFitResult) -> FittedModelPayload {
4794        let mut payload = FittedModelPayload::new(
4795            version,
4796            "Surv(entry, exit, event) ~ 1".to_string(),
4797            ModelKind::Survival,
4798            FittedFamily::Survival {
4799                likelihood: LikelihoodSpec::royston_parmar(),
4800                survival_likelihood: Some("marginal-slope".to_string()),
4801                survival_distribution: Some(ResidualDistribution::Gaussian),
4802                frailty: FrailtySpec::None,
4803            },
4804            "survival".to_string(),
4805        );
4806        payload.fit_result = Some(fit.clone());
4807        payload.unified = Some(fit);
4808        payload.survival_likelihood = Some("marginal-slope".to_string());
4809        payload.survival_distribution = Some(ResidualDistribution::Gaussian);
4810        payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
4811        payload.data_schema = Some(DataSchema {
4812            columns: vec![SchemaColumn {
4813                name: "z".to_string(),
4814                kind: ColumnKindTag::Continuous,
4815                levels: vec![],
4816            }],
4817        });
4818        payload.set_training_feature_metadata(vec!["z".to_string()], vec![(0.0, 0.0)]);
4819        payload.resolved_termspec = Some(empty_termspec());
4820        payload.resolved_termspec_logslope = Some(empty_termspec());
4821        payload.formula_logslope = Some("1".to_string());
4822        payload.z_column = Some("z".to_string());
4823        payload.latent_z_normalization = Some(SavedLatentZNormalization { mean: 0.0, sd: 1.0 });
4824        payload.logslope_baseline = Some(0.0);
4825        payload.link = Some(InverseLink::Standard(StandardLink::Probit));
4826        payload
4827    }
4828
4829    fn gamma_dispersion_location_scale_payload() -> FittedModelPayload {
4830        // A #913 genuine-dispersion location-scale model: Gamma mean family with
4831        // a log-precision `noise_formula` channel. Its likelihood response is
4832        // non-Gaussian and non-Binomial, so the predict-path classifier must
4833        // route it to `DispersionLocationScale`, NOT the binomial threshold-scale
4834        // class (issue #1064).
4835        let mut payload = FittedModelPayload::new(
4836            MODEL_PAYLOAD_VERSION,
4837            "y ~ x".to_string(),
4838            ModelKind::LocationScale,
4839            FittedFamily::LocationScale {
4840                likelihood: LikelihoodSpec::gamma_log(),
4841                base_link: Some(InverseLink::Standard(StandardLink::Log)),
4842            },
4843            "gamma-location-scale".to_string(),
4844        );
4845        payload.data_schema = Some(DataSchema {
4846            columns: vec![
4847                SchemaColumn {
4848                    name: "y".to_string(),
4849                    kind: ColumnKindTag::Continuous,
4850                    levels: vec![],
4851                },
4852                SchemaColumn {
4853                    name: "x".to_string(),
4854                    kind: ColumnKindTag::Continuous,
4855                    levels: vec![],
4856                },
4857            ],
4858        });
4859        payload.set_training_feature_metadata(vec!["x".to_string()], vec![(-1.0, 1.0)]);
4860        payload.resolved_termspec = Some(empty_termspec());
4861        payload.resolved_termspec_noise = Some(empty_termspec());
4862        payload.formula_noise = Some("x".to_string());
4863        payload.beta_noise = Some(vec![0.0]);
4864        payload.link = Some(InverseLink::Standard(StandardLink::Log));
4865        payload
4866    }
4867
4868    /// #1064 regression: a dispersion location-scale (#913) payload must be
4869    /// classified as `DispersionLocationScale` at every predict-path entry —
4870    /// both `from_payload` (load) and `predict_model_class` (runtime) — and never
4871    /// fall through to the binomial threshold-scale class. Before the fix the
4872    /// non-Gaussian `else` arm mis-routed every dispersion model to
4873    /// `BinomialLocationScale`, predicting the wrong family/link.
4874    #[test]
4875    fn dispersion_location_scale_payload_is_not_classified_binomial() {
4876        let model = FittedModel::from_payload(gamma_dispersion_location_scale_payload());
4877        assert_eq!(
4878            model.predict_model_class(),
4879            PredictModelClass::DispersionLocationScale,
4880            "Gamma dispersion location-scale must route through the dispersion \
4881             predictor, not the binomial threshold-scale class",
4882        );
4883        assert!(
4884            !matches!(
4885                model.predict_model_class(),
4886                PredictModelClass::BinomialLocationScale
4887            ),
4888            "dispersion location-scale must never be classified as binomial",
4889        );
4890
4891        // Each of the four #913 dispersion mean families classifies the same way.
4892        for likelihood in [
4893            LikelihoodSpec::gamma_log(),
4894            LikelihoodSpec::new(
4895                ResponseFamily::NegativeBinomial {
4896                    theta: 1.0,
4897                    theta_fixed: false,
4898                },
4899                InverseLink::Standard(StandardLink::Log),
4900            ),
4901            LikelihoodSpec::new(
4902                ResponseFamily::Beta { phi: 1.0 },
4903                InverseLink::Standard(StandardLink::Logit),
4904            ),
4905            LikelihoodSpec::new(
4906                ResponseFamily::Tweedie { p: 1.5 },
4907                InverseLink::Standard(StandardLink::Log),
4908            ),
4909        ] {
4910            let mut payload = gamma_dispersion_location_scale_payload();
4911            payload.family_state = FittedFamily::LocationScale {
4912                base_link: Some(likelihood.link.clone()),
4913                likelihood: likelihood.clone(),
4914            };
4915            let model = FittedModel::from_payload(payload);
4916            assert_eq!(
4917                model.predict_model_class(),
4918                PredictModelClass::DispersionLocationScale,
4919                "dispersion family {:?} mis-classified",
4920                likelihood.response,
4921            );
4922        }
4923    }
4924
4925    #[test]
4926    fn axis_clip_leaves_numeric_random_effect_group_axis_unclipped() {
4927        let data = array![[100.0], [-100.0]];
4928        let col_map = HashMap::from([("g".to_string(), 0usize)]);
4929
4930        let mut plain_payload = standard_gaussian_payload();
4931        plain_payload.data_schema = Some(DataSchema {
4932            columns: vec![SchemaColumn {
4933                name: "g".to_string(),
4934                kind: ColumnKindTag::Continuous,
4935                levels: vec![],
4936            }],
4937        });
4938        plain_payload.set_training_feature_metadata(vec!["g".to_string()], vec![(0.0, 7.0)]);
4939        plain_payload.resolved_termspec = Some(empty_termspec());
4940        let plain = FittedModel::from_payload(plain_payload.clone());
4941        let clipped = plain
4942            .axis_clip_to_training_ranges(data.view(), &col_map)
4943            .expect("ordinary continuous axis should clip outside the training range");
4944        assert_eq!(clipped.column(0).to_vec(), vec![7.0, 0.0]);
4945
4946        let mut group_payload = plain_payload;
4947        let mut group_spec = empty_termspec();
4948        group_spec
4949            .random_effect_terms
4950            .push(gam_terms::smooth::RandomEffectTermSpec {
4951                name: "g".to_string(),
4952                feature_col: 0,
4953                drop_first_level: false,
4954                penalized: true,
4955                frozen_levels: Some(vec![0.0_f64.to_bits(), 7.0_f64.to_bits()]),
4956            });
4957        group_payload.resolved_termspec = Some(group_spec);
4958        let group_model = FittedModel::from_payload(group_payload);
4959
4960        assert_eq!(
4961            group_model.random_effect_group_columns(),
4962            HashSet::from(["g".to_string()])
4963        );
4964
4965        assert_eq!(
4966            group_model.axis_clip_to_training_ranges(data.view(), &col_map),
4967            None,
4968            "numeric group labels must reach RandomEffectOperator as unseen levels, not be clipped to boundary seen levels"
4969        );
4970    }
4971
4972    #[test]
4973    fn validate_for_persistence_rejects_marginal_slope_score_warp_basis_mismatch() {
4974        let fit = saved_fit(vec![
4975            FittedBlock {
4976                beta: array![0.1],
4977                role: BlockRole::Mean,
4978                edf: 1.0,
4979                lambdas: Array1::zeros(0),
4980            },
4981            FittedBlock {
4982                beta: array![0.2],
4983                role: BlockRole::Scale,
4984                edf: 1.0,
4985                lambdas: Array1::zeros(0),
4986            },
4987            FittedBlock {
4988                beta: array![0.3],
4989                role: BlockRole::Mean,
4990                edf: 1.0,
4991                lambdas: Array1::zeros(0),
4992            },
4993        ]);
4994        let mut payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
4995        payload.score_warp_runtime = Some(anchored_runtime(2));
4996
4997        let err = FittedModel::from_payload(payload)
4998            .validate_for_persistence()
4999            .expect_err("marginal-slope score-warp basis mismatch should fail validation");
5000        assert!(err.to_string().contains("score-warp coefficient mismatch"));
5001    }
5002
5003    #[test]
5004    fn saved_prediction_runtime_rejects_survival_marginal_slope_link_basis_mismatch() {
5005        let fit = saved_fit(vec![
5006            FittedBlock {
5007                beta: array![0.1],
5008                role: BlockRole::Time,
5009                edf: 1.0,
5010                lambdas: Array1::zeros(0),
5011            },
5012            FittedBlock {
5013                beta: array![0.2],
5014                role: BlockRole::Mean,
5015                edf: 1.0,
5016                lambdas: Array1::zeros(0),
5017            },
5018            FittedBlock {
5019                beta: array![0.3],
5020                role: BlockRole::Scale,
5021                edf: 1.0,
5022                lambdas: Array1::zeros(0),
5023            },
5024            FittedBlock {
5025                beta: array![0.4],
5026                role: BlockRole::LinkWiggle,
5027                edf: 1.0,
5028                lambdas: Array1::zeros(0),
5029            },
5030        ]);
5031        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5032        payload.link_deviation_runtime = Some(anchored_runtime(2));
5033
5034        let err = FittedModel::from_payload(payload)
5035            .saved_prediction_runtime()
5036            .expect_err(
5037                "survival marginal-slope link basis mismatch should fail runtime validation",
5038            );
5039        assert!(
5040            err.to_string()
5041                .contains("link-deviation coefficient mismatch")
5042        );
5043    }
5044
5045    #[test]
5046    fn apply_survival_time_basis_writes_all_required_fields() {
5047        use crate::survival::construction::SavedSurvivalTimeBasis;
5048
5049        let fit = saved_fit(vec![
5050            FittedBlock {
5051                beta: array![0.1],
5052                role: BlockRole::Time,
5053                edf: 1.0,
5054                lambdas: Array1::zeros(0),
5055            },
5056            FittedBlock {
5057                beta: array![0.2],
5058                role: BlockRole::Mean,
5059                edf: 1.0,
5060                lambdas: Array1::zeros(0),
5061            },
5062            FittedBlock {
5063                beta: array![0.3],
5064                role: BlockRole::Scale,
5065                edf: 1.0,
5066                lambdas: Array1::zeros(0),
5067            },
5068        ]);
5069        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5070
5071        // Snapshot writes must match every persisted survival_time_* field —
5072        // forgetting one is exactly the marginal-slope save
5073        // regression. Routing through `apply_survival_time_basis` is the
5074        // structural contract that prevents that recurrence.
5075        let snapshot = SavedSurvivalTimeBasis {
5076            basisname: "royston-parmar".to_string(),
5077            degree: Some(3),
5078            knots: Some(vec![0.0, 1.0, 2.0]),
5079            keep_cols: Some(vec![0, 2]),
5080            smooth_lambda: Some(0.5),
5081            anchor: 0.25,
5082        };
5083        payload.apply_survival_time_basis(&snapshot);
5084
5085        assert_eq!(
5086            payload.survival_time_basis.as_deref(),
5087            Some("royston-parmar")
5088        );
5089        assert_eq!(payload.survival_time_degree, Some(3));
5090        assert_eq!(payload.survival_time_knots, Some(vec![0.0, 1.0, 2.0]));
5091        assert_eq!(payload.survival_time_keep_cols, Some(vec![0, 2]));
5092        assert_eq!(payload.survival_time_smooth_lambda, Some(0.5));
5093        assert_eq!(payload.survival_time_anchor, Some(0.25));
5094    }
5095
5096    #[test]
5097    fn validate_for_persistence_rejects_survival_without_time_anchor_metadata() {
5098        let fit = saved_fit(vec![
5099            FittedBlock {
5100                beta: array![0.1],
5101                role: BlockRole::Time,
5102                edf: 1.0,
5103                lambdas: Array1::zeros(0),
5104            },
5105            FittedBlock {
5106                beta: array![0.2],
5107                role: BlockRole::Mean,
5108                edf: 1.0,
5109                lambdas: Array1::zeros(0),
5110            },
5111            FittedBlock {
5112                beta: array![0.3],
5113                role: BlockRole::Scale,
5114                edf: 1.0,
5115                lambdas: Array1::zeros(0),
5116            },
5117        ]);
5118        let mut payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5119        // Pass the time_basis presence check but deliberately omit the
5120        // anchor — this is exactly the partial-write shape that the CLI's
5121        // marginal-slope+time-wiggle save path had before the structural
5122        // refactor (main.rs previously set basis/degree/knots/keep_cols/
5123        // smooth_lambda but forgot the anchor).
5124        payload.survival_time_basis = Some("ispline".to_string());
5125
5126        let err = FittedModel::from_payload(payload)
5127            .validate_for_persistence()
5128            .expect_err("survival model without time-anchor metadata should fail validation");
5129        assert!(err.to_string().contains("missing survival_time_anchor"));
5130    }
5131
5132    #[test]
5133    fn validate_for_persistence_rejects_survival_without_time_basis_metadata() {
5134        let fit = saved_fit(vec![
5135            FittedBlock {
5136                beta: array![0.1],
5137                role: BlockRole::Time,
5138                edf: 1.0,
5139                lambdas: Array1::zeros(0),
5140            },
5141            FittedBlock {
5142                beta: array![0.2],
5143                role: BlockRole::Mean,
5144                edf: 1.0,
5145                lambdas: Array1::zeros(0),
5146            },
5147            FittedBlock {
5148                beta: array![0.3],
5149                role: BlockRole::Scale,
5150                edf: 1.0,
5151                lambdas: Array1::zeros(0),
5152            },
5153        ]);
5154        let payload = survival_marginal_slope_payload(MODEL_PAYLOAD_VERSION, fit);
5155
5156        let err = FittedModel::from_payload(payload)
5157            .validate_for_persistence()
5158            .expect_err("survival model without time-basis metadata should fail validation");
5159        assert!(err.to_string().contains("missing survival_time_basis"));
5160    }
5161
5162    #[test]
5163    fn saved_prediction_runtime_rejects_stale_payload_version() {
5164        let fit = saved_fit(vec![
5165            FittedBlock {
5166                beta: array![0.1],
5167                role: BlockRole::Mean,
5168                edf: 1.0,
5169                lambdas: Array1::zeros(0),
5170            },
5171            FittedBlock {
5172                beta: array![0.2],
5173                role: BlockRole::Scale,
5174                edf: 1.0,
5175                lambdas: Array1::zeros(0),
5176            },
5177        ]);
5178        let payload = marginal_slope_payload(MODEL_PAYLOAD_VERSION - 1, fit);
5179
5180        let err = FittedModel::from_payload(payload)
5181            .saved_prediction_runtime()
5182            .expect_err("stale payload version should fail before runtime assembly");
5183        assert!(err.to_string().contains("payload schema mismatch"));
5184    }
5185}