1use gam_terms::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14 BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15 create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18 DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19 SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24 WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25 monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::outer_subsample::RowSet;
31use gam_problem::{InverseLink, StandardLink};
32use ndarray::{Array1, Array2, array, s};
33use rayon::prelude::*;
34
35#[derive(Clone, Debug)]
52pub enum SurvivalConstructionError {
53 InvalidConfig { reason: String },
56 MissingColumn { reason: String },
59 IncompatibleDimensions { reason: String },
62 DataValidationFailed { reason: String },
66 BasisConstructionFailed { reason: String },
70 UnsupportedDistribution { reason: String },
73}
74
75impl_reason_error_boilerplate! {
76 SurvivalConstructionError {
77 InvalidConfig,
78 MissingColumn,
79 IncompatibleDimensions,
80 DataValidationFailed,
81 BasisConstructionFailed,
82 UnsupportedDistribution,
83 }
84}
85
86#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub enum SurvivalBaselineTarget {
92 Linear,
96 Weibull,
101 Gompertz,
106 GompertzMakeham,
111}
112
113#[derive(Clone, Debug)]
114pub struct SurvivalBaselineConfig {
115 pub target: SurvivalBaselineTarget,
116 pub scale: Option<f64>,
117 pub shape: Option<f64>,
118 pub rate: Option<f64>,
119 pub makeham: Option<f64>,
120}
121
122#[derive(Clone, Debug)]
123pub enum SurvivalTimeBasisConfig {
124 None,
125 Linear,
126 BSpline {
127 degree: usize,
128 knots: Array1<f64>,
129 smooth_lambda: f64,
130 },
131 ISpline {
169 degree: usize,
170 knots: Array1<f64>,
171 keep_cols: Vec<usize>,
172 smooth_lambda: f64,
173 },
174}
175
176#[derive(Clone, Debug, PartialEq)]
190pub struct SavedSurvivalTimeBasis {
191 pub basisname: String,
192 pub degree: Option<usize>,
193 pub knots: Option<Vec<f64>>,
194 pub keep_cols: Option<Vec<usize>>,
195 pub smooth_lambda: Option<f64>,
196 pub anchor: f64,
197}
198
199impl SavedSurvivalTimeBasis {
200 pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
203 Self {
204 basisname: build.basisname.clone(),
205 degree: build.degree,
206 knots: build.knots.clone(),
207 keep_cols: build.keep_cols.clone(),
208 smooth_lambda: build.smooth_lambda,
209 anchor,
210 }
211 }
212}
213
214#[derive(Clone)]
215pub struct SurvivalTimeBuildOutput {
216 pub x_entry_time: DesignMatrix,
217 pub x_exit_time: DesignMatrix,
218 pub x_derivative_time: DesignMatrix,
219 pub penalties: Vec<Array2<f64>>,
220 pub nullspace_dims: Vec<usize>,
222 pub basisname: String,
223 pub degree: Option<usize>,
224 pub knots: Option<Vec<f64>>,
225 pub keep_cols: Option<Vec<usize>>,
226 pub smooth_lambda: Option<f64>,
227}
228
229pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
230
231pub const SURVIVAL_DELAYED_ENTRY_THRESHOLD: f64 = 1e-8;
237
238const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
246
247const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
255
256#[derive(Clone, Copy, Debug, PartialEq, Eq)]
257pub enum SurvivalLikelihoodMode {
258 Transformation,
259 Weibull,
260 LocationScale,
261 MarginalSlope,
262 Latent,
263 LatentBinary,
264}
265
266pub struct SurvivalTimeWiggleBuild {
267 pub penalties: Vec<Array2<f64>>,
268 pub nullspace_dims: Vec<usize>,
269 pub knots: Array1<f64>,
270 pub degree: usize,
271 pub ncols: usize,
272}
273
274pub fn normalize_survival_time_pair(
279 entry_raw: f64,
280 exit_raw: f64,
281 row_index: usize,
282) -> Result<(f64, f64), String> {
283 if !entry_raw.is_finite() || !exit_raw.is_finite() {
284 return Err(SurvivalConstructionError::DataValidationFailed {
285 reason: format!("non-finite survival times at row {}", row_index + 1),
286 }
287 .into());
288 }
289 if entry_raw < 0.0 || exit_raw < 0.0 {
290 return Err(SurvivalConstructionError::DataValidationFailed {
291 reason: format!("negative survival times at row {}", row_index + 1),
292 }
293 .into());
294 }
295
296 let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
297 let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
298 Ok((entry, exit))
299}
300
301pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
306 basisname.eq_ignore_ascii_case("ispline")
307}
308
309pub fn require_structural_survival_time_basis(
310 basisname: &str,
311 context: &str,
312) -> Result<(), String> {
313 if survival_basis_supports_structural_monotonicity(basisname) {
314 return Ok(());
315 }
316 Err(SurvivalConstructionError::UnsupportedDistribution {
317 reason: format!(
318 "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
319Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
320`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
321Re-run with `--time-basis ispline`."
322 ),
323 }
324 .into())
325}
326
327pub fn parse_survival_baseline_config(
332 target_raw: &str,
333 scale: Option<f64>,
334 shape: Option<f64>,
335 rate: Option<f64>,
336 makeham: Option<f64>,
337) -> Result<SurvivalBaselineConfig, String> {
338 let target = match target_raw.to_ascii_lowercase().as_str() {
339 "linear" => SurvivalBaselineTarget::Linear,
340 "weibull" => SurvivalBaselineTarget::Weibull,
341 "gompertz" => SurvivalBaselineTarget::Gompertz,
342 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
343 other => {
344 return Err(SurvivalConstructionError::UnsupportedDistribution {
345 reason: format!(
346 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
347 ),
348 }
349 .into());
350 }
351 };
352
353 match target {
354 SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
355 target,
356 scale: None,
357 shape: None,
358 rate: None,
359 makeham: None,
360 }),
361 SurvivalBaselineTarget::Weibull => {
362 let scale = scale.ok_or_else(|| {
363 "--baseline-target weibull requires --baseline-scale > 0".to_string()
364 })?;
365 let shape = shape.ok_or_else(|| {
366 "--baseline-target weibull requires --baseline-shape > 0".to_string()
367 })?;
368 if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
369 return Err(
370 "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
371 .to_string(),
372 );
373 }
374 Ok(SurvivalBaselineConfig {
375 target,
376 scale: Some(scale),
377 shape: Some(shape),
378 rate: None,
379 makeham: None,
380 })
381 }
382 SurvivalBaselineTarget::Gompertz => {
383 let rate = rate.unwrap_or(1.0);
384 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
385 if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
386 return Err(
387 "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
388 .to_string(),
389 );
390 }
391 Ok(SurvivalBaselineConfig {
392 target,
393 scale: None,
394 shape: Some(shape),
395 rate: Some(rate),
396 makeham: None,
397 })
398 }
399 SurvivalBaselineTarget::GompertzMakeham => {
400 let rate = rate.unwrap_or(0.5);
401 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
402 let makeham = makeham.unwrap_or(0.5);
403 if !rate.is_finite()
404 || rate <= 0.0
405 || !shape.is_finite()
406 || !makeham.is_finite()
407 || makeham <= 0.0
408 {
409 return Err(
410 "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
411 .to_string(),
412 );
413 }
414 Ok(SurvivalBaselineConfig {
415 target,
416 scale: None,
417 shape: Some(shape),
418 rate: Some(rate),
419 makeham: Some(makeham),
420 })
421 }
422 }
423}
424
425pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
430 match raw.to_ascii_lowercase().as_str() {
431 "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
432 "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
433 "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
434 "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
435 "latent" => Ok(SurvivalLikelihoodMode::Latent),
436 "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
437 other => Err(SurvivalConstructionError::UnsupportedDistribution {
438 reason: format!(
439 "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
440 ),
441 }
442 .into()),
443 }
444}
445
446pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
447 match mode {
448 SurvivalLikelihoodMode::Transformation => "transformation",
449 SurvivalLikelihoodMode::Weibull => "weibull",
450 SurvivalLikelihoodMode::LocationScale => "location-scale",
451 SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
452 SurvivalLikelihoodMode::Latent => "latent",
453 SurvivalLikelihoodMode::LatentBinary => "latent-binary",
454 }
455}
456
457pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
458 match raw.to_ascii_lowercase().as_str() {
459 "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
460 "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
461 "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
462 other => Err(SurvivalConstructionError::UnsupportedDistribution {
463 reason: format!(
464 "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
465 ),
466 }
467 .into()),
468 }
469}
470
471pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
472 match target {
473 SurvivalBaselineTarget::Linear => "linear",
474 SurvivalBaselineTarget::Weibull => "weibull",
475 SurvivalBaselineTarget::Gompertz => "gompertz",
476 SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
477 }
478}
479
480pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
481 let sum = age_exit
482 .iter()
483 .copied()
484 .filter(|value| value.is_finite() && *value > 0.0)
485 .sum::<f64>();
486 let count = age_exit
487 .iter()
488 .filter(|value| value.is_finite() && **value > 0.0)
489 .count()
490 .max(1);
491 (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
492}
493
494pub fn initial_survival_baseline_config_for_fit(
495 target_raw: &str,
496 scale: Option<f64>,
497 shape: Option<f64>,
498 rate: Option<f64>,
499 makeham: Option<f64>,
500 age_exit: &Array1<f64>,
501) -> Result<SurvivalBaselineConfig, String> {
502 let target = match target_raw.trim().to_ascii_lowercase().as_str() {
503 "linear" => SurvivalBaselineTarget::Linear,
504 "weibull" => SurvivalBaselineTarget::Weibull,
505 "gompertz" => SurvivalBaselineTarget::Gompertz,
506 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
507 other => {
508 return Err(SurvivalConstructionError::UnsupportedDistribution {
509 reason: format!(
510 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
511 ),
512 }
513 .into());
514 }
515 };
516 let time_scale_seed = positive_survival_time_seed(age_exit);
517 let cfg = match target {
518 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
519 target,
520 scale: None,
521 shape: None,
522 rate: None,
523 makeham: None,
524 },
525 SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
526 target,
527 scale: Some(scale.unwrap_or(time_scale_seed)),
528 shape: Some(shape.unwrap_or(1.0)),
529 rate: None,
530 makeham: None,
531 },
532 SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
533 target,
534 scale: None,
535 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
536 rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
537 makeham: None,
538 },
539 SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
540 target,
541 scale: None,
542 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
543 rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
544 makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
545 },
546 };
547 parse_survival_baseline_config(
548 survival_baseline_targetname(cfg.target),
549 cfg.scale,
550 cfg.shape,
551 cfg.rate,
552 cfg.makeham,
553 )
554}
555
556fn survival_baseline_theta_from_config(
557 cfg: &SurvivalBaselineConfig,
558) -> Result<Option<Array1<f64>>, String> {
559 Ok(match cfg.target {
560 SurvivalBaselineTarget::Linear => None,
561 SurvivalBaselineTarget::Weibull => Some(array![
562 cfg.scale
563 .ok_or_else(|| "missing weibull baseline scale".to_string())?
564 .ln(),
565 cfg.shape
566 .ok_or_else(|| "missing weibull baseline shape".to_string())?
567 .ln(),
568 ]),
569 SurvivalBaselineTarget::Gompertz => Some(array![
570 cfg.rate
571 .ok_or_else(|| "missing gompertz baseline rate".to_string())?
572 .ln(),
573 cfg.shape
574 .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
575 ]),
576 SurvivalBaselineTarget::GompertzMakeham => Some(array![
577 cfg.rate
578 .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
579 .ln(),
580 cfg.shape
581 .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
582 cfg.makeham
583 .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
584 .ln(),
585 ]),
586 })
587}
588
589fn survival_baseline_config_from_theta(
590 target: SurvivalBaselineTarget,
591 theta: &Array1<f64>,
592) -> Result<SurvivalBaselineConfig, String> {
593 let cfg = match target {
594 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
595 target,
596 scale: None,
597 shape: None,
598 rate: None,
599 makeham: None,
600 },
601 SurvivalBaselineTarget::Weibull => {
602 if theta.len() != 2 {
603 return Err(SurvivalConstructionError::IncompatibleDimensions {
604 reason: format!(
605 "weibull baseline parameter dimension mismatch: expected 2, got {}",
606 theta.len()
607 ),
608 }
609 .into());
610 }
611 SurvivalBaselineConfig {
612 target,
613 scale: Some(theta[0].exp()),
614 shape: Some(theta[1].exp()),
615 rate: None,
616 makeham: None,
617 }
618 }
619 SurvivalBaselineTarget::Gompertz => {
620 if theta.len() != 2 {
621 return Err(SurvivalConstructionError::IncompatibleDimensions {
622 reason: format!(
623 "gompertz baseline parameter dimension mismatch: expected 2, got {}",
624 theta.len()
625 ),
626 }
627 .into());
628 }
629 SurvivalBaselineConfig {
630 target,
631 scale: None,
632 shape: Some(theta[1]),
633 rate: Some(theta[0].exp()),
634 makeham: None,
635 }
636 }
637 SurvivalBaselineTarget::GompertzMakeham => {
638 if theta.len() != 3 {
639 return Err(SurvivalConstructionError::IncompatibleDimensions {
640 reason: format!(
641 "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
642 theta.len()
643 ),
644 }
645 .into());
646 }
647 SurvivalBaselineConfig {
648 target,
649 scale: None,
650 shape: Some(theta[1]),
651 rate: Some(theta[0].exp()),
652 makeham: Some(theta[2].exp()),
653 }
654 }
655 };
656 parse_survival_baseline_config(
657 survival_baseline_targetname(cfg.target),
658 cfg.scale,
659 cfg.shape,
660 cfg.rate,
661 cfg.makeham,
662 )
663}
664
665#[derive(Clone, Copy, Debug, PartialEq, Eq)]
678enum BaselineDerivativeContract {
679 GradientOnly,
682 GradientHessian,
686}
687
688impl BaselineDerivativeContract {
689 fn configure(
694 self,
695 problem: gam_solve::rho_optimizer::OuterProblem,
696 ) -> gam_solve::rho_optimizer::OuterProblem {
697 use gam_problem::{DeclaredHessianForm, Derivative};
698 match self {
699 BaselineDerivativeContract::GradientOnly => problem
702 .with_gradient(Derivative::Analytic)
703 .with_hessian(DeclaredHessianForm::Unavailable)
704 .with_tolerance(1e-4)
705 .with_max_iter(240),
706 BaselineDerivativeContract::GradientHessian => problem
707 .with_gradient(Derivative::Analytic)
708 .with_hessian(DeclaredHessianForm::Either)
709 .with_tolerance(1e-4)
710 .with_max_iter(240),
711 }
712 }
713}
714
715fn run_baseline_theta_optimizer<Fc, Fe>(
726 initial: &SurvivalBaselineConfig,
727 context: &str,
728 contract: BaselineDerivativeContract,
729 cost_fn: Fc,
730 eval_fn: Fe,
731) -> Result<SurvivalBaselineConfig, String>
732where
733 Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
734 Fe: FnMut(
735 &mut (),
736 &Array1<f64>,
737 ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
738{
739 use gam_solve::rho_optimizer::OuterProblem;
740 let Some(seed) = survival_baseline_theta_from_config(initial)? else {
741 return Ok(initial.clone());
742 };
743 let dim = seed.len();
744 let target = initial.target;
745 let lower = seed.mapv(|v| v - 6.0);
746 let upper = seed.mapv(|v| v + 6.0);
747 let problem = contract
748 .configure(OuterProblem::new(dim))
749 .with_bounds(lower, upper)
750 .with_initial_rho(seed.clone())
751 .with_seed_config(crate::seeding::SeedConfig {
752 max_seeds: 1,
753 seed_budget: 1,
754 num_auxiliary_trailing: dim,
755 ..Default::default()
756 });
757 let mut obj = problem.build_objective(
758 (),
759 cost_fn,
760 eval_fn,
761 None::<fn(&mut ())>,
762 None::<
763 fn(
764 &mut (),
765 &Array1<f64>,
766 ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
767 >,
768 );
769 let result = problem
770 .run(&mut obj, context)
771 .map_err(|e| format!("{context} failed: {e}"))?;
772 if !result.converged {
773 return Err(SurvivalConstructionError::InvalidConfig {
774 reason: format!(
775 "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
776 result.iterations,
777 result.final_value,
778 result.final_grad_norm_report(),
779 ),
780 }
781 .into());
782 }
783 survival_baseline_config_from_theta(target, &result.rho)
784}
785
786fn run_baseline_theta_optimizer_with_eval<F>(
799 initial: &SurvivalBaselineConfig,
800 context: &str,
801 contract: BaselineDerivativeContract,
802 objective: F,
803) -> Result<SurvivalBaselineConfig, String>
804where
805 F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
806{
807 let target = initial.target;
808 let engine_context = context.to_string();
809 let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
810 let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
811 theta: &Array1<f64>|
812 -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
813 let cfg = survival_baseline_config_from_theta(target, theta)
814 .map_err(crate::model_types::EstimationError::InvalidInput)?;
815 let eval =
816 obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
817 if eval.gradient.len() != theta.len() {
818 return Err(crate::model_types::EstimationError::InvalidInput(format!(
819 "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
820 eval.gradient.len(),
821 theta.len()
822 )));
823 }
824 if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
825 if h.nrows() != theta.len() || h.ncols() != theta.len() {
826 return Err(crate::model_types::EstimationError::InvalidInput(format!(
827 "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
828 h.nrows(),
829 h.ncols(),
830 theta.len(),
831 theta.len()
832 )));
833 }
834 }
835 Ok(eval)
836 };
837 let cost_objective = std::rc::Rc::clone(&objective);
838 let cost_eval = eval_at.clone();
839 let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
840 cost_eval(&cost_objective, theta).map(|eval| eval.cost)
841 };
842 let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
843 run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
844}
845
846pub fn optimize_survival_baseline_config_with_gradient_only<F>(
857 initial: &SurvivalBaselineConfig,
858 context: &str,
859 mut objective: F,
860) -> Result<SurvivalBaselineConfig, String>
861where
862 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
863{
864 use gam_problem::{HessianResult, OuterEval};
865 run_baseline_theta_optimizer_with_eval(
866 initial,
867 context,
868 BaselineDerivativeContract::GradientOnly,
869 move |cfg| {
870 let (cost, gradient) = objective(cfg)?;
871 Ok(OuterEval {
872 cost,
873 gradient,
874 hessian: HessianResult::Unavailable,
875 inner_beta_hint: None,
876 })
877 },
878 )
879}
880
881pub fn optimize_survival_baseline_config_with_gradient<F>(
886 initial: &SurvivalBaselineConfig,
887 context: &str,
888 mut objective: F,
889) -> Result<SurvivalBaselineConfig, String>
890where
891 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
892{
893 use gam_problem::{HessianResult, OuterEval};
894 run_baseline_theta_optimizer_with_eval(
895 initial,
896 context,
897 BaselineDerivativeContract::GradientHessian,
898 move |cfg| {
899 let (cost, gradient, hessian) = objective(cfg)?;
900 Ok(OuterEval {
901 cost,
902 gradient,
903 hessian: HessianResult::Analytic(hessian),
904 inner_beta_hint: None,
905 })
906 },
907 )
908}
909
910pub fn parse_survival_time_basis_config(
915 time_basis: &str,
916 time_degree: usize,
917 time_num_internal_knots: usize,
918 time_smooth_lambda: f64,
919) -> Result<SurvivalTimeBasisConfig, String> {
920 match time_basis.to_ascii_lowercase().as_str() {
921 "none" => Ok(SurvivalTimeBasisConfig::None),
922 "ispline" => {
923 if time_degree < 1 {
924 return Err(
925 "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
926 .to_string(),
927 );
928 }
929 if time_num_internal_knots == 0 {
930 return Err(
931 "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
932 .to_string(),
933 );
934 }
935 if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
936 return Err(
937 "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
938 .to_string(),
939 );
940 }
941 Ok(SurvivalTimeBasisConfig::ISpline {
942 degree: time_degree,
943 knots: Array1::zeros(0),
944 keep_cols: Vec::new(),
945 smooth_lambda: time_smooth_lambda,
946 })
947 }
948 "linear" | "bspline" => {
949 match require_structural_survival_time_basis(time_basis, "survival model configuration")
956 {
957 Err(e) => Err(e),
958 Ok(()) => Err(format!(
959 "internal: structural-basis check accepted non-structural \
960 survival time basis '{time_basis}'"
961 )),
962 }
963 }
964 other => Err(format!(
965 "unsupported --time-basis '{other}'; accepted values: ispline, none"
966 )),
967 }
968}
969
970pub fn build_survival_time_basis(
975 age_entry: &Array1<f64>,
976 age_exit: &Array1<f64>,
977 cfg: SurvivalTimeBasisConfig,
978 infer_knots_if_needed: Option<(usize, f64)>,
979) -> Result<SurvivalTimeBuildOutput, String> {
980 fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
981 if let Some(row) = times.iter().position(|t| !t.is_finite()) {
982 return Err(SurvivalConstructionError::DataValidationFailed {
983 reason: format!(
984 "survival time basis requires finite {label} times (row {})",
985 row + 1
986 ),
987 }
988 .into());
989 }
990 if let Some(row) = times.iter().position(|t| *t < 0.0) {
991 return Err(SurvivalConstructionError::DataValidationFailed {
992 reason: format!(
993 "survival time basis requires non-negative {label} times (row {})",
994 row + 1
995 ),
996 }
997 .into());
998 }
999 Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
1000 }
1001
1002 let n = age_entry.len();
1003 if n != age_exit.len() {
1004 return Err(SurvivalConstructionError::IncompatibleDimensions {
1005 reason: "survival time basis requires matching entry/exit lengths".to_string(),
1006 }
1007 .into());
1008 }
1009 for i in 0..n {
1010 if age_exit[i] < age_entry[i] {
1011 return Err(format!(
1012 "survival time basis requires exit times >= entry times (row {})",
1013 i + 1
1014 ));
1015 }
1016 }
1017 let log_entry = checked_log_survival_times(age_entry, "entry")?;
1018 let log_exit = checked_log_survival_times(age_exit, "exit")?;
1019
1020 fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1021 let n = log_entry.len();
1022 let entry_range = log_entry
1023 .iter()
1024 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1025 (lo.min(v), hi.max(v))
1026 });
1027 let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1028 if entry_degenerate {
1029 log_exit.clone()
1030 } else {
1031 let mut combined = Array1::<f64>::zeros(2 * n);
1032 for i in 0..n {
1033 combined[i] = log_entry[i];
1034 combined[n + i] = log_exit[i];
1035 }
1036 combined
1037 }
1038 }
1039
1040 fn data_capped_internal_knots(
1063 combined: &Array1<f64>,
1064 degree: usize,
1065 requested_internal_knots: usize,
1066 ) -> usize {
1067 if requested_internal_knots == 0 {
1068 return 0;
1069 }
1070 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1071 sorted.sort_by(f64::total_cmp);
1072 let minval = sorted.first().copied().unwrap_or(0.0);
1073 let maxval = sorted.last().copied().unwrap_or(minval);
1074 if minval == maxval {
1075 return 1.min(requested_internal_knots);
1077 }
1078 let scale = (maxval - minval).abs().max(1.0);
1079 let tol = 1e-12 * scale;
1080 let mut distinct_interior = 0usize;
1083 let mut last: Option<f64> = None;
1084 for &x in &sorted {
1085 if x <= minval + tol || x >= maxval - tol {
1086 continue;
1087 }
1088 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1089 continue;
1090 }
1091 distinct_interior += 1;
1092 last = Some(x);
1093 }
1094 let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1097 let n_distinct = {
1103 let mut count = 0usize;
1104 let mut last: Option<f64> = None;
1105 for &x in &sorted {
1106 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1107 continue;
1108 }
1109 count += 1;
1110 last = Some(x);
1111 }
1112 count
1113 };
1114 let dim_budget = n_distinct / 4;
1115 let dim_cap = dim_budget.saturating_sub(degree);
1116 cap = cap.min(dim_cap.max(1));
1117 cap.max(1)
1118 }
1119
1120 fn infer_survival_time_knots(
1121 combined: &Array1<f64>,
1122 knot_degree: usize,
1123 validation_degree: usize,
1124 num_internal_knots: usize,
1125 basis_options: BasisOptions,
1126 ) -> Result<Array1<f64>, String> {
1127 let num_internal_knots =
1133 data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1134
1135 fn quantile_knot_inference_needs_uniform_fallback(
1136 combined: &Array1<f64>,
1137 num_internal_knots: usize,
1138 ) -> bool {
1139 if num_internal_knots == 0 || combined.is_empty() {
1140 return false;
1141 }
1142
1143 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1144 sorted.sort_by(f64::total_cmp);
1145 let minval = sorted[0];
1146 let maxval = *sorted.last().unwrap_or(&minval);
1147 if minval == maxval {
1148 return false;
1149 }
1150
1151 let scale = (maxval - minval).abs().max(1.0);
1152 let tol = 1e-12 * scale;
1153 let mut support = Vec::with_capacity(sorted.len());
1154 let mut last: Option<f64> = None;
1155 for &x in &sorted {
1156 if x <= minval + tol || x >= maxval - tol {
1157 continue;
1158 }
1159 if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1160 continue;
1161 }
1162 support.push(x);
1163 last = Some(x);
1164 }
1165 if support.is_empty() {
1166 return true;
1167 }
1168
1169 let n = support.len();
1170 let mut prev_q = minval;
1171 for j in 1..=num_internal_knots {
1172 let p = j as f64 / (num_internal_knots + 1) as f64;
1173 let pos = p * (n.saturating_sub(1) as f64);
1174 let lo = pos.floor() as usize;
1175 let hi = pos.ceil() as usize;
1176 let frac = pos - lo as f64;
1177 let q = if lo == hi {
1178 support[lo]
1179 } else {
1180 support[lo] * (1.0 - frac) + support[hi] * frac
1181 }
1182 .clamp(minval, maxval);
1183 if q <= prev_q + tol || q >= maxval - tol {
1184 return true;
1185 }
1186 prev_q = q;
1187 }
1188
1189 false
1190 }
1191
1192 let inferwith =
1193 |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1194 let built = build_bspline_basis_1d(
1195 combined.view(),
1196 &BSplineBasisSpec {
1197 degree: knot_degree,
1198 penalty_order: 2,
1199 knotspec: BSplineKnotSpec::Automatic {
1200 num_internal_knots: Some(num_internal_knots),
1201 placement,
1202 },
1203 double_penalty: false,
1204 identifiability: BSplineIdentifiability::None,
1205 boundary: OneDimensionalBoundary::Open,
1206 boundary_conditions: BSplineBoundaryConditions::default(),
1207 },
1208 )
1209 .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1210 let knots = match built.metadata {
1211 BasisMetadata::BSpline1D { knots, .. } => knots,
1212 _ => {
1213 return Err(
1214 "internal error: expected BSpline1D metadata for survival time basis"
1215 .to_string(),
1216 );
1217 }
1218 };
1219 create_basis::<Dense>(
1228 combined.view(),
1229 KnotSource::Provided(knots.view()),
1230 validation_degree,
1231 basis_options,
1232 )
1233 .map_err(|e| e.to_string())?;
1234 Ok(knots)
1235 };
1236
1237 if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1238 inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1239 } else {
1240 inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1241 }
1242 }
1243
1244 match cfg {
1245 SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1246 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1247 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1248 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1249 penalties: Vec::new(),
1250 nullspace_dims: Vec::new(),
1251 basisname: "none".to_string(),
1252 degree: None,
1253 knots: None,
1254 keep_cols: None,
1255 smooth_lambda: None,
1256 }),
1257 SurvivalTimeBasisConfig::Linear => {
1258 let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1259 let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1260 let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1261 for i in 0..n {
1262 x_entry_time[[i, 0]] = 1.0;
1263 x_exit_time[[i, 0]] = 1.0;
1264 x_entry_time[[i, 1]] = log_entry[i];
1265 x_exit_time[[i, 1]] = log_exit[i];
1266 x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1267 }
1268 Ok(SurvivalTimeBuildOutput {
1269 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1270 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1271 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1272 penalties: Vec::new(),
1273 nullspace_dims: Vec::new(),
1274 basisname: "linear".to_string(),
1275 degree: None,
1276 knots: None,
1277 keep_cols: None,
1278 smooth_lambda: None,
1279 })
1280 }
1281 SurvivalTimeBasisConfig::BSpline {
1282 degree,
1283 knots,
1284 smooth_lambda,
1285 } => {
1286 let knotvec = if knots.is_empty() {
1287 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1288 "internal error: bspline time basis requested without knot source".to_string()
1289 })?;
1290 let combined = survival_time_knot_input(&log_entry, &log_exit);
1291 infer_survival_time_knots(
1292 &combined,
1293 degree,
1294 degree,
1295 num_internal_knots,
1296 BasisOptions::value(),
1297 )?
1298 } else {
1299 knots
1300 };
1301
1302 let entry_basis = build_bspline_basis_1d(
1303 log_entry.view(),
1304 &BSplineBasisSpec {
1305 degree,
1306 penalty_order: 2,
1307 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1308 double_penalty: false,
1309 identifiability: BSplineIdentifiability::None,
1310 boundary: OneDimensionalBoundary::Open,
1311 boundary_conditions: BSplineBoundaryConditions::default(),
1312 },
1313 )
1314 .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1315 let exit_basis = build_bspline_basis_1d(
1316 log_exit.view(),
1317 &BSplineBasisSpec {
1318 degree,
1319 penalty_order: 2,
1320 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1321 double_penalty: false,
1322 identifiability: BSplineIdentifiability::None,
1323 boundary: OneDimensionalBoundary::Open,
1324 boundary_conditions: BSplineBoundaryConditions::default(),
1325 },
1326 )
1327 .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1328
1329 let p_time = exit_basis.design.ncols();
1330 let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1334 let mut deriv_buf = vec![0.0_f64; p_time];
1335 for i in 0..n {
1336 deriv_buf.fill(0.0);
1337 evaluate_bspline_derivative_scalar(
1338 log_exit[i],
1339 knotvec.view(),
1340 degree,
1341 &mut deriv_buf,
1342 )
1343 .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1344 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1345 for j in 0..p_time {
1346 let v = deriv_buf[j] * chain;
1347 if v.abs() > 1e-15 {
1348 deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1349 }
1350 }
1351 }
1352 let x_derivative_time =
1353 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1354 {
1355 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1356 Err(_) => {
1357 let mut dense = Array2::<f64>::zeros((n, p_time));
1359 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1360 dense[[row, col]] = val;
1361 }
1362 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1363 }
1364 };
1365
1366 Ok(SurvivalTimeBuildOutput {
1367 x_entry_time: entry_basis.design,
1368 x_exit_time: exit_basis.design,
1369 x_derivative_time,
1370 nullspace_dims: entry_basis.nullspace_dims,
1371 penalties: entry_basis.penalties,
1372 basisname: "bspline".to_string(),
1373 degree: Some(degree),
1374 knots: Some(knotvec.to_vec()),
1375 keep_cols: None,
1376 smooth_lambda: Some(smooth_lambda),
1377 })
1378 }
1379 SurvivalTimeBasisConfig::ISpline {
1380 degree,
1381 knots,
1382 keep_cols,
1383 smooth_lambda,
1384 } => {
1385 let bspline_degree = degree
1386 .checked_add(1)
1387 .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1388 let knotvec = if knots.is_empty() {
1389 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1390 "internal error: ispline time basis requested without knot source".to_string()
1391 })?;
1392 let combined = survival_time_knot_input(&log_entry, &log_exit);
1393 infer_survival_time_knots(
1394 &combined,
1395 bspline_degree,
1396 degree,
1397 num_internal_knots,
1398 BasisOptions::i_spline(),
1399 )?
1400 } else {
1401 knots
1402 };
1403
1404 let (db_exit_arc, _) = create_basis::<Dense>(
1405 log_exit.view(),
1406 KnotSource::Provided(knotvec.view()),
1407 bspline_degree,
1408 BasisOptions::first_derivative(),
1409 )
1410 .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1411
1412 let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1415 let (entry_arc, _) = create_basis::<Dense>(
1416 log_entry.view(),
1417 KnotSource::Provided(knotvec.view()),
1418 degree,
1419 BasisOptions::i_spline(),
1420 )
1421 .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1422 let (exit_arc, _) = create_basis::<Dense>(
1423 log_exit.view(),
1424 KnotSource::Provided(knotvec.view()),
1425 degree,
1426 BasisOptions::i_spline(),
1427 )
1428 .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1429
1430 let x_entry_full = entry_arc.as_ref();
1431 let x_exit_full = exit_arc.as_ref();
1432 let p_time_full = x_exit_full.ncols();
1433 if p_time_full == 0 {
1434 return Err(SurvivalConstructionError::BasisConstructionFailed {
1435 reason: "internal error: empty ispline time basis".to_string(),
1436 }
1437 .into());
1438 }
1439 let db_exit = db_exit_arc.as_ref();
1440 if db_exit.ncols() != p_time_full + 1 {
1441 return Err(
1442 "internal error: ispline derivative basis width must exceed basis width by one"
1443 .to_string(),
1444 );
1445 }
1446
1447 let keep_cols = if keep_cols.is_empty() {
1448 let constant_tol = 1e-12_f64;
1449 let mut inferred_keep_cols: Vec<usize> = Vec::new();
1450 for j in 0..p_time_full {
1451 let mut minv = f64::INFINITY;
1452 let mut maxv = f64::NEG_INFINITY;
1453 for i in 0..n {
1454 let ve = x_exit_full[[i, j]];
1455 let vs = x_entry_full[[i, j]];
1456 minv = minv.min(ve.min(vs));
1457 maxv = maxv.max(ve.max(vs));
1458 }
1459 if (maxv - minv) > constant_tol {
1460 inferred_keep_cols.push(j);
1461 }
1462 }
1463 inferred_keep_cols
1464 } else {
1465 keep_cols
1466 };
1467 if keep_cols.is_empty() {
1468 return Err(
1469 "internal error: ispline basis has no shape-varying time columns"
1470 .to_string(),
1471 );
1472 }
1473 if keep_cols.iter().any(|&j| j >= p_time_full) {
1474 return Err(SurvivalConstructionError::MissingColumn {
1475 reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1476 }
1477 .into());
1478 }
1479
1480 let p_time = keep_cols.len();
1481 let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1482 let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1483 (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1486 };
1487 let db_exit = db_exit_arc.as_ref();
1488
1489 let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1494 let mut found_nonfinite: Option<(usize, usize)> = None;
1495 for i in 0..n {
1496 let mut running = 0.0_f64;
1497 let mut d_i_log_full = vec![0.0_f64; p_time_full];
1498 for j in (1..db_exit.ncols()).rev() {
1499 let term = db_exit[[i, j]];
1500 if term.is_finite() {
1501 running += term;
1502 }
1503 d_i_log_full[j - 1] = running;
1504 }
1505 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1506 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1507 let raw_v = d_i_log_full[j_old] * chain;
1508 let v = if (-1e-12..0.0).contains(&raw_v) {
1509 0.0
1510 } else {
1511 raw_v
1512 };
1513 if !v.is_finite() {
1514 found_nonfinite = Some((i, j_new));
1515 }
1516 if v < -1e-12 {
1517 return Err(format!(
1518 "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1519 i + 1,
1520 j_new + 1,
1521 v
1522 ));
1523 }
1524 if v.abs() > 1e-15 {
1525 deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1526 }
1527 }
1528 }
1529 if let Some((row, col)) = found_nonfinite {
1530 return Err(format!(
1531 "survival ispline derivative basis produced non-finite value at row {}, column {}",
1532 row + 1,
1533 col + 1
1534 ));
1535 }
1536 let x_derivative_time =
1537 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1538 {
1539 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1540 Err(_) => {
1541 let mut dense = Array2::<f64>::zeros((n, p_time));
1542 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1543 dense[[row, col]] = val;
1544 }
1545 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1546 }
1547 };
1548
1549 let penalty_basis = build_bspline_basis_1d(
1550 log_exit.view(),
1551 &BSplineBasisSpec {
1552 degree: bspline_degree,
1553 penalty_order: 2,
1554 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1555 double_penalty: false,
1556 identifiability: BSplineIdentifiability::None,
1557 boundary: OneDimensionalBoundary::Open,
1558 boundary_conditions: BSplineBoundaryConditions::default(),
1559 },
1560 )
1561 .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1562 if penalty_basis.design.ncols() != p_time_full + 1 {
1563 return Err("internal error: ispline penalty dimension mismatch".to_string());
1564 }
1565 let mut penalties = Vec::<Array2<f64>>::new();
1599 for s_mat in &penalty_basis.penalties {
1600 if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1601 continue;
1602 }
1603 let s_increment = s_mat.slice(s![1.., 1..]);
1632 if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1633 return Err(format!(
1634 "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1635 s_increment.nrows(),
1636 s_increment.ncols(),
1637 ));
1638 }
1639 let mut s_full = s_increment.to_owned();
1644 symmetrize_in_place(&mut s_full);
1645 let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1649 for i in 0..p_time_full {
1650 for j in 0..p_time_full {
1651 let mut v = 0.0;
1652 for k in j..p_time_full {
1653 v += s_full[[i, k]];
1654 }
1655 s_mid_full[[i, j]] = v;
1656 }
1657 }
1658 let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1662 for i in 0..p_time_full {
1663 for j in 0..p_time_full {
1664 let mut v = 0.0;
1665 for k in i..p_time_full {
1666 v += s_mid_full[[k, j]];
1667 }
1668 s_full_congruent[[i, j]] = v;
1669 }
1670 }
1671 let mut local = Array2::<f64>::zeros((p_time, p_time));
1673 for (i_new, &i_old) in keep_cols.iter().enumerate() {
1674 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1675 local[[i_new, j_new]] = 0.5
1678 * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1679 }
1680 }
1681 penalties.push(local);
1682 }
1683
1684 for (idx, s_mat) in penalties.iter().enumerate() {
1694 let p = s_mat.nrows();
1695 if p == 0 {
1696 continue;
1697 }
1698 if let Ok((evals, _)) =
1699 gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1700 {
1701 let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1702 "internal error: ispline penalty eigenvalues not contiguous".to_string()
1703 })?;
1704 let max_ev = evals_slice
1705 .iter()
1706 .copied()
1707 .fold(0.0_f64, |a, b| a.max(b.abs()))
1708 .max(1.0);
1709 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1710 let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1711 if min_ev < neg_tol {
1712 return Err(format!(
1713 "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1714 indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1715 {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1716 ));
1717 }
1718 }
1719 }
1720
1721 let nullspace_dims: Vec<usize> = penalties
1725 .iter()
1726 .map(|s_mat| {
1727 let p = s_mat.nrows();
1728 if p == 0 {
1729 return 0;
1730 }
1731 match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1732 Ok((evals, _)) => {
1733 let evals_slice: &[f64] = evals.as_slice().unwrap();
1734 let max_ev = evals_slice
1735 .iter()
1736 .copied()
1737 .fold(0.0_f64, |a, b| a.max(b.abs()))
1738 .max(1.0);
1739 let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1740 evals_slice.iter().filter(|&&e| e <= threshold).count()
1741 }
1742 Err(_) => 0,
1743 }
1744 })
1745 .collect();
1746 Ok(SurvivalTimeBuildOutput {
1747 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1748 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1749 x_derivative_time,
1750 penalties,
1751 nullspace_dims,
1752 basisname: "ispline".to_string(),
1753 degree: Some(degree),
1754 knots: Some(knotvec.to_vec()),
1755 keep_cols: Some(keep_cols),
1756 smooth_lambda: Some(smooth_lambda),
1757 })
1758 }
1759 }
1760}
1761
1762pub fn resolved_survival_time_basis_config_from_build(
1763 basisname: &str,
1764 degree: Option<usize>,
1765 knots: Option<&Vec<f64>>,
1766 keep_cols: Option<&Vec<usize>>,
1767 smooth_lambda: Option<f64>,
1768) -> Result<SurvivalTimeBasisConfig, String> {
1769 match basisname {
1770 "none" => Ok(SurvivalTimeBasisConfig::None),
1771 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1772 "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1773 degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1774 knots: Array1::from_vec(
1775 knots
1776 .cloned()
1777 .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1778 ),
1779 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1780 }),
1781 "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1782 degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1783 knots: Array1::from_vec(
1784 knots
1785 .cloned()
1786 .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1787 ),
1788 keep_cols: keep_cols
1789 .cloned()
1790 .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1791 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1792 }),
1793 other => Err(format!("unsupported survival time basis '{other}'")),
1794 }
1795}
1796
1797pub fn resolve_survival_time_anchor_value(
1798 age_entry: &Array1<f64>,
1799 time_anchor: Option<f64>,
1800) -> Result<f64, String> {
1801 if age_entry.is_empty() {
1802 return Err("survival time anchor requires non-empty entry times".to_string());
1803 }
1804 let anchor = match time_anchor {
1805 Some(t_anchor) => {
1806 if !t_anchor.is_finite() || t_anchor < 0.0 {
1807 return Err(format!(
1808 "survival time anchor must be finite and non-negative, got {t_anchor}"
1809 ));
1810 }
1811 t_anchor
1812 }
1813 None => age_entry
1814 .iter()
1815 .copied()
1816 .min_by(f64::total_cmp)
1817 .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1818 };
1819 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1820}
1821
1822pub fn resolve_survival_marginal_slope_time_anchor_value(
1854 age_entry: &Array1<f64>,
1855 age_exit: &Array1<f64>,
1856 time_anchor: Option<f64>,
1857) -> Result<f64, String> {
1858 if age_entry.is_empty() || age_exit.is_empty() {
1859 return Err(
1860 "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1861 );
1862 }
1863 let anchor = match time_anchor {
1864 Some(t_anchor) => {
1865 if !t_anchor.is_finite() || t_anchor < 0.0 {
1866 return Err(format!(
1867 "survival time anchor must be finite and non-negative, got {t_anchor}"
1868 ));
1869 }
1870 t_anchor
1871 }
1872 None => robust_interior_exit_anchor(age_exit),
1873 };
1874 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1875}
1876
1877fn robust_interior_exit_anchor(age_exit: &Array1<f64>) -> f64 {
1884 let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1885 sorted.sort_by(f64::total_cmp);
1886 let m = sorted.len();
1887 if m == 0 {
1888 return SURVIVAL_TIME_FLOOR;
1889 }
1890 if m % 2 == 1 {
1891 sorted[m / 2]
1892 } else {
1893 0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1894 }
1895}
1896
1897pub fn resolve_survival_transformation_time_anchor_value(
1918 age_entry: &Array1<f64>,
1919 age_exit: &Array1<f64>,
1920 time_anchor: Option<f64>,
1921) -> Result<f64, String> {
1922 if time_anchor.is_some() {
1923 return resolve_survival_time_anchor_value(age_entry, time_anchor);
1924 }
1925 if age_exit.is_empty() {
1926 return Err(
1927 "survival transformation time anchor requires non-empty exit times".to_string(),
1928 );
1929 }
1930 let min_entry = age_entry
1931 .iter()
1932 .copied()
1933 .fold(f64::INFINITY, f64::min);
1934 if min_entry > SURVIVAL_DELAYED_ENTRY_THRESHOLD {
1935 Ok(robust_interior_exit_anchor(age_exit).max(SURVIVAL_TIME_FLOOR))
1936 } else {
1937 resolve_survival_time_anchor_value(age_entry, None)
1938 }
1939}
1940
1941pub fn evaluate_survival_time_basis_row(
1942 age: f64,
1943 cfg: &SurvivalTimeBasisConfig,
1944) -> Result<Array1<f64>, String> {
1945 if !age.is_finite() || age < 0.0 {
1946 return Err(format!(
1947 "survival time basis row requires finite non-negative age, got {age}"
1948 ));
1949 }
1950 let age = age.max(SURVIVAL_TIME_FLOOR);
1951 let log_age = array![age.ln()];
1952 match cfg {
1953 SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1954 SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1955 SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1956 if knots.is_empty() {
1957 return Err(
1958 "survival BSpline anchor evaluation requires resolved knot metadata"
1959 .to_string(),
1960 );
1961 }
1962 let built = build_bspline_basis_1d(
1963 log_age.view(),
1964 &BSplineBasisSpec {
1965 degree: *degree,
1966 penalty_order: 2,
1967 knotspec: BSplineKnotSpec::Provided(knots.clone()),
1968 double_penalty: false,
1969 identifiability: BSplineIdentifiability::None,
1970 boundary: OneDimensionalBoundary::Open,
1971 boundary_conditions: BSplineBoundaryConditions::default(),
1972 },
1973 )
1974 .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1975 Ok(built.design.to_dense().row(0).to_owned())
1976 }
1977 SurvivalTimeBasisConfig::ISpline {
1978 degree,
1979 knots,
1980 keep_cols,
1981 ..
1982 } => {
1983 if knots.is_empty() {
1984 return Err(
1985 "survival ISpline anchor evaluation requires resolved knot metadata"
1986 .to_string(),
1987 );
1988 }
1989 let (basis_arc, _) = create_basis::<Dense>(
1990 log_age.view(),
1991 KnotSource::Provided(knots.view()),
1992 *degree,
1993 BasisOptions::i_spline(),
1994 )
1995 .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1996 let basis = basis_arc.as_ref();
1997 let row = basis.row(0);
1998 if keep_cols.is_empty() {
1999 return Ok(row.to_owned());
2000 }
2001 if keep_cols.iter().any(|&j| j >= row.len()) {
2002 return Err(SurvivalConstructionError::MissingColumn {
2003 reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
2004 }
2005 .into());
2006 }
2007 Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
2008 }
2009 }
2010}
2011
2012pub fn center_survival_time_designs_at_anchor(
2013 design_entry: &mut DesignMatrix,
2014 design_exit: &mut DesignMatrix,
2015 anchor_row: &Array1<f64>,
2016) -> Result<(), String> {
2017 if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
2018 return Err(format!(
2019 "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
2020 design_entry.ncols(),
2021 design_exit.ncols(),
2022 anchor_row.len()
2023 ));
2024 }
2025 fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
2028 let mut dense = dm.to_dense();
2029 for mut row in dense.rows_mut() {
2030 row -= &anchor.view();
2031 }
2032 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
2033 }
2034 center_dense(design_entry, anchor_row);
2035 center_dense(design_exit, anchor_row);
2036 Ok(())
2037}
2038
2039pub fn baseline_offset_theta_partials(
2069 age: f64,
2070 cfg: &SurvivalBaselineConfig,
2071) -> Result<Option<Vec<(f64, f64)>>, String> {
2072 let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2073 else {
2074 return Ok(None);
2075 };
2076
2077 match params {
2078 ValidatedBaselineTarget::Weibull { scale, shape } => {
2079 let eta = shape * (age.ln() - scale.ln());
2088 let o_d = shape / age;
2089 let d_eta_d_log_scale = -shape;
2090 let d_od_d_log_scale = 0.0;
2091 let d_eta_d_log_shape = eta;
2092 let d_od_d_log_shape = o_d;
2093 Ok(Some(vec![
2094 (d_eta_d_log_scale, d_od_d_log_scale),
2095 (d_eta_d_log_shape, d_od_d_log_shape),
2096 ]))
2097 }
2098 ValidatedBaselineTarget::Gompertz { shape, .. } => {
2099 let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2109 Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2110 }
2111 ValidatedBaselineTarget::GompertzMakeham {
2112 rate,
2113 shape,
2114 makeham,
2115 } => {
2116 let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2131 let cum_total = makeham * age + cum_g;
2132 if cum_total <= 0.0 || !cum_total.is_finite() {
2133 return Err(SurvivalConstructionError::DataValidationFailed {
2134 reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2135 }
2136 .into());
2137 }
2138 let inst_total = makeham + inst_g;
2139 let o_d = inst_total / cum_total;
2140 let inv_cum = 1.0 / cum_total;
2141 let d_cum_dlr = cum_g;
2146 let d_inst_dlr = inst_g;
2147 let d_eta_dlr = d_cum_dlr * inv_cum;
2148 let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2149 let (d_cum_dshape, d_inst_dshape) =
2151 gompertz_cumulative_shape_derivative(age, rate, shape);
2152 let d_eta_dshape = d_cum_dshape * inv_cum;
2153 let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2154 let d_cum_dlm = makeham * age;
2157 let d_inst_dlm = makeham;
2158 let d_eta_dlm = d_cum_dlm * inv_cum;
2159 let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2160 Ok(Some(vec![
2161 (d_eta_dlr, d_od_dlr),
2162 (d_eta_dshape, d_od_dshape),
2163 (d_eta_dlm, d_od_dlm),
2164 ]))
2165 }
2166 }
2167}
2168
2169fn baseline_chain_rule_gradient_with_partials<F>(
2197 label: &'static str,
2198 age_entry: ndarray::ArrayView1<'_, f64>,
2199 age_exit: ndarray::ArrayView1<'_, f64>,
2200 age_right: ndarray::ArrayView1<'_, f64>,
2201 cfg: &SurvivalBaselineConfig,
2202 residuals: &crate::survival::OffsetChannelResiduals,
2203 partials: F,
2204) -> Result<Option<Array1<f64>>, String>
2205where
2206 F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2207{
2208 let n = age_exit.len();
2209 if age_entry.len() != n
2210 || age_right.len() != n
2211 || residuals.exit.len() != n
2212 || residuals.entry.len() != n
2213 || residuals.derivative.len() != n
2214 || residuals.right.len() != n
2215 {
2216 return Err(format!(
2217 "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2218 age_entry.len(),
2219 n,
2220 age_right.len(),
2221 residuals.exit.len(),
2222 residuals.entry.len(),
2223 residuals.derivative.len(),
2224 residuals.right.len(),
2225 ));
2226 }
2227 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2230 let theta_dim = match probe_age {
2231 Some(t) => match partials(t, cfg)? {
2232 None => return Ok(None),
2233 Some(v) => v.len(),
2234 },
2235 None => {
2236 return Err(format!("{label}: no valid positive age for dim probe"));
2237 }
2238 };
2239 let mut grad = Array1::<f64>::zeros(theta_dim);
2250 for i in 0..n {
2251 let partials_exit = partials(age_exit[i], cfg)?
2253 .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2254 if partials_exit.len() != theta_dim {
2255 return Err(format!(
2256 "{label}: theta_dim drifted ({} != {})",
2257 partials_exit.len(),
2258 theta_dim
2259 ));
2260 }
2261 let r_x = residuals.exit[i];
2262 let r_d = residuals.derivative[i];
2263 for k in 0..theta_dim {
2264 let (d_eta_dk, d_od_dk) = partials_exit[k];
2265 grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2266 }
2267 let r_e = residuals.entry[i];
2271 if r_e != 0.0 {
2272 let partials_entry = partials(age_entry[i], cfg)?
2273 .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2274 for k in 0..theta_dim {
2275 grad[k] += r_e * partials_entry[k].0;
2276 }
2277 }
2278 let r_r = residuals.right[i];
2287 if r_r != 0.0 {
2288 let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2289 format!("{label}: unexpected None from partials at right boundary")
2290 })?;
2291 if partials_right.len() != theta_dim {
2292 return Err(format!(
2293 "{label}: theta_dim drifted at right boundary ({} != {})",
2294 partials_right.len(),
2295 theta_dim
2296 ));
2297 }
2298 for k in 0..theta_dim {
2299 grad[k] += r_r * partials_right[k].0;
2300 }
2301 }
2302 }
2303 Ok(Some(grad))
2304}
2305
2306pub fn baseline_chain_rule_gradient(
2340 age_entry: ndarray::ArrayView1<'_, f64>,
2341 age_exit: ndarray::ArrayView1<'_, f64>,
2342 age_right: ndarray::ArrayView1<'_, f64>,
2343 cfg: &SurvivalBaselineConfig,
2344 residuals: &crate::survival::OffsetChannelResiduals,
2345) -> Result<Option<Array1<f64>>, String> {
2346 baseline_chain_rule_gradient_with_partials(
2347 "baseline_chain_rule_gradient",
2348 age_entry,
2349 age_exit,
2350 age_right,
2351 cfg,
2352 residuals,
2353 baseline_offset_theta_partials,
2354 )
2355}
2356
2357pub fn marginal_slope_baseline_chain_rule_gradient(
2364 age_entry: ndarray::ArrayView1<'_, f64>,
2365 age_exit: ndarray::ArrayView1<'_, f64>,
2366 cfg: &SurvivalBaselineConfig,
2367 residuals: &crate::survival::OffsetChannelResiduals,
2368) -> Result<Option<Array1<f64>>, String> {
2369 baseline_chain_rule_gradient_with_partials(
2373 "marginal_slope_baseline_chain_rule_gradient",
2374 age_entry,
2375 age_exit,
2376 age_exit,
2377 cfg,
2378 residuals,
2379 marginal_slope_baseline_offset_theta_partials,
2380 )
2381}
2382
2383#[inline]
2387fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2388 if shape.abs() < 1e-10 {
2389 let x = shape * age;
2392 (
2393 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2394 rate * (1.0 + x + 0.5 * x * x),
2395 )
2396 } else {
2397 let shape_age = shape * age;
2398 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2399 let instant_hazard = rate * shape_age.exp();
2400 (cumulative_hazard, instant_hazard)
2401 }
2402}
2403
2404#[inline]
2420fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2421 let x = shape * age;
2422 let dinstg_dshape = rate * age * x.exp();
2423 let dhg_dshape = if x.abs() < 1e-4 {
2432 let t = age;
2433 rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2435 } else {
2436 let e = x.exp();
2438 let em1 = x.exp_m1();
2439 let numerator = age * e * shape - em1;
2440 rate * numerator / (shape * shape)
2441 };
2442 (dhg_dshape, dinstg_dshape)
2443}
2444
2445#[inline]
2450fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2451 if shape.abs() < 1e-10 {
2452 let t = age;
2462 let d_eta = 0.5 * t + shape * t * t / 12.0;
2463 let dlog_od = 0.5 * t - shape * t * t / 12.0;
2464 let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2465 (d_eta, o_d * dlog_od)
2466 } else {
2467 let x = shape * age;
2468 let e = x.exp();
2469 let em1 = x.exp_m1(); let d_eta = -1.0 / shape + age * e / em1;
2471 let o_d = shape * e / em1;
2473 let dlog_od = 1.0 / shape - age / em1;
2474 (d_eta, o_d * dlog_od)
2475 }
2476}
2477
2478#[derive(Clone, Copy, Debug)]
2487enum ValidatedBaselineTarget {
2488 Weibull { scale: f64, shape: f64 },
2489 Gompertz { rate: f64, shape: f64 },
2490 GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2491}
2492
2493fn validated_baseline_params(
2499 age: f64,
2500 cfg: &SurvivalBaselineConfig,
2501 context: &str,
2502) -> Result<Option<ValidatedBaselineTarget>, String> {
2503 if !age.is_finite() || age <= 0.0 {
2504 return Err(format!(
2505 "survival ages must be finite and positive for {context}"
2506 ));
2507 }
2508
2509 match cfg.target {
2510 SurvivalBaselineTarget::Linear => Ok(None),
2511 SurvivalBaselineTarget::Weibull => {
2512 let scale = cfg
2513 .scale
2514 .ok_or_else(|| "weibull missing scale".to_string())?;
2515 let shape = cfg
2516 .shape
2517 .ok_or_else(|| "weibull missing shape".to_string())?;
2518 if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2519 return Err(SurvivalConstructionError::InvalidConfig {
2520 reason: "weibull baseline requires finite positive scale and shape".to_string(),
2521 }
2522 .into());
2523 }
2524 Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2525 }
2526 SurvivalBaselineTarget::Gompertz => {
2527 let rate = cfg
2528 .rate
2529 .ok_or_else(|| "gompertz missing rate".to_string())?;
2530 let shape = cfg
2531 .shape
2532 .ok_or_else(|| "gompertz missing shape".to_string())?;
2533 if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2534 return Err(
2535 "gompertz baseline requires finite positive rate and finite shape".to_string(),
2536 );
2537 }
2538 Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2539 }
2540 SurvivalBaselineTarget::GompertzMakeham => {
2541 let rate = cfg
2542 .rate
2543 .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2544 let shape = cfg
2545 .shape
2546 .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2547 let makeham = cfg
2548 .makeham
2549 .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2550 if !(rate.is_finite()
2551 && shape.is_finite()
2552 && makeham.is_finite()
2553 && rate > 0.0
2554 && makeham > 0.0)
2555 {
2556 return Err(
2557 "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2558 .to_string(),
2559 );
2560 }
2561 Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2562 rate,
2563 shape,
2564 makeham,
2565 }))
2566 }
2567 }
2568}
2569
2570fn survival_hazard_theta_partials(
2571 age: f64,
2572 cfg: &SurvivalBaselineConfig,
2573) -> Result<Option<Vec<(f64, f64)>>, String> {
2574 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2575 return Ok(None);
2576 };
2577
2578 match params {
2579 ValidatedBaselineTarget::Weibull { scale, shape } => {
2580 let log_time_ratio = age.ln() - scale.ln();
2581 let cumulative_hazard = (age / scale).powf(shape);
2582 let instant_hazard = shape * cumulative_hazard / age;
2583 let eta = shape * log_time_ratio;
2584 Ok(Some(vec![
2585 (-shape * cumulative_hazard, -shape * instant_hazard),
2586 (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2587 ]))
2588 }
2589 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2590 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2591 let (d_cum_dshape, d_inst_dshape) =
2592 gompertz_cumulative_shape_derivative(age, rate, shape);
2593 Ok(Some(vec![
2594 (cumulative_hazard, instant_hazard),
2595 (d_cum_dshape, d_inst_dshape),
2596 ]))
2597 }
2598 ValidatedBaselineTarget::GompertzMakeham {
2599 rate,
2600 shape,
2601 makeham,
2602 } => {
2603 let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2604 let (d_cum_dshape, d_inst_dshape) =
2605 gompertz_cumulative_shape_derivative(age, rate, shape);
2606 Ok(Some(vec![
2607 (cum_gompertz, inst_gompertz),
2608 (d_cum_dshape, d_inst_dshape),
2609 (makeham * age, makeham),
2610 ]))
2611 }
2612 }
2613}
2614
2615fn survival_cumulative_and_instant_hazard(
2616 age: f64,
2617 cfg: &SurvivalBaselineConfig,
2618) -> Result<Option<(f64, f64)>, String> {
2619 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2620 return Ok(None);
2621 };
2622
2623 match params {
2624 ValidatedBaselineTarget::Weibull { scale, shape } => {
2625 let cumulative_hazard = (age / scale).powf(shape);
2626 let instant_hazard = shape * cumulative_hazard / age;
2627 Ok(Some((cumulative_hazard, instant_hazard)))
2628 }
2629 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2630 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2631 Ok(Some((cumulative_hazard, instant_hazard)))
2632 }
2633 ValidatedBaselineTarget::GompertzMakeham {
2634 rate,
2635 shape,
2636 makeham,
2637 } => {
2638 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2639 Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2640 }
2641 }
2642}
2643
2644#[derive(Clone, Copy, Debug)]
2645struct MarginalSlopeBaselinePoint {
2646 instant_hazard: f64,
2647 q: f64,
2648 q_t: f64,
2649}
2650
2651fn evaluate_marginal_slope_baseline_point(
2652 age: f64,
2653 cfg: &SurvivalBaselineConfig,
2654) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2655 let Some((cumulative_hazard, instant_hazard)) =
2656 survival_cumulative_and_instant_hazard(age, cfg)?
2657 else {
2658 return Ok(None);
2659 };
2660 if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2661 return Err(format!(
2662 "{} marginal-slope baseline produced non-positive cumulative hazard",
2663 survival_baseline_targetname(cfg.target)
2664 ));
2665 }
2666 if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2667 return Err(format!(
2668 "{} marginal-slope baseline produced non-positive instant hazard",
2669 survival_baseline_targetname(cfg.target)
2670 ));
2671 }
2672 let survival = (-cumulative_hazard).exp();
2673 if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2674 return Err(format!(
2675 "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2676 survival_baseline_targetname(cfg.target)
2677 ));
2678 }
2679 let q = -standard_normal_quantile(survival).map_err(|e| {
2680 format!(
2681 "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2682 survival_baseline_targetname(cfg.target)
2683 )
2684 })?;
2685 let phi_q = normal_pdf(q);
2686 if !(phi_q.is_finite() && phi_q > 0.0) {
2687 return Err(format!(
2688 "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2689 survival_baseline_targetname(cfg.target)
2690 ));
2691 }
2692 Ok(Some(MarginalSlopeBaselinePoint {
2693 instant_hazard,
2694 q,
2695 q_t: instant_hazard * survival / phi_q,
2696 }))
2697}
2698
2699pub fn evaluate_survival_baseline(
2702 age: f64,
2703 cfg: &SurvivalBaselineConfig,
2704) -> Result<(f64, f64), String> {
2705 if !age.is_finite() || age < 0.0 {
2706 return Err(
2707 "survival ages must be finite and non-negative for baseline target evaluation"
2708 .to_string(),
2709 );
2710 }
2711
2712 if age == 0.0 {
2723 return match cfg.target {
2724 SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2725 SurvivalBaselineTarget::Weibull
2726 | SurvivalBaselineTarget::Gompertz
2727 | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2728 };
2729 }
2730
2731 let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2732 return Ok((0.0, 0.0));
2733 };
2734
2735 match params {
2736 ValidatedBaselineTarget::Weibull { scale, shape } => {
2737 let eta = shape * (age.ln() - scale.ln());
2738 let derivative = shape / age;
2739 Ok((eta, derivative))
2740 }
2741 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2742 let (h, inst) = gompertz_hazard_components(age, rate, shape);
2743 if h <= 0.0 || !h.is_finite() {
2744 return Err(if shape.abs() < 1e-10 {
2745 "invalid gompertz baseline at near-zero shape".to_string()
2746 } else {
2747 "gompertz baseline produced non-positive cumulative hazard".to_string()
2748 });
2749 }
2750 let derivative = inst / h;
2751 Ok((h.ln(), derivative))
2752 }
2753 ValidatedBaselineTarget::GompertzMakeham {
2754 rate,
2755 shape,
2756 makeham,
2757 } => {
2758 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2759 let h = makeham * age + h_gompertz;
2760 if h <= 0.0 || !h.is_finite() {
2761 return Err(
2762 "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2763 );
2764 }
2765 let inst = makeham + inst_gompertz;
2766 let derivative = inst / h;
2767 Ok((h.ln(), derivative))
2768 }
2769 }
2770}
2771
2772pub fn evaluate_survival_marginal_slope_baseline(
2778 age: f64,
2779 cfg: &SurvivalBaselineConfig,
2780) -> Result<(f64, f64), String> {
2781 if age == 0.0 {
2793 return Ok((0.0, 0.0));
2794 }
2795 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2796 return Ok((0.0, 0.0));
2797 };
2798 Ok((point.q, point.q_t))
2799}
2800
2801pub fn marginal_slope_baseline_offset_theta_partials(
2814 age: f64,
2815 cfg: &SurvivalBaselineConfig,
2816) -> Result<Option<Vec<(f64, f64)>>, String> {
2817 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2818 return Ok(None);
2819 };
2820 let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2821 .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2822 let a = point.q_t / point.instant_hazard;
2823 let a_log_derivative_factor = point.q * a - 1.0;
2824 Ok(Some(
2825 hazard_partials
2826 .into_iter()
2827 .map(|(d_h_cum, d_h_inst)| {
2828 (
2829 a * d_h_cum,
2830 a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2831 )
2832 })
2833 .collect(),
2834 ))
2835}
2836
2837pub fn marginal_slope_baseline_chain_rule_hessian(
2840 age_entry: ndarray::ArrayView1<'_, f64>,
2841 age_exit: ndarray::ArrayView1<'_, f64>,
2842 cfg: &SurvivalBaselineConfig,
2843 residuals: &crate::survival::OffsetChannelResiduals,
2844 curvatures: &crate::survival::OffsetChannelCurvatures,
2845) -> Result<Option<Array2<f64>>, String> {
2846 let n = age_exit.len();
2847 if age_entry.len() != n
2848 || residuals.exit.len() != n
2849 || residuals.entry.len() != n
2850 || residuals.derivative.len() != n
2851 || curvatures.rows.len() != n
2852 {
2853 return Err(format!(
2854 "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2855 age_entry.len(),
2856 n,
2857 residuals.exit.len(),
2858 residuals.entry.len(),
2859 residuals.derivative.len(),
2860 curvatures.rows.len(),
2861 ));
2862 }
2863 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2864 let dim = match probe_age {
2865 Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2866 None => return Ok(None),
2867 Some(parts) => parts.first.len(),
2868 },
2869 None => {
2870 return Err(
2871 "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2872 .to_string(),
2873 );
2874 }
2875 };
2876 let hessian = RowSet::All.par_try_reduce_fold(
2883 n,
2884 || Array2::<f64>::zeros((dim, dim)),
2885 |mut acc, i, _row_weight| -> Result<Array2<f64>, String> {
2886 let exit_parts =
2887 marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2888 .ok_or_else(|| {
2889 "unexpected None from marginal-slope second partials at exit".to_string()
2890 })?;
2891 if exit_parts.first.len() != dim {
2892 return Err(
2893 "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2894 );
2895 }
2896 let mut entry_parts = None;
2897 if residuals.entry[i] != 0.0 {
2898 entry_parts = Some(
2899 marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2900 .ok_or_else(|| {
2901 "unexpected None from marginal-slope second partials at entry"
2902 .to_string()
2903 })?,
2904 );
2905 }
2906 for a in 0..dim {
2907 for b in 0..dim {
2908 let j_exit_a = exit_parts.first[a].0;
2909 let j_exit_b = exit_parts.first[b].0;
2910 let j_deriv_a = exit_parts.first[a].1;
2911 let j_deriv_b = exit_parts.first[b].1;
2912 let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2913 + residuals.derivative[i] * exit_parts.second[a][b].1;
2914 if let Some(parts) = entry_parts.as_ref() {
2915 value += residuals.entry[i] * parts.second[a][b].0;
2916 }
2917 let curv = curvatures.rows[i];
2918 let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2919 let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2920 let ja = [j_entry_a, j_exit_a, j_deriv_a];
2921 let jb = [j_entry_b, j_exit_b, j_deriv_b];
2922 for u in 0..3 {
2923 for v in 0..3 {
2924 value += ja[u] * curv[u][v] * jb[v];
2925 }
2926 }
2927 acc[[a, b]] += value;
2928 }
2929 }
2930 Ok(acc)
2931 },
2932 |a, b| Ok(a + b),
2933 )?;
2934 Ok(Some(hessian))
2935}
2936
2937struct MarginalSlopeThetaSecondPartials {
2938 first: Vec<(f64, f64)>,
2939 second: Vec<Vec<(f64, f64)>>,
2940}
2941
2942fn marginal_slope_baseline_offset_theta_second_partials(
2943 age: f64,
2944 cfg: &SurvivalBaselineConfig,
2945) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2946 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2947 return Ok(None);
2948 };
2949 let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2950 return Ok(None);
2951 };
2952 let (cum_hazard, instant_hazard) = hazard;
2953 let survival = (-cum_hazard).exp();
2954 let a = survival / normal_pdf(point.q);
2955 let b = point.q * a - 1.0;
2956 let b_factor = a + point.q * b;
2957 let dim = first.len();
2958 let mut first_out = Vec::with_capacity(dim);
2959 let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2960 for i in 0..dim {
2961 let (h_i, inst_i) = first[i];
2962 first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2963 }
2964 for i in 0..dim {
2965 for j in 0..dim {
2966 let (h_i, inst_i) = first[i];
2967 let (h_j, inst_j) = first[j];
2968 let (h_ij, inst_ij) = second[i][j];
2969 let a_j = a * b * h_j;
2970 let b_j = a * h_j * b_factor;
2971 let q_ij = a * h_ij + a * b * h_i * h_j;
2972 let qt_inner_i = inst_i + instant_hazard * b * h_i;
2973 let qt_ij = a_j * qt_inner_i
2974 + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2975 second_out[i][j] = (q_ij, qt_ij);
2976 }
2977 }
2978 Ok(Some(MarginalSlopeThetaSecondPartials {
2979 first: first_out,
2980 second: second_out,
2981 }))
2982}
2983
2984type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2985
2986fn survival_hazard_theta_first_second(
2987 age: f64,
2988 cfg: &SurvivalBaselineConfig,
2989) -> Result<Option<HazardFirstSecond>, String> {
2990 let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2991 return Ok(None);
2992 };
2993 let first = survival_hazard_theta_partials(age, cfg)?
2994 .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2995 let dim = first.len();
2996 let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2997 match cfg.target {
2998 SurvivalBaselineTarget::Linear => return Ok(None),
2999 SurvivalBaselineTarget::Weibull => {
3000 let scale = cfg
3001 .scale
3002 .ok_or_else(|| "weibull missing scale".to_string())?;
3003 let shape = cfg
3004 .shape
3005 .ok_or_else(|| "weibull missing shape".to_string())?;
3006 let log_time_ratio = age.ln() - scale.ln();
3007 let cumulative_hazard = hazard.0;
3008 let instant_hazard = hazard.1;
3009 let eta = shape * log_time_ratio;
3010 second[0][0] = (
3011 shape * shape * cumulative_hazard,
3012 shape * shape * instant_hazard,
3013 );
3014 second[0][1] = (
3015 -shape * cumulative_hazard * (1.0 + eta),
3016 -shape * instant_hazard * (2.0 + eta),
3017 );
3018 second[1][0] = second[0][1];
3019 second[1][1] = (
3020 eta * cumulative_hazard * (1.0 + eta),
3021 (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
3022 );
3023 }
3024 SurvivalBaselineTarget::Gompertz => {
3025 let rate = cfg
3026 .rate
3027 .ok_or_else(|| "gompertz missing rate".to_string())?;
3028 let shape = cfg
3029 .shape
3030 .ok_or_else(|| "gompertz missing shape".to_string())?;
3031 second[0][0] = first[0];
3032 second[0][1] = first[1];
3033 second[1][0] = first[1];
3034 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3035 }
3036 SurvivalBaselineTarget::GompertzMakeham => {
3037 let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
3038 let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
3039 second[0][0] = first[0];
3040 second[0][1] = first[1];
3041 second[1][0] = first[1];
3042 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3043 second[2][2] = first[2];
3044 }
3045 }
3046 Ok(Some((hazard, first, second)))
3047}
3048
3049#[inline]
3050fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3051 let x = shape * age;
3052 if x.abs() < 1e-3 {
3064 let t = age;
3065 (
3066 rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3067 rate * t * t * (1.0 + x + 0.5 * x * x),
3068 )
3069 } else {
3070 let e = x.exp();
3071 let em1 = x.exp_m1();
3072 let n = shape * age * e - em1;
3073 (
3074 rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3075 rate * age * age * e,
3076 )
3077 }
3078}
3079
3080#[derive(Clone, Copy)]
3085enum BaselineOffsetEvaluator {
3086 LogCumulativeHazard,
3087 ProbitSurvival,
3088}
3089
3090impl BaselineOffsetEvaluator {
3091 fn length_error(self) -> String {
3092 match self {
3093 Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3094 reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3095 }
3096 .into(),
3097 Self::ProbitSurvival => {
3098 "survival probit baseline offsets require matching entry/exit lengths".to_string()
3099 }
3100 }
3101 }
3102
3103 fn finite_error(self) -> &'static str {
3104 match self {
3105 Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3106 Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3107 }
3108 }
3109
3110 fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3111 match self {
3112 Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3113 Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3114 }
3115 }
3116
3117 fn exit_is_finite(self, value: f64, age: f64) -> bool {
3118 match self {
3119 Self::LogCumulativeHazard => {
3120 value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3121 }
3122 Self::ProbitSurvival => value.is_finite(),
3123 }
3124 }
3125}
3126
3127fn build_survival_offsets_with_evaluator(
3128 age_entry: &Array1<f64>,
3129 age_exit: &Array1<f64>,
3130 cfg: &SurvivalBaselineConfig,
3131 evaluator: BaselineOffsetEvaluator,
3132) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3133 if age_entry.len() != age_exit.len() {
3134 return Err(evaluator.length_error());
3135 }
3136 let n = age_entry.len();
3137 let triples: Vec<(f64, f64, f64)> = (0..n)
3140 .into_par_iter()
3141 .map(|i| -> Result<(f64, f64, f64), String> {
3142 let entry_age = age_entry[i];
3146 let e0 = if !entry_age.is_finite() {
3147 return Err(SurvivalConstructionError::DataValidationFailed {
3148 reason: format!("non-finite entry age at row {i}"),
3149 }
3150 .into());
3151 } else if entry_age <= 0.0 {
3152 0.0
3153 } else {
3154 evaluator.evaluate(entry_age, cfg)?.0
3155 };
3156 let exit_age = age_exit[i];
3157 let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3158 if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3159 return Err(SurvivalConstructionError::DataValidationFailed {
3160 reason: evaluator.finite_error().to_string(),
3161 }
3162 .into());
3163 }
3164 Ok((e0, e1, d1))
3165 })
3166 .collect::<Result<Vec<_>, String>>()?;
3167 let mut eta_entry = Array1::<f64>::zeros(n);
3168 let mut eta_exit = Array1::<f64>::zeros(n);
3169 let mut derivative_exit = Array1::<f64>::zeros(n);
3170 for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3171 eta_entry[i] = e0;
3172 eta_exit[i] = e1;
3173 derivative_exit[i] = d1;
3174 }
3175 Ok((eta_entry, eta_exit, derivative_exit))
3176}
3177
3178pub fn build_survival_baseline_offsets(
3181 age_entry: &Array1<f64>,
3182 age_exit: &Array1<f64>,
3183 cfg: &SurvivalBaselineConfig,
3184) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3185 build_survival_offsets_with_evaluator(
3186 age_entry,
3187 age_exit,
3188 cfg,
3189 BaselineOffsetEvaluator::LogCumulativeHazard,
3190 )
3191}
3192
3193pub fn build_survival_marginal_slope_baseline_offsets(
3196 age_entry: &Array1<f64>,
3197 age_exit: &Array1<f64>,
3198 cfg: &SurvivalBaselineConfig,
3199) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3200 build_survival_offsets_with_evaluator(
3201 age_entry,
3202 age_exit,
3203 cfg,
3204 BaselineOffsetEvaluator::ProbitSurvival,
3205 )
3206}
3207
3208pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3209 matches!(
3210 inverse_link,
3211 Some(
3212 InverseLink::Standard(StandardLink::Probit)
3213 | InverseLink::LatentCLogLog(_)
3214 | InverseLink::Sas(_)
3215 | InverseLink::BetaLogistic(_)
3216 | InverseLink::Mixture(_)
3217 )
3218 )
3219}
3220
3221pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3222 match likelihood_mode {
3223 SurvivalLikelihoodMode::LocationScale
3224 | SurvivalLikelihoodMode::Latent
3225 | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3226 SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3227 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3228 }
3229}
3230
3231pub fn build_survival_time_offsets_for_likelihood(
3232 age_entry: &Array1<f64>,
3233 age_exit: &Array1<f64>,
3234 baseline_cfg: &SurvivalBaselineConfig,
3235 likelihood_mode: SurvivalLikelihoodMode,
3236 inverse_link: Option<&InverseLink>,
3237) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3238 if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3239 || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3240 && location_scale_uses_probit_survival_baseline(inverse_link))
3241 {
3242 build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3243 } else {
3244 build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3245 }
3246}
3247
3248pub fn add_survival_time_derivative_guard_offset(
3249 age_entry: &Array1<f64>,
3250 age_exit: &Array1<f64>,
3251 anchor_time: f64,
3252 derivative_guard: f64,
3253 eta_offset_entry: &mut Array1<f64>,
3254 eta_offset_exit: &mut Array1<f64>,
3255 derivative_offset_exit: &mut Array1<f64>,
3256) -> Result<(), String> {
3257 if derivative_guard <= 0.0 {
3258 return Ok(());
3259 }
3260 let n = age_entry.len();
3261 if age_exit.len() != n
3262 || eta_offset_entry.len() != n
3263 || eta_offset_exit.len() != n
3264 || derivative_offset_exit.len() != n
3265 {
3266 return Err(SurvivalConstructionError::IncompatibleDimensions {
3267 reason: "survival derivative-guard offset lengths must match".to_string(),
3268 }
3269 .into());
3270 }
3271 for i in 0..n {
3272 eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3273 eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3274 derivative_offset_exit[i] += derivative_guard;
3275 }
3276 Ok(())
3277}
3278
3279#[derive(Clone, Debug)]
3280pub struct LatentSurvivalBaselineOffsets {
3281 pub loaded_eta_entry: Array1<f64>,
3282 pub loaded_eta_exit: Array1<f64>,
3283 pub loaded_derivative_exit: Array1<f64>,
3284 pub unloaded_mass_entry: Array1<f64>,
3285 pub unloaded_mass_exit: Array1<f64>,
3286 pub unloaded_hazard_exit: Array1<f64>,
3287}
3288
3289pub fn build_latent_survival_baseline_offsets(
3290 age_entry: &Array1<f64>,
3291 age_exit: &Array1<f64>,
3292 cfg: &SurvivalBaselineConfig,
3293 loading: HazardLoading,
3294) -> Result<LatentSurvivalBaselineOffsets, String> {
3295 if age_entry.len() != age_exit.len() {
3296 return Err(
3297 "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3298 );
3299 }
3300
3301 fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3302 if shape.abs() < 1e-10 {
3303 let x = shape * age;
3310 return (
3311 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3312 rate * (1.0 + x + 0.5 * x * x),
3313 );
3314 }
3315 let shape_age = shape * age;
3316 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3317 let instant_hazard = rate * shape_age.exp();
3318 (cumulative_hazard, instant_hazard)
3319 }
3320
3321 let n = age_entry.len();
3322
3323 let rows: Vec<[f64; 6]> = (0..n)
3326 .into_par_iter()
3327 .map(|i| -> Result<[f64; 6], String> {
3328 let entry = age_entry[i];
3329 let exit = age_exit[i];
3330 if !entry.is_finite()
3331 || !exit.is_finite()
3332 || entry <= 0.0
3333 || exit <= 0.0
3334 || exit < entry
3335 {
3336 return Err(format!(
3337 "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3338 i + 1
3339 ));
3340 }
3341 match loading {
3342 HazardLoading::Full => {
3343 let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3344 let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3345 Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3346 }
3347 HazardLoading::LoadedVsUnloaded => {
3348 if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3349 return Err(format!(
3350 "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3351 survival_baseline_targetname(cfg.target)
3352 ));
3353 }
3354 let rate = cfg.rate.ok_or_else(|| {
3355 "gompertz-makeham latent survival is missing baseline rate".to_string()
3356 })?;
3357 let shape = cfg.shape.ok_or_else(|| {
3358 "gompertz-makeham latent survival is missing baseline shape".to_string()
3359 })?;
3360 let makeham = cfg.makeham.ok_or_else(|| {
3361 "gompertz-makeham latent survival is missing baseline makeham".to_string()
3362 })?;
3363 let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3364 let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3365 if !(loaded_entry.is_finite()
3366 && loaded_entry > 0.0
3367 && loaded_exit.is_finite()
3368 && loaded_exit > 0.0
3369 && loaded_hazard.is_finite()
3370 && loaded_hazard > 0.0)
3371 {
3372 return Err(format!(
3373 "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3374 i + 1
3375 ));
3376 }
3377 Ok([
3378 loaded_entry.ln(),
3379 loaded_exit.ln(),
3380 loaded_hazard / loaded_exit,
3381 makeham * entry,
3382 makeham * exit,
3383 makeham,
3384 ])
3385 }
3386 }
3387 })
3388 .collect::<Result<Vec<_>, String>>()?;
3389
3390 let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3391 let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3392 let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3393 let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3394 let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3395 let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3396 for (i, row) in rows.into_iter().enumerate() {
3397 loaded_eta_entry[i] = row[0];
3398 loaded_eta_exit[i] = row[1];
3399 loaded_derivative_exit[i] = row[2];
3400 unloaded_mass_entry[i] = row[3];
3401 unloaded_mass_exit[i] = row[4];
3402 unloaded_hazard_exit[i] = row[5];
3403 }
3404
3405 Ok(LatentSurvivalBaselineOffsets {
3406 loaded_eta_entry,
3407 loaded_eta_exit,
3408 loaded_derivative_exit,
3409 unloaded_mass_entry,
3410 unloaded_mass_exit,
3411 unloaded_hazard_exit,
3412 })
3413}
3414
3415pub fn build_survival_timewiggle_derivative_design(
3420 eta_exit: &Array1<f64>,
3421 derivative_exit: &Array1<f64>,
3422 knots: &Array1<f64>,
3423 degree: usize,
3424) -> Result<Array2<f64>, String> {
3425 let mut design_derivative_exit =
3426 monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3427 for i in 0..design_derivative_exit.nrows() {
3428 let chain = derivative_exit[i];
3429 for j in 0..design_derivative_exit.ncols() {
3430 design_derivative_exit[[i, j]] *= chain;
3431 }
3432 }
3433 Ok(design_derivative_exit)
3434}
3435
3436pub fn build_survival_timewiggle_from_baseline(
3446 eta_entry: &Array1<f64>,
3447 eta_exit: &Array1<f64>,
3448 derivative_exit: &Array1<f64>,
3449 cfg: &LinkWiggleFormulaSpec,
3450) -> Result<SurvivalTimeWiggleBuild, String> {
3451 if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3452 return Err(
3453 "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3454 );
3455 }
3456 let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3459 && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3460 && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3461 if all_zero {
3462 return Err(
3463 "timewiggle requires a non-linear scalar survival baseline target; \
3464 the provided baseline offsets are all zero (linear baseline)"
3465 .to_string(),
3466 );
3467 }
3468 let n = eta_exit.len();
3469 let mut seed = Array1::<f64>::zeros(2 * n);
3470 for i in 0..n {
3471 seed[i] = eta_entry[i];
3472 seed[n + i] = eta_exit[i];
3473 }
3474 let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3478 let wiggle_cfg = WiggleBlockConfig {
3479 degree: cfg.degree,
3480 num_internal_knots: cfg.num_internal_knots,
3481 penalty_order: primary_order,
3482 double_penalty: cfg.double_penalty,
3483 };
3484 let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3485 append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3486 let ncols = combined_block.design.ncols();
3487 Ok(SurvivalTimeWiggleBuild {
3488 nullspace_dims: combined_block.nullspace_dims.clone(),
3489 penalties: {
3490 combined_block
3491 .penalties
3492 .into_iter()
3493 .map(|ps| ps.to_global(ncols))
3494 .collect()
3495 },
3496 knots,
3497 degree: cfg.degree,
3498 ncols,
3499 })
3500}
3501
3502pub fn append_zero_tail_columns(
3503 x_entry: &mut DesignMatrix,
3504 x_exit: &mut DesignMatrix,
3505 x_derivative: &mut DesignMatrix,
3506 tail_cols: usize,
3507) {
3508 if tail_cols == 0 {
3509 return;
3510 }
3511 fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3514 let old = dm.to_dense();
3515 let n = old.nrows();
3516 let p_base = old.ncols();
3517 let mut out = Array2::<f64>::zeros((n, p_base + tail));
3518 out.slice_mut(s![.., 0..p_base]).assign(&old);
3519 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3520 }
3521 append_dense(x_entry, tail_cols);
3522 append_dense(x_exit, tail_cols);
3523 append_dense(x_derivative, tail_cols);
3524}
3525
3526pub fn build_time_varying_survival_covariate_template(
3537 age_entry: &Array1<f64>,
3538 age_exit: &Array1<f64>,
3539 time_k: usize,
3540 time_degree: usize,
3541 block_name: &str,
3542) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3543 if time_k < time_degree + 1 {
3544 return Err(format!(
3545 "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3546 time_degree + 1
3547 ));
3548 }
3549 let num_internal_knots = time_k - (time_degree + 1);
3550
3551 let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3552 let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3553
3554 let time_spec = BSplineBasisSpec {
3555 degree: time_degree,
3556 penalty_order: 2,
3557 knotspec: BSplineKnotSpec::Automatic {
3558 num_internal_knots: Some(num_internal_knots),
3559 placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3560 },
3561 double_penalty: false,
3562 identifiability: BSplineIdentifiability::None,
3563 boundary: OneDimensionalBoundary::Open,
3564 boundary_conditions: BSplineBoundaryConditions::default(),
3565 };
3566
3567 let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3568 .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3569 let time_design_exit = time_build.design.to_dense();
3570
3571 let knots = match &time_build.metadata {
3572 BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3573 _ => {
3574 return Err(format!(
3575 "{block_name} time-margin basis returned unexpected metadata type"
3576 ));
3577 }
3578 };
3579
3580 let time_build_entry = build_bspline_basis_1d(
3581 log_entry.view(),
3582 &BSplineBasisSpec {
3583 degree: time_degree,
3584 penalty_order: 2,
3585 knotspec: BSplineKnotSpec::Provided(knots.clone()),
3586 double_penalty: false,
3587 identifiability: BSplineIdentifiability::None,
3588 boundary: OneDimensionalBoundary::Open,
3589 boundary_conditions: BSplineBoundaryConditions::default(),
3590 },
3591 )
3592 .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3593 let time_design_entry = time_build_entry.design.to_dense();
3594 let p_time = time_design_exit.ncols();
3595 let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3596 time_design_derivative_exit
3600 .as_slice_mut()
3601 .expect("zeros are contiguous")
3602 .par_chunks_mut(p_time)
3603 .enumerate()
3604 .try_for_each(|(i, row_out)| -> Result<(), String> {
3605 let mut deriv_buf = vec![0.0_f64; p_time];
3606 evaluate_bspline_derivative_scalar(
3607 log_exit[i],
3608 knots.view(),
3609 time_degree,
3610 &mut deriv_buf,
3611 )
3612 .map_err(|e| {
3613 format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3614 })?;
3615 let chain = 1.0 / age_exit[i].max(1e-12);
3616 for j in 0..p_time {
3617 row_out[j] = deriv_buf[j] * chain;
3618 }
3619 Ok(())
3620 })?;
3621
3622 Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3623 time_basis_entry: time_design_entry,
3624 time_basis_exit: time_design_exit,
3625 time_basis_derivative_exit: time_design_derivative_exit,
3626 time_penalties: time_build.penalties,
3627 })
3628}
3629
3630#[cfg(test)]
3631mod tests {
3632 use super::{
3633 SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3634 baseline_chain_rule_gradient, baseline_offset_theta_partials,
3635 build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3636 build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3637 evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3638 gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3639 marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3640 marginal_slope_baseline_offset_theta_partials,
3641 optimize_survival_baseline_config_with_gradient,
3642 optimize_survival_baseline_config_with_gradient_only,
3643 resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3644 survival_baseline_theta_from_config,
3645 };
3646 use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3647 use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3648 use crate::probability::normal_cdf;
3649 use ndarray::{Array1, Array2, array};
3650
3651 #[test]
3652 fn survival_timewiggle_keeps_requested_order_one_penalty() {
3653 let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3654 let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3655 let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3656 let cfg = LinkWiggleFormulaSpec {
3657 degree: 3,
3658 num_internal_knots: 4,
3659 penalty_orders: vec![1, 2, 3],
3660 double_penalty: false,
3661 };
3662
3663 let build =
3664 build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3665 .expect("build survival timewiggle");
3666
3667 assert_eq!(build.penalties.len(), 3);
3668 assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3669 assert!(build.ncols > 0);
3670 }
3671
3672 #[test]
3673 fn marginal_slope_time_anchor_defaults_to_median_exit() {
3674 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3675 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3676 let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3677 .expect("resolve marginal-slope default time anchor");
3678
3679 assert!(
3680 (anchor - 19.0).abs() <= 1e-12,
3681 "marginal-slope default anchor should be median exit, got {anchor}"
3682 );
3683 }
3684
3685 #[test]
3686 fn marginal_slope_time_anchor_honors_explicit_value() {
3687 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3688 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3689 let anchor =
3690 resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3691 .expect("resolve explicit marginal-slope time anchor");
3692
3693 assert!(
3694 (anchor - 7.5).abs() <= 1e-12,
3695 "explicit marginal-slope anchor should round-trip, got {anchor}"
3696 );
3697 }
3698
3699 #[test]
3710 fn baseline_optimizer_contracts_agree_on_shared_surface() {
3711 let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3716 let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3717
3718 let initial = SurvivalBaselineConfig {
3721 target: SurvivalBaselineTarget::Weibull,
3722 scale: Some(1.0),
3723 shape: Some(1.0),
3724 rate: None,
3725 makeham: None,
3726 };
3727
3728 let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3731 survival_baseline_theta_from_config(cfg)
3732 .expect("config→θ")
3733 .expect("Weibull config has a θ")
3734 };
3735
3736 let curvature_cost = curvature.clone();
3739 let star_cost = theta_star.clone();
3740 let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3741 let theta = survival_baseline_theta_from_config(cfg)?
3742 .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3743 let d = &theta - &star_cost;
3744 let ad = curvature_cost.dot(&d);
3745 Ok(0.5 * d.dot(&ad))
3746 };
3747
3748 let curvature_grad = curvature.clone();
3749 let star_grad = theta_star.clone();
3750 let cost_for_grad = cost_at.clone();
3751 let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3752 &initial,
3753 "baseline parity (gradient-only)",
3754 move |cfg| {
3755 let cost = cost_for_grad(cfg)?;
3756 let theta = survival_baseline_theta_from_config(cfg)?
3757 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3758 let gradient = curvature_grad.dot(&(&theta - &star_grad));
3759 Ok((cost, gradient))
3760 },
3761 )
3762 .expect("gradient-only baseline optimization converges");
3763
3764 let curvature_hess = curvature.clone();
3765 let star_hess = theta_star.clone();
3766 let cost_for_hess = cost_at.clone();
3767 let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3768 &initial,
3769 "baseline parity (gradient+Hessian)",
3770 move |cfg| {
3771 let cost = cost_for_hess(cfg)?;
3772 let theta = survival_baseline_theta_from_config(cfg)?
3773 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3774 let gradient = curvature_hess.dot(&(&theta - &star_hess));
3775 Ok((cost, gradient, curvature_hess.clone()))
3776 },
3777 )
3778 .expect("gradient+Hessian baseline optimization converges");
3779
3780 let theta_grad_only = recovered_theta(&result_grad_only);
3781 let theta_grad_hess = recovered_theta(&result_grad_hess);
3782
3783 for (label, theta) in [
3786 ("gradient-only", &theta_grad_only),
3787 ("gradient+Hessian", &theta_grad_hess),
3788 ] {
3789 let err = (theta - &theta_star)
3790 .mapv(f64::abs)
3791 .fold(0.0_f64, |a, &v| a.max(v));
3792 assert!(
3793 err <= 2e-3,
3794 "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3795 );
3796 }
3797
3798 let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3802 (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3803 };
3804 assert!(
3805 pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3806 "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3807 );
3808 }
3809
3810 #[test]
3811 fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3812 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3813 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3814 let requested_degree = 3;
3815 let num_internal_knots = 1;
3816
3817 let built = build_survival_time_basis(
3818 &age_entry,
3819 &age_exit,
3820 SurvivalTimeBasisConfig::ISpline {
3821 degree: requested_degree,
3822 knots: Array1::zeros(0),
3823 keep_cols: Vec::new(),
3824 smooth_lambda: 1e-2,
3825 },
3826 Some((num_internal_knots, 1e-2)),
3827 )
3828 .expect("automatic cubic ispline with one interior knot builds");
3829
3830 let working_degree = requested_degree + 1;
3831 let knots = built.knots.expect("resolved ispline knots");
3832 assert_eq!(
3833 knots.len(),
3834 num_internal_knots + 2 * (working_degree + 1),
3835 "I-spline automatic knots must be clamped for the working B-spline degree"
3836 );
3837 assert_eq!(built.degree, Some(requested_degree));
3838 assert!(built.x_exit_time.ncols() > 0);
3839 assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3840 assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3841 }
3842
3843 #[test]
3844 fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3845 let age_entry = array![1.0_f64, 1.0, 1.0];
3846 let age_exit = array![4.0_f64, 4.0, 4.0];
3847 let left = 1.0_f64.ln();
3848 let right = 4.0_f64.ln();
3849 let mid = left + 0.5 * (right - left);
3850 let knots = array![left, left, left, left, mid, right, right, right, right];
3851
3852 let built = build_survival_time_basis(
3853 &age_entry,
3854 &age_exit,
3855 SurvivalTimeBasisConfig::ISpline {
3856 degree: 2,
3857 knots,
3858 keep_cols: Vec::new(),
3859 smooth_lambda: 1e-2,
3860 },
3861 None,
3862 )
3863 .expect("build right-boundary ispline time basis");
3864
3865 let derivative = built.x_derivative_time.as_dense_cow();
3866 let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3867 assert!(
3868 max_abs > 1e-8,
3869 "right-boundary I-spline derivative must use the left-hand endpoint slope"
3870 );
3871 for row in derivative.rows() {
3872 assert!(
3873 row.iter().any(|v| *v > 1e-8),
3874 "each row at the right boundary needs a positive hazard derivative"
3875 );
3876 }
3877 }
3878
3879 #[test]
3880 fn ispline_time_penalty_is_psd_under_nontrivial_keep_cols() {
3881 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3900 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3901 let left = 1.0_f64.ln();
3902 let right = 21.0_f64.ln();
3903 let q1 = left + 0.25 * (right - left);
3904 let mid = left + 0.5 * (right - left);
3905 let q3 = left + 0.75 * (right - left);
3906 let knots = array![
3910 left, left, left, left, q1, mid, q3, right, right, right, right
3911 ];
3912
3913 let full = build_survival_time_basis(
3915 &age_entry,
3916 &age_exit,
3917 SurvivalTimeBasisConfig::ISpline {
3918 degree: 2,
3919 knots: knots.clone(),
3920 keep_cols: Vec::new(),
3921 smooth_lambda: 1e-2,
3922 },
3923 None,
3924 )
3925 .expect("build full-width ispline time basis");
3926 let p_time_full = full
3927 .keep_cols
3928 .as_ref()
3929 .map(|k| k.len())
3930 .unwrap_or_else(|| full.x_exit_time.ncols());
3931 assert!(
3932 p_time_full >= 3,
3933 "test needs at least 3 shape-varying columns to drop an interior one; got {p_time_full}"
3934 );
3935
3936 let keep_cols: Vec<usize> = (0..p_time_full).filter(|&j| j != 1).collect();
3939
3940 let built = build_survival_time_basis(
3941 &age_entry,
3942 &age_exit,
3943 SurvivalTimeBasisConfig::ISpline {
3944 degree: 2,
3945 knots,
3946 keep_cols: keep_cols.clone(),
3947 smooth_lambda: 1e-2,
3948 },
3949 None,
3950 )
3951 .expect(
3952 "reduced ispline penalty must build (PSD contract must accept the \
3953 congruence-first / select-second ordering)",
3954 );
3955
3956 assert_eq!(
3957 built.penalties.len(),
3958 1,
3959 "the ispline time basis should carry exactly one curvature penalty"
3960 );
3961 let s = &built.penalties[0];
3962 assert_eq!(s.nrows(), keep_cols.len());
3963 assert_eq!(s.ncols(), keep_cols.len());
3964
3965 let (evals, _) =
3966 gam_linalg::faer_ndarray::FaerEigh::eigh(s, faer::Side::Lower).expect("eigh of penalty");
3967 let evals_slice = evals.as_slice().expect("contiguous eigenvalues");
3968 let max_abs = evals_slice
3969 .iter()
3970 .copied()
3971 .fold(0.0_f64, |a, b| a.max(b.abs()))
3972 .max(1.0);
3973 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
3974 let tol = -100.0 * (s.nrows() as f64) * f64::EPSILON * max_abs;
3975 assert!(
3976 min_ev >= tol,
3977 "reduced I-spline time penalty must be PSD (gam#979): min eigenvalue \
3978 {min_ev:.3e} < tol {tol:.3e}, max|eig| {max_abs:.3e}"
3979 );
3980 }
3981
3982 #[test]
3983 fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3984 let cfg = SurvivalBaselineConfig {
3985 target: SurvivalBaselineTarget::GompertzMakeham,
3986 scale: None,
3987 shape: Some(0.07),
3988 rate: Some(0.012),
3989 makeham: Some(0.003),
3990 };
3991 let age = 11.5;
3992 let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3993 .expect("evaluate marginal-slope gompertz-makeham baseline");
3994 let shape = cfg.shape.expect("shape");
3995 let rate = cfg.rate.expect("rate");
3996 let makeham = cfg.makeham.expect("makeham");
3997 let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3998 let instant_hazard = makeham + rate * (shape * age).exp();
3999 let expected_survival = (-cumulative_hazard).exp();
4000 let actual_survival = normal_cdf(-q);
4001 assert!((actual_survival - expected_survival).abs() <= 1e-12);
4002
4003 let h = 1e-5;
4004 let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
4005 .expect("q plus")
4006 .0;
4007 let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
4008 .expect("q minus")
4009 .0;
4010 let fd = (q_plus - q_minus) / (2.0 * h);
4011 assert!((q_derivative - fd).abs() <= 1e-7);
4012 assert!(instant_hazard > 0.0);
4013 }
4014
4015 #[test]
4016 fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
4017 let configs = [
4026 SurvivalBaselineConfig {
4027 target: SurvivalBaselineTarget::Linear,
4028 scale: None,
4029 shape: None,
4030 rate: None,
4031 makeham: None,
4032 },
4033 SurvivalBaselineConfig {
4034 target: SurvivalBaselineTarget::Weibull,
4035 scale: Some(2.5),
4036 shape: Some(1.3),
4037 rate: None,
4038 makeham: None,
4039 },
4040 SurvivalBaselineConfig {
4041 target: SurvivalBaselineTarget::Gompertz,
4042 scale: None,
4043 shape: Some(0.05),
4044 rate: Some(0.01),
4045 makeham: None,
4046 },
4047 SurvivalBaselineConfig {
4048 target: SurvivalBaselineTarget::GompertzMakeham,
4049 scale: None,
4050 shape: Some(0.07),
4051 rate: Some(0.012),
4052 makeham: Some(0.003),
4053 },
4054 ];
4055 for cfg in &configs {
4056 let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
4059 .expect("marginal-slope baseline must be evaluable at the origin");
4060 assert_eq!(q0, 0.0);
4061 assert_eq!(q0_derivative, 0.0);
4062
4063 let (eta0, eta0_derivative) =
4067 evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
4068 assert!(eta0_derivative.is_finite());
4069 assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
4070
4071 let age_entry = array![0.0, 0.0];
4075 let age_exit = array![0.0, 1.5];
4076 let (entry, exit, derivative) =
4077 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
4078 .expect("probit baseline offsets must build through the origin");
4079 assert!(entry.iter().all(|v| v.is_finite()));
4080 assert!(exit.iter().all(|v| v.is_finite()));
4081 assert!(derivative.iter().all(|v| v.is_finite()));
4082 assert_eq!(exit[0], 0.0);
4084 }
4085 }
4086
4087 #[test]
4088 fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
4089 let cfg = SurvivalBaselineConfig {
4090 target: SurvivalBaselineTarget::GompertzMakeham,
4091 scale: None,
4092 shape: Some(0.03),
4093 rate: Some(0.01),
4094 makeham: Some(0.002),
4095 };
4096 let age_entry = array![2.0, 4.0];
4097 let age_exit = array![5.0, 9.0];
4098 let (entry, exit, derivative) =
4099 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
4100 .expect("marginal-slope baseline offsets");
4101 for i in 0..age_entry.len() {
4102 let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
4103 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4104 * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
4105 let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
4106 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4107 * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
4108 assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
4109 assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
4110 assert!(derivative[i].is_finite() && derivative[i] > 0.0);
4111 }
4112 }
4113
4114 fn fd_marginal_slope_baseline_offset(
4115 age: f64,
4116 cfg: &SurvivalBaselineConfig,
4117 steps: &[f64],
4118 ) -> Vec<(f64, f64)> {
4119 let theta = survival_baseline_theta_from_config(cfg)
4120 .expect("theta")
4121 .expect("non-linear baseline");
4122 assert_eq!(
4123 steps.len(),
4124 theta.len(),
4125 "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
4126 );
4127 (0..theta.len())
4128 .map(|k| {
4129 let h = steps[k];
4130 let mut theta_plus = theta.clone();
4131 theta_plus[k] += h;
4132 let mut theta_minus = theta.clone();
4133 theta_minus[k] -= h;
4134 let cfg_plus =
4135 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4136 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4137 .expect("minus cfg");
4138 let (q_p, qt_p) =
4139 evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
4140 let (q_m, qt_m) =
4141 evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
4142 ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
4143 })
4144 .collect()
4145 }
4146
4147 #[test]
4148 fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
4149 let cfg = SurvivalBaselineConfig {
4150 target: SurvivalBaselineTarget::GompertzMakeham,
4151 scale: None,
4152 shape: Some(0.04),
4153 rate: Some(0.013),
4154 makeham: Some(0.002),
4155 };
4156 let age = 17.0;
4157 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4158 .expect("partials")
4159 .expect("nonlinear");
4160 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
4161 assert_eq!(analytic.len(), fd.len());
4162 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4163 assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
4164 assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
4165 }
4166 }
4167
4168 #[test]
4169 fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4170 let cfg = SurvivalBaselineConfig {
4171 target: SurvivalBaselineTarget::GompertzMakeham,
4172 scale: None,
4173 shape: Some(1e-14),
4174 rate: Some(0.013),
4175 makeham: Some(0.002),
4176 };
4177 let age = 17.0;
4178 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4179 .expect("partials")
4180 .expect("nonlinear");
4181 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4182 assert_eq!(analytic.len(), fd.len());
4183 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4184 assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4185 assert_close(
4186 *aqt,
4187 *fqt,
4188 1e-5,
4189 &format!("near-zero gm-probit q' theta[{k}]"),
4190 );
4191 }
4192 }
4193
4194 fn shifted_quadratic_offset_residuals(
4195 age_entry: ndarray::ArrayView1<'_, f64>,
4196 age_exit: ndarray::ArrayView1<'_, f64>,
4197 base_cfg: &SurvivalBaselineConfig,
4198 candidate_cfg: &SurvivalBaselineConfig,
4199 base: &OffsetChannelResiduals,
4200 curvatures: &OffsetChannelCurvatures,
4201 ) -> OffsetChannelResiduals {
4202 let n = age_exit.len();
4203 let mut entry = base.entry.clone();
4204 let mut exit = base.exit.clone();
4205 let mut derivative = base.derivative.clone();
4206 for row in 0..n {
4207 let (_, base_exit, base_deriv) =
4208 baseline_marginal_slope_channels(age_exit[row], base_cfg);
4209 let (_, cand_exit, cand_deriv) =
4210 baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4211 let base_entry = if base.entry[row] == 0.0 {
4212 0.0
4213 } else {
4214 baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4215 };
4216 let cand_entry = if base.entry[row] == 0.0 {
4217 0.0
4218 } else {
4219 baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4220 };
4221 let delta = [
4222 cand_entry - base_entry,
4223 cand_exit - base_exit,
4224 cand_deriv - base_deriv,
4225 ];
4226 let mut shift = [0.0; 3];
4227 for i in 0..3 {
4228 for j in 0..3 {
4229 shift[i] += curvatures.rows[row][i][j] * delta[j];
4230 }
4231 }
4232 if base.entry[row] != 0.0 {
4233 entry[row] += shift[0];
4234 }
4235 exit[row] += shift[1];
4236 derivative[row] += shift[2];
4237 }
4238 OffsetChannelResiduals {
4239 entry,
4240 exit,
4241 derivative,
4242 right: base.right.clone(),
4243 }
4244 }
4245
4246 fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4247 let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4248 (q, q, q_t)
4249 }
4250
4251 #[test]
4252 fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4253 let cfg = SurvivalBaselineConfig {
4254 target: SurvivalBaselineTarget::GompertzMakeham,
4255 scale: None,
4256 shape: Some(0.025),
4257 rate: Some(0.012),
4258 makeham: Some(0.003),
4259 };
4260 let theta = survival_baseline_theta_from_config(&cfg)
4261 .expect("theta")
4262 .expect("nonlinear");
4263 let age_entry = array![2.5, 0.0, 5.0];
4264 let age_exit = array![7.5, 11.0, 15.0];
4265 let base_residuals = OffsetChannelResiduals {
4266 entry: array![0.2, 0.0, -0.1],
4267 exit: array![0.6, -0.3, 0.4],
4268 derivative: array![-0.5, 0.25, 0.15],
4269 right: Array1::<f64>::zeros(3),
4270 };
4271 let curvatures = OffsetChannelCurvatures {
4272 rows: vec![
4273 [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4274 [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4275 [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4276 ],
4277 };
4278 let analytic = marginal_slope_baseline_chain_rule_hessian(
4279 age_entry.view(),
4280 age_exit.view(),
4281 &cfg,
4282 &base_residuals,
4283 &curvatures,
4284 )
4285 .expect("hessian")
4286 .expect("nonlinear");
4287
4288 let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4289 let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4290 .expect("candidate cfg");
4291 let residuals = shifted_quadratic_offset_residuals(
4292 age_entry.view(),
4293 age_exit.view(),
4294 &cfg,
4295 &candidate,
4296 &base_residuals,
4297 &curvatures,
4298 );
4299 marginal_slope_baseline_chain_rule_gradient(
4300 age_entry.view(),
4301 age_exit.view(),
4302 &candidate,
4303 &residuals,
4304 )
4305 .expect("gradient")
4306 .expect("nonlinear")
4307 };
4308
4309 for j in 0..theta.len() {
4310 let step = if j == 1 { 2e-5 } else { 1e-5 };
4311 let mut plus = theta.clone();
4312 plus[j] += step;
4313 let mut minus = theta.clone();
4314 minus[j] -= step;
4315 let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4316 for i in 0..theta.len() {
4317 assert_close(
4318 analytic[[i, j]],
4319 fd_col[i],
4320 2e-5,
4321 &format!("baseline Hessian ({i},{j})"),
4322 );
4323 }
4324 }
4325 }
4326
4327 #[test]
4328 fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4329 let cfg = SurvivalBaselineConfig {
4330 target: SurvivalBaselineTarget::GompertzMakeham,
4331 scale: None,
4332 shape: Some(0.03),
4333 rate: Some(0.01),
4334 makeham: Some(0.002),
4335 };
4336 let age_entry = array![3.0, 6.0];
4337 let age_exit = array![8.0, 12.0];
4338 let residuals = OffsetChannelResiduals {
4339 exit: array![0.7, -0.2],
4340 entry: array![0.1, 0.4],
4341 derivative: array![1.3, -0.6],
4342 right: Array1::<f64>::zeros(2),
4343 };
4344 let grad = marginal_slope_baseline_chain_rule_gradient(
4345 age_entry.view(),
4346 age_exit.view(),
4347 &cfg,
4348 &residuals,
4349 )
4350 .expect("gradient")
4351 .expect("nonlinear");
4352
4353 let mut expected = Array1::<f64>::zeros(3);
4354 for i in 0..age_exit.len() {
4355 let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4356 .expect("exit partials")
4357 .expect("nonlinear");
4358 let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4359 .expect("entry partials")
4360 .expect("nonlinear");
4361 for k in 0..3 {
4362 expected[k] += residuals.exit[i] * exit_partials[k].0
4363 + residuals.derivative[i] * exit_partials[k].1
4364 + residuals.entry[i] * entry_partials[k].0;
4365 }
4366 }
4367 for k in 0..3 {
4368 assert_close(
4369 grad[k],
4370 expected[k],
4371 1e-12,
4372 &format!("gm-probit chain gradient theta[{k}]"),
4373 );
4374 }
4375 }
4376
4377 #[test]
4387 fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4388 let cfg = SurvivalBaselineConfig {
4389 target: SurvivalBaselineTarget::GompertzMakeham,
4390 scale: None,
4391 shape: Some(0.028),
4392 rate: Some(0.011),
4393 makeham: Some(0.0025),
4394 };
4395 let age_entry = array![3.0, 0.0, 5.5];
4398 let age_exit = array![8.0, 12.0, 16.0];
4399 let residuals = OffsetChannelResiduals {
4400 exit: array![0.7, -0.2, 0.45],
4401 entry: array![0.1, 0.0, -0.3],
4402 derivative: array![1.3, -0.6, 0.2],
4403 right: Array1::<f64>::zeros(3),
4404 };
4405
4406 let reference_gradient = |partials: &dyn Fn(
4409 f64,
4410 &SurvivalBaselineConfig,
4411 )
4412 -> Result<Option<Vec<(f64, f64)>>, String>|
4413 -> Array1<f64> {
4414 let theta_dim = partials(age_exit[0], &cfg)
4415 .expect("probe partials")
4416 .expect("nonlinear")
4417 .len();
4418 let mut acc = Array1::<f64>::zeros(theta_dim);
4419 for i in 0..age_exit.len() {
4420 let p_exit = partials(age_exit[i], &cfg)
4421 .expect("exit partials")
4422 .expect("nonlinear");
4423 let r_x = residuals.exit[i];
4424 let r_d = residuals.derivative[i];
4425 for k in 0..theta_dim {
4426 acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4427 }
4428 let r_e = residuals.entry[i];
4429 if r_e != 0.0 {
4430 let p_entry = partials(age_entry[i], &cfg)
4431 .expect("entry partials")
4432 .expect("nonlinear");
4433 for k in 0..theta_dim {
4434 acc[k] += r_e * p_entry[k].0;
4435 }
4436 }
4437 }
4438 acc
4439 };
4440
4441 let rp_engine = baseline_chain_rule_gradient(
4443 age_entry.view(),
4444 age_exit.view(),
4445 age_exit.view(),
4446 &cfg,
4447 &residuals,
4448 )
4449 .expect("rp gradient")
4450 .expect("rp nonlinear");
4451 let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4452 assert_eq!(rp_engine.len(), rp_reference.len());
4453 for k in 0..rp_engine.len() {
4454 assert_close(
4455 rp_engine[k],
4456 rp_reference[k],
4457 0.0,
4458 &format!("rp engine vs inline reference theta[{k}]"),
4459 );
4460 }
4461
4462 let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4464 age_entry.view(),
4465 age_exit.view(),
4466 &cfg,
4467 &residuals,
4468 )
4469 .expect("probit gradient")
4470 .expect("probit nonlinear");
4471 let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4472 assert_eq!(probit_engine.len(), probit_reference.len());
4473 for k in 0..probit_engine.len() {
4474 assert_close(
4475 probit_engine[k],
4476 probit_reference[k],
4477 0.0,
4478 &format!("probit engine vs inline reference theta[{k}]"),
4479 );
4480 }
4481 }
4482
4483 #[test]
4504 fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4505 let cfg = SurvivalBaselineConfig {
4506 target: SurvivalBaselineTarget::GompertzMakeham,
4507 scale: None,
4508 shape: Some(0.05),
4509 rate: Some(0.012),
4510 makeham: Some(0.003),
4511 };
4512 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4514 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4515 let residuals = OffsetChannelResiduals {
4518 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4519 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4520 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4521 right: Array1::<f64>::zeros(8),
4522 };
4523
4524 let analytic = baseline_chain_rule_gradient(
4525 age_entry.view(),
4526 age_exit.view(),
4527 age_exit.view(),
4528 &cfg,
4529 &residuals,
4530 )
4531 .expect("analytic gradient ok")
4532 .expect("GM baseline has a θ-gradient");
4533 assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4534
4535 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4541 let mut acc = 0.0;
4542 for i in 0..age_exit.len() {
4543 let (eta_exit_i, od_exit_i) =
4544 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4545 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4546 if residuals.entry[i] != 0.0 {
4547 let (eta_entry_i, _) =
4548 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4549 acc += residuals.entry[i] * eta_entry_i;
4550 }
4551 }
4552 acc
4553 };
4554
4555 let theta0 = survival_baseline_theta_from_config(&cfg)
4556 .expect("theta seed")
4557 .expect("GM has θ");
4558 let delta = 1e-4;
4560 let mut fd = Array1::<f64>::zeros(analytic.len());
4561 for k in 0..analytic.len() {
4562 let mut theta_plus = theta0.clone();
4563 theta_plus[k] += delta;
4564 let mut theta_minus = theta0.clone();
4565 theta_minus[k] -= delta;
4566 let cfg_plus =
4567 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4568 let cfg_minus =
4569 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4570 let lp = loss_at_cfg(&cfg_plus);
4571 let lm = loss_at_cfg(&cfg_minus);
4572 fd[k] = (lp - lm) / (2.0 * delta);
4573 }
4574
4575 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4576 let max_err = analytic
4577 .iter()
4578 .zip(fd.iter())
4579 .map(|(a, b)| (a - b).abs())
4580 .fold(0.0_f64, f64::max);
4581 let rel = max_err / (analytic_norm + 1e-12);
4582 eprintln!(
4584 "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4585 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4586 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4587 );
4588 assert!(
4589 rel < 1e-2,
4590 "analytic θ-gradient disagrees with central FD beyond 1%: \
4591 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4592 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4593 );
4594 }
4595
4596 #[test]
4611 fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4612 let cfg = SurvivalBaselineConfig {
4613 target: SurvivalBaselineTarget::Weibull,
4614 scale: Some(11.0),
4615 shape: Some(1.4),
4616 rate: None,
4617 makeham: None,
4618 };
4619 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4620 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4621 let residuals = OffsetChannelResiduals {
4622 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4623 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4624 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4625 right: Array1::<f64>::zeros(8),
4626 };
4627
4628 let analytic = baseline_chain_rule_gradient(
4629 age_entry.view(),
4630 age_exit.view(),
4631 age_exit.view(),
4632 &cfg,
4633 &residuals,
4634 )
4635 .expect("analytic gradient ok")
4636 .expect("Weibull baseline has a θ-gradient");
4637 assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4638
4639 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4640 let mut acc = 0.0;
4641 for i in 0..age_exit.len() {
4642 let (eta_exit_i, od_exit_i) =
4643 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4644 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4645 if residuals.entry[i] != 0.0 {
4646 let (eta_entry_i, _) =
4647 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4648 acc += residuals.entry[i] * eta_entry_i;
4649 }
4650 }
4651 acc
4652 };
4653
4654 let theta0 = survival_baseline_theta_from_config(&cfg)
4655 .expect("theta seed")
4656 .expect("Weibull has θ");
4657 let delta = 1e-4;
4658 let mut fd = Array1::<f64>::zeros(analytic.len());
4659 for k in 0..analytic.len() {
4660 let mut theta_plus = theta0.clone();
4661 theta_plus[k] += delta;
4662 let mut theta_minus = theta0.clone();
4663 theta_minus[k] -= delta;
4664 let cfg_plus =
4665 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4666 let cfg_minus =
4667 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4668 let lp = loss_at_cfg(&cfg_plus);
4669 let lm = loss_at_cfg(&cfg_minus);
4670 fd[k] = (lp - lm) / (2.0 * delta);
4671 }
4672
4673 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4674 let max_err = analytic
4675 .iter()
4676 .zip(fd.iter())
4677 .map(|(a, b)| (a - b).abs())
4678 .fold(0.0_f64, f64::max);
4679 let rel = max_err / (analytic_norm + 1e-12);
4680 eprintln!(
4681 "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4682 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4683 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4684 );
4685 assert!(
4686 rel < 1e-2,
4687 "analytic θ-gradient disagrees with central FD beyond 1%: \
4688 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4689 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4690 );
4691 }
4692
4693 fn fd_baseline_offset(
4706 age: f64,
4707 cfg: &SurvivalBaselineConfig,
4708 steps: &[f64],
4709 ) -> Vec<(f64, f64)> {
4710 let theta = survival_baseline_theta_from_config(cfg)
4711 .expect("theta")
4712 .expect("non-linear baseline");
4713 assert_eq!(
4714 steps.len(),
4715 theta.len(),
4716 "fd_baseline_offset: step vector length must match θ dimension"
4717 );
4718 (0..theta.len())
4719 .map(|k| {
4720 let h = steps[k];
4721 let mut theta_plus = theta.clone();
4722 theta_plus[k] += h;
4723 let mut theta_minus = theta.clone();
4724 theta_minus[k] -= h;
4725 let cfg_plus =
4726 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4727 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4728 .expect("minus cfg");
4729 let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4730 let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4731 ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4732 })
4733 .collect()
4734 }
4735
4736 fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4737 let ok = if expected.abs() < 1.0 {
4741 (actual - expected).abs() <= tol
4742 } else {
4743 (actual - expected).abs() <= tol * expected.abs().max(1.0)
4744 };
4745 assert!(
4746 ok,
4747 "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4748 );
4749 }
4750
4751 #[test]
4752 fn gompertz_offset_partials_match_central_diff() {
4753 let cases = [
4757 (0.5_f64, 0.01_f64, 30.0_f64),
4758 (0.2, 0.05, 60.0),
4759 (1.0, 0.001, 10.0),
4760 (0.4, 5e-11, 25.0),
4761 (0.4, -5e-11, 25.0),
4762 (0.3, -0.02, 40.0),
4763 (0.8, 0.2, 5.0),
4764 ];
4765 for &(rate, shape, age) in &cases {
4766 let cfg = SurvivalBaselineConfig {
4767 target: SurvivalBaselineTarget::Gompertz,
4768 scale: None,
4769 shape: Some(shape),
4770 rate: Some(rate),
4771 makeham: None,
4772 };
4773 let analytic = baseline_offset_theta_partials(age, &cfg)
4774 .expect("ok")
4775 .expect("non-linear");
4776 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4782 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4783 assert_eq!(analytic.len(), 2);
4784 assert_close(
4786 analytic[0].0,
4787 fd[0].0,
4788 1e-7,
4789 &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4790 );
4791 assert_close(
4792 analytic[0].1,
4793 fd[0].1,
4794 1e-7,
4795 &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4796 );
4797 assert_close(
4800 analytic[1].0,
4801 fd[1].0,
4802 1e-5,
4803 &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4804 );
4805 assert_close(
4806 analytic[1].1,
4807 fd[1].1,
4808 1e-5,
4809 &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4810 );
4811 }
4812 }
4813
4814 #[test]
4815 fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4816 let cfg = SurvivalBaselineConfig {
4820 target: SurvivalBaselineTarget::Gompertz,
4821 scale: None,
4822 shape: Some(0.05),
4823 rate: Some(0.3),
4824 makeham: None,
4825 };
4826 let partials = baseline_offset_theta_partials(42.0, &cfg)
4827 .expect("ok")
4828 .expect("non-linear");
4829 assert_eq!(partials[0].0, 1.0);
4830 assert_eq!(partials[0].1, 0.0);
4831 }
4832
4833 #[test]
4834 fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4835 let age = 25.0;
4842 let rate = 0.4;
4843 let cfg_taylor = SurvivalBaselineConfig {
4844 target: SurvivalBaselineTarget::Gompertz,
4845 scale: None,
4846 shape: Some(0.5e-10),
4847 rate: Some(rate),
4848 makeham: None,
4849 };
4850 let cfg_direct = SurvivalBaselineConfig {
4851 target: SurvivalBaselineTarget::Gompertz,
4852 scale: None,
4853 shape: Some(2.0e-10),
4854 rate: Some(rate),
4855 makeham: None,
4856 };
4857 let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4858 .expect("ok")
4859 .expect("nl");
4860 let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4861 .expect("ok")
4862 .expect("nl");
4863 assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4865 assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4866 assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4868 assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4869 }
4870
4871 #[test]
4883 fn gompertz_hazard_shape_derivatives_match_central_diff() {
4884 let cases = [
4889 (10.0_f64, 0.012_f64, 0.05_f64),
4890 (2.5, 0.5, 0.2),
4891 (15.0, 0.003, 0.01),
4892 (40.0, 0.3, 0.001),
4893 ];
4894 let h = 1e-6;
4895 for &(age, rate, shape) in &cases {
4896 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4898 let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4899 let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4900 assert_close(
4901 d_cum,
4902 (cum_p - cum_m) / (2.0 * h),
4903 1e-6,
4904 &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4905 );
4906 assert_close(
4907 d_inst,
4908 (inst_p - inst_m) / (2.0 * h),
4909 1e-6,
4910 &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4911 );
4912
4913 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4915 let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4916 let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4917 assert_close(
4918 d2_cum,
4919 (dcum_p - dcum_m) / (2.0 * h),
4920 1e-5,
4921 &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4922 );
4923 assert_close(
4924 d2_inst,
4925 (dinst_p - dinst_m) / (2.0 * h),
4926 1e-5,
4927 &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4928 );
4929 }
4930 }
4931
4932 #[test]
4933 fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4934 let cases = [
4946 (25.0_f64, 0.4_f64, 1e-9_f64),
4947 (100.0, 0.4, 1e-6), (100.0, 0.012, 1e-6), (50.0, 1.2, 1e-8),
4950 ];
4951 for &(age, rate, shape) in &cases {
4962 let t = age;
4963 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4964 assert_close(
4965 d_cum,
4966 rate * t * t / 2.0,
4967 1e-3,
4968 &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4969 );
4970 assert_close(
4971 d_inst,
4972 rate * t,
4973 1e-3,
4974 &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4975 );
4976
4977 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4978 assert_close(
4979 d2_cum,
4980 rate * t * t * t / 3.0,
4981 1e-3,
4982 &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4983 );
4984 assert_close(
4985 d2_inst,
4986 rate * t * t,
4987 1e-3,
4988 &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4989 );
4990 }
4991 }
4992
4993 #[test]
4994 fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4995 let age = 100.0;
5002 let rate = 0.4;
5003 let t = age;
5004 let truth = rate * t * t * t / 3.0; for k in 5..=12 {
5011 let shape = 10f64.powi(-(k as i32)); let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
5013 assert_close(
5014 d2_cum,
5015 truth,
5016 1e-3,
5017 &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
5018 );
5019 }
5020 }
5021
5022 #[test]
5023 fn weibull_offset_partials_match_central_diff() {
5024 let cases = [
5025 (0.5_f64, 1.2_f64, 25.0_f64),
5026 (2.0, 0.8, 60.0),
5027 (0.1, 3.0, 10.0),
5028 ];
5029 for &(scale, shape, age) in &cases {
5030 let cfg = SurvivalBaselineConfig {
5031 target: SurvivalBaselineTarget::Weibull,
5032 scale: Some(scale),
5033 shape: Some(shape),
5034 rate: None,
5035 makeham: None,
5036 };
5037 let analytic = baseline_offset_theta_partials(age, &cfg)
5038 .expect("ok")
5039 .expect("nl");
5040 let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
5041 assert_eq!(analytic.len(), 2);
5042 for k in 0..2 {
5043 assert_close(
5044 analytic[k].0,
5045 fd[k].0,
5046 1e-7,
5047 &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5048 );
5049 assert_close(
5050 analytic[k].1,
5051 fd[k].1,
5052 1e-7,
5053 &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5054 );
5055 }
5056 assert_eq!(analytic[0].1, 0.0);
5058 }
5059 }
5060
5061 #[test]
5062 fn gompertz_makeham_offset_partials_match_central_diff() {
5063 let cases = [
5064 (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
5065 (0.5, 0.01, 0.01, 25.0),
5066 (0.2, 0.001, 0.005, 60.0),
5067 (0.4, 5e-11, 0.01, 25.0),
5068 (0.4, -5e-11, 0.01, 25.0),
5069 (0.8, 0.2, 0.05, 5.0),
5070 ];
5071 for &(rate, shape, makeham, age) in &cases {
5072 let cfg = SurvivalBaselineConfig {
5073 target: SurvivalBaselineTarget::GompertzMakeham,
5074 scale: None,
5075 shape: Some(shape),
5076 rate: Some(rate),
5077 makeham: Some(makeham),
5078 };
5079 let analytic = baseline_offset_theta_partials(age, &cfg)
5080 .expect("ok")
5081 .expect("nl");
5082 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
5086 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
5087 assert_eq!(analytic.len(), 3);
5088 for k in 0..3 {
5089 assert_close(
5090 analytic[k].0,
5091 fd[k].0,
5092 1e-5,
5093 &format!(
5094 "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5095 ),
5096 );
5097 assert_close(
5098 analytic[k].1,
5099 fd[k].1,
5100 1e-5,
5101 &format!(
5102 "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5103 ),
5104 );
5105 }
5106 }
5107 }
5108
5109 #[test]
5110 fn linear_baseline_has_no_theta_partials() {
5111 let cfg = SurvivalBaselineConfig {
5112 target: SurvivalBaselineTarget::Linear,
5113 scale: None,
5114 shape: None,
5115 rate: None,
5116 makeham: None,
5117 };
5118 assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
5119 }
5120
5121 #[test]
5122 fn baseline_offset_partials_reject_non_positive_ages() {
5123 let cfg = SurvivalBaselineConfig {
5124 target: SurvivalBaselineTarget::Gompertz,
5125 scale: None,
5126 shape: Some(0.01),
5127 rate: Some(0.5),
5128 makeham: None,
5129 };
5130 assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
5131 assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
5132 assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
5133 }
5134
5135 #[test]
5141 fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
5142 let cfg = SurvivalBaselineConfig {
5143 target: SurvivalBaselineTarget::Gompertz,
5144 scale: None,
5145 shape: Some(0.05),
5146 rate: Some(0.3),
5147 makeham: None,
5148 };
5149 let age_entry = array![10.0_f64];
5150 let age_exit = array![25.0_f64];
5151 let residuals = OffsetChannelResiduals {
5152 exit: array![0.7_f64],
5153 entry: array![-0.2_f64],
5154 derivative: array![-0.4_f64],
5155 right: Array1::<f64>::zeros(1),
5156 };
5157 let grad = baseline_chain_rule_gradient(
5158 age_entry.view(),
5159 age_exit.view(),
5160 age_exit.view(),
5161 &cfg,
5162 &residuals,
5163 )
5164 .expect("ok")
5165 .expect("non-linear");
5166 let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5168 .unwrap()
5169 .unwrap();
5170 let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5171 .unwrap()
5172 .unwrap();
5173 for k in 0..p_exit.len() {
5174 let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5175 assert!(
5176 (grad[k] - expected).abs() < 1e-12,
5177 "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5178 grad[k],
5179 expected
5180 );
5181 }
5182 }
5183
5184 #[test]
5187 fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5188 let cfg = SurvivalBaselineConfig {
5189 target: SurvivalBaselineTarget::Gompertz,
5190 scale: None,
5191 shape: Some(0.05),
5192 rate: Some(0.3),
5193 makeham: None,
5194 };
5195 let age_entry = array![0.0_f64, 5.0_f64];
5196 let age_exit = array![10.0_f64, 20.0_f64];
5197 let residuals = OffsetChannelResiduals {
5198 exit: array![0.5_f64, 0.3_f64],
5199 entry: array![0.0_f64, -0.1_f64], derivative: array![-0.2_f64, 0.0_f64],
5201 right: Array1::<f64>::zeros(2),
5202 };
5203 let grad = baseline_chain_rule_gradient(
5205 age_entry.view(),
5206 age_exit.view(),
5207 age_exit.view(),
5208 &cfg,
5209 &residuals,
5210 )
5211 .expect("must not fail on origin-entry row with r_entry=0")
5212 .expect("non-linear");
5213 assert_eq!(grad.len(), 2);
5214 let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5216 let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5217 let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5218 for k in 0..2 {
5219 let expected = 0.5 * p_exit_0[k].0
5220 + (-0.2) * p_exit_0[k].1
5221 + 0.3 * p_exit_1[k].0
5222 + (-0.1) * p_entry_1[k].0;
5223 assert!(
5224 (grad[k] - expected).abs() < 1e-12,
5225 "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5226 grad[k],
5227 expected
5228 );
5229 }
5230 }
5231
5232 #[test]
5234 fn chain_rule_gradient_linear_target_returns_none() {
5235 let cfg = SurvivalBaselineConfig {
5236 target: SurvivalBaselineTarget::Linear,
5237 scale: None,
5238 shape: None,
5239 rate: None,
5240 makeham: None,
5241 };
5242 let age_entry = array![1.0_f64];
5243 let age_exit = array![2.0_f64];
5244 let residuals = OffsetChannelResiduals {
5245 exit: array![0.1_f64],
5246 entry: array![0.0_f64],
5247 derivative: array![0.0_f64],
5248 right: Array1::<f64>::zeros(1),
5249 };
5250 let grad = baseline_chain_rule_gradient(
5251 age_entry.view(),
5252 age_exit.view(),
5253 age_exit.view(),
5254 &cfg,
5255 &residuals,
5256 )
5257 .expect("ok");
5258 assert!(grad.is_none());
5259 }
5260
5261 #[test]
5280 fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5281 let cfg = SurvivalBaselineConfig {
5284 target: SurvivalBaselineTarget::Gompertz,
5285 scale: None,
5286 shape: Some(0.03),
5287 rate: Some(0.25),
5288 makeham: None,
5289 };
5290 let age_entry = array![0.0_f64, 5.0, 8.0];
5291 let age_exit = array![4.0_f64, 12.0, 20.0];
5292 let weights = array![1.0_f64, 2.0, 0.5];
5295 let events = [1.0_f64, 1.0, 0.0];
5296 let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5301 let s_vals = [0.7_f64, 1.1, 1.5];
5302 let (r_x, r_e, r_d) = {
5303 let mut rx = Array1::<f64>::zeros(3);
5304 let mut re = Array1::<f64>::zeros(3);
5305 let mut rd = Array1::<f64>::zeros(3);
5306 for i in 0..3 {
5307 let w = weights[i];
5308 let d = events[i];
5309 rx[i] = w * (eta_exit_vals[i].exp() - d);
5310 re[i] = if i == 0 {
5311 0.0 } else {
5313 -w * eta_entry_vals[i].exp()
5314 };
5315 rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5316 }
5317 (rx, re, rd)
5318 };
5319 let residuals = OffsetChannelResiduals {
5320 exit: r_x.clone(),
5321 entry: r_e.clone(),
5322 derivative: r_d.clone(),
5323 right: Array1::<f64>::zeros(3),
5324 };
5325 let grad = baseline_chain_rule_gradient(
5326 age_entry.view(),
5327 age_exit.view(),
5328 age_exit.view(),
5329 &cfg,
5330 &residuals,
5331 )
5332 .expect("ok")
5333 .expect("non-linear");
5334
5335 let nll = |theta_plus: &Array1<f64>| -> f64 {
5340 let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5341 let mut sum = 0.0_f64;
5342 for i in 0..3 {
5343 let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5344 let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5345 let d_eta_x = eta_x_p - base.0;
5346 let d_d_x = d_x_p - base.1;
5347 let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5348 let s_new = s_vals[i] + d_d_x;
5349 let interval_entry = if i == 0 {
5350 0.0_f64
5351 } else {
5352 let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5353 let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5354 let d_eta_e = eta_e_p - base_e.0;
5355 let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5356 eta_entry_new.exp()
5357 };
5358 let w = weights[i];
5359 let d = events[i];
5360 let nll_i =
5361 w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5362 sum += nll_i;
5363 }
5364 sum
5365 };
5366
5367 let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5368 let h = 1e-6;
5369 for k in 0..theta_base.len() {
5370 let mut tp = theta_base.clone();
5371 let mut tm = theta_base.clone();
5372 tp[k] += h;
5373 tm[k] -= h;
5374 let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5375 assert!(
5376 (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5377 "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5378 grad[k],
5379 fd
5380 );
5381 }
5382 }
5383
5384 #[test]
5386 fn chain_rule_gradient_rejects_length_mismatch() {
5387 let cfg = SurvivalBaselineConfig {
5388 target: SurvivalBaselineTarget::Gompertz,
5389 scale: None,
5390 shape: Some(0.05),
5391 rate: Some(0.3),
5392 makeham: None,
5393 };
5394 let age_entry = array![1.0_f64, 2.0]; let age_exit = array![5.0_f64, 6.0, 7.0]; let residuals = OffsetChannelResiduals {
5397 exit: array![0.1_f64, 0.2, 0.3],
5398 entry: array![0.0_f64, 0.0, 0.0],
5399 derivative: array![0.0_f64, 0.0, 0.0],
5400 right: Array1::<f64>::zeros(3),
5401 };
5402 let err = baseline_chain_rule_gradient(
5403 age_entry.view(),
5404 age_exit.view(),
5405 age_exit.view(),
5406 &cfg,
5407 &residuals,
5408 )
5409 .expect_err("length mismatch must error");
5410 assert!(err.contains("length mismatch"), "err={err}");
5411 }
5412}