Skip to main content

gam_models/inference/
model_payload_builders.rs

1//! Shared, source-agnostic builders for saved-model payloads.
2//!
3//! The CLI (`src/main.rs`) and the Python FFI (`crates/gam-pyffi/src/lib.rs`)
4//! both persist fitted models, and both used to assemble the serialized
5//! [`FittedModelPayload`] independently. That meant the on-disk contract for a
6//! given model kind could silently drift depending on whether the model was
7//! created through the CLI or through Python — exactly the failure mode that
8//! repeatedly bit the marginal-slope save→load path.
9//!
10//! This module assembles the *semantic* payload exactly once. Each caller is
11//! responsible only for the source-specific work of producing the resolved
12//! semantic inputs (the CLI threads them through from its argument parsing and
13//! fit pipeline; the FFI freezes term collections from designs and re-derives
14//! metadata from the [`FitConfig`]). Once both sides hand the same semantic
15//! content to the same assembler, payload drift becomes impossible by
16//! construction.
17
18use gam_solve::estimate::UnifiedFitResult;
19use crate::bms::deviation_runtime::AnchorComponentTag;
20use crate::bms::{
21    DeviationRuntime, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
22};
23use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
24use crate::scale_design::ScaleDeviationTransform;
25use crate::survival::construction::{
26    SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
27};
28use crate::survival::location_scale::{
29    ResidualDistribution, residual_distribution_from_inverse_link,
30};
31use crate::transformation_normal::TransformationNormalFamily;
32use crate::inference::model::{
33    FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind, SavedAnchorComponent,
34    SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization, SavedResidualCascade,
35    SavedSplineScan, TransformationScoreCalibration,
36};
37use gam_terms::smooth::TermCollectionSpec;
38use gam_problem::types::{
39    InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
40};
41use gam_data::DataSchema;
42use ndarray::Array2;
43
44/// Family tag persisted for Bernoulli marginal-slope saved models.
45const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
46
47/// Family tag persisted for transformation-normal saved models.
48const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
49
50/// Serialize an anchored-deviation [`DeviationRuntime`] (score-warp or
51/// link-deviation block) into its persistable [`SavedCompiledFlexBlock`] form.
52///
53/// This is the single source of truth for that conversion; the CLI and FFI
54/// payload builders both route through it so the serialized flex contract
55/// cannot diverge between the two save paths.
56pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
57    let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
58    let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
59    if let Some(installed) = runtime.installed_flex_block() {
60        anchor_correction = Some(
61            installed
62                .anchor_correction
63                .rows()
64                .into_iter()
65                .map(|row| row.to_vec())
66                .collect::<Vec<Vec<f64>>>(),
67        );
68        for component in &installed.anchor_components {
69            anchor_components.push(SavedAnchorComponent {
70                kind: match component {
71                    AnchorComponentTag::Parametric { block, ncols } => {
72                        SavedAnchorKind::Parametric {
73                            block: *block,
74                            ncols: *ncols,
75                        }
76                    }
77                    AnchorComponentTag::FlexEvaluation { ncols } => {
78                        SavedAnchorKind::FlexEvaluation { ncols: *ncols }
79                    }
80                },
81            });
82        }
83    }
84    SavedCompiledFlexBlock {
85        kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
86        breakpoints: runtime.breakpoints().to_vec(),
87        basis_dim: runtime.basis_dim(),
88        span_c0: runtime
89            .span_c0()
90            .rows()
91            .into_iter()
92            .map(|row| row.to_vec())
93            .collect(),
94        span_c1: runtime
95            .span_c1()
96            .rows()
97            .into_iter()
98            .map(|row| row.to_vec())
99            .collect(),
100        span_c2: runtime
101            .span_c2()
102            .rows()
103            .into_iter()
104            .map(|row| row.to_vec())
105            .collect(),
106        span_c3: runtime
107            .span_c3()
108            .rows()
109            .into_iter()
110            .map(|row| row.to_vec())
111            .collect(),
112        anchor_correction,
113        anchor_components,
114    }
115}
116
117/// Source-specific metadata that the CLI and FFI populate differently but that
118/// every saved payload carries.
119///
120/// `training_feature_ranges` is the only field the FFI path cannot currently
121/// supply (it persists headers without per-feature ranges); modeling it as
122/// `Option` keeps that distinction explicit instead of silently encoding an
123/// empty vector as if ranges were known.
124pub struct SavedModelSourceMetadata {
125    pub training_headers: Vec<String>,
126    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
127    pub offset_column: Option<String>,
128    pub noise_offset_column: Option<String>,
129}
130
131impl SavedModelSourceMetadata {
132    fn apply_to(self, payload: &mut FittedModelPayload) {
133        match self.training_feature_ranges {
134            Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
135            None => payload.training_headers = Some(self.training_headers),
136        }
137        payload.offset_column = self.offset_column;
138        payload.noise_offset_column = self.noise_offset_column;
139    }
140}
141
142/// The resolved, source-agnostic semantic content of a Bernoulli
143/// marginal-slope saved model.
144///
145/// The CLI threads these in directly from its fit pipeline; the FFI produces
146/// them by freezing its term collections and reading the [`FitConfig`]. Either
147/// way, the assembler below turns them into the canonical payload.
148pub struct BernoulliMarginalSlopeInputs<'a> {
149    pub formula: String,
150    pub data_schema: DataSchema,
151    pub logslope_formula: String,
152    pub z_column: String,
153    pub resolved_marginalspec: TermCollectionSpec,
154    pub resolved_logslopespec: TermCollectionSpec,
155    pub fit_result: UnifiedFitResult,
156    /// Number of *raw* marginal design columns `p_m` (= the term-collection
157    /// marginal design's `ncols()` BEFORE any #461 influence-absorber widening).
158    ///
159    /// When the Stage-1 influence absorber is active (A2), the fitted marginal
160    /// block carries the widened coefficient `[β_m; γ]` (length `p_m + p₁`) and
161    /// the joint covariance is dimensioned over the widened block. The absorbed
162    /// influence columns `Z̃_infl` are a TRAINING-only leakage absorber that does
163    /// not exist at predict rows, so the persisted model must drop `γ` and the
164    /// marginalized-out covariance sub-block to stay self-consistent against the
165    /// raw `p_m` marginal design at predict. The assembler uses this to truncate
166    /// the fit result once (shared CLI + FFI). With no absorber it equals the
167    /// fitted block width and the truncation is a no-op.
168    pub p_marginal: usize,
169    pub baseline_marginal: f64,
170    pub baseline_logslope: f64,
171    pub latent_z_normalization: SavedLatentZNormalization,
172    pub latent_measure: LatentMeasureKind,
173    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
174    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
175    pub score_warp_runtime: Option<&'a DeviationRuntime>,
176    pub link_dev_runtime: Option<&'a DeviationRuntime>,
177    pub base_link: InverseLink,
178    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
179}
180
181/// Drop the #461 training-only influence-absorber coefficients `γ` from a fitted
182/// Bernoulli marginal-slope result so the persisted model is self-consistent
183/// against the raw `p_m`-column marginal design at predict.
184///
185/// When the A2 influence absorber is active the marginal block (block 0) is the
186/// widened `[β_m; γ]` (length `p_m + p₁`, with `γ` the contiguous trailing `p₁`
187/// columns — see bms `widen_marginal_dense_with_influence`) and the joint
188/// conditional covariance is dimensioned over the widened joint coefficient
189/// vector. The absorbed columns `Z̃_infl` exist only at training rows; predict
190/// reconstructs the marginal index from the raw `p_m` design and the
191/// orthogonalized `β̂_m` is a property of the training fit. So this:
192///
193///  * slices `blocks[0].beta` and `block_states[0].beta` to their first `p_m`
194///    entries (the flat `beta` is recomputed from the blocks by
195///    `try_from_parts`),
196///  * **marginalizes** `γ` out of the joint Gaussian by dropping the `γ`
197///    rows/cols from the conditional covariance — taking the corresponding
198///    SUB-BLOCK of `Σ` is the exact marginal of a joint Gaussian (no
199///    re-inversion), so the kept `[β_m | β_logslope | …]` covariance is the
200///    correct predictive uncertainty accounting for the fitted absorber,
201///  * drops the persisted joint penalized-Hessian geometry: it is a precision
202///    over the *widened* joint coefficient vector, so a sub-block would be the
203///    wrong marginalization, and the only predict path that consumes it is the
204///    covariance-fallback that re-inverts `H` — which post-truncation would have
205///    the wrong dimension anyway. With the dense (already-marginalized) `Σ`
206///    matching the predict dimension, that fallback is never taken, so dropping
207///    the geometry removes a stale, wrong-dimension path rather than a used one.
208///
209/// Block-level `edf` / `lambdas` are left untouched: they are fitted scalars
210/// that legitimately reflect the full model (the absorber consumed real dof at
211/// fit time) and are persisted as-is. With no absorber (`block0.len() == p_m`)
212/// this is a no-op clone.
213fn truncate_marginal_slope_influence_absorber(
214    fit_result: UnifiedFitResult,
215    p_marginal: usize,
216) -> Result<UnifiedFitResult, String> {
217    let Some(block0) = fit_result.blocks.first() else {
218        return Err("marginal-slope fit result has no coefficient blocks".to_string());
219    };
220    let widened_len = block0.beta.len();
221    if widened_len <= p_marginal {
222        // No influence absorber installed (or already raw width): nothing to drop.
223        return Ok(fit_result);
224    }
225    let p_influence = widened_len - p_marginal;
226
227    let UnifiedFitResult {
228        mut blocks,
229        log_lambdas,
230        lambdas,
231        likelihood_family,
232        likelihood_scale,
233        log_likelihood_normalization,
234        log_likelihood,
235        deviance,
236        reml_score,
237        stable_penalty_term,
238        penalized_objective,
239        used_device,
240        outer_iterations,
241        outer_converged,
242        outer_gradient_norm,
243        standard_deviation,
244        covariance_conditional,
245        covariance_corrected,
246        inference,
247        fitted_link,
248        geometry: _,
249        mut block_states,
250        beta: _,
251        pirls_status,
252        max_abs_eta,
253        constraint_kkt,
254        artifacts,
255        inner_cycles,
256        outer_cost_evals: _,
257        inner_pirls_solves: _,
258    } = fit_result;
259
260    // Slice block 0's coefficients (and matching block-state) to the raw p_m,
261    // dropping the trailing γ absorber columns.
262    blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
263    if let Some(state0) = block_states.first_mut() {
264        state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
265    }
266
267    // Marginalize γ out of the joint conditional covariance: keep every index
268    // except the contiguous γ block [p_marginal, p_marginal + p_influence).
269    let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
270        cov.map(|cov| {
271            let total = cov.nrows();
272            let kept: Vec<usize> = (0..p_marginal)
273                .chain((p_marginal + p_influence)..total)
274                .collect();
275            let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
276            for (ri, &r) in kept.iter().enumerate() {
277                for (ci, &c) in kept.iter().enumerate() {
278                    out[[ri, ci]] = cov[[r, c]];
279                }
280            }
281            out
282        })
283    };
284    let covariance_conditional = drop_gamma_block(covariance_conditional);
285    let covariance_corrected = drop_gamma_block(covariance_corrected);
286
287    UnifiedFitResult::try_from_parts(gam_solve::estimate::UnifiedFitResultParts {
288        blocks,
289        log_lambdas,
290        lambdas,
291        likelihood_family,
292        likelihood_scale,
293        log_likelihood_normalization,
294        log_likelihood,
295        deviance,
296        reml_score,
297        stable_penalty_term,
298        penalized_objective,
299        // Preserve the GPU-execution flag across the absorber-column
300        // truncation: dropping the trailing γ columns does not change which
301        // device ran the solve.
302        used_device,
303        outer_iterations,
304        outer_converged,
305        outer_gradient_norm,
306        standard_deviation,
307        covariance_conditional,
308        covariance_corrected,
309        inference,
310        fitted_link,
311        // Drop the widened-joint penalized Hessian: see the doc comment.
312        geometry: None,
313        block_states,
314        pirls_status,
315        max_abs_eta,
316        constraint_kkt,
317        artifacts,
318        inner_cycles,
319    })
320    .map_err(|e| {
321        format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
322    })
323}
324
325/// Assemble the canonical spline-scan payload (#1030/#1034): a standard
326/// Gaussian-identity model whose fit representation is the exact O(n)
327/// smoothing-spline smoother state instead of a dense `fit_result`. The CLI
328/// and FFI save paths both route through here so the scan on-disk contract
329/// cannot diverge between sources.
330pub fn assemble_spline_scan_payload(
331    formula: String,
332    feature_column: String,
333    fit: &gam_solve::spline_scan::SplineScanFit,
334    data_schema: DataSchema,
335    training_headers: Vec<String>,
336    training_feature_ranges: Vec<(f64, f64)>,
337) -> FittedModelPayload {
338    let mut payload = FittedModelPayload::new(
339        MODEL_PAYLOAD_VERSION,
340        formula,
341        ModelKind::Standard,
342        FittedFamily::Standard {
343            likelihood: LikelihoodSpec::gaussian_identity(),
344            link: None,
345            latent_cloglog_state: None,
346            mixture_state: None,
347            sas_state: None,
348        },
349        "gaussian".to_string(),
350    );
351    payload.spline_scan = Some(SavedSplineScan {
352        feature_column,
353        state: fit.to_state(),
354    });
355    payload.data_schema = Some(data_schema);
356    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
357    payload
358}
359
360/// Assemble the canonical residual-cascade payload (#1032).
361///
362/// The CLI and FFI save paths both route through here so the cascade on-disk
363/// contract cannot diverge between sources.  Mirrors `assemble_spline_scan_payload`
364/// but for d ∈ {2,3} scattered coordinates (the Wendland multilevel-frame state).
365pub fn assemble_residual_cascade_payload(
366    formula: String,
367    feature_columns: Vec<String>,
368    fit: &gam_solve::residual_cascade::ResidualCascadeFit,
369    data_schema: DataSchema,
370    training_headers: Vec<String>,
371    training_feature_ranges: Vec<(f64, f64)>,
372) -> Result<FittedModelPayload, String> {
373    let mut payload = FittedModelPayload::new(
374        MODEL_PAYLOAD_VERSION,
375        formula,
376        ModelKind::Standard,
377        FittedFamily::Standard {
378            likelihood: gam_problem::types::LikelihoodSpec::gaussian_identity(),
379            link: None,
380            latent_cloglog_state: None,
381            mixture_state: None,
382            sas_state: None,
383        },
384        "gaussian".to_string(),
385    );
386    payload.residual_cascade = Some(SavedResidualCascade {
387        feature_columns,
388        state: fit.to_state().map_err(|e| {
389            format!("residual-cascade to_state failed during payload assembly: {e}")
390        })?,
391    });
392    payload.data_schema = Some(data_schema);
393    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
394    Ok(payload)
395}
396
397/// Assemble the canonical Bernoulli marginal-slope payload.
398///
399/// This is the single place that decides which payload fields a marginal-slope
400/// model carries and how the singular/vector mirror fields
401/// (`formula_logslope(s)`, `z_column(s)`, `logslope_baseline(s)`,
402/// `resolved_termspec_logslope(s)`) are kept consistent — so the CLI and FFI
403/// saved models are byte-equivalent for identical semantic content.
404pub fn assemble_bernoulli_marginal_slope_payload(
405    inputs: BernoulliMarginalSlopeInputs<'_>,
406    source: SavedModelSourceMetadata,
407) -> Result<FittedModelPayload, String> {
408    let BernoulliMarginalSlopeInputs {
409        formula,
410        data_schema,
411        logslope_formula,
412        z_column,
413        resolved_marginalspec,
414        resolved_logslopespec,
415        fit_result,
416        p_marginal,
417        baseline_marginal,
418        baseline_logslope,
419        latent_z_normalization,
420        latent_measure,
421        latent_z_rank_int_calibration,
422        latent_z_conditional_calibration,
423        score_warp_runtime,
424        link_dev_runtime,
425        base_link,
426        frailty,
427    } = inputs;
428
429    // #461 predict seam: drop the training-only influence-absorber γ (and
430    // marginalize it out of the covariance) so the persisted model matches the
431    // raw p_m marginal design at predict. No-op when the absorber is inactive.
432    let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
433
434    let marginal_likelihood_spec =
435        inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
436
437    let mut payload = FittedModelPayload::new(
438        MODEL_PAYLOAD_VERSION,
439        formula,
440        ModelKind::MarginalSlope,
441        FittedFamily::MarginalSlope {
442            likelihood: marginal_likelihood_spec,
443            base_link: base_link.clone(),
444            frailty,
445        },
446        FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
447    );
448    payload.unified = Some(fit_result.clone());
449    payload.fit_result = Some(fit_result);
450    payload.data_schema = Some(data_schema);
451    payload.formula_logslope = Some(logslope_formula.clone());
452    payload.z_column = Some(z_column.clone());
453    payload.formula_logslopes = Some(vec![logslope_formula]);
454    payload.z_columns = Some(vec![z_column]);
455    payload.latent_z_normalization = Some(latent_z_normalization);
456    payload.latent_measure = Some(latent_measure);
457    payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
458    payload.latent_z_conditional_calibration = latent_z_conditional_calibration;
459    payload.marginal_baseline = Some(baseline_marginal);
460    payload.logslope_baseline = Some(baseline_logslope);
461    payload.logslope_baselines = Some(vec![baseline_logslope]);
462    payload.link = Some(base_link);
463    payload.resolved_termspec = Some(resolved_marginalspec);
464    payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
465    payload.resolved_termspec_logslope = Some(resolved_logslopespec);
466    payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
467    payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
468    source.apply_to(&mut payload);
469    Ok(payload)
470}
471
472/// The resolved, source-agnostic semantic content of a transformation-normal
473/// saved model.
474///
475/// As with the marginal-slope inputs, the CLI threads the family and resolved
476/// covariate spec straight from its fit pipeline while the FFI reads them off
477/// its fit-result struct (freezing the covariate spec from its design first).
478pub struct TransformationNormalInputs<'a> {
479    pub formula: String,
480    pub data_schema: DataSchema,
481    pub resolved_covariate_spec: TermCollectionSpec,
482    pub fit_result: UnifiedFitResult,
483    pub family: &'a TransformationNormalFamily,
484    pub score_calibration: TransformationScoreCalibration,
485}
486
487/// Assemble the canonical transformation-normal payload.
488///
489/// Centralizing the response-transform snapshot (`knots`, `transform`,
490/// `degree`, `median`) and the fixed Gaussian-identity likelihood means the CLI
491/// and FFI cannot encode a transformation-normal model two different ways.
492pub fn assemble_transformation_normal_payload(
493    inputs: TransformationNormalInputs<'_>,
494    source: SavedModelSourceMetadata,
495) -> FittedModelPayload {
496    let TransformationNormalInputs {
497        formula,
498        data_schema,
499        resolved_covariate_spec,
500        fit_result,
501        family,
502        score_calibration,
503    } = inputs;
504
505    let mut payload = FittedModelPayload::new(
506        MODEL_PAYLOAD_VERSION,
507        formula,
508        ModelKind::TransformationNormal,
509        FittedFamily::TransformationNormal {
510            likelihood: LikelihoodSpec::new(
511                ResponseFamily::Gaussian,
512                InverseLink::Standard(StandardLink::Identity),
513            ),
514        },
515        FAMILY_TRANSFORMATION_NORMAL.to_string(),
516    );
517    payload.unified = Some(fit_result.clone());
518    payload.fit_result = Some(fit_result);
519    payload.data_schema = Some(data_schema);
520    payload.resolved_termspec = Some(resolved_covariate_spec);
521    payload.transformation_response_knots = Some(family.response_knots().to_vec());
522    payload.transformation_response_transform = Some(
523        family
524            .response_transform()
525            .rows()
526            .into_iter()
527            .map(|row| row.to_vec())
528            .collect(),
529    );
530    payload.transformation_response_degree = Some(family.response_degree());
531    payload.transformation_response_median = Some(family.response_median());
532    payload.transformation_score_calibration = Some(score_calibration);
533    source.apply_to(&mut payload);
534    payload
535}
536
537/// Which likelihood a (non-survival) location-scale model carries: Gaussian
538/// (residual response scale) or binomial (noise scale-deviation transform whose
539/// likelihood is resolved from the inverse link). The assembler resolves the
540/// `FittedFamily` from this once, rather than each save path stamping a
541/// (potentially wrong) likelihood and patching it afterwards.
542pub enum LocationScaleResponse<'a> {
543    /// Gaussian identity; `base_link` is the optional resolved base link the CLI
544    /// may pass through from `link(...)` (the FFI leaves it `None`).
545    Gaussian {
546        response_scale: f64,
547        base_link: Option<InverseLink>,
548    },
549    /// Binomial under `link`, with the encoded noise scale-deviation transform.
550    Binomial {
551        link: InverseLink,
552        noise_transform: &'a ScaleDeviationTransform,
553    },
554    /// A genuine-dispersion mean family (NegativeBinomial / Gamma / Beta /
555    /// Tweedie) whose log-precision channel carries `noise_formula` (#913). The
556    /// `likelihood` is the family's own [`LikelihoodSpec`]; `base_link` is the
557    /// mean inverse link (log, or logit for Beta). The log-precision block
558    /// coefficients ride in [`LocationScaleInputs::beta_noise`].
559    Dispersion {
560        likelihood: LikelihoodSpec,
561        base_link: InverseLink,
562        family_tag: &'static str,
563    },
564}
565
566/// Optional link-wiggle metadata persisted alongside a location-scale model.
567/// Knots/coefficients are already in raw response units — the Gaussian
568/// standardization and its inverse remap live inside
569/// `fit_gaussian_location_scale_model`, so the save path persists them verbatim.
570pub struct LocationScaleWiggle {
571    pub knots: Vec<f64>,
572    pub degree: usize,
573    pub beta_link_wiggle: Vec<f64>,
574}
575
576/// Source-agnostic semantic content of a (non-survival) location-scale saved
577/// model — the shared core behind the CLI's Gaussian/binomial save paths and
578/// the FFI's two location-scale builders.
579pub struct LocationScaleInputs {
580    pub formula: String,
581    pub data_schema: DataSchema,
582    pub noise_formula: String,
583    pub resolved_termspec: TermCollectionSpec,
584    pub resolved_termspec_noise: TermCollectionSpec,
585    pub fit_result: UnifiedFitResult,
586    pub beta_noise: Option<Vec<f64>>,
587    pub wiggle: Option<LocationScaleWiggle>,
588}
589
590/// Assemble the canonical (non-survival) location-scale payload — single source
591/// of truth for that on-disk contract. The family/likelihood is resolved from
592/// the [`LocationScaleResponse`] so the binomial branch never persists a wrong
593/// probit likelihood that a caller must patch afterwards.
594pub fn assemble_location_scale_payload(
595    inputs: LocationScaleInputs,
596    response: LocationScaleResponse<'_>,
597    source: SavedModelSourceMetadata,
598) -> Result<FittedModelPayload, String> {
599    let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
600    {
601        LocationScaleResponse::Gaussian {
602            response_scale,
603            base_link,
604        } => (
605            "gaussian-location-scale".to_string(),
606            LikelihoodSpec::gaussian_identity(),
607            // Gaussian location-scale does not carry a base link in its family
608            // state; the resolved link is persisted in `payload.link` below so
609            // prediction can recover it.
610            None,
611            Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
612            Some(response_scale),
613            None,
614        ),
615        LocationScaleResponse::Binomial {
616            link,
617            noise_transform,
618        } => {
619            let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
620                format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
621            })?;
622            (
623                "binomial-location-scale".to_string(),
624                likelihood,
625                Some(link.clone()),
626                Some(link),
627                None,
628                Some(noise_transform),
629            )
630        }
631        LocationScaleResponse::Dispersion {
632            likelihood,
633            base_link,
634            family_tag,
635        } => (
636            family_tag.to_string(),
637            likelihood,
638            Some(base_link.clone()),
639            Some(base_link),
640            None,
641            None,
642        ),
643    };
644
645    let mut payload = FittedModelPayload::new(
646        MODEL_PAYLOAD_VERSION,
647        inputs.formula,
648        ModelKind::LocationScale,
649        FittedFamily::LocationScale {
650            likelihood,
651            base_link,
652        },
653        family_tag,
654    );
655    payload.unified = Some(inputs.fit_result.clone());
656    payload.fit_result = Some(inputs.fit_result);
657    payload.data_schema = Some(inputs.data_schema);
658    payload.link = link;
659    payload.formula_noise = Some(inputs.noise_formula);
660    payload.beta_noise = inputs.beta_noise;
661    payload.gaussian_response_scale = response_scale;
662    if let Some(transform) = noise_transform {
663        payload.noise_projection = Some(
664            transform
665                .projection_coef
666                .rows()
667                .into_iter()
668                .map(|row| row.to_vec())
669                .collect(),
670        );
671        payload.noise_center = Some(transform.weighted_column_mean.to_vec());
672        payload.noise_scale = Some(transform.rescale.to_vec());
673        payload.noise_non_intercept_start = Some(transform.non_intercept_start);
674        payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
675    }
676    payload.resolved_termspec = Some(inputs.resolved_termspec);
677    payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
678    if let Some(wiggle) = inputs.wiggle {
679        payload.linkwiggle_knots = Some(wiggle.knots);
680        payload.linkwiggle_degree = Some(wiggle.degree);
681        payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
682    }
683    source.apply_to(&mut payload);
684    Ok(payload)
685}
686
687/// Source-agnostic semantic content of a survival marginal-slope
688/// (Royston-Parmar net) saved model. Centralizing assembly also fixes the
689/// FFI's prior omission of the `*_logslopes`/`*_columns`/`formula_logslopes`
690/// vector mirrors the CLI wrote.
691pub struct SurvivalMarginalSlopeInputs<'a> {
692    pub formula: String,
693    pub data_schema: DataSchema,
694    pub fit_result: UnifiedFitResult,
695    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
696    pub survival_entry: Option<String>,
697    pub survival_exit: String,
698    pub survival_event: String,
699    pub survivalspec: String,
700    pub baseline_cfg: SurvivalBaselineConfig,
701    pub time_basis: SavedSurvivalTimeBasis,
702    pub ridge_lambda: f64,
703    pub survival_likelihood_label: String,
704    pub resolved_marginalspec: TermCollectionSpec,
705    pub resolved_logslopespec: TermCollectionSpec,
706    pub logslope_formula: String,
707    pub z_column: String,
708    pub latent_z_normalization: SavedLatentZNormalization,
709    pub baseline_logslope: f64,
710    pub score_warp_runtime: Option<&'a DeviationRuntime>,
711    pub link_dev_runtime: Option<&'a DeviationRuntime>,
712    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the fit
713    /// hosted a dedicated additive absorber. Predict drops the absorber's `γ`;
714    /// this is persisted only so the predictor accounts for the extra trailing
715    /// block in the saved block count.
716    pub influence_absorber_width: Option<usize>,
717}
718
719/// Construct a Royston-Parmar survival [`FittedModelPayload`] through the
720/// canonical `Survival` family scaffold shared by every RP on-disk contract
721/// (marginal-slope, transformation, location-scale): the identity-link
722/// `RoystonParmar` likelihood, the persisted likelihood label, and the
723/// `fit_result` / `data_schema` install. Callers supply the two variants that
724/// differ — `survival_distribution` and `frailty` — and then set their own
725/// family-specific fields on the returned payload.
726fn new_royston_parmar_survival_payload(
727    formula: String,
728    fit_result: UnifiedFitResult,
729    data_schema: DataSchema,
730    survival_likelihood_label: &str,
731    survival_distribution: Option<ResidualDistribution>,
732    frailty: crate::survival::lognormal_kernel::FrailtySpec,
733) -> FittedModelPayload {
734    let mut payload = FittedModelPayload::new(
735        MODEL_PAYLOAD_VERSION,
736        formula,
737        ModelKind::Survival,
738        FittedFamily::Survival {
739            likelihood: LikelihoodSpec::new(
740                ResponseFamily::RoystonParmar,
741                InverseLink::Standard(StandardLink::Identity),
742            ),
743            survival_likelihood: Some(survival_likelihood_label.to_string()),
744            survival_distribution,
745            frailty,
746        },
747        ResponseFamily::RoystonParmar.name().to_string(),
748    );
749    payload.unified = Some(fit_result.clone());
750    payload.fit_result = Some(fit_result);
751    payload.data_schema = Some(data_schema);
752    payload
753}
754
755/// Assemble the canonical survival marginal-slope payload — single source of
756/// truth for that Royston-Parmar / Gaussian-residual on-disk contract.
757pub fn assemble_survival_marginal_slope_payload(
758    inputs: SurvivalMarginalSlopeInputs<'_>,
759    source: SavedModelSourceMetadata,
760) -> FittedModelPayload {
761    let mut payload = new_royston_parmar_survival_payload(
762        inputs.formula,
763        inputs.fit_result,
764        inputs.data_schema,
765        &inputs.survival_likelihood_label,
766        Some(ResidualDistribution::Gaussian),
767        inputs.frailty,
768    );
769    payload.survival_entry = inputs.survival_entry;
770    payload.survival_exit = Some(inputs.survival_exit);
771    payload.survival_event = Some(inputs.survival_event);
772    payload.survivalspec = Some(inputs.survivalspec);
773    payload.survival_baseline_target =
774        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
775    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
776    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
777    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
778    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
779    payload.apply_survival_time_basis(&inputs.time_basis);
780    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
781    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
782    payload.survival_distribution = Some(ResidualDistribution::Gaussian);
783    payload.link = Some(InverseLink::Standard(StandardLink::Probit));
784    payload.resolved_termspec = Some(inputs.resolved_marginalspec);
785    payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
786    payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
787    payload.formula_logslope = Some(inputs.logslope_formula.clone());
788    payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
789    payload.z_column = Some(inputs.z_column.clone());
790    payload.z_columns = Some(vec![inputs.z_column]);
791    payload.latent_z_normalization = Some(inputs.latent_z_normalization);
792    payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
793    payload.logslope_baseline = Some(inputs.baseline_logslope);
794    payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
795    payload.score_warp_runtime = inputs
796        .score_warp_runtime
797        .map(serialize_anchored_deviation_runtime);
798    payload.link_deviation_runtime = inputs
799        .link_dev_runtime
800        .map(serialize_anchored_deviation_runtime);
801    payload.influence_absorber_width = inputs.influence_absorber_width;
802    source.apply_to(&mut payload);
803    payload
804}
805
806/// Fitted baseline-timewiggle coefficients: a single block (net) or one per
807/// cause (joint cause-specific). Callers pass already-sliced coefficients.
808pub enum SurvivalTimewiggleBeta {
809    Single(Vec<f64>),
810    ByCause(Vec<Vec<f64>>),
811}
812
813/// Snapshot of the baseline-timewiggle block persisted with a survival model.
814pub struct SurvivalTimewiggle {
815    pub degree: usize,
816    pub knots: Vec<f64>,
817    pub penalty_orders: Option<Vec<usize>>,
818    pub double_penalty: Option<bool>,
819    pub beta: SurvivalTimewiggleBeta,
820}
821
822/// Source-agnostic semantic content of a survival transformation
823/// (Royston-Parmar) saved model — net single-cause or joint cause-specific.
824pub struct SurvivalTransformationInputs {
825    pub formula: String,
826    pub data_schema: DataSchema,
827    pub fit_result: UnifiedFitResult,
828    pub survival_entry: Option<String>,
829    pub survival_exit: String,
830    pub survival_event: String,
831    pub survivalspec: String,
832    /// `None` = net single-cause; `Some(n)` persists `survival_cause_count` and
833    /// `cause_1..cause_n` endpoint names.
834    pub cause_count: Option<usize>,
835    pub baseline_cfg: SurvivalBaselineConfig,
836    pub time_basis: SavedSurvivalTimeBasis,
837    pub ridge_lambda: f64,
838    pub survival_likelihood_label: String,
839    pub resolved_termspec: TermCollectionSpec,
840    /// Rigid time-block beta, persisted only by the cause-specific CLI path.
841    pub survival_beta_time: Option<Vec<f64>>,
842    pub timewiggle: Option<SurvivalTimewiggle>,
843}
844
845/// Assemble the canonical survival transformation payload — single source of
846/// truth for the Royston-Parmar transformation on-disk contract.
847pub fn assemble_survival_transformation_payload(
848    inputs: SurvivalTransformationInputs,
849    source: SavedModelSourceMetadata,
850) -> FittedModelPayload {
851    let mut payload = new_royston_parmar_survival_payload(
852        inputs.formula,
853        inputs.fit_result,
854        inputs.data_schema,
855        &inputs.survival_likelihood_label,
856        None,
857        crate::survival::lognormal_kernel::FrailtySpec::None,
858    );
859    payload.survival_entry = inputs.survival_entry;
860    payload.survival_exit = Some(inputs.survival_exit);
861    payload.survival_event = Some(inputs.survival_event);
862    payload.survivalspec = Some(inputs.survivalspec);
863    if let Some(cause_count) = inputs.cause_count {
864        payload.survival_cause_count = Some(cause_count);
865        payload.survival_endpoint_names = Some(
866            (1..=cause_count)
867                .map(|idx| format!("cause_{idx}"))
868                .collect(),
869        );
870    }
871    payload.survival_baseline_target =
872        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
873    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
874    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
875    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
876    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
877    payload.apply_survival_time_basis(&inputs.time_basis);
878    if let Some(timewiggle) = inputs.timewiggle {
879        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
880        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
881        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
882        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
883        match timewiggle.beta {
884            SurvivalTimewiggleBeta::Single(beta) => {
885                payload.beta_baseline_timewiggle = Some(beta);
886            }
887            SurvivalTimewiggleBeta::ByCause(by_cause) => {
888                payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
889            }
890        }
891    }
892    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
893    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
894    payload.survival_beta_time = inputs.survival_beta_time;
895    payload.resolved_termspec = Some(inputs.resolved_termspec);
896    source.apply_to(&mut payload);
897    payload
898}
899
900/// Source-agnostic semantic content of a survival location-scale
901/// (Royston-Parmar with a learned residual link) saved model. Centralizing
902/// fixes the drift where CLI and FFI disagreed on `formula_noise`,
903/// `baseline_timewiggle_*`, and `survival_noise_projection_ridge_alpha`.
904pub struct SurvivalLocationScaleInputs<'a> {
905    pub formula: String,
906    pub data_schema: DataSchema,
907    /// Fit result with the fitted inverse-link state and link-wiggle artifacts
908    /// already applied by the caller.
909    pub fit_result: UnifiedFitResult,
910    pub fitted_inverse_link: InverseLink,
911    // Independent `Option`s (not an all-or-nothing group) so the assembler
912    // reproduces exactly what the CLI and FFI each persist independently.
913    pub linkwiggle_degree: Option<usize>,
914    pub linkwiggle_knots: Option<Vec<f64>>,
915    pub beta_link_wiggle: Option<Vec<f64>>,
916    pub baseline_timewiggle: Option<SurvivalTimewiggle>,
917    pub survival_entry: Option<String>,
918    pub survival_exit: String,
919    pub survival_event: String,
920    pub survivalspec: String,
921    pub baseline_cfg: SurvivalBaselineConfig,
922    pub time_basis: SavedSurvivalTimeBasis,
923    pub ridge_lambda: f64,
924    pub survival_likelihood_label: String,
925    pub formula_noise: Option<String>,
926    pub survival_beta_time: Vec<f64>,
927    pub survival_beta_threshold: Vec<f64>,
928    pub survival_beta_log_sigma: Vec<f64>,
929    pub noise_transform: &'a ScaleDeviationTransform,
930    pub resolved_thresholdspec: TermCollectionSpec,
931    pub resolved_log_sigmaspec: TermCollectionSpec,
932}
933
934/// Assemble the canonical survival location-scale payload (the single source of
935/// truth for that on-disk contract).
936pub fn assemble_survival_location_scale_payload(
937    inputs: SurvivalLocationScaleInputs<'_>,
938    source: SavedModelSourceMetadata,
939) -> FittedModelPayload {
940    let survival_distribution =
941        residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
942    let mut payload = new_royston_parmar_survival_payload(
943        inputs.formula,
944        inputs.fit_result,
945        inputs.data_schema,
946        &inputs.survival_likelihood_label,
947        survival_distribution,
948        crate::survival::lognormal_kernel::FrailtySpec::None,
949    );
950    payload.link = Some(inputs.fitted_inverse_link);
951    payload.linkwiggle_degree = inputs.linkwiggle_degree;
952    payload.linkwiggle_knots = inputs.linkwiggle_knots;
953    payload.beta_link_wiggle = inputs.beta_link_wiggle;
954    if let Some(timewiggle) = inputs.baseline_timewiggle {
955        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
956        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
957        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
958        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
959        if let SurvivalTimewiggleBeta::Single(beta) = timewiggle.beta {
960            payload.beta_baseline_timewiggle = Some(beta);
961        }
962    }
963    payload.survival_entry = inputs.survival_entry;
964    payload.survival_exit = Some(inputs.survival_exit);
965    payload.survival_event = Some(inputs.survival_event);
966    payload.survivalspec = Some(inputs.survivalspec);
967    payload.survival_baseline_target =
968        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
969    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
970    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
971    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
972    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
973    payload.apply_survival_time_basis(&inputs.time_basis);
974    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
975    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
976    payload.formula_noise = inputs.formula_noise;
977    payload.survival_beta_time = Some(inputs.survival_beta_time);
978    payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
979    payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
980    payload.survival_noise_projection = Some(
981        inputs
982            .noise_transform
983            .projection_coef
984            .rows()
985            .into_iter()
986            .map(|row| row.to_vec())
987            .collect(),
988    );
989    payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
990    payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
991    payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
992    payload.survival_noise_projection_ridge_alpha =
993        Some(inputs.noise_transform.projection_ridge_alpha);
994    payload.survival_distribution = survival_distribution;
995    payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
996    payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
997    source.apply_to(&mut payload);
998    payload
999}
1000
1001/// Source-agnostic semantic content of a latent survival / latent binary saved
1002/// model. The caller resolves the family (splicing the learned latent SD into
1003/// the persisted frailty for survival) and the model-class / likelihood labels.
1004pub struct LatentWindowInputs {
1005    pub formula: String,
1006    pub data_schema: DataSchema,
1007    pub fit_result: UnifiedFitResult,
1008    pub family: FittedFamily,
1009    pub model_class_label: String,
1010    pub likelihood_label: String,
1011    pub survival_entry: Option<String>,
1012    pub survival_exit: String,
1013    pub survival_event: String,
1014    pub baseline_cfg: SurvivalBaselineConfig,
1015    pub time_basis: SavedSurvivalTimeBasis,
1016    pub ridge_lambda: f64,
1017    pub beta_time: Vec<f64>,
1018    pub resolved_termspec: TermCollectionSpec,
1019}
1020
1021/// Assemble the canonical latent survival / latent binary payload.
1022pub fn assemble_latent_window_payload(
1023    inputs: LatentWindowInputs,
1024    source: SavedModelSourceMetadata,
1025) -> FittedModelPayload {
1026    let mut payload = FittedModelPayload::new(
1027        MODEL_PAYLOAD_VERSION,
1028        inputs.formula,
1029        ModelKind::Survival,
1030        inputs.family,
1031        inputs.model_class_label,
1032    );
1033    payload.unified = Some(inputs.fit_result.clone());
1034    payload.fit_result = Some(inputs.fit_result);
1035    payload.data_schema = Some(inputs.data_schema);
1036    payload.survival_entry = inputs.survival_entry;
1037    payload.survival_exit = Some(inputs.survival_exit);
1038    payload.survival_event = Some(inputs.survival_event);
1039    payload.survivalspec = Some("net".to_string());
1040    payload.survival_baseline_target =
1041        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
1042    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
1043    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
1044    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
1045    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
1046    payload.apply_survival_time_basis(&inputs.time_basis);
1047    payload.survival_likelihood = Some(inputs.likelihood_label);
1048    payload.survival_beta_time = Some(inputs.beta_time);
1049    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
1050    payload.resolved_termspec = Some(inputs.resolved_termspec);
1051    source.apply_to(&mut payload);
1052    payload
1053}