gam_models/fit_orchestration/materialize/
survival_time.rs1use super::*;
2
3pub struct PreparedSurvivalTimeStack {
4 pub eta_offset_entry: Array1<f64>,
5 pub eta_offset_exit: Array1<f64>,
6 pub derivative_offset_exit: Array1<f64>,
7 pub unloaded_mass_entry: Array1<f64>,
8 pub unloaded_mass_exit: Array1<f64>,
9 pub unloaded_hazard_exit: Array1<f64>,
10 pub time_design_entry: gam_linalg::matrix::DesignMatrix,
11 pub time_design_exit: gam_linalg::matrix::DesignMatrix,
12 pub time_design_derivative_exit: gam_linalg::matrix::DesignMatrix,
13 pub time_penalties: Vec<Array2<f64>>,
14 pub time_nullspace_dims: Vec<usize>,
15 pub timewiggle_build: Option<crate::survival::construction::SurvivalTimeWiggleBuild>,
16 pub timewiggle_block: Option<TimeWiggleBlockInput>,
17}
18
19pub fn prepare_survival_time_stack(
20 age_entry: &Array1<f64>,
21 age_exit: &Array1<f64>,
22 baseline_cfg: &crate::survival::construction::SurvivalBaselineConfig,
23 likelihood_mode: SurvivalLikelihoodMode,
24 inverse_link: Option<&InverseLink>,
25 time_anchor: f64,
26 derivative_guard: f64,
27 time_build: &crate::survival::construction::SurvivalTimeBuildOutput,
28 effective_timewiggle: Option<&LinkWiggleFormulaSpec>,
29 latent_loading: Option<crate::survival::lognormal_kernel::HazardLoading>,
30) -> Result<PreparedSurvivalTimeStack, String> {
31 let (
32 mut eta_offset_entry,
33 mut eta_offset_exit,
34 mut derivative_offset_exit,
35 unloaded_mass_entry,
36 unloaded_mass_exit,
37 unloaded_hazard_exit,
38 ) = if let Some(loading) = latent_loading {
39 let offsets =
40 build_latent_survival_baseline_offsets(age_entry, age_exit, baseline_cfg, loading)?;
41 (
42 offsets.loaded_eta_entry,
43 offsets.loaded_eta_exit,
44 offsets.loaded_derivative_exit,
45 offsets.unloaded_mass_entry,
46 offsets.unloaded_mass_exit,
47 offsets.unloaded_hazard_exit,
48 )
49 } else {
50 let conditioning_cfg;
78 let offset_cfg = if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
79 && baseline_cfg.target == SurvivalBaselineTarget::Linear
80 {
81 let scale =
82 crate::survival::construction::positive_survival_time_seed(age_exit);
83 conditioning_cfg = crate::survival::construction::SurvivalBaselineConfig {
84 target: SurvivalBaselineTarget::Weibull,
85 scale: Some(scale),
86 shape: Some(1.0),
87 rate: None,
88 makeham: None,
89 };
90 &conditioning_cfg
91 } else {
92 baseline_cfg
93 };
94 let (eta_offset_entry, eta_offset_exit, derivative_offset_exit) =
95 build_survival_time_offsets_for_likelihood(
96 age_entry,
97 age_exit,
98 offset_cfg,
99 likelihood_mode,
100 inverse_link,
101 )?;
102 let n = age_entry.len();
103 (
104 eta_offset_entry,
105 eta_offset_exit,
106 derivative_offset_exit,
107 Array1::zeros(n),
108 Array1::zeros(n),
109 Array1::zeros(n),
110 )
111 };
112 add_survival_time_derivative_guard_offset(
113 age_entry,
114 age_exit,
115 time_anchor,
116 derivative_guard,
117 &mut eta_offset_entry,
118 &mut eta_offset_exit,
119 &mut derivative_offset_exit,
120 )?;
121 let timewiggle_build = if let Some(cfg) = effective_timewiggle {
122 Some(build_survival_timewiggle_from_baseline(
123 &eta_offset_entry,
124 &eta_offset_exit,
125 &derivative_offset_exit,
126 cfg,
127 )?)
128 } else {
129 None
130 };
131 let mut time_design_entry = time_build.x_entry_time.clone();
132 let mut time_design_exit = time_build.x_exit_time.clone();
133 let mut time_design_derivative_exit = time_build.x_derivative_time.clone();
134 let mut time_penalties = time_build.penalties.clone();
135 let mut time_nullspace_dims = time_build.nullspace_dims.clone();
136 let mut timewiggle_block = None;
137 if let Some(wiggle) = timewiggle_build.as_ref() {
138 let p_base = time_design_exit.ncols();
139 append_zero_tail_columns(
140 &mut time_design_entry,
141 &mut time_design_exit,
142 &mut time_design_derivative_exit,
143 wiggle.ncols,
144 );
145 for (idx, penalty) in wiggle.penalties.iter().enumerate() {
146 let mut embedded = Array2::<f64>::zeros((p_base + wiggle.ncols, p_base + wiggle.ncols));
147 embedded
148 .slice_mut(s![
149 p_base..p_base + wiggle.ncols,
150 p_base..p_base + wiggle.ncols
151 ])
152 .assign(penalty);
153 time_penalties.push(embedded);
154 time_nullspace_dims.push(wiggle.nullspace_dims.get(idx).copied().unwrap_or(0));
155 }
156 timewiggle_block = Some(TimeWiggleBlockInput {
157 knots: wiggle.knots.clone(),
158 degree: wiggle.degree,
159 ncols: wiggle.ncols,
160 });
161 }
162 Ok(PreparedSurvivalTimeStack {
163 eta_offset_entry,
164 eta_offset_exit,
165 derivative_offset_exit,
166 unloaded_mass_entry,
167 unloaded_mass_exit,
168 unloaded_hazard_exit,
169 time_design_entry,
170 time_design_exit,
171 time_design_derivative_exit,
172 time_penalties,
173 time_nullspace_dims,
174 timewiggle_build,
175 timewiggle_block,
176 })
177}