1use gam_terms::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14 BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15 create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18 DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19 SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24 WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25 monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::{InverseLink, StandardLink};
31use ndarray::{Array1, Array2, array, s};
32use rayon::prelude::*;
33
34#[derive(Clone, Debug)]
51pub enum SurvivalConstructionError {
52 InvalidConfig { reason: String },
55 MissingColumn { reason: String },
58 IncompatibleDimensions { reason: String },
61 DataValidationFailed { reason: String },
65 BasisConstructionFailed { reason: String },
69 UnsupportedDistribution { reason: String },
72}
73
74impl_reason_error_boilerplate! {
75 SurvivalConstructionError {
76 InvalidConfig,
77 MissingColumn,
78 IncompatibleDimensions,
79 DataValidationFailed,
80 BasisConstructionFailed,
81 UnsupportedDistribution,
82 }
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SurvivalBaselineTarget {
91 Linear,
95 Weibull,
100 Gompertz,
105 GompertzMakeham,
110}
111
112#[derive(Clone, Debug)]
113pub struct SurvivalBaselineConfig {
114 pub target: SurvivalBaselineTarget,
115 pub scale: Option<f64>,
116 pub shape: Option<f64>,
117 pub rate: Option<f64>,
118 pub makeham: Option<f64>,
119}
120
121#[derive(Clone, Debug)]
122pub enum SurvivalTimeBasisConfig {
123 None,
124 Linear,
125 BSpline {
126 degree: usize,
127 knots: Array1<f64>,
128 smooth_lambda: f64,
129 },
130 ISpline {
168 degree: usize,
169 knots: Array1<f64>,
170 keep_cols: Vec<usize>,
171 smooth_lambda: f64,
172 },
173}
174
175#[derive(Clone, Debug, PartialEq)]
189pub struct SavedSurvivalTimeBasis {
190 pub basisname: String,
191 pub degree: Option<usize>,
192 pub knots: Option<Vec<f64>>,
193 pub keep_cols: Option<Vec<usize>>,
194 pub smooth_lambda: Option<f64>,
195 pub anchor: f64,
196}
197
198impl SavedSurvivalTimeBasis {
199 pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
202 Self {
203 basisname: build.basisname.clone(),
204 degree: build.degree,
205 knots: build.knots.clone(),
206 keep_cols: build.keep_cols.clone(),
207 smooth_lambda: build.smooth_lambda,
208 anchor,
209 }
210 }
211}
212
213#[derive(Clone)]
214pub struct SurvivalTimeBuildOutput {
215 pub x_entry_time: DesignMatrix,
216 pub x_exit_time: DesignMatrix,
217 pub x_derivative_time: DesignMatrix,
218 pub penalties: Vec<Array2<f64>>,
219 pub nullspace_dims: Vec<usize>,
221 pub basisname: String,
222 pub degree: Option<usize>,
223 pub knots: Option<Vec<f64>>,
224 pub keep_cols: Option<Vec<usize>>,
225 pub smooth_lambda: Option<f64>,
226}
227
228pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
229
230const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
238
239const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
247
248#[derive(Clone, Copy, Debug, PartialEq, Eq)]
249pub enum SurvivalLikelihoodMode {
250 Transformation,
251 Weibull,
252 LocationScale,
253 MarginalSlope,
254 Latent,
255 LatentBinary,
256}
257
258pub struct SurvivalTimeWiggleBuild {
259 pub penalties: Vec<Array2<f64>>,
260 pub nullspace_dims: Vec<usize>,
261 pub knots: Array1<f64>,
262 pub degree: usize,
263 pub ncols: usize,
264}
265
266pub fn normalize_survival_time_pair(
271 entry_raw: f64,
272 exit_raw: f64,
273 row_index: usize,
274) -> Result<(f64, f64), String> {
275 if !entry_raw.is_finite() || !exit_raw.is_finite() {
276 return Err(SurvivalConstructionError::DataValidationFailed {
277 reason: format!("non-finite survival times at row {}", row_index + 1),
278 }
279 .into());
280 }
281 if entry_raw < 0.0 || exit_raw < 0.0 {
282 return Err(SurvivalConstructionError::DataValidationFailed {
283 reason: format!("negative survival times at row {}", row_index + 1),
284 }
285 .into());
286 }
287
288 let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
289 let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
290 Ok((entry, exit))
291}
292
293pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
298 basisname.eq_ignore_ascii_case("ispline")
299}
300
301pub fn require_structural_survival_time_basis(
302 basisname: &str,
303 context: &str,
304) -> Result<(), String> {
305 if survival_basis_supports_structural_monotonicity(basisname) {
306 return Ok(());
307 }
308 Err(SurvivalConstructionError::UnsupportedDistribution {
309 reason: format!(
310 "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
311Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
312`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
313Re-run with `--time-basis ispline`."
314 ),
315 }
316 .into())
317}
318
319pub fn parse_survival_baseline_config(
324 target_raw: &str,
325 scale: Option<f64>,
326 shape: Option<f64>,
327 rate: Option<f64>,
328 makeham: Option<f64>,
329) -> Result<SurvivalBaselineConfig, String> {
330 let target = match target_raw.to_ascii_lowercase().as_str() {
331 "linear" => SurvivalBaselineTarget::Linear,
332 "weibull" => SurvivalBaselineTarget::Weibull,
333 "gompertz" => SurvivalBaselineTarget::Gompertz,
334 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
335 other => {
336 return Err(SurvivalConstructionError::UnsupportedDistribution {
337 reason: format!(
338 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
339 ),
340 }
341 .into());
342 }
343 };
344
345 match target {
346 SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
347 target,
348 scale: None,
349 shape: None,
350 rate: None,
351 makeham: None,
352 }),
353 SurvivalBaselineTarget::Weibull => {
354 let scale = scale.ok_or_else(|| {
355 "--baseline-target weibull requires --baseline-scale > 0".to_string()
356 })?;
357 let shape = shape.ok_or_else(|| {
358 "--baseline-target weibull requires --baseline-shape > 0".to_string()
359 })?;
360 if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
361 return Err(
362 "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
363 .to_string(),
364 );
365 }
366 Ok(SurvivalBaselineConfig {
367 target,
368 scale: Some(scale),
369 shape: Some(shape),
370 rate: None,
371 makeham: None,
372 })
373 }
374 SurvivalBaselineTarget::Gompertz => {
375 let rate = rate.unwrap_or(1.0);
376 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
377 if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
378 return Err(
379 "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
380 .to_string(),
381 );
382 }
383 Ok(SurvivalBaselineConfig {
384 target,
385 scale: None,
386 shape: Some(shape),
387 rate: Some(rate),
388 makeham: None,
389 })
390 }
391 SurvivalBaselineTarget::GompertzMakeham => {
392 let rate = rate.unwrap_or(0.5);
393 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
394 let makeham = makeham.unwrap_or(0.5);
395 if !rate.is_finite()
396 || rate <= 0.0
397 || !shape.is_finite()
398 || !makeham.is_finite()
399 || makeham <= 0.0
400 {
401 return Err(
402 "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
403 .to_string(),
404 );
405 }
406 Ok(SurvivalBaselineConfig {
407 target,
408 scale: None,
409 shape: Some(shape),
410 rate: Some(rate),
411 makeham: Some(makeham),
412 })
413 }
414 }
415}
416
417pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
422 match raw.to_ascii_lowercase().as_str() {
423 "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
424 "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
425 "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
426 "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
427 "latent" => Ok(SurvivalLikelihoodMode::Latent),
428 "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
429 other => Err(SurvivalConstructionError::UnsupportedDistribution {
430 reason: format!(
431 "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
432 ),
433 }
434 .into()),
435 }
436}
437
438pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
439 match mode {
440 SurvivalLikelihoodMode::Transformation => "transformation",
441 SurvivalLikelihoodMode::Weibull => "weibull",
442 SurvivalLikelihoodMode::LocationScale => "location-scale",
443 SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
444 SurvivalLikelihoodMode::Latent => "latent",
445 SurvivalLikelihoodMode::LatentBinary => "latent-binary",
446 }
447}
448
449pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
450 match raw.to_ascii_lowercase().as_str() {
451 "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
452 "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
453 "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
454 other => Err(SurvivalConstructionError::UnsupportedDistribution {
455 reason: format!(
456 "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
457 ),
458 }
459 .into()),
460 }
461}
462
463pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
464 match target {
465 SurvivalBaselineTarget::Linear => "linear",
466 SurvivalBaselineTarget::Weibull => "weibull",
467 SurvivalBaselineTarget::Gompertz => "gompertz",
468 SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
469 }
470}
471
472pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
473 let sum = age_exit
474 .iter()
475 .copied()
476 .filter(|value| value.is_finite() && *value > 0.0)
477 .sum::<f64>();
478 let count = age_exit
479 .iter()
480 .filter(|value| value.is_finite() && **value > 0.0)
481 .count()
482 .max(1);
483 (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
484}
485
486pub fn initial_survival_baseline_config_for_fit(
487 target_raw: &str,
488 scale: Option<f64>,
489 shape: Option<f64>,
490 rate: Option<f64>,
491 makeham: Option<f64>,
492 age_exit: &Array1<f64>,
493) -> Result<SurvivalBaselineConfig, String> {
494 let target = match target_raw.trim().to_ascii_lowercase().as_str() {
495 "linear" => SurvivalBaselineTarget::Linear,
496 "weibull" => SurvivalBaselineTarget::Weibull,
497 "gompertz" => SurvivalBaselineTarget::Gompertz,
498 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
499 other => {
500 return Err(SurvivalConstructionError::UnsupportedDistribution {
501 reason: format!(
502 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
503 ),
504 }
505 .into());
506 }
507 };
508 let time_scale_seed = positive_survival_time_seed(age_exit);
509 let cfg = match target {
510 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
511 target,
512 scale: None,
513 shape: None,
514 rate: None,
515 makeham: None,
516 },
517 SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
518 target,
519 scale: Some(scale.unwrap_or(time_scale_seed)),
520 shape: Some(shape.unwrap_or(1.0)),
521 rate: None,
522 makeham: None,
523 },
524 SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
525 target,
526 scale: None,
527 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
528 rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
529 makeham: None,
530 },
531 SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
532 target,
533 scale: None,
534 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
535 rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
536 makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
537 },
538 };
539 parse_survival_baseline_config(
540 survival_baseline_targetname(cfg.target),
541 cfg.scale,
542 cfg.shape,
543 cfg.rate,
544 cfg.makeham,
545 )
546}
547
548fn survival_baseline_theta_from_config(
549 cfg: &SurvivalBaselineConfig,
550) -> Result<Option<Array1<f64>>, String> {
551 Ok(match cfg.target {
552 SurvivalBaselineTarget::Linear => None,
553 SurvivalBaselineTarget::Weibull => Some(array![
554 cfg.scale
555 .ok_or_else(|| "missing weibull baseline scale".to_string())?
556 .ln(),
557 cfg.shape
558 .ok_or_else(|| "missing weibull baseline shape".to_string())?
559 .ln(),
560 ]),
561 SurvivalBaselineTarget::Gompertz => Some(array![
562 cfg.rate
563 .ok_or_else(|| "missing gompertz baseline rate".to_string())?
564 .ln(),
565 cfg.shape
566 .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
567 ]),
568 SurvivalBaselineTarget::GompertzMakeham => Some(array![
569 cfg.rate
570 .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
571 .ln(),
572 cfg.shape
573 .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
574 cfg.makeham
575 .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
576 .ln(),
577 ]),
578 })
579}
580
581fn survival_baseline_config_from_theta(
582 target: SurvivalBaselineTarget,
583 theta: &Array1<f64>,
584) -> Result<SurvivalBaselineConfig, String> {
585 let cfg = match target {
586 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
587 target,
588 scale: None,
589 shape: None,
590 rate: None,
591 makeham: None,
592 },
593 SurvivalBaselineTarget::Weibull => {
594 if theta.len() != 2 {
595 return Err(SurvivalConstructionError::IncompatibleDimensions {
596 reason: format!(
597 "weibull baseline parameter dimension mismatch: expected 2, got {}",
598 theta.len()
599 ),
600 }
601 .into());
602 }
603 SurvivalBaselineConfig {
604 target,
605 scale: Some(theta[0].exp()),
606 shape: Some(theta[1].exp()),
607 rate: None,
608 makeham: None,
609 }
610 }
611 SurvivalBaselineTarget::Gompertz => {
612 if theta.len() != 2 {
613 return Err(SurvivalConstructionError::IncompatibleDimensions {
614 reason: format!(
615 "gompertz baseline parameter dimension mismatch: expected 2, got {}",
616 theta.len()
617 ),
618 }
619 .into());
620 }
621 SurvivalBaselineConfig {
622 target,
623 scale: None,
624 shape: Some(theta[1]),
625 rate: Some(theta[0].exp()),
626 makeham: None,
627 }
628 }
629 SurvivalBaselineTarget::GompertzMakeham => {
630 if theta.len() != 3 {
631 return Err(SurvivalConstructionError::IncompatibleDimensions {
632 reason: format!(
633 "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
634 theta.len()
635 ),
636 }
637 .into());
638 }
639 SurvivalBaselineConfig {
640 target,
641 scale: None,
642 shape: Some(theta[1]),
643 rate: Some(theta[0].exp()),
644 makeham: Some(theta[2].exp()),
645 }
646 }
647 };
648 parse_survival_baseline_config(
649 survival_baseline_targetname(cfg.target),
650 cfg.scale,
651 cfg.shape,
652 cfg.rate,
653 cfg.makeham,
654 )
655}
656
657#[derive(Clone, Copy, Debug, PartialEq, Eq)]
670enum BaselineDerivativeContract {
671 GradientOnly,
674 GradientHessian,
678}
679
680impl BaselineDerivativeContract {
681 fn configure(
686 self,
687 problem: gam_solve::rho_optimizer::OuterProblem,
688 ) -> gam_solve::rho_optimizer::OuterProblem {
689 use gam_problem::{DeclaredHessianForm, Derivative};
690 match self {
691 BaselineDerivativeContract::GradientOnly => problem
694 .with_gradient(Derivative::Analytic)
695 .with_hessian(DeclaredHessianForm::Unavailable)
696 .with_tolerance(1e-4)
697 .with_max_iter(240),
698 BaselineDerivativeContract::GradientHessian => problem
699 .with_gradient(Derivative::Analytic)
700 .with_hessian(DeclaredHessianForm::Either)
701 .with_tolerance(1e-4)
702 .with_max_iter(240),
703 }
704 }
705}
706
707fn run_baseline_theta_optimizer<Fc, Fe>(
718 initial: &SurvivalBaselineConfig,
719 context: &str,
720 contract: BaselineDerivativeContract,
721 cost_fn: Fc,
722 eval_fn: Fe,
723) -> Result<SurvivalBaselineConfig, String>
724where
725 Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
726 Fe: FnMut(
727 &mut (),
728 &Array1<f64>,
729 ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
730{
731 use gam_solve::rho_optimizer::OuterProblem;
732 let Some(seed) = survival_baseline_theta_from_config(initial)? else {
733 return Ok(initial.clone());
734 };
735 let dim = seed.len();
736 let target = initial.target;
737 let lower = seed.mapv(|v| v - 6.0);
738 let upper = seed.mapv(|v| v + 6.0);
739 let problem = contract
740 .configure(OuterProblem::new(dim))
741 .with_bounds(lower, upper)
742 .with_initial_rho(seed.clone())
743 .with_seed_config(crate::seeding::SeedConfig {
744 max_seeds: 1,
745 seed_budget: 1,
746 num_auxiliary_trailing: dim,
747 ..Default::default()
748 });
749 let mut obj = problem.build_objective(
750 (),
751 cost_fn,
752 eval_fn,
753 None::<fn(&mut ())>,
754 None::<
755 fn(
756 &mut (),
757 &Array1<f64>,
758 ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
759 >,
760 );
761 let result = problem
762 .run(&mut obj, context)
763 .map_err(|e| format!("{context} failed: {e}"))?;
764 if !result.converged {
765 return Err(SurvivalConstructionError::InvalidConfig {
766 reason: format!(
767 "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
768 result.iterations,
769 result.final_value,
770 result.final_grad_norm_report(),
771 ),
772 }
773 .into());
774 }
775 survival_baseline_config_from_theta(target, &result.rho)
776}
777
778fn run_baseline_theta_optimizer_with_eval<F>(
791 initial: &SurvivalBaselineConfig,
792 context: &str,
793 contract: BaselineDerivativeContract,
794 objective: F,
795) -> Result<SurvivalBaselineConfig, String>
796where
797 F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
798{
799 let target = initial.target;
800 let engine_context = context.to_string();
801 let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
802 let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
803 theta: &Array1<f64>|
804 -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
805 let cfg = survival_baseline_config_from_theta(target, theta)
806 .map_err(crate::model_types::EstimationError::InvalidInput)?;
807 let eval =
808 obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
809 if eval.gradient.len() != theta.len() {
810 return Err(crate::model_types::EstimationError::InvalidInput(format!(
811 "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
812 eval.gradient.len(),
813 theta.len()
814 )));
815 }
816 if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
817 if h.nrows() != theta.len() || h.ncols() != theta.len() {
818 return Err(crate::model_types::EstimationError::InvalidInput(format!(
819 "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
820 h.nrows(),
821 h.ncols(),
822 theta.len(),
823 theta.len()
824 )));
825 }
826 }
827 Ok(eval)
828 };
829 let cost_objective = std::rc::Rc::clone(&objective);
830 let cost_eval = eval_at.clone();
831 let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
832 cost_eval(&cost_objective, theta).map(|eval| eval.cost)
833 };
834 let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
835 run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
836}
837
838pub fn optimize_survival_baseline_config_with_gradient_only<F>(
849 initial: &SurvivalBaselineConfig,
850 context: &str,
851 mut objective: F,
852) -> Result<SurvivalBaselineConfig, String>
853where
854 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
855{
856 use gam_problem::{HessianResult, OuterEval};
857 run_baseline_theta_optimizer_with_eval(
858 initial,
859 context,
860 BaselineDerivativeContract::GradientOnly,
861 move |cfg| {
862 let (cost, gradient) = objective(cfg)?;
863 Ok(OuterEval {
864 cost,
865 gradient,
866 hessian: HessianResult::Unavailable,
867 inner_beta_hint: None,
868 })
869 },
870 )
871}
872
873pub fn optimize_survival_baseline_config_with_gradient<F>(
878 initial: &SurvivalBaselineConfig,
879 context: &str,
880 mut objective: F,
881) -> Result<SurvivalBaselineConfig, String>
882where
883 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
884{
885 use gam_problem::{HessianResult, OuterEval};
886 run_baseline_theta_optimizer_with_eval(
887 initial,
888 context,
889 BaselineDerivativeContract::GradientHessian,
890 move |cfg| {
891 let (cost, gradient, hessian) = objective(cfg)?;
892 Ok(OuterEval {
893 cost,
894 gradient,
895 hessian: HessianResult::Analytic(hessian),
896 inner_beta_hint: None,
897 })
898 },
899 )
900}
901
902pub fn parse_survival_time_basis_config(
907 time_basis: &str,
908 time_degree: usize,
909 time_num_internal_knots: usize,
910 time_smooth_lambda: f64,
911) -> Result<SurvivalTimeBasisConfig, String> {
912 match time_basis.to_ascii_lowercase().as_str() {
913 "none" => Ok(SurvivalTimeBasisConfig::None),
914 "ispline" => {
915 if time_degree < 1 {
916 return Err(
917 "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
918 .to_string(),
919 );
920 }
921 if time_num_internal_knots == 0 {
922 return Err(
923 "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
924 .to_string(),
925 );
926 }
927 if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
928 return Err(
929 "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
930 .to_string(),
931 );
932 }
933 Ok(SurvivalTimeBasisConfig::ISpline {
934 degree: time_degree,
935 knots: Array1::zeros(0),
936 keep_cols: Vec::new(),
937 smooth_lambda: time_smooth_lambda,
938 })
939 }
940 "linear" | "bspline" => {
941 match require_structural_survival_time_basis(time_basis, "survival model configuration")
948 {
949 Err(e) => Err(e),
950 Ok(()) => Err(format!(
951 "internal: structural-basis check accepted non-structural \
952 survival time basis '{time_basis}'"
953 )),
954 }
955 }
956 other => Err(format!(
957 "unsupported --time-basis '{other}'; accepted values: ispline, none"
958 )),
959 }
960}
961
962pub fn build_survival_time_basis(
967 age_entry: &Array1<f64>,
968 age_exit: &Array1<f64>,
969 cfg: SurvivalTimeBasisConfig,
970 infer_knots_if_needed: Option<(usize, f64)>,
971) -> Result<SurvivalTimeBuildOutput, String> {
972 fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
973 if let Some(row) = times.iter().position(|t| !t.is_finite()) {
974 return Err(SurvivalConstructionError::DataValidationFailed {
975 reason: format!(
976 "survival time basis requires finite {label} times (row {})",
977 row + 1
978 ),
979 }
980 .into());
981 }
982 if let Some(row) = times.iter().position(|t| *t < 0.0) {
983 return Err(SurvivalConstructionError::DataValidationFailed {
984 reason: format!(
985 "survival time basis requires non-negative {label} times (row {})",
986 row + 1
987 ),
988 }
989 .into());
990 }
991 Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
992 }
993
994 let n = age_entry.len();
995 if n != age_exit.len() {
996 return Err(SurvivalConstructionError::IncompatibleDimensions {
997 reason: "survival time basis requires matching entry/exit lengths".to_string(),
998 }
999 .into());
1000 }
1001 for i in 0..n {
1002 if age_exit[i] < age_entry[i] {
1003 return Err(format!(
1004 "survival time basis requires exit times >= entry times (row {})",
1005 i + 1
1006 ));
1007 }
1008 }
1009 let log_entry = checked_log_survival_times(age_entry, "entry")?;
1010 let log_exit = checked_log_survival_times(age_exit, "exit")?;
1011
1012 fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1013 let n = log_entry.len();
1014 let entry_range = log_entry
1015 .iter()
1016 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1017 (lo.min(v), hi.max(v))
1018 });
1019 let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1020 if entry_degenerate {
1021 log_exit.clone()
1022 } else {
1023 let mut combined = Array1::<f64>::zeros(2 * n);
1024 for i in 0..n {
1025 combined[i] = log_entry[i];
1026 combined[n + i] = log_exit[i];
1027 }
1028 combined
1029 }
1030 }
1031
1032 fn data_capped_internal_knots(
1055 combined: &Array1<f64>,
1056 degree: usize,
1057 requested_internal_knots: usize,
1058 ) -> usize {
1059 if requested_internal_knots == 0 {
1060 return 0;
1061 }
1062 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1063 sorted.sort_by(f64::total_cmp);
1064 let minval = sorted.first().copied().unwrap_or(0.0);
1065 let maxval = sorted.last().copied().unwrap_or(minval);
1066 if minval == maxval {
1067 return 1.min(requested_internal_knots);
1069 }
1070 let scale = (maxval - minval).abs().max(1.0);
1071 let tol = 1e-12 * scale;
1072 let mut distinct_interior = 0usize;
1075 let mut last: Option<f64> = None;
1076 for &x in &sorted {
1077 if x <= minval + tol || x >= maxval - tol {
1078 continue;
1079 }
1080 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1081 continue;
1082 }
1083 distinct_interior += 1;
1084 last = Some(x);
1085 }
1086 let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1089 let n_distinct = {
1095 let mut count = 0usize;
1096 let mut last: Option<f64> = None;
1097 for &x in &sorted {
1098 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1099 continue;
1100 }
1101 count += 1;
1102 last = Some(x);
1103 }
1104 count
1105 };
1106 let dim_budget = n_distinct / 4;
1107 let dim_cap = dim_budget.saturating_sub(degree);
1108 cap = cap.min(dim_cap.max(1));
1109 cap.max(1)
1110 }
1111
1112 fn infer_survival_time_knots(
1113 combined: &Array1<f64>,
1114 knot_degree: usize,
1115 validation_degree: usize,
1116 num_internal_knots: usize,
1117 basis_options: BasisOptions,
1118 ) -> Result<Array1<f64>, String> {
1119 let num_internal_knots =
1125 data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1126
1127 fn quantile_knot_inference_needs_uniform_fallback(
1128 combined: &Array1<f64>,
1129 num_internal_knots: usize,
1130 ) -> bool {
1131 if num_internal_knots == 0 || combined.is_empty() {
1132 return false;
1133 }
1134
1135 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1136 sorted.sort_by(f64::total_cmp);
1137 let minval = sorted[0];
1138 let maxval = *sorted.last().unwrap_or(&minval);
1139 if minval == maxval {
1140 return false;
1141 }
1142
1143 let scale = (maxval - minval).abs().max(1.0);
1144 let tol = 1e-12 * scale;
1145 let mut support = Vec::with_capacity(sorted.len());
1146 let mut last: Option<f64> = None;
1147 for &x in &sorted {
1148 if x <= minval + tol || x >= maxval - tol {
1149 continue;
1150 }
1151 if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1152 continue;
1153 }
1154 support.push(x);
1155 last = Some(x);
1156 }
1157 if support.is_empty() {
1158 return true;
1159 }
1160
1161 let n = support.len();
1162 let mut prev_q = minval;
1163 for j in 1..=num_internal_knots {
1164 let p = j as f64 / (num_internal_knots + 1) as f64;
1165 let pos = p * (n.saturating_sub(1) as f64);
1166 let lo = pos.floor() as usize;
1167 let hi = pos.ceil() as usize;
1168 let frac = pos - lo as f64;
1169 let q = if lo == hi {
1170 support[lo]
1171 } else {
1172 support[lo] * (1.0 - frac) + support[hi] * frac
1173 }
1174 .clamp(minval, maxval);
1175 if q <= prev_q + tol || q >= maxval - tol {
1176 return true;
1177 }
1178 prev_q = q;
1179 }
1180
1181 false
1182 }
1183
1184 let inferwith =
1185 |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1186 let built = build_bspline_basis_1d(
1187 combined.view(),
1188 &BSplineBasisSpec {
1189 degree: knot_degree,
1190 penalty_order: 2,
1191 knotspec: BSplineKnotSpec::Automatic {
1192 num_internal_knots: Some(num_internal_knots),
1193 placement,
1194 },
1195 double_penalty: false,
1196 identifiability: BSplineIdentifiability::None,
1197 boundary: OneDimensionalBoundary::Open,
1198 boundary_conditions: BSplineBoundaryConditions::default(),
1199 },
1200 )
1201 .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1202 let knots = match built.metadata {
1203 BasisMetadata::BSpline1D { knots, .. } => knots,
1204 _ => {
1205 return Err(
1206 "internal error: expected BSpline1D metadata for survival time basis"
1207 .to_string(),
1208 );
1209 }
1210 };
1211 create_basis::<Dense>(
1220 combined.view(),
1221 KnotSource::Provided(knots.view()),
1222 validation_degree,
1223 basis_options,
1224 )
1225 .map_err(|e| e.to_string())?;
1226 Ok(knots)
1227 };
1228
1229 if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1230 inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1231 } else {
1232 inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1233 }
1234 }
1235
1236 match cfg {
1237 SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1238 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1239 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1240 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1241 penalties: Vec::new(),
1242 nullspace_dims: Vec::new(),
1243 basisname: "none".to_string(),
1244 degree: None,
1245 knots: None,
1246 keep_cols: None,
1247 smooth_lambda: None,
1248 }),
1249 SurvivalTimeBasisConfig::Linear => {
1250 let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1251 let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1252 let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1253 for i in 0..n {
1254 x_entry_time[[i, 0]] = 1.0;
1255 x_exit_time[[i, 0]] = 1.0;
1256 x_entry_time[[i, 1]] = log_entry[i];
1257 x_exit_time[[i, 1]] = log_exit[i];
1258 x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1259 }
1260 Ok(SurvivalTimeBuildOutput {
1261 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1262 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1263 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1264 penalties: Vec::new(),
1265 nullspace_dims: Vec::new(),
1266 basisname: "linear".to_string(),
1267 degree: None,
1268 knots: None,
1269 keep_cols: None,
1270 smooth_lambda: None,
1271 })
1272 }
1273 SurvivalTimeBasisConfig::BSpline {
1274 degree,
1275 knots,
1276 smooth_lambda,
1277 } => {
1278 let knotvec = if knots.is_empty() {
1279 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1280 "internal error: bspline time basis requested without knot source".to_string()
1281 })?;
1282 let combined = survival_time_knot_input(&log_entry, &log_exit);
1283 infer_survival_time_knots(
1284 &combined,
1285 degree,
1286 degree,
1287 num_internal_knots,
1288 BasisOptions::value(),
1289 )?
1290 } else {
1291 knots
1292 };
1293
1294 let entry_basis = build_bspline_basis_1d(
1295 log_entry.view(),
1296 &BSplineBasisSpec {
1297 degree,
1298 penalty_order: 2,
1299 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1300 double_penalty: false,
1301 identifiability: BSplineIdentifiability::None,
1302 boundary: OneDimensionalBoundary::Open,
1303 boundary_conditions: BSplineBoundaryConditions::default(),
1304 },
1305 )
1306 .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1307 let exit_basis = build_bspline_basis_1d(
1308 log_exit.view(),
1309 &BSplineBasisSpec {
1310 degree,
1311 penalty_order: 2,
1312 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1313 double_penalty: false,
1314 identifiability: BSplineIdentifiability::None,
1315 boundary: OneDimensionalBoundary::Open,
1316 boundary_conditions: BSplineBoundaryConditions::default(),
1317 },
1318 )
1319 .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1320
1321 let p_time = exit_basis.design.ncols();
1322 let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1326 let mut deriv_buf = vec![0.0_f64; p_time];
1327 for i in 0..n {
1328 deriv_buf.fill(0.0);
1329 evaluate_bspline_derivative_scalar(
1330 log_exit[i],
1331 knotvec.view(),
1332 degree,
1333 &mut deriv_buf,
1334 )
1335 .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1336 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1337 for j in 0..p_time {
1338 let v = deriv_buf[j] * chain;
1339 if v.abs() > 1e-15 {
1340 deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1341 }
1342 }
1343 }
1344 let x_derivative_time =
1345 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1346 {
1347 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1348 Err(_) => {
1349 let mut dense = Array2::<f64>::zeros((n, p_time));
1351 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1352 dense[[row, col]] = val;
1353 }
1354 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1355 }
1356 };
1357
1358 Ok(SurvivalTimeBuildOutput {
1359 x_entry_time: entry_basis.design,
1360 x_exit_time: exit_basis.design,
1361 x_derivative_time,
1362 nullspace_dims: entry_basis.nullspace_dims,
1363 penalties: entry_basis.penalties,
1364 basisname: "bspline".to_string(),
1365 degree: Some(degree),
1366 knots: Some(knotvec.to_vec()),
1367 keep_cols: None,
1368 smooth_lambda: Some(smooth_lambda),
1369 })
1370 }
1371 SurvivalTimeBasisConfig::ISpline {
1372 degree,
1373 knots,
1374 keep_cols,
1375 smooth_lambda,
1376 } => {
1377 let bspline_degree = degree
1378 .checked_add(1)
1379 .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1380 let knotvec = if knots.is_empty() {
1381 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1382 "internal error: ispline time basis requested without knot source".to_string()
1383 })?;
1384 let combined = survival_time_knot_input(&log_entry, &log_exit);
1385 infer_survival_time_knots(
1386 &combined,
1387 bspline_degree,
1388 degree,
1389 num_internal_knots,
1390 BasisOptions::i_spline(),
1391 )?
1392 } else {
1393 knots
1394 };
1395
1396 let (db_exit_arc, _) = create_basis::<Dense>(
1397 log_exit.view(),
1398 KnotSource::Provided(knotvec.view()),
1399 bspline_degree,
1400 BasisOptions::first_derivative(),
1401 )
1402 .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1403
1404 let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1407 let (entry_arc, _) = create_basis::<Dense>(
1408 log_entry.view(),
1409 KnotSource::Provided(knotvec.view()),
1410 degree,
1411 BasisOptions::i_spline(),
1412 )
1413 .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1414 let (exit_arc, _) = create_basis::<Dense>(
1415 log_exit.view(),
1416 KnotSource::Provided(knotvec.view()),
1417 degree,
1418 BasisOptions::i_spline(),
1419 )
1420 .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1421
1422 let x_entry_full = entry_arc.as_ref();
1423 let x_exit_full = exit_arc.as_ref();
1424 let p_time_full = x_exit_full.ncols();
1425 if p_time_full == 0 {
1426 return Err(SurvivalConstructionError::BasisConstructionFailed {
1427 reason: "internal error: empty ispline time basis".to_string(),
1428 }
1429 .into());
1430 }
1431 let db_exit = db_exit_arc.as_ref();
1432 if db_exit.ncols() != p_time_full + 1 {
1433 return Err(
1434 "internal error: ispline derivative basis width must exceed basis width by one"
1435 .to_string(),
1436 );
1437 }
1438
1439 let keep_cols = if keep_cols.is_empty() {
1440 let constant_tol = 1e-12_f64;
1441 let mut inferred_keep_cols: Vec<usize> = Vec::new();
1442 for j in 0..p_time_full {
1443 let mut minv = f64::INFINITY;
1444 let mut maxv = f64::NEG_INFINITY;
1445 for i in 0..n {
1446 let ve = x_exit_full[[i, j]];
1447 let vs = x_entry_full[[i, j]];
1448 minv = minv.min(ve.min(vs));
1449 maxv = maxv.max(ve.max(vs));
1450 }
1451 if (maxv - minv) > constant_tol {
1452 inferred_keep_cols.push(j);
1453 }
1454 }
1455 inferred_keep_cols
1456 } else {
1457 keep_cols
1458 };
1459 if keep_cols.is_empty() {
1460 return Err(
1461 "internal error: ispline basis has no shape-varying time columns"
1462 .to_string(),
1463 );
1464 }
1465 if keep_cols.iter().any(|&j| j >= p_time_full) {
1466 return Err(SurvivalConstructionError::MissingColumn {
1467 reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1468 }
1469 .into());
1470 }
1471
1472 let p_time = keep_cols.len();
1473 let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1474 let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1475 (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1478 };
1479 let db_exit = db_exit_arc.as_ref();
1480
1481 let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1486 let mut found_nonfinite: Option<(usize, usize)> = None;
1487 for i in 0..n {
1488 let mut running = 0.0_f64;
1489 let mut d_i_log_full = vec![0.0_f64; p_time_full];
1490 for j in (1..db_exit.ncols()).rev() {
1491 let term = db_exit[[i, j]];
1492 if term.is_finite() {
1493 running += term;
1494 }
1495 d_i_log_full[j - 1] = running;
1496 }
1497 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1498 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1499 let raw_v = d_i_log_full[j_old] * chain;
1500 let v = if (-1e-12..0.0).contains(&raw_v) {
1501 0.0
1502 } else {
1503 raw_v
1504 };
1505 if !v.is_finite() {
1506 found_nonfinite = Some((i, j_new));
1507 }
1508 if v < -1e-12 {
1509 return Err(format!(
1510 "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1511 i + 1,
1512 j_new + 1,
1513 v
1514 ));
1515 }
1516 if v.abs() > 1e-15 {
1517 deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1518 }
1519 }
1520 }
1521 if let Some((row, col)) = found_nonfinite {
1522 return Err(format!(
1523 "survival ispline derivative basis produced non-finite value at row {}, column {}",
1524 row + 1,
1525 col + 1
1526 ));
1527 }
1528 let x_derivative_time =
1529 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1530 {
1531 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1532 Err(_) => {
1533 let mut dense = Array2::<f64>::zeros((n, p_time));
1534 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1535 dense[[row, col]] = val;
1536 }
1537 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1538 }
1539 };
1540
1541 let penalty_basis = build_bspline_basis_1d(
1542 log_exit.view(),
1543 &BSplineBasisSpec {
1544 degree: bspline_degree,
1545 penalty_order: 2,
1546 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1547 double_penalty: false,
1548 identifiability: BSplineIdentifiability::None,
1549 boundary: OneDimensionalBoundary::Open,
1550 boundary_conditions: BSplineBoundaryConditions::default(),
1551 },
1552 )
1553 .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1554 if penalty_basis.design.ncols() != p_time_full + 1 {
1555 return Err("internal error: ispline penalty dimension mismatch".to_string());
1556 }
1557 let mut penalties = Vec::<Array2<f64>>::new();
1591 for s_mat in &penalty_basis.penalties {
1592 if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1593 continue;
1594 }
1595 let s_increment = s_mat.slice(s![1.., 1..]);
1624 if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1625 return Err(format!(
1626 "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1627 s_increment.nrows(),
1628 s_increment.ncols(),
1629 ));
1630 }
1631 let mut s_full = s_increment.to_owned();
1636 symmetrize_in_place(&mut s_full);
1637 let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1641 for i in 0..p_time_full {
1642 for j in 0..p_time_full {
1643 let mut v = 0.0;
1644 for k in j..p_time_full {
1645 v += s_full[[i, k]];
1646 }
1647 s_mid_full[[i, j]] = v;
1648 }
1649 }
1650 let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1654 for i in 0..p_time_full {
1655 for j in 0..p_time_full {
1656 let mut v = 0.0;
1657 for k in i..p_time_full {
1658 v += s_mid_full[[k, j]];
1659 }
1660 s_full_congruent[[i, j]] = v;
1661 }
1662 }
1663 let mut local = Array2::<f64>::zeros((p_time, p_time));
1665 for (i_new, &i_old) in keep_cols.iter().enumerate() {
1666 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1667 local[[i_new, j_new]] = 0.5
1670 * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1671 }
1672 }
1673 penalties.push(local);
1674 }
1675
1676 for (idx, s_mat) in penalties.iter().enumerate() {
1686 let p = s_mat.nrows();
1687 if p == 0 {
1688 continue;
1689 }
1690 if let Ok((evals, _)) =
1691 gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1692 {
1693 let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1694 "internal error: ispline penalty eigenvalues not contiguous".to_string()
1695 })?;
1696 let max_ev = evals_slice
1697 .iter()
1698 .copied()
1699 .fold(0.0_f64, |a, b| a.max(b.abs()))
1700 .max(1.0);
1701 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1702 let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1703 if min_ev < neg_tol {
1704 return Err(format!(
1705 "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1706 indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1707 {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1708 ));
1709 }
1710 }
1711 }
1712
1713 let nullspace_dims: Vec<usize> = penalties
1717 .iter()
1718 .map(|s_mat| {
1719 let p = s_mat.nrows();
1720 if p == 0 {
1721 return 0;
1722 }
1723 match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1724 Ok((evals, _)) => {
1725 let evals_slice: &[f64] = evals.as_slice().unwrap();
1726 let max_ev = evals_slice
1727 .iter()
1728 .copied()
1729 .fold(0.0_f64, |a, b| a.max(b.abs()))
1730 .max(1.0);
1731 let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1732 evals_slice.iter().filter(|&&e| e <= threshold).count()
1733 }
1734 Err(_) => 0,
1735 }
1736 })
1737 .collect();
1738 Ok(SurvivalTimeBuildOutput {
1739 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1740 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1741 x_derivative_time,
1742 penalties,
1743 nullspace_dims,
1744 basisname: "ispline".to_string(),
1745 degree: Some(degree),
1746 knots: Some(knotvec.to_vec()),
1747 keep_cols: Some(keep_cols),
1748 smooth_lambda: Some(smooth_lambda),
1749 })
1750 }
1751 }
1752}
1753
1754pub fn resolved_survival_time_basis_config_from_build(
1755 basisname: &str,
1756 degree: Option<usize>,
1757 knots: Option<&Vec<f64>>,
1758 keep_cols: Option<&Vec<usize>>,
1759 smooth_lambda: Option<f64>,
1760) -> Result<SurvivalTimeBasisConfig, String> {
1761 match basisname {
1762 "none" => Ok(SurvivalTimeBasisConfig::None),
1763 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1764 "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1765 degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1766 knots: Array1::from_vec(
1767 knots
1768 .cloned()
1769 .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1770 ),
1771 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1772 }),
1773 "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1774 degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1775 knots: Array1::from_vec(
1776 knots
1777 .cloned()
1778 .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1779 ),
1780 keep_cols: keep_cols
1781 .cloned()
1782 .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1783 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1784 }),
1785 other => Err(format!("unsupported survival time basis '{other}'")),
1786 }
1787}
1788
1789pub fn resolve_survival_time_anchor_value(
1790 age_entry: &Array1<f64>,
1791 time_anchor: Option<f64>,
1792) -> Result<f64, String> {
1793 if age_entry.is_empty() {
1794 return Err("survival time anchor requires non-empty entry times".to_string());
1795 }
1796 let anchor = match time_anchor {
1797 Some(t_anchor) => {
1798 if !t_anchor.is_finite() || t_anchor < 0.0 {
1799 return Err(format!(
1800 "survival time anchor must be finite and non-negative, got {t_anchor}"
1801 ));
1802 }
1803 t_anchor
1804 }
1805 None => age_entry
1806 .iter()
1807 .copied()
1808 .min_by(f64::total_cmp)
1809 .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1810 };
1811 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1812}
1813
1814pub fn resolve_survival_marginal_slope_time_anchor_value(
1846 age_entry: &Array1<f64>,
1847 age_exit: &Array1<f64>,
1848 time_anchor: Option<f64>,
1849) -> Result<f64, String> {
1850 if age_entry.is_empty() || age_exit.is_empty() {
1851 return Err(
1852 "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1853 );
1854 }
1855 let anchor = match time_anchor {
1856 Some(t_anchor) => {
1857 if !t_anchor.is_finite() || t_anchor < 0.0 {
1858 return Err(format!(
1859 "survival time anchor must be finite and non-negative, got {t_anchor}"
1860 ));
1861 }
1862 t_anchor
1863 }
1864 None => {
1865 let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1866 sorted.sort_by(f64::total_cmp);
1867 let m = sorted.len();
1868 if m % 2 == 1 {
1869 sorted[m / 2]
1870 } else {
1871 0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1872 }
1873 }
1874 };
1875 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1876}
1877
1878pub fn evaluate_survival_time_basis_row(
1879 age: f64,
1880 cfg: &SurvivalTimeBasisConfig,
1881) -> Result<Array1<f64>, String> {
1882 if !age.is_finite() || age < 0.0 {
1883 return Err(format!(
1884 "survival time basis row requires finite non-negative age, got {age}"
1885 ));
1886 }
1887 let age = age.max(SURVIVAL_TIME_FLOOR);
1888 let log_age = array![age.ln()];
1889 match cfg {
1890 SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1891 SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1892 SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1893 if knots.is_empty() {
1894 return Err(
1895 "survival BSpline anchor evaluation requires resolved knot metadata"
1896 .to_string(),
1897 );
1898 }
1899 let built = build_bspline_basis_1d(
1900 log_age.view(),
1901 &BSplineBasisSpec {
1902 degree: *degree,
1903 penalty_order: 2,
1904 knotspec: BSplineKnotSpec::Provided(knots.clone()),
1905 double_penalty: false,
1906 identifiability: BSplineIdentifiability::None,
1907 boundary: OneDimensionalBoundary::Open,
1908 boundary_conditions: BSplineBoundaryConditions::default(),
1909 },
1910 )
1911 .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1912 Ok(built.design.to_dense().row(0).to_owned())
1913 }
1914 SurvivalTimeBasisConfig::ISpline {
1915 degree,
1916 knots,
1917 keep_cols,
1918 ..
1919 } => {
1920 if knots.is_empty() {
1921 return Err(
1922 "survival ISpline anchor evaluation requires resolved knot metadata"
1923 .to_string(),
1924 );
1925 }
1926 let (basis_arc, _) = create_basis::<Dense>(
1927 log_age.view(),
1928 KnotSource::Provided(knots.view()),
1929 *degree,
1930 BasisOptions::i_spline(),
1931 )
1932 .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1933 let basis = basis_arc.as_ref();
1934 let row = basis.row(0);
1935 if keep_cols.is_empty() {
1936 return Ok(row.to_owned());
1937 }
1938 if keep_cols.iter().any(|&j| j >= row.len()) {
1939 return Err(SurvivalConstructionError::MissingColumn {
1940 reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
1941 }
1942 .into());
1943 }
1944 Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
1945 }
1946 }
1947}
1948
1949pub fn center_survival_time_designs_at_anchor(
1950 design_entry: &mut DesignMatrix,
1951 design_exit: &mut DesignMatrix,
1952 anchor_row: &Array1<f64>,
1953) -> Result<(), String> {
1954 if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
1955 return Err(format!(
1956 "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
1957 design_entry.ncols(),
1958 design_exit.ncols(),
1959 anchor_row.len()
1960 ));
1961 }
1962 fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
1965 let mut dense = dm.to_dense();
1966 for mut row in dense.rows_mut() {
1967 row -= &anchor.view();
1968 }
1969 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
1970 }
1971 center_dense(design_entry, anchor_row);
1972 center_dense(design_exit, anchor_row);
1973 Ok(())
1974}
1975
1976pub fn baseline_offset_theta_partials(
2006 age: f64,
2007 cfg: &SurvivalBaselineConfig,
2008) -> Result<Option<Vec<(f64, f64)>>, String> {
2009 let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2010 else {
2011 return Ok(None);
2012 };
2013
2014 match params {
2015 ValidatedBaselineTarget::Weibull { scale, shape } => {
2016 let eta = shape * (age.ln() - scale.ln());
2025 let o_d = shape / age;
2026 let d_eta_d_log_scale = -shape;
2027 let d_od_d_log_scale = 0.0;
2028 let d_eta_d_log_shape = eta;
2029 let d_od_d_log_shape = o_d;
2030 Ok(Some(vec![
2031 (d_eta_d_log_scale, d_od_d_log_scale),
2032 (d_eta_d_log_shape, d_od_d_log_shape),
2033 ]))
2034 }
2035 ValidatedBaselineTarget::Gompertz { shape, .. } => {
2036 let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2046 Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2047 }
2048 ValidatedBaselineTarget::GompertzMakeham {
2049 rate,
2050 shape,
2051 makeham,
2052 } => {
2053 let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2068 let cum_total = makeham * age + cum_g;
2069 if cum_total <= 0.0 || !cum_total.is_finite() {
2070 return Err(SurvivalConstructionError::DataValidationFailed {
2071 reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2072 }
2073 .into());
2074 }
2075 let inst_total = makeham + inst_g;
2076 let o_d = inst_total / cum_total;
2077 let inv_cum = 1.0 / cum_total;
2078 let d_cum_dlr = cum_g;
2083 let d_inst_dlr = inst_g;
2084 let d_eta_dlr = d_cum_dlr * inv_cum;
2085 let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2086 let (d_cum_dshape, d_inst_dshape) =
2088 gompertz_cumulative_shape_derivative(age, rate, shape);
2089 let d_eta_dshape = d_cum_dshape * inv_cum;
2090 let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2091 let d_cum_dlm = makeham * age;
2094 let d_inst_dlm = makeham;
2095 let d_eta_dlm = d_cum_dlm * inv_cum;
2096 let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2097 Ok(Some(vec![
2098 (d_eta_dlr, d_od_dlr),
2099 (d_eta_dshape, d_od_dshape),
2100 (d_eta_dlm, d_od_dlm),
2101 ]))
2102 }
2103 }
2104}
2105
2106fn baseline_chain_rule_gradient_with_partials<F>(
2134 label: &'static str,
2135 age_entry: ndarray::ArrayView1<'_, f64>,
2136 age_exit: ndarray::ArrayView1<'_, f64>,
2137 age_right: ndarray::ArrayView1<'_, f64>,
2138 cfg: &SurvivalBaselineConfig,
2139 residuals: &crate::survival::OffsetChannelResiduals,
2140 partials: F,
2141) -> Result<Option<Array1<f64>>, String>
2142where
2143 F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2144{
2145 let n = age_exit.len();
2146 if age_entry.len() != n
2147 || age_right.len() != n
2148 || residuals.exit.len() != n
2149 || residuals.entry.len() != n
2150 || residuals.derivative.len() != n
2151 || residuals.right.len() != n
2152 {
2153 return Err(format!(
2154 "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2155 age_entry.len(),
2156 n,
2157 age_right.len(),
2158 residuals.exit.len(),
2159 residuals.entry.len(),
2160 residuals.derivative.len(),
2161 residuals.right.len(),
2162 ));
2163 }
2164 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2167 let theta_dim = match probe_age {
2168 Some(t) => match partials(t, cfg)? {
2169 None => return Ok(None),
2170 Some(v) => v.len(),
2171 },
2172 None => {
2173 return Err(format!("{label}: no valid positive age for dim probe"));
2174 }
2175 };
2176 let mut grad = Array1::<f64>::zeros(theta_dim);
2187 for i in 0..n {
2188 let partials_exit = partials(age_exit[i], cfg)?
2190 .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2191 if partials_exit.len() != theta_dim {
2192 return Err(format!(
2193 "{label}: theta_dim drifted ({} != {})",
2194 partials_exit.len(),
2195 theta_dim
2196 ));
2197 }
2198 let r_x = residuals.exit[i];
2199 let r_d = residuals.derivative[i];
2200 for k in 0..theta_dim {
2201 let (d_eta_dk, d_od_dk) = partials_exit[k];
2202 grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2203 }
2204 let r_e = residuals.entry[i];
2208 if r_e != 0.0 {
2209 let partials_entry = partials(age_entry[i], cfg)?
2210 .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2211 for k in 0..theta_dim {
2212 grad[k] += r_e * partials_entry[k].0;
2213 }
2214 }
2215 let r_r = residuals.right[i];
2224 if r_r != 0.0 {
2225 let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2226 format!("{label}: unexpected None from partials at right boundary")
2227 })?;
2228 if partials_right.len() != theta_dim {
2229 return Err(format!(
2230 "{label}: theta_dim drifted at right boundary ({} != {})",
2231 partials_right.len(),
2232 theta_dim
2233 ));
2234 }
2235 for k in 0..theta_dim {
2236 grad[k] += r_r * partials_right[k].0;
2237 }
2238 }
2239 }
2240 Ok(Some(grad))
2241}
2242
2243pub fn baseline_chain_rule_gradient(
2277 age_entry: ndarray::ArrayView1<'_, f64>,
2278 age_exit: ndarray::ArrayView1<'_, f64>,
2279 age_right: ndarray::ArrayView1<'_, f64>,
2280 cfg: &SurvivalBaselineConfig,
2281 residuals: &crate::survival::OffsetChannelResiduals,
2282) -> Result<Option<Array1<f64>>, String> {
2283 baseline_chain_rule_gradient_with_partials(
2284 "baseline_chain_rule_gradient",
2285 age_entry,
2286 age_exit,
2287 age_right,
2288 cfg,
2289 residuals,
2290 baseline_offset_theta_partials,
2291 )
2292}
2293
2294pub fn marginal_slope_baseline_chain_rule_gradient(
2301 age_entry: ndarray::ArrayView1<'_, f64>,
2302 age_exit: ndarray::ArrayView1<'_, f64>,
2303 cfg: &SurvivalBaselineConfig,
2304 residuals: &crate::survival::OffsetChannelResiduals,
2305) -> Result<Option<Array1<f64>>, String> {
2306 baseline_chain_rule_gradient_with_partials(
2310 "marginal_slope_baseline_chain_rule_gradient",
2311 age_entry,
2312 age_exit,
2313 age_exit,
2314 cfg,
2315 residuals,
2316 marginal_slope_baseline_offset_theta_partials,
2317 )
2318}
2319
2320#[inline]
2324fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2325 if shape.abs() < 1e-10 {
2326 let x = shape * age;
2329 (
2330 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2331 rate * (1.0 + x + 0.5 * x * x),
2332 )
2333 } else {
2334 let shape_age = shape * age;
2335 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2336 let instant_hazard = rate * shape_age.exp();
2337 (cumulative_hazard, instant_hazard)
2338 }
2339}
2340
2341#[inline]
2357fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2358 let x = shape * age;
2359 let dinstg_dshape = rate * age * x.exp();
2360 let dhg_dshape = if x.abs() < 1e-4 {
2369 let t = age;
2370 rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2372 } else {
2373 let e = x.exp();
2375 let em1 = x.exp_m1();
2376 let numerator = age * e * shape - em1;
2377 rate * numerator / (shape * shape)
2378 };
2379 (dhg_dshape, dinstg_dshape)
2380}
2381
2382#[inline]
2387fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2388 if shape.abs() < 1e-10 {
2389 let t = age;
2399 let d_eta = 0.5 * t + shape * t * t / 12.0;
2400 let dlog_od = 0.5 * t - shape * t * t / 12.0;
2401 let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2402 (d_eta, o_d * dlog_od)
2403 } else {
2404 let x = shape * age;
2405 let e = x.exp();
2406 let em1 = x.exp_m1(); let d_eta = -1.0 / shape + age * e / em1;
2408 let o_d = shape * e / em1;
2410 let dlog_od = 1.0 / shape - age / em1;
2411 (d_eta, o_d * dlog_od)
2412 }
2413}
2414
2415#[derive(Clone, Copy, Debug)]
2424enum ValidatedBaselineTarget {
2425 Weibull { scale: f64, shape: f64 },
2426 Gompertz { rate: f64, shape: f64 },
2427 GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2428}
2429
2430fn validated_baseline_params(
2436 age: f64,
2437 cfg: &SurvivalBaselineConfig,
2438 context: &str,
2439) -> Result<Option<ValidatedBaselineTarget>, String> {
2440 if !age.is_finite() || age <= 0.0 {
2441 return Err(format!(
2442 "survival ages must be finite and positive for {context}"
2443 ));
2444 }
2445
2446 match cfg.target {
2447 SurvivalBaselineTarget::Linear => Ok(None),
2448 SurvivalBaselineTarget::Weibull => {
2449 let scale = cfg
2450 .scale
2451 .ok_or_else(|| "weibull missing scale".to_string())?;
2452 let shape = cfg
2453 .shape
2454 .ok_or_else(|| "weibull missing shape".to_string())?;
2455 if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2456 return Err(SurvivalConstructionError::InvalidConfig {
2457 reason: "weibull baseline requires finite positive scale and shape".to_string(),
2458 }
2459 .into());
2460 }
2461 Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2462 }
2463 SurvivalBaselineTarget::Gompertz => {
2464 let rate = cfg
2465 .rate
2466 .ok_or_else(|| "gompertz missing rate".to_string())?;
2467 let shape = cfg
2468 .shape
2469 .ok_or_else(|| "gompertz missing shape".to_string())?;
2470 if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2471 return Err(
2472 "gompertz baseline requires finite positive rate and finite shape".to_string(),
2473 );
2474 }
2475 Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2476 }
2477 SurvivalBaselineTarget::GompertzMakeham => {
2478 let rate = cfg
2479 .rate
2480 .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2481 let shape = cfg
2482 .shape
2483 .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2484 let makeham = cfg
2485 .makeham
2486 .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2487 if !(rate.is_finite()
2488 && shape.is_finite()
2489 && makeham.is_finite()
2490 && rate > 0.0
2491 && makeham > 0.0)
2492 {
2493 return Err(
2494 "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2495 .to_string(),
2496 );
2497 }
2498 Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2499 rate,
2500 shape,
2501 makeham,
2502 }))
2503 }
2504 }
2505}
2506
2507fn survival_hazard_theta_partials(
2508 age: f64,
2509 cfg: &SurvivalBaselineConfig,
2510) -> Result<Option<Vec<(f64, f64)>>, String> {
2511 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2512 return Ok(None);
2513 };
2514
2515 match params {
2516 ValidatedBaselineTarget::Weibull { scale, shape } => {
2517 let log_time_ratio = age.ln() - scale.ln();
2518 let cumulative_hazard = (age / scale).powf(shape);
2519 let instant_hazard = shape * cumulative_hazard / age;
2520 let eta = shape * log_time_ratio;
2521 Ok(Some(vec![
2522 (-shape * cumulative_hazard, -shape * instant_hazard),
2523 (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2524 ]))
2525 }
2526 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2527 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2528 let (d_cum_dshape, d_inst_dshape) =
2529 gompertz_cumulative_shape_derivative(age, rate, shape);
2530 Ok(Some(vec![
2531 (cumulative_hazard, instant_hazard),
2532 (d_cum_dshape, d_inst_dshape),
2533 ]))
2534 }
2535 ValidatedBaselineTarget::GompertzMakeham {
2536 rate,
2537 shape,
2538 makeham,
2539 } => {
2540 let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2541 let (d_cum_dshape, d_inst_dshape) =
2542 gompertz_cumulative_shape_derivative(age, rate, shape);
2543 Ok(Some(vec![
2544 (cum_gompertz, inst_gompertz),
2545 (d_cum_dshape, d_inst_dshape),
2546 (makeham * age, makeham),
2547 ]))
2548 }
2549 }
2550}
2551
2552fn survival_cumulative_and_instant_hazard(
2553 age: f64,
2554 cfg: &SurvivalBaselineConfig,
2555) -> Result<Option<(f64, f64)>, String> {
2556 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2557 return Ok(None);
2558 };
2559
2560 match params {
2561 ValidatedBaselineTarget::Weibull { scale, shape } => {
2562 let cumulative_hazard = (age / scale).powf(shape);
2563 let instant_hazard = shape * cumulative_hazard / age;
2564 Ok(Some((cumulative_hazard, instant_hazard)))
2565 }
2566 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2567 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2568 Ok(Some((cumulative_hazard, instant_hazard)))
2569 }
2570 ValidatedBaselineTarget::GompertzMakeham {
2571 rate,
2572 shape,
2573 makeham,
2574 } => {
2575 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2576 Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2577 }
2578 }
2579}
2580
2581#[derive(Clone, Copy, Debug)]
2582struct MarginalSlopeBaselinePoint {
2583 instant_hazard: f64,
2584 q: f64,
2585 q_t: f64,
2586}
2587
2588fn evaluate_marginal_slope_baseline_point(
2589 age: f64,
2590 cfg: &SurvivalBaselineConfig,
2591) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2592 let Some((cumulative_hazard, instant_hazard)) =
2593 survival_cumulative_and_instant_hazard(age, cfg)?
2594 else {
2595 return Ok(None);
2596 };
2597 if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2598 return Err(format!(
2599 "{} marginal-slope baseline produced non-positive cumulative hazard",
2600 survival_baseline_targetname(cfg.target)
2601 ));
2602 }
2603 if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2604 return Err(format!(
2605 "{} marginal-slope baseline produced non-positive instant hazard",
2606 survival_baseline_targetname(cfg.target)
2607 ));
2608 }
2609 let survival = (-cumulative_hazard).exp();
2610 if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2611 return Err(format!(
2612 "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2613 survival_baseline_targetname(cfg.target)
2614 ));
2615 }
2616 let q = -standard_normal_quantile(survival).map_err(|e| {
2617 format!(
2618 "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2619 survival_baseline_targetname(cfg.target)
2620 )
2621 })?;
2622 let phi_q = normal_pdf(q);
2623 if !(phi_q.is_finite() && phi_q > 0.0) {
2624 return Err(format!(
2625 "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2626 survival_baseline_targetname(cfg.target)
2627 ));
2628 }
2629 Ok(Some(MarginalSlopeBaselinePoint {
2630 instant_hazard,
2631 q,
2632 q_t: instant_hazard * survival / phi_q,
2633 }))
2634}
2635
2636pub fn evaluate_survival_baseline(
2639 age: f64,
2640 cfg: &SurvivalBaselineConfig,
2641) -> Result<(f64, f64), String> {
2642 if !age.is_finite() || age < 0.0 {
2643 return Err(
2644 "survival ages must be finite and non-negative for baseline target evaluation"
2645 .to_string(),
2646 );
2647 }
2648
2649 if age == 0.0 {
2660 return match cfg.target {
2661 SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2662 SurvivalBaselineTarget::Weibull
2663 | SurvivalBaselineTarget::Gompertz
2664 | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2665 };
2666 }
2667
2668 let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2669 return Ok((0.0, 0.0));
2670 };
2671
2672 match params {
2673 ValidatedBaselineTarget::Weibull { scale, shape } => {
2674 let eta = shape * (age.ln() - scale.ln());
2675 let derivative = shape / age;
2676 Ok((eta, derivative))
2677 }
2678 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2679 let (h, inst) = gompertz_hazard_components(age, rate, shape);
2680 if h <= 0.0 || !h.is_finite() {
2681 return Err(if shape.abs() < 1e-10 {
2682 "invalid gompertz baseline at near-zero shape".to_string()
2683 } else {
2684 "gompertz baseline produced non-positive cumulative hazard".to_string()
2685 });
2686 }
2687 let derivative = inst / h;
2688 Ok((h.ln(), derivative))
2689 }
2690 ValidatedBaselineTarget::GompertzMakeham {
2691 rate,
2692 shape,
2693 makeham,
2694 } => {
2695 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2696 let h = makeham * age + h_gompertz;
2697 if h <= 0.0 || !h.is_finite() {
2698 return Err(
2699 "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2700 );
2701 }
2702 let inst = makeham + inst_gompertz;
2703 let derivative = inst / h;
2704 Ok((h.ln(), derivative))
2705 }
2706 }
2707}
2708
2709pub fn evaluate_survival_marginal_slope_baseline(
2715 age: f64,
2716 cfg: &SurvivalBaselineConfig,
2717) -> Result<(f64, f64), String> {
2718 if age == 0.0 {
2730 return Ok((0.0, 0.0));
2731 }
2732 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2733 return Ok((0.0, 0.0));
2734 };
2735 Ok((point.q, point.q_t))
2736}
2737
2738pub fn marginal_slope_baseline_offset_theta_partials(
2751 age: f64,
2752 cfg: &SurvivalBaselineConfig,
2753) -> Result<Option<Vec<(f64, f64)>>, String> {
2754 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2755 return Ok(None);
2756 };
2757 let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2758 .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2759 let a = point.q_t / point.instant_hazard;
2760 let a_log_derivative_factor = point.q * a - 1.0;
2761 Ok(Some(
2762 hazard_partials
2763 .into_iter()
2764 .map(|(d_h_cum, d_h_inst)| {
2765 (
2766 a * d_h_cum,
2767 a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2768 )
2769 })
2770 .collect(),
2771 ))
2772}
2773
2774pub fn marginal_slope_baseline_chain_rule_hessian(
2777 age_entry: ndarray::ArrayView1<'_, f64>,
2778 age_exit: ndarray::ArrayView1<'_, f64>,
2779 cfg: &SurvivalBaselineConfig,
2780 residuals: &crate::survival::OffsetChannelResiduals,
2781 curvatures: &crate::survival::OffsetChannelCurvatures,
2782) -> Result<Option<Array2<f64>>, String> {
2783 let n = age_exit.len();
2784 if age_entry.len() != n
2785 || residuals.exit.len() != n
2786 || residuals.entry.len() != n
2787 || residuals.derivative.len() != n
2788 || curvatures.rows.len() != n
2789 {
2790 return Err(format!(
2791 "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2792 age_entry.len(),
2793 n,
2794 residuals.exit.len(),
2795 residuals.entry.len(),
2796 residuals.derivative.len(),
2797 curvatures.rows.len(),
2798 ));
2799 }
2800 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2801 let dim = match probe_age {
2802 Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2803 None => return Ok(None),
2804 Some(parts) => parts.first.len(),
2805 },
2806 None => {
2807 return Err(
2808 "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2809 .to_string(),
2810 );
2811 }
2812 };
2813 let hessian = (0..n)
2818 .into_par_iter()
2819 .try_fold(
2820 || Array2::<f64>::zeros((dim, dim)),
2821 |mut acc, i| -> Result<Array2<f64>, String> {
2822 let exit_parts =
2823 marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2824 .ok_or_else(|| {
2825 "unexpected None from marginal-slope second partials at exit"
2826 .to_string()
2827 })?;
2828 if exit_parts.first.len() != dim {
2829 return Err(
2830 "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2831 );
2832 }
2833 let mut entry_parts = None;
2834 if residuals.entry[i] != 0.0 {
2835 entry_parts = Some(
2836 marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2837 .ok_or_else(|| {
2838 "unexpected None from marginal-slope second partials at entry"
2839 .to_string()
2840 })?,
2841 );
2842 }
2843 for a in 0..dim {
2844 for b in 0..dim {
2845 let j_exit_a = exit_parts.first[a].0;
2846 let j_exit_b = exit_parts.first[b].0;
2847 let j_deriv_a = exit_parts.first[a].1;
2848 let j_deriv_b = exit_parts.first[b].1;
2849 let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2850 + residuals.derivative[i] * exit_parts.second[a][b].1;
2851 if let Some(parts) = entry_parts.as_ref() {
2852 value += residuals.entry[i] * parts.second[a][b].0;
2853 }
2854 let curv = curvatures.rows[i];
2855 let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2856 let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2857 let ja = [j_entry_a, j_exit_a, j_deriv_a];
2858 let jb = [j_entry_b, j_exit_b, j_deriv_b];
2859 for u in 0..3 {
2860 for v in 0..3 {
2861 value += ja[u] * curv[u][v] * jb[v];
2862 }
2863 }
2864 acc[[a, b]] += value;
2865 }
2866 }
2867 Ok(acc)
2868 },
2869 )
2870 .try_reduce(|| Array2::<f64>::zeros((dim, dim)), |a, b| Ok(a + b))?;
2871 Ok(Some(hessian))
2872}
2873
2874struct MarginalSlopeThetaSecondPartials {
2875 first: Vec<(f64, f64)>,
2876 second: Vec<Vec<(f64, f64)>>,
2877}
2878
2879fn marginal_slope_baseline_offset_theta_second_partials(
2880 age: f64,
2881 cfg: &SurvivalBaselineConfig,
2882) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2883 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2884 return Ok(None);
2885 };
2886 let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2887 return Ok(None);
2888 };
2889 let (cum_hazard, instant_hazard) = hazard;
2890 let survival = (-cum_hazard).exp();
2891 let a = survival / normal_pdf(point.q);
2892 let b = point.q * a - 1.0;
2893 let b_factor = a + point.q * b;
2894 let dim = first.len();
2895 let mut first_out = Vec::with_capacity(dim);
2896 let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2897 for i in 0..dim {
2898 let (h_i, inst_i) = first[i];
2899 first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2900 }
2901 for i in 0..dim {
2902 for j in 0..dim {
2903 let (h_i, inst_i) = first[i];
2904 let (h_j, inst_j) = first[j];
2905 let (h_ij, inst_ij) = second[i][j];
2906 let a_j = a * b * h_j;
2907 let b_j = a * h_j * b_factor;
2908 let q_ij = a * h_ij + a * b * h_i * h_j;
2909 let qt_inner_i = inst_i + instant_hazard * b * h_i;
2910 let qt_ij = a_j * qt_inner_i
2911 + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2912 second_out[i][j] = (q_ij, qt_ij);
2913 }
2914 }
2915 Ok(Some(MarginalSlopeThetaSecondPartials {
2916 first: first_out,
2917 second: second_out,
2918 }))
2919}
2920
2921type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2922
2923fn survival_hazard_theta_first_second(
2924 age: f64,
2925 cfg: &SurvivalBaselineConfig,
2926) -> Result<Option<HazardFirstSecond>, String> {
2927 let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2928 return Ok(None);
2929 };
2930 let first = survival_hazard_theta_partials(age, cfg)?
2931 .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2932 let dim = first.len();
2933 let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2934 match cfg.target {
2935 SurvivalBaselineTarget::Linear => return Ok(None),
2936 SurvivalBaselineTarget::Weibull => {
2937 let scale = cfg
2938 .scale
2939 .ok_or_else(|| "weibull missing scale".to_string())?;
2940 let shape = cfg
2941 .shape
2942 .ok_or_else(|| "weibull missing shape".to_string())?;
2943 let log_time_ratio = age.ln() - scale.ln();
2944 let cumulative_hazard = hazard.0;
2945 let instant_hazard = hazard.1;
2946 let eta = shape * log_time_ratio;
2947 second[0][0] = (
2948 shape * shape * cumulative_hazard,
2949 shape * shape * instant_hazard,
2950 );
2951 second[0][1] = (
2952 -shape * cumulative_hazard * (1.0 + eta),
2953 -shape * instant_hazard * (2.0 + eta),
2954 );
2955 second[1][0] = second[0][1];
2956 second[1][1] = (
2957 eta * cumulative_hazard * (1.0 + eta),
2958 (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
2959 );
2960 }
2961 SurvivalBaselineTarget::Gompertz => {
2962 let rate = cfg
2963 .rate
2964 .ok_or_else(|| "gompertz missing rate".to_string())?;
2965 let shape = cfg
2966 .shape
2967 .ok_or_else(|| "gompertz missing shape".to_string())?;
2968 second[0][0] = first[0];
2969 second[0][1] = first[1];
2970 second[1][0] = first[1];
2971 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2972 }
2973 SurvivalBaselineTarget::GompertzMakeham => {
2974 let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
2975 let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
2976 second[0][0] = first[0];
2977 second[0][1] = first[1];
2978 second[1][0] = first[1];
2979 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2980 second[2][2] = first[2];
2981 }
2982 }
2983 Ok(Some((hazard, first, second)))
2984}
2985
2986#[inline]
2987fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2988 let x = shape * age;
2989 if x.abs() < 1e-3 {
3001 let t = age;
3002 (
3003 rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3004 rate * t * t * (1.0 + x + 0.5 * x * x),
3005 )
3006 } else {
3007 let e = x.exp();
3008 let em1 = x.exp_m1();
3009 let n = shape * age * e - em1;
3010 (
3011 rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3012 rate * age * age * e,
3013 )
3014 }
3015}
3016
3017#[derive(Clone, Copy)]
3022enum BaselineOffsetEvaluator {
3023 LogCumulativeHazard,
3024 ProbitSurvival,
3025}
3026
3027impl BaselineOffsetEvaluator {
3028 fn length_error(self) -> String {
3029 match self {
3030 Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3031 reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3032 }
3033 .into(),
3034 Self::ProbitSurvival => {
3035 "survival probit baseline offsets require matching entry/exit lengths".to_string()
3036 }
3037 }
3038 }
3039
3040 fn finite_error(self) -> &'static str {
3041 match self {
3042 Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3043 Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3044 }
3045 }
3046
3047 fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3048 match self {
3049 Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3050 Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3051 }
3052 }
3053
3054 fn exit_is_finite(self, value: f64, age: f64) -> bool {
3055 match self {
3056 Self::LogCumulativeHazard => {
3057 value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3058 }
3059 Self::ProbitSurvival => value.is_finite(),
3060 }
3061 }
3062}
3063
3064fn build_survival_offsets_with_evaluator(
3065 age_entry: &Array1<f64>,
3066 age_exit: &Array1<f64>,
3067 cfg: &SurvivalBaselineConfig,
3068 evaluator: BaselineOffsetEvaluator,
3069) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3070 if age_entry.len() != age_exit.len() {
3071 return Err(evaluator.length_error());
3072 }
3073 let n = age_entry.len();
3074 let triples: Vec<(f64, f64, f64)> = (0..n)
3077 .into_par_iter()
3078 .map(|i| -> Result<(f64, f64, f64), String> {
3079 let entry_age = age_entry[i];
3083 let e0 = if !entry_age.is_finite() {
3084 return Err(SurvivalConstructionError::DataValidationFailed {
3085 reason: format!("non-finite entry age at row {i}"),
3086 }
3087 .into());
3088 } else if entry_age <= 0.0 {
3089 0.0
3090 } else {
3091 evaluator.evaluate(entry_age, cfg)?.0
3092 };
3093 let exit_age = age_exit[i];
3094 let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3095 if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3096 return Err(SurvivalConstructionError::DataValidationFailed {
3097 reason: evaluator.finite_error().to_string(),
3098 }
3099 .into());
3100 }
3101 Ok((e0, e1, d1))
3102 })
3103 .collect::<Result<Vec<_>, String>>()?;
3104 let mut eta_entry = Array1::<f64>::zeros(n);
3105 let mut eta_exit = Array1::<f64>::zeros(n);
3106 let mut derivative_exit = Array1::<f64>::zeros(n);
3107 for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3108 eta_entry[i] = e0;
3109 eta_exit[i] = e1;
3110 derivative_exit[i] = d1;
3111 }
3112 Ok((eta_entry, eta_exit, derivative_exit))
3113}
3114
3115pub fn build_survival_baseline_offsets(
3118 age_entry: &Array1<f64>,
3119 age_exit: &Array1<f64>,
3120 cfg: &SurvivalBaselineConfig,
3121) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3122 build_survival_offsets_with_evaluator(
3123 age_entry,
3124 age_exit,
3125 cfg,
3126 BaselineOffsetEvaluator::LogCumulativeHazard,
3127 )
3128}
3129
3130pub fn build_survival_marginal_slope_baseline_offsets(
3133 age_entry: &Array1<f64>,
3134 age_exit: &Array1<f64>,
3135 cfg: &SurvivalBaselineConfig,
3136) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3137 build_survival_offsets_with_evaluator(
3138 age_entry,
3139 age_exit,
3140 cfg,
3141 BaselineOffsetEvaluator::ProbitSurvival,
3142 )
3143}
3144
3145pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3146 matches!(
3147 inverse_link,
3148 Some(
3149 InverseLink::Standard(StandardLink::Probit)
3150 | InverseLink::LatentCLogLog(_)
3151 | InverseLink::Sas(_)
3152 | InverseLink::BetaLogistic(_)
3153 | InverseLink::Mixture(_)
3154 )
3155 )
3156}
3157
3158pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3159 match likelihood_mode {
3160 SurvivalLikelihoodMode::LocationScale
3161 | SurvivalLikelihoodMode::Latent
3162 | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3163 SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3164 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3165 }
3166}
3167
3168pub fn build_survival_time_offsets_for_likelihood(
3169 age_entry: &Array1<f64>,
3170 age_exit: &Array1<f64>,
3171 baseline_cfg: &SurvivalBaselineConfig,
3172 likelihood_mode: SurvivalLikelihoodMode,
3173 inverse_link: Option<&InverseLink>,
3174) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3175 if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3176 || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3177 && location_scale_uses_probit_survival_baseline(inverse_link))
3178 {
3179 build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3180 } else {
3181 build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3182 }
3183}
3184
3185pub fn add_survival_time_derivative_guard_offset(
3186 age_entry: &Array1<f64>,
3187 age_exit: &Array1<f64>,
3188 anchor_time: f64,
3189 derivative_guard: f64,
3190 eta_offset_entry: &mut Array1<f64>,
3191 eta_offset_exit: &mut Array1<f64>,
3192 derivative_offset_exit: &mut Array1<f64>,
3193) -> Result<(), String> {
3194 if derivative_guard <= 0.0 {
3195 return Ok(());
3196 }
3197 let n = age_entry.len();
3198 if age_exit.len() != n
3199 || eta_offset_entry.len() != n
3200 || eta_offset_exit.len() != n
3201 || derivative_offset_exit.len() != n
3202 {
3203 return Err(SurvivalConstructionError::IncompatibleDimensions {
3204 reason: "survival derivative-guard offset lengths must match".to_string(),
3205 }
3206 .into());
3207 }
3208 for i in 0..n {
3209 eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3210 eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3211 derivative_offset_exit[i] += derivative_guard;
3212 }
3213 Ok(())
3214}
3215
3216#[derive(Clone, Debug)]
3217pub struct LatentSurvivalBaselineOffsets {
3218 pub loaded_eta_entry: Array1<f64>,
3219 pub loaded_eta_exit: Array1<f64>,
3220 pub loaded_derivative_exit: Array1<f64>,
3221 pub unloaded_mass_entry: Array1<f64>,
3222 pub unloaded_mass_exit: Array1<f64>,
3223 pub unloaded_hazard_exit: Array1<f64>,
3224}
3225
3226pub fn build_latent_survival_baseline_offsets(
3227 age_entry: &Array1<f64>,
3228 age_exit: &Array1<f64>,
3229 cfg: &SurvivalBaselineConfig,
3230 loading: HazardLoading,
3231) -> Result<LatentSurvivalBaselineOffsets, String> {
3232 if age_entry.len() != age_exit.len() {
3233 return Err(
3234 "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3235 );
3236 }
3237
3238 fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3239 if shape.abs() < 1e-10 {
3240 let x = shape * age;
3247 return (
3248 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3249 rate * (1.0 + x + 0.5 * x * x),
3250 );
3251 }
3252 let shape_age = shape * age;
3253 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3254 let instant_hazard = rate * shape_age.exp();
3255 (cumulative_hazard, instant_hazard)
3256 }
3257
3258 let n = age_entry.len();
3259
3260 let rows: Vec<[f64; 6]> = (0..n)
3263 .into_par_iter()
3264 .map(|i| -> Result<[f64; 6], String> {
3265 let entry = age_entry[i];
3266 let exit = age_exit[i];
3267 if !entry.is_finite()
3268 || !exit.is_finite()
3269 || entry <= 0.0
3270 || exit <= 0.0
3271 || exit < entry
3272 {
3273 return Err(format!(
3274 "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3275 i + 1
3276 ));
3277 }
3278 match loading {
3279 HazardLoading::Full => {
3280 let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3281 let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3282 Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3283 }
3284 HazardLoading::LoadedVsUnloaded => {
3285 if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3286 return Err(format!(
3287 "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3288 survival_baseline_targetname(cfg.target)
3289 ));
3290 }
3291 let rate = cfg.rate.ok_or_else(|| {
3292 "gompertz-makeham latent survival is missing baseline rate".to_string()
3293 })?;
3294 let shape = cfg.shape.ok_or_else(|| {
3295 "gompertz-makeham latent survival is missing baseline shape".to_string()
3296 })?;
3297 let makeham = cfg.makeham.ok_or_else(|| {
3298 "gompertz-makeham latent survival is missing baseline makeham".to_string()
3299 })?;
3300 let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3301 let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3302 if !(loaded_entry.is_finite()
3303 && loaded_entry > 0.0
3304 && loaded_exit.is_finite()
3305 && loaded_exit > 0.0
3306 && loaded_hazard.is_finite()
3307 && loaded_hazard > 0.0)
3308 {
3309 return Err(format!(
3310 "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3311 i + 1
3312 ));
3313 }
3314 Ok([
3315 loaded_entry.ln(),
3316 loaded_exit.ln(),
3317 loaded_hazard / loaded_exit,
3318 makeham * entry,
3319 makeham * exit,
3320 makeham,
3321 ])
3322 }
3323 }
3324 })
3325 .collect::<Result<Vec<_>, String>>()?;
3326
3327 let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3328 let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3329 let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3330 let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3331 let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3332 let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3333 for (i, row) in rows.into_iter().enumerate() {
3334 loaded_eta_entry[i] = row[0];
3335 loaded_eta_exit[i] = row[1];
3336 loaded_derivative_exit[i] = row[2];
3337 unloaded_mass_entry[i] = row[3];
3338 unloaded_mass_exit[i] = row[4];
3339 unloaded_hazard_exit[i] = row[5];
3340 }
3341
3342 Ok(LatentSurvivalBaselineOffsets {
3343 loaded_eta_entry,
3344 loaded_eta_exit,
3345 loaded_derivative_exit,
3346 unloaded_mass_entry,
3347 unloaded_mass_exit,
3348 unloaded_hazard_exit,
3349 })
3350}
3351
3352pub fn build_survival_timewiggle_derivative_design(
3357 eta_exit: &Array1<f64>,
3358 derivative_exit: &Array1<f64>,
3359 knots: &Array1<f64>,
3360 degree: usize,
3361) -> Result<Array2<f64>, String> {
3362 let mut design_derivative_exit =
3363 monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3364 for i in 0..design_derivative_exit.nrows() {
3365 let chain = derivative_exit[i];
3366 for j in 0..design_derivative_exit.ncols() {
3367 design_derivative_exit[[i, j]] *= chain;
3368 }
3369 }
3370 Ok(design_derivative_exit)
3371}
3372
3373pub fn build_survival_timewiggle_from_baseline(
3383 eta_entry: &Array1<f64>,
3384 eta_exit: &Array1<f64>,
3385 derivative_exit: &Array1<f64>,
3386 cfg: &LinkWiggleFormulaSpec,
3387) -> Result<SurvivalTimeWiggleBuild, String> {
3388 if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3389 return Err(
3390 "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3391 );
3392 }
3393 let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3396 && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3397 && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3398 if all_zero {
3399 return Err(
3400 "timewiggle requires a non-linear scalar survival baseline target; \
3401 the provided baseline offsets are all zero (linear baseline)"
3402 .to_string(),
3403 );
3404 }
3405 let n = eta_exit.len();
3406 let mut seed = Array1::<f64>::zeros(2 * n);
3407 for i in 0..n {
3408 seed[i] = eta_entry[i];
3409 seed[n + i] = eta_exit[i];
3410 }
3411 let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3415 let wiggle_cfg = WiggleBlockConfig {
3416 degree: cfg.degree,
3417 num_internal_knots: cfg.num_internal_knots,
3418 penalty_order: primary_order,
3419 double_penalty: cfg.double_penalty,
3420 };
3421 let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3422 append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3423 let ncols = combined_block.design.ncols();
3424 Ok(SurvivalTimeWiggleBuild {
3425 nullspace_dims: combined_block.nullspace_dims.clone(),
3426 penalties: {
3427 combined_block
3428 .penalties
3429 .into_iter()
3430 .map(|ps| ps.to_global(ncols))
3431 .collect()
3432 },
3433 knots,
3434 degree: cfg.degree,
3435 ncols,
3436 })
3437}
3438
3439pub fn append_zero_tail_columns(
3440 x_entry: &mut DesignMatrix,
3441 x_exit: &mut DesignMatrix,
3442 x_derivative: &mut DesignMatrix,
3443 tail_cols: usize,
3444) {
3445 if tail_cols == 0 {
3446 return;
3447 }
3448 fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3451 let old = dm.to_dense();
3452 let n = old.nrows();
3453 let p_base = old.ncols();
3454 let mut out = Array2::<f64>::zeros((n, p_base + tail));
3455 out.slice_mut(s![.., 0..p_base]).assign(&old);
3456 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3457 }
3458 append_dense(x_entry, tail_cols);
3459 append_dense(x_exit, tail_cols);
3460 append_dense(x_derivative, tail_cols);
3461}
3462
3463pub fn build_time_varying_survival_covariate_template(
3474 age_entry: &Array1<f64>,
3475 age_exit: &Array1<f64>,
3476 time_k: usize,
3477 time_degree: usize,
3478 block_name: &str,
3479) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3480 if time_k < time_degree + 1 {
3481 return Err(format!(
3482 "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3483 time_degree + 1
3484 ));
3485 }
3486 let num_internal_knots = time_k - (time_degree + 1);
3487
3488 let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3489 let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3490
3491 let time_spec = BSplineBasisSpec {
3492 degree: time_degree,
3493 penalty_order: 2,
3494 knotspec: BSplineKnotSpec::Automatic {
3495 num_internal_knots: Some(num_internal_knots),
3496 placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3497 },
3498 double_penalty: false,
3499 identifiability: BSplineIdentifiability::None,
3500 boundary: OneDimensionalBoundary::Open,
3501 boundary_conditions: BSplineBoundaryConditions::default(),
3502 };
3503
3504 let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3505 .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3506 let time_design_exit = time_build.design.to_dense();
3507
3508 let knots = match &time_build.metadata {
3509 BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3510 _ => {
3511 return Err(format!(
3512 "{block_name} time-margin basis returned unexpected metadata type"
3513 ));
3514 }
3515 };
3516
3517 let time_build_entry = build_bspline_basis_1d(
3518 log_entry.view(),
3519 &BSplineBasisSpec {
3520 degree: time_degree,
3521 penalty_order: 2,
3522 knotspec: BSplineKnotSpec::Provided(knots.clone()),
3523 double_penalty: false,
3524 identifiability: BSplineIdentifiability::None,
3525 boundary: OneDimensionalBoundary::Open,
3526 boundary_conditions: BSplineBoundaryConditions::default(),
3527 },
3528 )
3529 .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3530 let time_design_entry = time_build_entry.design.to_dense();
3531 let p_time = time_design_exit.ncols();
3532 let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3533 time_design_derivative_exit
3537 .as_slice_mut()
3538 .expect("zeros are contiguous")
3539 .par_chunks_mut(p_time)
3540 .enumerate()
3541 .try_for_each(|(i, row_out)| -> Result<(), String> {
3542 let mut deriv_buf = vec![0.0_f64; p_time];
3543 evaluate_bspline_derivative_scalar(
3544 log_exit[i],
3545 knots.view(),
3546 time_degree,
3547 &mut deriv_buf,
3548 )
3549 .map_err(|e| {
3550 format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3551 })?;
3552 let chain = 1.0 / age_exit[i].max(1e-12);
3553 for j in 0..p_time {
3554 row_out[j] = deriv_buf[j] * chain;
3555 }
3556 Ok(())
3557 })?;
3558
3559 Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3560 time_basis_entry: time_design_entry,
3561 time_basis_exit: time_design_exit,
3562 time_basis_derivative_exit: time_design_derivative_exit,
3563 time_penalties: time_build.penalties,
3564 })
3565}
3566
3567#[cfg(test)]
3568mod tests {
3569 use super::{
3570 SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3571 baseline_chain_rule_gradient, baseline_offset_theta_partials,
3572 build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3573 build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3574 evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3575 gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3576 marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3577 marginal_slope_baseline_offset_theta_partials,
3578 optimize_survival_baseline_config_with_gradient,
3579 optimize_survival_baseline_config_with_gradient_only,
3580 resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3581 survival_baseline_theta_from_config,
3582 };
3583 use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3584 use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3585 use crate::probability::normal_cdf;
3586 use ndarray::{Array1, Array2, array};
3587
3588 #[test]
3589 fn survival_timewiggle_keeps_requested_order_one_penalty() {
3590 let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3591 let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3592 let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3593 let cfg = LinkWiggleFormulaSpec {
3594 degree: 3,
3595 num_internal_knots: 4,
3596 penalty_orders: vec![1, 2, 3],
3597 double_penalty: false,
3598 };
3599
3600 let build =
3601 build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3602 .expect("build survival timewiggle");
3603
3604 assert_eq!(build.penalties.len(), 3);
3605 assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3606 assert!(build.ncols > 0);
3607 }
3608
3609 #[test]
3610 fn marginal_slope_time_anchor_defaults_to_median_exit() {
3611 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3612 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3613 let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3614 .expect("resolve marginal-slope default time anchor");
3615
3616 assert!(
3617 (anchor - 19.0).abs() <= 1e-12,
3618 "marginal-slope default anchor should be median exit, got {anchor}"
3619 );
3620 }
3621
3622 #[test]
3623 fn marginal_slope_time_anchor_honors_explicit_value() {
3624 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3625 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3626 let anchor =
3627 resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3628 .expect("resolve explicit marginal-slope time anchor");
3629
3630 assert!(
3631 (anchor - 7.5).abs() <= 1e-12,
3632 "explicit marginal-slope anchor should round-trip, got {anchor}"
3633 );
3634 }
3635
3636 #[test]
3647 fn baseline_optimizer_contracts_agree_on_shared_surface() {
3648 let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3653 let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3654
3655 let initial = SurvivalBaselineConfig {
3658 target: SurvivalBaselineTarget::Weibull,
3659 scale: Some(1.0),
3660 shape: Some(1.0),
3661 rate: None,
3662 makeham: None,
3663 };
3664
3665 let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3668 survival_baseline_theta_from_config(cfg)
3669 .expect("config→θ")
3670 .expect("Weibull config has a θ")
3671 };
3672
3673 let curvature_cost = curvature.clone();
3676 let star_cost = theta_star.clone();
3677 let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3678 let theta = survival_baseline_theta_from_config(cfg)?
3679 .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3680 let d = &theta - &star_cost;
3681 let ad = curvature_cost.dot(&d);
3682 Ok(0.5 * d.dot(&ad))
3683 };
3684
3685 let curvature_grad = curvature.clone();
3686 let star_grad = theta_star.clone();
3687 let cost_for_grad = cost_at.clone();
3688 let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3689 &initial,
3690 "baseline parity (gradient-only)",
3691 move |cfg| {
3692 let cost = cost_for_grad(cfg)?;
3693 let theta = survival_baseline_theta_from_config(cfg)?
3694 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3695 let gradient = curvature_grad.dot(&(&theta - &star_grad));
3696 Ok((cost, gradient))
3697 },
3698 )
3699 .expect("gradient-only baseline optimization converges");
3700
3701 let curvature_hess = curvature.clone();
3702 let star_hess = theta_star.clone();
3703 let cost_for_hess = cost_at.clone();
3704 let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3705 &initial,
3706 "baseline parity (gradient+Hessian)",
3707 move |cfg| {
3708 let cost = cost_for_hess(cfg)?;
3709 let theta = survival_baseline_theta_from_config(cfg)?
3710 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3711 let gradient = curvature_hess.dot(&(&theta - &star_hess));
3712 Ok((cost, gradient, curvature_hess.clone()))
3713 },
3714 )
3715 .expect("gradient+Hessian baseline optimization converges");
3716
3717 let theta_grad_only = recovered_theta(&result_grad_only);
3718 let theta_grad_hess = recovered_theta(&result_grad_hess);
3719
3720 for (label, theta) in [
3723 ("gradient-only", &theta_grad_only),
3724 ("gradient+Hessian", &theta_grad_hess),
3725 ] {
3726 let err = (theta - &theta_star)
3727 .mapv(f64::abs)
3728 .fold(0.0_f64, |a, &v| a.max(v));
3729 assert!(
3730 err <= 2e-3,
3731 "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3732 );
3733 }
3734
3735 let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3739 (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3740 };
3741 assert!(
3742 pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3743 "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3744 );
3745 }
3746
3747 #[test]
3748 fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3749 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3750 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3751 let requested_degree = 3;
3752 let num_internal_knots = 1;
3753
3754 let built = build_survival_time_basis(
3755 &age_entry,
3756 &age_exit,
3757 SurvivalTimeBasisConfig::ISpline {
3758 degree: requested_degree,
3759 knots: Array1::zeros(0),
3760 keep_cols: Vec::new(),
3761 smooth_lambda: 1e-2,
3762 },
3763 Some((num_internal_knots, 1e-2)),
3764 )
3765 .expect("automatic cubic ispline with one interior knot builds");
3766
3767 let working_degree = requested_degree + 1;
3768 let knots = built.knots.expect("resolved ispline knots");
3769 assert_eq!(
3770 knots.len(),
3771 num_internal_knots + 2 * (working_degree + 1),
3772 "I-spline automatic knots must be clamped for the working B-spline degree"
3773 );
3774 assert_eq!(built.degree, Some(requested_degree));
3775 assert!(built.x_exit_time.ncols() > 0);
3776 assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3777 assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3778 }
3779
3780 #[test]
3781 fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3782 let age_entry = array![1.0_f64, 1.0, 1.0];
3783 let age_exit = array![4.0_f64, 4.0, 4.0];
3784 let left = 1.0_f64.ln();
3785 let right = 4.0_f64.ln();
3786 let mid = left + 0.5 * (right - left);
3787 let knots = array![left, left, left, left, mid, right, right, right, right];
3788
3789 let built = build_survival_time_basis(
3790 &age_entry,
3791 &age_exit,
3792 SurvivalTimeBasisConfig::ISpline {
3793 degree: 2,
3794 knots,
3795 keep_cols: Vec::new(),
3796 smooth_lambda: 1e-2,
3797 },
3798 None,
3799 )
3800 .expect("build right-boundary ispline time basis");
3801
3802 let derivative = built.x_derivative_time.as_dense_cow();
3803 let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3804 assert!(
3805 max_abs > 1e-8,
3806 "right-boundary I-spline derivative must use the left-hand endpoint slope"
3807 );
3808 for row in derivative.rows() {
3809 assert!(
3810 row.iter().any(|v| *v > 1e-8),
3811 "each row at the right boundary needs a positive hazard derivative"
3812 );
3813 }
3814 }
3815
3816 #[test]
3817 fn ispline_time_penalty_is_psd_under_nontrivial_keep_cols() {
3818 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3837 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3838 let left = 1.0_f64.ln();
3839 let right = 21.0_f64.ln();
3840 let q1 = left + 0.25 * (right - left);
3841 let mid = left + 0.5 * (right - left);
3842 let q3 = left + 0.75 * (right - left);
3843 let knots = array![
3847 left, left, left, left, q1, mid, q3, right, right, right, right
3848 ];
3849
3850 let full = build_survival_time_basis(
3852 &age_entry,
3853 &age_exit,
3854 SurvivalTimeBasisConfig::ISpline {
3855 degree: 2,
3856 knots: knots.clone(),
3857 keep_cols: Vec::new(),
3858 smooth_lambda: 1e-2,
3859 },
3860 None,
3861 )
3862 .expect("build full-width ispline time basis");
3863 let p_time_full = full
3864 .keep_cols
3865 .as_ref()
3866 .map(|k| k.len())
3867 .unwrap_or_else(|| full.x_exit_time.ncols());
3868 assert!(
3869 p_time_full >= 3,
3870 "test needs at least 3 shape-varying columns to drop an interior one; got {p_time_full}"
3871 );
3872
3873 let keep_cols: Vec<usize> = (0..p_time_full).filter(|&j| j != 1).collect();
3876
3877 let built = build_survival_time_basis(
3878 &age_entry,
3879 &age_exit,
3880 SurvivalTimeBasisConfig::ISpline {
3881 degree: 2,
3882 knots,
3883 keep_cols: keep_cols.clone(),
3884 smooth_lambda: 1e-2,
3885 },
3886 None,
3887 )
3888 .expect(
3889 "reduced ispline penalty must build (PSD contract must accept the \
3890 congruence-first / select-second ordering)",
3891 );
3892
3893 assert_eq!(
3894 built.penalties.len(),
3895 1,
3896 "the ispline time basis should carry exactly one curvature penalty"
3897 );
3898 let s = &built.penalties[0];
3899 assert_eq!(s.nrows(), keep_cols.len());
3900 assert_eq!(s.ncols(), keep_cols.len());
3901
3902 let (evals, _) =
3903 gam_linalg::faer_ndarray::FaerEigh::eigh(s, faer::Side::Lower).expect("eigh of penalty");
3904 let evals_slice = evals.as_slice().expect("contiguous eigenvalues");
3905 let max_abs = evals_slice
3906 .iter()
3907 .copied()
3908 .fold(0.0_f64, |a, b| a.max(b.abs()))
3909 .max(1.0);
3910 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
3911 let tol = -100.0 * (s.nrows() as f64) * f64::EPSILON * max_abs;
3912 assert!(
3913 min_ev >= tol,
3914 "reduced I-spline time penalty must be PSD (gam#979): min eigenvalue \
3915 {min_ev:.3e} < tol {tol:.3e}, max|eig| {max_abs:.3e}"
3916 );
3917 }
3918
3919 #[test]
3920 fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3921 let cfg = SurvivalBaselineConfig {
3922 target: SurvivalBaselineTarget::GompertzMakeham,
3923 scale: None,
3924 shape: Some(0.07),
3925 rate: Some(0.012),
3926 makeham: Some(0.003),
3927 };
3928 let age = 11.5;
3929 let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3930 .expect("evaluate marginal-slope gompertz-makeham baseline");
3931 let shape = cfg.shape.expect("shape");
3932 let rate = cfg.rate.expect("rate");
3933 let makeham = cfg.makeham.expect("makeham");
3934 let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3935 let instant_hazard = makeham + rate * (shape * age).exp();
3936 let expected_survival = (-cumulative_hazard).exp();
3937 let actual_survival = normal_cdf(-q);
3938 assert!((actual_survival - expected_survival).abs() <= 1e-12);
3939
3940 let h = 1e-5;
3941 let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
3942 .expect("q plus")
3943 .0;
3944 let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
3945 .expect("q minus")
3946 .0;
3947 let fd = (q_plus - q_minus) / (2.0 * h);
3948 assert!((q_derivative - fd).abs() <= 1e-7);
3949 assert!(instant_hazard > 0.0);
3950 }
3951
3952 #[test]
3953 fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
3954 let configs = [
3963 SurvivalBaselineConfig {
3964 target: SurvivalBaselineTarget::Linear,
3965 scale: None,
3966 shape: None,
3967 rate: None,
3968 makeham: None,
3969 },
3970 SurvivalBaselineConfig {
3971 target: SurvivalBaselineTarget::Weibull,
3972 scale: Some(2.5),
3973 shape: Some(1.3),
3974 rate: None,
3975 makeham: None,
3976 },
3977 SurvivalBaselineConfig {
3978 target: SurvivalBaselineTarget::Gompertz,
3979 scale: None,
3980 shape: Some(0.05),
3981 rate: Some(0.01),
3982 makeham: None,
3983 },
3984 SurvivalBaselineConfig {
3985 target: SurvivalBaselineTarget::GompertzMakeham,
3986 scale: None,
3987 shape: Some(0.07),
3988 rate: Some(0.012),
3989 makeham: Some(0.003),
3990 },
3991 ];
3992 for cfg in &configs {
3993 let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
3996 .expect("marginal-slope baseline must be evaluable at the origin");
3997 assert_eq!(q0, 0.0);
3998 assert_eq!(q0_derivative, 0.0);
3999
4000 let (eta0, eta0_derivative) =
4004 evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
4005 assert!(eta0_derivative.is_finite());
4006 assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
4007
4008 let age_entry = array![0.0, 0.0];
4012 let age_exit = array![0.0, 1.5];
4013 let (entry, exit, derivative) =
4014 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
4015 .expect("probit baseline offsets must build through the origin");
4016 assert!(entry.iter().all(|v| v.is_finite()));
4017 assert!(exit.iter().all(|v| v.is_finite()));
4018 assert!(derivative.iter().all(|v| v.is_finite()));
4019 assert_eq!(exit[0], 0.0);
4021 }
4022 }
4023
4024 #[test]
4025 fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
4026 let cfg = SurvivalBaselineConfig {
4027 target: SurvivalBaselineTarget::GompertzMakeham,
4028 scale: None,
4029 shape: Some(0.03),
4030 rate: Some(0.01),
4031 makeham: Some(0.002),
4032 };
4033 let age_entry = array![2.0, 4.0];
4034 let age_exit = array![5.0, 9.0];
4035 let (entry, exit, derivative) =
4036 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
4037 .expect("marginal-slope baseline offsets");
4038 for i in 0..age_entry.len() {
4039 let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
4040 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4041 * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
4042 let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
4043 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4044 * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
4045 assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
4046 assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
4047 assert!(derivative[i].is_finite() && derivative[i] > 0.0);
4048 }
4049 }
4050
4051 fn fd_marginal_slope_baseline_offset(
4052 age: f64,
4053 cfg: &SurvivalBaselineConfig,
4054 steps: &[f64],
4055 ) -> Vec<(f64, f64)> {
4056 let theta = survival_baseline_theta_from_config(cfg)
4057 .expect("theta")
4058 .expect("non-linear baseline");
4059 assert_eq!(
4060 steps.len(),
4061 theta.len(),
4062 "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
4063 );
4064 (0..theta.len())
4065 .map(|k| {
4066 let h = steps[k];
4067 let mut theta_plus = theta.clone();
4068 theta_plus[k] += h;
4069 let mut theta_minus = theta.clone();
4070 theta_minus[k] -= h;
4071 let cfg_plus =
4072 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4073 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4074 .expect("minus cfg");
4075 let (q_p, qt_p) =
4076 evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
4077 let (q_m, qt_m) =
4078 evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
4079 ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
4080 })
4081 .collect()
4082 }
4083
4084 #[test]
4085 fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
4086 let cfg = SurvivalBaselineConfig {
4087 target: SurvivalBaselineTarget::GompertzMakeham,
4088 scale: None,
4089 shape: Some(0.04),
4090 rate: Some(0.013),
4091 makeham: Some(0.002),
4092 };
4093 let age = 17.0;
4094 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4095 .expect("partials")
4096 .expect("nonlinear");
4097 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
4098 assert_eq!(analytic.len(), fd.len());
4099 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4100 assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
4101 assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
4102 }
4103 }
4104
4105 #[test]
4106 fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4107 let cfg = SurvivalBaselineConfig {
4108 target: SurvivalBaselineTarget::GompertzMakeham,
4109 scale: None,
4110 shape: Some(1e-14),
4111 rate: Some(0.013),
4112 makeham: Some(0.002),
4113 };
4114 let age = 17.0;
4115 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4116 .expect("partials")
4117 .expect("nonlinear");
4118 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4119 assert_eq!(analytic.len(), fd.len());
4120 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4121 assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4122 assert_close(
4123 *aqt,
4124 *fqt,
4125 1e-5,
4126 &format!("near-zero gm-probit q' theta[{k}]"),
4127 );
4128 }
4129 }
4130
4131 fn shifted_quadratic_offset_residuals(
4132 age_entry: ndarray::ArrayView1<'_, f64>,
4133 age_exit: ndarray::ArrayView1<'_, f64>,
4134 base_cfg: &SurvivalBaselineConfig,
4135 candidate_cfg: &SurvivalBaselineConfig,
4136 base: &OffsetChannelResiduals,
4137 curvatures: &OffsetChannelCurvatures,
4138 ) -> OffsetChannelResiduals {
4139 let n = age_exit.len();
4140 let mut entry = base.entry.clone();
4141 let mut exit = base.exit.clone();
4142 let mut derivative = base.derivative.clone();
4143 for row in 0..n {
4144 let (_, base_exit, base_deriv) =
4145 baseline_marginal_slope_channels(age_exit[row], base_cfg);
4146 let (_, cand_exit, cand_deriv) =
4147 baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4148 let base_entry = if base.entry[row] == 0.0 {
4149 0.0
4150 } else {
4151 baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4152 };
4153 let cand_entry = if base.entry[row] == 0.0 {
4154 0.0
4155 } else {
4156 baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4157 };
4158 let delta = [
4159 cand_entry - base_entry,
4160 cand_exit - base_exit,
4161 cand_deriv - base_deriv,
4162 ];
4163 let mut shift = [0.0; 3];
4164 for i in 0..3 {
4165 for j in 0..3 {
4166 shift[i] += curvatures.rows[row][i][j] * delta[j];
4167 }
4168 }
4169 if base.entry[row] != 0.0 {
4170 entry[row] += shift[0];
4171 }
4172 exit[row] += shift[1];
4173 derivative[row] += shift[2];
4174 }
4175 OffsetChannelResiduals {
4176 entry,
4177 exit,
4178 derivative,
4179 right: base.right.clone(),
4180 }
4181 }
4182
4183 fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4184 let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4185 (q, q, q_t)
4186 }
4187
4188 #[test]
4189 fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4190 let cfg = SurvivalBaselineConfig {
4191 target: SurvivalBaselineTarget::GompertzMakeham,
4192 scale: None,
4193 shape: Some(0.025),
4194 rate: Some(0.012),
4195 makeham: Some(0.003),
4196 };
4197 let theta = survival_baseline_theta_from_config(&cfg)
4198 .expect("theta")
4199 .expect("nonlinear");
4200 let age_entry = array![2.5, 0.0, 5.0];
4201 let age_exit = array![7.5, 11.0, 15.0];
4202 let base_residuals = OffsetChannelResiduals {
4203 entry: array![0.2, 0.0, -0.1],
4204 exit: array![0.6, -0.3, 0.4],
4205 derivative: array![-0.5, 0.25, 0.15],
4206 right: Array1::<f64>::zeros(3),
4207 };
4208 let curvatures = OffsetChannelCurvatures {
4209 rows: vec![
4210 [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4211 [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4212 [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4213 ],
4214 };
4215 let analytic = marginal_slope_baseline_chain_rule_hessian(
4216 age_entry.view(),
4217 age_exit.view(),
4218 &cfg,
4219 &base_residuals,
4220 &curvatures,
4221 )
4222 .expect("hessian")
4223 .expect("nonlinear");
4224
4225 let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4226 let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4227 .expect("candidate cfg");
4228 let residuals = shifted_quadratic_offset_residuals(
4229 age_entry.view(),
4230 age_exit.view(),
4231 &cfg,
4232 &candidate,
4233 &base_residuals,
4234 &curvatures,
4235 );
4236 marginal_slope_baseline_chain_rule_gradient(
4237 age_entry.view(),
4238 age_exit.view(),
4239 &candidate,
4240 &residuals,
4241 )
4242 .expect("gradient")
4243 .expect("nonlinear")
4244 };
4245
4246 for j in 0..theta.len() {
4247 let step = if j == 1 { 2e-5 } else { 1e-5 };
4248 let mut plus = theta.clone();
4249 plus[j] += step;
4250 let mut minus = theta.clone();
4251 minus[j] -= step;
4252 let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4253 for i in 0..theta.len() {
4254 assert_close(
4255 analytic[[i, j]],
4256 fd_col[i],
4257 2e-5,
4258 &format!("baseline Hessian ({i},{j})"),
4259 );
4260 }
4261 }
4262 }
4263
4264 #[test]
4265 fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4266 let cfg = SurvivalBaselineConfig {
4267 target: SurvivalBaselineTarget::GompertzMakeham,
4268 scale: None,
4269 shape: Some(0.03),
4270 rate: Some(0.01),
4271 makeham: Some(0.002),
4272 };
4273 let age_entry = array![3.0, 6.0];
4274 let age_exit = array![8.0, 12.0];
4275 let residuals = OffsetChannelResiduals {
4276 exit: array![0.7, -0.2],
4277 entry: array![0.1, 0.4],
4278 derivative: array![1.3, -0.6],
4279 right: Array1::<f64>::zeros(2),
4280 };
4281 let grad = marginal_slope_baseline_chain_rule_gradient(
4282 age_entry.view(),
4283 age_exit.view(),
4284 &cfg,
4285 &residuals,
4286 )
4287 .expect("gradient")
4288 .expect("nonlinear");
4289
4290 let mut expected = Array1::<f64>::zeros(3);
4291 for i in 0..age_exit.len() {
4292 let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4293 .expect("exit partials")
4294 .expect("nonlinear");
4295 let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4296 .expect("entry partials")
4297 .expect("nonlinear");
4298 for k in 0..3 {
4299 expected[k] += residuals.exit[i] * exit_partials[k].0
4300 + residuals.derivative[i] * exit_partials[k].1
4301 + residuals.entry[i] * entry_partials[k].0;
4302 }
4303 }
4304 for k in 0..3 {
4305 assert_close(
4306 grad[k],
4307 expected[k],
4308 1e-12,
4309 &format!("gm-probit chain gradient theta[{k}]"),
4310 );
4311 }
4312 }
4313
4314 #[test]
4324 fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4325 let cfg = SurvivalBaselineConfig {
4326 target: SurvivalBaselineTarget::GompertzMakeham,
4327 scale: None,
4328 shape: Some(0.028),
4329 rate: Some(0.011),
4330 makeham: Some(0.0025),
4331 };
4332 let age_entry = array![3.0, 0.0, 5.5];
4335 let age_exit = array![8.0, 12.0, 16.0];
4336 let residuals = OffsetChannelResiduals {
4337 exit: array![0.7, -0.2, 0.45],
4338 entry: array![0.1, 0.0, -0.3],
4339 derivative: array![1.3, -0.6, 0.2],
4340 right: Array1::<f64>::zeros(3),
4341 };
4342
4343 let reference_gradient = |partials: &dyn Fn(
4346 f64,
4347 &SurvivalBaselineConfig,
4348 )
4349 -> Result<Option<Vec<(f64, f64)>>, String>|
4350 -> Array1<f64> {
4351 let theta_dim = partials(age_exit[0], &cfg)
4352 .expect("probe partials")
4353 .expect("nonlinear")
4354 .len();
4355 let mut acc = Array1::<f64>::zeros(theta_dim);
4356 for i in 0..age_exit.len() {
4357 let p_exit = partials(age_exit[i], &cfg)
4358 .expect("exit partials")
4359 .expect("nonlinear");
4360 let r_x = residuals.exit[i];
4361 let r_d = residuals.derivative[i];
4362 for k in 0..theta_dim {
4363 acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4364 }
4365 let r_e = residuals.entry[i];
4366 if r_e != 0.0 {
4367 let p_entry = partials(age_entry[i], &cfg)
4368 .expect("entry partials")
4369 .expect("nonlinear");
4370 for k in 0..theta_dim {
4371 acc[k] += r_e * p_entry[k].0;
4372 }
4373 }
4374 }
4375 acc
4376 };
4377
4378 let rp_engine = baseline_chain_rule_gradient(
4380 age_entry.view(),
4381 age_exit.view(),
4382 age_exit.view(),
4383 &cfg,
4384 &residuals,
4385 )
4386 .expect("rp gradient")
4387 .expect("rp nonlinear");
4388 let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4389 assert_eq!(rp_engine.len(), rp_reference.len());
4390 for k in 0..rp_engine.len() {
4391 assert_close(
4392 rp_engine[k],
4393 rp_reference[k],
4394 0.0,
4395 &format!("rp engine vs inline reference theta[{k}]"),
4396 );
4397 }
4398
4399 let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4401 age_entry.view(),
4402 age_exit.view(),
4403 &cfg,
4404 &residuals,
4405 )
4406 .expect("probit gradient")
4407 .expect("probit nonlinear");
4408 let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4409 assert_eq!(probit_engine.len(), probit_reference.len());
4410 for k in 0..probit_engine.len() {
4411 assert_close(
4412 probit_engine[k],
4413 probit_reference[k],
4414 0.0,
4415 &format!("probit engine vs inline reference theta[{k}]"),
4416 );
4417 }
4418 }
4419
4420 #[test]
4441 fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4442 let cfg = SurvivalBaselineConfig {
4443 target: SurvivalBaselineTarget::GompertzMakeham,
4444 scale: None,
4445 shape: Some(0.05),
4446 rate: Some(0.012),
4447 makeham: Some(0.003),
4448 };
4449 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4451 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4452 let residuals = OffsetChannelResiduals {
4455 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4456 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4457 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4458 right: Array1::<f64>::zeros(8),
4459 };
4460
4461 let analytic = baseline_chain_rule_gradient(
4462 age_entry.view(),
4463 age_exit.view(),
4464 age_exit.view(),
4465 &cfg,
4466 &residuals,
4467 )
4468 .expect("analytic gradient ok")
4469 .expect("GM baseline has a θ-gradient");
4470 assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4471
4472 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4478 let mut acc = 0.0;
4479 for i in 0..age_exit.len() {
4480 let (eta_exit_i, od_exit_i) =
4481 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4482 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4483 if residuals.entry[i] != 0.0 {
4484 let (eta_entry_i, _) =
4485 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4486 acc += residuals.entry[i] * eta_entry_i;
4487 }
4488 }
4489 acc
4490 };
4491
4492 let theta0 = survival_baseline_theta_from_config(&cfg)
4493 .expect("theta seed")
4494 .expect("GM has θ");
4495 let delta = 1e-4;
4497 let mut fd = Array1::<f64>::zeros(analytic.len());
4498 for k in 0..analytic.len() {
4499 let mut theta_plus = theta0.clone();
4500 theta_plus[k] += delta;
4501 let mut theta_minus = theta0.clone();
4502 theta_minus[k] -= delta;
4503 let cfg_plus =
4504 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4505 let cfg_minus =
4506 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4507 let lp = loss_at_cfg(&cfg_plus);
4508 let lm = loss_at_cfg(&cfg_minus);
4509 fd[k] = (lp - lm) / (2.0 * delta);
4510 }
4511
4512 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4513 let max_err = analytic
4514 .iter()
4515 .zip(fd.iter())
4516 .map(|(a, b)| (a - b).abs())
4517 .fold(0.0_f64, f64::max);
4518 let rel = max_err / (analytic_norm + 1e-12);
4519 eprintln!(
4521 "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4522 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4523 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4524 );
4525 assert!(
4526 rel < 1e-2,
4527 "analytic θ-gradient disagrees with central FD beyond 1%: \
4528 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4529 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4530 );
4531 }
4532
4533 #[test]
4548 fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4549 let cfg = SurvivalBaselineConfig {
4550 target: SurvivalBaselineTarget::Weibull,
4551 scale: Some(11.0),
4552 shape: Some(1.4),
4553 rate: None,
4554 makeham: None,
4555 };
4556 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4557 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4558 let residuals = OffsetChannelResiduals {
4559 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4560 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4561 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4562 right: Array1::<f64>::zeros(8),
4563 };
4564
4565 let analytic = baseline_chain_rule_gradient(
4566 age_entry.view(),
4567 age_exit.view(),
4568 age_exit.view(),
4569 &cfg,
4570 &residuals,
4571 )
4572 .expect("analytic gradient ok")
4573 .expect("Weibull baseline has a θ-gradient");
4574 assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4575
4576 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4577 let mut acc = 0.0;
4578 for i in 0..age_exit.len() {
4579 let (eta_exit_i, od_exit_i) =
4580 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4581 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4582 if residuals.entry[i] != 0.0 {
4583 let (eta_entry_i, _) =
4584 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4585 acc += residuals.entry[i] * eta_entry_i;
4586 }
4587 }
4588 acc
4589 };
4590
4591 let theta0 = survival_baseline_theta_from_config(&cfg)
4592 .expect("theta seed")
4593 .expect("Weibull has θ");
4594 let delta = 1e-4;
4595 let mut fd = Array1::<f64>::zeros(analytic.len());
4596 for k in 0..analytic.len() {
4597 let mut theta_plus = theta0.clone();
4598 theta_plus[k] += delta;
4599 let mut theta_minus = theta0.clone();
4600 theta_minus[k] -= delta;
4601 let cfg_plus =
4602 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4603 let cfg_minus =
4604 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4605 let lp = loss_at_cfg(&cfg_plus);
4606 let lm = loss_at_cfg(&cfg_minus);
4607 fd[k] = (lp - lm) / (2.0 * delta);
4608 }
4609
4610 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4611 let max_err = analytic
4612 .iter()
4613 .zip(fd.iter())
4614 .map(|(a, b)| (a - b).abs())
4615 .fold(0.0_f64, f64::max);
4616 let rel = max_err / (analytic_norm + 1e-12);
4617 eprintln!(
4618 "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4619 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4620 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4621 );
4622 assert!(
4623 rel < 1e-2,
4624 "analytic θ-gradient disagrees with central FD beyond 1%: \
4625 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4626 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4627 );
4628 }
4629
4630 fn fd_baseline_offset(
4643 age: f64,
4644 cfg: &SurvivalBaselineConfig,
4645 steps: &[f64],
4646 ) -> Vec<(f64, f64)> {
4647 let theta = survival_baseline_theta_from_config(cfg)
4648 .expect("theta")
4649 .expect("non-linear baseline");
4650 assert_eq!(
4651 steps.len(),
4652 theta.len(),
4653 "fd_baseline_offset: step vector length must match θ dimension"
4654 );
4655 (0..theta.len())
4656 .map(|k| {
4657 let h = steps[k];
4658 let mut theta_plus = theta.clone();
4659 theta_plus[k] += h;
4660 let mut theta_minus = theta.clone();
4661 theta_minus[k] -= h;
4662 let cfg_plus =
4663 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4664 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4665 .expect("minus cfg");
4666 let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4667 let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4668 ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4669 })
4670 .collect()
4671 }
4672
4673 fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4674 let ok = if expected.abs() < 1.0 {
4678 (actual - expected).abs() <= tol
4679 } else {
4680 (actual - expected).abs() <= tol * expected.abs().max(1.0)
4681 };
4682 assert!(
4683 ok,
4684 "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4685 );
4686 }
4687
4688 #[test]
4689 fn gompertz_offset_partials_match_central_diff() {
4690 let cases = [
4694 (0.5_f64, 0.01_f64, 30.0_f64),
4695 (0.2, 0.05, 60.0),
4696 (1.0, 0.001, 10.0),
4697 (0.4, 5e-11, 25.0),
4698 (0.4, -5e-11, 25.0),
4699 (0.3, -0.02, 40.0),
4700 (0.8, 0.2, 5.0),
4701 ];
4702 for &(rate, shape, age) in &cases {
4703 let cfg = SurvivalBaselineConfig {
4704 target: SurvivalBaselineTarget::Gompertz,
4705 scale: None,
4706 shape: Some(shape),
4707 rate: Some(rate),
4708 makeham: None,
4709 };
4710 let analytic = baseline_offset_theta_partials(age, &cfg)
4711 .expect("ok")
4712 .expect("non-linear");
4713 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4719 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4720 assert_eq!(analytic.len(), 2);
4721 assert_close(
4723 analytic[0].0,
4724 fd[0].0,
4725 1e-7,
4726 &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4727 );
4728 assert_close(
4729 analytic[0].1,
4730 fd[0].1,
4731 1e-7,
4732 &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4733 );
4734 assert_close(
4737 analytic[1].0,
4738 fd[1].0,
4739 1e-5,
4740 &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4741 );
4742 assert_close(
4743 analytic[1].1,
4744 fd[1].1,
4745 1e-5,
4746 &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4747 );
4748 }
4749 }
4750
4751 #[test]
4752 fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4753 let cfg = SurvivalBaselineConfig {
4757 target: SurvivalBaselineTarget::Gompertz,
4758 scale: None,
4759 shape: Some(0.05),
4760 rate: Some(0.3),
4761 makeham: None,
4762 };
4763 let partials = baseline_offset_theta_partials(42.0, &cfg)
4764 .expect("ok")
4765 .expect("non-linear");
4766 assert_eq!(partials[0].0, 1.0);
4767 assert_eq!(partials[0].1, 0.0);
4768 }
4769
4770 #[test]
4771 fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4772 let age = 25.0;
4779 let rate = 0.4;
4780 let cfg_taylor = SurvivalBaselineConfig {
4781 target: SurvivalBaselineTarget::Gompertz,
4782 scale: None,
4783 shape: Some(0.5e-10),
4784 rate: Some(rate),
4785 makeham: None,
4786 };
4787 let cfg_direct = SurvivalBaselineConfig {
4788 target: SurvivalBaselineTarget::Gompertz,
4789 scale: None,
4790 shape: Some(2.0e-10),
4791 rate: Some(rate),
4792 makeham: None,
4793 };
4794 let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4795 .expect("ok")
4796 .expect("nl");
4797 let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4798 .expect("ok")
4799 .expect("nl");
4800 assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4802 assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4803 assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4805 assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4806 }
4807
4808 #[test]
4820 fn gompertz_hazard_shape_derivatives_match_central_diff() {
4821 let cases = [
4826 (10.0_f64, 0.012_f64, 0.05_f64),
4827 (2.5, 0.5, 0.2),
4828 (15.0, 0.003, 0.01),
4829 (40.0, 0.3, 0.001),
4830 ];
4831 let h = 1e-6;
4832 for &(age, rate, shape) in &cases {
4833 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4835 let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4836 let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4837 assert_close(
4838 d_cum,
4839 (cum_p - cum_m) / (2.0 * h),
4840 1e-6,
4841 &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4842 );
4843 assert_close(
4844 d_inst,
4845 (inst_p - inst_m) / (2.0 * h),
4846 1e-6,
4847 &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4848 );
4849
4850 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4852 let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4853 let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4854 assert_close(
4855 d2_cum,
4856 (dcum_p - dcum_m) / (2.0 * h),
4857 1e-5,
4858 &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4859 );
4860 assert_close(
4861 d2_inst,
4862 (dinst_p - dinst_m) / (2.0 * h),
4863 1e-5,
4864 &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4865 );
4866 }
4867 }
4868
4869 #[test]
4870 fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4871 let cases = [
4883 (25.0_f64, 0.4_f64, 1e-9_f64),
4884 (100.0, 0.4, 1e-6), (100.0, 0.012, 1e-6), (50.0, 1.2, 1e-8),
4887 ];
4888 for &(age, rate, shape) in &cases {
4899 let t = age;
4900 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4901 assert_close(
4902 d_cum,
4903 rate * t * t / 2.0,
4904 1e-3,
4905 &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4906 );
4907 assert_close(
4908 d_inst,
4909 rate * t,
4910 1e-3,
4911 &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4912 );
4913
4914 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4915 assert_close(
4916 d2_cum,
4917 rate * t * t * t / 3.0,
4918 1e-3,
4919 &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4920 );
4921 assert_close(
4922 d2_inst,
4923 rate * t * t,
4924 1e-3,
4925 &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4926 );
4927 }
4928 }
4929
4930 #[test]
4931 fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4932 let age = 100.0;
4939 let rate = 0.4;
4940 let t = age;
4941 let truth = rate * t * t * t / 3.0; for k in 5..=12 {
4948 let shape = 10f64.powi(-(k as i32)); let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4950 assert_close(
4951 d2_cum,
4952 truth,
4953 1e-3,
4954 &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
4955 );
4956 }
4957 }
4958
4959 #[test]
4960 fn weibull_offset_partials_match_central_diff() {
4961 let cases = [
4962 (0.5_f64, 1.2_f64, 25.0_f64),
4963 (2.0, 0.8, 60.0),
4964 (0.1, 3.0, 10.0),
4965 ];
4966 for &(scale, shape, age) in &cases {
4967 let cfg = SurvivalBaselineConfig {
4968 target: SurvivalBaselineTarget::Weibull,
4969 scale: Some(scale),
4970 shape: Some(shape),
4971 rate: None,
4972 makeham: None,
4973 };
4974 let analytic = baseline_offset_theta_partials(age, &cfg)
4975 .expect("ok")
4976 .expect("nl");
4977 let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
4978 assert_eq!(analytic.len(), 2);
4979 for k in 0..2 {
4980 assert_close(
4981 analytic[k].0,
4982 fd[k].0,
4983 1e-7,
4984 &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4985 );
4986 assert_close(
4987 analytic[k].1,
4988 fd[k].1,
4989 1e-7,
4990 &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4991 );
4992 }
4993 assert_eq!(analytic[0].1, 0.0);
4995 }
4996 }
4997
4998 #[test]
4999 fn gompertz_makeham_offset_partials_match_central_diff() {
5000 let cases = [
5001 (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
5002 (0.5, 0.01, 0.01, 25.0),
5003 (0.2, 0.001, 0.005, 60.0),
5004 (0.4, 5e-11, 0.01, 25.0),
5005 (0.4, -5e-11, 0.01, 25.0),
5006 (0.8, 0.2, 0.05, 5.0),
5007 ];
5008 for &(rate, shape, makeham, age) in &cases {
5009 let cfg = SurvivalBaselineConfig {
5010 target: SurvivalBaselineTarget::GompertzMakeham,
5011 scale: None,
5012 shape: Some(shape),
5013 rate: Some(rate),
5014 makeham: Some(makeham),
5015 };
5016 let analytic = baseline_offset_theta_partials(age, &cfg)
5017 .expect("ok")
5018 .expect("nl");
5019 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
5023 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
5024 assert_eq!(analytic.len(), 3);
5025 for k in 0..3 {
5026 assert_close(
5027 analytic[k].0,
5028 fd[k].0,
5029 1e-5,
5030 &format!(
5031 "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5032 ),
5033 );
5034 assert_close(
5035 analytic[k].1,
5036 fd[k].1,
5037 1e-5,
5038 &format!(
5039 "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5040 ),
5041 );
5042 }
5043 }
5044 }
5045
5046 #[test]
5047 fn linear_baseline_has_no_theta_partials() {
5048 let cfg = SurvivalBaselineConfig {
5049 target: SurvivalBaselineTarget::Linear,
5050 scale: None,
5051 shape: None,
5052 rate: None,
5053 makeham: None,
5054 };
5055 assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
5056 }
5057
5058 #[test]
5059 fn baseline_offset_partials_reject_non_positive_ages() {
5060 let cfg = SurvivalBaselineConfig {
5061 target: SurvivalBaselineTarget::Gompertz,
5062 scale: None,
5063 shape: Some(0.01),
5064 rate: Some(0.5),
5065 makeham: None,
5066 };
5067 assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
5068 assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
5069 assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
5070 }
5071
5072 #[test]
5078 fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
5079 let cfg = SurvivalBaselineConfig {
5080 target: SurvivalBaselineTarget::Gompertz,
5081 scale: None,
5082 shape: Some(0.05),
5083 rate: Some(0.3),
5084 makeham: None,
5085 };
5086 let age_entry = array![10.0_f64];
5087 let age_exit = array![25.0_f64];
5088 let residuals = OffsetChannelResiduals {
5089 exit: array![0.7_f64],
5090 entry: array![-0.2_f64],
5091 derivative: array![-0.4_f64],
5092 right: Array1::<f64>::zeros(1),
5093 };
5094 let grad = baseline_chain_rule_gradient(
5095 age_entry.view(),
5096 age_exit.view(),
5097 age_exit.view(),
5098 &cfg,
5099 &residuals,
5100 )
5101 .expect("ok")
5102 .expect("non-linear");
5103 let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5105 .unwrap()
5106 .unwrap();
5107 let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5108 .unwrap()
5109 .unwrap();
5110 for k in 0..p_exit.len() {
5111 let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5112 assert!(
5113 (grad[k] - expected).abs() < 1e-12,
5114 "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5115 grad[k],
5116 expected
5117 );
5118 }
5119 }
5120
5121 #[test]
5124 fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5125 let cfg = SurvivalBaselineConfig {
5126 target: SurvivalBaselineTarget::Gompertz,
5127 scale: None,
5128 shape: Some(0.05),
5129 rate: Some(0.3),
5130 makeham: None,
5131 };
5132 let age_entry = array![0.0_f64, 5.0_f64];
5133 let age_exit = array![10.0_f64, 20.0_f64];
5134 let residuals = OffsetChannelResiduals {
5135 exit: array![0.5_f64, 0.3_f64],
5136 entry: array![0.0_f64, -0.1_f64], derivative: array![-0.2_f64, 0.0_f64],
5138 right: Array1::<f64>::zeros(2),
5139 };
5140 let grad = baseline_chain_rule_gradient(
5142 age_entry.view(),
5143 age_exit.view(),
5144 age_exit.view(),
5145 &cfg,
5146 &residuals,
5147 )
5148 .expect("must not fail on origin-entry row with r_entry=0")
5149 .expect("non-linear");
5150 assert_eq!(grad.len(), 2);
5151 let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5153 let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5154 let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5155 for k in 0..2 {
5156 let expected = 0.5 * p_exit_0[k].0
5157 + (-0.2) * p_exit_0[k].1
5158 + 0.3 * p_exit_1[k].0
5159 + (-0.1) * p_entry_1[k].0;
5160 assert!(
5161 (grad[k] - expected).abs() < 1e-12,
5162 "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5163 grad[k],
5164 expected
5165 );
5166 }
5167 }
5168
5169 #[test]
5171 fn chain_rule_gradient_linear_target_returns_none() {
5172 let cfg = SurvivalBaselineConfig {
5173 target: SurvivalBaselineTarget::Linear,
5174 scale: None,
5175 shape: None,
5176 rate: None,
5177 makeham: None,
5178 };
5179 let age_entry = array![1.0_f64];
5180 let age_exit = array![2.0_f64];
5181 let residuals = OffsetChannelResiduals {
5182 exit: array![0.1_f64],
5183 entry: array![0.0_f64],
5184 derivative: array![0.0_f64],
5185 right: Array1::<f64>::zeros(1),
5186 };
5187 let grad = baseline_chain_rule_gradient(
5188 age_entry.view(),
5189 age_exit.view(),
5190 age_exit.view(),
5191 &cfg,
5192 &residuals,
5193 )
5194 .expect("ok");
5195 assert!(grad.is_none());
5196 }
5197
5198 #[test]
5217 fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5218 let cfg = SurvivalBaselineConfig {
5221 target: SurvivalBaselineTarget::Gompertz,
5222 scale: None,
5223 shape: Some(0.03),
5224 rate: Some(0.25),
5225 makeham: None,
5226 };
5227 let age_entry = array![0.0_f64, 5.0, 8.0];
5228 let age_exit = array![4.0_f64, 12.0, 20.0];
5229 let weights = array![1.0_f64, 2.0, 0.5];
5232 let events = [1.0_f64, 1.0, 0.0];
5233 let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5238 let s_vals = [0.7_f64, 1.1, 1.5];
5239 let (r_x, r_e, r_d) = {
5240 let mut rx = Array1::<f64>::zeros(3);
5241 let mut re = Array1::<f64>::zeros(3);
5242 let mut rd = Array1::<f64>::zeros(3);
5243 for i in 0..3 {
5244 let w = weights[i];
5245 let d = events[i];
5246 rx[i] = w * (eta_exit_vals[i].exp() - d);
5247 re[i] = if i == 0 {
5248 0.0 } else {
5250 -w * eta_entry_vals[i].exp()
5251 };
5252 rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5253 }
5254 (rx, re, rd)
5255 };
5256 let residuals = OffsetChannelResiduals {
5257 exit: r_x.clone(),
5258 entry: r_e.clone(),
5259 derivative: r_d.clone(),
5260 right: Array1::<f64>::zeros(3),
5261 };
5262 let grad = baseline_chain_rule_gradient(
5263 age_entry.view(),
5264 age_exit.view(),
5265 age_exit.view(),
5266 &cfg,
5267 &residuals,
5268 )
5269 .expect("ok")
5270 .expect("non-linear");
5271
5272 let nll = |theta_plus: &Array1<f64>| -> f64 {
5277 let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5278 let mut sum = 0.0_f64;
5279 for i in 0..3 {
5280 let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5281 let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5282 let d_eta_x = eta_x_p - base.0;
5283 let d_d_x = d_x_p - base.1;
5284 let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5285 let s_new = s_vals[i] + d_d_x;
5286 let interval_entry = if i == 0 {
5287 0.0_f64
5288 } else {
5289 let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5290 let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5291 let d_eta_e = eta_e_p - base_e.0;
5292 let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5293 eta_entry_new.exp()
5294 };
5295 let w = weights[i];
5296 let d = events[i];
5297 let nll_i =
5298 w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5299 sum += nll_i;
5300 }
5301 sum
5302 };
5303
5304 let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5305 let h = 1e-6;
5306 for k in 0..theta_base.len() {
5307 let mut tp = theta_base.clone();
5308 let mut tm = theta_base.clone();
5309 tp[k] += h;
5310 tm[k] -= h;
5311 let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5312 assert!(
5313 (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5314 "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5315 grad[k],
5316 fd
5317 );
5318 }
5319 }
5320
5321 #[test]
5323 fn chain_rule_gradient_rejects_length_mismatch() {
5324 let cfg = SurvivalBaselineConfig {
5325 target: SurvivalBaselineTarget::Gompertz,
5326 scale: None,
5327 shape: Some(0.05),
5328 rate: Some(0.3),
5329 makeham: None,
5330 };
5331 let age_entry = array![1.0_f64, 2.0]; let age_exit = array![5.0_f64, 6.0, 7.0]; let residuals = OffsetChannelResiduals {
5334 exit: array![0.1_f64, 0.2, 0.3],
5335 entry: array![0.0_f64, 0.0, 0.0],
5336 derivative: array![0.0_f64, 0.0, 0.0],
5337 right: Array1::<f64>::zeros(3),
5338 };
5339 let err = baseline_chain_rule_gradient(
5340 age_entry.view(),
5341 age_exit.view(),
5342 age_exit.view(),
5343 &cfg,
5344 &residuals,
5345 )
5346 .expect_err("length mismatch must error");
5347 assert!(err.contains("length mismatch"), "err={err}");
5348 }
5349}