Skip to main content

gam_models/inference/
model_payload_builders.rs

1//! Shared, source-agnostic builders for saved-model payloads.
2//!
3//! The CLI (`src/main.rs`) and the Python FFI (`crates/gam-pyffi/src/lib.rs`)
4//! both persist fitted models, and both used to assemble the serialized
5//! [`FittedModelPayload`] independently. That meant the on-disk contract for a
6//! given model kind could silently drift depending on whether the model was
7//! created through the CLI or through Python — exactly the failure mode that
8//! repeatedly bit the marginal-slope save→load path.
9//!
10//! This module assembles the *semantic* payload exactly once. Each caller is
11//! responsible only for the source-specific work of producing the resolved
12//! semantic inputs (the CLI threads them through from its argument parsing and
13//! fit pipeline; the FFI freezes term collections from designs and re-derives
14//! metadata from the [`FitConfig`]). Once both sides hand the same semantic
15//! content to the same assembler, payload drift becomes impossible by
16//! construction.
17
18use gam_solve::estimate::UnifiedFitResult;
19use crate::bms::deviation_runtime::AnchorComponentTag;
20use crate::bms::{
21    DeviationRuntime, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
22};
23use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
24use crate::scale_design::ScaleDeviationTransform;
25use crate::survival::construction::{
26    SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
27};
28use crate::survival::location_scale::{
29    ResidualDistribution, residual_distribution_from_inverse_link,
30};
31use crate::transformation_normal::TransformationNormalFamily;
32use crate::inference::model::{
33    FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind, SavedAnchorComponent,
34    SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization, SavedResidualCascade,
35    SavedSplineScan, TransformationScoreCalibration,
36};
37use gam_terms::smooth::TermCollectionSpec;
38use gam_problem::types::{
39    InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
40};
41use gam_data::DataSchema;
42use ndarray::Array2;
43
44/// Family tag persisted for Bernoulli marginal-slope saved models.
45const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
46
47/// Family tag persisted for transformation-normal saved models.
48const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
49
50/// Serialize an anchored-deviation [`DeviationRuntime`] (score-warp or
51/// link-deviation block) into its persistable [`SavedCompiledFlexBlock`] form.
52///
53/// This is the single source of truth for that conversion; the CLI and FFI
54/// payload builders both route through it so the serialized flex contract
55/// cannot diverge between the two save paths.
56pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
57    let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
58    let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
59    if let Some(installed) = runtime.installed_flex_block() {
60        anchor_correction = Some(
61            installed
62                .anchor_correction
63                .rows()
64                .into_iter()
65                .map(|row| row.to_vec())
66                .collect::<Vec<Vec<f64>>>(),
67        );
68        for component in &installed.anchor_components {
69            anchor_components.push(SavedAnchorComponent {
70                kind: match component {
71                    AnchorComponentTag::Parametric { block, ncols } => {
72                        SavedAnchorKind::Parametric {
73                            block: *block,
74                            ncols: *ncols,
75                        }
76                    }
77                    AnchorComponentTag::FlexEvaluation { ncols } => {
78                        SavedAnchorKind::FlexEvaluation { ncols: *ncols }
79                    }
80                },
81            });
82        }
83    }
84    SavedCompiledFlexBlock {
85        kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
86        breakpoints: runtime.breakpoints().to_vec(),
87        basis_dim: runtime.basis_dim(),
88        span_c0: runtime
89            .span_c0()
90            .rows()
91            .into_iter()
92            .map(|row| row.to_vec())
93            .collect(),
94        span_c1: runtime
95            .span_c1()
96            .rows()
97            .into_iter()
98            .map(|row| row.to_vec())
99            .collect(),
100        span_c2: runtime
101            .span_c2()
102            .rows()
103            .into_iter()
104            .map(|row| row.to_vec())
105            .collect(),
106        span_c3: runtime
107            .span_c3()
108            .rows()
109            .into_iter()
110            .map(|row| row.to_vec())
111            .collect(),
112        anchor_correction,
113        anchor_components,
114    }
115}
116
117/// Source-specific metadata that the CLI and FFI populate differently but that
118/// every saved payload carries.
119///
120/// `training_feature_ranges` is the only field the FFI path cannot currently
121/// supply (it persists headers without per-feature ranges); modeling it as
122/// `Option` keeps that distinction explicit instead of silently encoding an
123/// empty vector as if ranges were known.
124pub struct SavedModelSourceMetadata {
125    pub training_headers: Vec<String>,
126    pub training_feature_ranges: Option<Vec<(f64, f64)>>,
127    pub offset_column: Option<String>,
128    pub noise_offset_column: Option<String>,
129}
130
131impl SavedModelSourceMetadata {
132    fn apply_to(self, payload: &mut FittedModelPayload) {
133        match self.training_feature_ranges {
134            Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
135            None => payload.training_headers = Some(self.training_headers),
136        }
137        payload.offset_column = self.offset_column;
138        payload.noise_offset_column = self.noise_offset_column;
139    }
140}
141
142/// The resolved, source-agnostic semantic content of a Bernoulli
143/// marginal-slope saved model.
144///
145/// The CLI threads these in directly from its fit pipeline; the FFI produces
146/// them by freezing its term collections and reading the [`FitConfig`]. Either
147/// way, the assembler below turns them into the canonical payload.
148pub struct BernoulliMarginalSlopeInputs<'a> {
149    pub formula: String,
150    pub data_schema: DataSchema,
151    pub logslope_formula: String,
152    pub z_column: String,
153    pub resolved_marginalspec: TermCollectionSpec,
154    pub resolved_logslopespec: TermCollectionSpec,
155    pub fit_result: UnifiedFitResult,
156    /// Number of *raw* marginal design columns `p_m` (= the term-collection
157    /// marginal design's `ncols()` BEFORE any #461 influence-absorber widening).
158    ///
159    /// When the Stage-1 influence absorber is active (A2), the fitted marginal
160    /// block carries the widened coefficient `[β_m; γ]` (length `p_m + p₁`) and
161    /// the joint covariance is dimensioned over the widened block. The absorbed
162    /// influence columns `Z̃_infl` are a TRAINING-only leakage absorber that does
163    /// not exist at predict rows, so the persisted model must drop `γ` and the
164    /// marginalized-out covariance sub-block to stay self-consistent against the
165    /// raw `p_m` marginal design at predict. The assembler uses this to truncate
166    /// the fit result once (shared CLI + FFI). With no absorber it equals the
167    /// fitted block width and the truncation is a no-op.
168    pub p_marginal: usize,
169    pub baseline_marginal: f64,
170    pub baseline_logslope: f64,
171    pub latent_z_normalization: SavedLatentZNormalization,
172    pub latent_measure: LatentMeasureKind,
173    pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
174    pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
175    pub score_warp_runtime: Option<&'a DeviationRuntime>,
176    pub link_dev_runtime: Option<&'a DeviationRuntime>,
177    pub base_link: InverseLink,
178    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
179}
180
181/// Drop the #461 training-only influence-absorber coefficients `γ` from a fitted
182/// Bernoulli marginal-slope result so the persisted model is self-consistent
183/// against the raw `p_m`-column marginal design at predict.
184///
185/// When the A2 influence absorber is active the marginal block (block 0) is the
186/// widened `[β_m; γ]` (length `p_m + p₁`, with `γ` the contiguous trailing `p₁`
187/// columns — see bms `widen_marginal_dense_with_influence`) and the joint
188/// conditional covariance is dimensioned over the widened joint coefficient
189/// vector. The absorbed columns `Z̃_infl` exist only at training rows; predict
190/// reconstructs the marginal index from the raw `p_m` design and the
191/// orthogonalized `β̂_m` is a property of the training fit. So this:
192///
193///  * slices `blocks[0].beta` and `block_states[0].beta` to their first `p_m`
194///    entries (the flat `beta` is recomputed from the blocks by
195///    `try_from_parts`),
196///  * **marginalizes** `γ` out of the joint Gaussian by dropping the `γ`
197///    rows/cols from the conditional covariance — taking the corresponding
198///    SUB-BLOCK of `Σ` is the exact marginal of a joint Gaussian (no
199///    re-inversion), so the kept `[β_m | β_logslope | …]` covariance is the
200///    correct predictive uncertainty accounting for the fitted absorber,
201///  * drops the persisted joint penalized-Hessian geometry: it is a precision
202///    over the *widened* joint coefficient vector, so a sub-block would be the
203///    wrong marginalization, and the only predict path that consumes it is the
204///    covariance-fallback that re-inverts `H` — which post-truncation would have
205///    the wrong dimension anyway. With the dense (already-marginalized) `Σ`
206///    matching the predict dimension, that fallback is never taken, so dropping
207///    the geometry removes a stale, wrong-dimension path rather than a used one.
208///
209/// Block-level `edf` / `lambdas` are left untouched: they are fitted scalars
210/// that legitimately reflect the full model (the absorber consumed real dof at
211/// fit time) and are persisted as-is. With no absorber (`block0.len() == p_m`)
212/// this is a no-op clone.
213fn truncate_marginal_slope_influence_absorber(
214    fit_result: UnifiedFitResult,
215    p_marginal: usize,
216) -> Result<UnifiedFitResult, String> {
217    let Some(block0) = fit_result.blocks.first() else {
218        return Err("marginal-slope fit result has no coefficient blocks".to_string());
219    };
220    let widened_len = block0.beta.len();
221    if widened_len <= p_marginal {
222        // No influence absorber installed (or already raw width): nothing to drop.
223        return Ok(fit_result);
224    }
225    let p_influence = widened_len - p_marginal;
226
227    let UnifiedFitResult {
228        mut blocks,
229        log_lambdas,
230        lambdas,
231        likelihood_family,
232        likelihood_scale,
233        log_likelihood_normalization,
234        log_likelihood,
235        deviance,
236        reml_score,
237        stable_penalty_term,
238        penalized_objective,
239        used_device,
240        outer_iterations,
241        outer_converged,
242        outer_gradient_norm,
243        standard_deviation,
244        covariance_conditional,
245        covariance_corrected,
246        inference,
247        fitted_link,
248        geometry: _,
249        mut block_states,
250        beta: _,
251        pirls_status,
252        max_abs_eta,
253        constraint_kkt,
254        artifacts,
255        inner_cycles,
256        outer_cost_evals: _,
257    } = fit_result;
258
259    // Slice block 0's coefficients (and matching block-state) to the raw p_m,
260    // dropping the trailing γ absorber columns.
261    blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
262    if let Some(state0) = block_states.first_mut() {
263        state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
264    }
265
266    // Marginalize γ out of the joint conditional covariance: keep every index
267    // except the contiguous γ block [p_marginal, p_marginal + p_influence).
268    let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
269        cov.map(|cov| {
270            let total = cov.nrows();
271            let kept: Vec<usize> = (0..p_marginal)
272                .chain((p_marginal + p_influence)..total)
273                .collect();
274            let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
275            for (ri, &r) in kept.iter().enumerate() {
276                for (ci, &c) in kept.iter().enumerate() {
277                    out[[ri, ci]] = cov[[r, c]];
278                }
279            }
280            out
281        })
282    };
283    let covariance_conditional = drop_gamma_block(covariance_conditional);
284    let covariance_corrected = drop_gamma_block(covariance_corrected);
285
286    UnifiedFitResult::try_from_parts(gam_solve::estimate::UnifiedFitResultParts {
287        blocks,
288        log_lambdas,
289        lambdas,
290        likelihood_family,
291        likelihood_scale,
292        log_likelihood_normalization,
293        log_likelihood,
294        deviance,
295        reml_score,
296        stable_penalty_term,
297        penalized_objective,
298        // Preserve the GPU-execution flag across the absorber-column
299        // truncation: dropping the trailing γ columns does not change which
300        // device ran the solve.
301        used_device,
302        outer_iterations,
303        outer_converged,
304        outer_gradient_norm,
305        standard_deviation,
306        covariance_conditional,
307        covariance_corrected,
308        inference,
309        fitted_link,
310        // Drop the widened-joint penalized Hessian: see the doc comment.
311        geometry: None,
312        block_states,
313        pirls_status,
314        max_abs_eta,
315        constraint_kkt,
316        artifacts,
317        inner_cycles,
318    })
319    .map_err(|e| {
320        format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
321    })
322}
323
324/// Assemble the canonical spline-scan payload (#1030/#1034): a standard
325/// Gaussian-identity model whose fit representation is the exact O(n)
326/// smoothing-spline smoother state instead of a dense `fit_result`. The CLI
327/// and FFI save paths both route through here so the scan on-disk contract
328/// cannot diverge between sources.
329pub fn assemble_spline_scan_payload(
330    formula: String,
331    feature_column: String,
332    fit: &gam_solve::spline_scan::SplineScanFit,
333    data_schema: DataSchema,
334    training_headers: Vec<String>,
335    training_feature_ranges: Vec<(f64, f64)>,
336) -> FittedModelPayload {
337    let mut payload = FittedModelPayload::new(
338        MODEL_PAYLOAD_VERSION,
339        formula,
340        ModelKind::Standard,
341        FittedFamily::Standard {
342            likelihood: LikelihoodSpec::gaussian_identity(),
343            link: None,
344            latent_cloglog_state: None,
345            mixture_state: None,
346            sas_state: None,
347        },
348        "gaussian".to_string(),
349    );
350    payload.spline_scan = Some(SavedSplineScan {
351        feature_column,
352        state: fit.to_state(),
353    });
354    payload.data_schema = Some(data_schema);
355    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
356    payload
357}
358
359/// Assemble the canonical residual-cascade payload (#1032).
360///
361/// The CLI and FFI save paths both route through here so the cascade on-disk
362/// contract cannot diverge between sources.  Mirrors `assemble_spline_scan_payload`
363/// but for d ∈ {2,3} scattered coordinates (the Wendland multilevel-frame state).
364pub fn assemble_residual_cascade_payload(
365    formula: String,
366    feature_columns: Vec<String>,
367    fit: &gam_solve::residual_cascade::ResidualCascadeFit,
368    data_schema: DataSchema,
369    training_headers: Vec<String>,
370    training_feature_ranges: Vec<(f64, f64)>,
371) -> Result<FittedModelPayload, String> {
372    let mut payload = FittedModelPayload::new(
373        MODEL_PAYLOAD_VERSION,
374        formula,
375        ModelKind::Standard,
376        FittedFamily::Standard {
377            likelihood: gam_problem::types::LikelihoodSpec::gaussian_identity(),
378            link: None,
379            latent_cloglog_state: None,
380            mixture_state: None,
381            sas_state: None,
382        },
383        "gaussian".to_string(),
384    );
385    payload.residual_cascade = Some(SavedResidualCascade {
386        feature_columns,
387        state: fit.to_state().map_err(|e| {
388            format!("residual-cascade to_state failed during payload assembly: {e}")
389        })?,
390    });
391    payload.data_schema = Some(data_schema);
392    payload.set_training_feature_metadata(training_headers, training_feature_ranges);
393    Ok(payload)
394}
395
396/// Assemble the canonical Bernoulli marginal-slope payload.
397///
398/// This is the single place that decides which payload fields a marginal-slope
399/// model carries and how the singular/vector mirror fields
400/// (`formula_logslope(s)`, `z_column(s)`, `logslope_baseline(s)`,
401/// `resolved_termspec_logslope(s)`) are kept consistent — so the CLI and FFI
402/// saved models are byte-equivalent for identical semantic content.
403pub fn assemble_bernoulli_marginal_slope_payload(
404    inputs: BernoulliMarginalSlopeInputs<'_>,
405    source: SavedModelSourceMetadata,
406) -> Result<FittedModelPayload, String> {
407    let BernoulliMarginalSlopeInputs {
408        formula,
409        data_schema,
410        logslope_formula,
411        z_column,
412        resolved_marginalspec,
413        resolved_logslopespec,
414        fit_result,
415        p_marginal,
416        baseline_marginal,
417        baseline_logslope,
418        latent_z_normalization,
419        latent_measure,
420        latent_z_rank_int_calibration,
421        latent_z_conditional_calibration,
422        score_warp_runtime,
423        link_dev_runtime,
424        base_link,
425        frailty,
426    } = inputs;
427
428    // #461 predict seam: drop the training-only influence-absorber γ (and
429    // marginalize it out of the covariance) so the persisted model matches the
430    // raw p_m marginal design at predict. No-op when the absorber is inactive.
431    let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
432
433    let marginal_likelihood_spec =
434        inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
435
436    let mut payload = FittedModelPayload::new(
437        MODEL_PAYLOAD_VERSION,
438        formula,
439        ModelKind::MarginalSlope,
440        FittedFamily::MarginalSlope {
441            likelihood: marginal_likelihood_spec,
442            base_link: base_link.clone(),
443            frailty,
444        },
445        FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
446    );
447    payload.unified = Some(fit_result.clone());
448    payload.fit_result = Some(fit_result);
449    payload.data_schema = Some(data_schema);
450    payload.formula_logslope = Some(logslope_formula.clone());
451    payload.z_column = Some(z_column.clone());
452    payload.formula_logslopes = Some(vec![logslope_formula]);
453    payload.z_columns = Some(vec![z_column]);
454    payload.latent_z_normalization = Some(latent_z_normalization);
455    payload.latent_measure = Some(latent_measure);
456    payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
457    payload.latent_z_conditional_calibration = latent_z_conditional_calibration;
458    payload.marginal_baseline = Some(baseline_marginal);
459    payload.logslope_baseline = Some(baseline_logslope);
460    payload.logslope_baselines = Some(vec![baseline_logslope]);
461    payload.link = Some(base_link);
462    payload.resolved_termspec = Some(resolved_marginalspec);
463    payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
464    payload.resolved_termspec_logslope = Some(resolved_logslopespec);
465    payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
466    payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
467    source.apply_to(&mut payload);
468    Ok(payload)
469}
470
471/// The resolved, source-agnostic semantic content of a transformation-normal
472/// saved model.
473///
474/// As with the marginal-slope inputs, the CLI threads the family and resolved
475/// covariate spec straight from its fit pipeline while the FFI reads them off
476/// its fit-result struct (freezing the covariate spec from its design first).
477pub struct TransformationNormalInputs<'a> {
478    pub formula: String,
479    pub data_schema: DataSchema,
480    pub resolved_covariate_spec: TermCollectionSpec,
481    pub fit_result: UnifiedFitResult,
482    pub family: &'a TransformationNormalFamily,
483    pub score_calibration: TransformationScoreCalibration,
484}
485
486/// Assemble the canonical transformation-normal payload.
487///
488/// Centralizing the response-transform snapshot (`knots`, `transform`,
489/// `degree`, `median`) and the fixed Gaussian-identity likelihood means the CLI
490/// and FFI cannot encode a transformation-normal model two different ways.
491pub fn assemble_transformation_normal_payload(
492    inputs: TransformationNormalInputs<'_>,
493    source: SavedModelSourceMetadata,
494) -> FittedModelPayload {
495    let TransformationNormalInputs {
496        formula,
497        data_schema,
498        resolved_covariate_spec,
499        fit_result,
500        family,
501        score_calibration,
502    } = inputs;
503
504    let mut payload = FittedModelPayload::new(
505        MODEL_PAYLOAD_VERSION,
506        formula,
507        ModelKind::TransformationNormal,
508        FittedFamily::TransformationNormal {
509            likelihood: LikelihoodSpec::new(
510                ResponseFamily::Gaussian,
511                InverseLink::Standard(StandardLink::Identity),
512            ),
513        },
514        FAMILY_TRANSFORMATION_NORMAL.to_string(),
515    );
516    payload.unified = Some(fit_result.clone());
517    payload.fit_result = Some(fit_result);
518    payload.data_schema = Some(data_schema);
519    payload.resolved_termspec = Some(resolved_covariate_spec);
520    payload.transformation_response_knots = Some(family.response_knots().to_vec());
521    payload.transformation_response_transform = Some(
522        family
523            .response_transform()
524            .rows()
525            .into_iter()
526            .map(|row| row.to_vec())
527            .collect(),
528    );
529    payload.transformation_response_degree = Some(family.response_degree());
530    payload.transformation_response_median = Some(family.response_median());
531    payload.transformation_score_calibration = Some(score_calibration);
532    source.apply_to(&mut payload);
533    payload
534}
535
536/// Which likelihood a (non-survival) location-scale model carries: Gaussian
537/// (residual response scale) or binomial (noise scale-deviation transform whose
538/// likelihood is resolved from the inverse link). The assembler resolves the
539/// `FittedFamily` from this once, rather than each save path stamping a
540/// (potentially wrong) likelihood and patching it afterwards.
541pub enum LocationScaleResponse<'a> {
542    /// Gaussian identity; `base_link` is the optional resolved base link the CLI
543    /// may pass through from `link(...)` (the FFI leaves it `None`).
544    Gaussian {
545        response_scale: f64,
546        base_link: Option<InverseLink>,
547    },
548    /// Binomial under `link`, with the encoded noise scale-deviation transform.
549    Binomial {
550        link: InverseLink,
551        noise_transform: &'a ScaleDeviationTransform,
552    },
553    /// A genuine-dispersion mean family (NegativeBinomial / Gamma / Beta /
554    /// Tweedie) whose log-precision channel carries `noise_formula` (#913). The
555    /// `likelihood` is the family's own [`LikelihoodSpec`]; `base_link` is the
556    /// mean inverse link (log, or logit for Beta). The log-precision block
557    /// coefficients ride in [`LocationScaleInputs::beta_noise`].
558    Dispersion {
559        likelihood: LikelihoodSpec,
560        base_link: InverseLink,
561        family_tag: &'static str,
562    },
563}
564
565/// Optional link-wiggle metadata persisted alongside a location-scale model.
566/// Knots/coefficients are already in raw response units — the Gaussian
567/// standardization and its inverse remap live inside
568/// `fit_gaussian_location_scale_model`, so the save path persists them verbatim.
569pub struct LocationScaleWiggle {
570    pub knots: Vec<f64>,
571    pub degree: usize,
572    pub beta_link_wiggle: Vec<f64>,
573}
574
575/// Source-agnostic semantic content of a (non-survival) location-scale saved
576/// model — the shared core behind the CLI's Gaussian/binomial save paths and
577/// the FFI's two location-scale builders.
578pub struct LocationScaleInputs {
579    pub formula: String,
580    pub data_schema: DataSchema,
581    pub noise_formula: String,
582    pub resolved_termspec: TermCollectionSpec,
583    pub resolved_termspec_noise: TermCollectionSpec,
584    pub fit_result: UnifiedFitResult,
585    pub beta_noise: Option<Vec<f64>>,
586    pub wiggle: Option<LocationScaleWiggle>,
587}
588
589/// Assemble the canonical (non-survival) location-scale payload — single source
590/// of truth for that on-disk contract. The family/likelihood is resolved from
591/// the [`LocationScaleResponse`] so the binomial branch never persists a wrong
592/// probit likelihood that a caller must patch afterwards.
593pub fn assemble_location_scale_payload(
594    inputs: LocationScaleInputs,
595    response: LocationScaleResponse<'_>,
596    source: SavedModelSourceMetadata,
597) -> Result<FittedModelPayload, String> {
598    let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
599    {
600        LocationScaleResponse::Gaussian {
601            response_scale,
602            base_link,
603        } => (
604            "gaussian-location-scale".to_string(),
605            LikelihoodSpec::gaussian_identity(),
606            // Gaussian location-scale does not carry a base link in its family
607            // state; the resolved link is persisted in `payload.link` below so
608            // prediction can recover it.
609            None,
610            Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
611            Some(response_scale),
612            None,
613        ),
614        LocationScaleResponse::Binomial {
615            link,
616            noise_transform,
617        } => {
618            let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
619                format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
620            })?;
621            (
622                "binomial-location-scale".to_string(),
623                likelihood,
624                Some(link.clone()),
625                Some(link),
626                None,
627                Some(noise_transform),
628            )
629        }
630        LocationScaleResponse::Dispersion {
631            likelihood,
632            base_link,
633            family_tag,
634        } => (
635            family_tag.to_string(),
636            likelihood,
637            Some(base_link.clone()),
638            Some(base_link),
639            None,
640            None,
641        ),
642    };
643
644    let mut payload = FittedModelPayload::new(
645        MODEL_PAYLOAD_VERSION,
646        inputs.formula,
647        ModelKind::LocationScale,
648        FittedFamily::LocationScale {
649            likelihood,
650            base_link,
651        },
652        family_tag,
653    );
654    payload.unified = Some(inputs.fit_result.clone());
655    payload.fit_result = Some(inputs.fit_result);
656    payload.data_schema = Some(inputs.data_schema);
657    payload.link = link;
658    payload.formula_noise = Some(inputs.noise_formula);
659    payload.beta_noise = inputs.beta_noise;
660    payload.gaussian_response_scale = response_scale;
661    if let Some(transform) = noise_transform {
662        payload.noise_projection = Some(
663            transform
664                .projection_coef
665                .rows()
666                .into_iter()
667                .map(|row| row.to_vec())
668                .collect(),
669        );
670        payload.noise_center = Some(transform.weighted_column_mean.to_vec());
671        payload.noise_scale = Some(transform.rescale.to_vec());
672        payload.noise_non_intercept_start = Some(transform.non_intercept_start);
673        payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
674    }
675    payload.resolved_termspec = Some(inputs.resolved_termspec);
676    payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
677    if let Some(wiggle) = inputs.wiggle {
678        payload.linkwiggle_knots = Some(wiggle.knots);
679        payload.linkwiggle_degree = Some(wiggle.degree);
680        payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
681    }
682    source.apply_to(&mut payload);
683    Ok(payload)
684}
685
686/// Source-agnostic semantic content of a survival marginal-slope
687/// (Royston-Parmar net) saved model. Centralizing assembly also fixes the
688/// FFI's prior omission of the `*_logslopes`/`*_columns`/`formula_logslopes`
689/// vector mirrors the CLI wrote.
690pub struct SurvivalMarginalSlopeInputs<'a> {
691    pub formula: String,
692    pub data_schema: DataSchema,
693    pub fit_result: UnifiedFitResult,
694    pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
695    pub survival_entry: Option<String>,
696    pub survival_exit: String,
697    pub survival_event: String,
698    pub survivalspec: String,
699    pub baseline_cfg: SurvivalBaselineConfig,
700    pub time_basis: SavedSurvivalTimeBasis,
701    pub ridge_lambda: f64,
702    pub survival_likelihood_label: String,
703    pub resolved_marginalspec: TermCollectionSpec,
704    pub resolved_logslopespec: TermCollectionSpec,
705    pub logslope_formula: String,
706    pub z_column: String,
707    pub latent_z_normalization: SavedLatentZNormalization,
708    pub baseline_logslope: f64,
709    pub score_warp_runtime: Option<&'a DeviationRuntime>,
710    pub link_dev_runtime: Option<&'a DeviationRuntime>,
711    /// Width `p₁` of the absorbed Stage-1 influence block (#461) when the fit
712    /// hosted a dedicated additive absorber. Predict drops the absorber's `γ`;
713    /// this is persisted only so the predictor accounts for the extra trailing
714    /// block in the saved block count.
715    pub influence_absorber_width: Option<usize>,
716}
717
718/// Construct a Royston-Parmar survival [`FittedModelPayload`] through the
719/// canonical `Survival` family scaffold shared by every RP on-disk contract
720/// (marginal-slope, transformation, location-scale): the identity-link
721/// `RoystonParmar` likelihood, the persisted likelihood label, and the
722/// `fit_result` / `data_schema` install. Callers supply the two variants that
723/// differ — `survival_distribution` and `frailty` — and then set their own
724/// family-specific fields on the returned payload.
725fn new_royston_parmar_survival_payload(
726    formula: String,
727    fit_result: UnifiedFitResult,
728    data_schema: DataSchema,
729    survival_likelihood_label: &str,
730    survival_distribution: Option<ResidualDistribution>,
731    frailty: crate::survival::lognormal_kernel::FrailtySpec,
732) -> FittedModelPayload {
733    let mut payload = FittedModelPayload::new(
734        MODEL_PAYLOAD_VERSION,
735        formula,
736        ModelKind::Survival,
737        FittedFamily::Survival {
738            likelihood: LikelihoodSpec::new(
739                ResponseFamily::RoystonParmar,
740                InverseLink::Standard(StandardLink::Identity),
741            ),
742            survival_likelihood: Some(survival_likelihood_label.to_string()),
743            survival_distribution,
744            frailty,
745        },
746        ResponseFamily::RoystonParmar.name().to_string(),
747    );
748    payload.unified = Some(fit_result.clone());
749    payload.fit_result = Some(fit_result);
750    payload.data_schema = Some(data_schema);
751    payload
752}
753
754/// Assemble the canonical survival marginal-slope payload — single source of
755/// truth for that Royston-Parmar / Gaussian-residual on-disk contract.
756pub fn assemble_survival_marginal_slope_payload(
757    inputs: SurvivalMarginalSlopeInputs<'_>,
758    source: SavedModelSourceMetadata,
759) -> FittedModelPayload {
760    let mut payload = new_royston_parmar_survival_payload(
761        inputs.formula,
762        inputs.fit_result,
763        inputs.data_schema,
764        &inputs.survival_likelihood_label,
765        Some(ResidualDistribution::Gaussian),
766        inputs.frailty,
767    );
768    payload.survival_entry = inputs.survival_entry;
769    payload.survival_exit = Some(inputs.survival_exit);
770    payload.survival_event = Some(inputs.survival_event);
771    payload.survivalspec = Some(inputs.survivalspec);
772    payload.survival_baseline_target =
773        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
774    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
775    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
776    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
777    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
778    payload.apply_survival_time_basis(&inputs.time_basis);
779    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
780    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
781    payload.survival_distribution = Some(ResidualDistribution::Gaussian);
782    payload.link = Some(InverseLink::Standard(StandardLink::Probit));
783    payload.resolved_termspec = Some(inputs.resolved_marginalspec);
784    payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
785    payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
786    payload.formula_logslope = Some(inputs.logslope_formula.clone());
787    payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
788    payload.z_column = Some(inputs.z_column.clone());
789    payload.z_columns = Some(vec![inputs.z_column]);
790    payload.latent_z_normalization = Some(inputs.latent_z_normalization);
791    payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
792    payload.logslope_baseline = Some(inputs.baseline_logslope);
793    payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
794    payload.score_warp_runtime = inputs
795        .score_warp_runtime
796        .map(serialize_anchored_deviation_runtime);
797    payload.link_deviation_runtime = inputs
798        .link_dev_runtime
799        .map(serialize_anchored_deviation_runtime);
800    payload.influence_absorber_width = inputs.influence_absorber_width;
801    source.apply_to(&mut payload);
802    payload
803}
804
805/// Fitted baseline-timewiggle coefficients: a single block (net) or one per
806/// cause (joint cause-specific). Callers pass already-sliced coefficients.
807pub enum SurvivalTimewiggleBeta {
808    Single(Vec<f64>),
809    ByCause(Vec<Vec<f64>>),
810}
811
812/// Snapshot of the baseline-timewiggle block persisted with a survival model.
813pub struct SurvivalTimewiggle {
814    pub degree: usize,
815    pub knots: Vec<f64>,
816    pub penalty_orders: Option<Vec<usize>>,
817    pub double_penalty: Option<bool>,
818    pub beta: SurvivalTimewiggleBeta,
819}
820
821/// Source-agnostic semantic content of a survival transformation
822/// (Royston-Parmar) saved model — net single-cause or joint cause-specific.
823pub struct SurvivalTransformationInputs {
824    pub formula: String,
825    pub data_schema: DataSchema,
826    pub fit_result: UnifiedFitResult,
827    pub survival_entry: Option<String>,
828    pub survival_exit: String,
829    pub survival_event: String,
830    pub survivalspec: String,
831    /// `None` = net single-cause; `Some(n)` persists `survival_cause_count` and
832    /// `cause_1..cause_n` endpoint names.
833    pub cause_count: Option<usize>,
834    pub baseline_cfg: SurvivalBaselineConfig,
835    pub time_basis: SavedSurvivalTimeBasis,
836    pub ridge_lambda: f64,
837    pub survival_likelihood_label: String,
838    pub resolved_termspec: TermCollectionSpec,
839    /// Rigid time-block beta, persisted only by the cause-specific CLI path.
840    pub survival_beta_time: Option<Vec<f64>>,
841    pub timewiggle: Option<SurvivalTimewiggle>,
842}
843
844/// Assemble the canonical survival transformation payload — single source of
845/// truth for the Royston-Parmar transformation on-disk contract.
846pub fn assemble_survival_transformation_payload(
847    inputs: SurvivalTransformationInputs,
848    source: SavedModelSourceMetadata,
849) -> FittedModelPayload {
850    let mut payload = new_royston_parmar_survival_payload(
851        inputs.formula,
852        inputs.fit_result,
853        inputs.data_schema,
854        &inputs.survival_likelihood_label,
855        None,
856        crate::survival::lognormal_kernel::FrailtySpec::None,
857    );
858    payload.survival_entry = inputs.survival_entry;
859    payload.survival_exit = Some(inputs.survival_exit);
860    payload.survival_event = Some(inputs.survival_event);
861    payload.survivalspec = Some(inputs.survivalspec);
862    if let Some(cause_count) = inputs.cause_count {
863        payload.survival_cause_count = Some(cause_count);
864        payload.survival_endpoint_names = Some(
865            (1..=cause_count)
866                .map(|idx| format!("cause_{idx}"))
867                .collect(),
868        );
869    }
870    payload.survival_baseline_target =
871        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
872    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
873    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
874    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
875    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
876    payload.apply_survival_time_basis(&inputs.time_basis);
877    if let Some(timewiggle) = inputs.timewiggle {
878        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
879        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
880        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
881        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
882        match timewiggle.beta {
883            SurvivalTimewiggleBeta::Single(beta) => {
884                payload.beta_baseline_timewiggle = Some(beta);
885            }
886            SurvivalTimewiggleBeta::ByCause(by_cause) => {
887                payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
888            }
889        }
890    }
891    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
892    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
893    payload.survival_beta_time = inputs.survival_beta_time;
894    payload.resolved_termspec = Some(inputs.resolved_termspec);
895    source.apply_to(&mut payload);
896    payload
897}
898
899/// Source-agnostic semantic content of a survival location-scale
900/// (Royston-Parmar with a learned residual link) saved model. Centralizing
901/// fixes the drift where CLI and FFI disagreed on `formula_noise`,
902/// `baseline_timewiggle_*`, and `survival_noise_projection_ridge_alpha`.
903pub struct SurvivalLocationScaleInputs<'a> {
904    pub formula: String,
905    pub data_schema: DataSchema,
906    /// Fit result with the fitted inverse-link state and link-wiggle artifacts
907    /// already applied by the caller.
908    pub fit_result: UnifiedFitResult,
909    pub fitted_inverse_link: InverseLink,
910    // Independent `Option`s (not an all-or-nothing group) so the assembler
911    // reproduces exactly what the CLI and FFI each persist independently.
912    pub linkwiggle_degree: Option<usize>,
913    pub linkwiggle_knots: Option<Vec<f64>>,
914    pub beta_link_wiggle: Option<Vec<f64>>,
915    pub baseline_timewiggle: Option<SurvivalTimewiggle>,
916    pub survival_entry: Option<String>,
917    pub survival_exit: String,
918    pub survival_event: String,
919    pub survivalspec: String,
920    pub baseline_cfg: SurvivalBaselineConfig,
921    pub time_basis: SavedSurvivalTimeBasis,
922    pub ridge_lambda: f64,
923    pub survival_likelihood_label: String,
924    pub formula_noise: Option<String>,
925    pub survival_beta_time: Vec<f64>,
926    pub survival_beta_threshold: Vec<f64>,
927    pub survival_beta_log_sigma: Vec<f64>,
928    pub noise_transform: &'a ScaleDeviationTransform,
929    pub resolved_thresholdspec: TermCollectionSpec,
930    pub resolved_log_sigmaspec: TermCollectionSpec,
931}
932
933/// Assemble the canonical survival location-scale payload (the single source of
934/// truth for that on-disk contract).
935pub fn assemble_survival_location_scale_payload(
936    inputs: SurvivalLocationScaleInputs<'_>,
937    source: SavedModelSourceMetadata,
938) -> FittedModelPayload {
939    let survival_distribution =
940        residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
941    let mut payload = new_royston_parmar_survival_payload(
942        inputs.formula,
943        inputs.fit_result,
944        inputs.data_schema,
945        &inputs.survival_likelihood_label,
946        survival_distribution,
947        crate::survival::lognormal_kernel::FrailtySpec::None,
948    );
949    payload.link = Some(inputs.fitted_inverse_link);
950    payload.linkwiggle_degree = inputs.linkwiggle_degree;
951    payload.linkwiggle_knots = inputs.linkwiggle_knots;
952    payload.beta_link_wiggle = inputs.beta_link_wiggle;
953    if let Some(timewiggle) = inputs.baseline_timewiggle {
954        payload.baseline_timewiggle_degree = Some(timewiggle.degree);
955        payload.baseline_timewiggle_knots = Some(timewiggle.knots);
956        payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
957        payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
958        if let SurvivalTimewiggleBeta::Single(beta) = timewiggle.beta {
959            payload.beta_baseline_timewiggle = Some(beta);
960        }
961    }
962    payload.survival_entry = inputs.survival_entry;
963    payload.survival_exit = Some(inputs.survival_exit);
964    payload.survival_event = Some(inputs.survival_event);
965    payload.survivalspec = Some(inputs.survivalspec);
966    payload.survival_baseline_target =
967        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
968    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
969    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
970    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
971    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
972    payload.apply_survival_time_basis(&inputs.time_basis);
973    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
974    payload.survival_likelihood = Some(inputs.survival_likelihood_label);
975    payload.formula_noise = inputs.formula_noise;
976    payload.survival_beta_time = Some(inputs.survival_beta_time);
977    payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
978    payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
979    payload.survival_noise_projection = Some(
980        inputs
981            .noise_transform
982            .projection_coef
983            .rows()
984            .into_iter()
985            .map(|row| row.to_vec())
986            .collect(),
987    );
988    payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
989    payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
990    payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
991    payload.survival_noise_projection_ridge_alpha =
992        Some(inputs.noise_transform.projection_ridge_alpha);
993    payload.survival_distribution = survival_distribution;
994    payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
995    payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
996    source.apply_to(&mut payload);
997    payload
998}
999
1000/// Source-agnostic semantic content of a latent survival / latent binary saved
1001/// model. The caller resolves the family (splicing the learned latent SD into
1002/// the persisted frailty for survival) and the model-class / likelihood labels.
1003pub struct LatentWindowInputs {
1004    pub formula: String,
1005    pub data_schema: DataSchema,
1006    pub fit_result: UnifiedFitResult,
1007    pub family: FittedFamily,
1008    pub model_class_label: String,
1009    pub likelihood_label: String,
1010    pub survival_entry: Option<String>,
1011    pub survival_exit: String,
1012    pub survival_event: String,
1013    pub baseline_cfg: SurvivalBaselineConfig,
1014    pub time_basis: SavedSurvivalTimeBasis,
1015    pub ridge_lambda: f64,
1016    pub beta_time: Vec<f64>,
1017    pub resolved_termspec: TermCollectionSpec,
1018}
1019
1020/// Assemble the canonical latent survival / latent binary payload.
1021pub fn assemble_latent_window_payload(
1022    inputs: LatentWindowInputs,
1023    source: SavedModelSourceMetadata,
1024) -> FittedModelPayload {
1025    let mut payload = FittedModelPayload::new(
1026        MODEL_PAYLOAD_VERSION,
1027        inputs.formula,
1028        ModelKind::Survival,
1029        inputs.family,
1030        inputs.model_class_label,
1031    );
1032    payload.unified = Some(inputs.fit_result.clone());
1033    payload.fit_result = Some(inputs.fit_result);
1034    payload.data_schema = Some(inputs.data_schema);
1035    payload.survival_entry = inputs.survival_entry;
1036    payload.survival_exit = Some(inputs.survival_exit);
1037    payload.survival_event = Some(inputs.survival_event);
1038    payload.survivalspec = Some("net".to_string());
1039    payload.survival_baseline_target =
1040        Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
1041    payload.survival_baseline_scale = inputs.baseline_cfg.scale;
1042    payload.survival_baseline_shape = inputs.baseline_cfg.shape;
1043    payload.survival_baseline_rate = inputs.baseline_cfg.rate;
1044    payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
1045    payload.apply_survival_time_basis(&inputs.time_basis);
1046    payload.survival_likelihood = Some(inputs.likelihood_label);
1047    payload.survival_beta_time = Some(inputs.beta_time);
1048    payload.survivalridge_lambda = Some(inputs.ridge_lambda);
1049    payload.resolved_termspec = Some(inputs.resolved_termspec);
1050    source.apply_to(&mut payload);
1051    payload
1052}