Skip to main content

gam_models/inference/
model_payload_builders.rs

1//! Shared, source-agnostic builders for saved-model payloads.
2//!
3//! The CLI (`src/main.rs`) and the Python FFI (`crates/gam-pyffi/src/lib.rs`)
4//! both persist fitted models, and both used to assemble the serialized
5//! [`FittedModelPayload`] independently. That meant the on-disk contract for a
6//! given model kind could silently drift depending on whether the model was
7//! created through the CLI or through Python — exactly the failure mode that
8//! repeatedly bit the marginal-slope save→load path.
9//!
10//! This module assembles the *semantic* payload exactly once. Each caller is
11//! responsible only for the source-specific work of producing the resolved
12//! semantic inputs (the CLI threads them through from its argument parsing and
13//! fit pipeline; the FFI freezes term collections from designs and re-derives
14//! metadata from the [`FitConfig`]). Once both sides hand the same semantic
15//! content to the same assembler, payload drift becomes impossible by
16//! construction.
17
18use gam_solve::estimate::UnifiedFitResult;
19use crate::bms::deviation_runtime::AnchorComponentTag;
20use crate::bms::{
21    DeviationRuntime, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
22};
23use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
24use crate::scale_design::ScaleDeviationTransform;
25use crate::survival::construction::{
26    SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
27};
28use crate::survival::location_scale::{
29    ResidualDistribution, residual_distribution_from_inverse_link,
30};
31use crate::transformation_normal::TransformationNormalFamily;
32use crate::inference::model::{
33    FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind, SavedAnchorComponent,
34    SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization, SavedResidualCascade,
35    SavedSplineScan, TransformationScoreCalibration,
36};
37use gam_terms::smooth::TermCollectionSpec;
38use gam_problem::types::{
39    InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
40};
41use gam_data::DataSchema;
42use ndarray::Array2;
43
44/// Family tag persisted for Bernoulli marginal-slope saved models.
45const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
46
47/// Family tag persisted for transformation-normal saved models.
48const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
49
50/// Serialize an anchored-deviation [`DeviationRuntime`] (score-warp or
51/// link-deviation block) into its persistable [`SavedCompiledFlexBlock`] form.
52///
53/// This is the single source of truth for that conversion; the CLI and FFI
54/// payload builders both route through it so the serialized flex contract
55/// cannot diverge between the two save paths.
56pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
57    let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
58    let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
59    if let Some(installed) = runtime.installed_flex_block() {
60        anchor_correction = Some(
61            installed
62                .anchor_correction
63                .rows()
64                .into_iter()
65                .map(|row| row.to_vec())
66                .collect::<Vec<Vec<f64>>>(),
67        );
68        for component in &installed.anchor_components {
69            anchor_components.push(SavedAnchorComponent {
70                kind: match component {
71                    AnchorComponentTag::Parametric { block, ncols } => {
72                        SavedAnchorKind::Parametric {
73                            block: *block,
74                            ncols: *ncols,
75                        }
76                    }
77                    AnchorComponentTag::FlexEvaluation { ncols } => {
78                        SavedAnchorKind::FlexEvaluation { ncols: *ncols }
79                    }
80                },
81            });
82        }
83    }
84    SavedCompiledFlexBlock {
85        kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
86        breakpoints: runtime.breakpoints().to_vec(),
87        basis_dim: runtime.basis_dim(),
88        span_c0: runtime
89            .span_c0()
90            .rows()
91            .into_iter()
92            .map(|row| row.to_vec())
93            .collect(),
94        span_c1: runtime
95            .span_c1()
96            .rows()
97            .into_iter()
98            .map(|row| row.to_vec())
99            .collect(),
100        span_c2: runtime
101            .span_c2()
102            .rows()
103            .into_iter()
104            .map(|row| row.to_vec())
105            .collect(),
106        span_c3: runtime
107            .span_c3()
108            .rows()
109            .into_iter()
110            .map(|row| row.to_vec())
111            .collect(),
112        anchor_correction,
113        anchor_components,
114    }
115}
116
117/// Source-specific metadata that the CLI and FFI populate differently but that
118/// every saved payload carries.
119///
120/// `training_feature_ranges` is the only field the FFI path cannot currently
121/// supply (it persists headers without per-feature ranges); modeling it as
122/// `Option` keeps that distinction explicit instead of silently encoding an
123/// empty vector as if ranges were known.
124pub struct SavedModelSourceMetadata {
125    pub training_headers: Vec<String>,
126    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
127    pub offset_column: Option<String>,
128    pub noise_offset_column: Option<String>,
129}
130
131impl SavedModelSourceMetadata {
132    fn apply_to(self, payload: &mut FittedModelPayload) {
133        match self.training_feature_ranges {
134            Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
135            None => payload.training_headers = Some(self.training_headers),
136        }
137        payload.offset_column = self.offset_column;
138        payload.noise_offset_column = self.noise_offset_column;
139    }
140}
141
142/// The resolved, source-agnostic semantic content of a Bernoulli
143/// marginal-slope saved model.
144///
145/// The CLI threads these in directly from its fit pipeline; the FFI produces
146/// them by freezing its term collections and reading the [`FitConfig`]. Either
147/// way, the assembler below turns them into the canonical payload.
148pub struct BernoulliMarginalSlopeInputs<'a> {
149    pub formula: String,
150    pub data_schema: DataSchema,
151    pub logslope_formula: String,
152    pub z_column: String,
153    pub resolved_marginalspec: TermCollectionSpec,
154    pub resolved_logslopespec: TermCollectionSpec,
155    pub fit_result: UnifiedFitResult,
156    /// Number of *raw* marginal design columns `p_m` (= the term-collection
157    /// marginal design's `ncols()` BEFORE any #461 influence-absorber widening).
158    ///
159    /// When the Stage-1 influence absorber is active (A2), the fitted marginal
160    /// block carries the widened coefficient `[β_m; γ]` (length `p_m + p₁`) and
161    /// the joint covariance is dimensioned over the widened block. The absorbed
162    /// influence columns `Z̃_infl` are a TRAINING-only leakage absorber that does
163    /// not exist at predict rows, so the persisted model must drop `γ` and the
164    /// marginalized-out covariance sub-block to stay self-consistent against the
165    /// raw `p_m` marginal design at predict. The assembler uses this to truncate
166    /// the fit result once (shared CLI + FFI). With no absorber it equals the
167    /// fitted block width and the truncation is a no-op.
168    pub p_marginal: usize,
169    pub baseline_marginal: f64,
170    pub baseline_logslope: f64,
171    pub latent_z_normalization: SavedLatentZNormalization,
172    pub latent_measure: LatentMeasureKind,
173    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
174    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
175    pub score_warp_runtime: Option<&'a DeviationRuntime>,
176    pub link_dev_runtime: Option<&'a DeviationRuntime>,
177    pub base_link: InverseLink,
178    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
179}
180
181/// Drop the #461 training-only influence-absorber coefficients `γ` from a fitted
182/// Bernoulli marginal-slope result so the persisted model is self-consistent
183/// against the raw `p_m`-column marginal design at predict.
184///
185/// When the A2 influence absorber is active the marginal block (block 0) is the
186/// widened `[β_m; γ]` (length `p_m + p₁`, with `γ` the contiguous trailing `p₁`
187/// columns — see bms `widen_marginal_dense_with_influence`) and the joint
188/// conditional covariance is dimensioned over the widened joint coefficient
189/// vector. The absorbed columns `Z̃_infl` exist only at training rows; predict
190/// reconstructs the marginal index from the raw `p_m` design and the
191/// orthogonalized `β̂_m` is a property of the training fit. So this:
192///
193///  * slices `blocks[0].beta` and `block_states[0].beta` to their first `p_m`
194///    entries (the flat `beta` is recomputed from the blocks by
195///    `try_from_parts`),
196///  * **marginalizes** `γ` out of the joint Gaussian by dropping the `γ`
197///    rows/cols from the conditional covariance — taking the corresponding
198///    SUB-BLOCK of `Σ` is the exact marginal of a joint Gaussian (no
199///    re-inversion), so the kept `[β_m | β_logslope | …]` covariance is the
200///    correct predictive uncertainty accounting for the fitted absorber,
201///  * drops the persisted joint penalized-Hessian geometry: it is a precision
202///    over the *widened* joint coefficient vector, so a sub-block would be the
203///    wrong marginalization, and the only predict path that consumes it is the
204///    covariance-fallback that re-inverts `H` — which post-truncation would have
205///    the wrong dimension anyway. With the dense (already-marginalized) `Σ`
206///    matching the predict dimension, that fallback is never taken, so dropping
207///    the geometry removes a stale, wrong-dimension path rather than a used one.
208///
209/// Block-level `edf` / `lambdas` are left untouched: they are fitted scalars
210/// that legitimately reflect the full model (the absorber consumed real dof at
211/// fit time) and are persisted as-is. With no absorber (`block0.len() == p_m`)
212/// this is a no-op clone.
213fn truncate_marginal_slope_influence_absorber(
214    fit_result: UnifiedFitResult,
215    p_marginal: usize,
216) -> Result<UnifiedFitResult, String> {
217    let Some(block0) = fit_result.blocks.first() else {
218        return Err("marginal-slope fit result has no coefficient blocks".to_string());
219    };
220    let widened_len = block0.beta.len();
221    if widened_len <= p_marginal {
222        // No influence absorber installed (or already raw width): nothing to drop.
223        return Ok(fit_result);
224    }
225    let p_influence = widened_len - p_marginal;
226
227    let UnifiedFitResult {
228        mut blocks,
229        log_lambdas,
230        lambdas,
231        likelihood_family,
232        likelihood_scale,
233        log_likelihood_normalization,
234        log_likelihood,
235        deviance,
236        reml_score,
237        stable_penalty_term,
238        penalized_objective,
239        used_device,
240        outer_iterations,
241        outer_converged,
242        outer_gradient_norm,
243        standard_deviation,
244        covariance_conditional,
245        covariance_corrected,
246        inference,
247        fitted_link,
248        geometry: _,
249        mut block_states,
250        beta: _,
251        pirls_status,
252        max_abs_eta,
253        constraint_kkt,
254        artifacts,
255        inner_cycles,
256        outer_cost_evals: _,
257        inner_pirls_solves: _,
258    } = fit_result;
259
260    // Slice block 0's coefficients (and matching block-state) to the raw p_m,
261    // dropping the trailing γ absorber columns.
262    blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
263    if let Some(state0) = block_states.first_mut() {
264        state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
265    }
266
267    // Marginalize γ out of the joint conditional covariance: keep every index
268    // except the contiguous γ block [p_marginal, p_marginal + p_influence).
269    let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
270        cov.map(|cov| {
271            let total = cov.nrows();
272            let kept: Vec<usize> = (0..p_marginal)
273                .chain((p_marginal + p_influence)..total)
274                .collect();
275            let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
276            for (ri, &r) in kept.iter().enumerate() {
277                for (ci, &c) in kept.iter().enumerate() {
278                    out[[ri, ci]] = cov[[r, c]];
279                }
280            }
281            out
282        })
283    };
284    let covariance_conditional = drop_gamma_block(covariance_conditional);
285    let covariance_corrected = drop_gamma_block(covariance_corrected);
286
287    UnifiedFitResult::try_from_parts(gam_solve::estimate::UnifiedFitResultParts {
288        blocks,
289        log_lambdas,
290        lambdas,
291        likelihood_family,
292        likelihood_scale,
293        log_likelihood_normalization,
294        log_likelihood,
295        deviance,
296        reml_score,
297        stable_penalty_term,
298        penalized_objective,
299        // Preserve the GPU-execution flag across the absorber-column
300        // truncation: dropping the trailing γ columns does not change which
301        // device ran the solve.
302        used_device,
303        outer_iterations,
304        outer_converged,
305        outer_gradient_norm,
306        standard_deviation,
307        covariance_conditional,
308        covariance_corrected,
309        inference,
310        fitted_link,
311        // Drop the widened-joint penalized Hessian: see the doc comment.
312        geometry: None,
313        block_states,
314        pirls_status,
315        max_abs_eta,
316        constraint_kkt,
317        artifacts,
318        inner_cycles,
319    })
320    .map_err(|e| {
321        format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
322    })
323}
324
325/// Assemble the canonical spline-scan payload (#1030/#1034): a standard
326/// Gaussian-identity model whose fit representation is the exact O(n)
327/// smoothing-spline smoother state instead of a dense `fit_result`. The CLI
328/// and FFI save paths both route through here so the scan on-disk contract
329/// cannot diverge between sources.
330pub fn assemble_spline_scan_payload(
331    formula: String,
332    feature_column: String,
333    fit: &gam_solve::spline_scan::SplineScanFit,
334    data_schema: DataSchema,
335    training_headers: Vec<String>,
336    training_feature_ranges: Vec<(f64, f64)>,
337) -> FittedModelPayload {
338    let mut payload = FittedModelPayload::new(
339        MODEL_PAYLOAD_VERSION,
340        formula,
341        ModelKind::Standard,
342        FittedFamily::Standard {
343            likelihood: LikelihoodSpec::gaussian_identity(),
344            link: None,
345            latent_cloglog_state: None,
346            mixture_state: None,
347            sas_state: None,
348        },
349        "gaussian".to_string(),
350    );
351    payload.spline_scan = Some(SavedSplineScan {
352        feature_column,
353        state: fit.to_state(),
354    });
355    payload.data_schema = Some(data_schema);
356    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
357    payload
358}
359
360/// Assemble the canonical residual-cascade payload (#1032).
361///
362/// The CLI and FFI save paths both route through here so the cascade on-disk
363/// contract cannot diverge between sources.  Mirrors `assemble_spline_scan_payload`
364/// but for d ∈ {2,3} scattered coordinates (the Wendland multilevel-frame state).
365pub fn assemble_residual_cascade_payload(
366    formula: String,
367    feature_columns: Vec<String>,
368    fit: &gam_solve::residual_cascade::ResidualCascadeFit,
369    data_schema: DataSchema,
370    training_headers: Vec<String>,
371    training_feature_ranges: Vec<(f64, f64)>,
372) -> Result<FittedModelPayload, String> {
373    let mut payload = FittedModelPayload::new(
374        MODEL_PAYLOAD_VERSION,
375        formula,
376        ModelKind::Standard,
377        FittedFamily::Standard {
378            likelihood: gam_problem::types::LikelihoodSpec::gaussian_identity(),
379            link: None,
380            latent_cloglog_state: None,
381            mixture_state: None,
382            sas_state: None,
383        },
384        "gaussian".to_string(),
385    );
386    payload.residual_cascade = Some(SavedResidualCascade {
387        feature_columns,
388        state: fit.to_state().map_err(|e| {
389            format!("residual-cascade to_state failed during payload assembly: {e}")
390        })?,
391    });
392    payload.data_schema = Some(data_schema);
393    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
394    Ok(payload)
395}
396
397/// Assemble the canonical Bernoulli marginal-slope payload.
398///
399/// This is the single place that decides which payload fields a marginal-slope
400/// model carries and how the singular/vector mirror fields
401/// (`formula_logslope(s)`, `z_column(s)`, `logslope_baseline(s)`,
402/// `resolved_termspec_logslope(s)`) are kept consistent — so the CLI and FFI
403/// saved models are byte-equivalent for identical semantic content.
404pub fn assemble_bernoulli_marginal_slope_payload(
405    inputs: BernoulliMarginalSlopeInputs<'_>,
406    source: SavedModelSourceMetadata,
407) -> Result<FittedModelPayload, String> {
408    let BernoulliMarginalSlopeInputs {
409        formula,
410        data_schema,
411        logslope_formula,
412        z_column,
413        resolved_marginalspec,
414        resolved_logslopespec,
415        fit_result,
416        p_marginal,
417        baseline_marginal,
418        baseline_logslope,
419        latent_z_normalization,
420        latent_measure,
421        latent_z_rank_int_calibration,
422        latent_z_conditional_calibration,
423        score_warp_runtime,
424        link_dev_runtime,
425        base_link,
426        frailty,
427    } = inputs;
428
429    // #461 predict seam: drop the training-only influence-absorber γ (and
430    // marginalize it out of the covariance) so the persisted model matches the
431    // raw p_m marginal design at predict. No-op when the absorber is inactive.
432    let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
433
434    let marginal_likelihood_spec =
435        inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
436
437    let mut payload = FittedModelPayload::new(
438        MODEL_PAYLOAD_VERSION,
439        formula,
440        ModelKind::MarginalSlope,
441        FittedFamily::MarginalSlope {
442            likelihood: marginal_likelihood_spec,
443            base_link: base_link.clone(),
444            frailty,
445        },
446        FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
447    );
448    payload.unified = Some(fit_result.clone());
449    payload.fit_result = Some(fit_result);
450    payload.data_schema = Some(data_schema);
451    payload.formula_logslope = Some(logslope_formula.clone());
452    payload.z_column = Some(z_column.clone());
453    payload.formula_logslopes = Some(vec![logslope_formula]);
454    payload.z_columns = Some(vec![z_column]);
455    payload.latent_z_normalization = Some(latent_z_normalization);
456    payload.latent_measure = Some(latent_measure);
457    payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
458    payload.latent_z_conditional_calibration = latent_z_conditional_calibration;
459    payload.marginal_baseline = Some(baseline_marginal);
460    payload.logslope_baseline = Some(baseline_logslope);
461    payload.logslope_baselines = Some(vec![baseline_logslope]);
462    payload.link = Some(base_link);
463    payload.resolved_termspec = Some(resolved_marginalspec);
464    payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
465    payload.resolved_termspec_logslope = Some(resolved_logslopespec);
466    payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
467    payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
468    source.apply_to(&mut payload);
469    Ok(payload)
470}
471
472/// The resolved, source-agnostic semantic content of a transformation-normal
473/// saved model.
474///
475/// As with the marginal-slope inputs, the CLI threads the family and resolved
476/// covariate spec straight from its fit pipeline while the FFI reads them off
477/// its fit-result struct (freezing the covariate spec from its design first).
478pub struct TransformationNormalInputs<'a> {
479    pub formula: String,
480    pub data_schema: DataSchema,
481    pub resolved_covariate_spec: TermCollectionSpec,
482    pub fit_result: UnifiedFitResult,
483    pub family: &'a TransformationNormalFamily,
484    pub score_calibration: TransformationScoreCalibration,
485}
486
487/// Assemble the canonical transformation-normal payload.
488///
489/// Centralizing the response-transform snapshot (`knots`, `transform`,
490/// `degree`, `median`) and the fixed Gaussian-identity likelihood means the CLI
491/// and FFI cannot encode a transformation-normal model two different ways.
492pub fn assemble_transformation_normal_payload(
493    inputs: TransformationNormalInputs<'_>,
494    source: SavedModelSourceMetadata,
495) -> FittedModelPayload {
496    let TransformationNormalInputs {
497        formula,
498        data_schema,
499        resolved_covariate_spec,
500        fit_result,
501        family,
502        score_calibration,
503    } = inputs;
504
505    let mut payload = FittedModelPayload::new(
506        MODEL_PAYLOAD_VERSION,
507        formula,
508        ModelKind::TransformationNormal,
509        FittedFamily::TransformationNormal {
510            likelihood: LikelihoodSpec::new(
511                ResponseFamily::Gaussian,
512                InverseLink::Standard(StandardLink::Identity),
513            ),
514        },
515        FAMILY_TRANSFORMATION_NORMAL.to_string(),
516    );
517    payload.unified = Some(fit_result.clone());
518    payload.fit_result = Some(fit_result);
519    payload.data_schema = Some(data_schema);
520    payload.resolved_termspec = Some(resolved_covariate_spec);
521    payload.transformation_response_knots = Some(family.response_knots().to_vec());
522    payload.transformation_response_transform = Some(
523        family
524            .response_transform()
525            .rows()
526            .into_iter()
527            .map(|row| row.to_vec())
528            .collect(),
529    );
530    payload.transformation_response_degree = Some(family.response_degree());
531    payload.transformation_response_median = Some(family.response_median());
532    payload.transformation_score_calibration = Some(score_calibration);
533    source.apply_to(&mut payload);
534    payload
535}
536
537/// Which likelihood a (non-survival) location-scale model carries: Gaussian
538/// (residual response scale) or binomial (noise scale-deviation transform whose
539/// likelihood is resolved from the inverse link). The assembler resolves the
540/// `FittedFamily` from this once, rather than each save path stamping a
541/// (potentially wrong) likelihood and patching it afterwards.
542pub enum LocationScaleResponse<'a> {
543    /// Gaussian identity; `base_link` is the optional resolved base link the CLI
544    /// may pass through from `link(...)` (the FFI leaves it `None`).
545    Gaussian {
546        response_scale: f64,
547        base_link: Option<InverseLink>,
548    },
549    /// Binomial under `link`, with the encoded noise scale-deviation transform.
550    Binomial {
551        link: InverseLink,
552        noise_transform: &'a ScaleDeviationTransform,
553    },
554    /// A genuine-dispersion mean family (NegativeBinomial / Gamma / Beta /
555    /// Tweedie) whose log-precision channel carries `noise_formula` (#913). The
556    /// `likelihood` is the family's own [`LikelihoodSpec`]; `base_link` is the
557    /// mean inverse link (log, or logit for Beta). The log-precision block
558    /// coefficients ride in [`LocationScaleInputs::beta_noise`].
559    Dispersion {
560        likelihood: LikelihoodSpec,
561        base_link: InverseLink,
562        family_tag: &'static str,
563    },
564}
565
566/// Optional link-wiggle metadata persisted alongside a location-scale model.
567/// Knots/coefficients are already in raw response units — the Gaussian
568/// standardization and its inverse remap live inside
569/// `fit_gaussian_location_scale_model`, so the save path persists them verbatim.
570pub struct LocationScaleWiggle {
571    pub knots: Vec<f64>,
572    pub degree: usize,
573    pub beta_link_wiggle: Vec<f64>,
574}
575
576/// Source-agnostic semantic content of a (non-survival) location-scale saved
577/// model — the shared core behind the CLI's Gaussian/binomial save paths and
578/// the FFI's two location-scale builders.
579pub struct LocationScaleInputs {
580    pub formula: String,
581    pub data_schema: DataSchema,
582    pub noise_formula: String,
583    pub resolved_termspec: TermCollectionSpec,
584    pub resolved_termspec_noise: TermCollectionSpec,
585    pub fit_result: UnifiedFitResult,
586    pub beta_noise: Option<Vec<f64>>,
587    pub wiggle: Option<LocationScaleWiggle>,
588}
589
590/// Assemble the canonical (non-survival) location-scale payload — single source
591/// of truth for that on-disk contract. The family/likelihood is resolved from
592/// the [`LocationScaleResponse`] so the binomial branch never persists a wrong
593/// probit likelihood that a caller must patch afterwards.
594pub fn assemble_location_scale_payload(
595    inputs: LocationScaleInputs,
596    response: LocationScaleResponse<'_>,
597    source: SavedModelSourceMetadata,
598) -> Result<FittedModelPayload, String> {
599    let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
600    {
601        LocationScaleResponse::Gaussian {
602            response_scale,
603            base_link,
604        } => (
605            "gaussian-location-scale".to_string(),
606            LikelihoodSpec::gaussian_identity(),
607            // Gaussian location-scale does not carry a base link in its family
608            // state; the resolved link is persisted in `payload.link` below so
609            // prediction can recover it.
610            None,
611            Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
612            Some(response_scale),
613            None,
614        ),
615        LocationScaleResponse::Binomial {
616            link,
617            noise_transform,
618        } => {
619            let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
620                format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
621            })?;
622            (
623                "binomial-location-scale".to_string(),
624                likelihood,
625                Some(link.clone()),
626                Some(link),
627                None,
628                Some(noise_transform),
629            )
630        }
631        LocationScaleResponse::Dispersion {
632            likelihood,
633            base_link,
634            family_tag,
635        } => (
636            family_tag.to_string(),
637            likelihood,
638            Some(base_link.clone()),
639            Some(base_link),
640            None,
641            None,
642        ),
643    };
644
645    let mut payload = FittedModelPayload::new(
646        MODEL_PAYLOAD_VERSION,
647        inputs.formula,
648        ModelKind::LocationScale,
649        FittedFamily::LocationScale {
650            likelihood,
651            base_link,
652        },
653        family_tag,
654    );
655    payload.unified = Some(inputs.fit_result.clone());
656    payload.fit_result = Some(inputs.fit_result);
657    payload.data_schema = Some(inputs.data_schema);
658    payload.link = link;
659    payload.formula_noise = Some(inputs.noise_formula);
660    payload.beta_noise = inputs.beta_noise;
661    payload.gaussian_response_scale = response_scale;
662    if let Some(transform) = noise_transform {
663        payload.noise_projection = Some(
664            transform
665                .projection_coef
666                .rows()
667                .into_iter()
668                .map(|row| row.to_vec())
669                .collect(),
670        );
671        payload.noise_center = Some(transform.weighted_column_mean.to_vec());
672        payload.noise_scale = Some(transform.rescale.to_vec());
673        payload.noise_non_intercept_start = Some(transform.non_intercept_start);
674        payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
675    }
676    payload.resolved_termspec = Some(inputs.resolved_termspec);
677    payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
678    if let Some(wiggle) = inputs.wiggle {
679        payload.linkwiggle_knots = Some(wiggle.knots);
680        payload.linkwiggle_degree = Some(wiggle.degree);
681        payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
682    }
683    source.apply_to(&mut payload);
684    Ok(payload)
685}
686
687/// Source-agnostic semantic content of a survival marginal-slope
688/// (Royston-Parmar net) saved model. Centralizing assembly also fixes the
689/// FFI's prior omission of the `*_logslopes`/`*_columns`/`formula_logslopes`
690/// vector mirrors the CLI wrote.
691pub struct SurvivalMarginalSlopeInputs<'a> {
692    pub formula: String,
693    pub data_schema: DataSchema,
694    pub fit_result: UnifiedFitResult,
695    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
696    pub survival_entry: Option<String>,
697    pub survival_exit: String,
698    pub survival_event: String,
699    pub survivalspec: String,
700    pub baseline_cfg: SurvivalBaselineConfig,
701    pub time_basis: SavedSurvivalTimeBasis,
702    pub ridge_lambda: f64,
703    pub survival_likelihood_label: String,
704    pub resolved_marginalspec: TermCollectionSpec,
705    pub resolved_logslopespec: TermCollectionSpec,
706    pub logslope_formula: String,
707    pub z_column: String,
708    pub latent_z_normalization: SavedLatentZNormalization,
709    pub baseline_logslope: f64,
710    pub score_warp_runtime: Option<&'a DeviationRuntime>,
711    pub link_dev_runtime: Option<&'a DeviationRuntime>,
712    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the fit
713    /// hosted a dedicated additive absorber. Predict drops the absorber's `γ`;
714    /// this is persisted only so the predictor accounts for the extra trailing
715    /// block in the saved block count.
716    pub influence_absorber_width: Option<usize>,
717}
718
719/// Construct a Royston-Parmar survival [`FittedModelPayload`] through the
720/// canonical `Survival` family scaffold shared by every RP on-disk contract
721/// (marginal-slope, transformation, location-scale): the identity-link
722/// `RoystonParmar` likelihood, the persisted likelihood label, and the
723/// `fit_result` / `data_schema` install. Callers supply the two variants that
724/// differ — `survival_distribution` and `frailty` — and then set their own
725/// family-specific fields on the returned payload.
726fn new_royston_parmar_survival_payload(
727    formula: String,
728    fit_result: UnifiedFitResult,
729    data_schema: DataSchema,
730    survival_likelihood_label: &str,
731    survival_distribution: Option<ResidualDistribution>,
732    frailty: crate::survival::lognormal_kernel::FrailtySpec,
733) -> FittedModelPayload {
734    let mut payload = FittedModelPayload::new(
735        MODEL_PAYLOAD_VERSION,
736        formula,
737        ModelKind::Survival,
738        FittedFamily::Survival {
739            likelihood: LikelihoodSpec::new(
740                ResponseFamily::RoystonParmar,
741                InverseLink::Standard(StandardLink::Identity),
742            ),
743            survival_likelihood: Some(survival_likelihood_label.to_string()),
744            survival_distribution,
745            frailty,
746        },
747        ResponseFamily::RoystonParmar.name().to_string(),
748    );
749    payload.unified = Some(fit_result.clone());
750    payload.fit_result = Some(fit_result);
751    payload.data_schema = Some(data_schema);
752    payload
753}
754
755/// Assemble the canonical survival marginal-slope payload — single source of
756/// truth for that Royston-Parmar / Gaussian-residual on-disk contract.
757pub fn assemble_survival_marginal_slope_payload(
758    inputs: SurvivalMarginalSlopeInputs<'_>,
759    source: SavedModelSourceMetadata,
760) -> FittedModelPayload {
761    let mut payload = new_royston_parmar_survival_payload(
762        inputs.formula,
763        inputs.fit_result,
764        inputs.data_schema,
765        &inputs.survival_likelihood_label,
766        Some(ResidualDistribution::Gaussian),
767        inputs.frailty,
768    );
769    payload.survival_entry = inputs.survival_entry;
770    payload.survival_exit = Some(inputs.survival_exit);
771    payload.survival_event = Some(inputs.survival_event);
772    payload.survivalspec = Some(inputs.survivalspec);
773    payload.survival_baseline_target =
774        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
775    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
776    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
777    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
778    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
779    payload.apply_survival_time_basis(&inputs.time_basis);
780    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
781    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
782    payload.survival_distribution = Some(ResidualDistribution::Gaussian);
783    payload.link = Some(InverseLink::Standard(StandardLink::Probit));
784    payload.resolved_termspec = Some(inputs.resolved_marginalspec);
785    payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
786    payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
787    payload.formula_logslope = Some(inputs.logslope_formula.clone());
788    payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
789    payload.z_column = Some(inputs.z_column.clone());
790    payload.z_columns = Some(vec![inputs.z_column]);
791    payload.latent_z_normalization = Some(inputs.latent_z_normalization);
792    payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
793    payload.logslope_baseline = Some(inputs.baseline_logslope);
794    payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
795    payload.score_warp_runtime = inputs
796        .score_warp_runtime
797        .map(serialize_anchored_deviation_runtime);
798    payload.link_deviation_runtime = inputs
799        .link_dev_runtime
800        .map(serialize_anchored_deviation_runtime);
801    payload.influence_absorber_width = inputs.influence_absorber_width;
802    source.apply_to(&mut payload);
803    payload
804}
805
806/// Fitted baseline-timewiggle coefficients: a single block (net) or one per
807/// cause (joint cause-specific). Callers pass already-sliced coefficients.
808pub enum SurvivalTimewiggleBeta {
809    Single(Vec<f64>),
810    ByCause(Vec<Vec<f64>>),
811}
812
813/// Route the fitted baseline-timewiggle coefficients into the matching payload
814/// slot. Both survival payload assemblers funnel through this ONE exhaustive
815/// `match` so a new [`SurvivalTimewiggleBeta`] variant is a compile error rather
816/// than a silent drop (the location-scale assembler previously `if let`-matched
817/// only `Single` and silently discarded `ByCause`).
818fn apply_timewiggle_beta(payload: &mut FittedModelPayload, beta: SurvivalTimewiggleBeta) {
819    match beta {
820        SurvivalTimewiggleBeta::Single(beta) => {
821            payload.beta_baseline_timewiggle = Some(beta);
822        }
823        SurvivalTimewiggleBeta::ByCause(by_cause) => {
824            payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
825        }
826    }
827}
828
829/// Snapshot of the baseline-timewiggle block persisted with a survival model.
830pub struct SurvivalTimewiggle {
831    pub degree: usize,
832    pub knots: Vec<f64>,
833    pub penalty_orders: Option<Vec<usize>>,
834    pub double_penalty: Option<bool>,
835    pub beta: SurvivalTimewiggleBeta,
836}
837
838/// Source-agnostic semantic content of a survival transformation
839/// (Royston-Parmar) saved model — net single-cause or joint cause-specific.
840pub struct SurvivalTransformationInputs {
841    pub formula: String,
842    pub data_schema: DataSchema,
843    pub fit_result: UnifiedFitResult,
844    pub survival_entry: Option<String>,
845    pub survival_exit: String,
846    pub survival_event: String,
847    pub survivalspec: String,
848    /// `None` = net single-cause; `Some(n)` persists `survival_cause_count` and
849    /// `cause_1..cause_n` endpoint names.
850    pub cause_count: Option<usize>,
851    pub baseline_cfg: SurvivalBaselineConfig,
852    pub time_basis: SavedSurvivalTimeBasis,
853    pub ridge_lambda: f64,
854    pub survival_likelihood_label: String,
855    pub resolved_termspec: TermCollectionSpec,
856    /// Rigid time-block beta, persisted only by the cause-specific CLI path.
857    pub survival_beta_time: Option<Vec<f64>>,
858    pub timewiggle: Option<SurvivalTimewiggle>,
859}
860
861/// Assemble the canonical survival transformation payload — single source of
862/// truth for the Royston-Parmar transformation on-disk contract.
863pub fn assemble_survival_transformation_payload(
864    inputs: SurvivalTransformationInputs,
865    source: SavedModelSourceMetadata,
866) -> FittedModelPayload {
867    let mut payload = new_royston_parmar_survival_payload(
868        inputs.formula,
869        inputs.fit_result,
870        inputs.data_schema,
871        &inputs.survival_likelihood_label,
872        None,
873        crate::survival::lognormal_kernel::FrailtySpec::None,
874    );
875    payload.survival_entry = inputs.survival_entry;
876    payload.survival_exit = Some(inputs.survival_exit);
877    payload.survival_event = Some(inputs.survival_event);
878    payload.survivalspec = Some(inputs.survivalspec);
879    if let Some(cause_count) = inputs.cause_count {
880        payload.survival_cause_count = Some(cause_count);
881        payload.survival_endpoint_names = Some(
882            (1..=cause_count)
883                .map(|idx| format!("cause_{idx}"))
884                .collect(),
885        );
886    }
887    payload.survival_baseline_target =
888        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
889    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
890    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
891    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
892    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
893    payload.apply_survival_time_basis(&inputs.time_basis);
894    if let Some(timewiggle) = inputs.timewiggle {
895        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
896        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
897        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
898        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
899        apply_timewiggle_beta(&mut payload, timewiggle.beta);
900    }
901    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
902    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
903    payload.survival_beta_time = inputs.survival_beta_time;
904    payload.resolved_termspec = Some(inputs.resolved_termspec);
905    source.apply_to(&mut payload);
906    payload
907}
908
909/// Source-agnostic semantic content of a survival location-scale
910/// (Royston-Parmar with a learned residual link) saved model. Centralizing
911/// fixes the drift where CLI and FFI disagreed on `formula_noise`,
912/// `baseline_timewiggle_*`, and `survival_noise_projection_ridge_alpha`.
913pub struct SurvivalLocationScaleInputs<'a> {
914    pub formula: String,
915    pub data_schema: DataSchema,
916    /// Fit result with the fitted inverse-link state and link-wiggle artifacts
917    /// already applied by the caller.
918    pub fit_result: UnifiedFitResult,
919    pub fitted_inverse_link: InverseLink,
920    // Independent `Option`s (not an all-or-nothing group) so the assembler
921    // reproduces exactly what the CLI and FFI each persist independently.
922    pub linkwiggle_degree: Option<usize>,
923    pub linkwiggle_knots: Option<Vec<f64>>,
924    pub beta_link_wiggle: Option<Vec<f64>>,
925    pub baseline_timewiggle: Option<SurvivalTimewiggle>,
926    pub survival_entry: Option<String>,
927    pub survival_exit: String,
928    pub survival_event: String,
929    pub survivalspec: String,
930    pub baseline_cfg: SurvivalBaselineConfig,
931    pub time_basis: SavedSurvivalTimeBasis,
932    pub ridge_lambda: f64,
933    pub survival_likelihood_label: String,
934    pub formula_noise: Option<String>,
935    pub survival_beta_time: Vec<f64>,
936    pub survival_beta_threshold: Vec<f64>,
937    pub survival_beta_log_sigma: Vec<f64>,
938    pub noise_transform: &'a ScaleDeviationTransform,
939    pub resolved_thresholdspec: TermCollectionSpec,
940    pub resolved_log_sigmaspec: TermCollectionSpec,
941}
942
943/// Assemble the canonical survival location-scale payload (the single source of
944/// truth for that on-disk contract).
945pub fn assemble_survival_location_scale_payload(
946    inputs: SurvivalLocationScaleInputs<'_>,
947    source: SavedModelSourceMetadata,
948) -> FittedModelPayload {
949    let survival_distribution =
950        residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
951    let mut payload = new_royston_parmar_survival_payload(
952        inputs.formula,
953        inputs.fit_result,
954        inputs.data_schema,
955        &inputs.survival_likelihood_label,
956        survival_distribution,
957        crate::survival::lognormal_kernel::FrailtySpec::None,
958    );
959    payload.link = Some(inputs.fitted_inverse_link);
960    payload.linkwiggle_degree = inputs.linkwiggle_degree;
961    payload.linkwiggle_knots = inputs.linkwiggle_knots;
962    payload.beta_link_wiggle = inputs.beta_link_wiggle;
963    if let Some(timewiggle) = inputs.baseline_timewiggle {
964        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
965        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
966        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
967        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
968        apply_timewiggle_beta(&mut payload, timewiggle.beta);
969    }
970    payload.survival_entry = inputs.survival_entry;
971    payload.survival_exit = Some(inputs.survival_exit);
972    payload.survival_event = Some(inputs.survival_event);
973    payload.survivalspec = Some(inputs.survivalspec);
974    payload.survival_baseline_target =
975        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
976    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
977    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
978    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
979    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
980    payload.apply_survival_time_basis(&inputs.time_basis);
981    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
982    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
983    payload.formula_noise = inputs.formula_noise;
984    payload.survival_beta_time = Some(inputs.survival_beta_time);
985    payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
986    payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
987    payload.survival_noise_projection = Some(
988        inputs
989            .noise_transform
990            .projection_coef
991            .rows()
992            .into_iter()
993            .map(|row| row.to_vec())
994            .collect(),
995    );
996    payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
997    payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
998    payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
999    payload.survival_noise_projection_ridge_alpha =
1000        Some(inputs.noise_transform.projection_ridge_alpha);
1001    payload.survival_distribution = survival_distribution;
1002    payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
1003    payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
1004    source.apply_to(&mut payload);
1005    payload
1006}
1007
1008/// Source-agnostic semantic content of a latent survival / latent binary saved
1009/// model. The caller resolves the family (splicing the learned latent SD into
1010/// the persisted frailty for survival) and the model-class / likelihood labels.
1011pub struct LatentWindowInputs {
1012    pub formula: String,
1013    pub data_schema: DataSchema,
1014    pub fit_result: UnifiedFitResult,
1015    pub family: FittedFamily,
1016    pub model_class_label: String,
1017    pub likelihood_label: String,
1018    pub survival_entry: Option<String>,
1019    pub survival_exit: String,
1020    pub survival_event: String,
1021    pub baseline_cfg: SurvivalBaselineConfig,
1022    pub time_basis: SavedSurvivalTimeBasis,
1023    pub ridge_lambda: f64,
1024    pub beta_time: Vec<f64>,
1025    pub resolved_termspec: TermCollectionSpec,
1026}
1027
1028/// Assemble the canonical latent survival / latent binary payload.
1029pub fn assemble_latent_window_payload(
1030    inputs: LatentWindowInputs,
1031    source: SavedModelSourceMetadata,
1032) -> FittedModelPayload {
1033    let mut payload = FittedModelPayload::new(
1034        MODEL_PAYLOAD_VERSION,
1035        inputs.formula,
1036        ModelKind::Survival,
1037        inputs.family,
1038        inputs.model_class_label,
1039    );
1040    payload.unified = Some(inputs.fit_result.clone());
1041    payload.fit_result = Some(inputs.fit_result);
1042    payload.data_schema = Some(inputs.data_schema);
1043    payload.survival_entry = inputs.survival_entry;
1044    payload.survival_exit = Some(inputs.survival_exit);
1045    payload.survival_event = Some(inputs.survival_event);
1046    payload.survivalspec = Some("net".to_string());
1047    payload.survival_baseline_target =
1048        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
1049    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
1050    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
1051    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
1052    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
1053    payload.apply_survival_time_basis(&inputs.time_basis);
1054    payload.survival_likelihood = Some(inputs.likelihood_label);
1055    payload.survival_beta_time = Some(inputs.beta_time);
1056    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
1057    payload.resolved_termspec = Some(inputs.resolved_termspec);
1058    source.apply_to(&mut payload);
1059    payload
1060}
1061
1062#[cfg(test)]
1063mod apply_timewiggle_beta_tests {
1064    use super::*;
1065
1066    /// Minimal payload with both baseline-timewiggle slots unset. Uses the
1067    /// fixture-free `LatentBinary` family so the test needs no `LikelihoodSpec`.
1068    fn empty_payload() -> FittedModelPayload {
1069        FittedModelPayload::new(
1070            MODEL_PAYLOAD_VERSION,
1071            "y ~ 1".to_string(),
1072            ModelKind::Survival,
1073            FittedFamily::LatentBinary {
1074                frailty: crate::survival::lognormal_kernel::FrailtySpec::None,
1075            },
1076            "test".to_string(),
1077        )
1078    }
1079
1080    #[test]
1081    fn by_cause_beta_populates_only_the_by_cause_slot() {
1082        let mut payload = empty_payload();
1083        apply_timewiggle_beta(
1084            &mut payload,
1085            SurvivalTimewiggleBeta::ByCause(vec![vec![1.0, 2.0], vec![3.0]]),
1086        );
1087        assert_eq!(
1088            payload.beta_baseline_timewiggle_by_cause,
1089            Some(vec![vec![1.0, 2.0], vec![3.0]]),
1090            "ByCause coefficients must land in the by-cause slot (regression: the \
1091             location-scale assembler used to silently drop them)"
1092        );
1093        assert!(
1094            payload.beta_baseline_timewiggle.is_none(),
1095            "ByCause must not populate the single-block slot"
1096        );
1097    }
1098
1099    #[test]
1100    fn single_beta_populates_only_the_flat_slot() {
1101        let mut payload = empty_payload();
1102        apply_timewiggle_beta(&mut payload, SurvivalTimewiggleBeta::Single(vec![4.0, 5.0]));
1103        assert_eq!(payload.beta_baseline_timewiggle, Some(vec![4.0, 5.0]));
1104        assert!(payload.beta_baseline_timewiggle_by_cause.is_none());
1105    }
1106}