1use gam_terms::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14 BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15 create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18 DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19 SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24 WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25 monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::{InverseLink, StandardLink};
31use ndarray::{Array1, Array2, array, s};
32use rayon::prelude::*;
33
34#[derive(Clone, Debug)]
51pub enum SurvivalConstructionError {
52 InvalidConfig { reason: String },
55 MissingColumn { reason: String },
58 IncompatibleDimensions { reason: String },
61 DataValidationFailed { reason: String },
65 BasisConstructionFailed { reason: String },
69 UnsupportedDistribution { reason: String },
72}
73
74impl_reason_error_boilerplate! {
75 SurvivalConstructionError {
76 InvalidConfig,
77 MissingColumn,
78 IncompatibleDimensions,
79 DataValidationFailed,
80 BasisConstructionFailed,
81 UnsupportedDistribution,
82 }
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SurvivalBaselineTarget {
91 Linear,
95 Weibull,
100 Gompertz,
105 GompertzMakeham,
110}
111
112#[derive(Clone, Debug)]
113pub struct SurvivalBaselineConfig {
114 pub target: SurvivalBaselineTarget,
115 pub scale: Option<f64>,
116 pub shape: Option<f64>,
117 pub rate: Option<f64>,
118 pub makeham: Option<f64>,
119}
120
121#[derive(Clone, Debug)]
122pub enum SurvivalTimeBasisConfig {
123 None,
124 Linear,
125 BSpline {
126 degree: usize,
127 knots: Array1<f64>,
128 smooth_lambda: f64,
129 },
130 ISpline {
168 degree: usize,
169 knots: Array1<f64>,
170 keep_cols: Vec<usize>,
171 smooth_lambda: f64,
172 },
173}
174
175#[derive(Clone, Debug, PartialEq)]
189pub struct SavedSurvivalTimeBasis {
190 pub basisname: String,
191 pub degree: Option<usize>,
192 pub knots: Option<Vec<f64>>,
193 pub keep_cols: Option<Vec<usize>>,
194 pub smooth_lambda: Option<f64>,
195 pub anchor: f64,
196}
197
198impl SavedSurvivalTimeBasis {
199 pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
202 Self {
203 basisname: build.basisname.clone(),
204 degree: build.degree,
205 knots: build.knots.clone(),
206 keep_cols: build.keep_cols.clone(),
207 smooth_lambda: build.smooth_lambda,
208 anchor,
209 }
210 }
211}
212
213#[derive(Clone)]
214pub struct SurvivalTimeBuildOutput {
215 pub x_entry_time: DesignMatrix,
216 pub x_exit_time: DesignMatrix,
217 pub x_derivative_time: DesignMatrix,
218 pub penalties: Vec<Array2<f64>>,
219 pub nullspace_dims: Vec<usize>,
221 pub basisname: String,
222 pub degree: Option<usize>,
223 pub knots: Option<Vec<f64>>,
224 pub keep_cols: Option<Vec<usize>>,
225 pub smooth_lambda: Option<f64>,
226}
227
228pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
229
230pub const SURVIVAL_DELAYED_ENTRY_THRESHOLD: f64 = 1e-8;
236
237const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
245
246const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
254
255#[derive(Clone, Copy, Debug, PartialEq, Eq)]
256pub enum SurvivalLikelihoodMode {
257 Transformation,
258 Weibull,
259 LocationScale,
260 MarginalSlope,
261 Latent,
262 LatentBinary,
263}
264
265pub struct SurvivalTimeWiggleBuild {
266 pub penalties: Vec<Array2<f64>>,
267 pub nullspace_dims: Vec<usize>,
268 pub knots: Array1<f64>,
269 pub degree: usize,
270 pub ncols: usize,
271}
272
273pub fn normalize_survival_time_pair(
278 entry_raw: f64,
279 exit_raw: f64,
280 row_index: usize,
281) -> Result<(f64, f64), String> {
282 if !entry_raw.is_finite() || !exit_raw.is_finite() {
283 return Err(SurvivalConstructionError::DataValidationFailed {
284 reason: format!("non-finite survival times at row {}", row_index + 1),
285 }
286 .into());
287 }
288 if entry_raw < 0.0 || exit_raw < 0.0 {
289 return Err(SurvivalConstructionError::DataValidationFailed {
290 reason: format!("negative survival times at row {}", row_index + 1),
291 }
292 .into());
293 }
294
295 let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
296 let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
297 Ok((entry, exit))
298}
299
300pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
305 basisname.eq_ignore_ascii_case("ispline")
306}
307
308pub fn require_structural_survival_time_basis(
309 basisname: &str,
310 context: &str,
311) -> Result<(), String> {
312 if survival_basis_supports_structural_monotonicity(basisname) {
313 return Ok(());
314 }
315 Err(SurvivalConstructionError::UnsupportedDistribution {
316 reason: format!(
317 "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
318Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
319`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
320Re-run with `--time-basis ispline`."
321 ),
322 }
323 .into())
324}
325
326pub fn parse_survival_baseline_config(
331 target_raw: &str,
332 scale: Option<f64>,
333 shape: Option<f64>,
334 rate: Option<f64>,
335 makeham: Option<f64>,
336) -> Result<SurvivalBaselineConfig, String> {
337 let target = match target_raw.to_ascii_lowercase().as_str() {
338 "linear" => SurvivalBaselineTarget::Linear,
339 "weibull" => SurvivalBaselineTarget::Weibull,
340 "gompertz" => SurvivalBaselineTarget::Gompertz,
341 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
342 other => {
343 return Err(SurvivalConstructionError::UnsupportedDistribution {
344 reason: format!(
345 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
346 ),
347 }
348 .into());
349 }
350 };
351
352 match target {
353 SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
354 target,
355 scale: None,
356 shape: None,
357 rate: None,
358 makeham: None,
359 }),
360 SurvivalBaselineTarget::Weibull => {
361 let scale = scale.ok_or_else(|| {
362 "--baseline-target weibull requires --baseline-scale > 0".to_string()
363 })?;
364 let shape = shape.ok_or_else(|| {
365 "--baseline-target weibull requires --baseline-shape > 0".to_string()
366 })?;
367 if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
368 return Err(
369 "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
370 .to_string(),
371 );
372 }
373 Ok(SurvivalBaselineConfig {
374 target,
375 scale: Some(scale),
376 shape: Some(shape),
377 rate: None,
378 makeham: None,
379 })
380 }
381 SurvivalBaselineTarget::Gompertz => {
382 let rate = rate.unwrap_or(1.0);
383 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
384 if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
385 return Err(
386 "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
387 .to_string(),
388 );
389 }
390 Ok(SurvivalBaselineConfig {
391 target,
392 scale: None,
393 shape: Some(shape),
394 rate: Some(rate),
395 makeham: None,
396 })
397 }
398 SurvivalBaselineTarget::GompertzMakeham => {
399 let rate = rate.unwrap_or(0.5);
400 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
401 let makeham = makeham.unwrap_or(0.5);
402 if !rate.is_finite()
403 || rate <= 0.0
404 || !shape.is_finite()
405 || !makeham.is_finite()
406 || makeham <= 0.0
407 {
408 return Err(
409 "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
410 .to_string(),
411 );
412 }
413 Ok(SurvivalBaselineConfig {
414 target,
415 scale: None,
416 shape: Some(shape),
417 rate: Some(rate),
418 makeham: Some(makeham),
419 })
420 }
421 }
422}
423
424pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
429 match raw.to_ascii_lowercase().as_str() {
430 "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
431 "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
432 "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
433 "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
434 "latent" => Ok(SurvivalLikelihoodMode::Latent),
435 "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
436 other => Err(SurvivalConstructionError::UnsupportedDistribution {
437 reason: format!(
438 "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
439 ),
440 }
441 .into()),
442 }
443}
444
445pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
446 match mode {
447 SurvivalLikelihoodMode::Transformation => "transformation",
448 SurvivalLikelihoodMode::Weibull => "weibull",
449 SurvivalLikelihoodMode::LocationScale => "location-scale",
450 SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
451 SurvivalLikelihoodMode::Latent => "latent",
452 SurvivalLikelihoodMode::LatentBinary => "latent-binary",
453 }
454}
455
456pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
457 match raw.to_ascii_lowercase().as_str() {
458 "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
459 "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
460 "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
461 other => Err(SurvivalConstructionError::UnsupportedDistribution {
462 reason: format!(
463 "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
464 ),
465 }
466 .into()),
467 }
468}
469
470pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
471 match target {
472 SurvivalBaselineTarget::Linear => "linear",
473 SurvivalBaselineTarget::Weibull => "weibull",
474 SurvivalBaselineTarget::Gompertz => "gompertz",
475 SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
476 }
477}
478
479pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
480 let sum = age_exit
481 .iter()
482 .copied()
483 .filter(|value| value.is_finite() && *value > 0.0)
484 .sum::<f64>();
485 let count = age_exit
486 .iter()
487 .filter(|value| value.is_finite() && **value > 0.0)
488 .count()
489 .max(1);
490 (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
491}
492
493pub fn initial_survival_baseline_config_for_fit(
494 target_raw: &str,
495 scale: Option<f64>,
496 shape: Option<f64>,
497 rate: Option<f64>,
498 makeham: Option<f64>,
499 age_exit: &Array1<f64>,
500) -> Result<SurvivalBaselineConfig, String> {
501 let target = match target_raw.trim().to_ascii_lowercase().as_str() {
502 "linear" => SurvivalBaselineTarget::Linear,
503 "weibull" => SurvivalBaselineTarget::Weibull,
504 "gompertz" => SurvivalBaselineTarget::Gompertz,
505 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
506 other => {
507 return Err(SurvivalConstructionError::UnsupportedDistribution {
508 reason: format!(
509 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
510 ),
511 }
512 .into());
513 }
514 };
515 let time_scale_seed = positive_survival_time_seed(age_exit);
516 let cfg = match target {
517 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
518 target,
519 scale: None,
520 shape: None,
521 rate: None,
522 makeham: None,
523 },
524 SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
525 target,
526 scale: Some(scale.unwrap_or(time_scale_seed)),
527 shape: Some(shape.unwrap_or(1.0)),
528 rate: None,
529 makeham: None,
530 },
531 SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
532 target,
533 scale: None,
534 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
535 rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
536 makeham: None,
537 },
538 SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
539 target,
540 scale: None,
541 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
542 rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
543 makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
544 },
545 };
546 parse_survival_baseline_config(
547 survival_baseline_targetname(cfg.target),
548 cfg.scale,
549 cfg.shape,
550 cfg.rate,
551 cfg.makeham,
552 )
553}
554
555fn survival_baseline_theta_from_config(
556 cfg: &SurvivalBaselineConfig,
557) -> Result<Option<Array1<f64>>, String> {
558 Ok(match cfg.target {
559 SurvivalBaselineTarget::Linear => None,
560 SurvivalBaselineTarget::Weibull => Some(array![
561 cfg.scale
562 .ok_or_else(|| "missing weibull baseline scale".to_string())?
563 .ln(),
564 cfg.shape
565 .ok_or_else(|| "missing weibull baseline shape".to_string())?
566 .ln(),
567 ]),
568 SurvivalBaselineTarget::Gompertz => Some(array![
569 cfg.rate
570 .ok_or_else(|| "missing gompertz baseline rate".to_string())?
571 .ln(),
572 cfg.shape
573 .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
574 ]),
575 SurvivalBaselineTarget::GompertzMakeham => Some(array![
576 cfg.rate
577 .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
578 .ln(),
579 cfg.shape
580 .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
581 cfg.makeham
582 .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
583 .ln(),
584 ]),
585 })
586}
587
588fn survival_baseline_config_from_theta(
589 target: SurvivalBaselineTarget,
590 theta: &Array1<f64>,
591) -> Result<SurvivalBaselineConfig, String> {
592 let cfg = match target {
593 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
594 target,
595 scale: None,
596 shape: None,
597 rate: None,
598 makeham: None,
599 },
600 SurvivalBaselineTarget::Weibull => {
601 if theta.len() != 2 {
602 return Err(SurvivalConstructionError::IncompatibleDimensions {
603 reason: format!(
604 "weibull baseline parameter dimension mismatch: expected 2, got {}",
605 theta.len()
606 ),
607 }
608 .into());
609 }
610 SurvivalBaselineConfig {
611 target,
612 scale: Some(theta[0].exp()),
613 shape: Some(theta[1].exp()),
614 rate: None,
615 makeham: None,
616 }
617 }
618 SurvivalBaselineTarget::Gompertz => {
619 if theta.len() != 2 {
620 return Err(SurvivalConstructionError::IncompatibleDimensions {
621 reason: format!(
622 "gompertz baseline parameter dimension mismatch: expected 2, got {}",
623 theta.len()
624 ),
625 }
626 .into());
627 }
628 SurvivalBaselineConfig {
629 target,
630 scale: None,
631 shape: Some(theta[1]),
632 rate: Some(theta[0].exp()),
633 makeham: None,
634 }
635 }
636 SurvivalBaselineTarget::GompertzMakeham => {
637 if theta.len() != 3 {
638 return Err(SurvivalConstructionError::IncompatibleDimensions {
639 reason: format!(
640 "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
641 theta.len()
642 ),
643 }
644 .into());
645 }
646 SurvivalBaselineConfig {
647 target,
648 scale: None,
649 shape: Some(theta[1]),
650 rate: Some(theta[0].exp()),
651 makeham: Some(theta[2].exp()),
652 }
653 }
654 };
655 parse_survival_baseline_config(
656 survival_baseline_targetname(cfg.target),
657 cfg.scale,
658 cfg.shape,
659 cfg.rate,
660 cfg.makeham,
661 )
662}
663
664#[derive(Clone, Copy, Debug, PartialEq, Eq)]
677enum BaselineDerivativeContract {
678 GradientOnly,
681 GradientHessian,
685}
686
687impl BaselineDerivativeContract {
688 fn configure(
693 self,
694 problem: gam_solve::rho_optimizer::OuterProblem,
695 ) -> gam_solve::rho_optimizer::OuterProblem {
696 use gam_problem::{DeclaredHessianForm, Derivative};
697 match self {
698 BaselineDerivativeContract::GradientOnly => problem
701 .with_gradient(Derivative::Analytic)
702 .with_hessian(DeclaredHessianForm::Unavailable)
703 .with_tolerance(1e-4)
704 .with_max_iter(240),
705 BaselineDerivativeContract::GradientHessian => problem
706 .with_gradient(Derivative::Analytic)
707 .with_hessian(DeclaredHessianForm::Either)
708 .with_tolerance(1e-4)
709 .with_max_iter(240),
710 }
711 }
712}
713
714fn run_baseline_theta_optimizer<Fc, Fe>(
725 initial: &SurvivalBaselineConfig,
726 context: &str,
727 contract: BaselineDerivativeContract,
728 cost_fn: Fc,
729 eval_fn: Fe,
730) -> Result<SurvivalBaselineConfig, String>
731where
732 Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
733 Fe: FnMut(
734 &mut (),
735 &Array1<f64>,
736 ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
737{
738 use gam_solve::rho_optimizer::OuterProblem;
739 let Some(seed) = survival_baseline_theta_from_config(initial)? else {
740 return Ok(initial.clone());
741 };
742 let dim = seed.len();
743 let target = initial.target;
744 let lower = seed.mapv(|v| v - 6.0);
745 let upper = seed.mapv(|v| v + 6.0);
746 let problem = contract
747 .configure(OuterProblem::new(dim))
748 .with_bounds(lower, upper)
749 .with_initial_rho(seed.clone())
750 .with_seed_config(crate::seeding::SeedConfig {
751 max_seeds: 1,
752 seed_budget: 1,
753 num_auxiliary_trailing: dim,
754 ..Default::default()
755 });
756 let mut obj = problem.build_objective(
757 (),
758 cost_fn,
759 eval_fn,
760 None::<fn(&mut ())>,
761 None::<
762 fn(
763 &mut (),
764 &Array1<f64>,
765 ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
766 >,
767 );
768 let result = problem
769 .run(&mut obj, context)
770 .map_err(|e| format!("{context} failed: {e}"))?;
771 if !result.converged {
772 return Err(SurvivalConstructionError::InvalidConfig {
773 reason: format!(
774 "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
775 result.iterations,
776 result.final_value,
777 result.final_grad_norm_report(),
778 ),
779 }
780 .into());
781 }
782 survival_baseline_config_from_theta(target, &result.rho)
783}
784
785fn run_baseline_theta_optimizer_with_eval<F>(
798 initial: &SurvivalBaselineConfig,
799 context: &str,
800 contract: BaselineDerivativeContract,
801 objective: F,
802) -> Result<SurvivalBaselineConfig, String>
803where
804 F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
805{
806 let target = initial.target;
807 let engine_context = context.to_string();
808 let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
809 let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
810 theta: &Array1<f64>|
811 -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
812 let cfg = survival_baseline_config_from_theta(target, theta)
813 .map_err(crate::model_types::EstimationError::InvalidInput)?;
814 let eval =
815 obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
816 if eval.gradient.len() != theta.len() {
817 return Err(crate::model_types::EstimationError::InvalidInput(format!(
818 "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
819 eval.gradient.len(),
820 theta.len()
821 )));
822 }
823 if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
824 if h.nrows() != theta.len() || h.ncols() != theta.len() {
825 return Err(crate::model_types::EstimationError::InvalidInput(format!(
826 "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
827 h.nrows(),
828 h.ncols(),
829 theta.len(),
830 theta.len()
831 )));
832 }
833 }
834 Ok(eval)
835 };
836 let cost_objective = std::rc::Rc::clone(&objective);
837 let cost_eval = eval_at.clone();
838 let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
839 cost_eval(&cost_objective, theta).map(|eval| eval.cost)
840 };
841 let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
842 run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
843}
844
845pub fn optimize_survival_baseline_config_with_gradient_only<F>(
856 initial: &SurvivalBaselineConfig,
857 context: &str,
858 mut objective: F,
859) -> Result<SurvivalBaselineConfig, String>
860where
861 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
862{
863 use gam_problem::{HessianResult, OuterEval};
864 run_baseline_theta_optimizer_with_eval(
865 initial,
866 context,
867 BaselineDerivativeContract::GradientOnly,
868 move |cfg| {
869 let (cost, gradient) = objective(cfg)?;
870 Ok(OuterEval {
871 cost,
872 gradient,
873 hessian: HessianResult::Unavailable,
874 inner_beta_hint: None,
875 })
876 },
877 )
878}
879
880pub fn optimize_survival_baseline_config_with_gradient<F>(
885 initial: &SurvivalBaselineConfig,
886 context: &str,
887 mut objective: F,
888) -> Result<SurvivalBaselineConfig, String>
889where
890 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
891{
892 use gam_problem::{HessianResult, OuterEval};
893 run_baseline_theta_optimizer_with_eval(
894 initial,
895 context,
896 BaselineDerivativeContract::GradientHessian,
897 move |cfg| {
898 let (cost, gradient, hessian) = objective(cfg)?;
899 Ok(OuterEval {
900 cost,
901 gradient,
902 hessian: HessianResult::Analytic(hessian),
903 inner_beta_hint: None,
904 })
905 },
906 )
907}
908
909pub fn parse_survival_time_basis_config(
914 time_basis: &str,
915 time_degree: usize,
916 time_num_internal_knots: usize,
917 time_smooth_lambda: f64,
918) -> Result<SurvivalTimeBasisConfig, String> {
919 match time_basis.to_ascii_lowercase().as_str() {
920 "none" => Ok(SurvivalTimeBasisConfig::None),
921 "ispline" => {
922 if time_degree < 1 {
923 return Err(
924 "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
925 .to_string(),
926 );
927 }
928 if time_num_internal_knots == 0 {
929 return Err(
930 "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
931 .to_string(),
932 );
933 }
934 if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
935 return Err(
936 "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
937 .to_string(),
938 );
939 }
940 Ok(SurvivalTimeBasisConfig::ISpline {
941 degree: time_degree,
942 knots: Array1::zeros(0),
943 keep_cols: Vec::new(),
944 smooth_lambda: time_smooth_lambda,
945 })
946 }
947 "linear" | "bspline" => {
948 match require_structural_survival_time_basis(time_basis, "survival model configuration")
955 {
956 Err(e) => Err(e),
957 Ok(()) => Err(format!(
958 "internal: structural-basis check accepted non-structural \
959 survival time basis '{time_basis}'"
960 )),
961 }
962 }
963 other => Err(format!(
964 "unsupported --time-basis '{other}'; accepted values: ispline, none"
965 )),
966 }
967}
968
969pub fn build_survival_time_basis(
974 age_entry: &Array1<f64>,
975 age_exit: &Array1<f64>,
976 cfg: SurvivalTimeBasisConfig,
977 infer_knots_if_needed: Option<(usize, f64)>,
978) -> Result<SurvivalTimeBuildOutput, String> {
979 fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
980 if let Some(row) = times.iter().position(|t| !t.is_finite()) {
981 return Err(SurvivalConstructionError::DataValidationFailed {
982 reason: format!(
983 "survival time basis requires finite {label} times (row {})",
984 row + 1
985 ),
986 }
987 .into());
988 }
989 if let Some(row) = times.iter().position(|t| *t < 0.0) {
990 return Err(SurvivalConstructionError::DataValidationFailed {
991 reason: format!(
992 "survival time basis requires non-negative {label} times (row {})",
993 row + 1
994 ),
995 }
996 .into());
997 }
998 Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
999 }
1000
1001 let n = age_entry.len();
1002 if n != age_exit.len() {
1003 return Err(SurvivalConstructionError::IncompatibleDimensions {
1004 reason: "survival time basis requires matching entry/exit lengths".to_string(),
1005 }
1006 .into());
1007 }
1008 for i in 0..n {
1009 if age_exit[i] < age_entry[i] {
1010 return Err(format!(
1011 "survival time basis requires exit times >= entry times (row {})",
1012 i + 1
1013 ));
1014 }
1015 }
1016 let log_entry = checked_log_survival_times(age_entry, "entry")?;
1017 let log_exit = checked_log_survival_times(age_exit, "exit")?;
1018
1019 fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1020 let n = log_entry.len();
1021 let entry_range = log_entry
1022 .iter()
1023 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1024 (lo.min(v), hi.max(v))
1025 });
1026 let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1027 if entry_degenerate {
1028 log_exit.clone()
1029 } else {
1030 let mut combined = Array1::<f64>::zeros(2 * n);
1031 for i in 0..n {
1032 combined[i] = log_entry[i];
1033 combined[n + i] = log_exit[i];
1034 }
1035 combined
1036 }
1037 }
1038
1039 fn data_capped_internal_knots(
1062 combined: &Array1<f64>,
1063 degree: usize,
1064 requested_internal_knots: usize,
1065 ) -> usize {
1066 if requested_internal_knots == 0 {
1067 return 0;
1068 }
1069 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1070 sorted.sort_by(f64::total_cmp);
1071 let minval = sorted.first().copied().unwrap_or(0.0);
1072 let maxval = sorted.last().copied().unwrap_or(minval);
1073 if minval == maxval {
1074 return 1.min(requested_internal_knots);
1076 }
1077 let scale = (maxval - minval).abs().max(1.0);
1078 let tol = 1e-12 * scale;
1079 let mut distinct_interior = 0usize;
1082 let mut last: Option<f64> = None;
1083 for &x in &sorted {
1084 if x <= minval + tol || x >= maxval - tol {
1085 continue;
1086 }
1087 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1088 continue;
1089 }
1090 distinct_interior += 1;
1091 last = Some(x);
1092 }
1093 let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1096 let n_distinct = {
1102 let mut count = 0usize;
1103 let mut last: Option<f64> = None;
1104 for &x in &sorted {
1105 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1106 continue;
1107 }
1108 count += 1;
1109 last = Some(x);
1110 }
1111 count
1112 };
1113 let dim_budget = n_distinct / 4;
1114 let dim_cap = dim_budget.saturating_sub(degree);
1115 cap = cap.min(dim_cap.max(1));
1116 cap.max(1)
1117 }
1118
1119 fn infer_survival_time_knots(
1120 combined: &Array1<f64>,
1121 knot_degree: usize,
1122 validation_degree: usize,
1123 num_internal_knots: usize,
1124 basis_options: BasisOptions,
1125 ) -> Result<Array1<f64>, String> {
1126 let num_internal_knots =
1132 data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1133
1134 fn quantile_knot_inference_needs_uniform_fallback(
1135 combined: &Array1<f64>,
1136 num_internal_knots: usize,
1137 ) -> bool {
1138 if num_internal_knots == 0 || combined.is_empty() {
1139 return false;
1140 }
1141
1142 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1143 sorted.sort_by(f64::total_cmp);
1144 let minval = sorted[0];
1145 let maxval = *sorted.last().unwrap_or(&minval);
1146 if minval == maxval {
1147 return false;
1148 }
1149
1150 let scale = (maxval - minval).abs().max(1.0);
1151 let tol = 1e-12 * scale;
1152 let mut support = Vec::with_capacity(sorted.len());
1153 let mut last: Option<f64> = None;
1154 for &x in &sorted {
1155 if x <= minval + tol || x >= maxval - tol {
1156 continue;
1157 }
1158 if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1159 continue;
1160 }
1161 support.push(x);
1162 last = Some(x);
1163 }
1164 if support.is_empty() {
1165 return true;
1166 }
1167
1168 let n = support.len();
1169 let mut prev_q = minval;
1170 for j in 1..=num_internal_knots {
1171 let p = j as f64 / (num_internal_knots + 1) as f64;
1172 let pos = p * (n.saturating_sub(1) as f64);
1173 let lo = pos.floor() as usize;
1174 let hi = pos.ceil() as usize;
1175 let frac = pos - lo as f64;
1176 let q = if lo == hi {
1177 support[lo]
1178 } else {
1179 support[lo] * (1.0 - frac) + support[hi] * frac
1180 }
1181 .clamp(minval, maxval);
1182 if q <= prev_q + tol || q >= maxval - tol {
1183 return true;
1184 }
1185 prev_q = q;
1186 }
1187
1188 false
1189 }
1190
1191 let inferwith =
1192 |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1193 let built = build_bspline_basis_1d(
1194 combined.view(),
1195 &BSplineBasisSpec {
1196 degree: knot_degree,
1197 penalty_order: 2,
1198 knotspec: BSplineKnotSpec::Automatic {
1199 num_internal_knots: Some(num_internal_knots),
1200 placement,
1201 },
1202 double_penalty: false,
1203 identifiability: BSplineIdentifiability::None,
1204 boundary: OneDimensionalBoundary::Open,
1205 boundary_conditions: BSplineBoundaryConditions::default(),
1206 },
1207 )
1208 .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1209 let knots = match built.metadata {
1210 BasisMetadata::BSpline1D { knots, .. } => knots,
1211 _ => {
1212 return Err(
1213 "internal error: expected BSpline1D metadata for survival time basis"
1214 .to_string(),
1215 );
1216 }
1217 };
1218 create_basis::<Dense>(
1227 combined.view(),
1228 KnotSource::Provided(knots.view()),
1229 validation_degree,
1230 basis_options,
1231 )
1232 .map_err(|e| e.to_string())?;
1233 Ok(knots)
1234 };
1235
1236 if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1237 inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1238 } else {
1239 inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1240 }
1241 }
1242
1243 match cfg {
1244 SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1245 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1246 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1247 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1248 penalties: Vec::new(),
1249 nullspace_dims: Vec::new(),
1250 basisname: "none".to_string(),
1251 degree: None,
1252 knots: None,
1253 keep_cols: None,
1254 smooth_lambda: None,
1255 }),
1256 SurvivalTimeBasisConfig::Linear => {
1257 let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1258 let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1259 let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1260 for i in 0..n {
1261 x_entry_time[[i, 0]] = 1.0;
1262 x_exit_time[[i, 0]] = 1.0;
1263 x_entry_time[[i, 1]] = log_entry[i];
1264 x_exit_time[[i, 1]] = log_exit[i];
1265 x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1266 }
1267 Ok(SurvivalTimeBuildOutput {
1268 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1269 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1270 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1271 penalties: Vec::new(),
1272 nullspace_dims: Vec::new(),
1273 basisname: "linear".to_string(),
1274 degree: None,
1275 knots: None,
1276 keep_cols: None,
1277 smooth_lambda: None,
1278 })
1279 }
1280 SurvivalTimeBasisConfig::BSpline {
1281 degree,
1282 knots,
1283 smooth_lambda,
1284 } => {
1285 let knotvec = if knots.is_empty() {
1286 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1287 "internal error: bspline time basis requested without knot source".to_string()
1288 })?;
1289 let combined = survival_time_knot_input(&log_entry, &log_exit);
1290 infer_survival_time_knots(
1291 &combined,
1292 degree,
1293 degree,
1294 num_internal_knots,
1295 BasisOptions::value(),
1296 )?
1297 } else {
1298 knots
1299 };
1300
1301 let entry_basis = build_bspline_basis_1d(
1302 log_entry.view(),
1303 &BSplineBasisSpec {
1304 degree,
1305 penalty_order: 2,
1306 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1307 double_penalty: false,
1308 identifiability: BSplineIdentifiability::None,
1309 boundary: OneDimensionalBoundary::Open,
1310 boundary_conditions: BSplineBoundaryConditions::default(),
1311 },
1312 )
1313 .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1314 let exit_basis = build_bspline_basis_1d(
1315 log_exit.view(),
1316 &BSplineBasisSpec {
1317 degree,
1318 penalty_order: 2,
1319 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1320 double_penalty: false,
1321 identifiability: BSplineIdentifiability::None,
1322 boundary: OneDimensionalBoundary::Open,
1323 boundary_conditions: BSplineBoundaryConditions::default(),
1324 },
1325 )
1326 .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1327
1328 let p_time = exit_basis.design.ncols();
1329 let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1333 let mut deriv_buf = vec![0.0_f64; p_time];
1334 for i in 0..n {
1335 deriv_buf.fill(0.0);
1336 evaluate_bspline_derivative_scalar(
1337 log_exit[i],
1338 knotvec.view(),
1339 degree,
1340 &mut deriv_buf,
1341 )
1342 .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1343 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1344 for j in 0..p_time {
1345 let v = deriv_buf[j] * chain;
1346 if v.abs() > 1e-15 {
1347 deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1348 }
1349 }
1350 }
1351 let x_derivative_time =
1352 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1353 {
1354 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1355 Err(_) => {
1356 let mut dense = Array2::<f64>::zeros((n, p_time));
1358 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1359 dense[[row, col]] = val;
1360 }
1361 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1362 }
1363 };
1364
1365 Ok(SurvivalTimeBuildOutput {
1366 x_entry_time: entry_basis.design,
1367 x_exit_time: exit_basis.design,
1368 x_derivative_time,
1369 nullspace_dims: entry_basis.nullspace_dims,
1370 penalties: entry_basis.penalties,
1371 basisname: "bspline".to_string(),
1372 degree: Some(degree),
1373 knots: Some(knotvec.to_vec()),
1374 keep_cols: None,
1375 smooth_lambda: Some(smooth_lambda),
1376 })
1377 }
1378 SurvivalTimeBasisConfig::ISpline {
1379 degree,
1380 knots,
1381 keep_cols,
1382 smooth_lambda,
1383 } => {
1384 let bspline_degree = degree
1385 .checked_add(1)
1386 .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1387 let knotvec = if knots.is_empty() {
1388 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1389 "internal error: ispline time basis requested without knot source".to_string()
1390 })?;
1391 let combined = survival_time_knot_input(&log_entry, &log_exit);
1392 infer_survival_time_knots(
1393 &combined,
1394 bspline_degree,
1395 degree,
1396 num_internal_knots,
1397 BasisOptions::i_spline(),
1398 )?
1399 } else {
1400 knots
1401 };
1402
1403 let (db_exit_arc, _) = create_basis::<Dense>(
1404 log_exit.view(),
1405 KnotSource::Provided(knotvec.view()),
1406 bspline_degree,
1407 BasisOptions::first_derivative(),
1408 )
1409 .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1410
1411 let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1414 let (entry_arc, _) = create_basis::<Dense>(
1415 log_entry.view(),
1416 KnotSource::Provided(knotvec.view()),
1417 degree,
1418 BasisOptions::i_spline(),
1419 )
1420 .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1421 let (exit_arc, _) = create_basis::<Dense>(
1422 log_exit.view(),
1423 KnotSource::Provided(knotvec.view()),
1424 degree,
1425 BasisOptions::i_spline(),
1426 )
1427 .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1428
1429 let x_entry_full = entry_arc.as_ref();
1430 let x_exit_full = exit_arc.as_ref();
1431 let p_time_full = x_exit_full.ncols();
1432 if p_time_full == 0 {
1433 return Err(SurvivalConstructionError::BasisConstructionFailed {
1434 reason: "internal error: empty ispline time basis".to_string(),
1435 }
1436 .into());
1437 }
1438 let db_exit = db_exit_arc.as_ref();
1439 if db_exit.ncols() != p_time_full + 1 {
1440 return Err(
1441 "internal error: ispline derivative basis width must exceed basis width by one"
1442 .to_string(),
1443 );
1444 }
1445
1446 let keep_cols = if keep_cols.is_empty() {
1447 let constant_tol = 1e-12_f64;
1448 let mut inferred_keep_cols: Vec<usize> = Vec::new();
1449 for j in 0..p_time_full {
1450 let mut minv = f64::INFINITY;
1451 let mut maxv = f64::NEG_INFINITY;
1452 for i in 0..n {
1453 let ve = x_exit_full[[i, j]];
1454 let vs = x_entry_full[[i, j]];
1455 minv = minv.min(ve.min(vs));
1456 maxv = maxv.max(ve.max(vs));
1457 }
1458 if (maxv - minv) > constant_tol {
1459 inferred_keep_cols.push(j);
1460 }
1461 }
1462 inferred_keep_cols
1463 } else {
1464 keep_cols
1465 };
1466 if keep_cols.is_empty() {
1467 return Err(
1468 "internal error: ispline basis has no shape-varying time columns"
1469 .to_string(),
1470 );
1471 }
1472 if keep_cols.iter().any(|&j| j >= p_time_full) {
1473 return Err(SurvivalConstructionError::MissingColumn {
1474 reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1475 }
1476 .into());
1477 }
1478
1479 let p_time = keep_cols.len();
1480 let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1481 let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1482 (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1485 };
1486 let db_exit = db_exit_arc.as_ref();
1487
1488 let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1493 let mut found_nonfinite: Option<(usize, usize)> = None;
1494 for i in 0..n {
1495 let mut running = 0.0_f64;
1496 let mut d_i_log_full = vec![0.0_f64; p_time_full];
1497 for j in (1..db_exit.ncols()).rev() {
1498 let term = db_exit[[i, j]];
1499 if term.is_finite() {
1500 running += term;
1501 }
1502 d_i_log_full[j - 1] = running;
1503 }
1504 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1505 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1506 let raw_v = d_i_log_full[j_old] * chain;
1507 let v = if (-1e-12..0.0).contains(&raw_v) {
1508 0.0
1509 } else {
1510 raw_v
1511 };
1512 if !v.is_finite() {
1513 found_nonfinite = Some((i, j_new));
1514 }
1515 if v < -1e-12 {
1516 return Err(format!(
1517 "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1518 i + 1,
1519 j_new + 1,
1520 v
1521 ));
1522 }
1523 if v.abs() > 1e-15 {
1524 deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1525 }
1526 }
1527 }
1528 if let Some((row, col)) = found_nonfinite {
1529 return Err(format!(
1530 "survival ispline derivative basis produced non-finite value at row {}, column {}",
1531 row + 1,
1532 col + 1
1533 ));
1534 }
1535 let x_derivative_time =
1536 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1537 {
1538 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1539 Err(_) => {
1540 let mut dense = Array2::<f64>::zeros((n, p_time));
1541 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1542 dense[[row, col]] = val;
1543 }
1544 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1545 }
1546 };
1547
1548 let penalty_basis = build_bspline_basis_1d(
1549 log_exit.view(),
1550 &BSplineBasisSpec {
1551 degree: bspline_degree,
1552 penalty_order: 2,
1553 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1554 double_penalty: false,
1555 identifiability: BSplineIdentifiability::None,
1556 boundary: OneDimensionalBoundary::Open,
1557 boundary_conditions: BSplineBoundaryConditions::default(),
1558 },
1559 )
1560 .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1561 if penalty_basis.design.ncols() != p_time_full + 1 {
1562 return Err("internal error: ispline penalty dimension mismatch".to_string());
1563 }
1564 let mut penalties = Vec::<Array2<f64>>::new();
1598 for s_mat in &penalty_basis.penalties {
1599 if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1600 continue;
1601 }
1602 let s_increment = s_mat.slice(s![1.., 1..]);
1631 if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1632 return Err(format!(
1633 "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1634 s_increment.nrows(),
1635 s_increment.ncols(),
1636 ));
1637 }
1638 let mut s_full = s_increment.to_owned();
1643 symmetrize_in_place(&mut s_full);
1644 let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1648 for i in 0..p_time_full {
1649 for j in 0..p_time_full {
1650 let mut v = 0.0;
1651 for k in j..p_time_full {
1652 v += s_full[[i, k]];
1653 }
1654 s_mid_full[[i, j]] = v;
1655 }
1656 }
1657 let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1661 for i in 0..p_time_full {
1662 for j in 0..p_time_full {
1663 let mut v = 0.0;
1664 for k in i..p_time_full {
1665 v += s_mid_full[[k, j]];
1666 }
1667 s_full_congruent[[i, j]] = v;
1668 }
1669 }
1670 let mut local = Array2::<f64>::zeros((p_time, p_time));
1672 for (i_new, &i_old) in keep_cols.iter().enumerate() {
1673 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1674 local[[i_new, j_new]] = 0.5
1677 * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1678 }
1679 }
1680 penalties.push(local);
1681 }
1682
1683 for (idx, s_mat) in penalties.iter().enumerate() {
1693 let p = s_mat.nrows();
1694 if p == 0 {
1695 continue;
1696 }
1697 if let Ok((evals, _)) =
1698 gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1699 {
1700 let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1701 "internal error: ispline penalty eigenvalues not contiguous".to_string()
1702 })?;
1703 let max_ev = evals_slice
1704 .iter()
1705 .copied()
1706 .fold(0.0_f64, |a, b| a.max(b.abs()))
1707 .max(1.0);
1708 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1709 let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1710 if min_ev < neg_tol {
1711 return Err(format!(
1712 "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1713 indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1714 {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1715 ));
1716 }
1717 }
1718 }
1719
1720 let nullspace_dims: Vec<usize> = penalties
1724 .iter()
1725 .map(|s_mat| {
1726 let p = s_mat.nrows();
1727 if p == 0 {
1728 return 0;
1729 }
1730 match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1731 Ok((evals, _)) => {
1732 let evals_slice: &[f64] = evals.as_slice().unwrap();
1733 let max_ev = evals_slice
1734 .iter()
1735 .copied()
1736 .fold(0.0_f64, |a, b| a.max(b.abs()))
1737 .max(1.0);
1738 let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1739 evals_slice.iter().filter(|&&e| e <= threshold).count()
1740 }
1741 Err(_) => 0,
1742 }
1743 })
1744 .collect();
1745 Ok(SurvivalTimeBuildOutput {
1746 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1747 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1748 x_derivative_time,
1749 penalties,
1750 nullspace_dims,
1751 basisname: "ispline".to_string(),
1752 degree: Some(degree),
1753 knots: Some(knotvec.to_vec()),
1754 keep_cols: Some(keep_cols),
1755 smooth_lambda: Some(smooth_lambda),
1756 })
1757 }
1758 }
1759}
1760
1761pub fn resolved_survival_time_basis_config_from_build(
1762 basisname: &str,
1763 degree: Option<usize>,
1764 knots: Option<&Vec<f64>>,
1765 keep_cols: Option<&Vec<usize>>,
1766 smooth_lambda: Option<f64>,
1767) -> Result<SurvivalTimeBasisConfig, String> {
1768 match basisname {
1769 "none" => Ok(SurvivalTimeBasisConfig::None),
1770 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1771 "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1772 degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1773 knots: Array1::from_vec(
1774 knots
1775 .cloned()
1776 .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1777 ),
1778 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1779 }),
1780 "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1781 degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1782 knots: Array1::from_vec(
1783 knots
1784 .cloned()
1785 .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1786 ),
1787 keep_cols: keep_cols
1788 .cloned()
1789 .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1790 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1791 }),
1792 other => Err(format!("unsupported survival time basis '{other}'")),
1793 }
1794}
1795
1796pub fn resolve_survival_time_anchor_value(
1797 age_entry: &Array1<f64>,
1798 time_anchor: Option<f64>,
1799) -> Result<f64, String> {
1800 if age_entry.is_empty() {
1801 return Err("survival time anchor requires non-empty entry times".to_string());
1802 }
1803 let anchor = match time_anchor {
1804 Some(t_anchor) => {
1805 if !t_anchor.is_finite() || t_anchor < 0.0 {
1806 return Err(format!(
1807 "survival time anchor must be finite and non-negative, got {t_anchor}"
1808 ));
1809 }
1810 t_anchor
1811 }
1812 None => age_entry
1813 .iter()
1814 .copied()
1815 .min_by(f64::total_cmp)
1816 .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1817 };
1818 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1819}
1820
1821pub fn resolve_survival_marginal_slope_time_anchor_value(
1853 age_entry: &Array1<f64>,
1854 age_exit: &Array1<f64>,
1855 time_anchor: Option<f64>,
1856) -> Result<f64, String> {
1857 if age_entry.is_empty() || age_exit.is_empty() {
1858 return Err(
1859 "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1860 );
1861 }
1862 let anchor = match time_anchor {
1863 Some(t_anchor) => {
1864 if !t_anchor.is_finite() || t_anchor < 0.0 {
1865 return Err(format!(
1866 "survival time anchor must be finite and non-negative, got {t_anchor}"
1867 ));
1868 }
1869 t_anchor
1870 }
1871 None => robust_interior_exit_anchor(age_exit),
1872 };
1873 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1874}
1875
1876fn robust_interior_exit_anchor(age_exit: &Array1<f64>) -> f64 {
1883 let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1884 sorted.sort_by(f64::total_cmp);
1885 let m = sorted.len();
1886 if m == 0 {
1887 return SURVIVAL_TIME_FLOOR;
1888 }
1889 if m % 2 == 1 {
1890 sorted[m / 2]
1891 } else {
1892 0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1893 }
1894}
1895
1896pub fn resolve_survival_transformation_time_anchor_value(
1917 age_entry: &Array1<f64>,
1918 age_exit: &Array1<f64>,
1919 time_anchor: Option<f64>,
1920) -> Result<f64, String> {
1921 if time_anchor.is_some() {
1922 return resolve_survival_time_anchor_value(age_entry, time_anchor);
1923 }
1924 if age_exit.is_empty() {
1925 return Err(
1926 "survival transformation time anchor requires non-empty exit times".to_string(),
1927 );
1928 }
1929 let min_entry = age_entry
1930 .iter()
1931 .copied()
1932 .fold(f64::INFINITY, f64::min);
1933 if min_entry > SURVIVAL_DELAYED_ENTRY_THRESHOLD {
1934 Ok(robust_interior_exit_anchor(age_exit).max(SURVIVAL_TIME_FLOOR))
1935 } else {
1936 resolve_survival_time_anchor_value(age_entry, None)
1937 }
1938}
1939
1940pub fn evaluate_survival_time_basis_row(
1941 age: f64,
1942 cfg: &SurvivalTimeBasisConfig,
1943) -> Result<Array1<f64>, String> {
1944 if !age.is_finite() || age < 0.0 {
1945 return Err(format!(
1946 "survival time basis row requires finite non-negative age, got {age}"
1947 ));
1948 }
1949 let age = age.max(SURVIVAL_TIME_FLOOR);
1950 let log_age = array![age.ln()];
1951 match cfg {
1952 SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1953 SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1954 SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1955 if knots.is_empty() {
1956 return Err(
1957 "survival BSpline anchor evaluation requires resolved knot metadata"
1958 .to_string(),
1959 );
1960 }
1961 let built = build_bspline_basis_1d(
1962 log_age.view(),
1963 &BSplineBasisSpec {
1964 degree: *degree,
1965 penalty_order: 2,
1966 knotspec: BSplineKnotSpec::Provided(knots.clone()),
1967 double_penalty: false,
1968 identifiability: BSplineIdentifiability::None,
1969 boundary: OneDimensionalBoundary::Open,
1970 boundary_conditions: BSplineBoundaryConditions::default(),
1971 },
1972 )
1973 .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1974 Ok(built.design.to_dense().row(0).to_owned())
1975 }
1976 SurvivalTimeBasisConfig::ISpline {
1977 degree,
1978 knots,
1979 keep_cols,
1980 ..
1981 } => {
1982 if knots.is_empty() {
1983 return Err(
1984 "survival ISpline anchor evaluation requires resolved knot metadata"
1985 .to_string(),
1986 );
1987 }
1988 let (basis_arc, _) = create_basis::<Dense>(
1989 log_age.view(),
1990 KnotSource::Provided(knots.view()),
1991 *degree,
1992 BasisOptions::i_spline(),
1993 )
1994 .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1995 let basis = basis_arc.as_ref();
1996 let row = basis.row(0);
1997 if keep_cols.is_empty() {
1998 return Ok(row.to_owned());
1999 }
2000 if keep_cols.iter().any(|&j| j >= row.len()) {
2001 return Err(SurvivalConstructionError::MissingColumn {
2002 reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
2003 }
2004 .into());
2005 }
2006 Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
2007 }
2008 }
2009}
2010
2011pub fn center_survival_time_designs_at_anchor(
2012 design_entry: &mut DesignMatrix,
2013 design_exit: &mut DesignMatrix,
2014 anchor_row: &Array1<f64>,
2015) -> Result<(), String> {
2016 if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
2017 return Err(format!(
2018 "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
2019 design_entry.ncols(),
2020 design_exit.ncols(),
2021 anchor_row.len()
2022 ));
2023 }
2024 fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
2027 let mut dense = dm.to_dense();
2028 for mut row in dense.rows_mut() {
2029 row -= &anchor.view();
2030 }
2031 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
2032 }
2033 center_dense(design_entry, anchor_row);
2034 center_dense(design_exit, anchor_row);
2035 Ok(())
2036}
2037
2038pub fn baseline_offset_theta_partials(
2068 age: f64,
2069 cfg: &SurvivalBaselineConfig,
2070) -> Result<Option<Vec<(f64, f64)>>, String> {
2071 let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2072 else {
2073 return Ok(None);
2074 };
2075
2076 match params {
2077 ValidatedBaselineTarget::Weibull { scale, shape } => {
2078 let eta = shape * (age.ln() - scale.ln());
2087 let o_d = shape / age;
2088 let d_eta_d_log_scale = -shape;
2089 let d_od_d_log_scale = 0.0;
2090 let d_eta_d_log_shape = eta;
2091 let d_od_d_log_shape = o_d;
2092 Ok(Some(vec![
2093 (d_eta_d_log_scale, d_od_d_log_scale),
2094 (d_eta_d_log_shape, d_od_d_log_shape),
2095 ]))
2096 }
2097 ValidatedBaselineTarget::Gompertz { shape, .. } => {
2098 let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2108 Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2109 }
2110 ValidatedBaselineTarget::GompertzMakeham {
2111 rate,
2112 shape,
2113 makeham,
2114 } => {
2115 let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2130 let cum_total = makeham * age + cum_g;
2131 if cum_total <= 0.0 || !cum_total.is_finite() {
2132 return Err(SurvivalConstructionError::DataValidationFailed {
2133 reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2134 }
2135 .into());
2136 }
2137 let inst_total = makeham + inst_g;
2138 let o_d = inst_total / cum_total;
2139 let inv_cum = 1.0 / cum_total;
2140 let d_cum_dlr = cum_g;
2145 let d_inst_dlr = inst_g;
2146 let d_eta_dlr = d_cum_dlr * inv_cum;
2147 let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2148 let (d_cum_dshape, d_inst_dshape) =
2150 gompertz_cumulative_shape_derivative(age, rate, shape);
2151 let d_eta_dshape = d_cum_dshape * inv_cum;
2152 let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2153 let d_cum_dlm = makeham * age;
2156 let d_inst_dlm = makeham;
2157 let d_eta_dlm = d_cum_dlm * inv_cum;
2158 let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2159 Ok(Some(vec![
2160 (d_eta_dlr, d_od_dlr),
2161 (d_eta_dshape, d_od_dshape),
2162 (d_eta_dlm, d_od_dlm),
2163 ]))
2164 }
2165 }
2166}
2167
2168fn baseline_chain_rule_gradient_with_partials<F>(
2196 label: &'static str,
2197 age_entry: ndarray::ArrayView1<'_, f64>,
2198 age_exit: ndarray::ArrayView1<'_, f64>,
2199 age_right: ndarray::ArrayView1<'_, f64>,
2200 cfg: &SurvivalBaselineConfig,
2201 residuals: &crate::survival::OffsetChannelResiduals,
2202 partials: F,
2203) -> Result<Option<Array1<f64>>, String>
2204where
2205 F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2206{
2207 let n = age_exit.len();
2208 if age_entry.len() != n
2209 || age_right.len() != n
2210 || residuals.exit.len() != n
2211 || residuals.entry.len() != n
2212 || residuals.derivative.len() != n
2213 || residuals.right.len() != n
2214 {
2215 return Err(format!(
2216 "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2217 age_entry.len(),
2218 n,
2219 age_right.len(),
2220 residuals.exit.len(),
2221 residuals.entry.len(),
2222 residuals.derivative.len(),
2223 residuals.right.len(),
2224 ));
2225 }
2226 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2229 let theta_dim = match probe_age {
2230 Some(t) => match partials(t, cfg)? {
2231 None => return Ok(None),
2232 Some(v) => v.len(),
2233 },
2234 None => {
2235 return Err(format!("{label}: no valid positive age for dim probe"));
2236 }
2237 };
2238 let mut grad = Array1::<f64>::zeros(theta_dim);
2249 for i in 0..n {
2250 let partials_exit = partials(age_exit[i], cfg)?
2252 .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2253 if partials_exit.len() != theta_dim {
2254 return Err(format!(
2255 "{label}: theta_dim drifted ({} != {})",
2256 partials_exit.len(),
2257 theta_dim
2258 ));
2259 }
2260 let r_x = residuals.exit[i];
2261 let r_d = residuals.derivative[i];
2262 for k in 0..theta_dim {
2263 let (d_eta_dk, d_od_dk) = partials_exit[k];
2264 grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2265 }
2266 let r_e = residuals.entry[i];
2270 if r_e != 0.0 {
2271 let partials_entry = partials(age_entry[i], cfg)?
2272 .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2273 for k in 0..theta_dim {
2274 grad[k] += r_e * partials_entry[k].0;
2275 }
2276 }
2277 let r_r = residuals.right[i];
2286 if r_r != 0.0 {
2287 let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2288 format!("{label}: unexpected None from partials at right boundary")
2289 })?;
2290 if partials_right.len() != theta_dim {
2291 return Err(format!(
2292 "{label}: theta_dim drifted at right boundary ({} != {})",
2293 partials_right.len(),
2294 theta_dim
2295 ));
2296 }
2297 for k in 0..theta_dim {
2298 grad[k] += r_r * partials_right[k].0;
2299 }
2300 }
2301 }
2302 Ok(Some(grad))
2303}
2304
2305pub fn baseline_chain_rule_gradient(
2339 age_entry: ndarray::ArrayView1<'_, f64>,
2340 age_exit: ndarray::ArrayView1<'_, f64>,
2341 age_right: ndarray::ArrayView1<'_, f64>,
2342 cfg: &SurvivalBaselineConfig,
2343 residuals: &crate::survival::OffsetChannelResiduals,
2344) -> Result<Option<Array1<f64>>, String> {
2345 baseline_chain_rule_gradient_with_partials(
2346 "baseline_chain_rule_gradient",
2347 age_entry,
2348 age_exit,
2349 age_right,
2350 cfg,
2351 residuals,
2352 baseline_offset_theta_partials,
2353 )
2354}
2355
2356pub fn marginal_slope_baseline_chain_rule_gradient(
2363 age_entry: ndarray::ArrayView1<'_, f64>,
2364 age_exit: ndarray::ArrayView1<'_, f64>,
2365 cfg: &SurvivalBaselineConfig,
2366 residuals: &crate::survival::OffsetChannelResiduals,
2367) -> Result<Option<Array1<f64>>, String> {
2368 baseline_chain_rule_gradient_with_partials(
2372 "marginal_slope_baseline_chain_rule_gradient",
2373 age_entry,
2374 age_exit,
2375 age_exit,
2376 cfg,
2377 residuals,
2378 marginal_slope_baseline_offset_theta_partials,
2379 )
2380}
2381
2382#[inline]
2386fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2387 if shape.abs() < 1e-10 {
2388 let x = shape * age;
2391 (
2392 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2393 rate * (1.0 + x + 0.5 * x * x),
2394 )
2395 } else {
2396 let shape_age = shape * age;
2397 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2398 let instant_hazard = rate * shape_age.exp();
2399 (cumulative_hazard, instant_hazard)
2400 }
2401}
2402
2403#[inline]
2419fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2420 let x = shape * age;
2421 let dinstg_dshape = rate * age * x.exp();
2422 let dhg_dshape = if x.abs() < 1e-4 {
2431 let t = age;
2432 rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2434 } else {
2435 let e = x.exp();
2437 let em1 = x.exp_m1();
2438 let numerator = age * e * shape - em1;
2439 rate * numerator / (shape * shape)
2440 };
2441 (dhg_dshape, dinstg_dshape)
2442}
2443
2444#[inline]
2449fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2450 if shape.abs() < 1e-10 {
2451 let t = age;
2461 let d_eta = 0.5 * t + shape * t * t / 12.0;
2462 let dlog_od = 0.5 * t - shape * t * t / 12.0;
2463 let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2464 (d_eta, o_d * dlog_od)
2465 } else {
2466 let x = shape * age;
2467 let e = x.exp();
2468 let em1 = x.exp_m1(); let d_eta = -1.0 / shape + age * e / em1;
2470 let o_d = shape * e / em1;
2472 let dlog_od = 1.0 / shape - age / em1;
2473 (d_eta, o_d * dlog_od)
2474 }
2475}
2476
2477#[derive(Clone, Copy, Debug)]
2486enum ValidatedBaselineTarget {
2487 Weibull { scale: f64, shape: f64 },
2488 Gompertz { rate: f64, shape: f64 },
2489 GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2490}
2491
2492fn validated_baseline_params(
2498 age: f64,
2499 cfg: &SurvivalBaselineConfig,
2500 context: &str,
2501) -> Result<Option<ValidatedBaselineTarget>, String> {
2502 if !age.is_finite() || age <= 0.0 {
2503 return Err(format!(
2504 "survival ages must be finite and positive for {context}"
2505 ));
2506 }
2507
2508 match cfg.target {
2509 SurvivalBaselineTarget::Linear => Ok(None),
2510 SurvivalBaselineTarget::Weibull => {
2511 let scale = cfg
2512 .scale
2513 .ok_or_else(|| "weibull missing scale".to_string())?;
2514 let shape = cfg
2515 .shape
2516 .ok_or_else(|| "weibull missing shape".to_string())?;
2517 if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2518 return Err(SurvivalConstructionError::InvalidConfig {
2519 reason: "weibull baseline requires finite positive scale and shape".to_string(),
2520 }
2521 .into());
2522 }
2523 Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2524 }
2525 SurvivalBaselineTarget::Gompertz => {
2526 let rate = cfg
2527 .rate
2528 .ok_or_else(|| "gompertz missing rate".to_string())?;
2529 let shape = cfg
2530 .shape
2531 .ok_or_else(|| "gompertz missing shape".to_string())?;
2532 if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2533 return Err(
2534 "gompertz baseline requires finite positive rate and finite shape".to_string(),
2535 );
2536 }
2537 Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2538 }
2539 SurvivalBaselineTarget::GompertzMakeham => {
2540 let rate = cfg
2541 .rate
2542 .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2543 let shape = cfg
2544 .shape
2545 .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2546 let makeham = cfg
2547 .makeham
2548 .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2549 if !(rate.is_finite()
2550 && shape.is_finite()
2551 && makeham.is_finite()
2552 && rate > 0.0
2553 && makeham > 0.0)
2554 {
2555 return Err(
2556 "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2557 .to_string(),
2558 );
2559 }
2560 Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2561 rate,
2562 shape,
2563 makeham,
2564 }))
2565 }
2566 }
2567}
2568
2569fn survival_hazard_theta_partials(
2570 age: f64,
2571 cfg: &SurvivalBaselineConfig,
2572) -> Result<Option<Vec<(f64, f64)>>, String> {
2573 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2574 return Ok(None);
2575 };
2576
2577 match params {
2578 ValidatedBaselineTarget::Weibull { scale, shape } => {
2579 let log_time_ratio = age.ln() - scale.ln();
2580 let cumulative_hazard = (age / scale).powf(shape);
2581 let instant_hazard = shape * cumulative_hazard / age;
2582 let eta = shape * log_time_ratio;
2583 Ok(Some(vec![
2584 (-shape * cumulative_hazard, -shape * instant_hazard),
2585 (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2586 ]))
2587 }
2588 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2589 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2590 let (d_cum_dshape, d_inst_dshape) =
2591 gompertz_cumulative_shape_derivative(age, rate, shape);
2592 Ok(Some(vec![
2593 (cumulative_hazard, instant_hazard),
2594 (d_cum_dshape, d_inst_dshape),
2595 ]))
2596 }
2597 ValidatedBaselineTarget::GompertzMakeham {
2598 rate,
2599 shape,
2600 makeham,
2601 } => {
2602 let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2603 let (d_cum_dshape, d_inst_dshape) =
2604 gompertz_cumulative_shape_derivative(age, rate, shape);
2605 Ok(Some(vec![
2606 (cum_gompertz, inst_gompertz),
2607 (d_cum_dshape, d_inst_dshape),
2608 (makeham * age, makeham),
2609 ]))
2610 }
2611 }
2612}
2613
2614fn survival_cumulative_and_instant_hazard(
2615 age: f64,
2616 cfg: &SurvivalBaselineConfig,
2617) -> Result<Option<(f64, f64)>, String> {
2618 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2619 return Ok(None);
2620 };
2621
2622 match params {
2623 ValidatedBaselineTarget::Weibull { scale, shape } => {
2624 let cumulative_hazard = (age / scale).powf(shape);
2625 let instant_hazard = shape * cumulative_hazard / age;
2626 Ok(Some((cumulative_hazard, instant_hazard)))
2627 }
2628 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2629 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2630 Ok(Some((cumulative_hazard, instant_hazard)))
2631 }
2632 ValidatedBaselineTarget::GompertzMakeham {
2633 rate,
2634 shape,
2635 makeham,
2636 } => {
2637 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2638 Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2639 }
2640 }
2641}
2642
2643#[derive(Clone, Copy, Debug)]
2644struct MarginalSlopeBaselinePoint {
2645 instant_hazard: f64,
2646 q: f64,
2647 q_t: f64,
2648}
2649
2650fn evaluate_marginal_slope_baseline_point(
2651 age: f64,
2652 cfg: &SurvivalBaselineConfig,
2653) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2654 let Some((cumulative_hazard, instant_hazard)) =
2655 survival_cumulative_and_instant_hazard(age, cfg)?
2656 else {
2657 return Ok(None);
2658 };
2659 if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2660 return Err(format!(
2661 "{} marginal-slope baseline produced non-positive cumulative hazard",
2662 survival_baseline_targetname(cfg.target)
2663 ));
2664 }
2665 if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2666 return Err(format!(
2667 "{} marginal-slope baseline produced non-positive instant hazard",
2668 survival_baseline_targetname(cfg.target)
2669 ));
2670 }
2671 let survival = (-cumulative_hazard).exp();
2672 if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2673 return Err(format!(
2674 "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2675 survival_baseline_targetname(cfg.target)
2676 ));
2677 }
2678 let q = -standard_normal_quantile(survival).map_err(|e| {
2679 format!(
2680 "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2681 survival_baseline_targetname(cfg.target)
2682 )
2683 })?;
2684 let phi_q = normal_pdf(q);
2685 if !(phi_q.is_finite() && phi_q > 0.0) {
2686 return Err(format!(
2687 "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2688 survival_baseline_targetname(cfg.target)
2689 ));
2690 }
2691 Ok(Some(MarginalSlopeBaselinePoint {
2692 instant_hazard,
2693 q,
2694 q_t: instant_hazard * survival / phi_q,
2695 }))
2696}
2697
2698pub fn evaluate_survival_baseline(
2701 age: f64,
2702 cfg: &SurvivalBaselineConfig,
2703) -> Result<(f64, f64), String> {
2704 if !age.is_finite() || age < 0.0 {
2705 return Err(
2706 "survival ages must be finite and non-negative for baseline target evaluation"
2707 .to_string(),
2708 );
2709 }
2710
2711 if age == 0.0 {
2722 return match cfg.target {
2723 SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2724 SurvivalBaselineTarget::Weibull
2725 | SurvivalBaselineTarget::Gompertz
2726 | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2727 };
2728 }
2729
2730 let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2731 return Ok((0.0, 0.0));
2732 };
2733
2734 match params {
2735 ValidatedBaselineTarget::Weibull { scale, shape } => {
2736 let eta = shape * (age.ln() - scale.ln());
2737 let derivative = shape / age;
2738 Ok((eta, derivative))
2739 }
2740 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2741 let (h, inst) = gompertz_hazard_components(age, rate, shape);
2742 if h <= 0.0 || !h.is_finite() {
2743 return Err(if shape.abs() < 1e-10 {
2744 "invalid gompertz baseline at near-zero shape".to_string()
2745 } else {
2746 "gompertz baseline produced non-positive cumulative hazard".to_string()
2747 });
2748 }
2749 let derivative = inst / h;
2750 Ok((h.ln(), derivative))
2751 }
2752 ValidatedBaselineTarget::GompertzMakeham {
2753 rate,
2754 shape,
2755 makeham,
2756 } => {
2757 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2758 let h = makeham * age + h_gompertz;
2759 if h <= 0.0 || !h.is_finite() {
2760 return Err(
2761 "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2762 );
2763 }
2764 let inst = makeham + inst_gompertz;
2765 let derivative = inst / h;
2766 Ok((h.ln(), derivative))
2767 }
2768 }
2769}
2770
2771pub fn evaluate_survival_marginal_slope_baseline(
2777 age: f64,
2778 cfg: &SurvivalBaselineConfig,
2779) -> Result<(f64, f64), String> {
2780 if age == 0.0 {
2792 return Ok((0.0, 0.0));
2793 }
2794 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2795 return Ok((0.0, 0.0));
2796 };
2797 Ok((point.q, point.q_t))
2798}
2799
2800pub fn marginal_slope_baseline_offset_theta_partials(
2813 age: f64,
2814 cfg: &SurvivalBaselineConfig,
2815) -> Result<Option<Vec<(f64, f64)>>, String> {
2816 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2817 return Ok(None);
2818 };
2819 let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2820 .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2821 let a = point.q_t / point.instant_hazard;
2822 let a_log_derivative_factor = point.q * a - 1.0;
2823 Ok(Some(
2824 hazard_partials
2825 .into_iter()
2826 .map(|(d_h_cum, d_h_inst)| {
2827 (
2828 a * d_h_cum,
2829 a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2830 )
2831 })
2832 .collect(),
2833 ))
2834}
2835
2836pub fn marginal_slope_baseline_chain_rule_hessian(
2839 age_entry: ndarray::ArrayView1<'_, f64>,
2840 age_exit: ndarray::ArrayView1<'_, f64>,
2841 cfg: &SurvivalBaselineConfig,
2842 residuals: &crate::survival::OffsetChannelResiduals,
2843 curvatures: &crate::survival::OffsetChannelCurvatures,
2844) -> Result<Option<Array2<f64>>, String> {
2845 let n = age_exit.len();
2846 if age_entry.len() != n
2847 || residuals.exit.len() != n
2848 || residuals.entry.len() != n
2849 || residuals.derivative.len() != n
2850 || curvatures.rows.len() != n
2851 {
2852 return Err(format!(
2853 "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2854 age_entry.len(),
2855 n,
2856 residuals.exit.len(),
2857 residuals.entry.len(),
2858 residuals.derivative.len(),
2859 curvatures.rows.len(),
2860 ));
2861 }
2862 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2863 let dim = match probe_age {
2864 Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2865 None => return Ok(None),
2866 Some(parts) => parts.first.len(),
2867 },
2868 None => {
2869 return Err(
2870 "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2871 .to_string(),
2872 );
2873 }
2874 };
2875 let hessian = (0..n)
2880 .into_par_iter()
2881 .try_fold(
2882 || Array2::<f64>::zeros((dim, dim)),
2883 |mut acc, i| -> Result<Array2<f64>, String> {
2884 let exit_parts =
2885 marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2886 .ok_or_else(|| {
2887 "unexpected None from marginal-slope second partials at exit"
2888 .to_string()
2889 })?;
2890 if exit_parts.first.len() != dim {
2891 return Err(
2892 "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2893 );
2894 }
2895 let mut entry_parts = None;
2896 if residuals.entry[i] != 0.0 {
2897 entry_parts = Some(
2898 marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2899 .ok_or_else(|| {
2900 "unexpected None from marginal-slope second partials at entry"
2901 .to_string()
2902 })?,
2903 );
2904 }
2905 for a in 0..dim {
2906 for b in 0..dim {
2907 let j_exit_a = exit_parts.first[a].0;
2908 let j_exit_b = exit_parts.first[b].0;
2909 let j_deriv_a = exit_parts.first[a].1;
2910 let j_deriv_b = exit_parts.first[b].1;
2911 let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2912 + residuals.derivative[i] * exit_parts.second[a][b].1;
2913 if let Some(parts) = entry_parts.as_ref() {
2914 value += residuals.entry[i] * parts.second[a][b].0;
2915 }
2916 let curv = curvatures.rows[i];
2917 let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2918 let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2919 let ja = [j_entry_a, j_exit_a, j_deriv_a];
2920 let jb = [j_entry_b, j_exit_b, j_deriv_b];
2921 for u in 0..3 {
2922 for v in 0..3 {
2923 value += ja[u] * curv[u][v] * jb[v];
2924 }
2925 }
2926 acc[[a, b]] += value;
2927 }
2928 }
2929 Ok(acc)
2930 },
2931 )
2932 .try_reduce(|| Array2::<f64>::zeros((dim, dim)), |a, b| Ok(a + b))?;
2933 Ok(Some(hessian))
2934}
2935
2936struct MarginalSlopeThetaSecondPartials {
2937 first: Vec<(f64, f64)>,
2938 second: Vec<Vec<(f64, f64)>>,
2939}
2940
2941fn marginal_slope_baseline_offset_theta_second_partials(
2942 age: f64,
2943 cfg: &SurvivalBaselineConfig,
2944) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2945 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2946 return Ok(None);
2947 };
2948 let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2949 return Ok(None);
2950 };
2951 let (cum_hazard, instant_hazard) = hazard;
2952 let survival = (-cum_hazard).exp();
2953 let a = survival / normal_pdf(point.q);
2954 let b = point.q * a - 1.0;
2955 let b_factor = a + point.q * b;
2956 let dim = first.len();
2957 let mut first_out = Vec::with_capacity(dim);
2958 let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2959 for i in 0..dim {
2960 let (h_i, inst_i) = first[i];
2961 first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2962 }
2963 for i in 0..dim {
2964 for j in 0..dim {
2965 let (h_i, inst_i) = first[i];
2966 let (h_j, inst_j) = first[j];
2967 let (h_ij, inst_ij) = second[i][j];
2968 let a_j = a * b * h_j;
2969 let b_j = a * h_j * b_factor;
2970 let q_ij = a * h_ij + a * b * h_i * h_j;
2971 let qt_inner_i = inst_i + instant_hazard * b * h_i;
2972 let qt_ij = a_j * qt_inner_i
2973 + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2974 second_out[i][j] = (q_ij, qt_ij);
2975 }
2976 }
2977 Ok(Some(MarginalSlopeThetaSecondPartials {
2978 first: first_out,
2979 second: second_out,
2980 }))
2981}
2982
2983type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2984
2985fn survival_hazard_theta_first_second(
2986 age: f64,
2987 cfg: &SurvivalBaselineConfig,
2988) -> Result<Option<HazardFirstSecond>, String> {
2989 let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2990 return Ok(None);
2991 };
2992 let first = survival_hazard_theta_partials(age, cfg)?
2993 .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2994 let dim = first.len();
2995 let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2996 match cfg.target {
2997 SurvivalBaselineTarget::Linear => return Ok(None),
2998 SurvivalBaselineTarget::Weibull => {
2999 let scale = cfg
3000 .scale
3001 .ok_or_else(|| "weibull missing scale".to_string())?;
3002 let shape = cfg
3003 .shape
3004 .ok_or_else(|| "weibull missing shape".to_string())?;
3005 let log_time_ratio = age.ln() - scale.ln();
3006 let cumulative_hazard = hazard.0;
3007 let instant_hazard = hazard.1;
3008 let eta = shape * log_time_ratio;
3009 second[0][0] = (
3010 shape * shape * cumulative_hazard,
3011 shape * shape * instant_hazard,
3012 );
3013 second[0][1] = (
3014 -shape * cumulative_hazard * (1.0 + eta),
3015 -shape * instant_hazard * (2.0 + eta),
3016 );
3017 second[1][0] = second[0][1];
3018 second[1][1] = (
3019 eta * cumulative_hazard * (1.0 + eta),
3020 (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
3021 );
3022 }
3023 SurvivalBaselineTarget::Gompertz => {
3024 let rate = cfg
3025 .rate
3026 .ok_or_else(|| "gompertz missing rate".to_string())?;
3027 let shape = cfg
3028 .shape
3029 .ok_or_else(|| "gompertz missing shape".to_string())?;
3030 second[0][0] = first[0];
3031 second[0][1] = first[1];
3032 second[1][0] = first[1];
3033 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3034 }
3035 SurvivalBaselineTarget::GompertzMakeham => {
3036 let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
3037 let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
3038 second[0][0] = first[0];
3039 second[0][1] = first[1];
3040 second[1][0] = first[1];
3041 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3042 second[2][2] = first[2];
3043 }
3044 }
3045 Ok(Some((hazard, first, second)))
3046}
3047
3048#[inline]
3049fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3050 let x = shape * age;
3051 if x.abs() < 1e-3 {
3063 let t = age;
3064 (
3065 rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3066 rate * t * t * (1.0 + x + 0.5 * x * x),
3067 )
3068 } else {
3069 let e = x.exp();
3070 let em1 = x.exp_m1();
3071 let n = shape * age * e - em1;
3072 (
3073 rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3074 rate * age * age * e,
3075 )
3076 }
3077}
3078
3079#[derive(Clone, Copy)]
3084enum BaselineOffsetEvaluator {
3085 LogCumulativeHazard,
3086 ProbitSurvival,
3087}
3088
3089impl BaselineOffsetEvaluator {
3090 fn length_error(self) -> String {
3091 match self {
3092 Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3093 reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3094 }
3095 .into(),
3096 Self::ProbitSurvival => {
3097 "survival probit baseline offsets require matching entry/exit lengths".to_string()
3098 }
3099 }
3100 }
3101
3102 fn finite_error(self) -> &'static str {
3103 match self {
3104 Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3105 Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3106 }
3107 }
3108
3109 fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3110 match self {
3111 Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3112 Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3113 }
3114 }
3115
3116 fn exit_is_finite(self, value: f64, age: f64) -> bool {
3117 match self {
3118 Self::LogCumulativeHazard => {
3119 value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3120 }
3121 Self::ProbitSurvival => value.is_finite(),
3122 }
3123 }
3124}
3125
3126fn build_survival_offsets_with_evaluator(
3127 age_entry: &Array1<f64>,
3128 age_exit: &Array1<f64>,
3129 cfg: &SurvivalBaselineConfig,
3130 evaluator: BaselineOffsetEvaluator,
3131) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3132 if age_entry.len() != age_exit.len() {
3133 return Err(evaluator.length_error());
3134 }
3135 let n = age_entry.len();
3136 let triples: Vec<(f64, f64, f64)> = (0..n)
3139 .into_par_iter()
3140 .map(|i| -> Result<(f64, f64, f64), String> {
3141 let entry_age = age_entry[i];
3145 let e0 = if !entry_age.is_finite() {
3146 return Err(SurvivalConstructionError::DataValidationFailed {
3147 reason: format!("non-finite entry age at row {i}"),
3148 }
3149 .into());
3150 } else if entry_age <= 0.0 {
3151 0.0
3152 } else {
3153 evaluator.evaluate(entry_age, cfg)?.0
3154 };
3155 let exit_age = age_exit[i];
3156 let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3157 if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3158 return Err(SurvivalConstructionError::DataValidationFailed {
3159 reason: evaluator.finite_error().to_string(),
3160 }
3161 .into());
3162 }
3163 Ok((e0, e1, d1))
3164 })
3165 .collect::<Result<Vec<_>, String>>()?;
3166 let mut eta_entry = Array1::<f64>::zeros(n);
3167 let mut eta_exit = Array1::<f64>::zeros(n);
3168 let mut derivative_exit = Array1::<f64>::zeros(n);
3169 for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3170 eta_entry[i] = e0;
3171 eta_exit[i] = e1;
3172 derivative_exit[i] = d1;
3173 }
3174 Ok((eta_entry, eta_exit, derivative_exit))
3175}
3176
3177pub fn build_survival_baseline_offsets(
3180 age_entry: &Array1<f64>,
3181 age_exit: &Array1<f64>,
3182 cfg: &SurvivalBaselineConfig,
3183) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3184 build_survival_offsets_with_evaluator(
3185 age_entry,
3186 age_exit,
3187 cfg,
3188 BaselineOffsetEvaluator::LogCumulativeHazard,
3189 )
3190}
3191
3192pub fn build_survival_marginal_slope_baseline_offsets(
3195 age_entry: &Array1<f64>,
3196 age_exit: &Array1<f64>,
3197 cfg: &SurvivalBaselineConfig,
3198) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3199 build_survival_offsets_with_evaluator(
3200 age_entry,
3201 age_exit,
3202 cfg,
3203 BaselineOffsetEvaluator::ProbitSurvival,
3204 )
3205}
3206
3207pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3208 matches!(
3209 inverse_link,
3210 Some(
3211 InverseLink::Standard(StandardLink::Probit)
3212 | InverseLink::LatentCLogLog(_)
3213 | InverseLink::Sas(_)
3214 | InverseLink::BetaLogistic(_)
3215 | InverseLink::Mixture(_)
3216 )
3217 )
3218}
3219
3220pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3221 match likelihood_mode {
3222 SurvivalLikelihoodMode::LocationScale
3223 | SurvivalLikelihoodMode::Latent
3224 | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3225 SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3226 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3227 }
3228}
3229
3230pub fn build_survival_time_offsets_for_likelihood(
3231 age_entry: &Array1<f64>,
3232 age_exit: &Array1<f64>,
3233 baseline_cfg: &SurvivalBaselineConfig,
3234 likelihood_mode: SurvivalLikelihoodMode,
3235 inverse_link: Option<&InverseLink>,
3236) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3237 if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3238 || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3239 && location_scale_uses_probit_survival_baseline(inverse_link))
3240 {
3241 build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3242 } else {
3243 build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3244 }
3245}
3246
3247pub fn add_survival_time_derivative_guard_offset(
3248 age_entry: &Array1<f64>,
3249 age_exit: &Array1<f64>,
3250 anchor_time: f64,
3251 derivative_guard: f64,
3252 eta_offset_entry: &mut Array1<f64>,
3253 eta_offset_exit: &mut Array1<f64>,
3254 derivative_offset_exit: &mut Array1<f64>,
3255) -> Result<(), String> {
3256 if derivative_guard <= 0.0 {
3257 return Ok(());
3258 }
3259 let n = age_entry.len();
3260 if age_exit.len() != n
3261 || eta_offset_entry.len() != n
3262 || eta_offset_exit.len() != n
3263 || derivative_offset_exit.len() != n
3264 {
3265 return Err(SurvivalConstructionError::IncompatibleDimensions {
3266 reason: "survival derivative-guard offset lengths must match".to_string(),
3267 }
3268 .into());
3269 }
3270 for i in 0..n {
3271 eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3272 eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3273 derivative_offset_exit[i] += derivative_guard;
3274 }
3275 Ok(())
3276}
3277
3278#[derive(Clone, Debug)]
3279pub struct LatentSurvivalBaselineOffsets {
3280 pub loaded_eta_entry: Array1<f64>,
3281 pub loaded_eta_exit: Array1<f64>,
3282 pub loaded_derivative_exit: Array1<f64>,
3283 pub unloaded_mass_entry: Array1<f64>,
3284 pub unloaded_mass_exit: Array1<f64>,
3285 pub unloaded_hazard_exit: Array1<f64>,
3286}
3287
3288pub fn build_latent_survival_baseline_offsets(
3289 age_entry: &Array1<f64>,
3290 age_exit: &Array1<f64>,
3291 cfg: &SurvivalBaselineConfig,
3292 loading: HazardLoading,
3293) -> Result<LatentSurvivalBaselineOffsets, String> {
3294 if age_entry.len() != age_exit.len() {
3295 return Err(
3296 "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3297 );
3298 }
3299
3300 fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3301 if shape.abs() < 1e-10 {
3302 let x = shape * age;
3309 return (
3310 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3311 rate * (1.0 + x + 0.5 * x * x),
3312 );
3313 }
3314 let shape_age = shape * age;
3315 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3316 let instant_hazard = rate * shape_age.exp();
3317 (cumulative_hazard, instant_hazard)
3318 }
3319
3320 let n = age_entry.len();
3321
3322 let rows: Vec<[f64; 6]> = (0..n)
3325 .into_par_iter()
3326 .map(|i| -> Result<[f64; 6], String> {
3327 let entry = age_entry[i];
3328 let exit = age_exit[i];
3329 if !entry.is_finite()
3330 || !exit.is_finite()
3331 || entry <= 0.0
3332 || exit <= 0.0
3333 || exit < entry
3334 {
3335 return Err(format!(
3336 "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3337 i + 1
3338 ));
3339 }
3340 match loading {
3341 HazardLoading::Full => {
3342 let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3343 let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3344 Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3345 }
3346 HazardLoading::LoadedVsUnloaded => {
3347 if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3348 return Err(format!(
3349 "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3350 survival_baseline_targetname(cfg.target)
3351 ));
3352 }
3353 let rate = cfg.rate.ok_or_else(|| {
3354 "gompertz-makeham latent survival is missing baseline rate".to_string()
3355 })?;
3356 let shape = cfg.shape.ok_or_else(|| {
3357 "gompertz-makeham latent survival is missing baseline shape".to_string()
3358 })?;
3359 let makeham = cfg.makeham.ok_or_else(|| {
3360 "gompertz-makeham latent survival is missing baseline makeham".to_string()
3361 })?;
3362 let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3363 let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3364 if !(loaded_entry.is_finite()
3365 && loaded_entry > 0.0
3366 && loaded_exit.is_finite()
3367 && loaded_exit > 0.0
3368 && loaded_hazard.is_finite()
3369 && loaded_hazard > 0.0)
3370 {
3371 return Err(format!(
3372 "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3373 i + 1
3374 ));
3375 }
3376 Ok([
3377 loaded_entry.ln(),
3378 loaded_exit.ln(),
3379 loaded_hazard / loaded_exit,
3380 makeham * entry,
3381 makeham * exit,
3382 makeham,
3383 ])
3384 }
3385 }
3386 })
3387 .collect::<Result<Vec<_>, String>>()?;
3388
3389 let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3390 let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3391 let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3392 let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3393 let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3394 let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3395 for (i, row) in rows.into_iter().enumerate() {
3396 loaded_eta_entry[i] = row[0];
3397 loaded_eta_exit[i] = row[1];
3398 loaded_derivative_exit[i] = row[2];
3399 unloaded_mass_entry[i] = row[3];
3400 unloaded_mass_exit[i] = row[4];
3401 unloaded_hazard_exit[i] = row[5];
3402 }
3403
3404 Ok(LatentSurvivalBaselineOffsets {
3405 loaded_eta_entry,
3406 loaded_eta_exit,
3407 loaded_derivative_exit,
3408 unloaded_mass_entry,
3409 unloaded_mass_exit,
3410 unloaded_hazard_exit,
3411 })
3412}
3413
3414pub fn build_survival_timewiggle_derivative_design(
3419 eta_exit: &Array1<f64>,
3420 derivative_exit: &Array1<f64>,
3421 knots: &Array1<f64>,
3422 degree: usize,
3423) -> Result<Array2<f64>, String> {
3424 let mut design_derivative_exit =
3425 monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3426 for i in 0..design_derivative_exit.nrows() {
3427 let chain = derivative_exit[i];
3428 for j in 0..design_derivative_exit.ncols() {
3429 design_derivative_exit[[i, j]] *= chain;
3430 }
3431 }
3432 Ok(design_derivative_exit)
3433}
3434
3435pub fn build_survival_timewiggle_from_baseline(
3445 eta_entry: &Array1<f64>,
3446 eta_exit: &Array1<f64>,
3447 derivative_exit: &Array1<f64>,
3448 cfg: &LinkWiggleFormulaSpec,
3449) -> Result<SurvivalTimeWiggleBuild, String> {
3450 if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3451 return Err(
3452 "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3453 );
3454 }
3455 let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3458 && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3459 && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3460 if all_zero {
3461 return Err(
3462 "timewiggle requires a non-linear scalar survival baseline target; \
3463 the provided baseline offsets are all zero (linear baseline)"
3464 .to_string(),
3465 );
3466 }
3467 let n = eta_exit.len();
3468 let mut seed = Array1::<f64>::zeros(2 * n);
3469 for i in 0..n {
3470 seed[i] = eta_entry[i];
3471 seed[n + i] = eta_exit[i];
3472 }
3473 let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3477 let wiggle_cfg = WiggleBlockConfig {
3478 degree: cfg.degree,
3479 num_internal_knots: cfg.num_internal_knots,
3480 penalty_order: primary_order,
3481 double_penalty: cfg.double_penalty,
3482 };
3483 let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3484 append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3485 let ncols = combined_block.design.ncols();
3486 Ok(SurvivalTimeWiggleBuild {
3487 nullspace_dims: combined_block.nullspace_dims.clone(),
3488 penalties: {
3489 combined_block
3490 .penalties
3491 .into_iter()
3492 .map(|ps| ps.to_global(ncols))
3493 .collect()
3494 },
3495 knots,
3496 degree: cfg.degree,
3497 ncols,
3498 })
3499}
3500
3501pub fn append_zero_tail_columns(
3502 x_entry: &mut DesignMatrix,
3503 x_exit: &mut DesignMatrix,
3504 x_derivative: &mut DesignMatrix,
3505 tail_cols: usize,
3506) {
3507 if tail_cols == 0 {
3508 return;
3509 }
3510 fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3513 let old = dm.to_dense();
3514 let n = old.nrows();
3515 let p_base = old.ncols();
3516 let mut out = Array2::<f64>::zeros((n, p_base + tail));
3517 out.slice_mut(s![.., 0..p_base]).assign(&old);
3518 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3519 }
3520 append_dense(x_entry, tail_cols);
3521 append_dense(x_exit, tail_cols);
3522 append_dense(x_derivative, tail_cols);
3523}
3524
3525pub fn build_time_varying_survival_covariate_template(
3536 age_entry: &Array1<f64>,
3537 age_exit: &Array1<f64>,
3538 time_k: usize,
3539 time_degree: usize,
3540 block_name: &str,
3541) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3542 if time_k < time_degree + 1 {
3543 return Err(format!(
3544 "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3545 time_degree + 1
3546 ));
3547 }
3548 let num_internal_knots = time_k - (time_degree + 1);
3549
3550 let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3551 let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3552
3553 let time_spec = BSplineBasisSpec {
3554 degree: time_degree,
3555 penalty_order: 2,
3556 knotspec: BSplineKnotSpec::Automatic {
3557 num_internal_knots: Some(num_internal_knots),
3558 placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3559 },
3560 double_penalty: false,
3561 identifiability: BSplineIdentifiability::None,
3562 boundary: OneDimensionalBoundary::Open,
3563 boundary_conditions: BSplineBoundaryConditions::default(),
3564 };
3565
3566 let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3567 .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3568 let time_design_exit = time_build.design.to_dense();
3569
3570 let knots = match &time_build.metadata {
3571 BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3572 _ => {
3573 return Err(format!(
3574 "{block_name} time-margin basis returned unexpected metadata type"
3575 ));
3576 }
3577 };
3578
3579 let time_build_entry = build_bspline_basis_1d(
3580 log_entry.view(),
3581 &BSplineBasisSpec {
3582 degree: time_degree,
3583 penalty_order: 2,
3584 knotspec: BSplineKnotSpec::Provided(knots.clone()),
3585 double_penalty: false,
3586 identifiability: BSplineIdentifiability::None,
3587 boundary: OneDimensionalBoundary::Open,
3588 boundary_conditions: BSplineBoundaryConditions::default(),
3589 },
3590 )
3591 .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3592 let time_design_entry = time_build_entry.design.to_dense();
3593 let p_time = time_design_exit.ncols();
3594 let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3595 time_design_derivative_exit
3599 .as_slice_mut()
3600 .expect("zeros are contiguous")
3601 .par_chunks_mut(p_time)
3602 .enumerate()
3603 .try_for_each(|(i, row_out)| -> Result<(), String> {
3604 let mut deriv_buf = vec![0.0_f64; p_time];
3605 evaluate_bspline_derivative_scalar(
3606 log_exit[i],
3607 knots.view(),
3608 time_degree,
3609 &mut deriv_buf,
3610 )
3611 .map_err(|e| {
3612 format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3613 })?;
3614 let chain = 1.0 / age_exit[i].max(1e-12);
3615 for j in 0..p_time {
3616 row_out[j] = deriv_buf[j] * chain;
3617 }
3618 Ok(())
3619 })?;
3620
3621 Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3622 time_basis_entry: time_design_entry,
3623 time_basis_exit: time_design_exit,
3624 time_basis_derivative_exit: time_design_derivative_exit,
3625 time_penalties: time_build.penalties,
3626 })
3627}
3628
3629#[cfg(test)]
3630mod tests {
3631 use super::{
3632 SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3633 baseline_chain_rule_gradient, baseline_offset_theta_partials,
3634 build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3635 build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3636 evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3637 gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3638 marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3639 marginal_slope_baseline_offset_theta_partials,
3640 optimize_survival_baseline_config_with_gradient,
3641 optimize_survival_baseline_config_with_gradient_only,
3642 resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3643 survival_baseline_theta_from_config,
3644 };
3645 use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3646 use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3647 use crate::probability::normal_cdf;
3648 use ndarray::{Array1, Array2, array};
3649
3650 #[test]
3651 fn survival_timewiggle_keeps_requested_order_one_penalty() {
3652 let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3653 let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3654 let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3655 let cfg = LinkWiggleFormulaSpec {
3656 degree: 3,
3657 num_internal_knots: 4,
3658 penalty_orders: vec![1, 2, 3],
3659 double_penalty: false,
3660 };
3661
3662 let build =
3663 build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3664 .expect("build survival timewiggle");
3665
3666 assert_eq!(build.penalties.len(), 3);
3667 assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3668 assert!(build.ncols > 0);
3669 }
3670
3671 #[test]
3672 fn marginal_slope_time_anchor_defaults_to_median_exit() {
3673 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3674 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3675 let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3676 .expect("resolve marginal-slope default time anchor");
3677
3678 assert!(
3679 (anchor - 19.0).abs() <= 1e-12,
3680 "marginal-slope default anchor should be median exit, got {anchor}"
3681 );
3682 }
3683
3684 #[test]
3685 fn marginal_slope_time_anchor_honors_explicit_value() {
3686 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3687 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3688 let anchor =
3689 resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3690 .expect("resolve explicit marginal-slope time anchor");
3691
3692 assert!(
3693 (anchor - 7.5).abs() <= 1e-12,
3694 "explicit marginal-slope anchor should round-trip, got {anchor}"
3695 );
3696 }
3697
3698 #[test]
3709 fn baseline_optimizer_contracts_agree_on_shared_surface() {
3710 let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3715 let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3716
3717 let initial = SurvivalBaselineConfig {
3720 target: SurvivalBaselineTarget::Weibull,
3721 scale: Some(1.0),
3722 shape: Some(1.0),
3723 rate: None,
3724 makeham: None,
3725 };
3726
3727 let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3730 survival_baseline_theta_from_config(cfg)
3731 .expect("config→θ")
3732 .expect("Weibull config has a θ")
3733 };
3734
3735 let curvature_cost = curvature.clone();
3738 let star_cost = theta_star.clone();
3739 let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3740 let theta = survival_baseline_theta_from_config(cfg)?
3741 .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3742 let d = &theta - &star_cost;
3743 let ad = curvature_cost.dot(&d);
3744 Ok(0.5 * d.dot(&ad))
3745 };
3746
3747 let curvature_grad = curvature.clone();
3748 let star_grad = theta_star.clone();
3749 let cost_for_grad = cost_at.clone();
3750 let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3751 &initial,
3752 "baseline parity (gradient-only)",
3753 move |cfg| {
3754 let cost = cost_for_grad(cfg)?;
3755 let theta = survival_baseline_theta_from_config(cfg)?
3756 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3757 let gradient = curvature_grad.dot(&(&theta - &star_grad));
3758 Ok((cost, gradient))
3759 },
3760 )
3761 .expect("gradient-only baseline optimization converges");
3762
3763 let curvature_hess = curvature.clone();
3764 let star_hess = theta_star.clone();
3765 let cost_for_hess = cost_at.clone();
3766 let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3767 &initial,
3768 "baseline parity (gradient+Hessian)",
3769 move |cfg| {
3770 let cost = cost_for_hess(cfg)?;
3771 let theta = survival_baseline_theta_from_config(cfg)?
3772 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3773 let gradient = curvature_hess.dot(&(&theta - &star_hess));
3774 Ok((cost, gradient, curvature_hess.clone()))
3775 },
3776 )
3777 .expect("gradient+Hessian baseline optimization converges");
3778
3779 let theta_grad_only = recovered_theta(&result_grad_only);
3780 let theta_grad_hess = recovered_theta(&result_grad_hess);
3781
3782 for (label, theta) in [
3785 ("gradient-only", &theta_grad_only),
3786 ("gradient+Hessian", &theta_grad_hess),
3787 ] {
3788 let err = (theta - &theta_star)
3789 .mapv(f64::abs)
3790 .fold(0.0_f64, |a, &v| a.max(v));
3791 assert!(
3792 err <= 2e-3,
3793 "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3794 );
3795 }
3796
3797 let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3801 (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3802 };
3803 assert!(
3804 pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3805 "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3806 );
3807 }
3808
3809 #[test]
3810 fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3811 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3812 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3813 let requested_degree = 3;
3814 let num_internal_knots = 1;
3815
3816 let built = build_survival_time_basis(
3817 &age_entry,
3818 &age_exit,
3819 SurvivalTimeBasisConfig::ISpline {
3820 degree: requested_degree,
3821 knots: Array1::zeros(0),
3822 keep_cols: Vec::new(),
3823 smooth_lambda: 1e-2,
3824 },
3825 Some((num_internal_knots, 1e-2)),
3826 )
3827 .expect("automatic cubic ispline with one interior knot builds");
3828
3829 let working_degree = requested_degree + 1;
3830 let knots = built.knots.expect("resolved ispline knots");
3831 assert_eq!(
3832 knots.len(),
3833 num_internal_knots + 2 * (working_degree + 1),
3834 "I-spline automatic knots must be clamped for the working B-spline degree"
3835 );
3836 assert_eq!(built.degree, Some(requested_degree));
3837 assert!(built.x_exit_time.ncols() > 0);
3838 assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3839 assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3840 }
3841
3842 #[test]
3843 fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3844 let age_entry = array![1.0_f64, 1.0, 1.0];
3845 let age_exit = array![4.0_f64, 4.0, 4.0];
3846 let left = 1.0_f64.ln();
3847 let right = 4.0_f64.ln();
3848 let mid = left + 0.5 * (right - left);
3849 let knots = array![left, left, left, left, mid, right, right, right, right];
3850
3851 let built = build_survival_time_basis(
3852 &age_entry,
3853 &age_exit,
3854 SurvivalTimeBasisConfig::ISpline {
3855 degree: 2,
3856 knots,
3857 keep_cols: Vec::new(),
3858 smooth_lambda: 1e-2,
3859 },
3860 None,
3861 )
3862 .expect("build right-boundary ispline time basis");
3863
3864 let derivative = built.x_derivative_time.as_dense_cow();
3865 let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3866 assert!(
3867 max_abs > 1e-8,
3868 "right-boundary I-spline derivative must use the left-hand endpoint slope"
3869 );
3870 for row in derivative.rows() {
3871 assert!(
3872 row.iter().any(|v| *v > 1e-8),
3873 "each row at the right boundary needs a positive hazard derivative"
3874 );
3875 }
3876 }
3877
3878 #[test]
3879 fn ispline_time_penalty_is_psd_under_nontrivial_keep_cols() {
3880 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3899 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3900 let left = 1.0_f64.ln();
3901 let right = 21.0_f64.ln();
3902 let q1 = left + 0.25 * (right - left);
3903 let mid = left + 0.5 * (right - left);
3904 let q3 = left + 0.75 * (right - left);
3905 let knots = array![
3909 left, left, left, left, q1, mid, q3, right, right, right, right
3910 ];
3911
3912 let full = build_survival_time_basis(
3914 &age_entry,
3915 &age_exit,
3916 SurvivalTimeBasisConfig::ISpline {
3917 degree: 2,
3918 knots: knots.clone(),
3919 keep_cols: Vec::new(),
3920 smooth_lambda: 1e-2,
3921 },
3922 None,
3923 )
3924 .expect("build full-width ispline time basis");
3925 let p_time_full = full
3926 .keep_cols
3927 .as_ref()
3928 .map(|k| k.len())
3929 .unwrap_or_else(|| full.x_exit_time.ncols());
3930 assert!(
3931 p_time_full >= 3,
3932 "test needs at least 3 shape-varying columns to drop an interior one; got {p_time_full}"
3933 );
3934
3935 let keep_cols: Vec<usize> = (0..p_time_full).filter(|&j| j != 1).collect();
3938
3939 let built = build_survival_time_basis(
3940 &age_entry,
3941 &age_exit,
3942 SurvivalTimeBasisConfig::ISpline {
3943 degree: 2,
3944 knots,
3945 keep_cols: keep_cols.clone(),
3946 smooth_lambda: 1e-2,
3947 },
3948 None,
3949 )
3950 .expect(
3951 "reduced ispline penalty must build (PSD contract must accept the \
3952 congruence-first / select-second ordering)",
3953 );
3954
3955 assert_eq!(
3956 built.penalties.len(),
3957 1,
3958 "the ispline time basis should carry exactly one curvature penalty"
3959 );
3960 let s = &built.penalties[0];
3961 assert_eq!(s.nrows(), keep_cols.len());
3962 assert_eq!(s.ncols(), keep_cols.len());
3963
3964 let (evals, _) =
3965 gam_linalg::faer_ndarray::FaerEigh::eigh(s, faer::Side::Lower).expect("eigh of penalty");
3966 let evals_slice = evals.as_slice().expect("contiguous eigenvalues");
3967 let max_abs = evals_slice
3968 .iter()
3969 .copied()
3970 .fold(0.0_f64, |a, b| a.max(b.abs()))
3971 .max(1.0);
3972 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
3973 let tol = -100.0 * (s.nrows() as f64) * f64::EPSILON * max_abs;
3974 assert!(
3975 min_ev >= tol,
3976 "reduced I-spline time penalty must be PSD (gam#979): min eigenvalue \
3977 {min_ev:.3e} < tol {tol:.3e}, max|eig| {max_abs:.3e}"
3978 );
3979 }
3980
3981 #[test]
3982 fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3983 let cfg = SurvivalBaselineConfig {
3984 target: SurvivalBaselineTarget::GompertzMakeham,
3985 scale: None,
3986 shape: Some(0.07),
3987 rate: Some(0.012),
3988 makeham: Some(0.003),
3989 };
3990 let age = 11.5;
3991 let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3992 .expect("evaluate marginal-slope gompertz-makeham baseline");
3993 let shape = cfg.shape.expect("shape");
3994 let rate = cfg.rate.expect("rate");
3995 let makeham = cfg.makeham.expect("makeham");
3996 let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3997 let instant_hazard = makeham + rate * (shape * age).exp();
3998 let expected_survival = (-cumulative_hazard).exp();
3999 let actual_survival = normal_cdf(-q);
4000 assert!((actual_survival - expected_survival).abs() <= 1e-12);
4001
4002 let h = 1e-5;
4003 let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
4004 .expect("q plus")
4005 .0;
4006 let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
4007 .expect("q minus")
4008 .0;
4009 let fd = (q_plus - q_minus) / (2.0 * h);
4010 assert!((q_derivative - fd).abs() <= 1e-7);
4011 assert!(instant_hazard > 0.0);
4012 }
4013
4014 #[test]
4015 fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
4016 let configs = [
4025 SurvivalBaselineConfig {
4026 target: SurvivalBaselineTarget::Linear,
4027 scale: None,
4028 shape: None,
4029 rate: None,
4030 makeham: None,
4031 },
4032 SurvivalBaselineConfig {
4033 target: SurvivalBaselineTarget::Weibull,
4034 scale: Some(2.5),
4035 shape: Some(1.3),
4036 rate: None,
4037 makeham: None,
4038 },
4039 SurvivalBaselineConfig {
4040 target: SurvivalBaselineTarget::Gompertz,
4041 scale: None,
4042 shape: Some(0.05),
4043 rate: Some(0.01),
4044 makeham: None,
4045 },
4046 SurvivalBaselineConfig {
4047 target: SurvivalBaselineTarget::GompertzMakeham,
4048 scale: None,
4049 shape: Some(0.07),
4050 rate: Some(0.012),
4051 makeham: Some(0.003),
4052 },
4053 ];
4054 for cfg in &configs {
4055 let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
4058 .expect("marginal-slope baseline must be evaluable at the origin");
4059 assert_eq!(q0, 0.0);
4060 assert_eq!(q0_derivative, 0.0);
4061
4062 let (eta0, eta0_derivative) =
4066 evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
4067 assert!(eta0_derivative.is_finite());
4068 assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
4069
4070 let age_entry = array![0.0, 0.0];
4074 let age_exit = array![0.0, 1.5];
4075 let (entry, exit, derivative) =
4076 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
4077 .expect("probit baseline offsets must build through the origin");
4078 assert!(entry.iter().all(|v| v.is_finite()));
4079 assert!(exit.iter().all(|v| v.is_finite()));
4080 assert!(derivative.iter().all(|v| v.is_finite()));
4081 assert_eq!(exit[0], 0.0);
4083 }
4084 }
4085
4086 #[test]
4087 fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
4088 let cfg = SurvivalBaselineConfig {
4089 target: SurvivalBaselineTarget::GompertzMakeham,
4090 scale: None,
4091 shape: Some(0.03),
4092 rate: Some(0.01),
4093 makeham: Some(0.002),
4094 };
4095 let age_entry = array![2.0, 4.0];
4096 let age_exit = array![5.0, 9.0];
4097 let (entry, exit, derivative) =
4098 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
4099 .expect("marginal-slope baseline offsets");
4100 for i in 0..age_entry.len() {
4101 let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
4102 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4103 * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
4104 let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
4105 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4106 * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
4107 assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
4108 assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
4109 assert!(derivative[i].is_finite() && derivative[i] > 0.0);
4110 }
4111 }
4112
4113 fn fd_marginal_slope_baseline_offset(
4114 age: f64,
4115 cfg: &SurvivalBaselineConfig,
4116 steps: &[f64],
4117 ) -> Vec<(f64, f64)> {
4118 let theta = survival_baseline_theta_from_config(cfg)
4119 .expect("theta")
4120 .expect("non-linear baseline");
4121 assert_eq!(
4122 steps.len(),
4123 theta.len(),
4124 "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
4125 );
4126 (0..theta.len())
4127 .map(|k| {
4128 let h = steps[k];
4129 let mut theta_plus = theta.clone();
4130 theta_plus[k] += h;
4131 let mut theta_minus = theta.clone();
4132 theta_minus[k] -= h;
4133 let cfg_plus =
4134 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4135 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4136 .expect("minus cfg");
4137 let (q_p, qt_p) =
4138 evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
4139 let (q_m, qt_m) =
4140 evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
4141 ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
4142 })
4143 .collect()
4144 }
4145
4146 #[test]
4147 fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
4148 let cfg = SurvivalBaselineConfig {
4149 target: SurvivalBaselineTarget::GompertzMakeham,
4150 scale: None,
4151 shape: Some(0.04),
4152 rate: Some(0.013),
4153 makeham: Some(0.002),
4154 };
4155 let age = 17.0;
4156 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4157 .expect("partials")
4158 .expect("nonlinear");
4159 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
4160 assert_eq!(analytic.len(), fd.len());
4161 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4162 assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
4163 assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
4164 }
4165 }
4166
4167 #[test]
4168 fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4169 let cfg = SurvivalBaselineConfig {
4170 target: SurvivalBaselineTarget::GompertzMakeham,
4171 scale: None,
4172 shape: Some(1e-14),
4173 rate: Some(0.013),
4174 makeham: Some(0.002),
4175 };
4176 let age = 17.0;
4177 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4178 .expect("partials")
4179 .expect("nonlinear");
4180 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4181 assert_eq!(analytic.len(), fd.len());
4182 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4183 assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4184 assert_close(
4185 *aqt,
4186 *fqt,
4187 1e-5,
4188 &format!("near-zero gm-probit q' theta[{k}]"),
4189 );
4190 }
4191 }
4192
4193 fn shifted_quadratic_offset_residuals(
4194 age_entry: ndarray::ArrayView1<'_, f64>,
4195 age_exit: ndarray::ArrayView1<'_, f64>,
4196 base_cfg: &SurvivalBaselineConfig,
4197 candidate_cfg: &SurvivalBaselineConfig,
4198 base: &OffsetChannelResiduals,
4199 curvatures: &OffsetChannelCurvatures,
4200 ) -> OffsetChannelResiduals {
4201 let n = age_exit.len();
4202 let mut entry = base.entry.clone();
4203 let mut exit = base.exit.clone();
4204 let mut derivative = base.derivative.clone();
4205 for row in 0..n {
4206 let (_, base_exit, base_deriv) =
4207 baseline_marginal_slope_channels(age_exit[row], base_cfg);
4208 let (_, cand_exit, cand_deriv) =
4209 baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4210 let base_entry = if base.entry[row] == 0.0 {
4211 0.0
4212 } else {
4213 baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4214 };
4215 let cand_entry = if base.entry[row] == 0.0 {
4216 0.0
4217 } else {
4218 baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4219 };
4220 let delta = [
4221 cand_entry - base_entry,
4222 cand_exit - base_exit,
4223 cand_deriv - base_deriv,
4224 ];
4225 let mut shift = [0.0; 3];
4226 for i in 0..3 {
4227 for j in 0..3 {
4228 shift[i] += curvatures.rows[row][i][j] * delta[j];
4229 }
4230 }
4231 if base.entry[row] != 0.0 {
4232 entry[row] += shift[0];
4233 }
4234 exit[row] += shift[1];
4235 derivative[row] += shift[2];
4236 }
4237 OffsetChannelResiduals {
4238 entry,
4239 exit,
4240 derivative,
4241 right: base.right.clone(),
4242 }
4243 }
4244
4245 fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4246 let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4247 (q, q, q_t)
4248 }
4249
4250 #[test]
4251 fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4252 let cfg = SurvivalBaselineConfig {
4253 target: SurvivalBaselineTarget::GompertzMakeham,
4254 scale: None,
4255 shape: Some(0.025),
4256 rate: Some(0.012),
4257 makeham: Some(0.003),
4258 };
4259 let theta = survival_baseline_theta_from_config(&cfg)
4260 .expect("theta")
4261 .expect("nonlinear");
4262 let age_entry = array![2.5, 0.0, 5.0];
4263 let age_exit = array![7.5, 11.0, 15.0];
4264 let base_residuals = OffsetChannelResiduals {
4265 entry: array![0.2, 0.0, -0.1],
4266 exit: array![0.6, -0.3, 0.4],
4267 derivative: array![-0.5, 0.25, 0.15],
4268 right: Array1::<f64>::zeros(3),
4269 };
4270 let curvatures = OffsetChannelCurvatures {
4271 rows: vec![
4272 [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4273 [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4274 [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4275 ],
4276 };
4277 let analytic = marginal_slope_baseline_chain_rule_hessian(
4278 age_entry.view(),
4279 age_exit.view(),
4280 &cfg,
4281 &base_residuals,
4282 &curvatures,
4283 )
4284 .expect("hessian")
4285 .expect("nonlinear");
4286
4287 let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4288 let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4289 .expect("candidate cfg");
4290 let residuals = shifted_quadratic_offset_residuals(
4291 age_entry.view(),
4292 age_exit.view(),
4293 &cfg,
4294 &candidate,
4295 &base_residuals,
4296 &curvatures,
4297 );
4298 marginal_slope_baseline_chain_rule_gradient(
4299 age_entry.view(),
4300 age_exit.view(),
4301 &candidate,
4302 &residuals,
4303 )
4304 .expect("gradient")
4305 .expect("nonlinear")
4306 };
4307
4308 for j in 0..theta.len() {
4309 let step = if j == 1 { 2e-5 } else { 1e-5 };
4310 let mut plus = theta.clone();
4311 plus[j] += step;
4312 let mut minus = theta.clone();
4313 minus[j] -= step;
4314 let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4315 for i in 0..theta.len() {
4316 assert_close(
4317 analytic[[i, j]],
4318 fd_col[i],
4319 2e-5,
4320 &format!("baseline Hessian ({i},{j})"),
4321 );
4322 }
4323 }
4324 }
4325
4326 #[test]
4327 fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4328 let cfg = SurvivalBaselineConfig {
4329 target: SurvivalBaselineTarget::GompertzMakeham,
4330 scale: None,
4331 shape: Some(0.03),
4332 rate: Some(0.01),
4333 makeham: Some(0.002),
4334 };
4335 let age_entry = array![3.0, 6.0];
4336 let age_exit = array![8.0, 12.0];
4337 let residuals = OffsetChannelResiduals {
4338 exit: array![0.7, -0.2],
4339 entry: array![0.1, 0.4],
4340 derivative: array![1.3, -0.6],
4341 right: Array1::<f64>::zeros(2),
4342 };
4343 let grad = marginal_slope_baseline_chain_rule_gradient(
4344 age_entry.view(),
4345 age_exit.view(),
4346 &cfg,
4347 &residuals,
4348 )
4349 .expect("gradient")
4350 .expect("nonlinear");
4351
4352 let mut expected = Array1::<f64>::zeros(3);
4353 for i in 0..age_exit.len() {
4354 let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4355 .expect("exit partials")
4356 .expect("nonlinear");
4357 let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4358 .expect("entry partials")
4359 .expect("nonlinear");
4360 for k in 0..3 {
4361 expected[k] += residuals.exit[i] * exit_partials[k].0
4362 + residuals.derivative[i] * exit_partials[k].1
4363 + residuals.entry[i] * entry_partials[k].0;
4364 }
4365 }
4366 for k in 0..3 {
4367 assert_close(
4368 grad[k],
4369 expected[k],
4370 1e-12,
4371 &format!("gm-probit chain gradient theta[{k}]"),
4372 );
4373 }
4374 }
4375
4376 #[test]
4386 fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4387 let cfg = SurvivalBaselineConfig {
4388 target: SurvivalBaselineTarget::GompertzMakeham,
4389 scale: None,
4390 shape: Some(0.028),
4391 rate: Some(0.011),
4392 makeham: Some(0.0025),
4393 };
4394 let age_entry = array![3.0, 0.0, 5.5];
4397 let age_exit = array![8.0, 12.0, 16.0];
4398 let residuals = OffsetChannelResiduals {
4399 exit: array![0.7, -0.2, 0.45],
4400 entry: array![0.1, 0.0, -0.3],
4401 derivative: array![1.3, -0.6, 0.2],
4402 right: Array1::<f64>::zeros(3),
4403 };
4404
4405 let reference_gradient = |partials: &dyn Fn(
4408 f64,
4409 &SurvivalBaselineConfig,
4410 )
4411 -> Result<Option<Vec<(f64, f64)>>, String>|
4412 -> Array1<f64> {
4413 let theta_dim = partials(age_exit[0], &cfg)
4414 .expect("probe partials")
4415 .expect("nonlinear")
4416 .len();
4417 let mut acc = Array1::<f64>::zeros(theta_dim);
4418 for i in 0..age_exit.len() {
4419 let p_exit = partials(age_exit[i], &cfg)
4420 .expect("exit partials")
4421 .expect("nonlinear");
4422 let r_x = residuals.exit[i];
4423 let r_d = residuals.derivative[i];
4424 for k in 0..theta_dim {
4425 acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4426 }
4427 let r_e = residuals.entry[i];
4428 if r_e != 0.0 {
4429 let p_entry = partials(age_entry[i], &cfg)
4430 .expect("entry partials")
4431 .expect("nonlinear");
4432 for k in 0..theta_dim {
4433 acc[k] += r_e * p_entry[k].0;
4434 }
4435 }
4436 }
4437 acc
4438 };
4439
4440 let rp_engine = baseline_chain_rule_gradient(
4442 age_entry.view(),
4443 age_exit.view(),
4444 age_exit.view(),
4445 &cfg,
4446 &residuals,
4447 )
4448 .expect("rp gradient")
4449 .expect("rp nonlinear");
4450 let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4451 assert_eq!(rp_engine.len(), rp_reference.len());
4452 for k in 0..rp_engine.len() {
4453 assert_close(
4454 rp_engine[k],
4455 rp_reference[k],
4456 0.0,
4457 &format!("rp engine vs inline reference theta[{k}]"),
4458 );
4459 }
4460
4461 let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4463 age_entry.view(),
4464 age_exit.view(),
4465 &cfg,
4466 &residuals,
4467 )
4468 .expect("probit gradient")
4469 .expect("probit nonlinear");
4470 let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4471 assert_eq!(probit_engine.len(), probit_reference.len());
4472 for k in 0..probit_engine.len() {
4473 assert_close(
4474 probit_engine[k],
4475 probit_reference[k],
4476 0.0,
4477 &format!("probit engine vs inline reference theta[{k}]"),
4478 );
4479 }
4480 }
4481
4482 #[test]
4503 fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4504 let cfg = SurvivalBaselineConfig {
4505 target: SurvivalBaselineTarget::GompertzMakeham,
4506 scale: None,
4507 shape: Some(0.05),
4508 rate: Some(0.012),
4509 makeham: Some(0.003),
4510 };
4511 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4513 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4514 let residuals = OffsetChannelResiduals {
4517 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4518 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4519 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4520 right: Array1::<f64>::zeros(8),
4521 };
4522
4523 let analytic = baseline_chain_rule_gradient(
4524 age_entry.view(),
4525 age_exit.view(),
4526 age_exit.view(),
4527 &cfg,
4528 &residuals,
4529 )
4530 .expect("analytic gradient ok")
4531 .expect("GM baseline has a θ-gradient");
4532 assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4533
4534 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4540 let mut acc = 0.0;
4541 for i in 0..age_exit.len() {
4542 let (eta_exit_i, od_exit_i) =
4543 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4544 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4545 if residuals.entry[i] != 0.0 {
4546 let (eta_entry_i, _) =
4547 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4548 acc += residuals.entry[i] * eta_entry_i;
4549 }
4550 }
4551 acc
4552 };
4553
4554 let theta0 = survival_baseline_theta_from_config(&cfg)
4555 .expect("theta seed")
4556 .expect("GM has θ");
4557 let delta = 1e-4;
4559 let mut fd = Array1::<f64>::zeros(analytic.len());
4560 for k in 0..analytic.len() {
4561 let mut theta_plus = theta0.clone();
4562 theta_plus[k] += delta;
4563 let mut theta_minus = theta0.clone();
4564 theta_minus[k] -= delta;
4565 let cfg_plus =
4566 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4567 let cfg_minus =
4568 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4569 let lp = loss_at_cfg(&cfg_plus);
4570 let lm = loss_at_cfg(&cfg_minus);
4571 fd[k] = (lp - lm) / (2.0 * delta);
4572 }
4573
4574 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4575 let max_err = analytic
4576 .iter()
4577 .zip(fd.iter())
4578 .map(|(a, b)| (a - b).abs())
4579 .fold(0.0_f64, f64::max);
4580 let rel = max_err / (analytic_norm + 1e-12);
4581 eprintln!(
4583 "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4584 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4585 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4586 );
4587 assert!(
4588 rel < 1e-2,
4589 "analytic θ-gradient disagrees with central FD beyond 1%: \
4590 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4591 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4592 );
4593 }
4594
4595 #[test]
4610 fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4611 let cfg = SurvivalBaselineConfig {
4612 target: SurvivalBaselineTarget::Weibull,
4613 scale: Some(11.0),
4614 shape: Some(1.4),
4615 rate: None,
4616 makeham: None,
4617 };
4618 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4619 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4620 let residuals = OffsetChannelResiduals {
4621 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4622 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4623 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4624 right: Array1::<f64>::zeros(8),
4625 };
4626
4627 let analytic = baseline_chain_rule_gradient(
4628 age_entry.view(),
4629 age_exit.view(),
4630 age_exit.view(),
4631 &cfg,
4632 &residuals,
4633 )
4634 .expect("analytic gradient ok")
4635 .expect("Weibull baseline has a θ-gradient");
4636 assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4637
4638 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4639 let mut acc = 0.0;
4640 for i in 0..age_exit.len() {
4641 let (eta_exit_i, od_exit_i) =
4642 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4643 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4644 if residuals.entry[i] != 0.0 {
4645 let (eta_entry_i, _) =
4646 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4647 acc += residuals.entry[i] * eta_entry_i;
4648 }
4649 }
4650 acc
4651 };
4652
4653 let theta0 = survival_baseline_theta_from_config(&cfg)
4654 .expect("theta seed")
4655 .expect("Weibull has θ");
4656 let delta = 1e-4;
4657 let mut fd = Array1::<f64>::zeros(analytic.len());
4658 for k in 0..analytic.len() {
4659 let mut theta_plus = theta0.clone();
4660 theta_plus[k] += delta;
4661 let mut theta_minus = theta0.clone();
4662 theta_minus[k] -= delta;
4663 let cfg_plus =
4664 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4665 let cfg_minus =
4666 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4667 let lp = loss_at_cfg(&cfg_plus);
4668 let lm = loss_at_cfg(&cfg_minus);
4669 fd[k] = (lp - lm) / (2.0 * delta);
4670 }
4671
4672 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4673 let max_err = analytic
4674 .iter()
4675 .zip(fd.iter())
4676 .map(|(a, b)| (a - b).abs())
4677 .fold(0.0_f64, f64::max);
4678 let rel = max_err / (analytic_norm + 1e-12);
4679 eprintln!(
4680 "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4681 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4682 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4683 );
4684 assert!(
4685 rel < 1e-2,
4686 "analytic θ-gradient disagrees with central FD beyond 1%: \
4687 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4688 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4689 );
4690 }
4691
4692 fn fd_baseline_offset(
4705 age: f64,
4706 cfg: &SurvivalBaselineConfig,
4707 steps: &[f64],
4708 ) -> Vec<(f64, f64)> {
4709 let theta = survival_baseline_theta_from_config(cfg)
4710 .expect("theta")
4711 .expect("non-linear baseline");
4712 assert_eq!(
4713 steps.len(),
4714 theta.len(),
4715 "fd_baseline_offset: step vector length must match θ dimension"
4716 );
4717 (0..theta.len())
4718 .map(|k| {
4719 let h = steps[k];
4720 let mut theta_plus = theta.clone();
4721 theta_plus[k] += h;
4722 let mut theta_minus = theta.clone();
4723 theta_minus[k] -= h;
4724 let cfg_plus =
4725 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4726 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4727 .expect("minus cfg");
4728 let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4729 let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4730 ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4731 })
4732 .collect()
4733 }
4734
4735 fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4736 let ok = if expected.abs() < 1.0 {
4740 (actual - expected).abs() <= tol
4741 } else {
4742 (actual - expected).abs() <= tol * expected.abs().max(1.0)
4743 };
4744 assert!(
4745 ok,
4746 "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4747 );
4748 }
4749
4750 #[test]
4751 fn gompertz_offset_partials_match_central_diff() {
4752 let cases = [
4756 (0.5_f64, 0.01_f64, 30.0_f64),
4757 (0.2, 0.05, 60.0),
4758 (1.0, 0.001, 10.0),
4759 (0.4, 5e-11, 25.0),
4760 (0.4, -5e-11, 25.0),
4761 (0.3, -0.02, 40.0),
4762 (0.8, 0.2, 5.0),
4763 ];
4764 for &(rate, shape, age) in &cases {
4765 let cfg = SurvivalBaselineConfig {
4766 target: SurvivalBaselineTarget::Gompertz,
4767 scale: None,
4768 shape: Some(shape),
4769 rate: Some(rate),
4770 makeham: None,
4771 };
4772 let analytic = baseline_offset_theta_partials(age, &cfg)
4773 .expect("ok")
4774 .expect("non-linear");
4775 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4781 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4782 assert_eq!(analytic.len(), 2);
4783 assert_close(
4785 analytic[0].0,
4786 fd[0].0,
4787 1e-7,
4788 &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4789 );
4790 assert_close(
4791 analytic[0].1,
4792 fd[0].1,
4793 1e-7,
4794 &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4795 );
4796 assert_close(
4799 analytic[1].0,
4800 fd[1].0,
4801 1e-5,
4802 &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4803 );
4804 assert_close(
4805 analytic[1].1,
4806 fd[1].1,
4807 1e-5,
4808 &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4809 );
4810 }
4811 }
4812
4813 #[test]
4814 fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4815 let cfg = SurvivalBaselineConfig {
4819 target: SurvivalBaselineTarget::Gompertz,
4820 scale: None,
4821 shape: Some(0.05),
4822 rate: Some(0.3),
4823 makeham: None,
4824 };
4825 let partials = baseline_offset_theta_partials(42.0, &cfg)
4826 .expect("ok")
4827 .expect("non-linear");
4828 assert_eq!(partials[0].0, 1.0);
4829 assert_eq!(partials[0].1, 0.0);
4830 }
4831
4832 #[test]
4833 fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4834 let age = 25.0;
4841 let rate = 0.4;
4842 let cfg_taylor = SurvivalBaselineConfig {
4843 target: SurvivalBaselineTarget::Gompertz,
4844 scale: None,
4845 shape: Some(0.5e-10),
4846 rate: Some(rate),
4847 makeham: None,
4848 };
4849 let cfg_direct = SurvivalBaselineConfig {
4850 target: SurvivalBaselineTarget::Gompertz,
4851 scale: None,
4852 shape: Some(2.0e-10),
4853 rate: Some(rate),
4854 makeham: None,
4855 };
4856 let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4857 .expect("ok")
4858 .expect("nl");
4859 let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4860 .expect("ok")
4861 .expect("nl");
4862 assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4864 assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4865 assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4867 assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4868 }
4869
4870 #[test]
4882 fn gompertz_hazard_shape_derivatives_match_central_diff() {
4883 let cases = [
4888 (10.0_f64, 0.012_f64, 0.05_f64),
4889 (2.5, 0.5, 0.2),
4890 (15.0, 0.003, 0.01),
4891 (40.0, 0.3, 0.001),
4892 ];
4893 let h = 1e-6;
4894 for &(age, rate, shape) in &cases {
4895 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4897 let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4898 let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4899 assert_close(
4900 d_cum,
4901 (cum_p - cum_m) / (2.0 * h),
4902 1e-6,
4903 &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4904 );
4905 assert_close(
4906 d_inst,
4907 (inst_p - inst_m) / (2.0 * h),
4908 1e-6,
4909 &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4910 );
4911
4912 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4914 let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4915 let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4916 assert_close(
4917 d2_cum,
4918 (dcum_p - dcum_m) / (2.0 * h),
4919 1e-5,
4920 &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4921 );
4922 assert_close(
4923 d2_inst,
4924 (dinst_p - dinst_m) / (2.0 * h),
4925 1e-5,
4926 &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4927 );
4928 }
4929 }
4930
4931 #[test]
4932 fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4933 let cases = [
4945 (25.0_f64, 0.4_f64, 1e-9_f64),
4946 (100.0, 0.4, 1e-6), (100.0, 0.012, 1e-6), (50.0, 1.2, 1e-8),
4949 ];
4950 for &(age, rate, shape) in &cases {
4961 let t = age;
4962 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4963 assert_close(
4964 d_cum,
4965 rate * t * t / 2.0,
4966 1e-3,
4967 &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4968 );
4969 assert_close(
4970 d_inst,
4971 rate * t,
4972 1e-3,
4973 &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4974 );
4975
4976 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4977 assert_close(
4978 d2_cum,
4979 rate * t * t * t / 3.0,
4980 1e-3,
4981 &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4982 );
4983 assert_close(
4984 d2_inst,
4985 rate * t * t,
4986 1e-3,
4987 &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4988 );
4989 }
4990 }
4991
4992 #[test]
4993 fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4994 let age = 100.0;
5001 let rate = 0.4;
5002 let t = age;
5003 let truth = rate * t * t * t / 3.0; for k in 5..=12 {
5010 let shape = 10f64.powi(-(k as i32)); let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
5012 assert_close(
5013 d2_cum,
5014 truth,
5015 1e-3,
5016 &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
5017 );
5018 }
5019 }
5020
5021 #[test]
5022 fn weibull_offset_partials_match_central_diff() {
5023 let cases = [
5024 (0.5_f64, 1.2_f64, 25.0_f64),
5025 (2.0, 0.8, 60.0),
5026 (0.1, 3.0, 10.0),
5027 ];
5028 for &(scale, shape, age) in &cases {
5029 let cfg = SurvivalBaselineConfig {
5030 target: SurvivalBaselineTarget::Weibull,
5031 scale: Some(scale),
5032 shape: Some(shape),
5033 rate: None,
5034 makeham: None,
5035 };
5036 let analytic = baseline_offset_theta_partials(age, &cfg)
5037 .expect("ok")
5038 .expect("nl");
5039 let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
5040 assert_eq!(analytic.len(), 2);
5041 for k in 0..2 {
5042 assert_close(
5043 analytic[k].0,
5044 fd[k].0,
5045 1e-7,
5046 &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5047 );
5048 assert_close(
5049 analytic[k].1,
5050 fd[k].1,
5051 1e-7,
5052 &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5053 );
5054 }
5055 assert_eq!(analytic[0].1, 0.0);
5057 }
5058 }
5059
5060 #[test]
5061 fn gompertz_makeham_offset_partials_match_central_diff() {
5062 let cases = [
5063 (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
5064 (0.5, 0.01, 0.01, 25.0),
5065 (0.2, 0.001, 0.005, 60.0),
5066 (0.4, 5e-11, 0.01, 25.0),
5067 (0.4, -5e-11, 0.01, 25.0),
5068 (0.8, 0.2, 0.05, 5.0),
5069 ];
5070 for &(rate, shape, makeham, age) in &cases {
5071 let cfg = SurvivalBaselineConfig {
5072 target: SurvivalBaselineTarget::GompertzMakeham,
5073 scale: None,
5074 shape: Some(shape),
5075 rate: Some(rate),
5076 makeham: Some(makeham),
5077 };
5078 let analytic = baseline_offset_theta_partials(age, &cfg)
5079 .expect("ok")
5080 .expect("nl");
5081 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
5085 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
5086 assert_eq!(analytic.len(), 3);
5087 for k in 0..3 {
5088 assert_close(
5089 analytic[k].0,
5090 fd[k].0,
5091 1e-5,
5092 &format!(
5093 "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5094 ),
5095 );
5096 assert_close(
5097 analytic[k].1,
5098 fd[k].1,
5099 1e-5,
5100 &format!(
5101 "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5102 ),
5103 );
5104 }
5105 }
5106 }
5107
5108 #[test]
5109 fn linear_baseline_has_no_theta_partials() {
5110 let cfg = SurvivalBaselineConfig {
5111 target: SurvivalBaselineTarget::Linear,
5112 scale: None,
5113 shape: None,
5114 rate: None,
5115 makeham: None,
5116 };
5117 assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
5118 }
5119
5120 #[test]
5121 fn baseline_offset_partials_reject_non_positive_ages() {
5122 let cfg = SurvivalBaselineConfig {
5123 target: SurvivalBaselineTarget::Gompertz,
5124 scale: None,
5125 shape: Some(0.01),
5126 rate: Some(0.5),
5127 makeham: None,
5128 };
5129 assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
5130 assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
5131 assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
5132 }
5133
5134 #[test]
5140 fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
5141 let cfg = SurvivalBaselineConfig {
5142 target: SurvivalBaselineTarget::Gompertz,
5143 scale: None,
5144 shape: Some(0.05),
5145 rate: Some(0.3),
5146 makeham: None,
5147 };
5148 let age_entry = array![10.0_f64];
5149 let age_exit = array![25.0_f64];
5150 let residuals = OffsetChannelResiduals {
5151 exit: array![0.7_f64],
5152 entry: array![-0.2_f64],
5153 derivative: array![-0.4_f64],
5154 right: Array1::<f64>::zeros(1),
5155 };
5156 let grad = baseline_chain_rule_gradient(
5157 age_entry.view(),
5158 age_exit.view(),
5159 age_exit.view(),
5160 &cfg,
5161 &residuals,
5162 )
5163 .expect("ok")
5164 .expect("non-linear");
5165 let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5167 .unwrap()
5168 .unwrap();
5169 let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5170 .unwrap()
5171 .unwrap();
5172 for k in 0..p_exit.len() {
5173 let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5174 assert!(
5175 (grad[k] - expected).abs() < 1e-12,
5176 "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5177 grad[k],
5178 expected
5179 );
5180 }
5181 }
5182
5183 #[test]
5186 fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5187 let cfg = SurvivalBaselineConfig {
5188 target: SurvivalBaselineTarget::Gompertz,
5189 scale: None,
5190 shape: Some(0.05),
5191 rate: Some(0.3),
5192 makeham: None,
5193 };
5194 let age_entry = array![0.0_f64, 5.0_f64];
5195 let age_exit = array![10.0_f64, 20.0_f64];
5196 let residuals = OffsetChannelResiduals {
5197 exit: array![0.5_f64, 0.3_f64],
5198 entry: array![0.0_f64, -0.1_f64], derivative: array![-0.2_f64, 0.0_f64],
5200 right: Array1::<f64>::zeros(2),
5201 };
5202 let grad = baseline_chain_rule_gradient(
5204 age_entry.view(),
5205 age_exit.view(),
5206 age_exit.view(),
5207 &cfg,
5208 &residuals,
5209 )
5210 .expect("must not fail on origin-entry row with r_entry=0")
5211 .expect("non-linear");
5212 assert_eq!(grad.len(), 2);
5213 let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5215 let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5216 let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5217 for k in 0..2 {
5218 let expected = 0.5 * p_exit_0[k].0
5219 + (-0.2) * p_exit_0[k].1
5220 + 0.3 * p_exit_1[k].0
5221 + (-0.1) * p_entry_1[k].0;
5222 assert!(
5223 (grad[k] - expected).abs() < 1e-12,
5224 "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5225 grad[k],
5226 expected
5227 );
5228 }
5229 }
5230
5231 #[test]
5233 fn chain_rule_gradient_linear_target_returns_none() {
5234 let cfg = SurvivalBaselineConfig {
5235 target: SurvivalBaselineTarget::Linear,
5236 scale: None,
5237 shape: None,
5238 rate: None,
5239 makeham: None,
5240 };
5241 let age_entry = array![1.0_f64];
5242 let age_exit = array![2.0_f64];
5243 let residuals = OffsetChannelResiduals {
5244 exit: array![0.1_f64],
5245 entry: array![0.0_f64],
5246 derivative: array![0.0_f64],
5247 right: Array1::<f64>::zeros(1),
5248 };
5249 let grad = baseline_chain_rule_gradient(
5250 age_entry.view(),
5251 age_exit.view(),
5252 age_exit.view(),
5253 &cfg,
5254 &residuals,
5255 )
5256 .expect("ok");
5257 assert!(grad.is_none());
5258 }
5259
5260 #[test]
5279 fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5280 let cfg = SurvivalBaselineConfig {
5283 target: SurvivalBaselineTarget::Gompertz,
5284 scale: None,
5285 shape: Some(0.03),
5286 rate: Some(0.25),
5287 makeham: None,
5288 };
5289 let age_entry = array![0.0_f64, 5.0, 8.0];
5290 let age_exit = array![4.0_f64, 12.0, 20.0];
5291 let weights = array![1.0_f64, 2.0, 0.5];
5294 let events = [1.0_f64, 1.0, 0.0];
5295 let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5300 let s_vals = [0.7_f64, 1.1, 1.5];
5301 let (r_x, r_e, r_d) = {
5302 let mut rx = Array1::<f64>::zeros(3);
5303 let mut re = Array1::<f64>::zeros(3);
5304 let mut rd = Array1::<f64>::zeros(3);
5305 for i in 0..3 {
5306 let w = weights[i];
5307 let d = events[i];
5308 rx[i] = w * (eta_exit_vals[i].exp() - d);
5309 re[i] = if i == 0 {
5310 0.0 } else {
5312 -w * eta_entry_vals[i].exp()
5313 };
5314 rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5315 }
5316 (rx, re, rd)
5317 };
5318 let residuals = OffsetChannelResiduals {
5319 exit: r_x.clone(),
5320 entry: r_e.clone(),
5321 derivative: r_d.clone(),
5322 right: Array1::<f64>::zeros(3),
5323 };
5324 let grad = baseline_chain_rule_gradient(
5325 age_entry.view(),
5326 age_exit.view(),
5327 age_exit.view(),
5328 &cfg,
5329 &residuals,
5330 )
5331 .expect("ok")
5332 .expect("non-linear");
5333
5334 let nll = |theta_plus: &Array1<f64>| -> f64 {
5339 let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5340 let mut sum = 0.0_f64;
5341 for i in 0..3 {
5342 let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5343 let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5344 let d_eta_x = eta_x_p - base.0;
5345 let d_d_x = d_x_p - base.1;
5346 let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5347 let s_new = s_vals[i] + d_d_x;
5348 let interval_entry = if i == 0 {
5349 0.0_f64
5350 } else {
5351 let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5352 let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5353 let d_eta_e = eta_e_p - base_e.0;
5354 let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5355 eta_entry_new.exp()
5356 };
5357 let w = weights[i];
5358 let d = events[i];
5359 let nll_i =
5360 w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5361 sum += nll_i;
5362 }
5363 sum
5364 };
5365
5366 let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5367 let h = 1e-6;
5368 for k in 0..theta_base.len() {
5369 let mut tp = theta_base.clone();
5370 let mut tm = theta_base.clone();
5371 tp[k] += h;
5372 tm[k] -= h;
5373 let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5374 assert!(
5375 (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5376 "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5377 grad[k],
5378 fd
5379 );
5380 }
5381 }
5382
5383 #[test]
5385 fn chain_rule_gradient_rejects_length_mismatch() {
5386 let cfg = SurvivalBaselineConfig {
5387 target: SurvivalBaselineTarget::Gompertz,
5388 scale: None,
5389 shape: Some(0.05),
5390 rate: Some(0.3),
5391 makeham: None,
5392 };
5393 let age_entry = array![1.0_f64, 2.0]; let age_exit = array![5.0_f64, 6.0, 7.0]; let residuals = OffsetChannelResiduals {
5396 exit: array![0.1_f64, 0.2, 0.3],
5397 entry: array![0.0_f64, 0.0, 0.0],
5398 derivative: array![0.0_f64, 0.0, 0.0],
5399 right: Array1::<f64>::zeros(3),
5400 };
5401 let err = baseline_chain_rule_gradient(
5402 age_entry.view(),
5403 age_exit.view(),
5404 age_exit.view(),
5405 &cfg,
5406 &residuals,
5407 )
5408 .expect_err("length mismatch must error");
5409 assert!(err.contains("length mismatch"), "err={err}");
5410 }
5411}