1use crate::bms::{
2 DEFAULT_EMPIRICAL_LATENT_GRID_SIZE, DeviationBlockConfig, LatentMeasureSpec, LatentZCheckMode,
3 LatentZNormalizationMode, LatentZPolicy,
4};
5use crate::survival::construction::SurvivalBaselineTarget;
6use gam_problem::{InverseLink, StandardLink};
7
8#[derive(Clone, Debug)]
12pub enum LatentScoreSemantics {
13 FrozenConditionalNormal { check_mode: LatentZCheckMode },
18 FitWeightedNormalization,
20 EmpiricalLatentMeasure { normalize_location_scale: bool },
23}
24
25impl LatentScoreSemantics {
26 pub fn into_policy(self) -> LatentZPolicy {
27 match self {
28 Self::FrozenConditionalNormal { check_mode } => LatentZPolicy {
29 check_mode,
30 ..LatentZPolicy::frozen_transformation_normal()
31 },
32 Self::FitWeightedNormalization => LatentZPolicy::exploratory_fit_weighted(),
33 Self::EmpiricalLatentMeasure {
34 normalize_location_scale,
35 } => LatentZPolicy {
36 normalization: if normalize_location_scale {
37 LatentZNormalizationMode::FitWeighted
38 } else {
39 LatentZNormalizationMode::None
40 },
41 latent_measure: LatentMeasureSpec::GlobalEmpirical {
42 grid_size: DEFAULT_EMPIRICAL_LATENT_GRID_SIZE,
43 },
44 ..LatentZPolicy::exploratory_fit_weighted()
45 },
46 }
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct MarginalSlopeCalibrationProtocol {
52 pub base_link: InverseLink,
53 pub score_warp: Option<DeviationBlockConfig>,
56 pub link_deviation: Option<DeviationBlockConfig>,
59 pub latent_score: LatentScoreSemantics,
60}
61
62impl MarginalSlopeCalibrationProtocol {
63 fn default_latent_score() -> LatentScoreSemantics {
64 LatentScoreSemantics::FrozenConditionalNormal {
69 check_mode: LatentZCheckMode::WarnOnly,
70 }
71 }
72
73 pub fn probit(
78 score_warp: Option<DeviationBlockConfig>,
79 link_deviation: Option<DeviationBlockConfig>,
80 latent_score: LatentScoreSemantics,
81 ) -> Self {
82 Self {
83 base_link: InverseLink::Standard(StandardLink::Probit),
84 score_warp,
85 link_deviation,
86 latent_score,
87 }
88 }
89
90 pub fn probit_rigid() -> Self {
92 Self::probit(None, None, Self::default_latent_score())
93 }
94
95 pub fn probit_with_score_and_link_wiggle() -> Self {
98 let wiggle = DeviationBlockConfig::triple_penalty_default();
99 Self::probit(
100 Some(wiggle.clone()),
101 Some(wiggle),
102 Self::default_latent_score(),
103 )
104 }
105}
106
107#[derive(Clone, Debug)]
108pub struct SurvivalMarginalSlopeProtocol {
109 pub marginal: MarginalSlopeCalibrationProtocol,
110 pub baseline_target: SurvivalBaselineTarget,
111}
112
113impl SurvivalMarginalSlopeProtocol {
114 pub fn gompertz_makeham_probit(marginal: MarginalSlopeCalibrationProtocol) -> Self {
119 Self {
120 marginal,
121 baseline_target: SurvivalBaselineTarget::GompertzMakeham,
122 }
123 }
124}