Skip to main content

gam_models/fit_orchestration/materialize/
survival_time.rs

1use super::*;
2
3pub struct PreparedSurvivalTimeStack {
4    pub eta_offset_entry: Array1<f64>,
5    pub eta_offset_exit: Array1<f64>,
6    pub derivative_offset_exit: Array1<f64>,
7    pub unloaded_mass_entry: Array1<f64>,
8    pub unloaded_mass_exit: Array1<f64>,
9    pub unloaded_hazard_exit: Array1<f64>,
10    pub time_design_entry: gam_linalg::matrix::DesignMatrix,
11    pub time_design_exit: gam_linalg::matrix::DesignMatrix,
12    pub time_design_derivative_exit: gam_linalg::matrix::DesignMatrix,
13    pub time_penalties: Vec<Array2<f64>>,
14    pub time_nullspace_dims: Vec<usize>,
15    pub timewiggle_build: Option<crate::survival::construction::SurvivalTimeWiggleBuild>,
16    pub timewiggle_block: Option<TimeWiggleBlockInput>,
17}
18
19pub fn prepare_survival_time_stack(
20    age_entry: &Array1<f64>,
21    age_exit: &Array1<f64>,
22    baseline_cfg: &crate::survival::construction::SurvivalBaselineConfig,
23    likelihood_mode: SurvivalLikelihoodMode,
24    inverse_link: Option<&InverseLink>,
25    time_anchor: f64,
26    derivative_guard: f64,
27    time_build: &crate::survival::construction::SurvivalTimeBuildOutput,
28    effective_timewiggle: Option<&LinkWiggleFormulaSpec>,
29    latent_loading: Option<crate::survival::lognormal_kernel::HazardLoading>,
30) -> Result<PreparedSurvivalTimeStack, String> {
31    let (
32        mut eta_offset_entry,
33        mut eta_offset_exit,
34        mut derivative_offset_exit,
35        unloaded_mass_entry,
36        unloaded_mass_exit,
37        unloaded_hazard_exit,
38    ) = if let Some(loading) = latent_loading {
39        let offsets =
40            build_latent_survival_baseline_offsets(age_entry, age_exit, baseline_cfg, loading)?;
41        (
42            offsets.loaded_eta_entry,
43            offsets.loaded_eta_exit,
44            offsets.loaded_derivative_exit,
45            offsets.unloaded_mass_entry,
46            offsets.unloaded_mass_exit,
47            offsets.unloaded_hazard_exit,
48        )
49    } else {
50        // Baseline-hazard barrier conditioning for the marginal-slope likelihood
51        // (gam#797). That likelihood carries `-d·log(qd1)`, a log-barrier on the
52        // baseline-hazard time derivative `qd1 = X_d·β_time + derivative_offset`.
53        // The default `baseline-target=linear` is DEGENERATE for this barrier:
54        // `evaluate_survival_baseline` returns `(0, 0)` for Linear, so the offset
55        // collapses to `derivative_guard` (1e-6) and the I-spline time seed starts
56        // at `qd1 ≈ 1e-6` — exactly ON the barrier boundary, where the
57        // self-concordant Newton step is `∝ qd1` (intrinsically ~1e-4), the
58        // barrier gradient/Hessian are ~1e6 / ~1e12, and the inner joint-Newton
59        // crawls and never reaches the data-scale baseline within the cycle
60        // budget — every outer seed is rejected and the fit hard-fails.
61        //
62        // Condition the COLD START by building the baseline OFFSET from a fixed,
63        // data-seeded Weibull (scale = mean positive exit time, shape = 1) instead
64        // of the zero-derivative Linear baseline, but ONLY for the offset: the
65        // outer `baseline_cfg.target` stays `Linear`, so the
66        // `baseline_cfg.target != Linear` optimize gate
67        // (the gradient baseline optimizers) never fires and no baseline-shape
68        // search is introduced. With shape = 1 the Weibull baseline-hazard
69        // derivative is `1/age_exit` (the natural data hazard scale), so the seed
70        // starts with `qd1` at O(1/T) interior — barrier gradient O(10-10²),
71        // comparable to the marginal/logslope blocks — and `β_time ≈ 0`. This
72        // changes only the STARTING point / offset split: the I-spline still learns
73        // the data-driven deviation from this parametric baseline (the converged
74        // fitted hazard is the same flexible family), so the fix is a pure
75        // preconditioning of the cold start. Gated to MarginalSlope with a Linear
76        // target so every other Linear-baseline survival path is byte-unchanged.
77        let conditioning_cfg;
78        let offset_cfg = if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
79            && baseline_cfg.target == SurvivalBaselineTarget::Linear
80        {
81            let scale =
82                crate::survival::construction::positive_survival_time_seed(age_exit);
83            conditioning_cfg = crate::survival::construction::SurvivalBaselineConfig {
84                target: SurvivalBaselineTarget::Weibull,
85                scale: Some(scale),
86                shape: Some(1.0),
87                rate: None,
88                makeham: None,
89            };
90            &conditioning_cfg
91        } else {
92            baseline_cfg
93        };
94        let (eta_offset_entry, eta_offset_exit, derivative_offset_exit) =
95            build_survival_time_offsets_for_likelihood(
96                age_entry,
97                age_exit,
98                offset_cfg,
99                likelihood_mode,
100                inverse_link,
101            )?;
102        let n = age_entry.len();
103        (
104            eta_offset_entry,
105            eta_offset_exit,
106            derivative_offset_exit,
107            Array1::zeros(n),
108            Array1::zeros(n),
109            Array1::zeros(n),
110        )
111    };
112    add_survival_time_derivative_guard_offset(
113        age_entry,
114        age_exit,
115        time_anchor,
116        derivative_guard,
117        &mut eta_offset_entry,
118        &mut eta_offset_exit,
119        &mut derivative_offset_exit,
120    )?;
121    let timewiggle_build = if let Some(cfg) = effective_timewiggle {
122        Some(build_survival_timewiggle_from_baseline(
123            &eta_offset_entry,
124            &eta_offset_exit,
125            &derivative_offset_exit,
126            cfg,
127        )?)
128    } else {
129        None
130    };
131    let mut time_design_entry = time_build.x_entry_time.clone();
132    let mut time_design_exit = time_build.x_exit_time.clone();
133    let mut time_design_derivative_exit = time_build.x_derivative_time.clone();
134    let mut time_penalties = time_build.penalties.clone();
135    let mut time_nullspace_dims = time_build.nullspace_dims.clone();
136    let mut timewiggle_block = None;
137    if let Some(wiggle) = timewiggle_build.as_ref() {
138        let p_base = time_design_exit.ncols();
139        append_zero_tail_columns(
140            &mut time_design_entry,
141            &mut time_design_exit,
142            &mut time_design_derivative_exit,
143            wiggle.ncols,
144        );
145        for (idx, penalty) in wiggle.penalties.iter().enumerate() {
146            let mut embedded = Array2::<f64>::zeros((p_base + wiggle.ncols, p_base + wiggle.ncols));
147            embedded
148                .slice_mut(s![
149                    p_base..p_base + wiggle.ncols,
150                    p_base..p_base + wiggle.ncols
151                ])
152                .assign(penalty);
153            time_penalties.push(embedded);
154            time_nullspace_dims.push(wiggle.nullspace_dims.get(idx).copied().unwrap_or(0));
155        }
156        timewiggle_block = Some(TimeWiggleBlockInput {
157            knots: wiggle.knots.clone(),
158            degree: wiggle.degree,
159            ncols: wiggle.ncols,
160        });
161    }
162    Ok(PreparedSurvivalTimeStack {
163        eta_offset_entry,
164        eta_offset_exit,
165        derivative_offset_exit,
166        unloaded_mass_entry,
167        unloaded_mass_exit,
168        unloaded_hazard_exit,
169        time_design_entry,
170        time_design_exit,
171        time_design_derivative_exit,
172        time_penalties,
173        time_nullspace_dims,
174        timewiggle_build,
175        timewiggle_block,
176    })
177}