Skip to main content

gam_models/
protocol.rs

1use crate::bms::{
2    DEFAULT_EMPIRICAL_LATENT_GRID_SIZE, DeviationBlockConfig, LatentMeasureSpec, LatentZCheckMode,
3    LatentZNormalizationMode, LatentZPolicy,
4};
5use crate::survival::construction::SurvivalBaselineTarget;
6use gam_problem::{InverseLink, StandardLink};
7
8/// Calibration semantics for the latent score `z` consumed by marginal-slope
9/// families. Every variant is fully effective — there are no silently-ignored
10/// metadata fields.
11#[derive(Clone, Debug)]
12pub enum LatentScoreSemantics {
13    /// z is already on a frozen latent scale and the calibration law is
14    /// assumed (approximately) standard normal. `check_mode` controls whether
15    /// the fit aborts (`Strict`), only warns (`WarnOnly`), or skips the
16    /// normality diagnostics entirely (`Off`).
17    FrozenConditionalNormal { check_mode: LatentZCheckMode },
18    /// z will be centered/scaled inside the fit.
19    FitWeightedNormalization,
20    /// z is carried by its observed empirical latent measure instead of
21    /// pretending the downstream calibration law is standard normal.
22    EmpiricalLatentMeasure { normalize_location_scale: bool },
23}
24
25impl LatentScoreSemantics {
26    pub fn into_policy(self) -> LatentZPolicy {
27        match self {
28            Self::FrozenConditionalNormal { check_mode } => LatentZPolicy {
29                check_mode,
30                ..LatentZPolicy::frozen_transformation_normal()
31            },
32            Self::FitWeightedNormalization => LatentZPolicy::exploratory_fit_weighted(),
33            Self::EmpiricalLatentMeasure {
34                normalize_location_scale,
35            } => LatentZPolicy {
36                normalization: if normalize_location_scale {
37                    LatentZNormalizationMode::FitWeighted
38                } else {
39                    LatentZNormalizationMode::None
40                },
41                latent_measure: LatentMeasureSpec::GlobalEmpirical {
42                    grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
43                },
44                ..LatentZPolicy::exploratory_fit_weighted()
45            },
46        }
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct MarginalSlopeCalibrationProtocol {
52    pub base_link: InverseLink,
53    /// Optional cubic score-warp block. `None` selects the rigid
54    /// (algebraic closed-form) path for the score-warp axis.
55    pub score_warp: Option<DeviationBlockConfig>,
56    /// Optional cubic link-deviation block. `None` selects the rigid
57    /// (algebraic closed-form) path for the link-deviation axis.
58    pub link_deviation: Option<DeviationBlockConfig>,
59    pub latent_score: LatentScoreSemantics,
60}
61
62impl MarginalSlopeCalibrationProtocol {
63    fn default_latent_score() -> LatentScoreSemantics {
64        // WarnOnly mirrors `LatentZPolicy::frozen_transformation_normal`'s
65        // own default: at large-scale dimensionality the upstream conditional
66        // transformation-normal preprocessor can leave the global latent z
67        // mildly heavy-tailed without violating per-strata calibration.
68        LatentScoreSemantics::FrozenConditionalNormal {
69            check_mode: LatentZCheckMode::WarnOnly,
70        }
71    }
72
73    /// Construct a probit-link marginal-slope protocol with caller-supplied
74    /// optional score-warp / link-deviation blocks and explicit latent-score
75    /// semantics. Pass `None` for either block to select the rigid algebraic
76    /// closed-form path on that axis.
77    pub fn probit(
78        score_warp: Option<DeviationBlockConfig>,
79        link_deviation: Option<DeviationBlockConfig>,
80        latent_score: LatentScoreSemantics,
81    ) -> Self {
82        Self {
83            base_link: InverseLink::Standard(StandardLink::Probit),
84            score_warp,
85            link_deviation,
86            latent_score,
87        }
88    }
89
90    /// Rigid probit marginal-slope: no score-warp, no link-deviation.
91    pub fn probit_rigid() -> Self {
92        Self::probit(None, None, Self::default_latent_score())
93    }
94
95    /// Probit marginal-slope with both cubic blocks at their triple-penalty
96    /// defaults.
97    pub fn probit_with_score_and_link_wiggle() -> Self {
98        let wiggle = DeviationBlockConfig::triple_penalty_default();
99        Self::probit(
100            Some(wiggle.clone()),
101            Some(wiggle),
102            Self::default_latent_score(),
103        )
104    }
105}
106
107#[derive(Clone, Debug)]
108pub struct SurvivalMarginalSlopeProtocol {
109    pub marginal: MarginalSlopeCalibrationProtocol,
110    pub baseline_target: SurvivalBaselineTarget,
111}
112
113impl SurvivalMarginalSlopeProtocol {
114    /// Survival marginal-slope on a Gompertz-Makeham baseline with the
115    /// supplied marginal-calibration protocol. Score-warp, link-deviation,
116    /// and latent-score semantics all flow through from `marginal` —
117    /// nothing is baked in.
118    pub fn gompertz_makeham_probit(marginal: MarginalSlopeCalibrationProtocol) -> Self {
119        Self {
120            marginal,
121            baseline_target: SurvivalBaselineTarget::GompertzMakeham,
122        }
123    }
124}