1use std::collections::HashMap;
11
12use ndarray::{Array1, Array2, ArrayView2, s};
13
14use crate::scale_design::scale_transform_from_payload;
15use crate::survival::construction::{
16 SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalLikelihoodMode,
17 SurvivalTimeBuildOutput, add_survival_time_derivative_guard_offset, build_survival_time_basis,
18 build_survival_time_offsets_for_likelihood, build_survival_timewiggle_derivative_design,
19 center_survival_time_designs_at_anchor, evaluate_survival_time_basis_row,
20 normalize_survival_time_pair, parse_survival_likelihood_mode,
21 require_structural_survival_time_basis, resolved_survival_time_basis_config_from_build,
22 survival_derivative_guard_for_likelihood, survival_likelihood_modename,
23};
24use crate::survival::lognormal_kernel::FrailtySpec;
25use crate::survival::{
26 CompetingRisksCifResult, assemble_competing_risks_cif_from_endpoints,
27};
28use crate::wiggle::buildwiggle_block_input_from_knots;
29use crate::inference::model::{
30 FittedFamily, FittedModel as SavedModel, SavedBaselineTimeWiggleRuntime,
31 load_survival_time_basis_config_from_model, survival_baseline_config_from_model,
32};
33use crate::inference::predict_io::{BernoulliMarginalSlopePredictor, PredictInput};
34use gam_linalg::matrix::DesignMatrix;
35use crate::model_types::{BlockRole, FittedBlock, FittedLinkState, UnifiedFitResult};
36use crate::probability::signed_probit_logcdf_and_mills_ratio;
37use gam_solve::mixture_link::inverse_link_jet_for_inverse_link;
38use gam_terms::term_builder::resolve_role_col;
39use gam_terms::smooth::build_term_collection_design;
40use gam_terms::smooth::TermCollectionSpec;
41use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
42
43pub struct SurvivalTimeColumns {
51 pub entry_col: Option<usize>,
52 pub exit_col: usize,
53}
54
55impl SurvivalTimeColumns {
56 #[inline]
59 pub fn row_entry_time(&self, data: ArrayView2<'_, f64>, i: usize) -> f64 {
60 self.entry_col.map_or(0.0, |idx| data[[i, idx]])
61 }
62}
63
64pub fn resolve_saved_survival_time_columns(
68 model: &SavedModel,
69 col_map: &HashMap<String, usize>,
70) -> Result<SurvivalTimeColumns, String> {
71 let entry_col: Option<usize> = model
72 .survival_entry
73 .as_deref()
74 .map(|name| resolve_role_col(col_map, name, "entry"))
75 .transpose()?;
76 let exitname = model
77 .survival_exit
78 .as_ref()
79 .ok_or_else(|| "survival model missing exit column metadata".to_string())?;
80 let exit_col = resolve_role_col(col_map, exitname, "exit")?;
81 Ok(SurvivalTimeColumns {
82 entry_col,
83 exit_col,
84 })
85}
86
87const SURVIVAL_PROB_MIN_FOR_LOG: f64 = 1e-300;
93
94#[derive(Debug, Clone)]
101pub enum SurvivalPredictError {
102 InvalidInput { reason: String },
105 MissingFitMetadata { reason: String },
109 IncompatibleSchema { reason: String },
112 UnsupportedConfiguration { reason: String },
117 NumericalFailure { reason: String },
121 ModelPayload {
124 context: &'static str,
125 source: crate::inference::model::FittedModelError,
126 },
127}
128
129impl std::fmt::Display for SurvivalPredictError {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 SurvivalPredictError::InvalidInput { reason }
133 | SurvivalPredictError::MissingFitMetadata { reason }
134 | SurvivalPredictError::IncompatibleSchema { reason }
135 | SurvivalPredictError::UnsupportedConfiguration { reason }
136 | SurvivalPredictError::NumericalFailure { reason } => f.write_str(reason),
137 SurvivalPredictError::ModelPayload { context, source } => {
138 write!(f, "{context}: {source}")
139 }
140 }
141 }
142}
143
144impl std::error::Error for SurvivalPredictError {
145 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
146 match self {
147 SurvivalPredictError::ModelPayload { source, .. } => Some(source),
148 SurvivalPredictError::InvalidInput { .. }
149 | SurvivalPredictError::MissingFitMetadata { .. }
150 | SurvivalPredictError::IncompatibleSchema { .. }
151 | SurvivalPredictError::UnsupportedConfiguration { .. }
152 | SurvivalPredictError::NumericalFailure { .. } => None,
153 }
154 }
155}
156
157impl From<SurvivalPredictError> for String {
158 fn from(err: SurvivalPredictError) -> String {
159 err.to_string()
160 }
161}
162
163impl From<String> for SurvivalPredictError {
164 fn from(reason: String) -> SurvivalPredictError {
170 SurvivalPredictError::InvalidInput { reason }
171 }
172}
173
174impl From<gam_data::DataError> for SurvivalPredictError {
175 fn from(err: gam_data::DataError) -> SurvivalPredictError {
179 SurvivalPredictError::InvalidInput {
180 reason: err.to_string(),
181 }
182 }
183}
184
185pub struct SurvivalPredictRequest<'a> {
187 pub model: &'a SavedModel,
188 pub data: ArrayView2<'a, f64>,
189 pub col_map: &'a HashMap<String, usize>,
190 pub training_headers: Option<&'a Vec<String>>,
191 pub primary_offset: &'a Array1<f64>,
192 pub noise_offset: &'a Array1<f64>,
193 pub time_grid: Option<&'a [f64]>,
196 pub with_uncertainty: bool,
202}
203
204pub struct SurvivalPredictResult {
206 pub times: Vec<f64>,
207 pub hazard: Array2<f64>,
208 pub survival: Array2<f64>,
209 pub cumulative_hazard: Array2<f64>,
210 pub linear_predictor: Array1<f64>,
211 pub likelihood_mode: SurvivalLikelihoodMode,
212 pub survival_se: Option<Array2<f64>>,
216 pub eta_se: Option<Array1<f64>>,
220}
221
222fn restricted_mean_survival_time_from_curve(
238 times: &[f64],
239 survival_row: ndarray::ArrayView1<'_, f64>,
240 tau: f64,
241) -> Option<f64> {
242 if times.is_empty() || !(tau > 0.0) || !tau.is_finite() {
243 return None;
244 }
245 if times.len() != survival_row.len() {
246 return None;
247 }
248
249 let mut prev_t = 0.0_f64;
251 let mut prev_s = 1.0_f64;
252 let mut area = 0.0_f64;
253
254 for (idx, &t) in times.iter().enumerate() {
255 if !t.is_finite() || t < prev_t {
256 return None;
257 }
258 let s = survival_row[idx];
259 if !s.is_finite() {
260 return None;
261 }
262 if t >= tau {
263 let span = t - prev_t;
265 let s_tau = if span > 0.0 {
266 let w = (tau - prev_t) / span;
267 prev_s + w * (s - prev_s)
268 } else {
269 prev_s
270 };
271 area += 0.5 * (prev_s + s_tau) * (tau - prev_t);
272 return Some(area);
273 }
274 area += 0.5 * (prev_s + s) * (t - prev_t);
275 prev_t = t;
276 prev_s = s;
277 }
278
279 area += prev_s * (tau - prev_t);
283 Some(area)
284}
285
286impl SurvivalPredictResult {
287 pub fn restricted_mean_survival_time(&self, tau: f64) -> Option<Array1<f64>> {
294 let n = self.survival.nrows();
295 let mut out = Array1::<f64>::zeros(n);
296 for i in 0..n {
297 let rmst =
298 restricted_mean_survival_time_from_curve(&self.times, self.survival.row(i), tau)?;
299 out[i] = rmst;
300 }
301 Some(out)
302 }
303}
304
305impl CompetingRisksPredictResult {
306 pub fn restricted_mean_overall_survival_time(&self, tau: f64) -> Option<Array1<f64>> {
312 let n = self.overall_survival.nrows();
313 let mut out = Array1::<f64>::zeros(n);
314 for i in 0..n {
315 let rmst = restricted_mean_survival_time_from_curve(
316 &self.times,
317 self.overall_survival.row(i),
318 tau,
319 )?;
320 out[i] = rmst;
321 }
322 Some(out)
323 }
324}
325
326pub fn harrell_concordance(time: &[f64], event: &[f64], risk: &[f64]) -> Option<f64> {
339 let n = time.len();
340 if n != event.len() || n != risk.len() {
341 return None;
342 }
343 let mut comparable = 0.0_f64;
344 let mut concordant = 0.0_f64;
345 for i in 0..n {
346 for j in (i + 1)..n {
347 let (early, late) = if time[i] < time[j] {
348 (i, j)
349 } else if time[j] < time[i] {
350 (j, i)
351 } else {
352 if event[i] > 0.5 && event[j] > 0.5 {
355 comparable += 1.0;
356 concordant += 0.5;
357 }
358 continue;
359 };
360 if event[early] < 0.5 {
361 continue;
363 }
364 comparable += 1.0;
365 if risk[early] > risk[late] {
366 concordant += 1.0;
367 } else if risk[early] == risk[late] {
368 concordant += 0.5;
369 }
370 }
371 }
372 if comparable == 0.0 {
373 return None;
374 }
375 Some(concordant / comparable)
376}
377
378pub fn ipcw_brier_score(
407 s_pred: &[f64],
408 time: &[f64],
409 event: &[f64],
410 tau: f64,
411 g_cens: impl Fn(f64) -> f64,
412) -> Option<f64> {
413 let n = s_pred.len();
414 if n != time.len() || n != event.len() {
415 return None;
416 }
417 let mut n_valid = 0.0_f64;
418 let mut acc = 0.0_f64;
419 for i in 0..n {
420 if !time[i].is_finite() || !event[i].is_finite() || time[i] <= 0.0 {
421 continue;
422 }
423 n_valid += 1.0;
426 let (target, weight) = if time[i] <= tau && event[i] > 0.5 {
427 let g = g_cens(time[i]);
429 if !(g > 0.0) {
430 continue;
431 }
432 (0.0, 1.0 / g)
433 } else if time[i] > tau {
434 let g = g_cens(tau);
436 if !(g > 0.0) {
437 continue;
438 }
439 (1.0, 1.0 / g)
440 } else {
441 continue;
443 };
444 let resid = target - s_pred[i];
445 acc += weight * resid * resid;
446 }
447 if n_valid == 0.0 {
448 return None;
449 }
450 Some(acc / n_valid)
451}
452
453pub fn integrated_ipcw_brier_score(
472 s_pred: ArrayView2<f64>,
473 time: &[f64],
474 event: &[f64],
475 grid: &[f64],
476 horizon: f64,
477 g_cens: impl Fn(f64) -> f64,
478) -> Option<f64> {
479 let m = grid.len();
480 if m < 2 || s_pred.ncols() != m || s_pred.nrows() != time.len() {
481 return None;
482 }
483 if grid.windows(2).any(|pair| !(pair[1] > pair[0])) {
484 return None;
485 }
486 let mut pts: Vec<(f64, f64)> = Vec::with_capacity(m);
488 for k in 0..m {
489 if grid[k] > horizon {
490 break;
491 }
492 let col = s_pred.column(k);
493 let col_slice: Vec<f64> = col.to_vec();
494 if let Some(bs) = ipcw_brier_score(&col_slice, time, event, grid[k], &g_cens) {
495 pts.push((grid[k], bs));
496 }
497 }
498 if pts.len() < 2 {
499 return None;
500 }
501 let span = pts[pts.len() - 1].0 - pts[0].0;
502 if !(span > 0.0) {
503 return None;
504 }
505 let mut integral = 0.0_f64;
506 for w in pts.windows(2) {
507 integral += 0.5 * (w[1].1 + w[0].1) * (w[1].0 - w[0].0);
508 }
509 Some(integral / span)
510}
511
512#[derive(Clone, Debug, Default)]
519pub struct KaplanMeier {
520 steps: Vec<(f64, f64)>,
522}
523
524impl KaplanMeier {
525 pub fn fit(time: &[f64], event: &[f64]) -> Self {
527 let mut rows: Vec<(f64, bool)> = time
528 .iter()
529 .zip(event.iter())
530 .filter_map(|(&t, &e)| {
531 (t.is_finite() && e.is_finite() && t > 0.0).then_some((t, e > 0.5))
532 })
533 .collect();
534 rows.sort_by(|a, b| a.0.total_cmp(&b.0));
535 let mut steps = Vec::new();
536 let mut at_risk = rows.len() as f64;
537 let mut survival = 1.0_f64;
538 let mut i = 0usize;
539 while i < rows.len() {
540 let t = rows[i].0;
541 let mut j = i;
542 let mut deaths = 0usize;
543 while j < rows.len() && rows[j].0 == t {
544 deaths += usize::from(rows[j].1);
545 j += 1;
546 }
547 if deaths > 0 && at_risk > 0.0 {
548 survival *= ((at_risk - deaths as f64) / at_risk).max(0.0);
549 steps.push((t, survival));
550 }
551 at_risk -= (j - i) as f64;
552 i = j;
553 }
554 Self { steps }
555 }
556
557 pub fn fit_censoring(time: &[f64], event: &[f64]) -> Self {
561 let flipped: Vec<f64> = event
562 .iter()
563 .map(|&e| if e > 0.5 { 0.0 } else { 1.0 })
564 .collect();
565 Self::fit(time, &flipped)
566 }
567
568 pub fn at(&self, t: f64) -> f64 {
571 let mut s = 1.0_f64;
572 for &(time, surv) in &self.steps {
573 if time <= t {
574 s = surv;
575 } else {
576 break;
577 }
578 }
579 s
580 }
581}
582
583pub struct CompetingRisksPredictResult {
585 pub times: Vec<f64>,
586 pub endpoint_names: Vec<String>,
587 pub hazard: Vec<Array2<f64>>,
589 pub survival: Vec<Array2<f64>>,
591 pub cumulative_hazard: Vec<Array2<f64>>,
593 pub cif: Vec<Array2<f64>>,
595 pub overall_survival: Array2<f64>,
597 pub linear_predictor: Vec<Array1<f64>>,
599 pub likelihood_mode: SurvivalLikelihoodMode,
600}
601
602pub fn predict_survival(
608 req: SurvivalPredictRequest<'_>,
609) -> Result<SurvivalPredictResult, SurvivalPredictError> {
610 let SurvivalPredictRequest {
611 model,
612 data,
613 col_map,
614 training_headers,
615 primary_offset,
616 noise_offset,
617 time_grid,
618 with_uncertainty,
619 } = req;
620
621 let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
630 let exit_col = time_cols.exit_col;
631
632 let termspec = resolve_termspec_for_prediction(
633 &model.resolved_termspec,
634 training_headers,
635 col_map,
636 "resolved_termspec",
637 )?;
638 let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
644 let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
645 let cov_design = build_term_collection_design(cov_input, &termspec)
646 .map_err(|e| format!("failed to build survival prediction design: {e}"))?;
647
648 let n = data.nrows();
649 if primary_offset.len() != n || noise_offset.len() != n {
650 return Err(SurvivalPredictError::InvalidInput {
651 reason: format!(
652 "survival prediction offset length mismatch: rows={n}, offset={}, noise_offset={}",
653 primary_offset.len(),
654 noise_offset.len()
655 ),
656 });
657 }
658
659 use rayon::iter::{IntoParallelIterator, ParallelIterator};
660 let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
661 .into_par_iter()
662 .map(|i| {
663 normalize_survival_time_pair(time_cols.row_entry_time(data, i), data[[i, exit_col]], i)
664 })
665 .collect();
666 let pairs = pairs?;
667 let mut age_entry = Array1::<f64>::zeros(n);
668 let mut age_exit = Array1::<f64>::zeros(n);
669 for (i, (t0, t1)) in pairs.into_iter().enumerate() {
670 age_entry[i] = t0;
671 age_exit[i] = t1;
672 }
673
674 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
675
676 if matches!(
680 saved_likelihood_mode,
681 SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
682 ) {
683 return Err(SurvivalPredictError::UnsupportedConfiguration {
684 reason: format!(
685 "survival prediction via predict_survival does not support likelihood_mode={} yet; \
686 latent window prediction lives in the CLI's run_predict_saved_latent_window_impl \
687 pipeline and has not yet been ported to the library. Use the CLI predict command.",
688 survival_likelihood_modename(saved_likelihood_mode)
689 ),
690 });
691 }
692 if saved_likelihood_mode == SurvivalLikelihoodMode::LocationScale {
695 return predict_survival_location_scale_batch(
696 model,
697 &age_entry,
698 &age_exit,
699 &cov_design,
700 primary_offset,
701 noise_offset,
702 training_headers,
703 col_map,
704 data,
705 time_grid,
706 with_uncertainty,
707 )
708 .map_err(SurvivalPredictError::from);
709 }
710 if with_uncertainty {
711 return Err(SurvivalPredictError::from(format!(
712 "predict_survival: with_uncertainty is currently supported only for the \
713 location-scale likelihood mode; got {}",
714 survival_likelihood_modename(saved_likelihood_mode)
715 )));
716 }
717
718 let time_cfg = load_survival_time_basis_config_from_model(model)?;
721 let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
722 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
723 &time_build.basisname,
724 time_build.degree,
725 time_build.knots.as_ref(),
726 time_build.keep_cols.as_ref(),
727 time_build.smooth_lambda,
728 )?;
729 let weibull_baseline_in_beta = saved_likelihood_mode == SurvivalLikelihoodMode::Weibull
746 && !model.has_baseline_time_wiggle();
747 let mut time_anchor: Option<f64> = None;
748 let mut time_anchor_row_cached: Option<Array1<f64>> = None;
749 if matches!(
750 saved_likelihood_mode,
751 SurvivalLikelihoodMode::LocationScale | SurvivalLikelihoodMode::MarginalSlope
752 ) || weibull_baseline_in_beta
753 {
754 let anchor = model
755 .survival_time_anchor
756 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
757 let time_anchor_row = evaluate_survival_time_basis_row(anchor, &resolved_time_cfg)?;
758 center_survival_time_designs_at_anchor(
759 &mut time_build.x_entry_time,
760 &mut time_build.x_exit_time,
761 &time_anchor_row,
762 )?;
763 time_anchor = Some(anchor);
764 time_anchor_row_cached = Some(time_anchor_row);
765 }
766 if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull && !model.has_baseline_time_wiggle()
767 {
768 require_structural_survival_time_basis(&time_build.basisname, "saved survival sampling")?;
769 }
770 let mut baseline_cfg = saved_survival_runtime_baseline_config(model)?;
771 if weibull_baseline_in_beta {
772 baseline_cfg = SurvivalBaselineConfig {
773 target: SurvivalBaselineTarget::Linear,
774 scale: None,
775 shape: None,
776 rate: None,
777 makeham: None,
778 };
779 }
780
781 let per_row_eval = time_grid.is_none();
784 let eval_times: Vec<f64> = match time_grid {
785 Some(grid) => {
786 if grid.is_empty() {
787 return Err(SurvivalPredictError::InvalidInput {
788 reason: "survival time_grid must contain at least one time".to_string(),
789 });
790 }
791 for (idx, &t) in grid.iter().enumerate() {
792 if !t.is_finite() || t < 0.0 {
793 return Err(SurvivalPredictError::InvalidInput {
794 reason: format!(
795 "survival time_grid requires finite non-negative times (index {idx})",
796 ),
797 });
798 }
799 }
800 grid.to_vec()
801 }
802 None => Vec::new(),
803 };
804
805 let t_cols = if per_row_eval { 1 } else { eval_times.len() };
806 let mut hazard = Array2::<f64>::zeros((n, t_cols));
807 let mut survival = Array2::<f64>::zeros((n, t_cols));
808 let mut cumulative_hazard = Array2::<f64>::zeros((n, t_cols));
809 let mut linear_predictor = Array1::<f64>::zeros(n);
810
811 let marginal_slope_ctx = if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
817 let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
821 build_survival_time_offsets_for_likelihood(
822 &age_entry,
823 &age_exit,
824 &baseline_cfg,
825 saved_likelihood_mode,
826 None,
827 )?;
828 add_survival_time_derivative_guard_offset(
829 &age_entry,
830 &age_exit,
831 time_anchor.ok_or_else(|| {
832 "saved survival marginal-slope model missing survival_time_anchor".to_string()
833 })?,
834 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
835 &mut eta_offset_entry,
836 &mut eta_offset_exit,
837 &mut derivative_offset_exit,
838 )?;
839 Some(build_marginal_slope_predict_context(
840 model,
841 data,
842 col_map,
843 training_headers,
844 &cov_design.design,
845 primary_offset,
846 noise_offset,
847 &time_build,
848 &eta_offset_entry,
849 &eta_offset_exit,
850 &derivative_offset_exit,
851 )?)
852 } else {
853 None
854 };
855
856 struct SurvivalPredictionRow {
860 hazard: Vec<f64>,
861 survival: Vec<f64>,
862 cumulative_hazard: Vec<f64>,
863 linear_predictor: f64,
864 }
865
866 let row_results: Result<Vec<SurvivalPredictionRow>, SurvivalPredictError> = (0..n)
867 .into_par_iter()
868 .map(|i| {
869 let cov_row = if matches!(
870 saved_likelihood_mode,
871 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
872 ) {
873 Some(design_row_owned(
874 &cov_design.design,
875 i,
876 "survival predict covariate row",
877 )?)
878 } else {
879 None
880 };
881 let evaluate_at = |t_query: f64| -> Result<(f64, f64, f64), SurvivalPredictError> {
882 let t_entry = age_entry[i].min(t_query);
883 let single_entry = Array1::from_elem(1, t_entry);
884 let single_exit = Array1::from_elem(1, t_query);
885 let mut row_time =
886 build_survival_time_basis(&single_entry, &single_exit, time_cfg.clone(), None)?;
887 if let Some(anchor_row) = time_anchor_row_cached.as_ref() {
888 center_survival_time_designs_at_anchor(
889 &mut row_time.x_entry_time,
890 &mut row_time.x_exit_time,
891 anchor_row,
892 )?;
893 }
894 let (mut r_eta_entry, mut r_eta_exit, mut r_deriv_exit) =
895 build_survival_time_offsets_for_likelihood(
896 &single_entry,
897 &single_exit,
898 &baseline_cfg,
899 saved_likelihood_mode,
900 None,
901 )?;
902 if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
903 add_survival_time_derivative_guard_offset(
904 &single_entry,
905 &single_exit,
906 time_anchor.ok_or_else(|| {
907 "saved survival marginal-slope model missing survival_time_anchor"
908 .to_string()
909 })?,
910 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
911 &mut r_eta_entry,
912 &mut r_eta_exit,
913 &mut r_deriv_exit,
914 )?;
915 }
916
917 match saved_likelihood_mode {
918 SurvivalLikelihoodMode::MarginalSlope => {
919 let ctx = marginal_slope_ctx.as_ref().ok_or_else(|| {
920 "internal error: marginal-slope context missing for marginal-slope mode"
921 .to_string()
922 })?;
923 evaluate_marginal_slope_row(
924 i,
925 ctx,
926 &row_time,
927 &r_eta_exit,
928 &r_deriv_exit,
929 primary_offset[i],
930 )
931 }
932 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
933 let cov_row = cov_row.as_ref().ok_or_else(|| {
934 "internal error: covariate row missing for Royston-Parmar prediction"
935 .to_string()
936 })?;
937 evaluate_rp_row(
938 model,
939 &row_time,
940 cov_row,
941 r_eta_exit[0],
942 r_deriv_exit[0],
943 primary_offset[i],
944 )
945 }
946 SurvivalLikelihoodMode::Latent
947 | SurvivalLikelihoodMode::LatentBinary
948 | SurvivalLikelihoodMode::LocationScale => {
949 Err(SurvivalPredictError::NumericalFailure {
950 reason: "unreachable: unsupported likelihood_mode filtered earlier"
951 .to_string(),
952 })
953 }
954 }
955 };
956
957 let mut row = SurvivalPredictionRow {
958 hazard: vec![0.0; t_cols],
959 survival: vec![0.0; t_cols],
960 cumulative_hazard: vec![0.0; t_cols],
961 linear_predictor: 0.0,
962 };
963 if per_row_eval {
964 let (eta_t, cum_t, haz_t) = evaluate_at(age_exit[i])?;
965 row.linear_predictor = eta_t;
966 row.hazard[0] = haz_t;
967 row.cumulative_hazard[0] = cum_t;
968 row.survival[0] = (-cum_t).exp().clamp(0.0, 1.0);
969 } else {
970 for (j, &t_query) in eval_times.iter().enumerate() {
971 if t_query <= 0.0 {
972 row.hazard[j] = 0.0;
973 row.cumulative_hazard[j] = 0.0;
974 row.survival[j] = 1.0;
975 } else {
976 let (_eta_t, cum_t, haz_t) = evaluate_at(t_query)?;
977 row.hazard[j] = haz_t;
978 row.cumulative_hazard[j] = cum_t;
979 row.survival[j] = (-cum_t).exp().clamp(0.0, 1.0);
980 }
981 }
982 let (eta_t, _, _) = evaluate_at(age_exit[i])?;
983 row.linear_predictor = eta_t;
984 }
985 Ok(row)
986 })
987 .collect();
988
989 for (i, row) in row_results?.into_iter().enumerate() {
990 linear_predictor[i] = row.linear_predictor;
991 for j in 0..t_cols {
992 hazard[[i, j]] = row.hazard[j];
993 cumulative_hazard[[i, j]] = row.cumulative_hazard[j];
994 survival[[i, j]] = row.survival[j];
995 }
996 }
997
998 let times_out: Vec<f64> = if per_row_eval {
999 age_exit.to_vec()
1000 } else {
1001 eval_times
1002 };
1003
1004 Ok(SurvivalPredictResult {
1005 times: times_out,
1006 hazard,
1007 survival,
1008 cumulative_hazard,
1009 linear_predictor,
1010 likelihood_mode: saved_likelihood_mode,
1011 survival_se: None,
1012 eta_se: None,
1013 })
1014}
1015
1016pub fn predict_competing_risks_survival(
1017 req: SurvivalPredictRequest<'_>,
1018) -> Result<CompetingRisksPredictResult, SurvivalPredictError> {
1019 let SurvivalPredictRequest {
1020 model,
1021 data,
1022 col_map,
1023 training_headers,
1024 primary_offset,
1025 noise_offset,
1026 time_grid,
1027 with_uncertainty,
1028 } = req;
1029
1030 if with_uncertainty {
1031 return Err(SurvivalPredictError::UnsupportedConfiguration {
1032 reason: "competing-risks survival prediction does not yet support with_uncertainty"
1033 .to_string(),
1034 });
1035 }
1036
1037 let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
1038 if !matches!(
1039 saved_likelihood_mode,
1040 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
1041 ) {
1042 return Err(SurvivalPredictError::UnsupportedConfiguration {
1043 reason: format!(
1044 "joint cause-specific prediction supports transformation/weibull survival only; got {}",
1045 survival_likelihood_modename(saved_likelihood_mode)
1046 ),
1047 });
1048 }
1049
1050 let fit = fit_result_from_saved_model_for_prediction(model)?;
1051 let cause_count = model
1052 .survival_cause_count
1053 .unwrap_or(fit.blocks.len())
1054 .max(1);
1055 if cause_count <= 1 {
1056 return Err(SurvivalPredictError::MissingFitMetadata {
1057 reason: "competing-risks survival prediction requires a saved model with at least two causes"
1058 .to_string(),
1059 });
1060 }
1061 if fit.blocks.len() != cause_count {
1062 return Err(SurvivalPredictError::IncompatibleSchema {
1063 reason: format!(
1064 "saved competing-risks survival fit has {} coefficient blocks but metadata says {cause_count} causes",
1065 fit.blocks.len()
1066 ),
1067 });
1068 }
1069 let endpoint_names = model.survival_endpoint_names.clone().unwrap_or_else(|| {
1070 (1..=cause_count)
1071 .map(|idx| format!("cause_{idx}"))
1072 .collect()
1073 });
1074 if endpoint_names.len() != cause_count {
1075 return Err(SurvivalPredictError::IncompatibleSchema {
1076 reason: format!(
1077 "saved competing-risks survival endpoint_names has length {}, expected {cause_count}",
1078 endpoint_names.len()
1079 ),
1080 });
1081 }
1082
1083 let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
1087 let exit_col = time_cols.exit_col;
1088
1089 let termspec = resolve_termspec_for_prediction(
1090 &model.resolved_termspec,
1091 training_headers,
1092 col_map,
1093 "resolved_termspec",
1094 )?;
1095 let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
1096 let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
1097 let cov_design = build_term_collection_design(cov_input, &termspec)
1098 .map_err(|e| format!("failed to build competing-risks prediction design: {e}"))?;
1099
1100 let n = data.nrows();
1101 if primary_offset.len() != n || noise_offset.len() != n {
1102 return Err(SurvivalPredictError::InvalidInput {
1103 reason: format!(
1104 "competing-risks prediction offset length mismatch: rows={n}, offset={}, noise_offset={}",
1105 primary_offset.len(),
1106 noise_offset.len()
1107 ),
1108 });
1109 }
1110
1111 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1112 let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
1113 .into_par_iter()
1114 .map(|i| {
1115 normalize_survival_time_pair(time_cols.row_entry_time(data, i), data[[i, exit_col]], i)
1116 })
1117 .collect();
1118 let pairs = pairs?;
1119 let mut age_entry = Array1::<f64>::zeros(n);
1120 let mut age_exit = Array1::<f64>::zeros(n);
1121 for (i, (t0, t1)) in pairs.into_iter().enumerate() {
1122 age_entry[i] = t0;
1123 age_exit[i] = t1;
1124 }
1125
1126 let time_cfg = load_survival_time_basis_config_from_model(model)?;
1127 let time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
1128 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
1129 &time_build.basisname,
1130 time_build.degree,
1131 time_build.knots.as_ref(),
1132 time_build.keep_cols.as_ref(),
1133 time_build.smooth_lambda,
1134 )?;
1135 let weibull_baseline_in_beta = saved_likelihood_mode == SurvivalLikelihoodMode::Weibull
1144 && !model.has_baseline_time_wiggle();
1145 let cr_time_anchor_row: Option<Array1<f64>> = if weibull_baseline_in_beta {
1146 let anchor = model
1147 .survival_time_anchor
1148 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1149 Some(evaluate_survival_time_basis_row(
1150 anchor,
1151 &resolved_time_cfg,
1152 )?)
1153 } else {
1154 None
1155 };
1156 if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull && !model.has_baseline_time_wiggle()
1157 {
1158 require_structural_survival_time_basis(
1159 &time_build.basisname,
1160 "saved competing-risks survival prediction",
1161 )?;
1162 }
1163 let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1164
1165 let per_row_eval = time_grid.is_none();
1166 let eval_times: Vec<f64> = match time_grid {
1167 Some(grid) => {
1168 if grid.is_empty() {
1169 return Err(SurvivalPredictError::InvalidInput {
1170 reason: "survival time_grid must contain at least one time".to_string(),
1171 });
1172 }
1173 for (idx, &t) in grid.iter().enumerate() {
1174 if !t.is_finite() || t < 0.0 {
1175 return Err(SurvivalPredictError::InvalidInput {
1176 reason: format!(
1177 "survival time_grid requires finite non-negative times (index {idx})",
1178 ),
1179 });
1180 }
1181 }
1182 grid.to_vec()
1183 }
1184 None => Vec::new(),
1185 };
1186 let t_cols = if per_row_eval { 1 } else { eval_times.len() };
1187
1188 const CIF_REFINE_SUBINTERVALS: usize = 32;
1207 let (refined_times, user_time_to_refined_index): (Vec<f64>, Vec<usize>) = if per_row_eval {
1208 (Vec::new(), Vec::new())
1209 } else {
1210 let mut refined: Vec<f64> = Vec::new();
1211 let mut user_index: Vec<usize> = Vec::with_capacity(eval_times.len());
1212 let mut prev = 0.0_f64;
1213 for &t_user in &eval_times {
1214 let gap = t_user - prev;
1219 if gap > 0.0 {
1220 for s in 1..CIF_REFINE_SUBINTERVALS {
1221 let t_mid = prev + gap * (s as f64) / (CIF_REFINE_SUBINTERVALS as f64);
1222 if refined.last().is_none_or(|&last| t_mid > last) {
1224 refined.push(t_mid);
1225 }
1226 }
1227 }
1228 if refined.last().is_none_or(|&last| t_user > last) {
1229 refined.push(t_user);
1230 }
1231 user_index.push(refined.len() - 1);
1232 prev = t_user;
1233 }
1234 (refined, user_index)
1235 };
1236 let refined_cols = refined_times.len();
1237
1238 let saved_timewiggle_by_cause = saved_cause_specific_timewiggles(model, &fit, cause_count)?;
1239 let cov_rows = (0..n)
1240 .map(|i| design_row_owned(&cov_design.design, i, "competing-risks covariate row"))
1241 .collect::<Result<Vec<_>, _>>()?;
1242
1243 let mut hazard = (0..cause_count)
1244 .map(|_| Array2::<f64>::zeros((n, t_cols)))
1245 .collect::<Vec<_>>();
1246 let mut survival = (0..cause_count)
1247 .map(|_| Array2::<f64>::zeros((n, t_cols)))
1248 .collect::<Vec<_>>();
1249 let mut cumulative_hazard = (0..cause_count)
1250 .map(|_| Array2::<f64>::zeros((n, t_cols)))
1251 .collect::<Vec<_>>();
1252 let mut cumulative_hazard_refined = (0..cause_count)
1255 .map(|_| Array2::<f64>::zeros((n, refined_cols)))
1256 .collect::<Vec<_>>();
1257 let mut linear_predictor = (0..cause_count)
1258 .map(|_| Array1::<f64>::zeros(n))
1259 .collect::<Vec<_>>();
1260
1261 struct CauseRow {
1262 cause: usize,
1263 row: usize,
1264 hazard: Vec<f64>,
1265 survival: Vec<f64>,
1266 cumulative: Vec<f64>,
1267 cumulative_refined: Vec<f64>,
1270 eta_exit: f64,
1271 }
1272
1273 let rows: Result<Vec<CauseRow>, SurvivalPredictError> = (0..cause_count * n)
1274 .into_par_iter()
1275 .map(|flat| {
1276 let cause = flat / n;
1277 let i = flat % n;
1278 let block = &fit.blocks[cause];
1279 let timewiggle = saved_timewiggle_by_cause[cause].as_ref();
1280 let evaluate_at = |t_query: f64| -> Result<(f64, f64, f64), SurvivalPredictError> {
1281 let t_entry = age_entry[i].min(t_query);
1282 let single_entry = Array1::from_elem(1, t_entry);
1283 let single_exit = Array1::from_elem(1, t_query);
1284 let mut row_time =
1285 build_survival_time_basis(&single_entry, &single_exit, time_cfg.clone(), None)?;
1286 if let Some(anchor_row) = cr_time_anchor_row.as_ref() {
1287 center_survival_time_designs_at_anchor(
1288 &mut row_time.x_entry_time,
1289 &mut row_time.x_exit_time,
1290 anchor_row,
1291 )?;
1292 }
1293 let (r_eta_exit, r_deriv_exit) = if weibull_baseline_in_beta {
1294 (0.0, 0.0)
1295 } else {
1296 let (_, eta_exit, deriv_exit) = build_survival_time_offsets_for_likelihood(
1297 &single_entry,
1298 &single_exit,
1299 &baseline_cfg,
1300 saved_likelihood_mode,
1301 None,
1302 )?;
1303 (eta_exit[0], deriv_exit[0])
1304 };
1305 evaluate_rp_row_with_beta(
1306 &block.beta,
1307 timewiggle,
1308 &row_time,
1309 &cov_rows[i],
1310 r_eta_exit,
1311 r_deriv_exit,
1312 primary_offset[i],
1313 )
1314 };
1315
1316 let mut out = CauseRow {
1317 cause,
1318 row: i,
1319 hazard: vec![0.0; t_cols],
1320 survival: vec![0.0; t_cols],
1321 cumulative: vec![0.0; t_cols],
1322 cumulative_refined: vec![0.0; refined_cols],
1323 eta_exit: 0.0,
1324 };
1325 if per_row_eval {
1326 let (eta_t, cum_t, haz_t) = evaluate_at(age_exit[i])?;
1327 out.eta_exit = eta_t;
1328 out.hazard[0] = haz_t;
1329 out.cumulative[0] = cum_t;
1330 out.survival[0] = (-cum_t).exp().clamp(0.0, 1.0);
1331 } else {
1332 for (j, &t_query) in eval_times.iter().enumerate() {
1333 if t_query <= 0.0 {
1340 out.hazard[j] = 0.0;
1341 out.cumulative[j] = 0.0;
1342 out.survival[j] = 1.0;
1343 } else {
1344 let (_eta_t, cum_t, haz_t) = evaluate_at(t_query)?;
1345 out.hazard[j] = haz_t;
1346 out.cumulative[j] = cum_t;
1347 out.survival[j] = (-cum_t).exp().clamp(0.0, 1.0);
1348 }
1349 }
1350 for (jr, &t_query) in refined_times.iter().enumerate() {
1356 out.cumulative_refined[jr] = if t_query <= 0.0 {
1357 0.0
1358 } else {
1359 evaluate_at(t_query)?.1
1360 };
1361 }
1362 let (eta_t, _, _) = evaluate_at(age_exit[i])?;
1363 out.eta_exit = eta_t;
1364 }
1365 Ok(out)
1366 })
1367 .collect();
1368
1369 for row in rows? {
1370 linear_predictor[row.cause][row.row] = row.eta_exit;
1371 for j in 0..t_cols {
1372 hazard[row.cause][[row.row, j]] = row.hazard[j];
1373 survival[row.cause][[row.row, j]] = row.survival[j];
1374 cumulative_hazard[row.cause][[row.row, j]] = row.cumulative[j];
1375 }
1376 for jr in 0..refined_cols {
1377 cumulative_hazard_refined[row.cause][[row.row, jr]] = row.cumulative_refined[jr];
1378 }
1379 }
1380
1381 let assembled = if per_row_eval {
1385 let assembly_times = Array1::from_elem(1, 0.0);
1386 assemble_competing_risks_cif_from_endpoints(assembly_times.view(), &cumulative_hazard)
1387 .map_err(|err| err.to_string())?
1388 } else {
1389 let assembly_times = Array1::from_vec(refined_times.clone());
1390 let refined_assembled = assemble_competing_risks_cif_from_endpoints(
1391 assembly_times.view(),
1392 &cumulative_hazard_refined,
1393 )
1394 .map_err(|err| err.to_string())?;
1395 let mut cif_user = (0..cause_count)
1397 .map(|_| Array2::<f64>::zeros((n, t_cols)))
1398 .collect::<Vec<_>>();
1399 let mut overall_user = Array2::<f64>::zeros((n, t_cols));
1400 for (j_user, &jr) in user_time_to_refined_index.iter().enumerate() {
1401 for cause in 0..cause_count {
1402 for row in 0..n {
1403 cif_user[cause][[row, j_user]] = refined_assembled.cif[cause][[row, jr]];
1404 }
1405 }
1406 for row in 0..n {
1407 overall_user[[row, j_user]] = refined_assembled.overall_survival[[row, jr]];
1408 }
1409 }
1410 CompetingRisksCifResult {
1411 cif: cif_user,
1412 overall_survival: overall_user,
1413 }
1414 };
1415 if assembled.cif.len() != cause_count {
1416 return Err(format!(
1417 "competing-risks CIF assembly produced {} endpoint matrices, expected {cause_count}",
1418 assembled.cif.len()
1419 )
1420 .into());
1421 }
1422 let cif = assembled.cif;
1423 let overall_survival = assembled.overall_survival;
1424 let times_out = if per_row_eval {
1425 age_exit.to_vec()
1426 } else {
1427 eval_times
1428 };
1429 Ok(CompetingRisksPredictResult {
1430 times: times_out,
1431 endpoint_names,
1432 hazard,
1433 survival,
1434 cumulative_hazard,
1435 cif,
1436 overall_survival,
1437 linear_predictor,
1438 likelihood_mode: saved_likelihood_mode,
1439 })
1440}
1441
1442fn saved_cause_specific_timewiggles(
1443 model: &SavedModel,
1444 fit: &UnifiedFitResult,
1445 cause_count: usize,
1446) -> Result<Vec<Option<SavedBaselineTimeWiggleRuntime>>, SurvivalPredictError> {
1447 let has_metadata = model.baseline_timewiggle_knots.is_some()
1448 || model.baseline_timewiggle_degree.is_some()
1449 || model.baseline_timewiggle_penalty_orders.is_some()
1450 || model.baseline_timewiggle_double_penalty.is_some()
1451 || model.beta_baseline_timewiggle_by_cause.is_some();
1452 if !has_metadata {
1453 return Ok(vec![None; cause_count]);
1454 }
1455 let knots = model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1456 "joint cause-specific survival missing baseline_timewiggle_knots".to_string()
1457 })?;
1458 let degree = model.baseline_timewiggle_degree.ok_or_else(|| {
1459 "joint cause-specific survival missing baseline_timewiggle_degree".to_string()
1460 })?;
1461 let penalty_orders = model
1462 .baseline_timewiggle_penalty_orders
1463 .clone()
1464 .ok_or_else(|| {
1465 "joint cause-specific survival missing baseline_timewiggle_penalty_orders".to_string()
1466 })?;
1467 let double_penalty = model.baseline_timewiggle_double_penalty.ok_or_else(|| {
1468 "joint cause-specific survival missing baseline_timewiggle_double_penalty".to_string()
1469 })?;
1470 let by_cause = model
1471 .beta_baseline_timewiggle_by_cause
1472 .as_ref()
1473 .ok_or_else(|| {
1474 "joint cause-specific survival missing beta_baseline_timewiggle_by_cause".to_string()
1475 })?;
1476 if by_cause.len() != cause_count {
1477 return Err(SurvivalPredictError::IncompatibleSchema {
1478 reason: format!(
1479 "joint cause-specific survival has {} timewiggle coefficient blocks, expected {cause_count}",
1480 by_cause.len()
1481 ),
1482 });
1483 }
1484 for (cause, (block, beta_w)) in fit.blocks.iter().zip(by_cause).enumerate() {
1485 if beta_w.len() > block.beta.len() {
1486 return Err(SurvivalPredictError::IncompatibleSchema {
1487 reason: format!(
1488 "joint cause-specific survival cause {} timewiggle beta has length {}, but endpoint beta has {} coefficients",
1489 cause + 1,
1490 beta_w.len(),
1491 block.beta.len()
1492 ),
1493 });
1494 }
1495 }
1496 Ok(by_cause
1497 .iter()
1498 .map(|beta| {
1499 Some(SavedBaselineTimeWiggleRuntime {
1500 knots: knots.clone(),
1501 degree,
1502 penalty_orders: penalty_orders.clone(),
1503 double_penalty,
1504 beta: beta.clone(),
1505 })
1506 })
1507 .collect())
1508}
1509
1510struct MarginalSlopePredictContext {
1518 predictor: BernoulliMarginalSlopePredictor,
1519 beta_time: Array1<f64>,
1521 beta_marginal: Array1<f64>,
1523 saved_timewiggle: Option<SavedBaselineTimeWiggleRuntime>,
1524 cov_design: DesignMatrix,
1526 logslope_design: DesignMatrix,
1528 cov_eta: Array1<f64>,
1531 z_raw: Array1<f64>,
1534 noise_offset: Array1<f64>,
1537}
1538
1539fn design_row_owned(
1540 design: &DesignMatrix,
1541 row: usize,
1542 context: &str,
1543) -> Result<Array1<f64>, SurvivalPredictError> {
1544 let chunk = design
1545 .try_row_chunk(row..row + 1)
1546 .map_err(|e| format!("{context}: {e}"))?;
1547 Ok(chunk.row(0).to_owned())
1548}
1549
1550fn build_marginal_slope_predict_context(
1551 model: &SavedModel,
1552 data: ArrayView2<'_, f64>,
1553 col_map: &HashMap<String, usize>,
1554 training_headers: Option<&Vec<String>>,
1555 cov_design: &DesignMatrix,
1556 primary_offset: &Array1<f64>,
1557 noise_offset: &Array1<f64>,
1558 time_build: &SurvivalTimeBuildOutput,
1559 eta_offset_entry: &Array1<f64>,
1560 eta_offset_exit: &Array1<f64>,
1561 derivative_offset_exit: &Array1<f64>,
1562) -> Result<MarginalSlopePredictContext, SurvivalPredictError> {
1563 let z_name = model
1564 .z_column
1565 .as_ref()
1566 .ok_or_else(|| "saved survival marginal-slope model missing z_column".to_string())?;
1567 let z_col = resolve_role_col(col_map, z_name, "z")?;
1568 let z_raw = data.column(z_col).to_owned();
1569
1570 let logslopespec = resolve_termspec_for_prediction(
1571 &model.resolved_termspec_logslope.as_ref().cloned(),
1572 training_headers,
1573 col_map,
1574 "resolved_termspec_logslope",
1575 )?;
1576 let logslope_clipped = model.axis_clip_to_training_ranges(data, col_map);
1577 let logslope_input = logslope_clipped.as_ref().map_or(data, |arr| arr.view());
1578 let logslope_design = build_term_collection_design(logslope_input, &logslopespec)
1579 .map_err(|e| format!("failed to build survival marginal-slope logslope design: {e}"))?;
1580
1581 let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1582 let (predictor, _pred_input, _predictor_fit) = build_saved_survival_marginal_slope_predictor(
1583 model,
1584 &fit_saved,
1585 z_name,
1586 &z_raw,
1587 cov_design,
1588 &logslope_design.design,
1589 time_build,
1590 eta_offset_entry,
1591 eta_offset_exit,
1592 derivative_offset_exit,
1593 primary_offset,
1594 noise_offset,
1595 )?;
1596
1597 let blocks = &fit_saved.blocks;
1598 if blocks.len() < 3 {
1599 return Err(SurvivalPredictError::IncompatibleSchema {
1600 reason: format!(
1601 "saved survival marginal-slope model requires at least 3 blocks [time, marginal, slope], got {}",
1602 blocks.len()
1603 ),
1604 });
1605 }
1606 let beta_time = blocks[0].beta.clone();
1607 let beta_marginal = blocks[1].beta.clone();
1608 let saved_runtime = model.saved_prediction_runtime()?;
1609 let saved_timewiggle = saved_runtime.baseline_time_wiggle.clone();
1610
1611 let cov_eta = cov_design.dot(&beta_marginal);
1614
1615 Ok(MarginalSlopePredictContext {
1616 predictor,
1617 beta_time,
1618 beta_marginal,
1619 saved_timewiggle,
1620 cov_design: cov_design.clone(),
1621 logslope_design: logslope_design.design.clone(),
1622 cov_eta,
1623 z_raw,
1624 noise_offset: noise_offset.clone(),
1625 })
1626}
1627
1628fn evaluate_marginal_slope_row(
1639 row_index: usize,
1640 ctx: &MarginalSlopePredictContext,
1641 row_time: &SurvivalTimeBuildOutput,
1642 r_eta_exit: &Array1<f64>,
1643 r_deriv_exit: &Array1<f64>,
1644 primary_offset_row: f64,
1645) -> Result<(f64, f64, f64), SurvivalPredictError> {
1646 let beta_time = &ctx.beta_time;
1647 let p_time_base = row_time.x_exit_time.ncols();
1648 let p_timewiggle = ctx
1649 .saved_timewiggle
1650 .as_ref()
1651 .map_or(0, |runtime| runtime.beta.len());
1652 if beta_time.len() != p_time_base + p_timewiggle {
1653 return Err(SurvivalPredictError::IncompatibleSchema {
1654 reason: format!(
1655 "saved survival marginal-slope time coefficient mismatch: beta has {} entries but expected base={} plus timewiggle={}",
1656 beta_time.len(),
1657 p_time_base,
1658 p_timewiggle
1659 ),
1660 });
1661 }
1662 let beta_time_base = beta_time.slice(s![..p_time_base]).to_owned();
1663
1664 let q_exit_base = row_time.x_exit_time.dot(&beta_time_base)[0]
1669 + ctx.cov_eta[row_index]
1670 + r_eta_exit[0]
1671 + primary_offset_row;
1672 let qd_exit_base = row_time.x_derivative_time.dot(&beta_time_base)[0] + r_deriv_exit[0];
1673
1674 let (qd_with_wiggle, exit_wiggle_design) = if let Some(runtime) = ctx.saved_timewiggle.as_ref()
1678 {
1679 let knots = Array1::from_vec(runtime.knots.clone());
1680 let beta_w = beta_time.slice(s![p_time_base..]).to_owned();
1681 let eta_exit_row = Array1::from_elem(1, q_exit_base);
1682 let deriv_row = Array1::from_elem(1, qd_exit_base);
1683 let exit_design = match buildwiggle_block_input_from_knots(
1684 eta_exit_row.view(),
1685 &knots,
1686 runtime.degree,
1687 2,
1688 false,
1689 )?
1690 .design
1691 {
1692 DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
1693 _ => {
1694 return Err(SurvivalPredictError::IncompatibleSchema {
1695 reason: "saved baseline-timewiggle exit design must be dense".to_string(),
1696 });
1697 }
1698 };
1699 let derivative_design = build_survival_timewiggle_derivative_design(
1700 &eta_exit_row,
1701 &deriv_row,
1702 &knots,
1703 runtime.degree,
1704 )?;
1705 (
1706 qd_exit_base + derivative_design.dot(&beta_w)[0],
1707 Some(exit_design),
1708 )
1709 } else {
1710 (qd_exit_base, None)
1711 };
1712
1713 let cov_dim = ctx.beta_marginal.len();
1722 let q_design_ncols = p_time_base + p_timewiggle + cov_dim;
1723 let mut q_design_full = Array2::<f64>::zeros((1, q_design_ncols));
1724 q_design_full
1725 .slice_mut(s![.., ..p_time_base])
1726 .assign(&row_time.x_exit_time.to_dense());
1727 if let Some(exit_w) = exit_wiggle_design.as_ref() {
1728 q_design_full
1729 .slice_mut(s![.., p_time_base..p_time_base + p_timewiggle])
1730 .assign(exit_w);
1731 }
1732 if cov_dim > 0 {
1733 let cov_row = design_row_owned(
1734 &ctx.cov_design,
1735 row_index,
1736 "survival marginal covariate row",
1737 )?;
1738 q_design_full
1739 .slice_mut(s![.., p_time_base + p_timewiggle..])
1740 .row_mut(0)
1741 .assign(&cov_row);
1742 }
1743
1744 let logslope_row = design_row_owned(
1751 &ctx.logslope_design,
1752 row_index,
1753 "survival marginal logslope row",
1754 )?;
1755 let mut logslope_design_2d = Array2::<f64>::zeros((1, logslope_row.len()));
1756 logslope_design_2d.row_mut(0).assign(&logslope_row);
1757
1758 let pred_input = PredictInput {
1759 design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(q_design_full)),
1760 offset: Array1::from_elem(1, r_eta_exit[0] + primary_offset_row),
1761 design_noise: Some(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1762 logslope_design_2d,
1763 ))),
1764 offset_noise: Some(Array1::from_elem(1, ctx.noise_offset[row_index])),
1765 auxiliary_scalar: Some(Array1::from_elem(1, ctx.z_raw[row_index])),
1766 auxiliary_matrix: None,
1767 };
1768
1769 let (eta_arr, deta_dq_arr) = ctx
1773 .predictor
1774 .predict_eta_and_q_chain(&pred_input)
1775 .map_err(|e| format!("saved survival marginal-slope predictor eta failed: {e}"))?;
1776 let eta = eta_arr[0];
1777 let eta_derivative = marginal_slope_index_derivative_at_horizon(deta_dq_arr[0], qd_with_wiggle);
1797 let (cum, haz) = probit_survival_hazard_components(eta, eta_derivative)?;
1798 Ok((eta, cum, haz))
1799}
1800
1801#[inline]
1814fn marginal_slope_index_derivative_at_horizon(deta_dq: f64, qd_with_wiggle: f64) -> f64 {
1815 let eta_derivative = deta_dq * qd_with_wiggle;
1816 if eta_derivative.is_finite() {
1817 eta_derivative.max(0.0)
1818 } else {
1819 eta_derivative
1820 }
1821}
1822
1823#[inline]
1824fn probit_survival_hazard_components(
1825 eta: f64,
1826 eta_derivative: f64,
1827) -> Result<(f64, f64), SurvivalPredictError> {
1828 if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative >= 0.0) {
1829 return Err(SurvivalPredictError::NumericalFailure {
1830 reason: format!(
1831 "saved survival marginal-slope prediction produced invalid survival index derivative: eta={eta}, eta_t={eta_derivative}"
1832 ),
1833 });
1834 }
1835
1836 let (log_survival, mills_ratio) = signed_probit_logcdf_and_mills_ratio(-eta);
1841 let cumulative_hazard = -log_survival;
1842 let hazard = if eta_derivative == 0.0 {
1843 0.0
1844 } else {
1845 mills_ratio * eta_derivative
1846 };
1847 if !(cumulative_hazard >= 0.0 && hazard >= 0.0) {
1854 return Err(SurvivalPredictError::NumericalFailure {
1855 reason: format!(
1856 "saved survival marginal-slope prediction produced invalid survival components: eta={eta}, eta_t={eta_derivative}, log_survival={log_survival}, hazard={hazard}"
1857 ),
1858 });
1859 }
1860 Ok((cumulative_hazard, hazard))
1861}
1862
1863fn evaluate_rp_row(
1864 model: &SavedModel,
1865 row_time: &SurvivalTimeBuildOutput,
1866 cov_row: &Array1<f64>,
1867 eta_time_offset_row: f64,
1868 derivative_time_offset_row: f64,
1869 primary_offset_row: f64,
1870) -> Result<(f64, f64, f64), SurvivalPredictError> {
1871 let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1872 let saved_runtime = model.saved_prediction_runtime()?;
1873 evaluate_rp_row_with_beta(
1874 &fit_saved.beta,
1875 saved_runtime.baseline_time_wiggle.as_ref(),
1876 row_time,
1877 cov_row,
1878 eta_time_offset_row,
1879 derivative_time_offset_row,
1880 primary_offset_row,
1881 )
1882}
1883
1884fn evaluate_rp_row_with_beta(
1885 beta: &Array1<f64>,
1886 saved_timewiggle: Option<&SavedBaselineTimeWiggleRuntime>,
1887 row_time: &SurvivalTimeBuildOutput,
1888 cov_row: &Array1<f64>,
1889 eta_time_offset_row: f64,
1890 derivative_time_offset_row: f64,
1891 primary_offset_row: f64,
1892) -> Result<(f64, f64, f64), SurvivalPredictError> {
1893 let p_time = row_time.x_exit_time.ncols();
1894 let p_timewiggle = saved_timewiggle.map_or(0, |runtime| runtime.beta.len());
1895 let p_cov = cov_row.len();
1896 let p = p_time + p_timewiggle + p_cov;
1897 if beta.len() != p {
1898 return Err(SurvivalPredictError::IncompatibleSchema {
1899 reason: format!(
1900 "survival RP coefficient mismatch: beta has {} entries but design has {} columns",
1901 beta.len(),
1902 p
1903 ),
1904 });
1905 }
1906 let mut x_exit = Array2::<f64>::zeros((1, p));
1907 if p_time > 0 {
1908 x_exit
1909 .slice_mut(s![.., ..p_time])
1910 .assign(&row_time.x_exit_time.to_dense());
1911 }
1912 let mut eta_derivative = derivative_time_offset_row;
1913 if p_time > 0 {
1914 eta_derivative += row_time
1915 .x_derivative_time
1916 .dot(&beta.slice(s![..p_time]).to_owned())[0];
1917 }
1918 if let Some(runtime) = saved_timewiggle {
1919 let knots = Array1::from_vec(runtime.knots.clone());
1920 let beta_w = beta.slice(s![p_time..p_time + p_timewiggle]).to_owned();
1921 let eta_exit_row = Array1::from_elem(1, eta_time_offset_row);
1922 let derivative_exit_row = Array1::from_elem(1, derivative_time_offset_row);
1923 let exit_design = match buildwiggle_block_input_from_knots(
1924 eta_exit_row.view(),
1925 &knots,
1926 runtime.degree,
1927 2,
1928 false,
1929 )?
1930 .design
1931 {
1932 DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
1933 _ => {
1934 return Err(SurvivalPredictError::IncompatibleSchema {
1935 reason: "saved baseline-timewiggle exit design must be dense".to_string(),
1936 });
1937 }
1938 };
1939 if exit_design.ncols() != p_timewiggle {
1940 return Err(SurvivalPredictError::IncompatibleSchema {
1941 reason: format!(
1942 "survival RP timewiggle design mismatch: rebuilt {} columns but runtime expects {}",
1943 exit_design.ncols(),
1944 p_timewiggle
1945 ),
1946 });
1947 }
1948 x_exit
1949 .slice_mut(s![.., p_time..p_time + p_timewiggle])
1950 .assign(&exit_design);
1951 let derivative_design = build_survival_timewiggle_derivative_design(
1952 &eta_exit_row,
1953 &derivative_exit_row,
1954 &knots,
1955 runtime.degree,
1956 )?;
1957 eta_derivative += derivative_design.dot(&beta_w)[0];
1958 }
1959 if p_cov > 0 {
1960 x_exit
1961 .slice_mut(s![
1962 ..,
1963 (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov)
1964 ])
1965 .row_mut(0)
1966 .assign(cov_row);
1967 }
1968 let offset_view = Array1::from_elem(1, eta_time_offset_row + primary_offset_row);
1969 let likelihood = LikelihoodSpec::new(
1970 ResponseFamily::RoystonParmar,
1971 InverseLink::Standard(StandardLink::Identity),
1972 );
1973 let eta =
1974 predict_royston_parmar_eta(x_exit.view(), beta.view(), offset_view.view(), &likelihood)?[0];
1975 let (cum, haz) = royston_parmar_survival_hazard_components(eta, eta_derivative)?;
1976 Ok((eta, cum, haz))
1977}
1978
1979fn predict_royston_parmar_eta<X>(
1980 x: X,
1981 beta: ndarray::ArrayView1<'_, f64>,
1982 offset: ndarray::ArrayView1<'_, f64>,
1983 likelihood: &LikelihoodSpec,
1984) -> Result<Array1<f64>, SurvivalPredictError>
1985where
1986 X: Into<DesignMatrix>,
1987{
1988 if !matches!(likelihood.response, ResponseFamily::RoystonParmar)
1989 || !matches!(
1990 likelihood.link,
1991 InverseLink::Standard(StandardLink::Identity)
1992 )
1993 {
1994 return Err(SurvivalPredictError::UnsupportedConfiguration {
1995 reason: "survival prediction requires RoystonParmar with identity link".to_string(),
1996 });
1997 }
1998 let x = x.into();
1999 if x.nrows() != offset.len() || x.ncols() != beta.len() {
2000 return Err(SurvivalPredictError::IncompatibleSchema {
2001 reason: format!(
2002 "survival prediction design dimensions disagree: design is {}x{}, beta has length {}, offset has length {}",
2003 x.nrows(),
2004 x.ncols(),
2005 beta.len(),
2006 offset.len()
2007 ),
2008 });
2009 }
2010 let mut eta = x.matrixvectormultiply(&beta.to_owned());
2011 eta += &offset;
2012 Ok(eta)
2013}
2014
2015#[inline]
2016fn royston_parmar_survival_hazard_components(
2017 eta: f64,
2018 eta_derivative: f64,
2019) -> Result<(f64, f64), SurvivalPredictError> {
2020 if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative >= 0.0) {
2036 return Err(SurvivalPredictError::NumericalFailure {
2037 reason: format!(
2038 "saved Royston-Parmar survival prediction produced invalid log-cumulative-hazard derivative: eta={eta}, eta_t={eta_derivative}"
2039 ),
2040 });
2041 }
2042 let cumulative_hazard = eta.exp();
2043 let hazard = if eta_derivative == 0.0 {
2049 0.0
2050 } else {
2051 cumulative_hazard * eta_derivative
2052 };
2053 if !(cumulative_hazard >= 0.0 && hazard >= 0.0) {
2062 return Err(SurvivalPredictError::NumericalFailure {
2063 reason: format!(
2064 "saved Royston-Parmar survival prediction produced invalid survival components: eta={eta}, eta_t={eta_derivative}, cumulative_hazard={cumulative_hazard}, hazard={hazard}"
2065 ),
2066 });
2067 }
2068 Ok((cumulative_hazard, hazard))
2069}
2070
2071fn predict_survival_location_scale_batch(
2081 model: &SavedModel,
2082 age_entry: &Array1<f64>,
2083 age_exit: &Array1<f64>,
2084 cov_design: &gam_terms::smooth::TermCollectionDesign,
2085 primary_offset: &Array1<f64>,
2086 noise_offset: &Array1<f64>,
2087 training_headers: Option<&Vec<String>>,
2088 col_map: &HashMap<String, usize>,
2089 data: ArrayView2<'_, f64>,
2090 time_grid: Option<&[f64]>,
2091 with_uncertainty: bool,
2092) -> Result<SurvivalPredictResult, String> {
2093 use crate::scale_design::build_scale_deviation_operator;
2094 use crate::survival::construction::evaluate_survival_time_basis_row;
2095 use crate::survival::location_scale::{
2096 SurvivalLocationScalePredictInput, predict_survival_location_scale,
2097 predict_survival_location_scale_from_linear_components,
2098 predict_survival_location_scalewith_uncertainty,
2099 };
2100 use gam_linalg::matrix::DesignMatrix;
2101
2102 let n = age_entry.len();
2103 let per_row_eval = time_grid.is_none();
2104 let eval_times: Vec<f64> = match time_grid {
2105 Some(grid) => {
2106 if grid.is_empty() {
2107 return Err("survival time_grid must contain at least one time".to_string());
2108 }
2109 for (idx, &t) in grid.iter().enumerate() {
2110 if !t.is_finite() || t < 0.0 {
2111 return Err(format!(
2112 "survival time_grid requires finite non-negative times (index {idx})",
2113 ));
2114 }
2115 }
2116 grid.to_vec()
2117 }
2118 None => Vec::new(),
2119 };
2120 let t_cols = if per_row_eval { 1 } else { eval_times.len() };
2121 let eval_width = if per_row_eval { 1 } else { t_cols + 1 };
2122 let saved_likelihood_mode = SurvivalLikelihoodMode::LocationScale;
2123 let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
2124 let saved_fit = saved_survival_location_scale_fit_result(model)?;
2125 let reduced_parametric_aft =
2144 !model.has_baseline_time_wiggle() && saved_fit.beta_time().iter().all(|&b| b == 0.0);
2145 let time_cfg = load_survival_time_basis_config_from_model(model)?;
2146 let mut time_build = build_survival_time_basis(age_entry, age_exit, time_cfg.clone(), None)?;
2147 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
2148 &time_build.basisname,
2149 time_build.degree,
2150 time_build.knots.as_ref(),
2151 time_build.keep_cols.as_ref(),
2152 time_build.smooth_lambda,
2153 )?;
2154 let time_anchor = model
2155 .survival_time_anchor
2156 .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
2157 let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
2158 center_survival_time_designs_at_anchor(
2159 &mut time_build.x_entry_time,
2160 &mut time_build.x_exit_time,
2161 &time_anchor_row,
2162 )?;
2163 if !model.has_baseline_time_wiggle() && !reduced_parametric_aft {
2167 require_structural_survival_time_basis(&time_build.basisname, "saved survival sampling")?;
2168 }
2169 let saved_inverse_link = resolve_survival_inverse_link_from_saved(model)?;
2170 let (eval_entry, eval_exit) = if per_row_eval {
2171 (age_entry.clone(), age_exit.clone())
2172 } else {
2173 let total = n * eval_width;
2174 let mut entry = Array1::<f64>::zeros(total);
2175 let mut exit = Array1::<f64>::zeros(total);
2176 {
2177 use rayon::iter::{IntoParallelIterator, ParallelIterator};
2178 let pairs: Vec<(f64, f64)> = (0..total)
2179 .into_par_iter()
2180 .map(|k| {
2181 let i = k / eval_width;
2182 let col = k % eval_width;
2183 let t = if col < t_cols {
2184 eval_times[col]
2185 } else {
2186 age_exit[i]
2187 };
2188 (age_entry[i].min(t), t)
2189 })
2190 .collect();
2191 for (k, (t0, t1)) in pairs.into_iter().enumerate() {
2192 entry[k] = t0;
2193 exit[k] = t1;
2194 }
2195 }
2196 (entry, exit)
2197 };
2198 let mut time_build =
2199 build_survival_time_basis(&eval_entry, &eval_exit, time_cfg.clone(), None)?;
2200 center_survival_time_designs_at_anchor(
2201 &mut time_build.x_entry_time,
2202 &mut time_build.x_exit_time,
2203 &time_anchor_row,
2204 )?;
2205 let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
2206 build_survival_time_offsets_for_likelihood(
2207 &eval_entry,
2208 &eval_exit,
2209 &baseline_cfg,
2210 saved_likelihood_mode,
2211 Some(&saved_inverse_link),
2212 )?;
2213 add_survival_time_derivative_guard_offset(
2214 &eval_entry,
2215 &eval_exit,
2216 time_anchor,
2217 survival_derivative_guard_for_likelihood(saved_likelihood_mode),
2218 &mut eta_offset_entry,
2219 &mut eta_offset_exit,
2220 &mut derivative_offset_exit,
2221 )?;
2222 if reduced_parametric_aft {
2223 eta_offset_exit = Array1::<f64>::zeros(eval_exit.len());
2235 }
2236
2237 let saved_timewiggle_runtime = model.saved_baseline_time_wiggle()?;
2238
2239 let threshold_design = cov_design;
2245 let log_sigmaspec = resolve_termspec_for_prediction(
2246 &model.resolved_termspec_noise,
2247 training_headers,
2248 col_map,
2249 "resolved_termspec_noise",
2250 )?;
2251 let sigma_clipped = model.axis_clip_to_training_ranges(data, col_map);
2252 let sigma_input = sigma_clipped.as_ref().map_or(data, |arr| arr.view());
2253 let raw_sigma_design =
2254 gam_terms::smooth::build_term_collection_design(sigma_input, &log_sigmaspec)
2255 .map_err(|err| format!("failed to build survival log-sigma design: {err}"))?;
2256 let survival_noise_transform = scale_transform_from_payload(
2257 &model.survival_noise_projection,
2258 &model.survival_noise_center,
2259 &model.survival_noise_scale,
2260 model.survival_noise_non_intercept_start,
2261 model.survival_noise_projection_ridge_alpha,
2262 )?;
2263
2264 let x_time_exit_dense = time_build
2265 .x_exit_time
2266 .try_to_dense_by_chunks("survival location-scale prediction time-exit design")?;
2267 let total_rows = eval_exit.len();
2268 let x_time_exit = if let Some(runtime) = saved_timewiggle_runtime.as_ref() {
2269 let mut full =
2270 Array2::<f64>::zeros((total_rows, x_time_exit_dense.ncols() + runtime.beta.len()));
2271 full.slice_mut(s![.., 0..x_time_exit_dense.ncols()])
2272 .assign(&x_time_exit_dense);
2273 full
2274 } else {
2275 x_time_exit_dense
2276 };
2277
2278 let repeat_rows =
2279 |matrix: &DesignMatrix, label: &str| -> Result<DesignMatrix, SurvivalPredictError> {
2280 if per_row_eval {
2281 return Ok(matrix.clone());
2282 }
2283 let dense = matrix.try_to_dense_by_chunks(label)?;
2284 let mut repeated = Array2::<f64>::zeros((total_rows, dense.ncols()));
2285 use rayon::iter::{IntoParallelIterator, ParallelIterator};
2286 let rows: Vec<Vec<f64>> = (0..total_rows)
2287 .into_par_iter()
2288 .map(|k| dense.row(k / eval_width).to_vec())
2289 .collect();
2290 for (k, row) in rows.into_iter().enumerate() {
2291 for (j, value) in row.into_iter().enumerate() {
2292 repeated[[k, j]] = value;
2293 }
2294 }
2295 Ok(DesignMatrix::from(repeated))
2296 };
2297 let threshold_matrix = repeat_rows(
2298 &threshold_design.design,
2299 "survival location-scale prediction threshold design",
2300 )?;
2301 let raw_sigma_matrix = repeat_rows(
2302 &raw_sigma_design.design,
2303 "survival location-scale prediction log-sigma design",
2304 )?;
2305
2306 let time_design = DesignMatrix::from(x_time_exit.clone());
2311 let survival_primary_design =
2312 DesignMatrix::hstack(vec![time_design, threshold_matrix.clone()])?;
2313 let prepared_sigma_design = if let Some(transform) = survival_noise_transform.as_ref() {
2314 build_scale_deviation_operator(survival_primary_design, raw_sigma_matrix, transform)?
2315 } else {
2316 raw_sigma_matrix
2317 };
2318 let link_wiggle_knots = model
2319 .linkwiggle_knots
2320 .as_ref()
2321 .map(|k| Array1::from_vec(k.clone()));
2322 let link_wiggle_degree = model.linkwiggle_degree;
2323 let time_wiggle_knots = saved_timewiggle_runtime
2324 .as_ref()
2325 .map(|w| Array1::from_vec(w.knots.clone()));
2326 let time_wiggle_degree = saved_timewiggle_runtime.as_ref().map(|w| w.degree);
2327 let time_wiggle_ncols = saved_timewiggle_runtime
2328 .as_ref()
2329 .map_or(0, |w| w.beta.len());
2330
2331 let expand_vector = |values: &Array1<f64>| -> Array1<f64> {
2332 if per_row_eval {
2333 values.clone()
2334 } else {
2335 Array1::from_shape_fn(total_rows, |k| values[k / eval_width])
2336 }
2337 };
2338 let eta_threshold_offset = {
2347 let mut offset = expand_vector(primary_offset);
2348 if reduced_parametric_aft {
2349 for (slot, &t) in offset.iter_mut().zip(eval_exit.iter()) {
2350 *slot -= t
2351 .max(crate::survival::construction::SURVIVAL_TIME_FLOOR)
2352 .ln();
2353 }
2354 }
2355 offset
2356 };
2357 let pred_input = SurvivalLocationScalePredictInput {
2362 x_time_exit,
2363 eta_time_offset_exit: eta_offset_exit.clone(),
2364 time_wiggle_knots: time_wiggle_knots.clone(),
2365 time_wiggle_degree,
2366 time_wiggle_ncols,
2367 x_threshold: threshold_matrix,
2368 eta_threshold_offset,
2369 x_log_sigma: prepared_sigma_design,
2370 eta_log_sigma_offset: expand_vector(noise_offset),
2371 x_link_wiggle: None,
2372 link_wiggle_knots: link_wiggle_knots.clone(),
2373 link_wiggle_degree,
2374 inverse_link: saved_inverse_link.clone(),
2375 };
2376
2377 let (eta_full, survival_prob_full, response_se_full, eta_se_full): (
2380 Array1<f64>,
2381 Array1<f64>,
2382 Option<Array1<f64>>,
2383 Option<Array1<f64>>,
2384 ) = if with_uncertainty {
2385 let cov = saved_fit.beta_covariance().ok_or_else(|| {
2386 "survival location-scale uncertainty: saved fit is missing the \
2387 posterior covariance; refit with the current CLI / library to \
2388 populate beta_covariance"
2389 .to_string()
2390 })?;
2391 let unc = predict_survival_location_scalewith_uncertainty(
2392 &pred_input,
2393 &saved_fit,
2394 cov,
2395 false,
2396 true,
2397 )
2398 .map_err(|err| format!("survival location-scale uncertainty predict failed: {err}"))?;
2399 let response_se = unc.response_standard_error.ok_or_else(|| {
2400 "survival location-scale uncertainty: response_standard_error \
2401 missing despite include_response_sd=true"
2402 .to_string()
2403 })?;
2404 (
2405 unc.eta,
2406 unc.survival_prob,
2407 Some(response_se),
2408 Some(unc.eta_standard_error),
2409 )
2410 } else if per_row_eval {
2411 let pred = predict_survival_location_scale(&pred_input, &saved_fit)
2412 .map_err(|err| format!("survival location-scale predict failed: {err}"))?;
2413 (pred.eta, pred.survival_prob, None, None)
2414 } else {
2415 let beta_threshold = saved_fit.beta_threshold();
2416 let beta_log_sigma = saved_fit.beta_log_sigma();
2417 let eta_t_subject =
2418 cov_design.design.matrixvectormultiply(&beta_threshold) + primary_offset;
2419 let eta_ls_subject = prepared_sigma_design_view(&pred_input)
2423 .matrixvectormultiply(&beta_log_sigma)
2424 + &pred_input.eta_log_sigma_offset;
2425 let mut eta_t = expand_vector(&eta_t_subject);
2431 if reduced_parametric_aft {
2432 for (slot, &t) in eta_t.iter_mut().zip(eval_exit.iter()) {
2433 *slot -= t
2434 .max(crate::survival::construction::SURVIVAL_TIME_FLOOR)
2435 .ln();
2436 }
2437 }
2438 let pred = predict_survival_location_scale_from_linear_components(
2439 &pred_input.x_time_exit,
2440 &eta_offset_exit,
2441 time_wiggle_knots.as_ref(),
2442 time_wiggle_degree,
2443 time_wiggle_ncols,
2444 &eta_t,
2445 &eta_ls_subject,
2446 link_wiggle_knots.as_ref(),
2447 link_wiggle_degree,
2448 &saved_inverse_link,
2449 &saved_fit,
2450 )
2451 .map_err(|err| format!("survival location-scale predict failed: {err}"))?;
2452 (pred.eta, pred.survival_prob, None, None)
2453 };
2454
2455 let eta_derivative_full = if reduced_parametric_aft {
2456 use crate::sigma_link::exp_sigma_inverse_from_eta_scalar;
2463 let beta_log_sigma = saved_fit.beta_log_sigma();
2464 let eta_ls = prepared_sigma_design_view(&pred_input).matrixvectormultiply(&beta_log_sigma)
2465 + &pred_input.eta_log_sigma_offset;
2466 let mut deriv = Array1::<f64>::zeros(eval_exit.len());
2467 for (k, slot) in deriv.iter_mut().enumerate() {
2468 let inv_sigma = exp_sigma_inverse_from_eta_scalar(eta_ls[k]);
2469 let t = eval_exit[k].max(crate::survival::construction::SURVIVAL_TIME_FLOOR);
2470 *slot = inv_sigma / t;
2471 }
2472 deriv
2473 } else {
2474 let x_time_derivative = time_build
2475 .x_derivative_time
2476 .try_to_dense_by_chunks("survival location-scale prediction time-derivative design")?;
2477 location_scale_eta_derivative_components(
2478 &x_time_derivative,
2479 &derivative_offset_exit,
2480 &pred_input.x_time_exit,
2481 &pred_input.eta_time_offset_exit,
2482 time_wiggle_knots.as_ref(),
2483 time_wiggle_degree,
2484 time_wiggle_ncols,
2485 &saved_fit,
2486 )?
2487 };
2488 let hazard_full = location_scale_hazard_from_eta_derivative(
2489 &eta_full,
2490 &eta_derivative_full,
2491 &saved_inverse_link,
2492 )?;
2493
2494 let mut survival = Array2::<f64>::zeros((n, t_cols));
2495 let mut cumulative_hazard = Array2::<f64>::zeros((n, t_cols));
2496 let mut hazard = Array2::<f64>::zeros((n, t_cols));
2497 ndarray::Zip::indexed(&mut survival)
2498 .and(&mut cumulative_hazard)
2499 .and(&mut hazard)
2500 .par_for_each(|(i, j), s, ch, h| {
2501 let query_time = if per_row_eval {
2510 age_exit[i]
2511 } else {
2512 eval_times[j]
2513 };
2514 if query_time <= 0.0 {
2515 *s = 1.0;
2516 *ch = 0.0;
2517 *h = 0.0;
2518 return;
2519 }
2520 let k = if per_row_eval { i } else { i * eval_width + j };
2521 let surv = survival_prob_full[k].clamp(SURVIVAL_PROB_MIN_FOR_LOG, 1.0);
2522 *s = surv;
2523 *ch = -surv.ln();
2524 *h = hazard_full[k];
2525 });
2526
2527 let linear_predictor = if per_row_eval {
2528 eta_full.clone()
2529 } else {
2530 Array1::from_shape_fn(n, |i| eta_full[i * eval_width + t_cols])
2531 };
2532 let times = if per_row_eval {
2533 age_exit.to_vec()
2534 } else {
2535 eval_times.clone()
2538 };
2539
2540 let survival_se = response_se_full.as_ref().map(|response_se| {
2541 let mut out = Array2::<f64>::zeros((n, t_cols));
2542 ndarray::Zip::indexed(&mut out).par_for_each(|(i, j), slot| {
2543 let query_time = if per_row_eval {
2546 age_exit[i]
2547 } else {
2548 eval_times[j]
2549 };
2550 if query_time <= 0.0 {
2551 *slot = 0.0;
2552 return;
2553 }
2554 let k = if per_row_eval { i } else { i * eval_width + j };
2555 *slot = response_se[k].max(0.0);
2556 });
2557 out
2558 });
2559 let eta_se_per_row = eta_se_full.as_ref().map(|eta_se| {
2560 if per_row_eval {
2561 eta_se.clone()
2562 } else {
2563 Array1::from_shape_fn(n, |i| eta_se[i * eval_width + t_cols])
2564 }
2565 });
2566
2567 Ok(SurvivalPredictResult {
2568 times,
2569 hazard,
2570 survival,
2571 cumulative_hazard,
2572 linear_predictor,
2573 likelihood_mode: saved_likelihood_mode,
2574 survival_se,
2575 eta_se: eta_se_per_row,
2576 })
2577}
2578
2579fn prepared_sigma_design_view(
2583 input: &crate::survival::location_scale::SurvivalLocationScalePredictInput,
2584) -> &gam_linalg::matrix::DesignMatrix {
2585 &input.x_log_sigma
2586}
2587
2588pub(crate) struct LocationScaleEtaComponents {
2589 pub h: Array1<f64>,
2590 pub time_jac: Array2<f64>,
2591 pub eta_t: Array1<f64>,
2592 pub eta_ls: Array1<f64>,
2593 pub inv_sigma: Array1<f64>,
2594}
2595
2596pub(crate) struct LocationScaleTimeWarpComponents {
2597 pub(crate) h: Array1<f64>,
2598 pub(crate) time_jac: Array2<f64>,
2599 pub(crate) time_wiggle_dq: Option<Array1<f64>>,
2600}
2601
2602pub(crate) fn location_scale_time_warp_components(
2603 x_time_exit: &Array2<f64>,
2604 eta_time_offset_exit: &Array1<f64>,
2605 time_wiggle_knots: Option<&Array1<f64>>,
2606 time_wiggle_degree: Option<usize>,
2607 time_wiggle_ncols: usize,
2608 fit: &UnifiedFitResult,
2609) -> Result<LocationScaleTimeWarpComponents, String> {
2610 let n = x_time_exit.nrows();
2611 if eta_time_offset_exit.len() != n {
2612 return Err("survival location-scale time-warp row mismatch across inputs".to_string());
2613 }
2614 let beta_time = fit.beta_time();
2615 if x_time_exit.ncols() != beta_time.len() {
2616 return Err(format!(
2617 "survival location-scale time-warp design mismatch: x_exit={} beta_time={}",
2618 x_time_exit.ncols(),
2619 beta_time.len()
2620 ));
2621 }
2622
2623 let p_time_total = beta_time.len();
2624 let p_wiggle = time_wiggle_ncols.min(p_time_total);
2625 let p_base = p_time_total - p_wiggle;
2626 let beta_base = beta_time.slice(s![..p_base]).to_owned();
2627 let h_base = if p_base > 0 {
2628 x_time_exit.slice(s![.., ..p_base]).dot(&beta_base) + eta_time_offset_exit
2629 } else {
2630 eta_time_offset_exit.clone()
2631 };
2632 let mut h = h_base.clone();
2633 let mut time_jac = x_time_exit.clone();
2634 let mut time_wiggle_dq = None;
2635 if p_wiggle > 0 {
2636 if x_time_exit
2637 .slice(s![.., p_base..p_time_total])
2638 .iter()
2639 .any(|&value| value != 0.0)
2640 {
2641 return Err(
2642 "survival location-scale timewiggle prediction requires zero placeholder tail columns"
2643 .to_string(),
2644 );
2645 }
2646 let knots = time_wiggle_knots.ok_or_else(|| {
2647 "survival location-scale time-warp: timewiggle coefficients are missing knot metadata"
2648 .to_string()
2649 })?;
2650 let degree = time_wiggle_degree.ok_or_else(|| {
2651 "survival location-scale time-warp: timewiggle coefficients are missing degree metadata"
2652 .to_string()
2653 })?;
2654 let beta_w = beta_time.slice(s![p_base..p_time_total]).to_owned();
2655 let time_basis = crate::wiggle::monotone_wiggle_basis_with_derivative_order(
2656 h_base.view(),
2657 knots,
2658 degree,
2659 0,
2660 )?;
2661 let time_basis_d1 = crate::wiggle::monotone_wiggle_basis_with_derivative_order(
2662 h_base.view(),
2663 knots,
2664 degree,
2665 1,
2666 )?;
2667 if time_basis.ncols() != p_wiggle || time_basis_d1.ncols() != p_wiggle {
2668 return Err(format!(
2669 "survival location-scale time-warp timewiggle mismatch: value basis has {} columns, derivative basis has {}, beta has {}",
2670 time_basis.ncols(),
2671 time_basis_d1.ncols(),
2672 p_wiggle
2673 ));
2674 }
2675 let dq = time_basis_d1.dot(&beta_w) + 1.0;
2676 h = &h_base + &time_basis.dot(&beta_w);
2677 time_jac = Array2::<f64>::zeros((n, p_time_total));
2678 if p_base > 0 {
2679 let scaled_base = crate::survival::location_scale::scale_dense_rows(
2680 &x_time_exit.slice(s![.., ..p_base]).to_owned(),
2681 &dq,
2682 )?;
2683 time_jac.slice_mut(s![.., ..p_base]).assign(&scaled_base);
2684 }
2685 time_jac
2686 .slice_mut(s![.., p_base..p_time_total])
2687 .assign(&time_basis);
2688 time_wiggle_dq = Some(dq);
2689 }
2690
2691 Ok(LocationScaleTimeWarpComponents {
2692 h,
2693 time_jac,
2694 time_wiggle_dq,
2695 })
2696}
2697
2698pub(crate) fn location_scale_eta_components(
2699 x_time_exit: &Array2<f64>,
2700 eta_time_offset_exit: &Array1<f64>,
2701 time_wiggle_knots: Option<&Array1<f64>>,
2702 time_wiggle_degree: Option<usize>,
2703 time_wiggle_ncols: usize,
2704 x_threshold: &gam_linalg::matrix::DesignMatrix,
2705 eta_threshold_offset: &Array1<f64>,
2706 x_log_sigma: &gam_linalg::matrix::DesignMatrix,
2707 eta_log_sigma_offset: &Array1<f64>,
2708 fit: &UnifiedFitResult,
2709) -> Result<LocationScaleEtaComponents, String> {
2710 let n = x_time_exit.nrows();
2711 if x_threshold.nrows() != n
2712 || eta_threshold_offset.len() != n
2713 || x_log_sigma.nrows() != n
2714 || eta_log_sigma_offset.len() != n
2715 {
2716 return Err("survival location-scale eta component row mismatch across inputs".to_string());
2717 }
2718 let time_components = location_scale_time_warp_components(
2719 x_time_exit,
2720 eta_time_offset_exit,
2721 time_wiggle_knots,
2722 time_wiggle_degree,
2723 time_wiggle_ncols,
2724 fit,
2725 )?;
2726 let beta_threshold = fit.beta_threshold();
2727 let beta_log_sigma = fit.beta_log_sigma();
2728 let eta_t = x_threshold.matrixvectormultiply(&beta_threshold) + eta_threshold_offset;
2729 let eta_ls = x_log_sigma.matrixvectormultiply(&beta_log_sigma) + eta_log_sigma_offset;
2730 let inv_sigma = eta_ls.mapv(crate::sigma_link::exp_sigma_inverse_from_eta_scalar);
2731 Ok(LocationScaleEtaComponents {
2732 h: time_components.h,
2733 time_jac: time_components.time_jac,
2734 eta_t,
2735 eta_ls,
2736 inv_sigma,
2737 })
2738}
2739
2740fn location_scale_eta_derivative_components(
2741 x_time_derivative: &Array2<f64>,
2742 derivative_offset_exit: &Array1<f64>,
2743 x_time_exit: &Array2<f64>,
2744 eta_time_offset_exit: &Array1<f64>,
2745 time_wiggle_knots: Option<&Array1<f64>>,
2746 time_wiggle_degree: Option<usize>,
2747 time_wiggle_ncols: usize,
2748 fit: &UnifiedFitResult,
2749) -> Result<Array1<f64>, String> {
2750 let n = x_time_exit.nrows();
2751 if x_time_derivative.nrows() != n
2752 || derivative_offset_exit.len() != n
2753 || eta_time_offset_exit.len() != n
2754 {
2755 return Err(
2756 "survival location-scale hazard derivative row mismatch across inputs".to_string(),
2757 );
2758 }
2759 let beta_time = fit.beta_time();
2760 let p_time_total = beta_time.len();
2761 let p_wiggle = time_wiggle_ncols.min(p_time_total);
2762 let p_base = p_time_total - p_wiggle;
2763 if x_time_exit.ncols() != p_time_total || x_time_derivative.ncols() != p_base {
2764 return Err(format!(
2765 "survival location-scale hazard derivative design mismatch: x_exit={} beta_time={} x_derivative={} base={}",
2766 x_time_exit.ncols(),
2767 p_time_total,
2768 x_time_derivative.ncols(),
2769 p_base
2770 ));
2771 }
2772
2773 let time_components = location_scale_time_warp_components(
2774 x_time_exit,
2775 eta_time_offset_exit,
2776 time_wiggle_knots,
2777 time_wiggle_degree,
2778 time_wiggle_ncols,
2779 fit,
2780 )?;
2781 let beta_base = beta_time.slice(s![..p_base]).to_owned();
2782 let mut eta_derivative = if p_base > 0 {
2783 x_time_derivative.dot(&beta_base) + derivative_offset_exit
2784 } else {
2785 derivative_offset_exit.clone()
2786 };
2787 if let Some(dq) = time_components.time_wiggle_dq.as_ref() {
2788 eta_derivative *= dq;
2789 }
2790 if eta_derivative
2791 .iter()
2792 .any(|value| !(value.is_finite() && *value > 0.0))
2793 {
2794 return Err(
2795 "survival location-scale hazard derivative must be finite and positive".to_string(),
2796 );
2797 }
2798 Ok(eta_derivative)
2799}
2800
2801fn location_scale_hazard_from_eta_derivative(
2802 eta: &Array1<f64>,
2803 eta_derivative: &Array1<f64>,
2804 inverse_link: &InverseLink,
2805) -> Result<Array1<f64>, String> {
2806 if eta.len() != eta_derivative.len() {
2807 return Err(format!(
2808 "survival location-scale hazard row mismatch: eta={} eta_derivative={}",
2809 eta.len(),
2810 eta_derivative.len()
2811 ));
2812 }
2813 let values = eta
2814 .iter()
2815 .zip(eta_derivative.iter())
2816 .map(|(&q, &q_t)| location_scale_hazard_component(q, q_t, inverse_link))
2817 .collect::<Result<Vec<_>, _>>()?;
2818 Ok(Array1::from_vec(values))
2819}
2820
2821fn location_scale_hazard_component(
2822 eta: f64,
2823 eta_derivative: f64,
2824 inverse_link: &InverseLink,
2825) -> Result<f64, String> {
2826 if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative > 0.0) {
2827 return Err(format!(
2828 "survival location-scale hazard requires finite eta and positive eta_t, got eta={eta}, eta_t={eta_derivative}"
2829 ));
2830 }
2831 match inverse_link {
2832 InverseLink::Standard(StandardLink::Probit) => {
2833 let (_, hazard) = probit_survival_hazard_components(eta, eta_derivative)?;
2834 Ok(hazard)
2835 }
2836 InverseLink::Standard(StandardLink::CLogLog) => {
2837 let (_, hazard) = royston_parmar_survival_hazard_components(eta, eta_derivative)?;
2838 Ok(hazard)
2839 }
2840 InverseLink::Standard(StandardLink::Logit) => {
2841 let failure = if eta >= 0.0 {
2842 1.0 / (1.0 + (-eta).exp())
2843 } else {
2844 let exp_eta = eta.exp();
2845 exp_eta / (1.0 + exp_eta)
2846 };
2847 Ok(failure * eta_derivative)
2848 }
2849 InverseLink::Standard(StandardLink::Identity) => {
2850 let survival = 1.0 - eta;
2851 if !(survival.is_finite() && survival > 0.0) {
2852 return Err(format!(
2853 "survival location-scale identity link produced invalid survival={survival} at eta={eta}"
2854 ));
2855 }
2856 Ok(eta_derivative / survival)
2857 }
2858 _ => {
2859 let jet = inverse_link_jet_for_inverse_link(inverse_link, eta)
2860 .map_err(|err| format!("survival location-scale inverse-link jet failed: {err}"))?;
2861 let survival = 1.0 - jet.mu;
2862 let hazard = jet.d1 * eta_derivative / survival;
2863 if !(survival.is_finite() && survival > 0.0 && hazard.is_finite() && hazard >= 0.0) {
2864 return Err(format!(
2865 "survival location-scale inverse link produced invalid hazard components: eta={eta}, eta_t={eta_derivative}, failure={}, d_failure={}, survival={survival}, hazard={hazard}",
2866 jet.mu, jet.d1
2867 ));
2868 }
2869 Ok(hazard)
2870 }
2871 }
2872}
2873
2874pub fn require_saved_survival_likelihood_mode(
2880 model: &SavedModel,
2881) -> Result<SurvivalLikelihoodMode, SurvivalPredictError> {
2882 if matches!(&model.family_state, FittedFamily::LatentSurvival { .. }) {
2883 return match model.survival_likelihood.as_deref() {
2884 Some("latent") => Ok(SurvivalLikelihoodMode::Latent),
2885 Some(other) => Err(SurvivalPredictError::MissingFitMetadata { reason: format!(
2886 "saved latent survival model has contradictory survival_likelihood metadata: expected 'latent', got '{other}'"
2887 ) }),
2888 None => Err(SurvivalPredictError::MissingFitMetadata {
2889 reason:
2890 "saved latent survival model is missing survival_likelihood=latent metadata; refit"
2891 .to_string(),
2892 }),
2893 };
2894 }
2895 if matches!(&model.family_state, FittedFamily::LatentBinary { .. }) {
2896 return match model.survival_likelihood.as_deref() {
2897 Some("latent-binary") => Ok(SurvivalLikelihoodMode::LatentBinary),
2898 Some(other) => Err(SurvivalPredictError::MissingFitMetadata { reason: format!(
2899 "saved latent binary model has contradictory survival_likelihood metadata: expected 'latent-binary', got '{other}'"
2900 ) }),
2901 None => Err(SurvivalPredictError::MissingFitMetadata {
2902 reason:
2903 "saved latent binary model is missing survival_likelihood=latent-binary metadata; refit"
2904 .to_string(),
2905 }),
2906 };
2907 }
2908 let raw = model.survival_likelihood.as_deref().ok_or_else(|| {
2909 "saved survival model is missing survival_likelihood metadata; refit".to_string()
2910 })?;
2911 parse_survival_likelihood_mode(raw).map_err(SurvivalPredictError::from)
2912}
2913
2914pub fn saved_survival_runtime_baseline_config(
2916 model: &SavedModel,
2917) -> Result<SurvivalBaselineConfig, SurvivalPredictError> {
2918 survival_baseline_config_from_model(model).map_err(SurvivalPredictError::from)
2919}
2920
2921pub fn resolve_termspec_for_prediction(
2924 modelspec: &Option<TermCollectionSpec>,
2925 training_headers: Option<&Vec<String>>,
2926 col_map: &HashMap<String, usize>,
2927 spec_label: &str,
2928) -> Result<TermCollectionSpec, SurvivalPredictError> {
2929 let saved = modelspec.as_ref().ok_or_else(|| {
2930 format!(
2931 "model is missing {spec_label}; refit to guarantee train/predict design consistency"
2932 )
2933 })?;
2934 saved.validate_frozen(spec_label)?;
2935 let headers = training_headers.ok_or_else(|| {
2936 "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
2937 .to_string()
2938 })?;
2939 let remapped = remap_term_collectionspec_columns(saved, headers, col_map)?;
2940 remapped.validate_frozen(spec_label)?;
2941 Ok(remapped)
2942}
2943
2944fn remap_term_collectionspec_columns(
2945 spec: &TermCollectionSpec,
2946 training_headers: &[String],
2947 prediction_column_map: &HashMap<String, usize>,
2948) -> Result<TermCollectionSpec, SurvivalPredictError> {
2949 spec.remap_feature_columns(|index| -> Result<usize, SurvivalPredictError> {
2953 let name = training_headers
2954 .get(index)
2955 .ok_or_else(|| format!("saved training column index {index} is out of bounds"))?;
2956 resolve_role_col(prediction_column_map, name, "prediction")
2957 .map_err(SurvivalPredictError::from)
2958 })
2959}
2960
2961pub fn fit_result_from_saved_model_for_prediction(
2963 model: &SavedModel,
2964) -> Result<UnifiedFitResult, String> {
2965 model
2966 .fit_result
2967 .clone()
2968 .ok_or_else(|| "model is missing canonical fit_result payload; refit".to_string())
2969}
2970
2971pub fn saved_survival_location_scale_fit_result(
2977 model: &SavedModel,
2978) -> Result<UnifiedFitResult, SurvivalPredictError> {
2979 model.saved_prediction_runtime()?;
2980 let mut fit = model.fit_result.clone().ok_or_else(|| {
2981 "saved location-scale survival model missing canonical fit_result; refit".to_string()
2982 })?;
2983 let inverse_link = resolve_survival_inverse_link_from_saved(model)?;
2984 apply_inverse_link_state_to_fit_result(&mut fit, &inverse_link);
2985 Ok(fit)
2986}
2987
2988pub fn apply_inverse_link_state_to_fit_result(
2989 fit_result: &mut UnifiedFitResult,
2990 inverse_link: &InverseLink,
2991) {
2992 fit_result.fitted_link = match inverse_link {
2993 InverseLink::LatentCLogLog(state) => FittedLinkState::LatentCLogLog { state: *state },
2994 InverseLink::Sas(state) => FittedLinkState::Sas {
2995 state: *state,
2996 covariance: None,
2997 },
2998 InverseLink::BetaLogistic(state) => FittedLinkState::BetaLogistic {
2999 state: *state,
3000 covariance: None,
3001 },
3002 InverseLink::Mixture(state) => FittedLinkState::Mixture {
3003 state: state.clone(),
3004 covariance: None,
3005 },
3006 InverseLink::Standard(_) => FittedLinkState::Standard(None),
3007 };
3008}
3009
3010pub fn resolve_survival_inverse_link_from_saved(
3013 model: &SavedModel,
3014) -> Result<InverseLink, SurvivalPredictError> {
3015 if let Some(link) = model.link.as_ref() {
3016 return Ok(link.clone());
3017 }
3018 Err(SurvivalPredictError::MissingFitMetadata {
3019 reason: "saved survival model is missing link metadata; refit".to_string(),
3020 })
3021}
3022
3023pub fn concat_array1_refs(parts: &[&Array1<f64>]) -> Array1<f64> {
3025 let total: usize = parts.iter().map(|part| part.len()).sum();
3026 let mut out = Array1::<f64>::zeros(total);
3027 let mut offset = 0usize;
3028 for part in parts {
3029 let width = part.len();
3030 out.slice_mut(s![offset..offset + width]).assign(part);
3031 offset += width;
3032 }
3033 out
3034}
3035
3036pub fn saved_baseline_timewiggle_components(
3040 eta_entry: &Array1<f64>,
3041 eta_exit: &Array1<f64>,
3042 derivative_exit: &Array1<f64>,
3043 model: &SavedModel,
3044) -> Result<Option<(Array2<f64>, Array2<f64>, Array2<f64>)>, SurvivalPredictError> {
3045 match model.saved_baseline_time_wiggle()? {
3046 None => Ok(None),
3047 Some(runtime) => {
3048 runtime.validate_global_monotonicity()?;
3049 let SavedBaselineTimeWiggleRuntime {
3050 knots,
3051 degree,
3052 beta,
3053 ..
3054 } = runtime;
3055 let knots = Array1::from_vec(knots);
3056 let entry = match buildwiggle_block_input_from_knots(
3057 eta_entry.view(),
3058 &knots,
3059 degree,
3060 2,
3061 false,
3062 )?
3063 .design
3064 {
3065 DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
3066 _ => {
3067 return Err(SurvivalPredictError::IncompatibleSchema {
3068 reason: "saved baseline-timewiggle entry design must be dense".to_string(),
3069 });
3070 }
3071 };
3072 let exit = match buildwiggle_block_input_from_knots(
3073 eta_exit.view(),
3074 &knots,
3075 degree,
3076 2,
3077 false,
3078 )?
3079 .design
3080 {
3081 DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
3082 _ => {
3083 return Err(SurvivalPredictError::IncompatibleSchema {
3084 reason: "saved baseline-timewiggle exit design must be dense".to_string(),
3085 });
3086 }
3087 };
3088 let betaw = beta;
3089 if entry.ncols() != betaw.len() || exit.ncols() != betaw.len() {
3090 return Err(SurvivalPredictError::IncompatibleSchema {
3091 reason: format!(
3092 "saved baseline-timewiggle dimension mismatch: coefficients have {} entries but basis has entry={} exit={}",
3093 betaw.len(),
3094 entry.ncols(),
3095 exit.ncols()
3096 ),
3097 });
3098 }
3099 let derivative = build_survival_timewiggle_derivative_design(
3100 eta_exit,
3101 derivative_exit,
3102 &knots,
3103 degree,
3104 )
3105 .map_err(|e| {
3106 e.replace(
3107 "build baseline-timewiggle",
3108 "evaluate saved baseline-timewiggle",
3109 )
3110 })?;
3111 if derivative.ncols() != betaw.len() {
3112 return Err(SurvivalPredictError::IncompatibleSchema {
3113 reason: format!(
3114 "saved baseline-timewiggle derivative dimension mismatch: coefficients have {} entries but derivative basis has {} columns",
3115 betaw.len(),
3116 derivative.ncols()
3117 ),
3118 });
3119 }
3120 Ok(Some((entry, exit, derivative)))
3121 }
3122 }
3123}
3124
3125pub fn build_saved_survival_marginal_slope_predictor(
3134 model: &SavedModel,
3135 fit_saved: &UnifiedFitResult,
3136 z_name: &str,
3137 z: &Array1<f64>,
3138 cov_design: &DesignMatrix,
3139 logslope_design: &DesignMatrix,
3140 time_build: &SurvivalTimeBuildOutput,
3141 eta_offset_entry: &Array1<f64>,
3142 eta_offset_exit: &Array1<f64>,
3143 derivative_offset_exit: &Array1<f64>,
3144 primary_offset: &Array1<f64>,
3145 noise_offset: &Array1<f64>,
3146) -> Result<
3147 (
3148 BernoulliMarginalSlopePredictor,
3149 PredictInput,
3150 UnifiedFitResult,
3151 ),
3152 SurvivalPredictError,
3153> {
3154 let saved_runtime = model.saved_prediction_runtime()?;
3155 if saved_runtime.link_wiggle.is_some() {
3156 return Err(SurvivalPredictError::MissingFitMetadata {
3157 reason:
3158 "saved survival marginal-slope model contains legacy linkwiggle metadata; refit with the anchored link-deviation runtime"
3159 .to_string(),
3160 });
3161 }
3162
3163 let saved_score_runtime = saved_runtime.score_warp;
3164 let saved_link_runtime = saved_runtime.link_deviation;
3165 let influence_absorber_width = saved_runtime.influence_absorber_width;
3170 let blocks = &fit_saved.blocks;
3171 let expected_blocks = 3
3172 + usize::from(saved_score_runtime.is_some())
3173 + usize::from(saved_link_runtime.is_some())
3174 + usize::from(influence_absorber_width.is_some());
3175 if blocks.len() != expected_blocks {
3176 return Err(SurvivalPredictError::IncompatibleSchema {
3177 reason: format!(
3178 "saved survival marginal-slope model requires {} blocks [time, marginal, slope{}{}{}], got {}",
3179 expected_blocks,
3180 if saved_score_runtime.is_some() {
3181 ", score-warp"
3182 } else {
3183 ""
3184 },
3185 if saved_link_runtime.is_some() {
3186 ", link-deviation"
3187 } else {
3188 ""
3189 },
3190 if influence_absorber_width.is_some() {
3191 ", influence-absorber(dropped)"
3192 } else {
3193 ""
3194 },
3195 blocks.len(),
3196 ),
3197 });
3198 }
3199
3200 let beta_time = &blocks[0].beta;
3201 let beta_marginal = &blocks[1].beta;
3202 let beta_logslope = &blocks[2].beta;
3203 if let Some(runtime) = saved_score_runtime.as_ref() {
3204 let beta = &blocks[3].beta;
3205 if beta.len() != runtime.basis_dim {
3206 return Err(SurvivalPredictError::IncompatibleSchema {
3207 reason: format!(
3208 "saved survival marginal-slope score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
3209 beta.len(),
3210 runtime.basis_dim
3211 ),
3212 });
3213 }
3214 }
3215 if let Some(runtime) = saved_link_runtime.as_ref() {
3216 let idx = 3 + usize::from(saved_score_runtime.is_some());
3217 let beta = &blocks[idx].beta;
3218 if beta.len() != runtime.basis_dim {
3219 return Err(SurvivalPredictError::IncompatibleSchema {
3220 reason: format!(
3221 "saved survival marginal-slope link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
3222 beta.len(),
3223 runtime.basis_dim
3224 ),
3225 });
3226 }
3227 }
3228
3229 if beta_marginal.len() != cov_design.ncols() {
3230 return Err(SurvivalPredictError::IncompatibleSchema {
3231 reason: format!(
3232 "saved survival marginal-slope marginal coefficient mismatch: beta has {} entries but baseline design has {} columns",
3233 beta_marginal.len(),
3234 cov_design.ncols()
3235 ),
3236 });
3237 }
3238 if beta_logslope.len() != logslope_design.ncols() {
3239 return Err(SurvivalPredictError::IncompatibleSchema {
3240 reason: format!(
3241 "saved survival marginal-slope slope coefficient mismatch: beta has {} entries but slope design has {} columns",
3242 beta_logslope.len(),
3243 logslope_design.ncols()
3244 ),
3245 });
3246 }
3247
3248 let p_time_base = time_build.x_exit_time.ncols();
3249 let saved_timewiggle = saved_runtime.baseline_time_wiggle;
3250 let p_timewiggle = saved_timewiggle
3251 .as_ref()
3252 .map_or(0, |runtime| runtime.beta.len());
3253 if beta_time.len() != p_time_base + p_timewiggle {
3254 return Err(SurvivalPredictError::IncompatibleSchema {
3255 reason: format!(
3256 "saved survival marginal-slope time coefficient mismatch: beta has {} entries but expected base={} plus timewiggle={}",
3257 beta_time.len(),
3258 p_time_base,
3259 p_timewiggle
3260 ),
3261 });
3262 }
3263
3264 let beta_time_base = beta_time.slice(s![..p_time_base]).to_owned();
3265 let cov_eta_marginal = cov_design.dot(beta_marginal);
3269 let q_entry_base = time_build.x_entry_time.dot(&beta_time_base)
3270 + &cov_eta_marginal
3271 + eta_offset_entry
3272 + primary_offset;
3273 let q_exit_base = time_build.x_exit_time.dot(&beta_time_base)
3274 + &cov_eta_marginal
3275 + eta_offset_exit
3276 + primary_offset;
3277 let qd_exit_base = time_build.x_derivative_time.dot(&beta_time_base) + derivative_offset_exit;
3278
3279 let mut q_design_parts = vec![time_build.x_exit_time.clone()];
3280 if saved_timewiggle.is_some() {
3281 let (_, exit_w, _) = saved_baseline_timewiggle_components(
3282 &q_entry_base,
3283 &q_exit_base,
3284 &qd_exit_base,
3285 model,
3286 )?
3287 .ok_or_else(|| {
3288 "saved survival marginal-slope model is missing baseline-timewiggle runtime metadata"
3289 .to_string()
3290 })?;
3291 if exit_w.ncols() != p_timewiggle {
3292 return Err(SurvivalPredictError::IncompatibleSchema {
3293 reason: format!(
3294 "saved survival marginal-slope timewiggle design mismatch: rebuilt {} columns but runtime expects {}",
3295 exit_w.ncols(),
3296 p_timewiggle
3297 ),
3298 });
3299 }
3300 q_design_parts.push(DesignMatrix::from(exit_w));
3301 }
3302 q_design_parts.push(cov_design.clone());
3303 let q_design = DesignMatrix::hstack(q_design_parts)?;
3304
3305 let combined_q_beta = concat_array1_refs(&[beta_time, beta_marginal]);
3306 let combined_q_lambdas = concat_array1_refs(&[&blocks[0].lambdas, &blocks[1].lambdas]);
3307 let mut predictor_blocks = Vec::with_capacity(
3308 2 + usize::from(saved_score_runtime.is_some()) + usize::from(saved_link_runtime.is_some()),
3309 );
3310 predictor_blocks.push(FittedBlock {
3311 beta: combined_q_beta.clone(),
3312 role: BlockRole::Mean,
3313 edf: blocks[0].edf + blocks[1].edf,
3314 lambdas: combined_q_lambdas,
3315 });
3316 predictor_blocks.push(FittedBlock {
3317 beta: beta_logslope.clone(),
3318 role: BlockRole::Scale,
3319 edf: blocks[2].edf,
3320 lambdas: blocks[2].lambdas.clone(),
3321 });
3322 if saved_score_runtime.is_some() {
3323 let mut block = blocks[3].clone();
3324 block.role = BlockRole::Mean;
3325 predictor_blocks.push(block);
3326 }
3327 if saved_link_runtime.is_some() {
3328 let idx = 3 + usize::from(saved_score_runtime.is_some());
3329 let mut block = blocks[idx].clone();
3330 block.role = BlockRole::LinkWiggle;
3331 predictor_blocks.push(block);
3332 }
3333
3334 let mut predictor_fit = fit_saved.clone();
3335 predictor_fit.blocks = predictor_blocks;
3336 predictor_fit.beta = concat_array1_refs(
3337 &predictor_fit
3338 .blocks
3339 .iter()
3340 .map(|block| &block.beta)
3341 .collect::<Vec<_>>(),
3342 );
3343 predictor_fit.block_states.clear();
3344
3345 let predictor = BernoulliMarginalSlopePredictor::from_unified(
3346 &predictor_fit,
3347 z_name.to_string(),
3348 model.latent_z_normalization.ok_or_else(|| {
3349 "saved survival marginal-slope model missing latent_z_normalization".to_string()
3350 })?,
3351 model.latent_measure.clone().ok_or_else(|| {
3352 "saved survival marginal-slope model missing latent_measure".to_string()
3353 })?,
3354 0.0,
3355 model.logslope_baseline.ok_or_else(|| {
3356 "saved survival marginal-slope model missing logslope_baseline".to_string()
3357 })?,
3358 model
3359 .resolved_inverse_link()?
3360 .unwrap_or(InverseLink::Standard(StandardLink::Probit)),
3361 model
3362 .family_state
3363 .frailty()
3364 .cloned()
3365 .unwrap_or(FrailtySpec::None),
3366 saved_score_runtime,
3367 saved_link_runtime,
3368 model.latent_z_rank_int_calibration.clone(),
3369 model.latent_z_conditional_calibration.clone(),
3372 )?;
3373
3374 let pred_input = PredictInput {
3375 design: q_design,
3376 offset: eta_offset_exit + primary_offset,
3377 design_noise: Some(logslope_design.clone()),
3378 offset_noise: Some(noise_offset.clone()),
3379 auxiliary_scalar: Some(z.clone()),
3380 auxiliary_matrix: None,
3381 };
3382
3383 Ok((predictor, pred_input, predictor_fit))
3384}
3385
3386#[cfg(test)]
3387mod tests {
3388 use super::*;
3389 use crate::probability::{normal_cdf, normal_pdf};
3390
3391 #[test]
3392 fn probit_survival_hazard_uses_density_over_survival() {
3393 let eta = 2.0;
3394 let eta_t = 0.3;
3395
3396 let (cum, hazard) =
3397 probit_survival_hazard_components(eta, eta_t).expect("valid components");
3398
3399 let survival = normal_cdf(-eta);
3400 let expected_cum = -survival.ln();
3401 let expected_hazard = normal_pdf(eta) * eta_t / survival;
3402 assert!((cum - expected_cum).abs() <= 1e-14);
3403 assert!((hazard - expected_hazard).abs() <= 1e-14);
3404 }
3405
3406 #[test]
3407 fn probit_survival_hazard_stays_finite_in_right_tail() {
3408 let eta = 40.0;
3409 let eta_t = 9.694_340_360_912_401e-5;
3410
3411 let event_density =
3412 (-0.5_f64 * eta * eta).exp() / (2.0 * std::f64::consts::PI).sqrt() * eta_t;
3413 assert_eq!(event_density, 0.0);
3414
3415 let (cum, hazard) =
3416 probit_survival_hazard_components(eta, eta_t).expect("valid tail components");
3417 assert!(cum > 800.0, "right-tail cumulative hazard was {cum}");
3418 assert!(
3419 (3.87e-3..3.89e-3).contains(&hazard),
3420 "right-tail hazard was {hazard}"
3421 );
3422 }
3423
3424 #[test]
3425 fn probit_survival_hazard_accepts_zero_time_derivative_as_flat_hazard() {
3426 let (cum, hazard) =
3427 probit_survival_hazard_components(1.0, 0.0).expect("zero derivative is flat hazard");
3428 assert!(cum > 0.0);
3429 assert_eq!(hazard, 0.0);
3430 }
3431
3432 #[test]
3433 fn marginal_slope_index_derivative_clamps_extrapolation_negative_to_flat_hazard() {
3434 let deta_dq = (1.0_f64 + 0.4 * 0.4).sqrt(); let qd_with_wiggle = -1.35e-3;
3442 let eta_t = marginal_slope_index_derivative_at_horizon(deta_dq, qd_with_wiggle);
3443 assert_eq!(
3444 eta_t, 0.0,
3445 "negative extrapolation derivative must clamp to 0"
3446 );
3447 let (cum, hazard) = probit_survival_hazard_components(-0.563, eta_t)
3449 .expect("clamped flat-hazard prediction must validate");
3450 assert!(
3451 cum >= 0.0,
3452 "cumulative hazard must be well-posed, got {cum}"
3453 );
3454 assert_eq!(
3455 hazard, 0.0,
3456 "clamped derivative gives zero instantaneous hazard"
3457 );
3458 }
3459
3460 #[test]
3461 fn marginal_slope_index_derivative_preserves_positive_and_nonfinite() {
3462 let positive = marginal_slope_index_derivative_at_horizon(1.25, 0.8);
3466 assert!(
3467 (positive - 1.0).abs() <= 1e-15,
3468 "positive derivative scaled by chain factor"
3469 );
3470 let nonfinite = marginal_slope_index_derivative_at_horizon(1.25, f64::NAN);
3471 assert!(
3472 nonfinite.is_nan(),
3473 "non-finite derivative passes through unclamped"
3474 );
3475 assert!(
3476 probit_survival_hazard_components(0.5, nonfinite).is_err(),
3477 "non-finite derivative must still be rejected by the validator"
3478 );
3479 }
3480
3481 #[test]
3482 fn probit_survival_hazard_rejects_infinite_time_derivative() {
3483 let err = probit_survival_hazard_components(1.0, f64::INFINITY)
3484 .expect_err("infinite derivative should be invalid");
3485 assert!(
3486 err.to_string()
3487 .contains("invalid survival index derivative")
3488 );
3489 }
3490
3491 #[test]
3492 fn probit_survival_hazard_rejects_nan_inputs() {
3493 let err_eta =
3499 probit_survival_hazard_components(f64::NAN, 0.5).expect_err("NaN eta must be rejected");
3500 assert!(
3501 err_eta
3502 .to_string()
3503 .contains("invalid survival index derivative")
3504 );
3505 let err_dt = probit_survival_hazard_components(1.0, f64::NAN)
3506 .expect_err("NaN eta_derivative must be rejected");
3507 assert!(
3508 err_dt
3509 .to_string()
3510 .contains("invalid survival index derivative")
3511 );
3512 }
3513
3514 #[test]
3515 fn probit_survival_hazard_rejects_negative_time_derivative() {
3516 let err = probit_survival_hazard_components(1.0, -0.5)
3520 .expect_err("negative derivative should be invalid");
3521 assert!(
3522 err.to_string()
3523 .contains("invalid survival index derivative")
3524 );
3525 }
3526
3527 #[test]
3528 fn royston_parmar_hazard_is_cumulative_hazard_derivative() {
3529 let eta = 2.0_f64.ln();
3530 let eta_t = 0.25;
3531
3532 let (cum, hazard) =
3533 royston_parmar_survival_hazard_components(eta, eta_t).expect("valid components");
3534
3535 assert!((cum - 2.0).abs() <= 1e-14);
3536 assert!((hazard - 0.5).abs() <= 1e-14);
3537 assert_ne!(hazard, cum);
3538 }
3539
3540 #[test]
3541 fn royston_parmar_hazard_rejects_negative_log_hazard_derivative() {
3542 let err = royston_parmar_survival_hazard_components(0.0, -0.5)
3546 .expect_err("negative derivative should be invalid");
3547 assert!(
3548 err.to_string()
3549 .contains("invalid log-cumulative-hazard derivative")
3550 );
3551 }
3552
3553 #[test]
3554 fn royston_parmar_hazard_accepts_zero_derivative_as_flat_boundary() {
3555 let eta = 1.9909019457445971_f64; let (cum, hazard) = royston_parmar_survival_hazard_components(eta, 0.0)
3562 .expect("zero derivative is a valid flat boundary, not an error");
3563 assert!((cum - eta.exp()).abs() <= 1e-12, "cum = Λ(t) = exp(η)");
3564 assert_eq!(
3565 hazard, 0.0,
3566 "flat cumulative hazard ⇒ zero instantaneous hazard"
3567 );
3568 let survival = (-cum).exp().clamp(0.0, 1.0);
3570 assert!(survival.is_finite() && (0.0..=1.0).contains(&survival));
3571 }
3572
3573 #[test]
3574 fn royston_parmar_hazard_zero_derivative_in_saturated_tail_is_zero_not_nan() {
3575 let eta = 1000.0_f64;
3581 assert!(
3582 eta.exp().is_infinite(),
3583 "test premise: exp(1000) overflows to +∞"
3584 );
3585 assert!(
3586 (f64::INFINITY * 0.0).is_nan(),
3587 "test premise: the naive product is NaN"
3588 );
3589 let (cum, hazard) = royston_parmar_survival_hazard_components(eta, 0.0)
3590 .expect("saturated + flat boundary must be valid");
3591 assert!(cum.is_infinite() && cum > 0.0, "cum saturates to +∞");
3592 assert_eq!(hazard, 0.0, "hazard at a flat boundary is 0, never NaN");
3593 }
3594
3595 #[test]
3596 fn royston_parmar_hazard_propagates_saturation_as_infinity() {
3597 let eta = 1000.0_f64;
3602 let eta_t = 0.5_f64;
3603 assert!(eta.exp().is_infinite(), "test premise: exp(1000) overflows");
3604
3605 let (cum, hazard) = royston_parmar_survival_hazard_components(eta, eta_t)
3606 .expect("saturated RP fit must yield a result, not an error");
3607 assert!(cum.is_infinite() && cum > 0.0, "expected +∞ cum, got {cum}");
3608 assert!(
3609 hazard.is_infinite() && hazard > 0.0,
3610 "expected +∞ hazard, got {hazard}"
3611 );
3612
3613 let survival = (-cum).exp().clamp(0.0, 1.0);
3615 assert_eq!(survival, 0.0, "saturated cum_hazard must give survival 0");
3616 }
3617
3618 #[test]
3619 fn royston_parmar_hazard_rejects_nan_eta() {
3620 let err = royston_parmar_survival_hazard_components(f64::NAN, 0.5)
3621 .expect_err("NaN eta should be invalid");
3622 assert!(
3623 err.to_string()
3624 .contains("invalid log-cumulative-hazard derivative")
3625 );
3626 }
3627
3628 #[test]
3629 fn royston_parmar_hazard_left_tail_collapses_to_zero() {
3630 let eta = -1000.0_f64;
3633 let eta_t = 2.0_f64;
3634 assert_eq!(eta.exp(), 0.0, "test premise: exp(-1000) underflows to 0");
3635
3636 let (cum, hazard) = royston_parmar_survival_hazard_components(eta, eta_t)
3637 .expect("RP left tail must remain valid");
3638 assert_eq!(
3639 cum, 0.0,
3640 "left-tail cum_hazard should underflow to 0, got {cum}"
3641 );
3642 assert_eq!(
3643 hazard, 0.0,
3644 "left-tail hazard should underflow to 0, got {hazard}"
3645 );
3646
3647 let survival = (-cum).exp().clamp(0.0, 1.0);
3649 assert_eq!(survival, 1.0);
3650 }
3651
3652 #[test]
3653 fn probit_survival_hazard_left_tail_collapses_to_zero() {
3654 let eta = -40.0_f64;
3658 let eta_t = 1.5_f64;
3659
3660 let (cum, hazard) =
3661 probit_survival_hazard_components(eta, eta_t).expect("left tail must remain valid");
3662 assert!(
3663 (0.0..1e-300).contains(&cum),
3664 "left-tail cum should be ~0, got {cum}"
3665 );
3666 assert_eq!(
3667 hazard, 0.0,
3668 "left-tail hazard should underflow to 0, got {hazard}"
3669 );
3670 }
3671
3672 #[test]
3673 fn location_scale_logit_hazard_is_failure_slope_over_survival() {
3674 let eta = 0.7;
3675 let eta_t = 0.4;
3676
3677 let hazard = location_scale_hazard_component(
3678 eta,
3679 eta_t,
3680 &InverseLink::Standard(StandardLink::Logit),
3681 )
3682 .expect("valid logit hazard");
3683
3684 let failure = 1.0 / (1.0 + (-eta).exp());
3685 assert!((hazard - failure * eta_t).abs() <= 1e-14);
3686 }
3687
3688 #[test]
3689 fn location_scale_cloglog_hazard_matches_log_cumulative_hazard_derivative() {
3690 let eta = 1.5;
3691 let eta_t = 0.2;
3692
3693 let hazard = location_scale_hazard_component(
3694 eta,
3695 eta_t,
3696 &InverseLink::Standard(StandardLink::CLogLog),
3697 )
3698 .expect("valid cloglog hazard");
3699
3700 assert!((hazard - eta.exp() * eta_t).abs() <= 1e-14);
3701 }
3702
3703 #[test]
3706 fn kaplan_meier_censoring_is_right_continuous_step() {
3707 let time = [2.0, 4.0, 6.0, 8.0];
3709 let event = [1.0, 0.0, 1.0, 0.0];
3710 let g = KaplanMeier::fit_censoring(&time, &event);
3711 assert!((g.at(0.0) - 1.0).abs() <= 1e-15);
3713 assert!((g.at(2.0) - 1.0).abs() <= 1e-15);
3714 assert!((g.at(3.999) - 1.0).abs() <= 1e-15);
3715 assert!((g.at(4.0) - 2.0 / 3.0).abs() <= 1e-12);
3717 assert!((g.at(5.0) - 2.0 / 3.0).abs() <= 1e-12);
3718 assert!((g.at(6.0) - 2.0 / 3.0).abs() <= 1e-12);
3720 assert!(g.at(8.0).abs() <= 1e-15);
3722 }
3723
3724 #[test]
3725 fn ipcw_brier_no_censoring_reduces_to_plain_brier() {
3726 let s_pred = [0.3, 0.7, 0.6, 0.2];
3729 let time = [2.0, 8.0, 10.0, 3.0];
3730 let event = [1.0, 1.0, 0.0, 1.0];
3731 let tau = 5.0;
3732 let g = KaplanMeier::fit_censoring(&time, &event);
3733 let bs = ipcw_brier_score(&s_pred, &time, &event, tau, |t| g.at(t)).unwrap();
3734 let expected =
3736 (0.3f64.powi(2) + (1.0 - 0.7f64).powi(2) + (1.0 - 0.6f64).powi(2) + 0.2f64.powi(2))
3737 / 4.0;
3738 assert!(
3739 (bs - expected).abs() <= 1e-12,
3740 "bs={bs} expected={expected}"
3741 );
3742 }
3743
3744 #[test]
3745 fn ipcw_brier_reweights_by_inverse_censoring_probability() {
3746 let s_pred = [0.4, 0.5, 0.7, 0.8];
3750 let time = [2.0, 4.0, 6.0, 8.0];
3751 let event = [1.0, 0.0, 1.0, 0.0];
3752 let tau = 5.0;
3753 let g = KaplanMeier::fit_censoring(&time, &event);
3754 let bs = ipcw_brier_score(&s_pred, &time, &event, tau, |t| g.at(t)).unwrap();
3755 let expected = (0.16 + 0.0 + 0.135 + 0.06) / 4.0;
3760 assert!(
3761 (bs - expected).abs() <= 1e-12,
3762 "bs={bs} expected={expected}"
3763 );
3764 }
3765
3766 #[test]
3767 fn ipcw_brier_drops_invalid_rows_from_both_numerator_and_denominator() {
3768 let s_pred = [0.3, 0.7, 0.5, 0.5];
3770 let time = [2.0, 8.0, f64::NAN, -1.0];
3771 let event = [1.0, 1.0, 1.0, 0.0];
3772 let g = KaplanMeier::fit_censoring(&time, &event);
3773 let bs = ipcw_brier_score(&s_pred, &time, &event, 5.0, |t| g.at(t)).unwrap();
3774 let expected = (0.3f64.powi(2) + (1.0 - 0.7f64).powi(2)) / 2.0;
3777 assert!(
3778 (bs - expected).abs() <= 1e-12,
3779 "bs={bs} expected={expected}"
3780 );
3781 }
3782
3783 #[test]
3784 fn integrated_ipcw_brier_of_constant_brier_is_that_constant() {
3785 let time = [2.0, 8.0, 10.0, 3.0];
3788 let event = [1.0, 1.0, 0.0, 1.0];
3789 let grid = [0.0, 1.0, 2.5, 4.0, 6.0];
3790 let col = [0.3, 0.7, 0.6, 0.2];
3794 let mut surv = Array2::<f64>::zeros((4, grid.len()));
3795 for k in 0..grid.len() {
3796 for i in 0..4 {
3797 surv[[i, k]] = col[i];
3798 }
3799 }
3800 let g = KaplanMeier::fit_censoring(&time, &event);
3801 let per_time = ipcw_brier_score(&col, &time, &event, grid[2], |t| g.at(t)).unwrap();
3802 let mut oracle_pts = Vec::new();
3806 for k in 0..grid.len() {
3807 oracle_pts.push((
3808 grid[k],
3809 ipcw_brier_score(&col, &time, &event, grid[k], |t| g.at(t)).unwrap(),
3810 ));
3811 }
3812 let mut integral = 0.0;
3813 for w in oracle_pts.windows(2) {
3814 integral += 0.5 * (w[0].1 + w[1].1) * (w[1].0 - w[0].0);
3815 }
3816 let oracle = integral / (grid[grid.len() - 1] - grid[0]);
3817 let ibs =
3818 integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, f64::INFINITY, |t| {
3819 g.at(t)
3820 })
3821 .unwrap();
3822 assert!((ibs - oracle).abs() <= 1e-12, "ibs={ibs} oracle={oracle}");
3823 assert!(per_time >= 0.0);
3825 }
3826
3827 #[test]
3828 fn integrated_ipcw_brier_respects_the_horizon_cutoff() {
3829 let time = [2.0, 8.0, 10.0, 3.0];
3830 let event = [1.0, 1.0, 0.0, 1.0];
3831 let grid = [0.0, 2.0, 4.0, 100.0];
3832 let col = [0.3, 0.7, 0.6, 0.2];
3833 let mut surv = Array2::<f64>::zeros((4, grid.len()));
3834 for k in 0..grid.len() {
3835 for i in 0..4 {
3836 surv[[i, k]] = col[i];
3837 }
3838 }
3839 let g = KaplanMeier::fit_censoring(&time, &event);
3840 let restricted =
3842 integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, 5.0, |t| g.at(t))
3843 .unwrap();
3844 let full =
3845 integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, f64::INFINITY, |t| {
3846 g.at(t)
3847 })
3848 .unwrap();
3849 assert!(
3852 (restricted - full).abs() > 1e-3,
3853 "horizon cutoff had no effect: restricted={restricted} full={full}"
3854 );
3855 }
3856
3857 #[test]
3858 fn integrated_ipcw_brier_rejects_malformed_grids() {
3859 let time = [2.0, 8.0];
3860 let event = [1.0, 0.0];
3861 let surv = Array2::<f64>::from_elem((2, 3), 0.5);
3862 let g = KaplanMeier::fit_censoring(&time, &event);
3863 let bad = [0.0, 2.0, 1.0];
3865 assert!(
3866 integrated_ipcw_brier_score(surv.view(), &time, &event, &bad, f64::INFINITY, |t| g
3867 .at(t))
3868 .is_none()
3869 );
3870 let short = [0.0, 1.0];
3872 assert!(
3873 integrated_ipcw_brier_score(surv.view(), &time, &event, &short, f64::INFINITY, |t| g
3874 .at(t))
3875 .is_none()
3876 );
3877 }
3878}