1use gam_terms::basis::{
13 BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14 BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15 create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18 DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19 SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24 WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25 monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::{InverseLink, StandardLink};
31use ndarray::{Array1, Array2, array, s};
32use rayon::prelude::*;
33
34#[derive(Clone, Debug)]
51pub enum SurvivalConstructionError {
52 InvalidConfig { reason: String },
55 MissingColumn { reason: String },
58 IncompatibleDimensions { reason: String },
61 DataValidationFailed { reason: String },
65 BasisConstructionFailed { reason: String },
69 UnsupportedDistribution { reason: String },
72}
73
74impl_reason_error_boilerplate! {
75 SurvivalConstructionError {
76 InvalidConfig,
77 MissingColumn,
78 IncompatibleDimensions,
79 DataValidationFailed,
80 BasisConstructionFailed,
81 UnsupportedDistribution,
82 }
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SurvivalBaselineTarget {
91 Linear,
95 Weibull,
100 Gompertz,
105 GompertzMakeham,
110}
111
112#[derive(Clone, Debug)]
113pub struct SurvivalBaselineConfig {
114 pub target: SurvivalBaselineTarget,
115 pub scale: Option<f64>,
116 pub shape: Option<f64>,
117 pub rate: Option<f64>,
118 pub makeham: Option<f64>,
119}
120
121#[derive(Clone, Debug)]
122pub enum SurvivalTimeBasisConfig {
123 None,
124 Linear,
125 BSpline {
126 degree: usize,
127 knots: Array1<f64>,
128 smooth_lambda: f64,
129 },
130 ISpline {
168 degree: usize,
169 knots: Array1<f64>,
170 keep_cols: Vec<usize>,
171 smooth_lambda: f64,
172 },
173}
174
175#[derive(Clone, Debug, PartialEq)]
189pub struct SavedSurvivalTimeBasis {
190 pub basisname: String,
191 pub degree: Option<usize>,
192 pub knots: Option<Vec<f64>>,
193 pub keep_cols: Option<Vec<usize>>,
194 pub smooth_lambda: Option<f64>,
195 pub anchor: f64,
196}
197
198impl SavedSurvivalTimeBasis {
199 pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
202 Self {
203 basisname: build.basisname.clone(),
204 degree: build.degree,
205 knots: build.knots.clone(),
206 keep_cols: build.keep_cols.clone(),
207 smooth_lambda: build.smooth_lambda,
208 anchor,
209 }
210 }
211}
212
213#[derive(Clone)]
214pub struct SurvivalTimeBuildOutput {
215 pub x_entry_time: DesignMatrix,
216 pub x_exit_time: DesignMatrix,
217 pub x_derivative_time: DesignMatrix,
218 pub penalties: Vec<Array2<f64>>,
219 pub nullspace_dims: Vec<usize>,
221 pub basisname: String,
222 pub degree: Option<usize>,
223 pub knots: Option<Vec<f64>>,
224 pub keep_cols: Option<Vec<usize>>,
225 pub smooth_lambda: Option<f64>,
226}
227
228pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
229
230const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
238
239const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
247
248#[derive(Clone, Copy, Debug, PartialEq, Eq)]
249pub enum SurvivalLikelihoodMode {
250 Transformation,
251 Weibull,
252 LocationScale,
253 MarginalSlope,
254 Latent,
255 LatentBinary,
256}
257
258pub struct SurvivalTimeWiggleBuild {
259 pub penalties: Vec<Array2<f64>>,
260 pub nullspace_dims: Vec<usize>,
261 pub knots: Array1<f64>,
262 pub degree: usize,
263 pub ncols: usize,
264}
265
266pub fn normalize_survival_time_pair(
271 entry_raw: f64,
272 exit_raw: f64,
273 row_index: usize,
274) -> Result<(f64, f64), String> {
275 if !entry_raw.is_finite() || !exit_raw.is_finite() {
276 return Err(SurvivalConstructionError::DataValidationFailed {
277 reason: format!("non-finite survival times at row {}", row_index + 1),
278 }
279 .into());
280 }
281 if entry_raw < 0.0 || exit_raw < 0.0 {
282 return Err(SurvivalConstructionError::DataValidationFailed {
283 reason: format!("negative survival times at row {}", row_index + 1),
284 }
285 .into());
286 }
287
288 let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
289 let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
290 Ok((entry, exit))
291}
292
293pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
298 basisname.eq_ignore_ascii_case("ispline")
299}
300
301pub fn require_structural_survival_time_basis(
302 basisname: &str,
303 context: &str,
304) -> Result<(), String> {
305 if survival_basis_supports_structural_monotonicity(basisname) {
306 return Ok(());
307 }
308 Err(SurvivalConstructionError::UnsupportedDistribution {
309 reason: format!(
310 "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
311Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
312`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
313Re-run with `--time-basis ispline`."
314 ),
315 }
316 .into())
317}
318
319pub fn parse_survival_baseline_config(
324 target_raw: &str,
325 scale: Option<f64>,
326 shape: Option<f64>,
327 rate: Option<f64>,
328 makeham: Option<f64>,
329) -> Result<SurvivalBaselineConfig, String> {
330 let target = match target_raw.to_ascii_lowercase().as_str() {
331 "linear" => SurvivalBaselineTarget::Linear,
332 "weibull" => SurvivalBaselineTarget::Weibull,
333 "gompertz" => SurvivalBaselineTarget::Gompertz,
334 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
335 other => {
336 return Err(SurvivalConstructionError::UnsupportedDistribution {
337 reason: format!(
338 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
339 ),
340 }
341 .into());
342 }
343 };
344
345 match target {
346 SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
347 target,
348 scale: None,
349 shape: None,
350 rate: None,
351 makeham: None,
352 }),
353 SurvivalBaselineTarget::Weibull => {
354 let scale = scale.ok_or_else(|| {
355 "--baseline-target weibull requires --baseline-scale > 0".to_string()
356 })?;
357 let shape = shape.ok_or_else(|| {
358 "--baseline-target weibull requires --baseline-shape > 0".to_string()
359 })?;
360 if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
361 return Err(
362 "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
363 .to_string(),
364 );
365 }
366 Ok(SurvivalBaselineConfig {
367 target,
368 scale: Some(scale),
369 shape: Some(shape),
370 rate: None,
371 makeham: None,
372 })
373 }
374 SurvivalBaselineTarget::Gompertz => {
375 let rate = rate.unwrap_or(1.0);
376 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
377 if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
378 return Err(
379 "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
380 .to_string(),
381 );
382 }
383 Ok(SurvivalBaselineConfig {
384 target,
385 scale: None,
386 shape: Some(shape),
387 rate: Some(rate),
388 makeham: None,
389 })
390 }
391 SurvivalBaselineTarget::GompertzMakeham => {
392 let rate = rate.unwrap_or(0.5);
393 let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
394 let makeham = makeham.unwrap_or(0.5);
395 if !rate.is_finite()
396 || rate <= 0.0
397 || !shape.is_finite()
398 || !makeham.is_finite()
399 || makeham <= 0.0
400 {
401 return Err(
402 "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
403 .to_string(),
404 );
405 }
406 Ok(SurvivalBaselineConfig {
407 target,
408 scale: None,
409 shape: Some(shape),
410 rate: Some(rate),
411 makeham: Some(makeham),
412 })
413 }
414 }
415}
416
417pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
422 match raw.to_ascii_lowercase().as_str() {
423 "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
424 "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
425 "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
426 "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
427 "latent" => Ok(SurvivalLikelihoodMode::Latent),
428 "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
429 other => Err(SurvivalConstructionError::UnsupportedDistribution {
430 reason: format!(
431 "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
432 ),
433 }
434 .into()),
435 }
436}
437
438pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
439 match mode {
440 SurvivalLikelihoodMode::Transformation => "transformation",
441 SurvivalLikelihoodMode::Weibull => "weibull",
442 SurvivalLikelihoodMode::LocationScale => "location-scale",
443 SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
444 SurvivalLikelihoodMode::Latent => "latent",
445 SurvivalLikelihoodMode::LatentBinary => "latent-binary",
446 }
447}
448
449pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
450 match raw.to_ascii_lowercase().as_str() {
451 "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
452 "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
453 "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
454 other => Err(SurvivalConstructionError::UnsupportedDistribution {
455 reason: format!(
456 "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
457 ),
458 }
459 .into()),
460 }
461}
462
463pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
464 match target {
465 SurvivalBaselineTarget::Linear => "linear",
466 SurvivalBaselineTarget::Weibull => "weibull",
467 SurvivalBaselineTarget::Gompertz => "gompertz",
468 SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
469 }
470}
471
472pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
473 let sum = age_exit
474 .iter()
475 .copied()
476 .filter(|value| value.is_finite() && *value > 0.0)
477 .sum::<f64>();
478 let count = age_exit
479 .iter()
480 .filter(|value| value.is_finite() && **value > 0.0)
481 .count()
482 .max(1);
483 (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
484}
485
486pub fn initial_survival_baseline_config_for_fit(
487 target_raw: &str,
488 scale: Option<f64>,
489 shape: Option<f64>,
490 rate: Option<f64>,
491 makeham: Option<f64>,
492 age_exit: &Array1<f64>,
493) -> Result<SurvivalBaselineConfig, String> {
494 let target = match target_raw.trim().to_ascii_lowercase().as_str() {
495 "linear" => SurvivalBaselineTarget::Linear,
496 "weibull" => SurvivalBaselineTarget::Weibull,
497 "gompertz" => SurvivalBaselineTarget::Gompertz,
498 "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
499 other => {
500 return Err(SurvivalConstructionError::UnsupportedDistribution {
501 reason: format!(
502 "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
503 ),
504 }
505 .into());
506 }
507 };
508 let time_scale_seed = positive_survival_time_seed(age_exit);
509 let cfg = match target {
510 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
511 target,
512 scale: None,
513 shape: None,
514 rate: None,
515 makeham: None,
516 },
517 SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
518 target,
519 scale: Some(scale.unwrap_or(time_scale_seed)),
520 shape: Some(shape.unwrap_or(1.0)),
521 rate: None,
522 makeham: None,
523 },
524 SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
525 target,
526 scale: None,
527 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
528 rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
529 makeham: None,
530 },
531 SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
532 target,
533 scale: None,
534 shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
535 rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
536 makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
537 },
538 };
539 parse_survival_baseline_config(
540 survival_baseline_targetname(cfg.target),
541 cfg.scale,
542 cfg.shape,
543 cfg.rate,
544 cfg.makeham,
545 )
546}
547
548fn survival_baseline_theta_from_config(
549 cfg: &SurvivalBaselineConfig,
550) -> Result<Option<Array1<f64>>, String> {
551 Ok(match cfg.target {
552 SurvivalBaselineTarget::Linear => None,
553 SurvivalBaselineTarget::Weibull => Some(array![
554 cfg.scale
555 .ok_or_else(|| "missing weibull baseline scale".to_string())?
556 .ln(),
557 cfg.shape
558 .ok_or_else(|| "missing weibull baseline shape".to_string())?
559 .ln(),
560 ]),
561 SurvivalBaselineTarget::Gompertz => Some(array![
562 cfg.rate
563 .ok_or_else(|| "missing gompertz baseline rate".to_string())?
564 .ln(),
565 cfg.shape
566 .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
567 ]),
568 SurvivalBaselineTarget::GompertzMakeham => Some(array![
569 cfg.rate
570 .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
571 .ln(),
572 cfg.shape
573 .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
574 cfg.makeham
575 .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
576 .ln(),
577 ]),
578 })
579}
580
581fn survival_baseline_config_from_theta(
582 target: SurvivalBaselineTarget,
583 theta: &Array1<f64>,
584) -> Result<SurvivalBaselineConfig, String> {
585 let cfg = match target {
586 SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
587 target,
588 scale: None,
589 shape: None,
590 rate: None,
591 makeham: None,
592 },
593 SurvivalBaselineTarget::Weibull => {
594 if theta.len() != 2 {
595 return Err(SurvivalConstructionError::IncompatibleDimensions {
596 reason: format!(
597 "weibull baseline parameter dimension mismatch: expected 2, got {}",
598 theta.len()
599 ),
600 }
601 .into());
602 }
603 SurvivalBaselineConfig {
604 target,
605 scale: Some(theta[0].exp()),
606 shape: Some(theta[1].exp()),
607 rate: None,
608 makeham: None,
609 }
610 }
611 SurvivalBaselineTarget::Gompertz => {
612 if theta.len() != 2 {
613 return Err(SurvivalConstructionError::IncompatibleDimensions {
614 reason: format!(
615 "gompertz baseline parameter dimension mismatch: expected 2, got {}",
616 theta.len()
617 ),
618 }
619 .into());
620 }
621 SurvivalBaselineConfig {
622 target,
623 scale: None,
624 shape: Some(theta[1]),
625 rate: Some(theta[0].exp()),
626 makeham: None,
627 }
628 }
629 SurvivalBaselineTarget::GompertzMakeham => {
630 if theta.len() != 3 {
631 return Err(SurvivalConstructionError::IncompatibleDimensions {
632 reason: format!(
633 "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
634 theta.len()
635 ),
636 }
637 .into());
638 }
639 SurvivalBaselineConfig {
640 target,
641 scale: None,
642 shape: Some(theta[1]),
643 rate: Some(theta[0].exp()),
644 makeham: Some(theta[2].exp()),
645 }
646 }
647 };
648 parse_survival_baseline_config(
649 survival_baseline_targetname(cfg.target),
650 cfg.scale,
651 cfg.shape,
652 cfg.rate,
653 cfg.makeham,
654 )
655}
656
657#[derive(Clone, Copy, Debug, PartialEq, Eq)]
670enum BaselineDerivativeContract {
671 GradientOnly,
674 GradientHessian,
678}
679
680impl BaselineDerivativeContract {
681 fn configure(
686 self,
687 problem: gam_solve::rho_optimizer::OuterProblem,
688 ) -> gam_solve::rho_optimizer::OuterProblem {
689 use gam_problem::{DeclaredHessianForm, Derivative};
690 match self {
691 BaselineDerivativeContract::GradientOnly => problem
694 .with_gradient(Derivative::Analytic)
695 .with_hessian(DeclaredHessianForm::Unavailable)
696 .with_tolerance(1e-4)
697 .with_max_iter(240),
698 BaselineDerivativeContract::GradientHessian => problem
699 .with_gradient(Derivative::Analytic)
700 .with_hessian(DeclaredHessianForm::Either)
701 .with_tolerance(1e-4)
702 .with_max_iter(240),
703 }
704 }
705}
706
707fn run_baseline_theta_optimizer<Fc, Fe>(
718 initial: &SurvivalBaselineConfig,
719 context: &str,
720 contract: BaselineDerivativeContract,
721 cost_fn: Fc,
722 eval_fn: Fe,
723) -> Result<SurvivalBaselineConfig, String>
724where
725 Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
726 Fe: FnMut(
727 &mut (),
728 &Array1<f64>,
729 ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
730{
731 use gam_solve::rho_optimizer::OuterProblem;
732 let Some(seed) = survival_baseline_theta_from_config(initial)? else {
733 return Ok(initial.clone());
734 };
735 let dim = seed.len();
736 let target = initial.target;
737 let lower = seed.mapv(|v| v - 6.0);
738 let upper = seed.mapv(|v| v + 6.0);
739 let problem = contract
740 .configure(OuterProblem::new(dim))
741 .with_bounds(lower, upper)
742 .with_initial_rho(seed.clone())
743 .with_seed_config(crate::seeding::SeedConfig {
744 max_seeds: 1,
745 seed_budget: 1,
746 num_auxiliary_trailing: dim,
747 ..Default::default()
748 });
749 let mut obj = problem.build_objective(
750 (),
751 cost_fn,
752 eval_fn,
753 None::<fn(&mut ())>,
754 None::<
755 fn(
756 &mut (),
757 &Array1<f64>,
758 ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
759 >,
760 );
761 let result = problem
762 .run(&mut obj, context)
763 .map_err(|e| format!("{context} failed: {e}"))?;
764 if !result.converged {
765 return Err(SurvivalConstructionError::InvalidConfig {
766 reason: format!(
767 "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
768 result.iterations,
769 result.final_value,
770 result.final_grad_norm_report(),
771 ),
772 }
773 .into());
774 }
775 survival_baseline_config_from_theta(target, &result.rho)
776}
777
778fn run_baseline_theta_optimizer_with_eval<F>(
791 initial: &SurvivalBaselineConfig,
792 context: &str,
793 contract: BaselineDerivativeContract,
794 objective: F,
795) -> Result<SurvivalBaselineConfig, String>
796where
797 F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
798{
799 let target = initial.target;
800 let engine_context = context.to_string();
801 let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
802 let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
803 theta: &Array1<f64>|
804 -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
805 let cfg = survival_baseline_config_from_theta(target, theta)
806 .map_err(crate::model_types::EstimationError::InvalidInput)?;
807 let eval =
808 obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
809 if eval.gradient.len() != theta.len() {
810 return Err(crate::model_types::EstimationError::InvalidInput(format!(
811 "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
812 eval.gradient.len(),
813 theta.len()
814 )));
815 }
816 if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
817 if h.nrows() != theta.len() || h.ncols() != theta.len() {
818 return Err(crate::model_types::EstimationError::InvalidInput(format!(
819 "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
820 h.nrows(),
821 h.ncols(),
822 theta.len(),
823 theta.len()
824 )));
825 }
826 }
827 Ok(eval)
828 };
829 let cost_objective = std::rc::Rc::clone(&objective);
830 let cost_eval = eval_at.clone();
831 let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
832 cost_eval(&cost_objective, theta).map(|eval| eval.cost)
833 };
834 let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
835 run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
836}
837
838pub fn optimize_survival_baseline_config_with_gradient_only<F>(
849 initial: &SurvivalBaselineConfig,
850 context: &str,
851 mut objective: F,
852) -> Result<SurvivalBaselineConfig, String>
853where
854 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
855{
856 use gam_problem::{HessianResult, OuterEval};
857 run_baseline_theta_optimizer_with_eval(
858 initial,
859 context,
860 BaselineDerivativeContract::GradientOnly,
861 move |cfg| {
862 let (cost, gradient) = objective(cfg)?;
863 Ok(OuterEval {
864 cost,
865 gradient,
866 hessian: HessianResult::Unavailable,
867 inner_beta_hint: None,
868 })
869 },
870 )
871}
872
873pub fn optimize_survival_baseline_config_with_gradient<F>(
878 initial: &SurvivalBaselineConfig,
879 context: &str,
880 mut objective: F,
881) -> Result<SurvivalBaselineConfig, String>
882where
883 F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
884{
885 use gam_problem::{HessianResult, OuterEval};
886 run_baseline_theta_optimizer_with_eval(
887 initial,
888 context,
889 BaselineDerivativeContract::GradientHessian,
890 move |cfg| {
891 let (cost, gradient, hessian) = objective(cfg)?;
892 Ok(OuterEval {
893 cost,
894 gradient,
895 hessian: HessianResult::Analytic(hessian),
896 inner_beta_hint: None,
897 })
898 },
899 )
900}
901
902pub fn parse_survival_time_basis_config(
907 time_basis: &str,
908 time_degree: usize,
909 time_num_internal_knots: usize,
910 time_smooth_lambda: f64,
911) -> Result<SurvivalTimeBasisConfig, String> {
912 match time_basis.to_ascii_lowercase().as_str() {
913 "none" => Ok(SurvivalTimeBasisConfig::None),
914 "ispline" => {
915 if time_degree < 1 {
916 return Err(
917 "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
918 .to_string(),
919 );
920 }
921 if time_num_internal_knots == 0 {
922 return Err(
923 "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
924 .to_string(),
925 );
926 }
927 if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
928 return Err(
929 "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
930 .to_string(),
931 );
932 }
933 Ok(SurvivalTimeBasisConfig::ISpline {
934 degree: time_degree,
935 knots: Array1::zeros(0),
936 keep_cols: Vec::new(),
937 smooth_lambda: time_smooth_lambda,
938 })
939 }
940 "linear" | "bspline" => {
941 match require_structural_survival_time_basis(time_basis, "survival model configuration")
948 {
949 Err(e) => Err(e),
950 Ok(()) => Err(format!(
951 "internal: structural-basis check accepted non-structural \
952 survival time basis '{time_basis}'"
953 )),
954 }
955 }
956 other => Err(format!(
957 "unsupported --time-basis '{other}'; accepted values: ispline, none"
958 )),
959 }
960}
961
962pub fn build_survival_time_basis(
967 age_entry: &Array1<f64>,
968 age_exit: &Array1<f64>,
969 cfg: SurvivalTimeBasisConfig,
970 infer_knots_if_needed: Option<(usize, f64)>,
971) -> Result<SurvivalTimeBuildOutput, String> {
972 fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
973 if let Some(row) = times.iter().position(|t| !t.is_finite()) {
974 return Err(SurvivalConstructionError::DataValidationFailed {
975 reason: format!(
976 "survival time basis requires finite {label} times (row {})",
977 row + 1
978 ),
979 }
980 .into());
981 }
982 if let Some(row) = times.iter().position(|t| *t < 0.0) {
983 return Err(SurvivalConstructionError::DataValidationFailed {
984 reason: format!(
985 "survival time basis requires non-negative {label} times (row {})",
986 row + 1
987 ),
988 }
989 .into());
990 }
991 Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
992 }
993
994 let n = age_entry.len();
995 if n != age_exit.len() {
996 return Err(SurvivalConstructionError::IncompatibleDimensions {
997 reason: "survival time basis requires matching entry/exit lengths".to_string(),
998 }
999 .into());
1000 }
1001 for i in 0..n {
1002 if age_exit[i] < age_entry[i] {
1003 return Err(format!(
1004 "survival time basis requires exit times >= entry times (row {})",
1005 i + 1
1006 ));
1007 }
1008 }
1009 let log_entry = checked_log_survival_times(age_entry, "entry")?;
1010 let log_exit = checked_log_survival_times(age_exit, "exit")?;
1011
1012 fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1013 let n = log_entry.len();
1014 let entry_range = log_entry
1015 .iter()
1016 .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1017 (lo.min(v), hi.max(v))
1018 });
1019 let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1020 if entry_degenerate {
1021 log_exit.clone()
1022 } else {
1023 let mut combined = Array1::<f64>::zeros(2 * n);
1024 for i in 0..n {
1025 combined[i] = log_entry[i];
1026 combined[n + i] = log_exit[i];
1027 }
1028 combined
1029 }
1030 }
1031
1032 fn data_capped_internal_knots(
1055 combined: &Array1<f64>,
1056 degree: usize,
1057 requested_internal_knots: usize,
1058 ) -> usize {
1059 if requested_internal_knots == 0 {
1060 return 0;
1061 }
1062 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1063 sorted.sort_by(f64::total_cmp);
1064 let minval = sorted.first().copied().unwrap_or(0.0);
1065 let maxval = sorted.last().copied().unwrap_or(minval);
1066 if minval == maxval {
1067 return 1.min(requested_internal_knots);
1069 }
1070 let scale = (maxval - minval).abs().max(1.0);
1071 let tol = 1e-12 * scale;
1072 let mut distinct_interior = 0usize;
1075 let mut last: Option<f64> = None;
1076 for &x in &sorted {
1077 if x <= minval + tol || x >= maxval - tol {
1078 continue;
1079 }
1080 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1081 continue;
1082 }
1083 distinct_interior += 1;
1084 last = Some(x);
1085 }
1086 let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1089 let n_distinct = {
1095 let mut count = 0usize;
1096 let mut last: Option<f64> = None;
1097 for &x in &sorted {
1098 if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1099 continue;
1100 }
1101 count += 1;
1102 last = Some(x);
1103 }
1104 count
1105 };
1106 let dim_budget = n_distinct / 4;
1107 let dim_cap = dim_budget.saturating_sub(degree);
1108 cap = cap.min(dim_cap.max(1));
1109 cap.max(1)
1110 }
1111
1112 fn infer_survival_time_knots(
1113 combined: &Array1<f64>,
1114 knot_degree: usize,
1115 validation_degree: usize,
1116 num_internal_knots: usize,
1117 basis_options: BasisOptions,
1118 ) -> Result<Array1<f64>, String> {
1119 let num_internal_knots =
1125 data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1126
1127 fn quantile_knot_inference_needs_uniform_fallback(
1128 combined: &Array1<f64>,
1129 num_internal_knots: usize,
1130 ) -> bool {
1131 if num_internal_knots == 0 || combined.is_empty() {
1132 return false;
1133 }
1134
1135 let mut sorted: Vec<f64> = combined.iter().copied().collect();
1136 sorted.sort_by(f64::total_cmp);
1137 let minval = sorted[0];
1138 let maxval = *sorted.last().unwrap_or(&minval);
1139 if minval == maxval {
1140 return false;
1141 }
1142
1143 let scale = (maxval - minval).abs().max(1.0);
1144 let tol = 1e-12 * scale;
1145 let mut support = Vec::with_capacity(sorted.len());
1146 let mut last: Option<f64> = None;
1147 for &x in &sorted {
1148 if x <= minval + tol || x >= maxval - tol {
1149 continue;
1150 }
1151 if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1152 continue;
1153 }
1154 support.push(x);
1155 last = Some(x);
1156 }
1157 if support.is_empty() {
1158 return true;
1159 }
1160
1161 let n = support.len();
1162 let mut prev_q = minval;
1163 for j in 1..=num_internal_knots {
1164 let p = j as f64 / (num_internal_knots + 1) as f64;
1165 let pos = p * (n.saturating_sub(1) as f64);
1166 let lo = pos.floor() as usize;
1167 let hi = pos.ceil() as usize;
1168 let frac = pos - lo as f64;
1169 let q = if lo == hi {
1170 support[lo]
1171 } else {
1172 support[lo] * (1.0 - frac) + support[hi] * frac
1173 }
1174 .clamp(minval, maxval);
1175 if q <= prev_q + tol || q >= maxval - tol {
1176 return true;
1177 }
1178 prev_q = q;
1179 }
1180
1181 false
1182 }
1183
1184 let inferwith =
1185 |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1186 let built = build_bspline_basis_1d(
1187 combined.view(),
1188 &BSplineBasisSpec {
1189 degree: knot_degree,
1190 penalty_order: 2,
1191 knotspec: BSplineKnotSpec::Automatic {
1192 num_internal_knots: Some(num_internal_knots),
1193 placement,
1194 },
1195 double_penalty: false,
1196 identifiability: BSplineIdentifiability::None,
1197 boundary: OneDimensionalBoundary::Open,
1198 boundary_conditions: BSplineBoundaryConditions::default(),
1199 },
1200 )
1201 .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1202 let knots = match built.metadata {
1203 BasisMetadata::BSpline1D { knots, .. } => knots,
1204 _ => {
1205 return Err(
1206 "internal error: expected BSpline1D metadata for survival time basis"
1207 .to_string(),
1208 );
1209 }
1210 };
1211 create_basis::<Dense>(
1220 combined.view(),
1221 KnotSource::Provided(knots.view()),
1222 validation_degree,
1223 basis_options,
1224 )
1225 .map_err(|e| e.to_string())?;
1226 Ok(knots)
1227 };
1228
1229 if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1230 inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1231 } else {
1232 inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1233 }
1234 }
1235
1236 match cfg {
1237 SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1238 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1239 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1240 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1241 penalties: Vec::new(),
1242 nullspace_dims: Vec::new(),
1243 basisname: "none".to_string(),
1244 degree: None,
1245 knots: None,
1246 keep_cols: None,
1247 smooth_lambda: None,
1248 }),
1249 SurvivalTimeBasisConfig::Linear => {
1250 let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1251 let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1252 let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1253 for i in 0..n {
1254 x_entry_time[[i, 0]] = 1.0;
1255 x_exit_time[[i, 0]] = 1.0;
1256 x_entry_time[[i, 1]] = log_entry[i];
1257 x_exit_time[[i, 1]] = log_exit[i];
1258 x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1259 }
1260 Ok(SurvivalTimeBuildOutput {
1261 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1262 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1263 x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1264 penalties: Vec::new(),
1265 nullspace_dims: Vec::new(),
1266 basisname: "linear".to_string(),
1267 degree: None,
1268 knots: None,
1269 keep_cols: None,
1270 smooth_lambda: None,
1271 })
1272 }
1273 SurvivalTimeBasisConfig::BSpline {
1274 degree,
1275 knots,
1276 smooth_lambda,
1277 } => {
1278 let knotvec = if knots.is_empty() {
1279 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1280 "internal error: bspline time basis requested without knot source".to_string()
1281 })?;
1282 let combined = survival_time_knot_input(&log_entry, &log_exit);
1283 infer_survival_time_knots(
1284 &combined,
1285 degree,
1286 degree,
1287 num_internal_knots,
1288 BasisOptions::value(),
1289 )?
1290 } else {
1291 knots
1292 };
1293
1294 let entry_basis = build_bspline_basis_1d(
1295 log_entry.view(),
1296 &BSplineBasisSpec {
1297 degree,
1298 penalty_order: 2,
1299 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1300 double_penalty: false,
1301 identifiability: BSplineIdentifiability::None,
1302 boundary: OneDimensionalBoundary::Open,
1303 boundary_conditions: BSplineBoundaryConditions::default(),
1304 },
1305 )
1306 .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1307 let exit_basis = build_bspline_basis_1d(
1308 log_exit.view(),
1309 &BSplineBasisSpec {
1310 degree,
1311 penalty_order: 2,
1312 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1313 double_penalty: false,
1314 identifiability: BSplineIdentifiability::None,
1315 boundary: OneDimensionalBoundary::Open,
1316 boundary_conditions: BSplineBoundaryConditions::default(),
1317 },
1318 )
1319 .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1320
1321 let p_time = exit_basis.design.ncols();
1322 let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1326 let mut deriv_buf = vec![0.0_f64; p_time];
1327 for i in 0..n {
1328 deriv_buf.fill(0.0);
1329 evaluate_bspline_derivative_scalar(
1330 log_exit[i],
1331 knotvec.view(),
1332 degree,
1333 &mut deriv_buf,
1334 )
1335 .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1336 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1337 for j in 0..p_time {
1338 let v = deriv_buf[j] * chain;
1339 if v.abs() > 1e-15 {
1340 deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1341 }
1342 }
1343 }
1344 let x_derivative_time =
1345 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1346 {
1347 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1348 Err(_) => {
1349 let mut dense = Array2::<f64>::zeros((n, p_time));
1351 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1352 dense[[row, col]] = val;
1353 }
1354 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1355 }
1356 };
1357
1358 Ok(SurvivalTimeBuildOutput {
1359 x_entry_time: entry_basis.design,
1360 x_exit_time: exit_basis.design,
1361 x_derivative_time,
1362 nullspace_dims: entry_basis.nullspace_dims,
1363 penalties: entry_basis.penalties,
1364 basisname: "bspline".to_string(),
1365 degree: Some(degree),
1366 knots: Some(knotvec.to_vec()),
1367 keep_cols: None,
1368 smooth_lambda: Some(smooth_lambda),
1369 })
1370 }
1371 SurvivalTimeBasisConfig::ISpline {
1372 degree,
1373 knots,
1374 keep_cols,
1375 smooth_lambda,
1376 } => {
1377 let bspline_degree = degree
1378 .checked_add(1)
1379 .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1380 let knotvec = if knots.is_empty() {
1381 let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1382 "internal error: ispline time basis requested without knot source".to_string()
1383 })?;
1384 let combined = survival_time_knot_input(&log_entry, &log_exit);
1385 infer_survival_time_knots(
1386 &combined,
1387 bspline_degree,
1388 degree,
1389 num_internal_knots,
1390 BasisOptions::i_spline(),
1391 )?
1392 } else {
1393 knots
1394 };
1395
1396 let (db_exit_arc, _) = create_basis::<Dense>(
1397 log_exit.view(),
1398 KnotSource::Provided(knotvec.view()),
1399 bspline_degree,
1400 BasisOptions::first_derivative(),
1401 )
1402 .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1403
1404 let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1407 let (entry_arc, _) = create_basis::<Dense>(
1408 log_entry.view(),
1409 KnotSource::Provided(knotvec.view()),
1410 degree,
1411 BasisOptions::i_spline(),
1412 )
1413 .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1414 let (exit_arc, _) = create_basis::<Dense>(
1415 log_exit.view(),
1416 KnotSource::Provided(knotvec.view()),
1417 degree,
1418 BasisOptions::i_spline(),
1419 )
1420 .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1421
1422 let x_entry_full = entry_arc.as_ref();
1423 let x_exit_full = exit_arc.as_ref();
1424 let p_time_full = x_exit_full.ncols();
1425 if p_time_full == 0 {
1426 return Err(SurvivalConstructionError::BasisConstructionFailed {
1427 reason: "internal error: empty ispline time basis".to_string(),
1428 }
1429 .into());
1430 }
1431 let db_exit = db_exit_arc.as_ref();
1432 if db_exit.ncols() != p_time_full + 1 {
1433 return Err(
1434 "internal error: ispline derivative basis width must exceed basis width by one"
1435 .to_string(),
1436 );
1437 }
1438
1439 let keep_cols = if keep_cols.is_empty() {
1440 let constant_tol = 1e-12_f64;
1441 let mut inferred_keep_cols: Vec<usize> = Vec::new();
1442 for j in 0..p_time_full {
1443 let mut minv = f64::INFINITY;
1444 let mut maxv = f64::NEG_INFINITY;
1445 for i in 0..n {
1446 let ve = x_exit_full[[i, j]];
1447 let vs = x_entry_full[[i, j]];
1448 minv = minv.min(ve.min(vs));
1449 maxv = maxv.max(ve.max(vs));
1450 }
1451 if (maxv - minv) > constant_tol {
1452 inferred_keep_cols.push(j);
1453 }
1454 }
1455 inferred_keep_cols
1456 } else {
1457 keep_cols
1458 };
1459 if keep_cols.is_empty() {
1460 return Err(
1461 "internal error: ispline basis has no shape-varying time columns"
1462 .to_string(),
1463 );
1464 }
1465 if keep_cols.iter().any(|&j| j >= p_time_full) {
1466 return Err(SurvivalConstructionError::MissingColumn {
1467 reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1468 }
1469 .into());
1470 }
1471
1472 let p_time = keep_cols.len();
1473 let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1474 let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1475 (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1478 };
1479 let db_exit = db_exit_arc.as_ref();
1480
1481 let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1486 let mut found_nonfinite: Option<(usize, usize)> = None;
1487 for i in 0..n {
1488 let mut running = 0.0_f64;
1489 let mut d_i_log_full = vec![0.0_f64; p_time_full];
1490 for j in (1..db_exit.ncols()).rev() {
1491 let term = db_exit[[i, j]];
1492 if term.is_finite() {
1493 running += term;
1494 }
1495 d_i_log_full[j - 1] = running;
1496 }
1497 let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1498 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1499 let raw_v = d_i_log_full[j_old] * chain;
1500 let v = if (-1e-12..0.0).contains(&raw_v) {
1501 0.0
1502 } else {
1503 raw_v
1504 };
1505 if !v.is_finite() {
1506 found_nonfinite = Some((i, j_new));
1507 }
1508 if v < -1e-12 {
1509 return Err(format!(
1510 "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1511 i + 1,
1512 j_new + 1,
1513 v
1514 ));
1515 }
1516 if v.abs() > 1e-15 {
1517 deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1518 }
1519 }
1520 }
1521 if let Some((row, col)) = found_nonfinite {
1522 return Err(format!(
1523 "survival ispline derivative basis produced non-finite value at row {}, column {}",
1524 row + 1,
1525 col + 1
1526 ));
1527 }
1528 let x_derivative_time =
1529 match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1530 {
1531 Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1532 Err(_) => {
1533 let mut dense = Array2::<f64>::zeros((n, p_time));
1534 for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1535 dense[[row, col]] = val;
1536 }
1537 DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1538 }
1539 };
1540
1541 let penalty_basis = build_bspline_basis_1d(
1542 log_exit.view(),
1543 &BSplineBasisSpec {
1544 degree: bspline_degree,
1545 penalty_order: 2,
1546 knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1547 double_penalty: false,
1548 identifiability: BSplineIdentifiability::None,
1549 boundary: OneDimensionalBoundary::Open,
1550 boundary_conditions: BSplineBoundaryConditions::default(),
1551 },
1552 )
1553 .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1554 if penalty_basis.design.ncols() != p_time_full + 1 {
1555 return Err("internal error: ispline penalty dimension mismatch".to_string());
1556 }
1557 let mut penalties = Vec::<Array2<f64>>::new();
1591 for s_mat in &penalty_basis.penalties {
1592 if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1593 continue;
1594 }
1595 let s_increment = s_mat.slice(s![1.., 1..]);
1624 if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1625 return Err(format!(
1626 "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1627 s_increment.nrows(),
1628 s_increment.ncols(),
1629 ));
1630 }
1631 let mut s_full = s_increment.to_owned();
1636 symmetrize_in_place(&mut s_full);
1637 let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1641 for i in 0..p_time_full {
1642 for j in 0..p_time_full {
1643 let mut v = 0.0;
1644 for k in j..p_time_full {
1645 v += s_full[[i, k]];
1646 }
1647 s_mid_full[[i, j]] = v;
1648 }
1649 }
1650 let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1654 for i in 0..p_time_full {
1655 for j in 0..p_time_full {
1656 let mut v = 0.0;
1657 for k in i..p_time_full {
1658 v += s_mid_full[[k, j]];
1659 }
1660 s_full_congruent[[i, j]] = v;
1661 }
1662 }
1663 let mut local = Array2::<f64>::zeros((p_time, p_time));
1665 for (i_new, &i_old) in keep_cols.iter().enumerate() {
1666 for (j_new, &j_old) in keep_cols.iter().enumerate() {
1667 local[[i_new, j_new]] = 0.5
1670 * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1671 }
1672 }
1673 penalties.push(local);
1674 }
1675
1676 for (idx, s_mat) in penalties.iter().enumerate() {
1686 let p = s_mat.nrows();
1687 if p == 0 {
1688 continue;
1689 }
1690 if let Ok((evals, _)) =
1691 gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1692 {
1693 let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1694 "internal error: ispline penalty eigenvalues not contiguous".to_string()
1695 })?;
1696 let max_ev = evals_slice
1697 .iter()
1698 .copied()
1699 .fold(0.0_f64, |a, b| a.max(b.abs()))
1700 .max(1.0);
1701 let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1702 let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1703 if min_ev < neg_tol {
1704 return Err(format!(
1705 "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1706 indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1707 {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1708 ));
1709 }
1710 }
1711 }
1712
1713 let nullspace_dims: Vec<usize> = penalties
1717 .iter()
1718 .map(|s_mat| {
1719 let p = s_mat.nrows();
1720 if p == 0 {
1721 return 0;
1722 }
1723 match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1724 Ok((evals, _)) => {
1725 let evals_slice: &[f64] = evals.as_slice().unwrap();
1726 let max_ev = evals_slice
1727 .iter()
1728 .copied()
1729 .fold(0.0_f64, |a, b| a.max(b.abs()))
1730 .max(1.0);
1731 let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1732 evals_slice.iter().filter(|&&e| e <= threshold).count()
1733 }
1734 Err(_) => 0,
1735 }
1736 })
1737 .collect();
1738 Ok(SurvivalTimeBuildOutput {
1739 x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1740 x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1741 x_derivative_time,
1742 penalties,
1743 nullspace_dims,
1744 basisname: "ispline".to_string(),
1745 degree: Some(degree),
1746 knots: Some(knotvec.to_vec()),
1747 keep_cols: Some(keep_cols),
1748 smooth_lambda: Some(smooth_lambda),
1749 })
1750 }
1751 }
1752}
1753
1754pub fn resolved_survival_time_basis_config_from_build(
1755 basisname: &str,
1756 degree: Option<usize>,
1757 knots: Option<&Vec<f64>>,
1758 keep_cols: Option<&Vec<usize>>,
1759 smooth_lambda: Option<f64>,
1760) -> Result<SurvivalTimeBasisConfig, String> {
1761 match basisname {
1762 "none" => Ok(SurvivalTimeBasisConfig::None),
1763 "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1764 "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1765 degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1766 knots: Array1::from_vec(
1767 knots
1768 .cloned()
1769 .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1770 ),
1771 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1772 }),
1773 "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1774 degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1775 knots: Array1::from_vec(
1776 knots
1777 .cloned()
1778 .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1779 ),
1780 keep_cols: keep_cols
1781 .cloned()
1782 .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1783 smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1784 }),
1785 other => Err(format!("unsupported survival time basis '{other}'")),
1786 }
1787}
1788
1789pub fn resolve_survival_time_anchor_value(
1790 age_entry: &Array1<f64>,
1791 time_anchor: Option<f64>,
1792) -> Result<f64, String> {
1793 if age_entry.is_empty() {
1794 return Err("survival time anchor requires non-empty entry times".to_string());
1795 }
1796 let anchor = match time_anchor {
1797 Some(t_anchor) => {
1798 if !t_anchor.is_finite() || t_anchor < 0.0 {
1799 return Err(format!(
1800 "survival time anchor must be finite and non-negative, got {t_anchor}"
1801 ));
1802 }
1803 t_anchor
1804 }
1805 None => age_entry
1806 .iter()
1807 .copied()
1808 .min_by(f64::total_cmp)
1809 .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1810 };
1811 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1812}
1813
1814pub fn resolve_survival_marginal_slope_time_anchor_value(
1846 age_entry: &Array1<f64>,
1847 age_exit: &Array1<f64>,
1848 time_anchor: Option<f64>,
1849) -> Result<f64, String> {
1850 if age_entry.is_empty() || age_exit.is_empty() {
1851 return Err(
1852 "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1853 );
1854 }
1855 let anchor = match time_anchor {
1856 Some(t_anchor) => {
1857 if !t_anchor.is_finite() || t_anchor < 0.0 {
1858 return Err(format!(
1859 "survival time anchor must be finite and non-negative, got {t_anchor}"
1860 ));
1861 }
1862 t_anchor
1863 }
1864 None => {
1865 let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1866 sorted.sort_by(f64::total_cmp);
1867 let m = sorted.len();
1868 if m % 2 == 1 {
1869 sorted[m / 2]
1870 } else {
1871 0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1872 }
1873 }
1874 };
1875 Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1876}
1877
1878pub fn evaluate_survival_time_basis_row(
1879 age: f64,
1880 cfg: &SurvivalTimeBasisConfig,
1881) -> Result<Array1<f64>, String> {
1882 if !age.is_finite() || age < 0.0 {
1883 return Err(format!(
1884 "survival time basis row requires finite non-negative age, got {age}"
1885 ));
1886 }
1887 let age = age.max(SURVIVAL_TIME_FLOOR);
1888 let log_age = array![age.ln()];
1889 match cfg {
1890 SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1891 SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1892 SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1893 if knots.is_empty() {
1894 return Err(
1895 "survival BSpline anchor evaluation requires resolved knot metadata"
1896 .to_string(),
1897 );
1898 }
1899 let built = build_bspline_basis_1d(
1900 log_age.view(),
1901 &BSplineBasisSpec {
1902 degree: *degree,
1903 penalty_order: 2,
1904 knotspec: BSplineKnotSpec::Provided(knots.clone()),
1905 double_penalty: false,
1906 identifiability: BSplineIdentifiability::None,
1907 boundary: OneDimensionalBoundary::Open,
1908 boundary_conditions: BSplineBoundaryConditions::default(),
1909 },
1910 )
1911 .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1912 Ok(built.design.to_dense().row(0).to_owned())
1913 }
1914 SurvivalTimeBasisConfig::ISpline {
1915 degree,
1916 knots,
1917 keep_cols,
1918 ..
1919 } => {
1920 if knots.is_empty() {
1921 return Err(
1922 "survival ISpline anchor evaluation requires resolved knot metadata"
1923 .to_string(),
1924 );
1925 }
1926 let (basis_arc, _) = create_basis::<Dense>(
1927 log_age.view(),
1928 KnotSource::Provided(knots.view()),
1929 *degree,
1930 BasisOptions::i_spline(),
1931 )
1932 .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1933 let basis = basis_arc.as_ref();
1934 let row = basis.row(0);
1935 if keep_cols.is_empty() {
1936 return Ok(row.to_owned());
1937 }
1938 if keep_cols.iter().any(|&j| j >= row.len()) {
1939 return Err(SurvivalConstructionError::MissingColumn {
1940 reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
1941 }
1942 .into());
1943 }
1944 Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
1945 }
1946 }
1947}
1948
1949pub fn center_survival_time_designs_at_anchor(
1950 design_entry: &mut DesignMatrix,
1951 design_exit: &mut DesignMatrix,
1952 anchor_row: &Array1<f64>,
1953) -> Result<(), String> {
1954 if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
1955 return Err(format!(
1956 "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
1957 design_entry.ncols(),
1958 design_exit.ncols(),
1959 anchor_row.len()
1960 ));
1961 }
1962 fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
1965 let mut dense = dm.to_dense();
1966 for mut row in dense.rows_mut() {
1967 row -= &anchor.view();
1968 }
1969 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
1970 }
1971 center_dense(design_entry, anchor_row);
1972 center_dense(design_exit, anchor_row);
1973 Ok(())
1974}
1975
1976pub fn baseline_offset_theta_partials(
2006 age: f64,
2007 cfg: &SurvivalBaselineConfig,
2008) -> Result<Option<Vec<(f64, f64)>>, String> {
2009 let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2010 else {
2011 return Ok(None);
2012 };
2013
2014 match params {
2015 ValidatedBaselineTarget::Weibull { scale, shape } => {
2016 let eta = shape * (age.ln() - scale.ln());
2025 let o_d = shape / age;
2026 let d_eta_d_log_scale = -shape;
2027 let d_od_d_log_scale = 0.0;
2028 let d_eta_d_log_shape = eta;
2029 let d_od_d_log_shape = o_d;
2030 Ok(Some(vec![
2031 (d_eta_d_log_scale, d_od_d_log_scale),
2032 (d_eta_d_log_shape, d_od_d_log_shape),
2033 ]))
2034 }
2035 ValidatedBaselineTarget::Gompertz { shape, .. } => {
2036 let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2046 Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2047 }
2048 ValidatedBaselineTarget::GompertzMakeham {
2049 rate,
2050 shape,
2051 makeham,
2052 } => {
2053 let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2068 let cum_total = makeham * age + cum_g;
2069 if cum_total <= 0.0 || !cum_total.is_finite() {
2070 return Err(SurvivalConstructionError::DataValidationFailed {
2071 reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2072 }
2073 .into());
2074 }
2075 let inst_total = makeham + inst_g;
2076 let o_d = inst_total / cum_total;
2077 let inv_cum = 1.0 / cum_total;
2078 let d_cum_dlr = cum_g;
2083 let d_inst_dlr = inst_g;
2084 let d_eta_dlr = d_cum_dlr * inv_cum;
2085 let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2086 let (d_cum_dshape, d_inst_dshape) =
2088 gompertz_cumulative_shape_derivative(age, rate, shape);
2089 let d_eta_dshape = d_cum_dshape * inv_cum;
2090 let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2091 let d_cum_dlm = makeham * age;
2094 let d_inst_dlm = makeham;
2095 let d_eta_dlm = d_cum_dlm * inv_cum;
2096 let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2097 Ok(Some(vec![
2098 (d_eta_dlr, d_od_dlr),
2099 (d_eta_dshape, d_od_dshape),
2100 (d_eta_dlm, d_od_dlm),
2101 ]))
2102 }
2103 }
2104}
2105
2106fn baseline_chain_rule_gradient_with_partials<F>(
2134 label: &'static str,
2135 age_entry: ndarray::ArrayView1<'_, f64>,
2136 age_exit: ndarray::ArrayView1<'_, f64>,
2137 age_right: ndarray::ArrayView1<'_, f64>,
2138 cfg: &SurvivalBaselineConfig,
2139 residuals: &crate::survival::OffsetChannelResiduals,
2140 partials: F,
2141) -> Result<Option<Array1<f64>>, String>
2142where
2143 F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2144{
2145 let n = age_exit.len();
2146 if age_entry.len() != n
2147 || age_right.len() != n
2148 || residuals.exit.len() != n
2149 || residuals.entry.len() != n
2150 || residuals.derivative.len() != n
2151 || residuals.right.len() != n
2152 {
2153 return Err(format!(
2154 "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2155 age_entry.len(),
2156 n,
2157 age_right.len(),
2158 residuals.exit.len(),
2159 residuals.entry.len(),
2160 residuals.derivative.len(),
2161 residuals.right.len(),
2162 ));
2163 }
2164 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2167 let theta_dim = match probe_age {
2168 Some(t) => match partials(t, cfg)? {
2169 None => return Ok(None),
2170 Some(v) => v.len(),
2171 },
2172 None => {
2173 return Err(format!("{label}: no valid positive age for dim probe"));
2174 }
2175 };
2176 let mut grad = Array1::<f64>::zeros(theta_dim);
2187 for i in 0..n {
2188 let partials_exit = partials(age_exit[i], cfg)?
2190 .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2191 if partials_exit.len() != theta_dim {
2192 return Err(format!(
2193 "{label}: theta_dim drifted ({} != {})",
2194 partials_exit.len(),
2195 theta_dim
2196 ));
2197 }
2198 let r_x = residuals.exit[i];
2199 let r_d = residuals.derivative[i];
2200 for k in 0..theta_dim {
2201 let (d_eta_dk, d_od_dk) = partials_exit[k];
2202 grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2203 }
2204 let r_e = residuals.entry[i];
2208 if r_e != 0.0 {
2209 let partials_entry = partials(age_entry[i], cfg)?
2210 .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2211 for k in 0..theta_dim {
2212 grad[k] += r_e * partials_entry[k].0;
2213 }
2214 }
2215 let r_r = residuals.right[i];
2224 if r_r != 0.0 {
2225 let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2226 format!("{label}: unexpected None from partials at right boundary")
2227 })?;
2228 if partials_right.len() != theta_dim {
2229 return Err(format!(
2230 "{label}: theta_dim drifted at right boundary ({} != {})",
2231 partials_right.len(),
2232 theta_dim
2233 ));
2234 }
2235 for k in 0..theta_dim {
2236 grad[k] += r_r * partials_right[k].0;
2237 }
2238 }
2239 }
2240 Ok(Some(grad))
2241}
2242
2243pub fn baseline_chain_rule_gradient(
2277 age_entry: ndarray::ArrayView1<'_, f64>,
2278 age_exit: ndarray::ArrayView1<'_, f64>,
2279 age_right: ndarray::ArrayView1<'_, f64>,
2280 cfg: &SurvivalBaselineConfig,
2281 residuals: &crate::survival::OffsetChannelResiduals,
2282) -> Result<Option<Array1<f64>>, String> {
2283 baseline_chain_rule_gradient_with_partials(
2284 "baseline_chain_rule_gradient",
2285 age_entry,
2286 age_exit,
2287 age_right,
2288 cfg,
2289 residuals,
2290 baseline_offset_theta_partials,
2291 )
2292}
2293
2294pub fn marginal_slope_baseline_chain_rule_gradient(
2301 age_entry: ndarray::ArrayView1<'_, f64>,
2302 age_exit: ndarray::ArrayView1<'_, f64>,
2303 cfg: &SurvivalBaselineConfig,
2304 residuals: &crate::survival::OffsetChannelResiduals,
2305) -> Result<Option<Array1<f64>>, String> {
2306 baseline_chain_rule_gradient_with_partials(
2310 "marginal_slope_baseline_chain_rule_gradient",
2311 age_entry,
2312 age_exit,
2313 age_exit,
2314 cfg,
2315 residuals,
2316 marginal_slope_baseline_offset_theta_partials,
2317 )
2318}
2319
2320#[inline]
2324fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2325 if shape.abs() < 1e-10 {
2326 let x = shape * age;
2329 (
2330 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2331 rate * (1.0 + x + 0.5 * x * x),
2332 )
2333 } else {
2334 let shape_age = shape * age;
2335 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2336 let instant_hazard = rate * shape_age.exp();
2337 (cumulative_hazard, instant_hazard)
2338 }
2339}
2340
2341#[inline]
2357fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2358 let x = shape * age;
2359 let dinstg_dshape = rate * age * x.exp();
2360 let dhg_dshape = if x.abs() < 1e-4 {
2369 let t = age;
2370 rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2372 } else {
2373 let e = x.exp();
2375 let em1 = x.exp_m1();
2376 let numerator = age * e * shape - em1;
2377 rate * numerator / (shape * shape)
2378 };
2379 (dhg_dshape, dinstg_dshape)
2380}
2381
2382#[inline]
2387fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2388 if shape.abs() < 1e-10 {
2389 let t = age;
2399 let d_eta = 0.5 * t + shape * t * t / 12.0;
2400 let dlog_od = 0.5 * t - shape * t * t / 12.0;
2401 let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2402 (d_eta, o_d * dlog_od)
2403 } else {
2404 let x = shape * age;
2405 let e = x.exp();
2406 let em1 = x.exp_m1(); let d_eta = -1.0 / shape + age * e / em1;
2408 let o_d = shape * e / em1;
2410 let dlog_od = 1.0 / shape - age / em1;
2411 (d_eta, o_d * dlog_od)
2412 }
2413}
2414
2415#[derive(Clone, Copy, Debug)]
2424enum ValidatedBaselineTarget {
2425 Weibull { scale: f64, shape: f64 },
2426 Gompertz { rate: f64, shape: f64 },
2427 GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2428}
2429
2430fn validated_baseline_params(
2436 age: f64,
2437 cfg: &SurvivalBaselineConfig,
2438 context: &str,
2439) -> Result<Option<ValidatedBaselineTarget>, String> {
2440 if !age.is_finite() || age <= 0.0 {
2441 return Err(format!(
2442 "survival ages must be finite and positive for {context}"
2443 ));
2444 }
2445
2446 match cfg.target {
2447 SurvivalBaselineTarget::Linear => Ok(None),
2448 SurvivalBaselineTarget::Weibull => {
2449 let scale = cfg
2450 .scale
2451 .ok_or_else(|| "weibull missing scale".to_string())?;
2452 let shape = cfg
2453 .shape
2454 .ok_or_else(|| "weibull missing shape".to_string())?;
2455 if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2456 return Err(SurvivalConstructionError::InvalidConfig {
2457 reason: "weibull baseline requires finite positive scale and shape".to_string(),
2458 }
2459 .into());
2460 }
2461 Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2462 }
2463 SurvivalBaselineTarget::Gompertz => {
2464 let rate = cfg
2465 .rate
2466 .ok_or_else(|| "gompertz missing rate".to_string())?;
2467 let shape = cfg
2468 .shape
2469 .ok_or_else(|| "gompertz missing shape".to_string())?;
2470 if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2471 return Err(
2472 "gompertz baseline requires finite positive rate and finite shape".to_string(),
2473 );
2474 }
2475 Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2476 }
2477 SurvivalBaselineTarget::GompertzMakeham => {
2478 let rate = cfg
2479 .rate
2480 .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2481 let shape = cfg
2482 .shape
2483 .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2484 let makeham = cfg
2485 .makeham
2486 .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2487 if !(rate.is_finite()
2488 && shape.is_finite()
2489 && makeham.is_finite()
2490 && rate > 0.0
2491 && makeham > 0.0)
2492 {
2493 return Err(
2494 "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2495 .to_string(),
2496 );
2497 }
2498 Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2499 rate,
2500 shape,
2501 makeham,
2502 }))
2503 }
2504 }
2505}
2506
2507fn survival_hazard_theta_partials(
2508 age: f64,
2509 cfg: &SurvivalBaselineConfig,
2510) -> Result<Option<Vec<(f64, f64)>>, String> {
2511 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2512 return Ok(None);
2513 };
2514
2515 match params {
2516 ValidatedBaselineTarget::Weibull { scale, shape } => {
2517 let log_time_ratio = age.ln() - scale.ln();
2518 let cumulative_hazard = (age / scale).powf(shape);
2519 let instant_hazard = shape * cumulative_hazard / age;
2520 let eta = shape * log_time_ratio;
2521 Ok(Some(vec![
2522 (-shape * cumulative_hazard, -shape * instant_hazard),
2523 (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2524 ]))
2525 }
2526 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2527 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2528 let (d_cum_dshape, d_inst_dshape) =
2529 gompertz_cumulative_shape_derivative(age, rate, shape);
2530 Ok(Some(vec![
2531 (cumulative_hazard, instant_hazard),
2532 (d_cum_dshape, d_inst_dshape),
2533 ]))
2534 }
2535 ValidatedBaselineTarget::GompertzMakeham {
2536 rate,
2537 shape,
2538 makeham,
2539 } => {
2540 let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2541 let (d_cum_dshape, d_inst_dshape) =
2542 gompertz_cumulative_shape_derivative(age, rate, shape);
2543 Ok(Some(vec![
2544 (cum_gompertz, inst_gompertz),
2545 (d_cum_dshape, d_inst_dshape),
2546 (makeham * age, makeham),
2547 ]))
2548 }
2549 }
2550}
2551
2552fn survival_cumulative_and_instant_hazard(
2553 age: f64,
2554 cfg: &SurvivalBaselineConfig,
2555) -> Result<Option<(f64, f64)>, String> {
2556 let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2557 return Ok(None);
2558 };
2559
2560 match params {
2561 ValidatedBaselineTarget::Weibull { scale, shape } => {
2562 let cumulative_hazard = (age / scale).powf(shape);
2563 let instant_hazard = shape * cumulative_hazard / age;
2564 Ok(Some((cumulative_hazard, instant_hazard)))
2565 }
2566 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2567 let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2568 Ok(Some((cumulative_hazard, instant_hazard)))
2569 }
2570 ValidatedBaselineTarget::GompertzMakeham {
2571 rate,
2572 shape,
2573 makeham,
2574 } => {
2575 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2576 Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2577 }
2578 }
2579}
2580
2581#[derive(Clone, Copy, Debug)]
2582struct MarginalSlopeBaselinePoint {
2583 instant_hazard: f64,
2584 q: f64,
2585 q_t: f64,
2586}
2587
2588fn evaluate_marginal_slope_baseline_point(
2589 age: f64,
2590 cfg: &SurvivalBaselineConfig,
2591) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2592 let Some((cumulative_hazard, instant_hazard)) =
2593 survival_cumulative_and_instant_hazard(age, cfg)?
2594 else {
2595 return Ok(None);
2596 };
2597 if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2598 return Err(format!(
2599 "{} marginal-slope baseline produced non-positive cumulative hazard",
2600 survival_baseline_targetname(cfg.target)
2601 ));
2602 }
2603 if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2604 return Err(format!(
2605 "{} marginal-slope baseline produced non-positive instant hazard",
2606 survival_baseline_targetname(cfg.target)
2607 ));
2608 }
2609 let survival = (-cumulative_hazard).exp();
2610 if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2611 return Err(format!(
2612 "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2613 survival_baseline_targetname(cfg.target)
2614 ));
2615 }
2616 let q = -standard_normal_quantile(survival).map_err(|e| {
2617 format!(
2618 "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2619 survival_baseline_targetname(cfg.target)
2620 )
2621 })?;
2622 let phi_q = normal_pdf(q);
2623 if !(phi_q.is_finite() && phi_q > 0.0) {
2624 return Err(format!(
2625 "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2626 survival_baseline_targetname(cfg.target)
2627 ));
2628 }
2629 Ok(Some(MarginalSlopeBaselinePoint {
2630 instant_hazard,
2631 q,
2632 q_t: instant_hazard * survival / phi_q,
2633 }))
2634}
2635
2636pub fn evaluate_survival_baseline(
2639 age: f64,
2640 cfg: &SurvivalBaselineConfig,
2641) -> Result<(f64, f64), String> {
2642 if !age.is_finite() || age < 0.0 {
2643 return Err(
2644 "survival ages must be finite and non-negative for baseline target evaluation"
2645 .to_string(),
2646 );
2647 }
2648
2649 if age == 0.0 {
2660 return match cfg.target {
2661 SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2662 SurvivalBaselineTarget::Weibull
2663 | SurvivalBaselineTarget::Gompertz
2664 | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2665 };
2666 }
2667
2668 let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2669 return Ok((0.0, 0.0));
2670 };
2671
2672 match params {
2673 ValidatedBaselineTarget::Weibull { scale, shape } => {
2674 let eta = shape * (age.ln() - scale.ln());
2675 let derivative = shape / age;
2676 Ok((eta, derivative))
2677 }
2678 ValidatedBaselineTarget::Gompertz { rate, shape } => {
2679 let (h, inst) = gompertz_hazard_components(age, rate, shape);
2680 if h <= 0.0 || !h.is_finite() {
2681 return Err(if shape.abs() < 1e-10 {
2682 "invalid gompertz baseline at near-zero shape".to_string()
2683 } else {
2684 "gompertz baseline produced non-positive cumulative hazard".to_string()
2685 });
2686 }
2687 let derivative = inst / h;
2688 Ok((h.ln(), derivative))
2689 }
2690 ValidatedBaselineTarget::GompertzMakeham {
2691 rate,
2692 shape,
2693 makeham,
2694 } => {
2695 let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2696 let h = makeham * age + h_gompertz;
2697 if h <= 0.0 || !h.is_finite() {
2698 return Err(
2699 "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2700 );
2701 }
2702 let inst = makeham + inst_gompertz;
2703 let derivative = inst / h;
2704 Ok((h.ln(), derivative))
2705 }
2706 }
2707}
2708
2709pub fn evaluate_survival_marginal_slope_baseline(
2715 age: f64,
2716 cfg: &SurvivalBaselineConfig,
2717) -> Result<(f64, f64), String> {
2718 if age == 0.0 {
2730 return Ok((0.0, 0.0));
2731 }
2732 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2733 return Ok((0.0, 0.0));
2734 };
2735 Ok((point.q, point.q_t))
2736}
2737
2738pub fn marginal_slope_baseline_offset_theta_partials(
2751 age: f64,
2752 cfg: &SurvivalBaselineConfig,
2753) -> Result<Option<Vec<(f64, f64)>>, String> {
2754 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2755 return Ok(None);
2756 };
2757 let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2758 .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2759 let a = point.q_t / point.instant_hazard;
2760 let a_log_derivative_factor = point.q * a - 1.0;
2761 Ok(Some(
2762 hazard_partials
2763 .into_iter()
2764 .map(|(d_h_cum, d_h_inst)| {
2765 (
2766 a * d_h_cum,
2767 a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2768 )
2769 })
2770 .collect(),
2771 ))
2772}
2773
2774pub fn marginal_slope_baseline_chain_rule_hessian(
2777 age_entry: ndarray::ArrayView1<'_, f64>,
2778 age_exit: ndarray::ArrayView1<'_, f64>,
2779 cfg: &SurvivalBaselineConfig,
2780 residuals: &crate::survival::OffsetChannelResiduals,
2781 curvatures: &crate::survival::OffsetChannelCurvatures,
2782) -> Result<Option<Array2<f64>>, String> {
2783 let n = age_exit.len();
2784 if age_entry.len() != n
2785 || residuals.exit.len() != n
2786 || residuals.entry.len() != n
2787 || residuals.derivative.len() != n
2788 || curvatures.rows.len() != n
2789 {
2790 return Err(format!(
2791 "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2792 age_entry.len(),
2793 n,
2794 residuals.exit.len(),
2795 residuals.entry.len(),
2796 residuals.derivative.len(),
2797 curvatures.rows.len(),
2798 ));
2799 }
2800 let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2801 let dim = match probe_age {
2802 Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2803 None => return Ok(None),
2804 Some(parts) => parts.first.len(),
2805 },
2806 None => {
2807 return Err(
2808 "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2809 .to_string(),
2810 );
2811 }
2812 };
2813 let hessian = (0..n)
2818 .into_par_iter()
2819 .try_fold(
2820 || Array2::<f64>::zeros((dim, dim)),
2821 |mut acc, i| -> Result<Array2<f64>, String> {
2822 let exit_parts =
2823 marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2824 .ok_or_else(|| {
2825 "unexpected None from marginal-slope second partials at exit"
2826 .to_string()
2827 })?;
2828 if exit_parts.first.len() != dim {
2829 return Err(
2830 "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2831 );
2832 }
2833 let mut entry_parts = None;
2834 if residuals.entry[i] != 0.0 {
2835 entry_parts = Some(
2836 marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2837 .ok_or_else(|| {
2838 "unexpected None from marginal-slope second partials at entry"
2839 .to_string()
2840 })?,
2841 );
2842 }
2843 for a in 0..dim {
2844 for b in 0..dim {
2845 let j_exit_a = exit_parts.first[a].0;
2846 let j_exit_b = exit_parts.first[b].0;
2847 let j_deriv_a = exit_parts.first[a].1;
2848 let j_deriv_b = exit_parts.first[b].1;
2849 let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2850 + residuals.derivative[i] * exit_parts.second[a][b].1;
2851 if let Some(parts) = entry_parts.as_ref() {
2852 value += residuals.entry[i] * parts.second[a][b].0;
2853 }
2854 let curv = curvatures.rows[i];
2855 let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2856 let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2857 let ja = [j_entry_a, j_exit_a, j_deriv_a];
2858 let jb = [j_entry_b, j_exit_b, j_deriv_b];
2859 for u in 0..3 {
2860 for v in 0..3 {
2861 value += ja[u] * curv[u][v] * jb[v];
2862 }
2863 }
2864 acc[[a, b]] += value;
2865 }
2866 }
2867 Ok(acc)
2868 },
2869 )
2870 .try_reduce(|| Array2::<f64>::zeros((dim, dim)), |a, b| Ok(a + b))?;
2871 Ok(Some(hessian))
2872}
2873
2874struct MarginalSlopeThetaSecondPartials {
2875 first: Vec<(f64, f64)>,
2876 second: Vec<Vec<(f64, f64)>>,
2877}
2878
2879fn marginal_slope_baseline_offset_theta_second_partials(
2880 age: f64,
2881 cfg: &SurvivalBaselineConfig,
2882) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2883 let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2884 return Ok(None);
2885 };
2886 let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2887 return Ok(None);
2888 };
2889 let (cum_hazard, instant_hazard) = hazard;
2890 let survival = (-cum_hazard).exp();
2891 let a = survival / normal_pdf(point.q);
2892 let b = point.q * a - 1.0;
2893 let b_factor = a + point.q * b;
2894 let dim = first.len();
2895 let mut first_out = Vec::with_capacity(dim);
2896 let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2897 for i in 0..dim {
2898 let (h_i, inst_i) = first[i];
2899 first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2900 }
2901 for i in 0..dim {
2902 for j in 0..dim {
2903 let (h_i, inst_i) = first[i];
2904 let (h_j, inst_j) = first[j];
2905 let (h_ij, inst_ij) = second[i][j];
2906 let a_j = a * b * h_j;
2907 let b_j = a * h_j * b_factor;
2908 let q_ij = a * h_ij + a * b * h_i * h_j;
2909 let qt_inner_i = inst_i + instant_hazard * b * h_i;
2910 let qt_ij = a_j * qt_inner_i
2911 + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2912 second_out[i][j] = (q_ij, qt_ij);
2913 }
2914 }
2915 Ok(Some(MarginalSlopeThetaSecondPartials {
2916 first: first_out,
2917 second: second_out,
2918 }))
2919}
2920
2921type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2922
2923fn survival_hazard_theta_first_second(
2924 age: f64,
2925 cfg: &SurvivalBaselineConfig,
2926) -> Result<Option<HazardFirstSecond>, String> {
2927 let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2928 return Ok(None);
2929 };
2930 let first = survival_hazard_theta_partials(age, cfg)?
2931 .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2932 let dim = first.len();
2933 let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2934 match cfg.target {
2935 SurvivalBaselineTarget::Linear => return Ok(None),
2936 SurvivalBaselineTarget::Weibull => {
2937 let scale = cfg
2938 .scale
2939 .ok_or_else(|| "weibull missing scale".to_string())?;
2940 let shape = cfg
2941 .shape
2942 .ok_or_else(|| "weibull missing shape".to_string())?;
2943 let log_time_ratio = age.ln() - scale.ln();
2944 let cumulative_hazard = hazard.0;
2945 let instant_hazard = hazard.1;
2946 let eta = shape * log_time_ratio;
2947 second[0][0] = (
2948 shape * shape * cumulative_hazard,
2949 shape * shape * instant_hazard,
2950 );
2951 second[0][1] = (
2952 -shape * cumulative_hazard * (1.0 + eta),
2953 -shape * instant_hazard * (2.0 + eta),
2954 );
2955 second[1][0] = second[0][1];
2956 second[1][1] = (
2957 eta * cumulative_hazard * (1.0 + eta),
2958 (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
2959 );
2960 }
2961 SurvivalBaselineTarget::Gompertz => {
2962 let rate = cfg
2963 .rate
2964 .ok_or_else(|| "gompertz missing rate".to_string())?;
2965 let shape = cfg
2966 .shape
2967 .ok_or_else(|| "gompertz missing shape".to_string())?;
2968 second[0][0] = first[0];
2969 second[0][1] = first[1];
2970 second[1][0] = first[1];
2971 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2972 }
2973 SurvivalBaselineTarget::GompertzMakeham => {
2974 let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
2975 let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
2976 second[0][0] = first[0];
2977 second[0][1] = first[1];
2978 second[1][0] = first[1];
2979 second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2980 second[2][2] = first[2];
2981 }
2982 }
2983 Ok(Some((hazard, first, second)))
2984}
2985
2986#[inline]
2987fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2988 let x = shape * age;
2989 if x.abs() < 1e-3 {
3001 let t = age;
3002 (
3003 rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3004 rate * t * t * (1.0 + x + 0.5 * x * x),
3005 )
3006 } else {
3007 let e = x.exp();
3008 let em1 = x.exp_m1();
3009 let n = shape * age * e - em1;
3010 (
3011 rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3012 rate * age * age * e,
3013 )
3014 }
3015}
3016
3017#[derive(Clone, Copy)]
3022enum BaselineOffsetEvaluator {
3023 LogCumulativeHazard,
3024 ProbitSurvival,
3025}
3026
3027impl BaselineOffsetEvaluator {
3028 fn length_error(self) -> String {
3029 match self {
3030 Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3031 reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3032 }
3033 .into(),
3034 Self::ProbitSurvival => {
3035 "survival probit baseline offsets require matching entry/exit lengths".to_string()
3036 }
3037 }
3038 }
3039
3040 fn finite_error(self) -> &'static str {
3041 match self {
3042 Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3043 Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3044 }
3045 }
3046
3047 fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3048 match self {
3049 Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3050 Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3051 }
3052 }
3053
3054 fn exit_is_finite(self, value: f64, age: f64) -> bool {
3055 match self {
3056 Self::LogCumulativeHazard => {
3057 value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3058 }
3059 Self::ProbitSurvival => value.is_finite(),
3060 }
3061 }
3062}
3063
3064fn build_survival_offsets_with_evaluator(
3065 age_entry: &Array1<f64>,
3066 age_exit: &Array1<f64>,
3067 cfg: &SurvivalBaselineConfig,
3068 evaluator: BaselineOffsetEvaluator,
3069) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3070 if age_entry.len() != age_exit.len() {
3071 return Err(evaluator.length_error());
3072 }
3073 let n = age_entry.len();
3074 let triples: Vec<(f64, f64, f64)> = (0..n)
3077 .into_par_iter()
3078 .map(|i| -> Result<(f64, f64, f64), String> {
3079 let entry_age = age_entry[i];
3083 let e0 = if !entry_age.is_finite() {
3084 return Err(SurvivalConstructionError::DataValidationFailed {
3085 reason: format!("non-finite entry age at row {i}"),
3086 }
3087 .into());
3088 } else if entry_age <= 0.0 {
3089 0.0
3090 } else {
3091 evaluator.evaluate(entry_age, cfg)?.0
3092 };
3093 let exit_age = age_exit[i];
3094 let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3095 if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3096 return Err(SurvivalConstructionError::DataValidationFailed {
3097 reason: evaluator.finite_error().to_string(),
3098 }
3099 .into());
3100 }
3101 Ok((e0, e1, d1))
3102 })
3103 .collect::<Result<Vec<_>, String>>()?;
3104 let mut eta_entry = Array1::<f64>::zeros(n);
3105 let mut eta_exit = Array1::<f64>::zeros(n);
3106 let mut derivative_exit = Array1::<f64>::zeros(n);
3107 for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3108 eta_entry[i] = e0;
3109 eta_exit[i] = e1;
3110 derivative_exit[i] = d1;
3111 }
3112 Ok((eta_entry, eta_exit, derivative_exit))
3113}
3114
3115pub fn build_survival_baseline_offsets(
3118 age_entry: &Array1<f64>,
3119 age_exit: &Array1<f64>,
3120 cfg: &SurvivalBaselineConfig,
3121) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3122 build_survival_offsets_with_evaluator(
3123 age_entry,
3124 age_exit,
3125 cfg,
3126 BaselineOffsetEvaluator::LogCumulativeHazard,
3127 )
3128}
3129
3130pub fn build_survival_marginal_slope_baseline_offsets(
3133 age_entry: &Array1<f64>,
3134 age_exit: &Array1<f64>,
3135 cfg: &SurvivalBaselineConfig,
3136) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3137 build_survival_offsets_with_evaluator(
3138 age_entry,
3139 age_exit,
3140 cfg,
3141 BaselineOffsetEvaluator::ProbitSurvival,
3142 )
3143}
3144
3145pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3146 matches!(
3147 inverse_link,
3148 Some(
3149 InverseLink::Standard(StandardLink::Probit)
3150 | InverseLink::LatentCLogLog(_)
3151 | InverseLink::Sas(_)
3152 | InverseLink::BetaLogistic(_)
3153 | InverseLink::Mixture(_)
3154 )
3155 )
3156}
3157
3158pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3159 match likelihood_mode {
3160 SurvivalLikelihoodMode::LocationScale
3161 | SurvivalLikelihoodMode::Latent
3162 | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3163 SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3164 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3165 }
3166}
3167
3168pub fn build_survival_time_offsets_for_likelihood(
3169 age_entry: &Array1<f64>,
3170 age_exit: &Array1<f64>,
3171 baseline_cfg: &SurvivalBaselineConfig,
3172 likelihood_mode: SurvivalLikelihoodMode,
3173 inverse_link: Option<&InverseLink>,
3174) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3175 if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3176 || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3177 && location_scale_uses_probit_survival_baseline(inverse_link))
3178 {
3179 build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3180 } else {
3181 build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3182 }
3183}
3184
3185pub fn add_survival_time_derivative_guard_offset(
3186 age_entry: &Array1<f64>,
3187 age_exit: &Array1<f64>,
3188 anchor_time: f64,
3189 derivative_guard: f64,
3190 eta_offset_entry: &mut Array1<f64>,
3191 eta_offset_exit: &mut Array1<f64>,
3192 derivative_offset_exit: &mut Array1<f64>,
3193) -> Result<(), String> {
3194 if derivative_guard <= 0.0 {
3195 return Ok(());
3196 }
3197 let n = age_entry.len();
3198 if age_exit.len() != n
3199 || eta_offset_entry.len() != n
3200 || eta_offset_exit.len() != n
3201 || derivative_offset_exit.len() != n
3202 {
3203 return Err(SurvivalConstructionError::IncompatibleDimensions {
3204 reason: "survival derivative-guard offset lengths must match".to_string(),
3205 }
3206 .into());
3207 }
3208 for i in 0..n {
3209 eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3210 eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3211 derivative_offset_exit[i] += derivative_guard;
3212 }
3213 Ok(())
3214}
3215
3216#[derive(Clone, Debug)]
3217pub struct LatentSurvivalBaselineOffsets {
3218 pub loaded_eta_entry: Array1<f64>,
3219 pub loaded_eta_exit: Array1<f64>,
3220 pub loaded_derivative_exit: Array1<f64>,
3221 pub unloaded_mass_entry: Array1<f64>,
3222 pub unloaded_mass_exit: Array1<f64>,
3223 pub unloaded_hazard_exit: Array1<f64>,
3224}
3225
3226pub fn build_latent_survival_baseline_offsets(
3227 age_entry: &Array1<f64>,
3228 age_exit: &Array1<f64>,
3229 cfg: &SurvivalBaselineConfig,
3230 loading: HazardLoading,
3231) -> Result<LatentSurvivalBaselineOffsets, String> {
3232 if age_entry.len() != age_exit.len() {
3233 return Err(
3234 "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3235 );
3236 }
3237
3238 fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3239 if shape.abs() < 1e-10 {
3240 let x = shape * age;
3247 return (
3248 rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3249 rate * (1.0 + x + 0.5 * x * x),
3250 );
3251 }
3252 let shape_age = shape * age;
3253 let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3254 let instant_hazard = rate * shape_age.exp();
3255 (cumulative_hazard, instant_hazard)
3256 }
3257
3258 let n = age_entry.len();
3259
3260 let rows: Vec<[f64; 6]> = (0..n)
3263 .into_par_iter()
3264 .map(|i| -> Result<[f64; 6], String> {
3265 let entry = age_entry[i];
3266 let exit = age_exit[i];
3267 if !entry.is_finite()
3268 || !exit.is_finite()
3269 || entry <= 0.0
3270 || exit <= 0.0
3271 || exit < entry
3272 {
3273 return Err(format!(
3274 "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3275 i + 1
3276 ));
3277 }
3278 match loading {
3279 HazardLoading::Full => {
3280 let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3281 let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3282 Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3283 }
3284 HazardLoading::LoadedVsUnloaded => {
3285 if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3286 return Err(format!(
3287 "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3288 survival_baseline_targetname(cfg.target)
3289 ));
3290 }
3291 let rate = cfg.rate.ok_or_else(|| {
3292 "gompertz-makeham latent survival is missing baseline rate".to_string()
3293 })?;
3294 let shape = cfg.shape.ok_or_else(|| {
3295 "gompertz-makeham latent survival is missing baseline shape".to_string()
3296 })?;
3297 let makeham = cfg.makeham.ok_or_else(|| {
3298 "gompertz-makeham latent survival is missing baseline makeham".to_string()
3299 })?;
3300 let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3301 let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3302 if !(loaded_entry.is_finite()
3303 && loaded_entry > 0.0
3304 && loaded_exit.is_finite()
3305 && loaded_exit > 0.0
3306 && loaded_hazard.is_finite()
3307 && loaded_hazard > 0.0)
3308 {
3309 return Err(format!(
3310 "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3311 i + 1
3312 ));
3313 }
3314 Ok([
3315 loaded_entry.ln(),
3316 loaded_exit.ln(),
3317 loaded_hazard / loaded_exit,
3318 makeham * entry,
3319 makeham * exit,
3320 makeham,
3321 ])
3322 }
3323 }
3324 })
3325 .collect::<Result<Vec<_>, String>>()?;
3326
3327 let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3328 let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3329 let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3330 let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3331 let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3332 let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3333 for (i, row) in rows.into_iter().enumerate() {
3334 loaded_eta_entry[i] = row[0];
3335 loaded_eta_exit[i] = row[1];
3336 loaded_derivative_exit[i] = row[2];
3337 unloaded_mass_entry[i] = row[3];
3338 unloaded_mass_exit[i] = row[4];
3339 unloaded_hazard_exit[i] = row[5];
3340 }
3341
3342 Ok(LatentSurvivalBaselineOffsets {
3343 loaded_eta_entry,
3344 loaded_eta_exit,
3345 loaded_derivative_exit,
3346 unloaded_mass_entry,
3347 unloaded_mass_exit,
3348 unloaded_hazard_exit,
3349 })
3350}
3351
3352pub fn build_survival_timewiggle_derivative_design(
3357 eta_exit: &Array1<f64>,
3358 derivative_exit: &Array1<f64>,
3359 knots: &Array1<f64>,
3360 degree: usize,
3361) -> Result<Array2<f64>, String> {
3362 let mut design_derivative_exit =
3363 monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3364 for i in 0..design_derivative_exit.nrows() {
3365 let chain = derivative_exit[i];
3366 for j in 0..design_derivative_exit.ncols() {
3367 design_derivative_exit[[i, j]] *= chain;
3368 }
3369 }
3370 Ok(design_derivative_exit)
3371}
3372
3373pub fn build_survival_timewiggle_from_baseline(
3383 eta_entry: &Array1<f64>,
3384 eta_exit: &Array1<f64>,
3385 derivative_exit: &Array1<f64>,
3386 cfg: &LinkWiggleFormulaSpec,
3387) -> Result<SurvivalTimeWiggleBuild, String> {
3388 if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3389 return Err(
3390 "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3391 );
3392 }
3393 let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3396 && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3397 && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3398 if all_zero {
3399 return Err(
3400 "timewiggle requires a non-linear scalar survival baseline target; \
3401 the provided baseline offsets are all zero (linear baseline)"
3402 .to_string(),
3403 );
3404 }
3405 let n = eta_exit.len();
3406 let mut seed = Array1::<f64>::zeros(2 * n);
3407 for i in 0..n {
3408 seed[i] = eta_entry[i];
3409 seed[n + i] = eta_exit[i];
3410 }
3411 let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3415 let wiggle_cfg = WiggleBlockConfig {
3416 degree: cfg.degree,
3417 num_internal_knots: cfg.num_internal_knots,
3418 penalty_order: primary_order,
3419 double_penalty: cfg.double_penalty,
3420 };
3421 let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3422 append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3423 let ncols = combined_block.design.ncols();
3424 Ok(SurvivalTimeWiggleBuild {
3425 nullspace_dims: combined_block.nullspace_dims.clone(),
3426 penalties: {
3427 combined_block
3428 .penalties
3429 .into_iter()
3430 .map(|ps| ps.to_global(ncols))
3431 .collect()
3432 },
3433 knots,
3434 degree: cfg.degree,
3435 ncols,
3436 })
3437}
3438
3439pub fn append_zero_tail_columns(
3440 x_entry: &mut DesignMatrix,
3441 x_exit: &mut DesignMatrix,
3442 x_derivative: &mut DesignMatrix,
3443 tail_cols: usize,
3444) {
3445 if tail_cols == 0 {
3446 return;
3447 }
3448 fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3451 let old = dm.to_dense();
3452 let n = old.nrows();
3453 let p_base = old.ncols();
3454 let mut out = Array2::<f64>::zeros((n, p_base + tail));
3455 out.slice_mut(s![.., 0..p_base]).assign(&old);
3456 *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3457 }
3458 append_dense(x_entry, tail_cols);
3459 append_dense(x_exit, tail_cols);
3460 append_dense(x_derivative, tail_cols);
3461}
3462
3463pub fn build_time_varying_survival_covariate_template(
3474 age_entry: &Array1<f64>,
3475 age_exit: &Array1<f64>,
3476 time_k: usize,
3477 time_degree: usize,
3478 block_name: &str,
3479) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3480 if time_k < time_degree + 1 {
3481 return Err(format!(
3482 "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3483 time_degree + 1
3484 ));
3485 }
3486 let num_internal_knots = time_k - (time_degree + 1);
3487
3488 let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3489 let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3490
3491 let time_spec = BSplineBasisSpec {
3492 degree: time_degree,
3493 penalty_order: 2,
3494 knotspec: BSplineKnotSpec::Automatic {
3495 num_internal_knots: Some(num_internal_knots),
3496 placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3497 },
3498 double_penalty: false,
3499 identifiability: BSplineIdentifiability::None,
3500 boundary: OneDimensionalBoundary::Open,
3501 boundary_conditions: BSplineBoundaryConditions::default(),
3502 };
3503
3504 let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3505 .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3506 let time_design_exit = time_build.design.to_dense();
3507
3508 let knots = match &time_build.metadata {
3509 BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3510 _ => {
3511 return Err(format!(
3512 "{block_name} time-margin basis returned unexpected metadata type"
3513 ));
3514 }
3515 };
3516
3517 let time_build_entry = build_bspline_basis_1d(
3518 log_entry.view(),
3519 &BSplineBasisSpec {
3520 degree: time_degree,
3521 penalty_order: 2,
3522 knotspec: BSplineKnotSpec::Provided(knots.clone()),
3523 double_penalty: false,
3524 identifiability: BSplineIdentifiability::None,
3525 boundary: OneDimensionalBoundary::Open,
3526 boundary_conditions: BSplineBoundaryConditions::default(),
3527 },
3528 )
3529 .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3530 let time_design_entry = time_build_entry.design.to_dense();
3531 let p_time = time_design_exit.ncols();
3532 let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3533 time_design_derivative_exit
3537 .as_slice_mut()
3538 .expect("zeros are contiguous")
3539 .par_chunks_mut(p_time)
3540 .enumerate()
3541 .try_for_each(|(i, row_out)| -> Result<(), String> {
3542 let mut deriv_buf = vec![0.0_f64; p_time];
3543 evaluate_bspline_derivative_scalar(
3544 log_exit[i],
3545 knots.view(),
3546 time_degree,
3547 &mut deriv_buf,
3548 )
3549 .map_err(|e| {
3550 format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3551 })?;
3552 let chain = 1.0 / age_exit[i].max(1e-12);
3553 for j in 0..p_time {
3554 row_out[j] = deriv_buf[j] * chain;
3555 }
3556 Ok(())
3557 })?;
3558
3559 Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3560 time_basis_entry: time_design_entry,
3561 time_basis_exit: time_design_exit,
3562 time_basis_derivative_exit: time_design_derivative_exit,
3563 time_penalties: time_build.penalties,
3564 })
3565}
3566
3567#[cfg(test)]
3568mod tests {
3569 use super::{
3570 SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3571 baseline_chain_rule_gradient, baseline_offset_theta_partials,
3572 build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3573 build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3574 evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3575 gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3576 marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3577 marginal_slope_baseline_offset_theta_partials,
3578 optimize_survival_baseline_config_with_gradient,
3579 optimize_survival_baseline_config_with_gradient_only,
3580 resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3581 survival_baseline_theta_from_config,
3582 };
3583 use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3584 use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3585 use crate::probability::normal_cdf;
3586 use ndarray::{Array1, Array2, array};
3587
3588 #[test]
3589 fn survival_timewiggle_keeps_requested_order_one_penalty() {
3590 let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3591 let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3592 let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3593 let cfg = LinkWiggleFormulaSpec {
3594 degree: 3,
3595 num_internal_knots: 4,
3596 penalty_orders: vec![1, 2, 3],
3597 double_penalty: false,
3598 };
3599
3600 let build =
3601 build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3602 .expect("build survival timewiggle");
3603
3604 assert_eq!(build.penalties.len(), 3);
3605 assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3606 assert!(build.ncols > 0);
3607 }
3608
3609 #[test]
3610 fn marginal_slope_time_anchor_defaults_to_median_exit() {
3611 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3612 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3613 let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3614 .expect("resolve marginal-slope default time anchor");
3615
3616 assert!(
3617 (anchor - 19.0).abs() <= 1e-12,
3618 "marginal-slope default anchor should be median exit, got {anchor}"
3619 );
3620 }
3621
3622 #[test]
3623 fn marginal_slope_time_anchor_honors_explicit_value() {
3624 let age_entry = array![9.0, 1.0, 4.0, 6.0];
3625 let age_exit = array![20.0, 12.0, 18.0, 30.0];
3626 let anchor =
3627 resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3628 .expect("resolve explicit marginal-slope time anchor");
3629
3630 assert!(
3631 (anchor - 7.5).abs() <= 1e-12,
3632 "explicit marginal-slope anchor should round-trip, got {anchor}"
3633 );
3634 }
3635
3636 #[test]
3647 fn baseline_optimizer_contracts_agree_on_shared_surface() {
3648 let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3653 let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3654
3655 let initial = SurvivalBaselineConfig {
3658 target: SurvivalBaselineTarget::Weibull,
3659 scale: Some(1.0),
3660 shape: Some(1.0),
3661 rate: None,
3662 makeham: None,
3663 };
3664
3665 let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3668 survival_baseline_theta_from_config(cfg)
3669 .expect("config→θ")
3670 .expect("Weibull config has a θ")
3671 };
3672
3673 let curvature_cost = curvature.clone();
3676 let star_cost = theta_star.clone();
3677 let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3678 let theta = survival_baseline_theta_from_config(cfg)?
3679 .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3680 let d = &theta - &star_cost;
3681 let ad = curvature_cost.dot(&d);
3682 Ok(0.5 * d.dot(&ad))
3683 };
3684
3685 let curvature_grad = curvature.clone();
3686 let star_grad = theta_star.clone();
3687 let cost_for_grad = cost_at.clone();
3688 let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3689 &initial,
3690 "baseline parity (gradient-only)",
3691 move |cfg| {
3692 let cost = cost_for_grad(cfg)?;
3693 let theta = survival_baseline_theta_from_config(cfg)?
3694 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3695 let gradient = curvature_grad.dot(&(&theta - &star_grad));
3696 Ok((cost, gradient))
3697 },
3698 )
3699 .expect("gradient-only baseline optimization converges");
3700
3701 let curvature_hess = curvature.clone();
3702 let star_hess = theta_star.clone();
3703 let cost_for_hess = cost_at.clone();
3704 let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3705 &initial,
3706 "baseline parity (gradient+Hessian)",
3707 move |cfg| {
3708 let cost = cost_for_hess(cfg)?;
3709 let theta = survival_baseline_theta_from_config(cfg)?
3710 .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3711 let gradient = curvature_hess.dot(&(&theta - &star_hess));
3712 Ok((cost, gradient, curvature_hess.clone()))
3713 },
3714 )
3715 .expect("gradient+Hessian baseline optimization converges");
3716
3717 let theta_grad_only = recovered_theta(&result_grad_only);
3718 let theta_grad_hess = recovered_theta(&result_grad_hess);
3719
3720 for (label, theta) in [
3723 ("gradient-only", &theta_grad_only),
3724 ("gradient+Hessian", &theta_grad_hess),
3725 ] {
3726 let err = (theta - &theta_star)
3727 .mapv(f64::abs)
3728 .fold(0.0_f64, |a, &v| a.max(v));
3729 assert!(
3730 err <= 2e-3,
3731 "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3732 );
3733 }
3734
3735 let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3739 (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3740 };
3741 assert!(
3742 pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3743 "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3744 );
3745 }
3746
3747 #[test]
3748 fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3749 let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3750 let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3751 let requested_degree = 3;
3752 let num_internal_knots = 1;
3753
3754 let built = build_survival_time_basis(
3755 &age_entry,
3756 &age_exit,
3757 SurvivalTimeBasisConfig::ISpline {
3758 degree: requested_degree,
3759 knots: Array1::zeros(0),
3760 keep_cols: Vec::new(),
3761 smooth_lambda: 1e-2,
3762 },
3763 Some((num_internal_knots, 1e-2)),
3764 )
3765 .expect("automatic cubic ispline with one interior knot builds");
3766
3767 let working_degree = requested_degree + 1;
3768 let knots = built.knots.expect("resolved ispline knots");
3769 assert_eq!(
3770 knots.len(),
3771 num_internal_knots + 2 * (working_degree + 1),
3772 "I-spline automatic knots must be clamped for the working B-spline degree"
3773 );
3774 assert_eq!(built.degree, Some(requested_degree));
3775 assert!(built.x_exit_time.ncols() > 0);
3776 assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3777 assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3778 }
3779
3780 #[test]
3781 fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3782 let age_entry = array![1.0_f64, 1.0, 1.0];
3783 let age_exit = array![4.0_f64, 4.0, 4.0];
3784 let left = 1.0_f64.ln();
3785 let right = 4.0_f64.ln();
3786 let mid = left + 0.5 * (right - left);
3787 let knots = array![left, left, left, left, mid, right, right, right, right];
3788
3789 let built = build_survival_time_basis(
3790 &age_entry,
3791 &age_exit,
3792 SurvivalTimeBasisConfig::ISpline {
3793 degree: 2,
3794 knots,
3795 keep_cols: Vec::new(),
3796 smooth_lambda: 1e-2,
3797 },
3798 None,
3799 )
3800 .expect("build right-boundary ispline time basis");
3801
3802 let derivative = built.x_derivative_time.as_dense_cow();
3803 let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3804 assert!(
3805 max_abs > 1e-8,
3806 "right-boundary I-spline derivative must use the left-hand endpoint slope"
3807 );
3808 for row in derivative.rows() {
3809 assert!(
3810 row.iter().any(|v| *v > 1e-8),
3811 "each row at the right boundary needs a positive hazard derivative"
3812 );
3813 }
3814 }
3815
3816 #[test]
3817 fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3818 let cfg = SurvivalBaselineConfig {
3819 target: SurvivalBaselineTarget::GompertzMakeham,
3820 scale: None,
3821 shape: Some(0.07),
3822 rate: Some(0.012),
3823 makeham: Some(0.003),
3824 };
3825 let age = 11.5;
3826 let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3827 .expect("evaluate marginal-slope gompertz-makeham baseline");
3828 let shape = cfg.shape.expect("shape");
3829 let rate = cfg.rate.expect("rate");
3830 let makeham = cfg.makeham.expect("makeham");
3831 let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3832 let instant_hazard = makeham + rate * (shape * age).exp();
3833 let expected_survival = (-cumulative_hazard).exp();
3834 let actual_survival = normal_cdf(-q);
3835 assert!((actual_survival - expected_survival).abs() <= 1e-12);
3836
3837 let h = 1e-5;
3838 let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
3839 .expect("q plus")
3840 .0;
3841 let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
3842 .expect("q minus")
3843 .0;
3844 let fd = (q_plus - q_minus) / (2.0 * h);
3845 assert!((q_derivative - fd).abs() <= 1e-7);
3846 assert!(instant_hazard > 0.0);
3847 }
3848
3849 #[test]
3850 fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
3851 let configs = [
3860 SurvivalBaselineConfig {
3861 target: SurvivalBaselineTarget::Linear,
3862 scale: None,
3863 shape: None,
3864 rate: None,
3865 makeham: None,
3866 },
3867 SurvivalBaselineConfig {
3868 target: SurvivalBaselineTarget::Weibull,
3869 scale: Some(2.5),
3870 shape: Some(1.3),
3871 rate: None,
3872 makeham: None,
3873 },
3874 SurvivalBaselineConfig {
3875 target: SurvivalBaselineTarget::Gompertz,
3876 scale: None,
3877 shape: Some(0.05),
3878 rate: Some(0.01),
3879 makeham: None,
3880 },
3881 SurvivalBaselineConfig {
3882 target: SurvivalBaselineTarget::GompertzMakeham,
3883 scale: None,
3884 shape: Some(0.07),
3885 rate: Some(0.012),
3886 makeham: Some(0.003),
3887 },
3888 ];
3889 for cfg in &configs {
3890 let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
3893 .expect("marginal-slope baseline must be evaluable at the origin");
3894 assert_eq!(q0, 0.0);
3895 assert_eq!(q0_derivative, 0.0);
3896
3897 let (eta0, eta0_derivative) =
3901 evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
3902 assert!(eta0_derivative.is_finite());
3903 assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
3904
3905 let age_entry = array![0.0, 0.0];
3909 let age_exit = array![0.0, 1.5];
3910 let (entry, exit, derivative) =
3911 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
3912 .expect("probit baseline offsets must build through the origin");
3913 assert!(entry.iter().all(|v| v.is_finite()));
3914 assert!(exit.iter().all(|v| v.is_finite()));
3915 assert!(derivative.iter().all(|v| v.is_finite()));
3916 assert_eq!(exit[0], 0.0);
3918 }
3919 }
3920
3921 #[test]
3922 fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
3923 let cfg = SurvivalBaselineConfig {
3924 target: SurvivalBaselineTarget::GompertzMakeham,
3925 scale: None,
3926 shape: Some(0.03),
3927 rate: Some(0.01),
3928 makeham: Some(0.002),
3929 };
3930 let age_entry = array![2.0, 4.0];
3931 let age_exit = array![5.0, 9.0];
3932 let (entry, exit, derivative) =
3933 build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
3934 .expect("marginal-slope baseline offsets");
3935 for i in 0..age_entry.len() {
3936 let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
3937 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
3938 * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
3939 let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
3940 + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
3941 * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
3942 assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
3943 assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
3944 assert!(derivative[i].is_finite() && derivative[i] > 0.0);
3945 }
3946 }
3947
3948 fn fd_marginal_slope_baseline_offset(
3949 age: f64,
3950 cfg: &SurvivalBaselineConfig,
3951 steps: &[f64],
3952 ) -> Vec<(f64, f64)> {
3953 let theta = survival_baseline_theta_from_config(cfg)
3954 .expect("theta")
3955 .expect("non-linear baseline");
3956 assert_eq!(
3957 steps.len(),
3958 theta.len(),
3959 "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
3960 );
3961 (0..theta.len())
3962 .map(|k| {
3963 let h = steps[k];
3964 let mut theta_plus = theta.clone();
3965 theta_plus[k] += h;
3966 let mut theta_minus = theta.clone();
3967 theta_minus[k] -= h;
3968 let cfg_plus =
3969 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
3970 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
3971 .expect("minus cfg");
3972 let (q_p, qt_p) =
3973 evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
3974 let (q_m, qt_m) =
3975 evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
3976 ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
3977 })
3978 .collect()
3979 }
3980
3981 #[test]
3982 fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
3983 let cfg = SurvivalBaselineConfig {
3984 target: SurvivalBaselineTarget::GompertzMakeham,
3985 scale: None,
3986 shape: Some(0.04),
3987 rate: Some(0.013),
3988 makeham: Some(0.002),
3989 };
3990 let age = 17.0;
3991 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
3992 .expect("partials")
3993 .expect("nonlinear");
3994 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
3995 assert_eq!(analytic.len(), fd.len());
3996 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
3997 assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
3998 assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
3999 }
4000 }
4001
4002 #[test]
4003 fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4004 let cfg = SurvivalBaselineConfig {
4005 target: SurvivalBaselineTarget::GompertzMakeham,
4006 scale: None,
4007 shape: Some(1e-14),
4008 rate: Some(0.013),
4009 makeham: Some(0.002),
4010 };
4011 let age = 17.0;
4012 let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4013 .expect("partials")
4014 .expect("nonlinear");
4015 let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4016 assert_eq!(analytic.len(), fd.len());
4017 for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4018 assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4019 assert_close(
4020 *aqt,
4021 *fqt,
4022 1e-5,
4023 &format!("near-zero gm-probit q' theta[{k}]"),
4024 );
4025 }
4026 }
4027
4028 fn shifted_quadratic_offset_residuals(
4029 age_entry: ndarray::ArrayView1<'_, f64>,
4030 age_exit: ndarray::ArrayView1<'_, f64>,
4031 base_cfg: &SurvivalBaselineConfig,
4032 candidate_cfg: &SurvivalBaselineConfig,
4033 base: &OffsetChannelResiduals,
4034 curvatures: &OffsetChannelCurvatures,
4035 ) -> OffsetChannelResiduals {
4036 let n = age_exit.len();
4037 let mut entry = base.entry.clone();
4038 let mut exit = base.exit.clone();
4039 let mut derivative = base.derivative.clone();
4040 for row in 0..n {
4041 let (_, base_exit, base_deriv) =
4042 baseline_marginal_slope_channels(age_exit[row], base_cfg);
4043 let (_, cand_exit, cand_deriv) =
4044 baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4045 let base_entry = if base.entry[row] == 0.0 {
4046 0.0
4047 } else {
4048 baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4049 };
4050 let cand_entry = if base.entry[row] == 0.0 {
4051 0.0
4052 } else {
4053 baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4054 };
4055 let delta = [
4056 cand_entry - base_entry,
4057 cand_exit - base_exit,
4058 cand_deriv - base_deriv,
4059 ];
4060 let mut shift = [0.0; 3];
4061 for i in 0..3 {
4062 for j in 0..3 {
4063 shift[i] += curvatures.rows[row][i][j] * delta[j];
4064 }
4065 }
4066 if base.entry[row] != 0.0 {
4067 entry[row] += shift[0];
4068 }
4069 exit[row] += shift[1];
4070 derivative[row] += shift[2];
4071 }
4072 OffsetChannelResiduals {
4073 entry,
4074 exit,
4075 derivative,
4076 right: base.right.clone(),
4077 }
4078 }
4079
4080 fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4081 let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4082 (q, q, q_t)
4083 }
4084
4085 #[test]
4086 fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4087 let cfg = SurvivalBaselineConfig {
4088 target: SurvivalBaselineTarget::GompertzMakeham,
4089 scale: None,
4090 shape: Some(0.025),
4091 rate: Some(0.012),
4092 makeham: Some(0.003),
4093 };
4094 let theta = survival_baseline_theta_from_config(&cfg)
4095 .expect("theta")
4096 .expect("nonlinear");
4097 let age_entry = array![2.5, 0.0, 5.0];
4098 let age_exit = array![7.5, 11.0, 15.0];
4099 let base_residuals = OffsetChannelResiduals {
4100 entry: array![0.2, 0.0, -0.1],
4101 exit: array![0.6, -0.3, 0.4],
4102 derivative: array![-0.5, 0.25, 0.15],
4103 right: Array1::<f64>::zeros(3),
4104 };
4105 let curvatures = OffsetChannelCurvatures {
4106 rows: vec![
4107 [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4108 [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4109 [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4110 ],
4111 };
4112 let analytic = marginal_slope_baseline_chain_rule_hessian(
4113 age_entry.view(),
4114 age_exit.view(),
4115 &cfg,
4116 &base_residuals,
4117 &curvatures,
4118 )
4119 .expect("hessian")
4120 .expect("nonlinear");
4121
4122 let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4123 let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4124 .expect("candidate cfg");
4125 let residuals = shifted_quadratic_offset_residuals(
4126 age_entry.view(),
4127 age_exit.view(),
4128 &cfg,
4129 &candidate,
4130 &base_residuals,
4131 &curvatures,
4132 );
4133 marginal_slope_baseline_chain_rule_gradient(
4134 age_entry.view(),
4135 age_exit.view(),
4136 &candidate,
4137 &residuals,
4138 )
4139 .expect("gradient")
4140 .expect("nonlinear")
4141 };
4142
4143 for j in 0..theta.len() {
4144 let step = if j == 1 { 2e-5 } else { 1e-5 };
4145 let mut plus = theta.clone();
4146 plus[j] += step;
4147 let mut minus = theta.clone();
4148 minus[j] -= step;
4149 let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4150 for i in 0..theta.len() {
4151 assert_close(
4152 analytic[[i, j]],
4153 fd_col[i],
4154 2e-5,
4155 &format!("baseline Hessian ({i},{j})"),
4156 );
4157 }
4158 }
4159 }
4160
4161 #[test]
4162 fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4163 let cfg = SurvivalBaselineConfig {
4164 target: SurvivalBaselineTarget::GompertzMakeham,
4165 scale: None,
4166 shape: Some(0.03),
4167 rate: Some(0.01),
4168 makeham: Some(0.002),
4169 };
4170 let age_entry = array![3.0, 6.0];
4171 let age_exit = array![8.0, 12.0];
4172 let residuals = OffsetChannelResiduals {
4173 exit: array![0.7, -0.2],
4174 entry: array![0.1, 0.4],
4175 derivative: array![1.3, -0.6],
4176 right: Array1::<f64>::zeros(2),
4177 };
4178 let grad = marginal_slope_baseline_chain_rule_gradient(
4179 age_entry.view(),
4180 age_exit.view(),
4181 &cfg,
4182 &residuals,
4183 )
4184 .expect("gradient")
4185 .expect("nonlinear");
4186
4187 let mut expected = Array1::<f64>::zeros(3);
4188 for i in 0..age_exit.len() {
4189 let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4190 .expect("exit partials")
4191 .expect("nonlinear");
4192 let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4193 .expect("entry partials")
4194 .expect("nonlinear");
4195 for k in 0..3 {
4196 expected[k] += residuals.exit[i] * exit_partials[k].0
4197 + residuals.derivative[i] * exit_partials[k].1
4198 + residuals.entry[i] * entry_partials[k].0;
4199 }
4200 }
4201 for k in 0..3 {
4202 assert_close(
4203 grad[k],
4204 expected[k],
4205 1e-12,
4206 &format!("gm-probit chain gradient theta[{k}]"),
4207 );
4208 }
4209 }
4210
4211 #[test]
4221 fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4222 let cfg = SurvivalBaselineConfig {
4223 target: SurvivalBaselineTarget::GompertzMakeham,
4224 scale: None,
4225 shape: Some(0.028),
4226 rate: Some(0.011),
4227 makeham: Some(0.0025),
4228 };
4229 let age_entry = array![3.0, 0.0, 5.5];
4232 let age_exit = array![8.0, 12.0, 16.0];
4233 let residuals = OffsetChannelResiduals {
4234 exit: array![0.7, -0.2, 0.45],
4235 entry: array![0.1, 0.0, -0.3],
4236 derivative: array![1.3, -0.6, 0.2],
4237 right: Array1::<f64>::zeros(3),
4238 };
4239
4240 let reference_gradient = |partials: &dyn Fn(
4243 f64,
4244 &SurvivalBaselineConfig,
4245 )
4246 -> Result<Option<Vec<(f64, f64)>>, String>|
4247 -> Array1<f64> {
4248 let theta_dim = partials(age_exit[0], &cfg)
4249 .expect("probe partials")
4250 .expect("nonlinear")
4251 .len();
4252 let mut acc = Array1::<f64>::zeros(theta_dim);
4253 for i in 0..age_exit.len() {
4254 let p_exit = partials(age_exit[i], &cfg)
4255 .expect("exit partials")
4256 .expect("nonlinear");
4257 let r_x = residuals.exit[i];
4258 let r_d = residuals.derivative[i];
4259 for k in 0..theta_dim {
4260 acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4261 }
4262 let r_e = residuals.entry[i];
4263 if r_e != 0.0 {
4264 let p_entry = partials(age_entry[i], &cfg)
4265 .expect("entry partials")
4266 .expect("nonlinear");
4267 for k in 0..theta_dim {
4268 acc[k] += r_e * p_entry[k].0;
4269 }
4270 }
4271 }
4272 acc
4273 };
4274
4275 let rp_engine = baseline_chain_rule_gradient(
4277 age_entry.view(),
4278 age_exit.view(),
4279 age_exit.view(),
4280 &cfg,
4281 &residuals,
4282 )
4283 .expect("rp gradient")
4284 .expect("rp nonlinear");
4285 let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4286 assert_eq!(rp_engine.len(), rp_reference.len());
4287 for k in 0..rp_engine.len() {
4288 assert_close(
4289 rp_engine[k],
4290 rp_reference[k],
4291 0.0,
4292 &format!("rp engine vs inline reference theta[{k}]"),
4293 );
4294 }
4295
4296 let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4298 age_entry.view(),
4299 age_exit.view(),
4300 &cfg,
4301 &residuals,
4302 )
4303 .expect("probit gradient")
4304 .expect("probit nonlinear");
4305 let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4306 assert_eq!(probit_engine.len(), probit_reference.len());
4307 for k in 0..probit_engine.len() {
4308 assert_close(
4309 probit_engine[k],
4310 probit_reference[k],
4311 0.0,
4312 &format!("probit engine vs inline reference theta[{k}]"),
4313 );
4314 }
4315 }
4316
4317 #[test]
4338 fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4339 let cfg = SurvivalBaselineConfig {
4340 target: SurvivalBaselineTarget::GompertzMakeham,
4341 scale: None,
4342 shape: Some(0.05),
4343 rate: Some(0.012),
4344 makeham: Some(0.003),
4345 };
4346 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4348 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4349 let residuals = OffsetChannelResiduals {
4352 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4353 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4354 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4355 right: Array1::<f64>::zeros(8),
4356 };
4357
4358 let analytic = baseline_chain_rule_gradient(
4359 age_entry.view(),
4360 age_exit.view(),
4361 age_exit.view(),
4362 &cfg,
4363 &residuals,
4364 )
4365 .expect("analytic gradient ok")
4366 .expect("GM baseline has a θ-gradient");
4367 assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4368
4369 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4375 let mut acc = 0.0;
4376 for i in 0..age_exit.len() {
4377 let (eta_exit_i, od_exit_i) =
4378 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4379 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4380 if residuals.entry[i] != 0.0 {
4381 let (eta_entry_i, _) =
4382 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4383 acc += residuals.entry[i] * eta_entry_i;
4384 }
4385 }
4386 acc
4387 };
4388
4389 let theta0 = survival_baseline_theta_from_config(&cfg)
4390 .expect("theta seed")
4391 .expect("GM has θ");
4392 let delta = 1e-4;
4394 let mut fd = Array1::<f64>::zeros(analytic.len());
4395 for k in 0..analytic.len() {
4396 let mut theta_plus = theta0.clone();
4397 theta_plus[k] += delta;
4398 let mut theta_minus = theta0.clone();
4399 theta_minus[k] -= delta;
4400 let cfg_plus =
4401 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4402 let cfg_minus =
4403 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4404 let lp = loss_at_cfg(&cfg_plus);
4405 let lm = loss_at_cfg(&cfg_minus);
4406 fd[k] = (lp - lm) / (2.0 * delta);
4407 }
4408
4409 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4410 let max_err = analytic
4411 .iter()
4412 .zip(fd.iter())
4413 .map(|(a, b)| (a - b).abs())
4414 .fold(0.0_f64, f64::max);
4415 let rel = max_err / (analytic_norm + 1e-12);
4416 eprintln!(
4418 "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4419 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4420 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4421 );
4422 assert!(
4423 rel < 1e-2,
4424 "analytic θ-gradient disagrees with central FD beyond 1%: \
4425 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4426 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4427 );
4428 }
4429
4430 #[test]
4445 fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4446 let cfg = SurvivalBaselineConfig {
4447 target: SurvivalBaselineTarget::Weibull,
4448 scale: Some(11.0),
4449 shape: Some(1.4),
4450 rate: None,
4451 makeham: None,
4452 };
4453 let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4454 let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4455 let residuals = OffsetChannelResiduals {
4456 exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4457 entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4458 derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4459 right: Array1::<f64>::zeros(8),
4460 };
4461
4462 let analytic = baseline_chain_rule_gradient(
4463 age_entry.view(),
4464 age_exit.view(),
4465 age_exit.view(),
4466 &cfg,
4467 &residuals,
4468 )
4469 .expect("analytic gradient ok")
4470 .expect("Weibull baseline has a θ-gradient");
4471 assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4472
4473 let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4474 let mut acc = 0.0;
4475 for i in 0..age_exit.len() {
4476 let (eta_exit_i, od_exit_i) =
4477 evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4478 acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4479 if residuals.entry[i] != 0.0 {
4480 let (eta_entry_i, _) =
4481 evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4482 acc += residuals.entry[i] * eta_entry_i;
4483 }
4484 }
4485 acc
4486 };
4487
4488 let theta0 = survival_baseline_theta_from_config(&cfg)
4489 .expect("theta seed")
4490 .expect("Weibull has θ");
4491 let delta = 1e-4;
4492 let mut fd = Array1::<f64>::zeros(analytic.len());
4493 for k in 0..analytic.len() {
4494 let mut theta_plus = theta0.clone();
4495 theta_plus[k] += delta;
4496 let mut theta_minus = theta0.clone();
4497 theta_minus[k] -= delta;
4498 let cfg_plus =
4499 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4500 let cfg_minus =
4501 survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4502 let lp = loss_at_cfg(&cfg_plus);
4503 let lm = loss_at_cfg(&cfg_minus);
4504 fd[k] = (lp - lm) / (2.0 * delta);
4505 }
4506
4507 let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4508 let max_err = analytic
4509 .iter()
4510 .zip(fd.iter())
4511 .map(|(a, b)| (a - b).abs())
4512 .fold(0.0_f64, f64::max);
4513 let rel = max_err / (analytic_norm + 1e-12);
4514 eprintln!(
4515 "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4516 analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4517 analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4518 );
4519 assert!(
4520 rel < 1e-2,
4521 "analytic θ-gradient disagrees with central FD beyond 1%: \
4522 analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4523 rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4524 );
4525 }
4526
4527 fn fd_baseline_offset(
4540 age: f64,
4541 cfg: &SurvivalBaselineConfig,
4542 steps: &[f64],
4543 ) -> Vec<(f64, f64)> {
4544 let theta = survival_baseline_theta_from_config(cfg)
4545 .expect("theta")
4546 .expect("non-linear baseline");
4547 assert_eq!(
4548 steps.len(),
4549 theta.len(),
4550 "fd_baseline_offset: step vector length must match θ dimension"
4551 );
4552 (0..theta.len())
4553 .map(|k| {
4554 let h = steps[k];
4555 let mut theta_plus = theta.clone();
4556 theta_plus[k] += h;
4557 let mut theta_minus = theta.clone();
4558 theta_minus[k] -= h;
4559 let cfg_plus =
4560 survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4561 let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4562 .expect("minus cfg");
4563 let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4564 let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4565 ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4566 })
4567 .collect()
4568 }
4569
4570 fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4571 let ok = if expected.abs() < 1.0 {
4575 (actual - expected).abs() <= tol
4576 } else {
4577 (actual - expected).abs() <= tol * expected.abs().max(1.0)
4578 };
4579 assert!(
4580 ok,
4581 "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4582 );
4583 }
4584
4585 #[test]
4586 fn gompertz_offset_partials_match_central_diff() {
4587 let cases = [
4591 (0.5_f64, 0.01_f64, 30.0_f64),
4592 (0.2, 0.05, 60.0),
4593 (1.0, 0.001, 10.0),
4594 (0.4, 5e-11, 25.0),
4595 (0.4, -5e-11, 25.0),
4596 (0.3, -0.02, 40.0),
4597 (0.8, 0.2, 5.0),
4598 ];
4599 for &(rate, shape, age) in &cases {
4600 let cfg = SurvivalBaselineConfig {
4601 target: SurvivalBaselineTarget::Gompertz,
4602 scale: None,
4603 shape: Some(shape),
4604 rate: Some(rate),
4605 makeham: None,
4606 };
4607 let analytic = baseline_offset_theta_partials(age, &cfg)
4608 .expect("ok")
4609 .expect("non-linear");
4610 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4616 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4617 assert_eq!(analytic.len(), 2);
4618 assert_close(
4620 analytic[0].0,
4621 fd[0].0,
4622 1e-7,
4623 &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4624 );
4625 assert_close(
4626 analytic[0].1,
4627 fd[0].1,
4628 1e-7,
4629 &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4630 );
4631 assert_close(
4634 analytic[1].0,
4635 fd[1].0,
4636 1e-5,
4637 &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4638 );
4639 assert_close(
4640 analytic[1].1,
4641 fd[1].1,
4642 1e-5,
4643 &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4644 );
4645 }
4646 }
4647
4648 #[test]
4649 fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4650 let cfg = SurvivalBaselineConfig {
4654 target: SurvivalBaselineTarget::Gompertz,
4655 scale: None,
4656 shape: Some(0.05),
4657 rate: Some(0.3),
4658 makeham: None,
4659 };
4660 let partials = baseline_offset_theta_partials(42.0, &cfg)
4661 .expect("ok")
4662 .expect("non-linear");
4663 assert_eq!(partials[0].0, 1.0);
4664 assert_eq!(partials[0].1, 0.0);
4665 }
4666
4667 #[test]
4668 fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4669 let age = 25.0;
4676 let rate = 0.4;
4677 let cfg_taylor = SurvivalBaselineConfig {
4678 target: SurvivalBaselineTarget::Gompertz,
4679 scale: None,
4680 shape: Some(0.5e-10),
4681 rate: Some(rate),
4682 makeham: None,
4683 };
4684 let cfg_direct = SurvivalBaselineConfig {
4685 target: SurvivalBaselineTarget::Gompertz,
4686 scale: None,
4687 shape: Some(2.0e-10),
4688 rate: Some(rate),
4689 makeham: None,
4690 };
4691 let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4692 .expect("ok")
4693 .expect("nl");
4694 let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4695 .expect("ok")
4696 .expect("nl");
4697 assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4699 assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4700 assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4702 assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4703 }
4704
4705 #[test]
4717 fn gompertz_hazard_shape_derivatives_match_central_diff() {
4718 let cases = [
4723 (10.0_f64, 0.012_f64, 0.05_f64),
4724 (2.5, 0.5, 0.2),
4725 (15.0, 0.003, 0.01),
4726 (40.0, 0.3, 0.001),
4727 ];
4728 let h = 1e-6;
4729 for &(age, rate, shape) in &cases {
4730 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4732 let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4733 let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4734 assert_close(
4735 d_cum,
4736 (cum_p - cum_m) / (2.0 * h),
4737 1e-6,
4738 &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4739 );
4740 assert_close(
4741 d_inst,
4742 (inst_p - inst_m) / (2.0 * h),
4743 1e-6,
4744 &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4745 );
4746
4747 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4749 let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4750 let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4751 assert_close(
4752 d2_cum,
4753 (dcum_p - dcum_m) / (2.0 * h),
4754 1e-5,
4755 &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4756 );
4757 assert_close(
4758 d2_inst,
4759 (dinst_p - dinst_m) / (2.0 * h),
4760 1e-5,
4761 &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4762 );
4763 }
4764 }
4765
4766 #[test]
4767 fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4768 let cases = [
4780 (25.0_f64, 0.4_f64, 1e-9_f64),
4781 (100.0, 0.4, 1e-6), (100.0, 0.012, 1e-6), (50.0, 1.2, 1e-8),
4784 ];
4785 for &(age, rate, shape) in &cases {
4796 let t = age;
4797 let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4798 assert_close(
4799 d_cum,
4800 rate * t * t / 2.0,
4801 1e-3,
4802 &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4803 );
4804 assert_close(
4805 d_inst,
4806 rate * t,
4807 1e-3,
4808 &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4809 );
4810
4811 let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4812 assert_close(
4813 d2_cum,
4814 rate * t * t * t / 3.0,
4815 1e-3,
4816 &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4817 );
4818 assert_close(
4819 d2_inst,
4820 rate * t * t,
4821 1e-3,
4822 &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4823 );
4824 }
4825 }
4826
4827 #[test]
4828 fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4829 let age = 100.0;
4836 let rate = 0.4;
4837 let t = age;
4838 let truth = rate * t * t * t / 3.0; for k in 5..=12 {
4845 let shape = 10f64.powi(-(k as i32)); let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4847 assert_close(
4848 d2_cum,
4849 truth,
4850 1e-3,
4851 &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
4852 );
4853 }
4854 }
4855
4856 #[test]
4857 fn weibull_offset_partials_match_central_diff() {
4858 let cases = [
4859 (0.5_f64, 1.2_f64, 25.0_f64),
4860 (2.0, 0.8, 60.0),
4861 (0.1, 3.0, 10.0),
4862 ];
4863 for &(scale, shape, age) in &cases {
4864 let cfg = SurvivalBaselineConfig {
4865 target: SurvivalBaselineTarget::Weibull,
4866 scale: Some(scale),
4867 shape: Some(shape),
4868 rate: None,
4869 makeham: None,
4870 };
4871 let analytic = baseline_offset_theta_partials(age, &cfg)
4872 .expect("ok")
4873 .expect("nl");
4874 let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
4875 assert_eq!(analytic.len(), 2);
4876 for k in 0..2 {
4877 assert_close(
4878 analytic[k].0,
4879 fd[k].0,
4880 1e-7,
4881 &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4882 );
4883 assert_close(
4884 analytic[k].1,
4885 fd[k].1,
4886 1e-7,
4887 &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4888 );
4889 }
4890 assert_eq!(analytic[0].1, 0.0);
4892 }
4893 }
4894
4895 #[test]
4896 fn gompertz_makeham_offset_partials_match_central_diff() {
4897 let cases = [
4898 (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
4899 (0.5, 0.01, 0.01, 25.0),
4900 (0.2, 0.001, 0.005, 60.0),
4901 (0.4, 5e-11, 0.01, 25.0),
4902 (0.4, -5e-11, 0.01, 25.0),
4903 (0.8, 0.2, 0.05, 5.0),
4904 ];
4905 for &(rate, shape, makeham, age) in &cases {
4906 let cfg = SurvivalBaselineConfig {
4907 target: SurvivalBaselineTarget::GompertzMakeham,
4908 scale: None,
4909 shape: Some(shape),
4910 rate: Some(rate),
4911 makeham: Some(makeham),
4912 };
4913 let analytic = baseline_offset_theta_partials(age, &cfg)
4914 .expect("ok")
4915 .expect("nl");
4916 let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4920 let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
4921 assert_eq!(analytic.len(), 3);
4922 for k in 0..3 {
4923 assert_close(
4924 analytic[k].0,
4925 fd[k].0,
4926 1e-5,
4927 &format!(
4928 "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
4929 ),
4930 );
4931 assert_close(
4932 analytic[k].1,
4933 fd[k].1,
4934 1e-5,
4935 &format!(
4936 "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
4937 ),
4938 );
4939 }
4940 }
4941 }
4942
4943 #[test]
4944 fn linear_baseline_has_no_theta_partials() {
4945 let cfg = SurvivalBaselineConfig {
4946 target: SurvivalBaselineTarget::Linear,
4947 scale: None,
4948 shape: None,
4949 rate: None,
4950 makeham: None,
4951 };
4952 assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
4953 }
4954
4955 #[test]
4956 fn baseline_offset_partials_reject_non_positive_ages() {
4957 let cfg = SurvivalBaselineConfig {
4958 target: SurvivalBaselineTarget::Gompertz,
4959 scale: None,
4960 shape: Some(0.01),
4961 rate: Some(0.5),
4962 makeham: None,
4963 };
4964 assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
4965 assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
4966 assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
4967 }
4968
4969 #[test]
4975 fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
4976 let cfg = SurvivalBaselineConfig {
4977 target: SurvivalBaselineTarget::Gompertz,
4978 scale: None,
4979 shape: Some(0.05),
4980 rate: Some(0.3),
4981 makeham: None,
4982 };
4983 let age_entry = array![10.0_f64];
4984 let age_exit = array![25.0_f64];
4985 let residuals = OffsetChannelResiduals {
4986 exit: array![0.7_f64],
4987 entry: array![-0.2_f64],
4988 derivative: array![-0.4_f64],
4989 right: Array1::<f64>::zeros(1),
4990 };
4991 let grad = baseline_chain_rule_gradient(
4992 age_entry.view(),
4993 age_exit.view(),
4994 age_exit.view(),
4995 &cfg,
4996 &residuals,
4997 )
4998 .expect("ok")
4999 .expect("non-linear");
5000 let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5002 .unwrap()
5003 .unwrap();
5004 let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5005 .unwrap()
5006 .unwrap();
5007 for k in 0..p_exit.len() {
5008 let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5009 assert!(
5010 (grad[k] - expected).abs() < 1e-12,
5011 "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5012 grad[k],
5013 expected
5014 );
5015 }
5016 }
5017
5018 #[test]
5021 fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5022 let cfg = SurvivalBaselineConfig {
5023 target: SurvivalBaselineTarget::Gompertz,
5024 scale: None,
5025 shape: Some(0.05),
5026 rate: Some(0.3),
5027 makeham: None,
5028 };
5029 let age_entry = array![0.0_f64, 5.0_f64];
5030 let age_exit = array![10.0_f64, 20.0_f64];
5031 let residuals = OffsetChannelResiduals {
5032 exit: array![0.5_f64, 0.3_f64],
5033 entry: array![0.0_f64, -0.1_f64], derivative: array![-0.2_f64, 0.0_f64],
5035 right: Array1::<f64>::zeros(2),
5036 };
5037 let grad = baseline_chain_rule_gradient(
5039 age_entry.view(),
5040 age_exit.view(),
5041 age_exit.view(),
5042 &cfg,
5043 &residuals,
5044 )
5045 .expect("must not fail on origin-entry row with r_entry=0")
5046 .expect("non-linear");
5047 assert_eq!(grad.len(), 2);
5048 let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5050 let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5051 let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5052 for k in 0..2 {
5053 let expected = 0.5 * p_exit_0[k].0
5054 + (-0.2) * p_exit_0[k].1
5055 + 0.3 * p_exit_1[k].0
5056 + (-0.1) * p_entry_1[k].0;
5057 assert!(
5058 (grad[k] - expected).abs() < 1e-12,
5059 "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5060 grad[k],
5061 expected
5062 );
5063 }
5064 }
5065
5066 #[test]
5068 fn chain_rule_gradient_linear_target_returns_none() {
5069 let cfg = SurvivalBaselineConfig {
5070 target: SurvivalBaselineTarget::Linear,
5071 scale: None,
5072 shape: None,
5073 rate: None,
5074 makeham: None,
5075 };
5076 let age_entry = array![1.0_f64];
5077 let age_exit = array![2.0_f64];
5078 let residuals = OffsetChannelResiduals {
5079 exit: array![0.1_f64],
5080 entry: array![0.0_f64],
5081 derivative: array![0.0_f64],
5082 right: Array1::<f64>::zeros(1),
5083 };
5084 let grad = baseline_chain_rule_gradient(
5085 age_entry.view(),
5086 age_exit.view(),
5087 age_exit.view(),
5088 &cfg,
5089 &residuals,
5090 )
5091 .expect("ok");
5092 assert!(grad.is_none());
5093 }
5094
5095 #[test]
5114 fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5115 let cfg = SurvivalBaselineConfig {
5118 target: SurvivalBaselineTarget::Gompertz,
5119 scale: None,
5120 shape: Some(0.03),
5121 rate: Some(0.25),
5122 makeham: None,
5123 };
5124 let age_entry = array![0.0_f64, 5.0, 8.0];
5125 let age_exit = array![4.0_f64, 12.0, 20.0];
5126 let weights = array![1.0_f64, 2.0, 0.5];
5129 let events = [1.0_f64, 1.0, 0.0];
5130 let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5135 let s_vals = [0.7_f64, 1.1, 1.5];
5136 let (r_x, r_e, r_d) = {
5137 let mut rx = Array1::<f64>::zeros(3);
5138 let mut re = Array1::<f64>::zeros(3);
5139 let mut rd = Array1::<f64>::zeros(3);
5140 for i in 0..3 {
5141 let w = weights[i];
5142 let d = events[i];
5143 rx[i] = w * (eta_exit_vals[i].exp() - d);
5144 re[i] = if i == 0 {
5145 0.0 } else {
5147 -w * eta_entry_vals[i].exp()
5148 };
5149 rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5150 }
5151 (rx, re, rd)
5152 };
5153 let residuals = OffsetChannelResiduals {
5154 exit: r_x.clone(),
5155 entry: r_e.clone(),
5156 derivative: r_d.clone(),
5157 right: Array1::<f64>::zeros(3),
5158 };
5159 let grad = baseline_chain_rule_gradient(
5160 age_entry.view(),
5161 age_exit.view(),
5162 age_exit.view(),
5163 &cfg,
5164 &residuals,
5165 )
5166 .expect("ok")
5167 .expect("non-linear");
5168
5169 let nll = |theta_plus: &Array1<f64>| -> f64 {
5174 let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5175 let mut sum = 0.0_f64;
5176 for i in 0..3 {
5177 let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5178 let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5179 let d_eta_x = eta_x_p - base.0;
5180 let d_d_x = d_x_p - base.1;
5181 let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5182 let s_new = s_vals[i] + d_d_x;
5183 let interval_entry = if i == 0 {
5184 0.0_f64
5185 } else {
5186 let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5187 let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5188 let d_eta_e = eta_e_p - base_e.0;
5189 let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5190 eta_entry_new.exp()
5191 };
5192 let w = weights[i];
5193 let d = events[i];
5194 let nll_i =
5195 w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5196 sum += nll_i;
5197 }
5198 sum
5199 };
5200
5201 let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5202 let h = 1e-6;
5203 for k in 0..theta_base.len() {
5204 let mut tp = theta_base.clone();
5205 let mut tm = theta_base.clone();
5206 tp[k] += h;
5207 tm[k] -= h;
5208 let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5209 assert!(
5210 (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5211 "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5212 grad[k],
5213 fd
5214 );
5215 }
5216 }
5217
5218 #[test]
5220 fn chain_rule_gradient_rejects_length_mismatch() {
5221 let cfg = SurvivalBaselineConfig {
5222 target: SurvivalBaselineTarget::Gompertz,
5223 scale: None,
5224 shape: Some(0.05),
5225 rate: Some(0.3),
5226 makeham: None,
5227 };
5228 let age_entry = array![1.0_f64, 2.0]; let age_exit = array![5.0_f64, 6.0, 7.0]; let residuals = OffsetChannelResiduals {
5231 exit: array![0.1_f64, 0.2, 0.3],
5232 entry: array![0.0_f64, 0.0, 0.0],
5233 derivative: array![0.0_f64, 0.0, 0.0],
5234 right: Array1::<f64>::zeros(3),
5235 };
5236 let err = baseline_chain_rule_gradient(
5237 age_entry.view(),
5238 age_exit.view(),
5239 age_exit.view(),
5240 &cfg,
5241 &residuals,
5242 )
5243 .expect_err("length mismatch must error");
5244 assert!(err.contains("length mismatch"), "err={err}");
5245 }
5246}