1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3 BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4 projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9 LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21 #[error("input dimensions are inconsistent")]
22 DimensionMismatch,
23 #[error("inputs contain non-finite values")]
24 NonFiniteInput,
25 #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26 UnsupportedSpec(&'static str),
27 #[error("crude risk integration setup is invalid")]
28 InvalidIntegrationSetup,
29 #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30 InvalidTimeGrid,
31 #[error("cumulative hazard must be nondecreasing")]
32 NonMonotoneCumulativeHazard,
33 #[error("instantaneous hazard must stay strictly positive during integration")]
34 NonPositiveHazard,
35 #[error("{reason}")]
36 InvalidInput { reason: String },
37 #[error("{reason}")]
38 CauseSpecificDimensionMismatch { reason: String },
39 #[error("{reason}")]
40 NumericalFailure { reason: String },
41 #[error("{reason}")]
42 EventCodeInvalid { reason: String },
43 #[error("{reason}")]
44 EventDegenerate { reason: String },
45 #[error("cause-specific survival block {block}: {source}")]
46 CauseSpecificBlock {
47 block: usize,
48 #[source]
49 source: Box<SurvivalError>,
50 },
51}
52
53impl From<SurvivalError> for String {
54 fn from(err: SurvivalError) -> Self {
55 err.to_string()
56 }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60 fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61 SurvivalError::CauseSpecificDimensionMismatch {
62 reason: err.message(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69 #[default]
70 Net,
71 Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76 pub age_entry: ArrayView1<'a, f64>,
77 pub age_exit: ArrayView1<'a, f64>,
78 pub event_target: ArrayView1<'a, u8>,
79 pub event_competing: ArrayView1<'a, u8>,
80 pub sampleweight: ArrayView1<'a, f64>,
81 pub x_entry: ArrayView2<'a, f64>,
82 pub x_exit: ArrayView2<'a, f64>,
83 pub x_derivative: ArrayView2<'a, f64>,
84 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94 pub age_entry: ArrayView1<'a, f64>,
95 pub age_exit: ArrayView1<'a, f64>,
96 pub event_target: ArrayView1<'a, u8>,
97 pub event_competing: ArrayView1<'a, u8>,
98 pub sampleweight: ArrayView1<'a, f64>,
99 pub time_entry: ArrayView2<'a, f64>,
100 pub time_exit: ArrayView2<'a, f64>,
101 pub time_derivative: ArrayView2<'a, f64>,
102 pub covariates: ArrayView2<'a, f64>,
103 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113 pub eta_entry: ArrayView1<'a, f64>,
115 pub eta_exit: ArrayView1<'a, f64>,
117 pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129 pub matrix: Array2<f64>,
130 pub lambda: f64,
131 pub range: Range<usize>,
132 pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139 pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143 pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144 Self { blocks }
145 }
146
147 pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148 let mut grad = Array1::zeros(beta.len());
149 for block in &self.blocks {
150 if block.lambda == 0.0 {
151 continue;
152 }
153 let b = beta.slice(ndarray::s![block.range.clone()]);
154 let g = block.matrix.dot(&b);
155 let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156 dst += &(block.lambda * g);
157 }
158 grad
159 }
160
161 pub fn hessian(&self, dim: usize) -> Array2<f64> {
162 let mut h = Array2::zeros((dim, dim));
163 self.addhessian_inplace(&mut h);
164 h
165 }
166
167 pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168 let mut value = 0.0;
169 for block in &self.blocks {
170 if block.lambda == 0.0 {
171 continue;
172 }
173 let b = beta.slice(ndarray::s![block.range.clone()]);
174 value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175 }
176 value
177 }
178
179 pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180 for block in &self.blocks {
181 if block.lambda == 0.0 {
182 continue;
183 }
184 let start = block.range.start;
185 let end = block.range.end;
186 h.slice_mut(ndarray::s![start..end, start..end])
187 .scaled_add(block.lambda, &block.matrix);
188 }
189 }
190}
191
192pub const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
205
206const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
214
215#[derive(Debug, Clone)]
216pub struct CauseSpecificRoystonParmarBlock {
217 pub age_entry: Array1<f64>,
218 pub age_exit: Array1<f64>,
219 pub event_target: Array1<u8>,
220 pub sampleweight: Array1<f64>,
221 pub x_entry: Array2<f64>,
222 pub x_exit: Array2<f64>,
223 pub x_derivative: Array2<f64>,
224 pub offset_eta_entry: Array1<f64>,
225 pub offset_eta_exit: Array1<f64>,
226 pub offset_derivative_exit: Array1<f64>,
227 pub derivative_floor: f64,
228}
229
230#[derive(Debug, Clone)]
236pub struct CauseSpecificRoystonParmarFamily {
237 blocks: Vec<CauseSpecificRoystonParmarBlock>,
238}
239
240impl CauseSpecificRoystonParmarFamily {
241 pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
242 if blocks.is_empty() {
243 return Err(SurvivalError::InvalidInput {
244 reason: "cause-specific survival family requires at least one endpoint".to_string(),
245 }
246 .into());
247 }
248 for (idx, block) in blocks.iter().enumerate() {
249 validate_cause_specific_block(block).map_err(|err| {
250 SurvivalError::CauseSpecificBlock {
251 block: idx + 1,
252 source: Box::new(err),
253 }
254 .to_string()
255 })?;
256 }
257 Ok(Self { blocks })
258 }
259
260 pub fn cause_count(&self) -> usize {
261 self.blocks.len()
262 }
263}
264
265fn validate_cause_specific_block(
266 block: &CauseSpecificRoystonParmarBlock,
267) -> Result<(), SurvivalError> {
268 let n = block.event_target.len();
269 let p = block.x_exit.ncols();
270 if n == 0 || p == 0 {
271 bail_invalid_surv!("empty event vector or coefficient block");
272 }
273 if block.age_entry.len() != n
274 || block.age_exit.len() != n
275 || block.sampleweight.len() != n
276 || block.x_entry.nrows() != n
277 || block.x_exit.nrows() != n
278 || block.x_derivative.nrows() != n
279 || block.x_entry.ncols() != p
280 || block.x_derivative.ncols() != p
281 || block.offset_eta_entry.len() != n
282 || block.offset_eta_exit.len() != n
283 || block.offset_derivative_exit.len() != n
284 {
285 return Err(SurvivalError::CauseSpecificDimensionMismatch {
286 reason: "dimension mismatch".to_string(),
287 });
288 }
289 if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
295 return Err(SurvivalError::EventCodeInvalid {
296 reason: format!(
297 "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
298 ),
299 });
300 }
301 if block.age_entry.iter().any(|v| !v.is_finite())
302 || block.age_exit.iter().any(|v| !v.is_finite())
303 || block
304 .sampleweight
305 .iter()
306 .any(|v| !v.is_finite() || *v < 0.0)
307 || block.x_entry.iter().any(|v| !v.is_finite())
308 || block.x_exit.iter().any(|v| !v.is_finite())
309 || block.x_derivative.iter().any(|v| !v.is_finite())
310 || block.offset_eta_entry.iter().any(|v| !v.is_finite())
311 || block.offset_eta_exit.iter().any(|v| !v.is_finite())
312 || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
313 || !block.derivative_floor.is_finite()
314 || block.derivative_floor < 0.0
315 {
316 bail_invalid_surv!("non-finite input");
317 }
318 Ok(())
319}
320
321fn evaluate_cause_specific_block(
322 block: &CauseSpecificRoystonParmarBlock,
323 beta: &Array1<f64>,
324) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
325 let n = block.event_target.len();
326 let p = block.x_exit.ncols();
327 if beta.len() != p {
328 return Err(SurvivalError::CauseSpecificDimensionMismatch {
329 reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
330 });
331 }
332 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
333 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
334 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
335 let mut log_likelihood = 0.0;
336 let mut w_exit = Array1::<f64>::zeros(n);
337 let mut w_entry = Array1::<f64>::zeros(n);
338 let mut w_event = Array1::<f64>::zeros(n);
339 let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
340 let mut w_event_outer = Array1::<f64>::zeros(n);
341
342 for i in 0..n {
343 let weight = block.sampleweight[i];
344 if weight <= 0.0 {
345 continue;
346 }
347 if block.age_exit[i] < block.age_entry[i] {
348 bail_invalid_surv!("age_exit < age_entry at row {i}");
349 }
350 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
351 let h_exit = eta_exit[i].exp();
352 let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
353 if !(h_exit.is_finite() && h_entry.is_finite()) {
354 return Err(SurvivalError::NumericalFailure {
355 reason: format!("non-finite cumulative hazard at row {i}"),
356 });
357 }
358 log_likelihood -= weight * (h_exit - h_entry);
359 w_exit[i] = weight * h_exit;
360 w_entry[i] = weight * h_entry;
361 if block.event_target[i] > 0 {
362 let deriv = derivative[i];
363 if !(deriv.is_finite() && deriv > 0.0) {
364 return Err(SurvivalError::NumericalFailure {
365 reason: format!(
366 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
367 ),
368 });
369 }
370 log_likelihood += weight * (eta_exit[i] + deriv.ln());
371 w_event[i] = weight;
372 w_event_inv_deriv[i] = weight / deriv;
373 w_event_outer[i] = weight / (deriv * deriv);
374 }
375 }
376
377 let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
378 nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
379 nll_gradient -= &fast_atv(&block.x_exit, &w_event);
380 nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
381 let gradient = -nll_gradient;
382
383 let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
384 hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
385 hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
386 Ok((log_likelihood, gradient, hessian))
387}
388
389impl CustomFamily for CauseSpecificRoystonParmarFamily {
390 fn joint_jeffreys_term_required(&self) -> bool {
394 true
395 }
396
397 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
398 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
399 "cause-specific survival",
400 self.blocks.len(),
401 block_states.len(),
402 )?;
403 let mut log_likelihood = 0.0;
404 let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
405 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
406 let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
407 log_likelihood += ll;
408 blockworking_sets.push(BlockWorkingSet::ExactNewton {
409 gradient,
410 hessian: SymmetricMatrix::Dense(hessian),
411 });
412 }
413 Ok(FamilyEvaluation {
414 log_likelihood,
415 blockworking_sets,
416 })
417 }
418
419 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
420 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
421 "cause-specific survival",
422 self.blocks.len(),
423 block_states.len(),
424 )?;
425 let mut log_likelihood = 0.0;
426 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
427 let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
428 log_likelihood += ll;
429 }
430 Ok(log_likelihood)
431 }
432
433 fn likelihood_blocks_uncoupled(&self) -> bool {
434 true
435 }
436
437 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
438 true
439 }
440
441 fn output_channel_assignment(
442 &self,
443 specs: &[crate::custom_family::ParameterBlockSpec],
444 ) -> Option<Vec<usize>> {
445 if specs.len() != self.blocks.len() {
446 return Some((0..self.blocks.len()).collect());
447 }
448 Some((0..specs.len()).collect())
449 }
450
451 fn coefficient_hessian_cost(
452 &self,
453 specs: &[crate::custom_family::ParameterBlockSpec],
454 ) -> u64 {
455 crate::custom_family::default_coefficient_hessian_cost(specs)
456 }
457
458 fn block_linear_constraints(
459 &self,
460 _: &[ParameterBlockState],
461 block_idx: usize,
462 spec: &crate::custom_family::ParameterBlockSpec,
463 ) -> Result<Option<LinearInequalityConstraints>, String> {
464 let block = self.blocks.get(block_idx).ok_or_else(|| {
465 SurvivalError::CauseSpecificDimensionMismatch {
466 reason: format!(
467 "cause-specific survival expected block index < {}, got {block_idx}",
468 self.blocks.len()
469 ),
470 }
471 .to_string()
472 })?;
473 if block.x_derivative.ncols() != spec.design.ncols() {
474 return Err(SurvivalError::CauseSpecificDimensionMismatch {
475 reason: format!(
476 "cause-specific survival derivative design has {} columns but block '{}' has {}",
477 block.x_derivative.ncols(),
478 spec.name,
479 spec.design.ncols()
480 ),
481 }
482 .into());
483 }
484 let rhs = block
485 .offset_derivative_exit
486 .mapv(|offset| block.derivative_floor - offset);
487 Ok(Some(LinearInequalityConstraints {
488 a: block.x_derivative.clone(),
489 b: rhs,
490 }))
491 }
492
493 fn max_feasible_step_size(
494 &self,
495 block_states: &[ParameterBlockState],
496 block_idx: usize,
497 delta: &Array1<f64>,
498 ) -> Result<Option<f64>, String> {
499 let block = self.blocks.get(block_idx).ok_or_else(|| {
500 SurvivalError::CauseSpecificDimensionMismatch {
501 reason: format!(
502 "cause-specific survival expected block index < {}, got {block_idx}",
503 self.blocks.len()
504 ),
505 }
506 .to_string()
507 })?;
508 let state = block_states.get(block_idx).ok_or_else(|| {
509 SurvivalError::CauseSpecificDimensionMismatch {
510 reason: format!(
511 "cause-specific survival expected {} block states, got {}",
512 self.blocks.len(),
513 block_states.len()
514 ),
515 }
516 .to_string()
517 })?;
518 if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
519 return Err(SurvivalError::CauseSpecificDimensionMismatch {
520 reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
521 }
522 .into());
523 }
524 let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
525 let derivative_delta = fast_av(&block.x_derivative, delta);
526 let mut alpha_max = 1.0_f64;
527 for i in 0..derivative.len() {
528 if block.sampleweight[i] <= 0.0 {
529 continue;
530 }
531 let current = derivative[i] - block.derivative_floor;
532 let slope = derivative_delta[i];
533 if slope < 0.0 {
534 if current <= 0.0 {
535 return Ok(Some(0.0));
536 }
537 alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
538 }
539 }
540 Ok(Some(alpha_max.clamp(0.0, 1.0)))
541 }
542
543 fn exact_newton_hessian_directional_derivative(
544 &self,
545 block_states: &[ParameterBlockState],
546 block_idx: usize,
547 d_beta: &Array1<f64>,
548 ) -> Result<Option<Array2<f64>>, String> {
549 let block = self.blocks.get(block_idx).ok_or_else(|| {
550 SurvivalError::CauseSpecificDimensionMismatch {
551 reason: format!(
552 "cause-specific survival expected block index < {}, got {block_idx}",
553 self.blocks.len()
554 ),
555 }
556 .to_string()
557 })?;
558 let state = block_states.get(block_idx).ok_or_else(|| {
559 SurvivalError::CauseSpecificDimensionMismatch {
560 reason: format!(
561 "cause-specific survival expected {} block states, got {}",
562 self.blocks.len(),
563 block_states.len()
564 ),
565 }
566 .to_string()
567 })?;
568 Ok(Some(cause_specific_hessian_directional_derivative(
569 block,
570 &state.beta,
571 d_beta,
572 )?))
573 }
574
575 fn exact_newton_hessian_second_directional_derivative(
576 &self,
577 block_states: &[ParameterBlockState],
578 block_idx: usize,
579 d_beta_u: &Array1<f64>,
580 d_beta_v: &Array1<f64>,
581 ) -> Result<Option<Array2<f64>>, String> {
582 let block = self.blocks.get(block_idx).ok_or_else(|| {
583 SurvivalError::CauseSpecificDimensionMismatch {
584 reason: format!(
585 "cause-specific survival expected block index < {}, got {block_idx}",
586 self.blocks.len()
587 ),
588 }
589 .to_string()
590 })?;
591 let state = block_states.get(block_idx).ok_or_else(|| {
592 SurvivalError::CauseSpecificDimensionMismatch {
593 reason: format!(
594 "cause-specific survival expected {} block states, got {}",
595 self.blocks.len(),
596 block_states.len()
597 ),
598 }
599 .to_string()
600 })?;
601 Ok(Some(cause_specific_hessian_second_directional_derivative(
602 block,
603 &state.beta,
604 d_beta_u,
605 d_beta_v,
606 )?))
607 }
608}
609
610fn cause_specific_hessian_directional_derivative(
611 block: &CauseSpecificRoystonParmarBlock,
612 beta: &Array1<f64>,
613 d_beta: &Array1<f64>,
614) -> Result<Array2<f64>, SurvivalError> {
615 let p = block.x_exit.ncols();
616 if beta.len() != p || d_beta.len() != p {
617 return Err(SurvivalError::CauseSpecificDimensionMismatch {
618 reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
619 });
620 }
621 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
622 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
623 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
624 let d_eta_entry = fast_av(&block.x_entry, d_beta);
625 let d_eta_exit = fast_av(&block.x_exit, d_beta);
626 let d_derivative = fast_av(&block.x_derivative, d_beta);
627 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
628 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
629 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
630
631 for i in 0..block.event_target.len() {
632 let weight = block.sampleweight[i];
633 if weight <= 0.0 {
634 continue;
635 }
636 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
637 w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
638 if has_entry {
639 w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
640 }
641 if block.event_target[i] > 0 {
642 let deriv = derivative[i];
643 if !(deriv.is_finite() && deriv > 0.0) {
644 return Err(SurvivalError::NumericalFailure {
645 reason: format!(
646 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
647 ),
648 });
649 }
650 w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
651 }
652 }
653
654 let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
655 d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
656 d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
657 Ok(d_hessian)
658}
659
660fn cause_specific_hessian_second_directional_derivative(
661 block: &CauseSpecificRoystonParmarBlock,
662 beta: &Array1<f64>,
663 d_beta_u: &Array1<f64>,
664 d_beta_v: &Array1<f64>,
665) -> Result<Array2<f64>, SurvivalError> {
666 let p = block.x_exit.ncols();
667 if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
668 return Err(SurvivalError::CauseSpecificDimensionMismatch {
669 reason: "cause-specific survival second Hessian derivative dimension mismatch"
670 .to_string(),
671 });
672 }
673 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
674 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
675 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
676 let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
677 let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
678 let u_derivative = fast_av(&block.x_derivative, d_beta_u);
679 let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
680 let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
681 let v_derivative = fast_av(&block.x_derivative, d_beta_v);
682 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
683 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
684 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
685
686 for i in 0..block.event_target.len() {
687 let weight = block.sampleweight[i];
688 if weight <= 0.0 {
689 continue;
690 }
691 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
692 w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
693 if has_entry {
694 w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
695 }
696 if block.event_target[i] > 0 {
697 let deriv = derivative[i];
698 if !(deriv.is_finite() && deriv > 0.0) {
699 return Err(SurvivalError::NumericalFailure {
700 reason: format!(
701 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
702 ),
703 });
704 }
705 w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
706 }
707 }
708
709 let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
710 d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
711 d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
712 Ok(d2_hessian)
713}
714
715pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
716 const INTEGER_TOL: f64 = 1e-8;
717 const MAX_AUTO_CAUSES: u8 = 32;
718 if !value.is_finite() {
719 return Err(SurvivalError::EventCodeInvalid {
720 reason: format!(
721 "survival event value at row {} is non-finite",
722 row_index + 1
723 ),
724 }
725 .into());
726 }
727 if value < 0.0 {
728 return Err(SurvivalError::EventCodeInvalid {
729 reason: format!(
730 "survival event value at row {} is negative: {value}",
731 row_index + 1
732 ),
733 }
734 .into());
735 }
736 let rounded = value.round();
737 if (value - rounded).abs() > INTEGER_TOL {
738 return Err(SurvivalError::EventCodeInvalid {
739 reason: format!(
740 "survival event value at row {} must be an integer code with 0=censored, got {value}",
741 row_index + 1
742 ),
743 }
744 .into());
745 }
746 if rounded > f64::from(MAX_AUTO_CAUSES) {
747 return Err(SurvivalError::EventCodeInvalid {
748 reason: format!(
749 "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
750 row_index + 1
751 ),
752 }
753 .into());
754 }
755 Ok(rounded as u8)
756}
757
758pub fn cause_count_from_event_codes(
759 event_codes: ArrayView1<'_, u8>,
760) -> Result<usize, SurvivalError> {
761 let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
762 if max_code == 0 {
763 return Ok(1);
764 }
765
766 let mut present = vec![false; max_code + 1];
767 for code in event_codes.iter().copied() {
768 present[usize::from(code)] = true;
769 }
770 if (1..=max_code).any(|code| !present[code]) {
771 let actual = present
772 .iter()
773 .enumerate()
774 .skip(1)
775 .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
776 .collect::<Vec<_>>()
777 .join(", ");
778 return Err(SurvivalError::EventCodeInvalid {
779 reason: format!(
780 "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
781 ),
782 });
783 }
784
785 Ok(max_code)
786}
787
788pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
801 event_codes.mapv(|label| u8::from(label > 0))
802}
803
804pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
814 let cause_code = cause as u8;
815 event_codes.mapv(|observed| u8::from(observed == cause_code))
816}
817
818fn compress_positive_collinear_constraints(
819 a: &Array2<f64>,
820 b: &Array1<f64>,
821) -> LinearInequalityConstraints {
822 const SCALE_TOL: f64 = 1e-14;
823 const KEY_TOL: f64 = 1e-8;
824
825 let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
826 let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
827
828 for i in 0..a.nrows() {
829 let row = a.row(i);
830 let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
831 if !scale.is_finite() || scale <= SCALE_TOL {
832 if b[i] > 0.0 {
833 fallbackrows.push((row.to_vec(), b[i]));
834 }
835 continue;
836 }
837
838 let normalizedrow: Vec<f64> = row
839 .iter()
840 .map(|&v| {
841 let scaled = v / scale;
842 if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
843 })
844 .collect();
845 let normalized_rhs = b[i] / scale;
846 let key: Vec<i64> = normalizedrow
847 .iter()
848 .map(|&v| (v / KEY_TOL).round() as i64)
849 .collect();
850
851 match grouped.get_mut(&key) {
852 Some((_, rhs_max)) => {
853 if normalized_rhs > *rhs_max {
854 *rhs_max = normalized_rhs;
855 }
856 }
857 None => {
858 grouped.insert(key, (normalizedrow, normalized_rhs));
859 }
860 }
861 }
862
863 let nrows = grouped.len() + fallbackrows.len();
864 let n_cols = a.ncols();
865 let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
866 let mut b_out = Array1::<f64>::zeros(nrows);
867
868 let mut outrow = 0usize;
869 for (_, (row, rhs)) in grouped {
870 for (j, value) in row.into_iter().enumerate() {
871 a_out[[outrow, j]] = value;
872 }
873 b_out[outrow] = rhs;
874 outrow += 1;
875 }
876 for (row, rhs) in fallbackrows {
877 for (j, value) in row.into_iter().enumerate() {
878 a_out[[outrow, j]] = value;
879 }
880 b_out[outrow] = rhs;
881 outrow += 1;
882 }
883
884 LinearInequalityConstraints { a: a_out, b: b_out }
885}
886
887#[derive(Debug, Clone, Copy, Default)]
888pub struct SurvivalMonotonicityPenalty {
889 pub tolerance: f64,
890}
891
892#[derive(Debug, Clone)]
893enum SurvivalDesign {
894 Flat {
895 x_entry: Array2<f64>,
896 x_exit: Array2<f64>,
897 x_derivative: Array2<f64>,
898 },
899 TimeCovariateShared {
900 time_entry: Array2<f64>,
901 time_exit: Array2<f64>,
902 time_derivative: Array2<f64>,
903 covariates: Array2<f64>,
904 },
905}
906
907impl SurvivalDesign {
908 fn p_total(&self) -> usize {
909 match self {
910 Self::Flat { x_exit, .. } => x_exit.ncols(),
911 Self::TimeCovariateShared {
912 time_exit,
913 covariates,
914 ..
915 } => time_exit.ncols() + covariates.ncols(),
916 }
917 }
918
919 fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
920 match self {
921 Self::Flat { .. } => time_mat.dot(beta),
922 Self::TimeCovariateShared { covariates, .. } => {
923 let p_time = time_mat.ncols();
924 let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
925 if covariates.ncols() > 0 {
926 out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
927 }
928 out
929 }
930 }
931 }
932
933 fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
934 match self {
935 Self::Flat { .. } => {
936 for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
937 *dst = src;
938 }
939 }
940 Self::TimeCovariateShared { covariates, .. } => {
941 let p_time = time_mat.ncols();
942 for j in 0..p_time {
943 out[j] = time_mat[[i, j]];
944 }
945 for j in 0..covariates.ncols() {
946 out[p_time + j] = covariates[[i, j]];
947 }
948 }
949 }
950 }
951}
952
953#[derive(Debug, Clone)]
955struct SurvivalWorkspace {
956 w_event: Array1<f64>,
957 w_event_inv_deriv: Array1<f64>,
958 w_event_outer: Array1<f64>,
959 w_hess_exit: Array1<f64>,
960 w_hess_entry: Array1<f64>,
961}
962
963impl SurvivalWorkspace {
964 fn new(n: usize) -> Self {
965 Self {
966 w_event: Array1::zeros(n),
967 w_event_inv_deriv: Array1::zeros(n),
968 w_event_outer: Array1::zeros(n),
969 w_hess_exit: Array1::zeros(n),
970 w_hess_entry: Array1::zeros(n),
971 }
972 }
973
974 fn reset(&mut self, n: usize) {
975 if self.w_event.len() != n {
976 *self = Self::new(n);
977 } else {
978 self.w_event.fill(0.0);
979 self.w_event_inv_deriv.fill(0.0);
980 self.w_event_outer.fill(0.0);
981 self.w_hess_exit.fill(0.0);
982 self.w_hess_entry.fill(0.0);
983 }
984 }
985}
986
987#[derive(Clone, Debug)]
1000pub struct OffsetChannelResiduals {
1001 pub exit: Array1<f64>,
1003 pub entry: Array1<f64>,
1005 pub derivative: Array1<f64>,
1007 pub right: Array1<f64>,
1011}
1012
1013#[derive(Clone, Debug)]
1016pub struct OffsetChannelCurvatures {
1017 pub rows: Vec<[[f64; 3]; 3]>,
1018}
1019
1020#[derive(Debug)]
1021pub struct WorkingModelSurvival {
1022 age_entry: Array1<f64>,
1023 age_exit: Array1<f64>,
1024 entry_at_origin: Array1<bool>,
1025 event_target: Array1<u8>,
1026 sampleweight: Array1<f64>,
1027 design: SurvivalDesign,
1028 offset_eta_entry: Array1<f64>,
1029 offset_eta_exit: Array1<f64>,
1030 offset_derivative_exit: Array1<f64>,
1031 penalties: PenaltyBlocks,
1032 monotonicity: SurvivalMonotonicityPenalty,
1033 structurally_monotonic: bool,
1034 structural_time_columns: usize,
1035 monotonicity_constraint_rows: Option<Array2<f64>>,
1036 monotonicity_constraint_offsets: Option<Array1<f64>>,
1037 workspace: std::sync::Mutex<SurvivalWorkspace>,
1038}
1039
1040impl Clone for WorkingModelSurvival {
1041 fn clone(&self) -> Self {
1042 let workspace = self.workspace.lock().unwrap().clone();
1043 Self {
1044 age_entry: self.age_entry.clone(),
1045 age_exit: self.age_exit.clone(),
1046 entry_at_origin: self.entry_at_origin.clone(),
1047 event_target: self.event_target.clone(),
1048 sampleweight: self.sampleweight.clone(),
1049 design: self.design.clone(),
1050 offset_eta_entry: self.offset_eta_entry.clone(),
1051 offset_eta_exit: self.offset_eta_exit.clone(),
1052 offset_derivative_exit: self.offset_derivative_exit.clone(),
1053 penalties: self.penalties.clone(),
1054 monotonicity: self.monotonicity,
1055 structurally_monotonic: self.structurally_monotonic,
1056 structural_time_columns: self.structural_time_columns,
1057 monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1058 monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1059 workspace: std::sync::Mutex::new(workspace),
1060 }
1061 }
1062}
1063
1064impl WorkingModelSurvival {
1065 const LOG_F64_MAX: f64 = 709.782712893384;
1066
1067 #[inline]
1068 fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1069 if base == 0.0 {
1070 return Ok(0.0);
1071 }
1072 let log_abs = log_scale + base.abs().ln();
1073 if !log_abs.is_finite() {
1074 crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1075 }
1076 if log_abs > Self::LOG_F64_MAX {
1077 crate::bail_invalid_estim!(
1078 "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1079 );
1080 }
1081 Ok(base.signum() * log_abs.exp())
1082 }
1083
1084 fn coefficient_dim(&self) -> usize {
1085 self.design.p_total()
1086 }
1087
1088 fn nrows(&self) -> usize {
1089 self.sampleweight.len()
1090 }
1091
1092 fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1093 let time_mat = match &self.design {
1094 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1095 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1096 };
1097 self.design.design_dot(time_mat, beta)
1098 }
1099
1100 fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1101 let time_mat = match &self.design {
1102 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1103 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1104 };
1105 self.design.design_dot(time_mat, beta)
1106 }
1107
1108 fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1109 match &self.design {
1110 SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1111 SurvivalDesign::TimeCovariateShared {
1112 time_derivative, ..
1113 } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1114 }
1115 }
1116
1117 fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1118 let time_mat = match &self.design {
1119 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1120 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1121 };
1122 self.design.fill_row(time_mat, i, out);
1123 }
1124
1125 fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1126 let time_mat = match &self.design {
1127 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1128 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1129 };
1130 self.design.fill_row(time_mat, i, out);
1131 }
1132
1133 fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1134 match &self.design {
1135 SurvivalDesign::Flat { x_derivative, .. } => {
1136 for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1137 *dst = src;
1138 }
1139 }
1140 SurvivalDesign::TimeCovariateShared {
1141 time_derivative, ..
1142 } => {
1143 let p_time = time_derivative.ncols();
1144 for j in 0..p_time {
1145 out[j] = time_derivative[[i, j]];
1146 }
1147 for dst in out.iter_mut().skip(p_time) {
1148 *dst = 0.0;
1149 }
1150 }
1151 }
1152 }
1153
1154 fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1155 match &self.design {
1156 SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1157 SurvivalDesign::TimeCovariateShared {
1158 time_derivative,
1159 covariates,
1160 ..
1161 } => {
1162 let p_time = time_derivative.ncols();
1163 let p_cov = covariates.ncols();
1164 let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1165 let time_block = fast_xt_diag_x(time_derivative, weights);
1166 out.slice_mut(ndarray::s![..p_time, ..p_time])
1167 .assign(&time_block);
1168 out
1169 }
1170 }
1171 }
1172
1173 fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1177 match &self.design {
1178 SurvivalDesign::Flat {
1179 x_entry, x_exit, ..
1180 } => {
1181 let mut h = fast_xt_diag_x(x_exit, w_exit);
1182 h -= &fast_xt_diag_x(x_entry, w_entry);
1183 h
1184 }
1185 SurvivalDesign::TimeCovariateShared {
1186 time_entry,
1187 time_exit,
1188 covariates,
1189 ..
1190 } => {
1191 let p_time = time_exit.ncols();
1192 let p_cov = covariates.ncols();
1193 let p = p_time + p_cov;
1194 let mut h = Array2::<f64>::zeros((p, p));
1195 let tt = {
1197 let mut block = fast_xt_diag_x(time_exit, w_exit);
1198 block -= &fast_xt_diag_x(time_entry, w_entry);
1199 block
1200 };
1201 h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1202 if p_cov > 0 {
1203 let tc = {
1205 let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1206 block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1207 block
1208 };
1209 h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1210 h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1211 let w_diff = w_exit - w_entry;
1213 let cc = fast_xt_diag_x(covariates, &w_diff);
1214 h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1215 }
1216 h
1217 }
1218 }
1219 }
1220
1221 fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1222 const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1223 if !self.structurally_monotonic {
1224 return None;
1225 }
1226 if deriv >= 1e-12 {
1227 return Some(deriv);
1228 }
1229 if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1230 return Some(1e-12);
1231 }
1232 None
1233 }
1234
1235 fn validate_penalties(
1236 penalties: &PenaltyBlocks,
1237 coefficient_dim: usize,
1238 ) -> Result<(), SurvivalError> {
1239 for block in &penalties.blocks {
1240 if !block.lambda.is_finite() || block.lambda < 0.0 {
1241 return Err(SurvivalError::NonFiniteInput);
1242 }
1243 if block.range.start > block.range.end || block.range.end > coefficient_dim {
1244 return Err(SurvivalError::DimensionMismatch);
1245 }
1246 let block_dim = block.range.end - block.range.start;
1247 if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1248 return Err(SurvivalError::DimensionMismatch);
1249 }
1250 if block.matrix.iter().any(|v| !v.is_finite()) {
1251 return Err(SurvivalError::NonFiniteInput);
1252 }
1253 }
1254 Ok(())
1255 }
1256
1257 fn derivative_guard(&self) -> f64 {
1258 if self.structurally_monotonic {
1259 return 0.0;
1263 }
1264 self.monotonicity.tolerance.max(0.0)
1265 }
1266
1267 fn derivative_guard_numerical(&self) -> f64 {
1268 let derivative_guard = self.derivative_guard();
1269 if derivative_guard <= 0.0 {
1270 if self.structurally_monotonic {
1279 -1e-10
1280 } else {
1281 1e-12
1282 }
1283 } else {
1284 (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1285 }
1286 }
1287
1288 fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1289 let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1290 1e-10 * scale
1291 }
1292
1293 fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1294 if !self.structurally_monotonic {
1295 return None;
1296 }
1297 let p = self.coefficient_dim();
1298 let time_columns = self.structural_time_columns.min(p);
1299 if time_columns == 0 {
1300 return None;
1301 }
1302 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1303 let mut active_columns = vec![false; time_columns];
1304 let mut derivative_row = vec![0.0_f64; p];
1305 for i in 0..self.nrows() {
1306 if self.sampleweight[i] <= 0.0 {
1307 continue;
1308 }
1309 self.fill_derivative_row(i, &mut derivative_row);
1310 for j in 0..time_columns {
1311 if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1312 active_columns[j] = true;
1313 }
1314 }
1315 }
1316 if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1317 for i in 0..rows.nrows() {
1318 for j in 0..time_columns {
1319 if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1320 active_columns[j] = true;
1321 }
1322 }
1323 }
1324 }
1325 let active_columns: Vec<usize> = active_columns
1326 .into_iter()
1327 .enumerate()
1328 .filter_map(|(j, active)| active.then_some(j))
1329 .collect();
1330 if active_columns.is_empty() {
1331 return None;
1332 }
1333 let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1334 let b = Array1::<f64>::zeros(active_columns.len());
1335 for (row, &col) in active_columns.iter().enumerate() {
1336 a[[row, col]] = 1.0;
1337 }
1338 Some(LinearInequalityConstraints { a, b })
1339 }
1340
1341 pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1342 let p = self.coefficient_dim();
1343 const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1344 if p == 0 {
1345 return None;
1346 }
1347 if self.structurally_monotonic {
1348 return self.structural_time_coefficient_constraints();
1349 }
1350 if let (Some(rows), Some(offsets)) = (
1351 self.monotonicity_constraint_rows.as_ref(),
1352 self.monotonicity_constraint_offsets.as_ref(),
1353 ) {
1354 let activerows: Vec<usize> = (0..rows.nrows())
1355 .filter(|&i| {
1356 rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1357 > DERIVATIVE_ROW_NORM_TOL
1358 })
1359 .collect();
1360 if activerows.is_empty() {
1361 return None;
1362 }
1363 let mut a = Array2::<f64>::zeros((activerows.len(), p));
1364 let mut b = Array1::<f64>::zeros(activerows.len());
1365 for (r, &i) in activerows.iter().enumerate() {
1366 a.row_mut(r).assign(&rows.row(i));
1367 b[r] = self.derivative_guard() - offsets[i];
1368 }
1369 return Some(compress_positive_collinear_constraints(&a, &b));
1370 }
1371 None
1372 }
1373
1374 pub fn from_engine_inputs(
1375 inputs: SurvivalEngineInputs<'_>,
1376 penalties: PenaltyBlocks,
1377 monotonicity: SurvivalMonotonicityPenalty,
1378 spec: SurvivalSpec,
1379 ) -> Result<Self, SurvivalError> {
1380 Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1381 }
1382
1383 fn validate_offsets(
1384 offsets: Option<SurvivalBaselineOffsets<'_>>,
1385 n: usize,
1386 ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1387 if let Some(off) = offsets {
1388 if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1389 {
1390 return Err(SurvivalError::DimensionMismatch);
1391 }
1392 if off.eta_entry.iter().any(|v| !v.is_finite())
1393 || off.eta_exit.iter().any(|v| !v.is_finite())
1394 || off.derivative_exit.iter().any(|v| !v.is_finite())
1395 {
1396 return Err(SurvivalError::NonFiniteInput);
1397 }
1398 Ok((
1399 off.eta_entry.to_owned(),
1400 off.eta_exit.to_owned(),
1401 off.derivative_exit.to_owned(),
1402 ))
1403 } else {
1404 Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1405 }
1406 }
1407
1408 fn validate_common_inputs(
1409 age_entry: &ArrayView1<f64>,
1410 age_exit: &ArrayView1<f64>,
1411 event_target: &ArrayView1<u8>,
1412 event_competing: &ArrayView1<u8>,
1413 sampleweight: &ArrayView1<f64>,
1414 ) -> Result<(), SurvivalError> {
1415 if age_entry.iter().any(|v| !v.is_finite())
1416 || age_exit.iter().any(|v| !v.is_finite())
1417 || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1418 {
1419 return Err(SurvivalError::NonFiniteInput);
1420 }
1421 if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1428 return Err(SurvivalError::EventCodeInvalid {
1429 reason: format!(
1430 "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1431 ),
1432 });
1433 }
1434 if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1435 return Err(SurvivalError::EventCodeInvalid {
1436 reason: format!(
1437 "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1438 ),
1439 });
1440 }
1441 if event_target
1442 .iter()
1443 .zip(event_competing.iter())
1444 .any(|(&target, &competing)| target > 0 && competing > 0)
1445 {
1446 return Err(SurvivalError::EventCodeInvalid {
1447 reason: "a row cannot be simultaneously a target event and a competing event"
1448 .to_string(),
1449 });
1450 }
1451 if age_entry
1465 .iter()
1466 .zip(age_exit.iter())
1467 .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1468 {
1469 return Err(SurvivalError::NonFiniteInput);
1470 }
1471 Ok::<(), _>(())
1472 }
1473
1474 fn validate_monotonicity_constraints(
1475 rows: Option<ArrayView2<'_, f64>>,
1476 offsets: Option<ArrayView1<'_, f64>>,
1477 coefficient_dim: usize,
1478 ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1479 match (rows, offsets) {
1480 (None, None) => Ok((None, None)),
1481 (Some(rows), Some(offsets)) => {
1482 if rows.ncols() != coefficient_dim
1483 || rows.nrows() != offsets.len()
1484 || rows.iter().any(|v| !v.is_finite())
1485 || offsets.iter().any(|v| !v.is_finite())
1486 {
1487 return Err(SurvivalError::DimensionMismatch);
1488 }
1489 Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1490 }
1491 _ => Err(SurvivalError::DimensionMismatch),
1492 }
1493 }
1494
1495 fn finish_construction(
1496 age_entry: ArrayView1<f64>,
1497 age_exit: ArrayView1<f64>,
1498 event_target: ArrayView1<u8>,
1499 sampleweight: ArrayView1<f64>,
1500 design: SurvivalDesign,
1501 offset_eta_entry: Array1<f64>,
1502 offset_eta_exit: Array1<f64>,
1503 offset_derivative_exit: Array1<f64>,
1504 penalties: PenaltyBlocks,
1505 monotonicity: SurvivalMonotonicityPenalty,
1506 monotonicity_constraint_rows: Option<Array2<f64>>,
1507 monotonicity_constraint_offsets: Option<Array1<f64>>,
1508 ) -> Self {
1509 let n = age_entry.len();
1510 Self {
1511 age_entry: age_entry.to_owned(),
1512 age_exit: age_exit.to_owned(),
1513 entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1514 event_target: event_target.to_owned(),
1515 sampleweight: sampleweight.to_owned(),
1516 design,
1517 offset_eta_entry,
1518 offset_eta_exit,
1519 offset_derivative_exit,
1520 penalties,
1521 monotonicity,
1522 structurally_monotonic: false,
1523 structural_time_columns: 0,
1524 monotonicity_constraint_rows,
1525 monotonicity_constraint_offsets,
1526 workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1527 }
1528 }
1529
1530 pub fn from_engine_inputswith_offsets(
1531 inputs: SurvivalEngineInputs<'_>,
1532 offsets: Option<SurvivalBaselineOffsets<'_>>,
1533 penalties: PenaltyBlocks,
1534 monotonicity: SurvivalMonotonicityPenalty,
1535 spec: SurvivalSpec,
1536 ) -> Result<Self, SurvivalError> {
1537 if spec == SurvivalSpec::Crude {
1538 return Err(SurvivalError::UnsupportedSpec("crude"));
1539 }
1540 let n = inputs.age_entry.len();
1541 let p = inputs.x_entry.ncols();
1542 if inputs.age_exit.len() != n
1543 || inputs.event_target.len() != n
1544 || inputs.event_competing.len() != n
1545 || inputs.sampleweight.len() != n
1546 || inputs.x_entry.nrows() != n
1547 || inputs.x_exit.nrows() != n
1548 || inputs.x_derivative.nrows() != n
1549 || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1550 || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1551 {
1552 return Err(SurvivalError::DimensionMismatch);
1553 }
1554 Self::validate_penalties(&penalties, p)?;
1555 Self::validate_common_inputs(
1556 &inputs.age_entry,
1557 &inputs.age_exit,
1558 &inputs.event_target,
1559 &inputs.event_competing,
1560 &inputs.sampleweight,
1561 )?;
1562 if inputs.x_entry.iter().any(|v| !v.is_finite())
1563 || inputs.x_exit.iter().any(|v| !v.is_finite())
1564 || inputs.x_derivative.iter().any(|v| !v.is_finite())
1565 {
1566 return Err(SurvivalError::NonFiniteInput);
1567 }
1568 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1569 Self::validate_offsets(offsets, n)?;
1570 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1571 Self::validate_monotonicity_constraints(
1572 inputs.monotonicity_constraint_rows,
1573 inputs.monotonicity_constraint_offsets,
1574 p,
1575 )?;
1576
1577 Ok(Self::finish_construction(
1578 inputs.age_entry,
1579 inputs.age_exit,
1580 inputs.event_target,
1581 inputs.sampleweight,
1582 SurvivalDesign::Flat {
1583 x_entry: inputs.x_entry.to_owned(),
1584 x_exit: inputs.x_exit.to_owned(),
1585 x_derivative: inputs.x_derivative.to_owned(),
1586 },
1587 offset_eta_entry,
1588 offset_eta_exit,
1589 offset_derivative_exit,
1590 penalties,
1591 monotonicity,
1592 monotonicity_constraint_rows,
1593 monotonicity_constraint_offsets,
1594 ))
1595 }
1596
1597 pub fn from_time_covariate_inputswith_offsets(
1598 inputs: SurvivalTimeCovarInputs<'_>,
1599 offsets: Option<SurvivalBaselineOffsets<'_>>,
1600 penalties: PenaltyBlocks,
1601 monotonicity: SurvivalMonotonicityPenalty,
1602 spec: SurvivalSpec,
1603 ) -> Result<Self, SurvivalError> {
1604 if spec == SurvivalSpec::Crude {
1605 return Err(SurvivalError::UnsupportedSpec("crude"));
1606 }
1607 let n = inputs.age_entry.len();
1608 let p_time = inputs.time_entry.ncols();
1609 let p_cov = inputs.covariates.ncols();
1610 let p = p_time + p_cov;
1611 if inputs.age_exit.len() != n
1612 || inputs.event_target.len() != n
1613 || inputs.event_competing.len() != n
1614 || inputs.sampleweight.len() != n
1615 || inputs.time_entry.nrows() != n
1616 || inputs.time_exit.nrows() != n
1617 || inputs.time_derivative.nrows() != n
1618 || inputs.covariates.nrows() != n
1619 || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1620 || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1621 {
1622 return Err(SurvivalError::DimensionMismatch);
1623 }
1624 Self::validate_penalties(&penalties, p)?;
1625 Self::validate_common_inputs(
1626 &inputs.age_entry,
1627 &inputs.age_exit,
1628 &inputs.event_target,
1629 &inputs.event_competing,
1630 &inputs.sampleweight,
1631 )?;
1632 if inputs.time_entry.iter().any(|v| !v.is_finite())
1633 || inputs.time_exit.iter().any(|v| !v.is_finite())
1634 || inputs.time_derivative.iter().any(|v| !v.is_finite())
1635 || inputs.covariates.iter().any(|v| !v.is_finite())
1636 {
1637 return Err(SurvivalError::NonFiniteInput);
1638 }
1639 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1640 Self::validate_offsets(offsets, n)?;
1641 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1642 Self::validate_monotonicity_constraints(
1643 inputs.monotonicity_constraint_rows,
1644 inputs.monotonicity_constraint_offsets,
1645 p,
1646 )?;
1647
1648 Ok(Self::finish_construction(
1649 inputs.age_entry,
1650 inputs.age_exit,
1651 inputs.event_target,
1652 inputs.sampleweight,
1653 SurvivalDesign::TimeCovariateShared {
1654 time_entry: inputs.time_entry.to_owned(),
1655 time_exit: inputs.time_exit.to_owned(),
1656 time_derivative: inputs.time_derivative.to_owned(),
1657 covariates: inputs.covariates.to_owned(),
1658 },
1659 offset_eta_entry,
1660 offset_eta_exit,
1661 offset_derivative_exit,
1662 penalties,
1663 monotonicity,
1664 monotonicity_constraint_rows,
1665 monotonicity_constraint_offsets,
1666 ))
1667 }
1668
1669 pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1684 if lambdas.len() != self.penalties.blocks.len() {
1685 crate::bail_invalid_estim!(
1686 "set_penalty_lambdas expects {} lambdas, got {}",
1687 self.penalties.blocks.len(),
1688 lambdas.len()
1689 );
1690 }
1691 for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1692 if !lambda.is_finite() || lambda < 0.0 {
1693 crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1694 }
1695 block.lambda = lambda;
1696 }
1697 Ok(())
1698 }
1699
1700 pub fn set_structural_monotonicity(
1701 &mut self,
1702 enabled: bool,
1703 time_columns: usize,
1704 ) -> Result<(), EstimationError> {
1705 let p = self.coefficient_dim();
1706 if time_columns > p {
1707 crate::bail_invalid_estim!(
1708 "structural time columns {} exceed coefficient dimension {}",
1709 time_columns,
1710 p
1711 );
1712 }
1713 if enabled && time_columns == 0 {
1714 crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1715 }
1716 if enabled {
1717 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1718 for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1719 if offset < -STRUCTURAL_DERIV_TOL {
1720 crate::bail_invalid_estim!(
1721 "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1722 );
1723 }
1724 }
1725 let mut derivative_row = vec![0.0_f64; p];
1726 for i in 0..self.nrows() {
1727 self.fill_derivative_row(i, &mut derivative_row);
1728 for j in 0..time_columns {
1729 let v = derivative_row[j];
1730 if v < -STRUCTURAL_DERIV_TOL {
1731 crate::bail_invalid_estim!(
1732 "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1733 );
1734 }
1735 }
1736 for j in time_columns..p {
1737 let v = derivative_row[j];
1738 if v.abs() > STRUCTURAL_DERIV_TOL {
1739 crate::bail_invalid_estim!(
1740 "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1741 );
1742 }
1743 }
1744 }
1745 if let (Some(rows), Some(offsets)) = (
1746 self.monotonicity_constraint_rows.as_ref(),
1747 self.monotonicity_constraint_offsets.as_ref(),
1748 ) {
1749 for (i, &offset) in offsets.iter().enumerate() {
1750 if offset < -STRUCTURAL_DERIV_TOL {
1751 crate::bail_invalid_estim!(
1752 "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1753 );
1754 }
1755 }
1756 for i in 0..rows.nrows() {
1757 for j in 0..time_columns {
1758 let v = rows[[i, j]];
1759 if v < -STRUCTURAL_DERIV_TOL {
1760 crate::bail_invalid_estim!(
1761 "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1762 );
1763 }
1764 }
1765 for j in time_columns..p {
1766 let v = rows[[i, j]];
1767 if v.abs() > STRUCTURAL_DERIV_TOL {
1768 crate::bail_invalid_estim!(
1769 "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1770 );
1771 }
1772 }
1773 }
1774 }
1775 }
1776 self.structurally_monotonic = enabled;
1777 self.structural_time_columns = if enabled { time_columns } else { 0 };
1778 Ok(())
1779 }
1780
1781 pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1782 if beta.len() != self.coefficient_dim() {
1783 crate::bail_invalid_estim!("survival beta dimension mismatch");
1784 }
1785
1786 let n = self.nrows();
1787 let p = self.coefficient_dim();
1788
1789 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1815 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1816 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1817
1818 let mut nll = 0.0;
1819 let derivative_guard = self.derivative_guard();
1820 let derivative_guard_numerical = self.derivative_guard_numerical();
1821 let mut workspace = self.workspace.lock().unwrap();
1822 workspace.reset(n);
1823 let SurvivalWorkspace {
1824 w_event,
1825 w_event_inv_deriv,
1826 w_event_outer,
1827 w_hess_exit,
1828 w_hess_entry,
1829 } = &mut *workspace;
1830
1831 for i in 0..n {
1833 let w = self.sampleweight[i];
1834 if w <= 0.0 {
1835 continue;
1836 }
1837 let entry_age = self.age_entry[i];
1838 let exit_age = self.age_exit[i];
1839 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1840 crate::bail_invalid_estim!(
1841 "survival ages must be finite with age_exit >= age_entry"
1842 );
1843 }
1844 let d = f64::from(self.event_target[i]);
1845
1846 let has_entry_interval = !self.entry_at_origin[i];
1847 let interval_scale = if has_entry_interval {
1848 eta_exit[i].max(eta_entry[i])
1849 } else {
1850 eta_exit[i]
1851 };
1852 let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1853 let h_s_scaled = if has_entry_interval {
1854 (eta_entry[i] - interval_scale).exp()
1855 } else {
1856 0.0
1857 };
1858 let interval_scaled = h_e_scaled - h_s_scaled;
1859 let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1860 let deriv = self
1861 .stabilized_structural_derivative(derivative_raw[i])
1862 .unwrap_or(derivative_raw[i]);
1863 let mono_floor = if d > 0.0 {
1872 derivative_guard_numerical
1873 } else {
1874 0.0
1875 };
1876 if !deriv.is_finite() || deriv < mono_floor {
1877 return Err(EstimationError::ParameterConstraintViolation(format!(
1878 "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1879 i, deriv, derivative_guard
1880 )));
1881 }
1882 if has_entry_interval {
1883 let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1884 if interval_scaled + increment_guard < 0.0 {
1885 return Err(EstimationError::ParameterConstraintViolation(format!(
1886 "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1887 i, interval
1888 )));
1889 }
1890 }
1891 nll += w * interval;
1892
1893 let w_exit_i = w * eta_exit[i].exp();
1897 let w_entry_i = if has_entry_interval {
1898 w * eta_entry[i].exp()
1899 } else {
1900 0.0
1901 };
1902 if !w_exit_i.is_finite() {
1903 crate::bail_invalid_estim!(
1904 "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1905 );
1906 }
1907 w_hess_exit[i] = w_exit_i;
1908 w_hess_entry[i] = w_entry_i;
1909
1910 if d > 0.0 {
1911 let inv_deriv = 1.0 / deriv;
1912 nll += -w * (eta_exit[i] + deriv.ln());
1913 w_event[i] = w;
1914 w_event_inv_deriv[i] = w * inv_deriv;
1915 w_event_outer[i] = w * inv_deriv * inv_deriv;
1916 }
1917 }
1918
1919 let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1923 let mut grad = Array1::<f64>::zeros(p);
1927 let mut grad_comp = Array1::<f64>::zeros(p);
1928 let mut row_exit = vec![0.0_f64; p];
1929 let mut row_entry = vec![0.0_f64; p];
1930 let mut row_derivative = vec![0.0_f64; p];
1931 for i in 0..n {
1932 let w_interval_exit = w_hess_exit[i];
1933 let w_interval_entry = w_hess_entry[i];
1934 let w_event_exit = w_event[i];
1935 let w_event_derivative = w_event_inv_deriv[i];
1936 if w_interval_exit == 0.0
1937 && w_interval_entry == 0.0
1938 && w_event_exit == 0.0
1939 && w_event_derivative == 0.0
1940 {
1941 continue;
1942 }
1943 self.fill_exit_row(i, &mut row_exit);
1944 self.fill_entry_row(i, &mut row_entry);
1945 self.fill_derivative_row(i, &mut row_derivative);
1946 for j in 0..p {
1947 let contribution = w_interval_exit * row_exit[j]
1948 - w_interval_entry * row_entry[j]
1949 - w_event_exit * row_exit[j]
1950 - w_event_derivative * row_derivative[j];
1951 let t = grad[j] + contribution;
1952 if grad[j].abs() >= contribution.abs() {
1953 grad_comp[j] += (grad[j] - t) + contribution;
1954 } else {
1955 grad_comp[j] += (contribution - t) + grad[j];
1956 }
1957 grad[j] = t;
1958 }
1959 }
1960 grad += &grad_comp;
1961
1962 h += &self.derivative_xt_diag_x(w_event_outer);
1963
1964 let score_norm = array1_l2_norm(&grad);
1968
1969 let penaltygrad = self.penalties.gradient(beta);
1970 let penalty_dev = self.penalties.deviance(beta);
1971 let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1972
1973 let mut totalgrad = grad;
1974 totalgrad += &penaltygrad;
1975
1976 self.penalties.addhessian_inplace(&mut h);
1977 const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1989 let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1990 for d in 0..p {
1991 h[[d, d]] += ridge_used;
1992 }
1993 totalgrad += &beta.mapv(|v| ridge_used * v);
1994 let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1998 let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1999
2000 let log_likelihood = -nll;
2001 let deviance = 2.0 * nll;
2002
2003 Ok(WorkingState {
2004 eta: LinearPredictor::new(eta_exit),
2005 gradient: totalgrad,
2006 hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2007 log_likelihood,
2008 deviance,
2009 penalty_term: penalty_dev + ridge_penalty,
2010 firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2011 ridge_used,
2012 hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2013 gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2014 })
2015 }
2016
2017 pub(crate) fn survival_hessian_derivative_correction(
2027 &self,
2028 beta: &Array1<f64>,
2029 u_k: &Array1<f64>,
2030 ) -> Result<Array2<f64>, EstimationError> {
2031 let p = beta.len();
2032 let n = self.nrows();
2033
2034 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2035 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2036 let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2037 let exp_entry = eta_entry.mapv(f64::exp);
2038 let exp_exit = eta_exit.mapv(f64::exp);
2039 let guard = self.derivative_guard();
2040 let guard_numerical = self.derivative_guard_numerical();
2041
2042 let jac = Array1::<f64>::ones(p);
2043 let curvature = Array1::<f64>::zeros(p);
2044 let third = Array1::<f64>::zeros(p);
2045
2046 let mut row_exit = vec![0.0_f64; p];
2047 let mut row_entry = vec![0.0_f64; p];
2048 let mut row_derivative = vec![0.0_f64; p];
2049 let mut ge = vec![0.0_f64; p];
2050 let mut gs = vec![0.0_f64; p];
2051 let mut gsd = vec![0.0_f64; p];
2052 let mut he = vec![0.0_f64; p];
2053 let mut hs = vec![0.0_f64; p];
2054 let mut hsd = vec![0.0_f64; p];
2055 let mut te = vec![0.0_f64; p];
2056 let mut ts = vec![0.0_f64; p];
2057 let mut tsd = vec![0.0_f64; p];
2058
2059 let mut b_dir = Array2::<f64>::zeros((p, p));
2060
2061 for i in 0..n {
2062 let w_i = self.sampleweight[i];
2063 if w_i <= 0.0 {
2064 continue;
2065 }
2066 let has_entry = !self.entry_at_origin[i];
2067 let mut deta_e = 0.0_f64;
2068 let mut deta_s = 0.0_f64;
2069 let mut ds = 0.0_f64;
2070 self.fill_exit_row(i, &mut row_exit);
2071 self.fill_entry_row(i, &mut row_entry);
2072 self.fill_derivative_row(i, &mut row_derivative);
2073 for j in 0..p {
2074 ge[j] = row_exit[j] * jac[j];
2075 gs[j] = row_entry[j] * jac[j];
2076 gsd[j] = row_derivative[j] * jac[j];
2077 he[j] = row_exit[j] * curvature[j];
2078 hs[j] = row_entry[j] * curvature[j];
2079 hsd[j] = row_derivative[j] * curvature[j];
2080 te[j] = row_exit[j] * third[j];
2081 ts[j] = row_entry[j] * third[j];
2082 tsd[j] = row_derivative[j] * third[j];
2083 deta_e += ge[j] * u_k[j];
2084 if has_entry {
2085 deta_s += gs[j] * u_k[j];
2086 }
2087 ds += gsd[j] * u_k[j];
2088 }
2089
2090 for r in 0..p {
2092 let dge_r = he[r] * u_k[r];
2093 let dgs_r = hs[r] * u_k[r];
2094 let dhe_r = te[r] * u_k[r];
2095 let dhs_r = ts[r] * u_k[r];
2096 for c in 0..p {
2097 let dge_c = he[c] * u_k[c];
2098 let dgs_c = hs[c] * u_k[c];
2099 let mut d_h_rc =
2100 exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2101 if r == c {
2102 d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2103 }
2104 if has_entry {
2105 d_h_rc -=
2106 exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2107 if r == c {
2108 d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2109 }
2110 }
2111 b_dir[[r, c]] += w_i * d_h_rc;
2112 }
2113 }
2114
2115 let s_i = self
2117 .stabilized_structural_derivative(deriv_raw[i])
2118 .unwrap_or(deriv_raw[i]);
2119 if !s_i.is_finite() {
2120 return Err(EstimationError::ParameterConstraintViolation(format!(
2121 "survival monotonicity violated in unified trace contraction at row {i}: \
2122 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2123 )));
2124 }
2125 if self.event_target[i] > 0 {
2126 if s_i < guard_numerical {
2127 return Err(EstimationError::ParameterConstraintViolation(format!(
2128 "survival monotonicity violated in unified trace contraction at row {i}: \
2129 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2130 )));
2131 }
2132 let inv_s = 1.0 / s_i;
2133 let inv_s2 = inv_s * inv_s;
2134 let inv_s3 = inv_s2 * inv_s;
2135 for r in 0..p {
2136 let dgd_r = hsd[r] * u_k[r];
2137 let dtsd_r = tsd[r] * u_k[r];
2138 let dte_r = te[r] * u_k[r];
2139 for c in 0..p {
2140 let dgd_c = hsd[c] * u_k[c];
2141 let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2142 - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2143 if r == c {
2144 d_h_rc += -dte_r;
2145 d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2146 }
2147 b_dir[[r, c]] += w_i * d_h_rc;
2148 }
2149 }
2150 }
2151 }
2152
2153 Ok(b_dir)
2154 }
2155
2156 pub fn offset_channel_residuals(
2194 &self,
2195 beta: &Array1<f64>,
2196 ) -> Result<OffsetChannelResiduals, EstimationError> {
2197 if beta.len() != self.coefficient_dim() {
2198 crate::bail_invalid_estim!(
2199 "survival beta dimension mismatch in offset_channel_residuals"
2200 );
2201 }
2202 let n = self.nrows();
2203 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2204 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2205 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2206
2207 let derivative_guard_numerical = self.derivative_guard_numerical();
2208 let mut r_exit = Array1::<f64>::zeros(n);
2209 let mut r_entry = Array1::<f64>::zeros(n);
2210 let mut r_deriv = Array1::<f64>::zeros(n);
2211
2212 for i in 0..n {
2213 let w = self.sampleweight[i];
2214 if w <= 0.0 {
2215 continue;
2216 }
2217 let entry_age = self.age_entry[i];
2218 let exit_age = self.age_exit[i];
2219 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2220 crate::bail_invalid_estim!(
2221 "survival ages must be finite with age_exit >= age_entry"
2222 );
2223 }
2224 let has_entry_interval = !self.entry_at_origin[i];
2225 let d = f64::from(self.event_target[i]);
2226 let w_exit_i = w * eta_exit[i].exp();
2230 let w_entry_i = if has_entry_interval {
2231 w * eta_entry[i].exp()
2232 } else {
2233 0.0
2234 };
2235 if !w_exit_i.is_finite() {
2236 crate::bail_invalid_estim!(
2237 "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2238 );
2239 }
2240 r_exit[i] = w_exit_i - d * w;
2241 r_entry[i] = -w_entry_i;
2242 let deriv_raw = derivative_raw[i];
2247 let deriv = self
2248 .stabilized_structural_derivative(deriv_raw)
2249 .unwrap_or(deriv_raw);
2250 let mono_floor = if d > 0.0 {
2251 derivative_guard_numerical
2252 } else {
2253 0.0
2254 };
2255 if !deriv.is_finite() || deriv < mono_floor {
2256 return Err(EstimationError::ParameterConstraintViolation(format!(
2257 "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2258 )));
2259 }
2260 if d > 0.0 {
2261 r_deriv[i] = -w * d / deriv;
2262 }
2263 }
2264
2265 let right = Array1::<f64>::zeros(r_exit.len());
2266 Ok(OffsetChannelResiduals {
2267 exit: r_exit,
2268 entry: r_entry,
2269 derivative: r_deriv,
2270 right,
2271 })
2272 }
2273
2274 pub fn unified_lamlobjective_and_rhogradient(
2280 &self,
2281 beta: &Array1<f64>,
2282 state: &WorkingState,
2283 rho: &Array1<f64>,
2284 ) -> Result<(f64, Array1<f64>), EstimationError> {
2285 use gam_solve::estimate::reml::assembly::{
2286 InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2287 };
2288 use gam_solve::estimate::reml::reml_outer_engine::{
2289 DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2290 compute_block_penalty_logdet_derivs,
2291 };
2292 use gam_problem::EvalMode;
2293
2294 let p = beta.len();
2295 let active_penalty_blocks: Vec<&PenaltyBlock> = self
2296 .penalties
2297 .blocks
2298 .iter()
2299 .filter(|b| b.lambda > 0.0)
2300 .collect();
2301 if rho.len() != active_penalty_blocks.len() {
2302 crate::bail_invalid_estim!(
2303 "survival LAML rho dimension {} does not match active penalty block count {}",
2304 rho.len(),
2305 active_penalty_blocks.len()
2306 );
2307 }
2308 let k_count = active_penalty_blocks.len();
2309
2310 let h_dense = state.hessian.to_dense();
2312 let hop = DenseSpectralOperator::from_symmetric(&h_dense)
2313 .map_err(EstimationError::InvalidInput)?;
2314
2315 let block_descs: Vec<PenaltyBlockDesc> = self
2317 .penalties
2318 .blocks
2319 .iter()
2320 .filter(|b| b.lambda > 0.0)
2321 .map(|b| PenaltyBlockDesc {
2322 matrix: &b.matrix,
2323 range_start: b.range.start,
2324 range_end: b.range.end,
2325 })
2326 .collect();
2327 let penalty_coords =
2328 penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2329
2330 let per_block_rho: Vec<Array1<f64>> =
2332 rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2333 let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2334 .iter()
2335 .map(|b| vec![b.matrix.clone()])
2336 .collect();
2337 let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2338 .iter()
2339 .map(|v| v.as_slice())
2340 .collect();
2341 let penalty_logdet = if k_count > 0 {
2342 compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2343 .map_err(EstimationError::InvalidInput)?
2344 } else {
2345 PenaltyLogdetDerivs {
2346 value: 0.0,
2347 first: Array1::zeros(0),
2348 second: Some(Array2::zeros((0, 0))),
2349 }
2350 };
2351
2352 let penalty_quadratic = 2.0 * state.penalty_term;
2354 let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2355
2356 const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2370 let kkt_residual = {
2371 let raw = state.gradient.clone();
2372 let projected = match self.monotonicity_linear_constraints() {
2373 Some(constraints) => {
2374 projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2375 .ok_or_else(|| {
2376 EstimationError::InvalidInput(
2377 "survival LAML could not project the monotonicity KKT residual"
2378 .to_string(),
2379 )
2380 })?
2381 }
2382 None => raw,
2383 };
2384 let projected_norm = array1_l2_norm(&projected);
2385 let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2386 if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2387 Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2388 } else {
2389 None
2390 }
2391 };
2392
2393 let result = InnerAssembly {
2394 log_likelihood: state.log_likelihood,
2395 penalty_quadratic,
2396 beta: beta.clone(),
2397 n_observations: self.nrows(),
2398 hessian_op: std::sync::Arc::new(hop),
2399 penalty_coords,
2400 penalty_logdet,
2401 dispersion: DispersionHandling::Fixed {
2402 phi: 1.0,
2403 include_logdet_h: true,
2404 include_logdet_s: true,
2405 },
2406 rho_curvature_scale: 1.0,
2407 rho_prior: gam_problem::RhoPrior::Flat,
2408 hessian_logdet_correction: 0.0,
2409 penalty_subspace_trace: None,
2410 deriv_provider: Some(Box::new(provider)),
2411 firth: None,
2412 nullspace_dim: None,
2413 barrier_config: None,
2414 ext_coords: Vec::new(),
2415 ext_coord_pair_fn: None,
2416 rho_ext_pair_fn: None,
2417 fixed_drift_deriv: None,
2418 contracted_psi_second_order: None,
2419 kkt_residual,
2420 active_constraints: None,
2421 }
2422 .evaluate(
2423 rho.as_slice().expect("rho must be contiguous"),
2424 EvalMode::ValueAndGradient,
2425 None,
2426 )
2427 .map_err(EstimationError::InvalidInput)?;
2428
2429 let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2430 Ok((result.cost, gradient))
2431 }
2432
2433 pub fn evaluate_survival_lamlcost_and_gradient(
2454 &self,
2455 rho: &[f64],
2456 beta0: &Array1<f64>,
2457 ) -> Result<(f64, Array1<f64>), EstimationError> {
2458 let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2459 let rho_arr = Array1::from_vec(rho.to_vec());
2464 let state = candidate.update_state(&beta)?;
2465 candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2466 }
2467
2468 fn reconverge_survival_inner_mode(
2478 &self,
2479 rho: &[f64],
2480 beta0: &Array1<f64>,
2481 ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2482 const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2487 const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2488 const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2489 const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2490
2491 let active_block_count = self
2492 .penalties
2493 .blocks
2494 .iter()
2495 .filter(|b| b.lambda > 0.0)
2496 .count();
2497 if rho.len() != active_block_count {
2498 crate::bail_invalid_estim!(
2499 "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2500 rho.len(),
2501 active_block_count
2502 );
2503 }
2504 if beta0.len() != self.coefficient_dim() {
2505 crate::bail_invalid_estim!(
2506 "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2507 beta0.len(),
2508 self.coefficient_dim()
2509 );
2510 }
2511
2512 let mut candidate = self.clone();
2515 let mut lambdas: Vec<f64> = candidate
2516 .penalties
2517 .blocks
2518 .iter()
2519 .map(|b| b.lambda)
2520 .collect();
2521 let mut active_idx = 0usize;
2522 for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2523 if block.lambda > 0.0 {
2524 *lambda = rho[active_idx].exp();
2525 active_idx += 1;
2526 }
2527 }
2528 candidate.set_penalty_lambdas(&lambdas)?;
2529
2530 let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2531 max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2532 convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2533 adaptive_kkt_tolerance: None,
2534 max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2535 min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2536 firth_bias_reduction: false,
2537 coefficient_lower_bounds: None,
2538 linear_constraints: None,
2539 initial_lm_lambda: None,
2540 geodesic_acceleration: false,
2541 arrow_schur: None,
2542 };
2543 let summary = gam_solve::pirls::runworking_model_pirls(
2544 &mut candidate,
2545 Coefficients::new(beta0.clone()),
2546 &opts,
2547 |_| {},
2548 )?;
2549 let mut beta = summary.beta.as_ref().to_owned();
2550
2551 {
2573 const POLISH_MAX_ITERS: usize = 400;
2574 const POLISH_TOL: f64 = 1e-13;
2575 const ARMIJO_C: f64 = 1e-4;
2577 const BACKTRACK: f64 = 0.5;
2578 const MAX_BACKTRACK: usize = 80;
2579 let p = beta.len();
2580 let penalized_objective =
2584 |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2585 for _ in 0..POLISH_MAX_ITERS {
2586 let st = match candidate.update_state(&beta) {
2587 Ok(st) => st,
2588 Err(_) => break,
2589 };
2590 let r = st.gradient.clone();
2591 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2592 if !r_norm.is_finite() || r_norm < POLISH_TOL {
2593 break;
2594 }
2595 let h = st.hessian.to_dense();
2596 let f0 = penalized_objective(&st);
2597 let h_scale = (0..p)
2612 .map(|d| h[[d, d]].abs())
2613 .fold(0.0_f64, f64::max)
2614 .max(1.0);
2615 let mut step: Option<Array1<f64>> = None;
2626 let mut dir_deriv = 0.0_f64;
2627 for lm_pow in 0..18 {
2628 let lambda_lm = if lm_pow == 0 {
2629 0.0
2630 } else {
2631 1e-12 * h_scale * 10f64.powi(lm_pow)
2632 };
2633 let mut h_reg = h.clone();
2634 for d in 0..p {
2635 h_reg[[d, d]] += lambda_lm;
2636 }
2637 let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2638 &h_reg,
2639 faer::Side::Lower,
2640 ) {
2641 Ok(f) => f,
2642 Err(_) => continue,
2643 };
2644 let candidate_step = factor.solvevec(&r);
2645 if candidate_step.iter().any(|v| !v.is_finite()) {
2646 continue;
2647 }
2648 let dd = -r.dot(&candidate_step);
2650 if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2651 step = Some(candidate_step);
2652 dir_deriv = dd;
2653 break;
2654 }
2655 }
2656 let (step, dir_deriv) = match step {
2657 Some(s) => (s, dir_deriv),
2658 None => {
2659 (r.clone(), -r_norm * r_norm)
2662 }
2663 };
2664 let mut alpha = 1.0_f64;
2665 let mut accepted = false;
2666 for _ in 0..MAX_BACKTRACK {
2667 let trial = &beta - &(alpha * &step);
2668 if let Ok(ts) = candidate.update_state(&trial) {
2669 let ft = penalized_objective(&ts);
2670 let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2671 let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2682 let residual_ok = tn.is_finite() && tn < r_norm;
2683 if armijo_ok || residual_ok {
2684 beta = trial;
2685 accepted = true;
2686 break;
2687 }
2688 }
2689 alpha *= BACKTRACK;
2690 }
2691 if !accepted {
2692 break;
2693 }
2694 }
2695 }
2696
2697 Ok((candidate, beta))
2698 }
2699}
2700
2701pub(crate) struct SurvivalDerivProvider {
2710 model: WorkingModelSurvival,
2711 beta: Array1<f64>,
2712}
2713
2714impl SurvivalDerivProvider {
2715 pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2716 Self { model, beta }
2717 }
2718}
2719
2720impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2721 fn hessian_derivative_correction(
2722 &self,
2723 v_k: &Array1<f64>,
2724 ) -> Result<Option<Array2<f64>>, String> {
2725 let u_k = -v_k;
2728 match self
2729 .model
2730 .survival_hessian_derivative_correction(&self.beta, &u_k)
2731 {
2732 Ok(correction) => Ok(Some(correction)),
2733 Err(e) => Err(e.to_string()),
2734 }
2735 }
2736
2737 fn has_corrections(&self) -> bool {
2738 true
2739 }
2740}
2741
2742#[derive(Debug, Clone)]
2743pub struct CrudeRiskResult {
2744 pub risk: f64,
2745 pub diseasegradient: Array1<f64>,
2746 pub mortalitygradient: Array1<f64>,
2747}
2748
2749#[derive(Debug, Clone)]
2750pub struct CompetingRisksCifResult {
2751 pub cif: Vec<Array2<f64>>,
2756 pub overall_survival: Array2<f64>,
2757}
2758
2759const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2764
2765pub fn assemble_competing_risks_cif(
2766 times: ArrayView1<'_, f64>,
2767 cumulative_hazard: ArrayView3<'_, f64>,
2768) -> Result<CompetingRisksCifResult, SurvivalError> {
2769 let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2770 if n_endpoints == 0 {
2771 return Err(SurvivalError::DimensionMismatch);
2772 }
2773 let endpoint_hazards = cumulative_hazard
2774 .axis_iter(Axis(0))
2775 .map(|view| view.to_owned())
2776 .collect::<Vec<_>>();
2777 assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2778 if result.overall_survival.dim() != (n_rows, n_times) {
2779 Err(SurvivalError::DimensionMismatch)
2780 } else {
2781 Ok(result)
2782 }
2783 })
2784}
2785
2786pub fn assemble_competing_risks_cif_from_endpoints(
2787 times: ArrayView1<'_, f64>,
2788 cumulative_hazards: &[Array2<f64>],
2789) -> Result<CompetingRisksCifResult, SurvivalError> {
2790 let n_endpoints = cumulative_hazards.len();
2791 if n_endpoints == 0 || times.is_empty() {
2792 return Err(SurvivalError::DimensionMismatch);
2793 }
2794 let (n_rows, n_times) = cumulative_hazards[0].dim();
2795 if n_rows == 0 || n_times == 0 || times.len() != n_times {
2796 return Err(SurvivalError::DimensionMismatch);
2797 }
2798 if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2799 return Err(SurvivalError::InvalidTimeGrid);
2800 }
2801 if times
2802 .iter()
2803 .zip(times.iter().skip(1))
2804 .any(|(previous, current)| current <= previous)
2805 {
2806 return Err(SurvivalError::InvalidTimeGrid);
2807 }
2808 for endpoint_hazard in cumulative_hazards {
2809 if endpoint_hazard.dim() != (n_rows, n_times) {
2810 return Err(SurvivalError::DimensionMismatch);
2811 }
2812 if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2813 return Err(SurvivalError::NonFiniteInput);
2814 }
2815 }
2816
2817 let max_abs_hazard = cumulative_hazards
2818 .iter()
2819 .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2820 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2821 let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2822 let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2823 .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2824 .collect();
2825 let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2826
2827 let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2839 let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2840 let mut surv_row = vec![0.0_f64; n_times];
2841 let mut previous_cif = vec![0.0_f64; n_endpoints];
2842 let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2843 let mut increments = vec![0.0_f64; n_endpoints];
2844 let mut previous_total_cumulative = 0.0_f64;
2845 for time_idx in 0..n_times {
2846 let mut total_increment = 0.0_f64;
2847 for endpoint in 0..n_endpoints {
2848 let current = cumulative_hazards[endpoint][[row, time_idx]];
2849 if current < -monotone_tolerance {
2850 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2851 }
2852 let raw_increment = current - previous_cumulative[endpoint];
2853 if raw_increment < -monotone_tolerance {
2854 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2855 }
2856 let increment = raw_increment.max(0.0);
2857 increments[endpoint] = increment;
2858 total_increment += increment;
2859 previous_cumulative[endpoint] += increment;
2860 }
2861
2862 let survival_left = (-previous_total_cumulative).exp();
2863 let interval_failure = -(-total_increment).exp_m1();
2864 for endpoint in 0..n_endpoints {
2865 if total_increment > 0.0 {
2866 previous_cif[endpoint] +=
2867 survival_left * interval_failure * increments[endpoint] / total_increment;
2868 }
2869 cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2870 }
2871 previous_total_cumulative += total_increment;
2872 let mut fsum_at_t = 0.0_f64;
2889 for endpoint in 0..n_endpoints {
2890 fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2891 }
2892 surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2893 }
2894 Ok((cif_flat, surv_row))
2895 };
2896
2897 let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2901 && rayon::current_thread_index().is_none()
2902 {
2903 use rayon::prelude::*;
2904 (0..n_rows)
2905 .into_par_iter()
2906 .map(assemble_row)
2907 .collect::<Result<_, _>>()?
2908 } else {
2909 (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2910 };
2911
2912 for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2913 for endpoint in 0..n_endpoints {
2914 for time_idx in 0..n_times {
2915 cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2916 }
2917 }
2918 for time_idx in 0..n_times {
2919 overall_survival[[row, time_idx]] = surv_row[time_idx];
2920 }
2921 }
2922
2923 Ok(CompetingRisksCifResult {
2924 cif,
2925 overall_survival,
2926 })
2927}
2928
2929fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2930 let mut nodesweights = Vec::with_capacity(n);
2931 let m = n.div_ceil(2);
2932
2933 for i in 0..m {
2934 let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2935 let mut pp = 0.0;
2936
2937 for _ in 0..100 {
2938 let mut p1 = 1.0;
2939 let mut p2 = 0.0;
2940 for j in 0..n {
2941 let p3 = p2;
2942 p2 = p1;
2943 p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2944 }
2945 pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2946 let z_prev = z;
2947 z = z_prev - p1 / pp;
2948 if (z - z_prev).abs() < 1e-14 {
2949 break;
2950 }
2951 }
2952
2953 let x = z;
2954 let w = 2.0 / ((1.0 - z * z) * pp * pp);
2955 if !n.is_multiple_of(2) && i == m - 1 {
2956 nodesweights.push((0.0, w));
2957 } else {
2958 nodesweights.push((-x, w));
2959 nodesweights.push((x, w));
2960 }
2961 }
2962
2963 nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2964 nodesweights
2965}
2966
2967fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2968 static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2974 &CACHE
2975}
2976
2977pub fn calculate_crude_risk_quadrature<F>(
3001 t0: f64,
3002 t1: f64,
3003 breakpoints: &[f64],
3004 h_dis_t0: f64,
3005 h_mor_t0: f64,
3006 design_d_t0: ArrayView1<'_, f64>,
3007 design_m_t0: ArrayView1<'_, f64>,
3008 mut eval_at: F,
3009) -> Result<CrudeRiskResult, SurvivalError>
3010where
3011 F: FnMut(
3012 f64,
3013 &mut Array1<f64>,
3014 &mut Array1<f64>,
3015 &mut Array1<f64>,
3016 ) -> Result<(f64, f64, f64), SurvivalError>,
3017{
3018 let coeff_len_d = design_d_t0.len();
3019 let coeff_len_m = design_m_t0.len();
3020 if coeff_len_d == 0 || coeff_len_m == 0 {
3021 return Err(SurvivalError::InvalidIntegrationSetup);
3022 }
3023 if !t0.is_finite()
3024 || !t1.is_finite()
3025 || !h_dis_t0.is_finite()
3026 || !h_mor_t0.is_finite()
3027 || design_d_t0.iter().any(|v| !v.is_finite())
3028 || design_m_t0.iter().any(|v| !v.is_finite())
3029 {
3030 return Err(SurvivalError::NonFiniteInput);
3031 }
3032 if t1 <= t0 {
3033 return Ok(CrudeRiskResult {
3034 risk: 0.0,
3035 diseasegradient: Array1::zeros(coeff_len_d),
3036 mortalitygradient: Array1::zeros(coeff_len_m),
3037 });
3038 }
3039
3040 let mut sorted_breaks: Vec<f64> = breakpoints
3041 .iter()
3042 .copied()
3043 .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3044 .collect();
3045 sorted_breaks.push(t0);
3046 sorted_breaks.push(t1);
3047 sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3048 sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3049 if sorted_breaks.len() < 2 {
3050 return Err(SurvivalError::InvalidIntegrationSetup);
3051 }
3052
3053 let mut total_risk = 0.0;
3054 let mut diseasegradient = Array1::zeros(coeff_len_d);
3055 let mut mortalitygradient = Array1::zeros(coeff_len_m);
3056 let nodesweights = gauss_legendre_quadrature();
3057
3058 let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3059 let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3060 let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3061
3062 for segment in sorted_breaks.windows(2) {
3063 let a = segment[0];
3064 let b = segment[1];
3065 let center = 0.5 * (b + a);
3066 let halfwidth = 0.5 * (b - a);
3067 if halfwidth <= 0.0 {
3068 continue;
3069 }
3070
3071 for &(x, w) in nodesweights {
3072 let u = center + halfwidth * x;
3073 let (inst_hazard_d, hazard_d, hazard_m) =
3074 eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3075 if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3076 return Err(SurvivalError::NonFiniteInput);
3077 }
3078 if inst_hazard_d <= 0.0 {
3079 return Err(SurvivalError::NonPositiveHazard);
3080 }
3081
3082 if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3083 return Err(SurvivalError::NonMonotoneCumulativeHazard);
3084 }
3085
3086 let h_dis_cond = hazard_d - h_dis_t0;
3087 let h_mor_cond = hazard_m - h_mor_t0;
3088 let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3089
3090 total_risk += w * inst_hazard_d * s_total * halfwidth;
3091
3092 let weight = w * s_total * halfwidth;
3098 for j in 0..coeff_len_d {
3099 let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3100 let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3101 let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3102 diseasegradient[j] += weight * g;
3103 }
3104
3105 let weight = w * inst_hazard_d * s_total * halfwidth;
3108 for j in 0..coeff_len_m {
3109 let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3110 mortalitygradient[j] += weight * g;
3111 }
3112 }
3113 }
3114
3115 Ok(CrudeRiskResult {
3116 risk: total_risk,
3117 diseasegradient,
3118 mortalitygradient,
3119 })
3120}
3121
3122impl PirlsWorkingModel for WorkingModelSurvival {
3123 fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3124 self.update_state(beta)
3125 }
3126}
3127
3128#[cfg(test)]
3129mod tests {
3130 use super::*;
3131 use ndarray::{Array1, Array2, Array3, array, s};
3132
3133 #[test]
3134 fn competing_risks_cif_constant_hazard_matches_closed_form() {
3135 let times = array![0.0, 2.0, 5.0, 10.0];
3136 let disease_rates = [0.12, 0.06];
3137 let death_rates = [0.05, 0.02];
3138 let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3139 let rate = if endpoint == 0 {
3140 disease_rates[row]
3141 } else {
3142 death_rates[row]
3143 };
3144 rate * times[time_idx]
3145 });
3146
3147 let result =
3148 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3149
3150 for row in 0..2 {
3151 let total_rate = disease_rates[row] + death_rates[row];
3152 for time_idx in 0..times.len() {
3153 let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3154 let expected_disease = disease_rates[row] / total_rate * failure;
3155 let expected_death = death_rates[row] / total_rate * failure;
3156 assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3157 assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3158 assert!(
3159 (result.cif[0][[row, time_idx]]
3160 + result.cif[1][[row, time_idx]]
3161 + result.overall_survival[[row, time_idx]]
3162 - 1.0)
3163 .abs()
3164 < 1e-12
3165 );
3166 }
3167 }
3168 }
3169
3170 #[test]
3171 fn competing_risks_cif_rejects_nonmonotone_hazards() {
3172 let times = array![0.0, 1.0, 2.0];
3173 let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3174 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3175 .expect_err("nonmonotone cumulative hazard should be rejected");
3176 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3177 }
3178
3179 #[test]
3180 fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3181 let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3182 let cumulative = Array3::from_shape_vec(
3183 (3, 2, 5),
3184 vec![
3185 0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, 0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, 0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3189 ],
3190 )
3191 .expect("shape");
3192
3193 let result =
3194 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3195
3196 for row in 0..2 {
3197 for time_idx in 0..times.len() {
3198 let total_cif = result.cif[0][[row, time_idx]]
3199 + result.cif[1][[row, time_idx]]
3200 + result.cif[2][[row, time_idx]];
3201 assert!(
3202 (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3203 "probability mass mismatch at row={row}, time_idx={time_idx}"
3204 );
3205 assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3206 for cause in 0..3 {
3207 assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3208 if time_idx > 0 {
3209 assert!(
3210 result.cif[cause][[row, time_idx]] + 1e-12
3211 >= result.cif[cause][[row, time_idx - 1]],
3212 "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3213 );
3214 }
3215 }
3216 }
3217 }
3218
3219 assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3222 assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3225 assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3226 }
3227
3228 #[test]
3229 fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3230 let cumulative = Array3::zeros((2, 1, 2));
3231
3232 for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3233 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3234 .expect_err("bad time grid should be rejected");
3235 assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3236 }
3237
3238 let times = array![0.0, 1.0];
3239 let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3240 let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3241 .expect_err("nonfinite hazard should be rejected");
3242 assert!(matches!(err, SurvivalError::NonFiniteInput));
3243 }
3244
3245 #[test]
3246 fn competing_risks_cif_extreme_hazards_remain_bounded() {
3247 let times = array![0.0, 1.0, 2.0];
3248 let cumulative =
3249 Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3250 .expect("shape");
3251
3252 let result =
3253 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3254
3255 for value in result
3256 .cif
3257 .iter()
3258 .flat_map(|m| m.iter())
3259 .chain(result.overall_survival.iter())
3260 {
3261 assert!(value.is_finite());
3262 assert!((0.0..=1.0).contains(value));
3263 }
3264 assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3265 assert_eq!(result.overall_survival[[0, 2]], 0.0);
3266 }
3267
3268 fn toy_penalties() -> PenaltyBlocks {
3269 let s = array![[2.0, 0.5], [0.5, 3.0]];
3270 PenaltyBlocks::new(vec![PenaltyBlock {
3271 matrix: s,
3272 lambda: 1.7,
3273 range: 1..3,
3274 nullspace_dim: 0,
3275 }])
3276 }
3277
3278 fn survival_inputs<'a>(
3279 age_entry: &'a Array1<f64>,
3280 age_exit: &'a Array1<f64>,
3281 event_target: &'a Array1<u8>,
3282 event_competing: &'a Array1<u8>,
3283 sampleweight: &'a Array1<f64>,
3284 x_entry: &'a Array2<f64>,
3285 x_exit: &'a Array2<f64>,
3286 x_derivative: &'a Array2<f64>,
3287 ) -> SurvivalEngineInputs<'a> {
3288 SurvivalEngineInputs {
3289 age_entry: age_entry.view(),
3290 age_exit: age_exit.view(),
3291 event_target: event_target.view(),
3292 event_competing: event_competing.view(),
3293 sampleweight: sampleweight.view(),
3294 x_entry: x_entry.view(),
3295 x_exit: x_exit.view(),
3296 x_derivative: x_derivative.view(),
3297 monotonicity_constraint_rows: None,
3298 monotonicity_constraint_offsets: None,
3299 }
3300 }
3301
3302 fn survival_model(
3303 inputs: SurvivalEngineInputs<'_>,
3304 penalties: PenaltyBlocks,
3305 monotonicity: SurvivalMonotonicityPenalty,
3306 spec: SurvivalSpec,
3307 ) -> Result<WorkingModelSurvival, SurvivalError> {
3308 WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3309 }
3310
3311 fn survival_model_with_offsets(
3312 inputs: SurvivalEngineInputs<'_>,
3313 offsets: Option<SurvivalBaselineOffsets<'_>>,
3314 penalties: PenaltyBlocks,
3315 monotonicity: SurvivalMonotonicityPenalty,
3316 spec: SurvivalSpec,
3317 ) -> Result<WorkingModelSurvival, SurvivalError> {
3318 WorkingModelSurvival::from_engine_inputswith_offsets(
3319 inputs,
3320 offsets,
3321 penalties,
3322 monotonicity,
3323 spec,
3324 )
3325 }
3326
3327 #[test]
3328 fn penaltyhessian_matchesgradient_jacobian() {
3329 let penalties = toy_penalties();
3330 let beta = array![10.0, -0.3, 1.2, 7.0];
3331
3332 let grad = penalties.gradient(&beta);
3333 let h = penalties.hessian(beta.len());
3334 let b_block = beta.slice(s![1..3]).to_owned();
3335 let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3336
3337 assert!((grad[1] - expected[0]).abs() < 1e-12);
3338 assert!((grad[2] - expected[1]).abs() < 1e-12);
3339 assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3340 assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3341 assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3342 assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3343 }
3344
3345 #[test]
3346 fn penaltygradient_matches_deviance_finite_difference() {
3347 let penalties = toy_penalties();
3348 let beta = array![10.0, -0.3, 1.2, 7.0];
3349 let grad = penalties.gradient(&beta);
3350 let eps = 1e-7;
3351
3352 for idx in 0..beta.len() {
3353 let mut plus = beta.clone();
3354 let mut minus = beta.clone();
3355 plus[idx] += eps;
3356 minus[idx] -= eps;
3357 let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3358 assert_eq!(
3359 grad[idx].signum(),
3360 fd.signum(),
3361 "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3362 grad[idx]
3363 );
3364 assert!(
3365 (grad[idx] - fd).abs() < 1e-6,
3366 "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3367 grad[idx]
3368 );
3369 }
3370 }
3371
3372 #[test]
3373 fn zero_offsets_match_default_survival_state() {
3374 let age_entry = array![1.0_f64, 2.0_f64];
3375 let age_exit = array![2.0_f64, 3.5_f64];
3376 let event_target = array![1u8, 0u8];
3377 let event_competing = array![0u8, 0u8];
3378 let sampleweight = array![1.0, 1.0];
3379 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3380 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3381 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3382 let penalties = PenaltyBlocks::new(Vec::new());
3383 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3384 let beta = array![-1.0, 0.8];
3385
3386 let base = survival_model(
3387 survival_inputs(
3388 &age_entry,
3389 &age_exit,
3390 &event_target,
3391 &event_competing,
3392 &sampleweight,
3393 &x_entry,
3394 &x_exit,
3395 &x_derivative,
3396 ),
3397 penalties.clone(),
3398 mono,
3399 SurvivalSpec::Net,
3400 )
3401 .expect("construct base survival model");
3402
3403 let zero_offsets = survival_model_with_offsets(
3404 survival_inputs(
3405 &age_entry,
3406 &age_exit,
3407 &event_target,
3408 &event_competing,
3409 &sampleweight,
3410 &x_entry,
3411 &x_exit,
3412 &x_derivative,
3413 ),
3414 Some(SurvivalBaselineOffsets {
3415 eta_entry: array![0.0, 0.0].view(),
3416 eta_exit: array![0.0, 0.0].view(),
3417 derivative_exit: array![0.0, 0.0].view(),
3418 }),
3419 penalties,
3420 mono,
3421 SurvivalSpec::Net,
3422 )
3423 .expect("construct offset survival model");
3424
3425 let state_base = base.update_state(&beta).expect("base state");
3426 let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3427 assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3428 assert!(
3429 state_base
3430 .gradient
3431 .iter()
3432 .zip(statezero.gradient.iter())
3433 .all(|(a, b)| (a - b).abs() < 1e-12)
3434 );
3435 }
3436
3437 #[test]
3438 fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3439 let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3453 let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3454 let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3456 let event_competing = Array1::<u8>::zeros(cause_labels.len());
3457 let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3458 let x_entry = array![
3459 [1.0, age_entry[0].max(1e-8).ln()],
3460 [1.0, age_entry[1].max(1e-8).ln()],
3461 [1.0, age_entry[2].max(1e-8).ln()],
3462 [1.0, age_entry[3].max(1e-8).ln()],
3463 ];
3464 let x_exit = array![
3465 [1.0, age_exit[0].ln()],
3466 [1.0, age_exit[1].ln()],
3467 [1.0, age_exit[2].ln()],
3468 [1.0, age_exit[3].ln()],
3469 ];
3470 let x_derivative = array![
3471 [0.0, 1.0 / age_exit[0]],
3472 [0.0, 1.0 / age_exit[1]],
3473 [0.0, 1.0 / age_exit[2]],
3474 [0.0, 1.0 / age_exit[3]],
3475 ];
3476 let penalties = PenaltyBlocks::new(Vec::new());
3477 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3478
3479 let raw = survival_model(
3484 survival_inputs(
3485 &age_entry,
3486 &age_exit,
3487 &cause_labels,
3488 &event_competing,
3489 &sampleweight,
3490 &x_entry,
3491 &x_exit,
3492 &x_derivative,
3493 ),
3494 penalties.clone(),
3495 mono,
3496 SurvivalSpec::Net,
3497 );
3498 assert!(
3499 matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3500 "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3501 );
3502
3503 let any_event = pooled_any_event_indicator(cause_labels.view());
3506 assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3507 assert_eq!(
3509 cause_specific_event_indicator(cause_labels.view(), 1),
3510 array![0u8, 1u8, 0u8, 0u8]
3511 );
3512 assert_eq!(
3513 cause_specific_event_indicator(cause_labels.view(), 2),
3514 array![0u8, 0u8, 1u8, 0u8]
3515 );
3516 let model = survival_model(
3517 survival_inputs(
3518 &age_entry,
3519 &age_exit,
3520 &any_event,
3521 &event_competing,
3522 &sampleweight,
3523 &x_entry,
3524 &x_exit,
3525 &x_derivative,
3526 ),
3527 penalties,
3528 mono,
3529 SurvivalSpec::Net,
3530 )
3531 .expect("pooled any-event baseline model must construct from competing-risks data");
3532
3533 let beta = array![-1.0_f64, 0.8];
3536 let state = model.update_state(&beta).expect("pooled baseline state");
3537 assert!(
3538 state.deviance.is_finite(),
3539 "pooled baseline deviance must be finite, got {}",
3540 state.deviance
3541 );
3542 assert!(
3543 state.gradient.iter().all(|g| g.is_finite()),
3544 "pooled baseline gradient must be finite"
3545 );
3546 }
3547
3548 #[test]
3549 fn offset_channel_residuals_match_central_fd_of_nll() {
3550 let age_entry = array![0.5_f64, 0.0, 0.3];
3555 let age_exit = array![1.4_f64, 1.0, 2.0];
3556 let event_target = array![1u8, 1u8, 0u8];
3557 let event_competing = array![0u8, 0u8, 0u8];
3558 let sampleweight = array![1.0_f64, 2.5, 0.7];
3559 let x_entry = array![
3560 [1.0, age_entry[0].ln()],
3561 [1.0, age_entry[1].max(1e-8).ln()],
3562 [1.0, age_entry[2].ln()]
3563 ];
3564 let x_exit = array![
3565 [1.0, age_exit[0].ln()],
3566 [1.0, age_exit[1].ln()],
3567 [1.0, age_exit[2].ln()]
3568 ];
3569 let x_derivative = array![
3570 [0.0, 1.0 / age_exit[0]],
3571 [0.0, 1.0 / age_exit[1]],
3572 [0.0, 1.0 / age_exit[2]]
3573 ];
3574 let o_entry = array![0.2_f64, 0.0, 0.1];
3577 let o_exit = array![0.4_f64, 0.5, 0.7];
3578 let o_deriv = array![0.3_f64, 0.8, 0.5];
3579 let penalties = PenaltyBlocks::new(Vec::new());
3580 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3581 let beta = array![-0.7_f64, 0.6];
3582
3583 let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3584 survival_model_with_offsets(
3585 survival_inputs(
3586 &age_entry,
3587 &age_exit,
3588 &event_target,
3589 &event_competing,
3590 &sampleweight,
3591 &x_entry,
3592 &x_exit,
3593 &x_derivative,
3594 ),
3595 Some(SurvivalBaselineOffsets {
3596 eta_entry: o_e.view(),
3597 eta_exit: o_x.view(),
3598 derivative_exit: o_d.view(),
3599 }),
3600 penalties.clone(),
3601 mono,
3602 SurvivalSpec::Net,
3603 )
3604 .expect("model build")
3605 };
3606
3607 let base = build(&o_entry, &o_exit, &o_deriv);
3608 let resid = base
3609 .offset_channel_residuals(&beta)
3610 .expect("offset residuals");
3611 assert_eq!(resid.exit.len(), 3);
3612 assert_eq!(resid.entry.len(), 3);
3613 assert_eq!(resid.derivative.len(), 3);
3614
3615 let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3618 let h = 1e-6;
3619
3620 assert_eq!(resid.entry[1], 0.0);
3624 assert_eq!(resid.derivative[2], 0.0);
3625
3626 for i in 0..3 {
3627 {
3629 let mut op = o_exit.clone();
3630 let mut om = o_exit.clone();
3631 op[i] += h;
3632 om[i] -= h;
3633 let fd = (nll(&build(&o_entry, &op, &o_deriv))
3634 - nll(&build(&o_entry, &om, &o_deriv)))
3635 / (2.0 * h);
3636 assert!(
3637 (resid.exit[i] - fd).abs() < 1e-6,
3638 "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3639 resid.exit[i],
3640 fd
3641 );
3642 }
3643 {
3647 let mut op = o_entry.clone();
3648 let mut om = o_entry.clone();
3649 op[i] += h;
3650 om[i] -= h;
3651 let fd = (nll(&build(&op, &o_exit, &o_deriv))
3652 - nll(&build(&om, &o_exit, &o_deriv)))
3653 / (2.0 * h);
3654 assert!(
3655 (resid.entry[i] - fd).abs() < 1e-6,
3656 "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3657 resid.entry[i],
3658 fd
3659 );
3660 }
3661 {
3663 let mut op = o_deriv.clone();
3664 let mut om = o_deriv.clone();
3665 op[i] += h;
3666 om[i] -= h;
3667 let fd = (nll(&build(&o_entry, &o_exit, &op))
3668 - nll(&build(&o_entry, &o_exit, &om)))
3669 / (2.0 * h);
3670 assert!(
3671 (resid.derivative[i] - fd).abs() < 1e-6,
3672 "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3673 resid.derivative[i],
3674 fd
3675 );
3676 }
3677 }
3678 }
3679
3680 #[test]
3681 fn offset_channel_residuals_respect_zero_sampleweight() {
3682 let age_entry = array![1.0_f64, 2.0];
3683 let age_exit = array![2.0_f64, 3.5];
3684 let event_target = array![1u8, 1u8];
3685 let event_competing = array![0u8, 0u8];
3686 let sampleweight = array![0.0_f64, 1.2]; let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3688 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3689 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3690 let penalties = PenaltyBlocks::new(Vec::new());
3691 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3692 let beta = array![-1.0_f64, 0.8];
3693
3694 let model = survival_model_with_offsets(
3695 survival_inputs(
3696 &age_entry,
3697 &age_exit,
3698 &event_target,
3699 &event_competing,
3700 &sampleweight,
3701 &x_entry,
3702 &x_exit,
3703 &x_derivative,
3704 ),
3705 Some(SurvivalBaselineOffsets {
3706 eta_entry: array![0.0_f64, 0.1].view(),
3707 eta_exit: array![0.0_f64, 0.2].view(),
3708 derivative_exit: array![0.0_f64, 0.1].view(),
3709 }),
3710 penalties,
3711 mono,
3712 SurvivalSpec::Net,
3713 )
3714 .expect("model");
3715 let r = model.offset_channel_residuals(&beta).expect("resid");
3716 assert_eq!(r.exit[0], 0.0);
3718 assert_eq!(r.entry[0], 0.0);
3719 assert_eq!(r.derivative[0], 0.0);
3720 assert!(r.exit[1] != 0.0);
3722 }
3723
3724 #[test]
3725 fn offset_channel_residuals_reject_beta_dim_mismatch() {
3726 let age_entry = array![1.0_f64];
3727 let age_exit = array![2.0_f64];
3728 let event_target = array![1u8];
3729 let event_competing = array![0u8];
3730 let sampleweight = array![1.0_f64];
3731 let x_entry = array![[1.0, 0.0]];
3732 let x_exit = array![[1.0, 0.7]];
3733 let x_derivative = array![[0.0, 0.5]];
3734 let penalties = PenaltyBlocks::new(Vec::new());
3735 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3736 let model = survival_model(
3737 survival_inputs(
3738 &age_entry,
3739 &age_exit,
3740 &event_target,
3741 &event_competing,
3742 &sampleweight,
3743 &x_entry,
3744 &x_exit,
3745 &x_derivative,
3746 ),
3747 penalties,
3748 mono,
3749 SurvivalSpec::Net,
3750 )
3751 .expect("model");
3752 let bad_beta = array![0.0_f64]; let err = model
3754 .offset_channel_residuals(&bad_beta)
3755 .expect_err("mismatch must error");
3756 match err {
3757 EstimationError::InvalidInput(msg) => {
3758 assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3759 }
3760 other => panic!("expected InvalidInput, got {other:?}"),
3761 }
3762 }
3763
3764 #[test]
3765 fn crudespec_is_rejected_by_one_hazard_engine() {
3766 let age_entry = array![1.0_f64];
3767 let age_exit = array![2.0_f64];
3768 let event_target = array![0u8];
3769 let event_competing = array![1u8];
3770 let sampleweight = array![1.0];
3771 let x_entry = array![[0.1]];
3772 let x_exit = array![[0.4]];
3773 let x_derivative = array![[1.0]];
3774 let penalties = PenaltyBlocks::new(Vec::new());
3775 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3776
3777 let err = survival_model(
3778 survival_inputs(
3779 &age_entry,
3780 &age_exit,
3781 &event_target,
3782 &event_competing,
3783 &sampleweight,
3784 &x_entry,
3785 &x_exit,
3786 &x_derivative,
3787 ),
3788 penalties,
3789 mono,
3790 SurvivalSpec::Crude,
3791 )
3792 .expect_err("crude fitting should be rejected by the one-hazard engine");
3793 assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3794 }
3795
3796 #[test]
3797 fn nonstructural_models_require_explicit_monotonicity_collocation() {
3798 let age_entry = array![1.0_f64, 1.5_f64];
3799 let age_exit = array![2.0_f64, 2.5_f64];
3800 let event_target = array![0u8, 0u8];
3801 let event_competing = array![0u8, 1u8];
3802 let sampleweight = array![1.0, 1.0];
3803 let x_entry = array![[0.2], [0.1]];
3804 let x_exit = array![[0.3], [0.2]];
3805 let x_derivative = array![[1.0], [1.0]];
3806
3807 let model = survival_model(
3808 survival_inputs(
3809 &age_entry,
3810 &age_exit,
3811 &event_target,
3812 &event_competing,
3813 &sampleweight,
3814 &x_entry,
3815 &x_exit,
3816 &x_derivative,
3817 ),
3818 PenaltyBlocks::new(Vec::new()),
3819 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3820 SurvivalSpec::Net,
3821 )
3822 .expect("construct censored survival model");
3823
3824 assert!(
3825 model.monotonicity_linear_constraints().is_none(),
3826 "non-structural survival models must not fabricate rowwise monotonicity constraints"
3827 );
3828 }
3829
3830 #[test]
3831 fn decreasing_interval_is_rejectedwithout_target_events() {
3832 let age_entry = array![1.0_f64];
3833 let age_exit = array![2.0_f64];
3834 let event_target = array![0u8];
3835 let event_competing = array![0u8];
3836 let sampleweight = array![1.0];
3837 let x_entry = array![[0.5]];
3838 let x_exit = array![[0.0]];
3839 let x_derivative = array![[1.0]];
3840
3841 let model = survival_model(
3842 survival_inputs(
3843 &age_entry,
3844 &age_exit,
3845 &event_target,
3846 &event_competing,
3847 &sampleweight,
3848 &x_entry,
3849 &x_exit,
3850 &x_derivative,
3851 ),
3852 PenaltyBlocks::new(Vec::new()),
3853 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3854 SurvivalSpec::Net,
3855 )
3856 .expect("construct censored survival model");
3857
3858 let err = model
3859 .update_state(&array![1.0])
3860 .expect_err("decreasing cumulative hazard increment should be rejected");
3861 assert!(
3862 err.to_string().contains("cumulative hazard decreased"),
3863 "unexpected error: {err}"
3864 );
3865 }
3866
3867 fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3868 calculate_crude_risk_quadrature(
3869 0.0,
3870 1.0,
3871 &[0.0, 1.0],
3872 beta_d.exp(),
3873 beta_m.exp(),
3874 array![1.0].view(),
3875 array![1.0].view(),
3876 |u, design_d, deriv_d, design_m| {
3877 let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3878 let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3879 let inst_hazard_d = 0.2 * beta_d.exp();
3880 design_d[0] = 1.0;
3881 deriv_d[0] = 0.0;
3884 design_m[0] = 1.0;
3885 Ok((inst_hazard_d, cumulative_d, cumulative_m))
3886 },
3887 )
3888 .expect("smooth crude-risk quadrature should succeed")
3889 }
3890
3891 #[test]
3892 fn crude_riskgradient_matches_monotoneobjective() {
3893 let beta_d = -0.2_f64;
3894 let beta_m = -0.5_f64;
3895 let result = smooth_crude_risk(beta_d, beta_m);
3896 let eps = 1e-6;
3897
3898 let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3899 - smooth_crude_risk(beta_d - eps, beta_m).risk)
3900 / (2.0 * eps);
3901 let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3902 - smooth_crude_risk(beta_d, beta_m - eps).risk)
3903 / (2.0 * eps);
3904
3905 assert!(
3906 (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3907 "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3908 result.diseasegradient[0]
3909 );
3910 assert!(
3911 (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3912 "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3913 result.mortalitygradient[0]
3914 );
3915 }
3916
3917 #[test]
3918 fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3919 let age_entry = array![1.0_f64, 2.0_f64];
3920 let age_exit = array![2.0_f64, 3.5_f64];
3921 let event_target = array![1u8, 0u8];
3922 let event_competing = array![0u8, 0u8];
3923 let sampleweight = array![1.0, 1.0];
3924 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3925 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3926 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3927 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3928 matrix: array![[2.0]],
3929 lambda: 1.7,
3930 range: 1..2,
3931 nullspace_dim: 0,
3932 }]);
3933 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3934 let beta = array![-1.2, 0.4];
3935
3936 let model = survival_model(
3937 survival_inputs(
3938 &age_entry,
3939 &age_exit,
3940 &event_target,
3941 &event_competing,
3942 &sampleweight,
3943 &x_entry,
3944 &x_exit,
3945 &x_derivative,
3946 ),
3947 penalties.clone(),
3948 mono,
3949 SurvivalSpec::Net,
3950 )
3951 .expect("construct survival model");
3952
3953 let state = model.update_state(&beta).expect("survival state");
3954 let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3955 assert!(
3956 (state.penalty_term - expected_penalty).abs() < 1e-12,
3957 "penalty_term mismatch: state={} expected={}",
3958 state.penalty_term,
3959 expected_penalty
3960 );
3961 }
3962
3963 #[test]
3964 fn negative_penalty_lambda_is_rejected() {
3965 let age_entry = array![1.0_f64];
3966 let age_exit = array![2.0_f64];
3967 let event_target = array![1u8];
3968 let event_competing = array![0u8];
3969 let sampleweight = array![1.0];
3970 let x_entry = array![[1.0, 0.0]];
3971 let x_exit = array![[1.0, 0.5]];
3972 let x_derivative = array![[0.0, 1.0]];
3973 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3974 matrix: array![[1.0]],
3975 lambda: -0.1,
3976 range: 1..2,
3977 nullspace_dim: 0,
3978 }]);
3979
3980 let err = survival_model(
3981 survival_inputs(
3982 &age_entry,
3983 &age_exit,
3984 &event_target,
3985 &event_competing,
3986 &sampleweight,
3987 &x_entry,
3988 &x_exit,
3989 &x_derivative,
3990 ),
3991 penalties,
3992 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3993 SurvivalSpec::Net,
3994 )
3995 .expect_err("negative lambda must be rejected");
3996
3997 assert!(matches!(err, SurvivalError::NonFiniteInput));
3998 }
3999
4000 #[test]
4001 fn penalty_block_range_and_shapemust_match_coefficients() {
4002 let age_entry = array![1.0_f64];
4003 let age_exit = array![2.0_f64];
4004 let event_target = array![1u8];
4005 let event_competing = array![0u8];
4006 let sampleweight = array![1.0];
4007 let x_entry = array![[1.0, 0.0]];
4008 let x_exit = array![[1.0, 0.5]];
4009 let x_derivative = array![[0.0, 1.0]];
4010 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4011 matrix: array![[1.0]],
4012 lambda: 0.5,
4013 range: 0..2,
4014 nullspace_dim: 0,
4015 }]);
4016
4017 let err = survival_model(
4018 survival_inputs(
4019 &age_entry,
4020 &age_exit,
4021 &event_target,
4022 &event_competing,
4023 &sampleweight,
4024 &x_entry,
4025 &x_exit,
4026 &x_derivative,
4027 ),
4028 penalties,
4029 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4030 SurvivalSpec::Net,
4031 )
4032 .expect_err("penalty block geometry must match coefficient support");
4033
4034 assert!(matches!(err, SurvivalError::DimensionMismatch));
4035 }
4036
4037 #[test]
4038 fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4039 let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4040 let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4041 let event_target = array![1u8, 0u8, 1u8];
4042 let event_competing = array![0u8, 0u8, 0u8];
4043 let sampleweight = array![1.0, 1.0, 1.0];
4044 let x_entry = array![
4045 [1.0, age_entry[0].ln()],
4046 [1.0, age_entry[1].ln()],
4047 [1.0, age_entry[2].ln()]
4048 ];
4049 let x_exit = array![
4050 [1.0, age_exit[0].ln()],
4051 [1.0, age_exit[1].ln()],
4052 [1.0, age_exit[2].ln()]
4053 ];
4054 let x_derivative = array![
4055 [0.0, 1.0 / age_exit[0]],
4056 [0.0, 1.0 / age_exit[1]],
4057 [0.0, 1.0 / age_exit[2]]
4058 ];
4059 let penalties = PenaltyBlocks::new(Vec::new());
4060 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4061 let beta = array![-1.0, 3.0];
4062
4063 let model = survival_model(
4064 survival_inputs(
4065 &age_entry,
4066 &age_exit,
4067 &event_target,
4068 &event_competing,
4069 &sampleweight,
4070 &x_entry,
4071 &x_exit,
4072 &x_derivative,
4073 ),
4074 penalties,
4075 mono,
4076 SurvivalSpec::Net,
4077 )
4078 .expect("construct survival model");
4079
4080 let state = model.update_state(&beta).expect("state at beta");
4081 let eps = 1e-7;
4082 for j in 0..beta.len() {
4083 let mut plus = beta.clone();
4084 let mut minus = beta.clone();
4085 plus[j] += eps;
4086 minus[j] -= eps;
4087 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4088 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4089 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4090 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4091 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4092 assert_eq!(
4093 state.gradient[j].signum(),
4094 fd.signum(),
4095 "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4096 state.gradient[j]
4097 );
4098 assert!(
4099 (state.gradient[j] - fd).abs() < 1e-5,
4100 "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4101 state.gradient[j]
4102 );
4103 }
4104 }
4105
4106 fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4107 let age_entry: Array1<f64> = Array1::from(vec![
4114 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4115 34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4116 ]);
4117 let age_exit: Array1<f64> = Array1::from(vec![
4118 45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4119 48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4120 ]);
4121 let event_target = Array1::from(vec![
4122 1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4123 ]);
4124 let event_competing = Array1::<u8>::zeros(age_entry.len());
4125 let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4126 let n = age_entry.len();
4127 let ln_age_mean: f64 = {
4128 let mut sum = 0.0;
4129 for i in 0..n {
4130 sum += age_entry[i].ln() + age_exit[i].ln();
4131 }
4132 sum / (2.0 * n as f64)
4133 };
4134 let mut x_entry = Array2::<f64>::zeros((n, 2));
4135 let mut x_exit = Array2::<f64>::zeros((n, 2));
4136 let mut x_derivative = Array2::<f64>::zeros((n, 2));
4137 for i in 0..n {
4138 x_entry[[i, 0]] = 1.0;
4139 x_exit[[i, 0]] = 1.0;
4140 x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4141 x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4142 x_derivative[[i, 0]] = 0.0;
4143 x_derivative[[i, 1]] = 1.0 / age_exit[i];
4144 }
4145 let penalties = PenaltyBlocks::new(vec![
4146 PenaltyBlock {
4147 matrix: array![[3.0]],
4148 lambda: 0.0,
4149 range: 0..1,
4150 nullspace_dim: 0,
4151 },
4152 PenaltyBlock {
4153 matrix: array![[2.5]],
4154 lambda,
4155 range: 1..2,
4156 nullspace_dim: 0,
4157 },
4158 ]);
4159 survival_model(
4160 survival_inputs(
4161 &age_entry,
4162 &age_exit,
4163 &event_target,
4164 &event_competing,
4165 &sampleweight,
4166 &x_entry,
4167 &x_exit,
4168 &x_derivative,
4169 ),
4170 penalties,
4171 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4172 SurvivalSpec::Net,
4173 )
4174 .expect("construct LAML FD survival model")
4175 }
4176
4177 fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4178 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4179 use gam_linalg::faer_ndarray::FaerEigh;
4180
4181 let h_dense = state.hessian.to_dense();
4182 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4183 let eps = spectral_epsilon(evals.as_slice().unwrap());
4184 evals
4185 .iter()
4186 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4187 .sum()
4188 }
4189
4190 #[test]
4191 fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4192 let rho0 = -0.35_f64;
4206 let beta = array![-2.5_f64, 1.0];
4207 let model = laml_fd_test_model(rho0.exp());
4208 let state = model
4209 .update_state(&beta)
4210 .expect("state for LAML prefix-skip test");
4211
4212 assert_eq!(model.penalties.blocks.len(), 2);
4217 assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4218 assert!(model.penalties.blocks[1].lambda > 0.0);
4219
4220 let rho = Array1::from_iter(
4221 model
4222 .penalties
4223 .blocks
4224 .iter()
4225 .filter(|b| b.lambda > 0.0)
4226 .map(|b| b.lambda.ln()),
4227 );
4228 assert_eq!(
4229 rho.len(),
4230 1,
4231 "fixture should expose exactly one active penalty block for the rho vector"
4232 );
4233
4234 let (obj, grad) = model
4235 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4236 .expect("survival LAML objective and gradient");
4237
4238 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4239 - 0.5 * (rho0 + 2.5_f64.ln());
4240 assert_eq!(
4241 grad.len(),
4242 1,
4243 "rho-gradient must match the active-penalty count, not the full block list"
4244 );
4245 assert!(
4246 (obj - expected).abs() < 1e-10,
4247 "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4248 );
4249 assert!(
4250 grad[0].is_finite(),
4251 "rho-gradient must be finite: {}",
4252 grad[0]
4253 );
4254 }
4255
4256 #[test]
4257 fn structural_monotonicgradient_matchesobjectivefd() {
4258 let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4259 let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4260 let event_target = array![1u8, 0u8, 1u8];
4261 let event_competing = array![0u8, 0u8, 0u8];
4262 let sampleweight = array![1.0, 1.0, 1.0];
4263
4264 let x_entry = array![
4267 [1.0, 0.2, 0.05, -0.7],
4268 [1.0, 0.5, 0.20, 0.1],
4269 [1.0, 0.9, 0.60, 1.2]
4270 ];
4271 let x_exit = array![
4272 [1.0, 0.4, 0.16, -0.7],
4273 [1.0, 0.8, 0.64, 0.1],
4274 [1.0, 1.1, 1.21, 1.2]
4275 ];
4276 let x_derivative = array![
4277 [0.0, 0.8, 0.64, 0.0],
4278 [0.0, 0.7, 1.12, 0.0],
4279 [0.0, 0.6, 1.32, 0.0]
4280 ];
4281 let penalties = PenaltyBlocks::new(Vec::new());
4282 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4283 let mut model = survival_model(
4284 survival_inputs(
4285 &age_entry,
4286 &age_exit,
4287 &event_target,
4288 &event_competing,
4289 &sampleweight,
4290 &x_entry,
4291 &x_exit,
4292 &x_derivative,
4293 ),
4294 penalties,
4295 mono,
4296 SurvivalSpec::Net,
4297 )
4298 .expect("construct structural survival model");
4299 model
4300 .set_structural_monotonicity(true, 3)
4301 .expect("enable structural monotonicity");
4302 let constraints = model
4303 .monotonicity_linear_constraints()
4304 .expect("structural derivative constraints");
4305 assert_eq!(constraints.a.nrows(), 2);
4306 assert_eq!(constraints.a.ncols(), 4);
4307 assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4308 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4309 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4310
4311 let beta = array![0.2, 0.2, 0.1, 0.2];
4312 let state = model.update_state(&beta).expect("state at structural beta");
4313 let eps = 1e-7;
4314 for j in 0..beta.len() {
4315 let mut plus = beta.clone();
4316 let mut minus = beta.clone();
4317 plus[j] += eps;
4318 minus[j] -= eps;
4319 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4320 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4321 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4322 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4323 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4324 assert_eq!(
4325 state.gradient[j].signum(),
4326 fd.signum(),
4327 "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4328 state.gradient[j]
4329 );
4330 assert!(
4331 (state.gradient[j] - fd).abs() < 2e-5,
4332 "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4333 state.gradient[j]
4334 );
4335 }
4336 }
4337
4338 #[test]
4339 fn structural_monotonic_lamlgradient_returns_finitevalues() {
4340 let age_entry = array![1.0_f64, 1.2_f64];
4341 let age_exit = array![1.5_f64, 2.0_f64];
4342 let event_target = array![1u8, 0u8];
4343 let event_competing = array![0u8, 0u8];
4344 let sampleweight = array![1.0, 1.0];
4345 let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4346 let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4347 let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4348 let penalties = PenaltyBlocks::new(Vec::new());
4349 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4350 let mut model = survival_model(
4351 survival_inputs(
4352 &age_entry,
4353 &age_exit,
4354 &event_target,
4355 &event_competing,
4356 &sampleweight,
4357 &x_entry,
4358 &x_exit,
4359 &x_derivative,
4360 ),
4361 penalties,
4362 mono,
4363 SurvivalSpec::Net,
4364 )
4365 .expect("construct structural survival model");
4366 model
4367 .set_structural_monotonicity(true, 2)
4368 .expect("enable structural monotonicity");
4369 model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4371 matrix: array![[1.0]],
4372 lambda: 0.7,
4373 range: 1..2,
4374 nullspace_dim: 0,
4375 }]);
4376 let beta = array![0.2, 0.2, 0.1];
4377 let state = model.update_state(&beta).expect("state at structural beta");
4378 let rho = Array1::from_iter(
4379 model
4380 .penalties
4381 .blocks
4382 .iter()
4383 .filter(|b| b.lambda > 0.0)
4384 .map(|b| b.lambda.ln()),
4385 );
4386 let (obj, grad) = model
4387 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4388 .expect("laml gradient should work in structural mode");
4389 assert!(obj.is_finite());
4390 assert_eq!(grad.len(), 1);
4391 assert!(grad[0].is_finite());
4392 }
4393
4394 #[test]
4395 fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4396 let age_entry = array![1.0_f64];
4397 let age_exit = array![2.0_f64];
4398 let event_target = array![1u8];
4399 let event_competing = array![0u8];
4400 let sampleweight = array![1.0];
4401 let x_entry = array![[0.0]];
4402 let x_exit = array![[0.2]];
4403 let x_derivative = array![[1.0]];
4404
4405 let penalties = PenaltyBlocks::new(Vec::new());
4406 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4407 let mut model = survival_model(
4408 survival_inputs(
4409 &age_entry,
4410 &age_exit,
4411 &event_target,
4412 &event_competing,
4413 &sampleweight,
4414 &x_entry,
4415 &x_exit,
4416 &x_derivative,
4417 ),
4418 penalties,
4419 mono,
4420 SurvivalSpec::Net,
4421 )
4422 .expect("construct structural survival model");
4423
4424 let beta = array![-3.0];
4425 assert!(
4426 model.update_state(&beta).is_err(),
4427 "negative derivative coefficient should violate derivative guard"
4428 );
4429
4430 model
4431 .set_structural_monotonicity(true, 1)
4432 .expect("enable structural monotonicity");
4433 let constraints = model
4434 .monotonicity_linear_constraints()
4435 .expect("structural derivative constraints");
4436 assert_eq!(constraints.a.nrows(), 1);
4437 assert_eq!(constraints.a.ncols(), 1);
4438 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4439 assert!(constraints.b[0].abs() <= 1e-12);
4441 let state = model
4442 .update_state(&array![1e-6])
4443 .expect("small positive derivative coefficient should remain feasible");
4444 assert!(state.deviance.is_finite());
4445 }
4446
4447 #[test]
4448 fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4449 let age_entry = array![1.0_f64];
4450 let age_exit = array![2.0_f64];
4451 let event_target = array![1u8];
4452 let event_competing = array![0u8];
4453 let sampleweight = array![1.0];
4454 let x_entry = array![[1.0, 0.0]];
4455 let x_exit = array![[1.0, 0.0]];
4456 let x_derivative = array![[0.0, 0.0]];
4457 let penalties = PenaltyBlocks::new(Vec::new());
4458 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4459 let eta_entry_offset = array![0.0];
4460 let eta_exit_offset = array![0.0];
4461 let derivative_offset_below_guard = array![2.0];
4462 let derivative_offset_above_guard = array![3.1];
4463 let offsets_below_guard = SurvivalBaselineOffsets {
4464 eta_entry: eta_entry_offset.view(),
4465 eta_exit: eta_exit_offset.view(),
4466 derivative_exit: derivative_offset_below_guard.view(),
4467 };
4468 let offsets_above_guard = SurvivalBaselineOffsets {
4469 eta_entry: eta_entry_offset.view(),
4470 eta_exit: eta_exit_offset.view(),
4471 derivative_exit: derivative_offset_above_guard.view(),
4472 };
4473
4474 let model_below_guard = survival_model_with_offsets(
4475 survival_inputs(
4476 &age_entry,
4477 &age_exit,
4478 &event_target,
4479 &event_competing,
4480 &sampleweight,
4481 &x_entry,
4482 &x_exit,
4483 &x_derivative,
4484 ),
4485 Some(offsets_below_guard),
4486 penalties.clone(),
4487 monotonicity,
4488 SurvivalSpec::Net,
4489 )
4490 .expect("construct model with derivative offset below guard");
4491 let err = model_below_guard
4492 .update_state(&array![0.0, 0.0])
4493 .expect_err("derivative offset below guard should be rejected");
4494 let err_text = err.to_string();
4495 assert!(
4496 err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4497 "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4498 );
4499
4500 let model_above_guard = survival_model_with_offsets(
4501 survival_inputs(
4502 &age_entry,
4503 &age_exit,
4504 &event_target,
4505 &event_competing,
4506 &sampleweight,
4507 &x_entry,
4508 &x_exit,
4509 &x_derivative,
4510 ),
4511 Some(offsets_above_guard),
4512 penalties,
4513 SurvivalMonotonicityPenalty { tolerance: 3.0 },
4514 SurvivalSpec::Net,
4515 )
4516 .expect("construct model with derivative offset above guard");
4517 let state = model_above_guard
4518 .update_state(&array![0.0, 0.0])
4519 .expect("derivative offset above guard should remain feasible");
4520 assert!(state.deviance.is_finite());
4521 }
4522
4523 #[test]
4524 fn structural_monotonicity_rejects_negative_derivative_offsets() {
4525 let age_entry = array![1.0_f64];
4526 let age_exit = array![2.0_f64];
4527 let event_target = array![1u8];
4528 let event_competing = array![0u8];
4529 let sampleweight = array![1.0];
4530 let x_entry = array![[0.0]];
4531 let x_exit = array![[0.2]];
4532 let x_derivative = array![[1.0]];
4533 let eta_entry = array![0.0];
4534 let eta_exit = array![0.0];
4535 let derivative_exit = array![-1e-3];
4536 let offsets = SurvivalBaselineOffsets {
4537 eta_entry: eta_entry.view(),
4538 eta_exit: eta_exit.view(),
4539 derivative_exit: derivative_exit.view(),
4540 };
4541
4542 let mut model = survival_model_with_offsets(
4543 survival_inputs(
4544 &age_entry,
4545 &age_exit,
4546 &event_target,
4547 &event_competing,
4548 &sampleweight,
4549 &x_entry,
4550 &x_exit,
4551 &x_derivative,
4552 ),
4553 Some(offsets),
4554 PenaltyBlocks::new(Vec::new()),
4555 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4556 SurvivalSpec::Net,
4557 )
4558 .expect("construct structural survival model");
4559 let err = model
4560 .set_structural_monotonicity(true, 1)
4561 .expect_err("negative derivative offsets must be rejected");
4562 assert!(
4563 err.to_string()
4564 .contains("structural monotonicity requires nonnegative derivative offsets"),
4565 "unexpected error: {err}"
4566 );
4567 }
4568
4569 #[test]
4570 fn structural_monotonicity_emits_coefficient_constraints() {
4571 let age_entry = array![1.0_f64, 1.5_f64];
4572 let age_exit = array![2.0_f64, 3.0_f64];
4573 let event_target = array![1u8, 0u8];
4574 let event_competing = array![0u8, 0u8];
4575 let sampleweight = array![1.0, 1.0];
4576 let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4577 let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4578 let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4579
4580 let mut model = survival_model(
4581 survival_inputs(
4582 &age_entry,
4583 &age_exit,
4584 &event_target,
4585 &event_competing,
4586 &sampleweight,
4587 &x_entry,
4588 &x_exit,
4589 &x_derivative,
4590 ),
4591 PenaltyBlocks::new(Vec::new()),
4592 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4593 SurvivalSpec::Net,
4594 )
4595 .expect("construct structural survival model");
4596 model
4597 .set_structural_monotonicity(true, 2)
4598 .expect("enable structural monotonicity");
4599
4600 let constraints = model
4601 .monotonicity_linear_constraints()
4602 .expect("structural derivative constraints");
4603
4604 assert_eq!(constraints.a.nrows(), 2);
4605 assert_eq!(constraints.a.ncols(), 3);
4606 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4607 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4608 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4609 }
4610
4611 #[test]
4612 fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4613 let age_entry = array![1.0_f64];
4614 let age_exit = array![2.0_f64];
4615 let event_target = array![1u8];
4616 let event_competing = array![0u8];
4617 let sampleweight = array![1.0];
4618 let x_entry = array![[1.0, 0.2]];
4619 let x_exit = array![[1.0, 0.6]];
4620 let x_derivative = array![[0.0, 1.0]];
4621
4622 let mut model = survival_model(
4623 survival_inputs(
4624 &age_entry,
4625 &age_exit,
4626 &event_target,
4627 &event_competing,
4628 &sampleweight,
4629 &x_entry,
4630 &x_exit,
4631 &x_derivative,
4632 ),
4633 PenaltyBlocks::new(Vec::new()),
4634 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4635 SurvivalSpec::Net,
4636 )
4637 .expect("construct structural survival model");
4638 model
4639 .set_structural_monotonicity(true, 2)
4640 .expect("enable structural monotonicity");
4641
4642 let constraints = model
4643 .monotonicity_linear_constraints()
4644 .expect("structural derivative constraints");
4645
4646 assert_eq!(constraints.a.nrows(), 1);
4647 assert!(
4648 constraints.a[[0, 0]].abs() <= 1e-12,
4649 "inactive time column should remain unconstrained"
4650 );
4651 assert!(
4652 (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4653 "active time column should remain constrained"
4654 );
4655 }
4656
4657 #[test]
4658 fn structural_monotonicity_preserves_sparse_row_patterns() {
4659 let age_entry = array![1.0_f64, 1.5_f64];
4660 let age_exit = array![2.0_f64, 2.5_f64];
4661 let event_target = array![1u8, 1u8];
4662 let event_competing = array![0u8, 0u8];
4663 let sampleweight = array![1.0, 1.0];
4664 let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4665 let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4666 let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4667
4668 let mut model = survival_model(
4669 survival_inputs(
4670 &age_entry,
4671 &age_exit,
4672 &event_target,
4673 &event_competing,
4674 &sampleweight,
4675 &x_entry,
4676 &x_exit,
4677 &x_derivative,
4678 ),
4679 PenaltyBlocks::new(Vec::new()),
4680 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4681 SurvivalSpec::Net,
4682 )
4683 .expect("construct structural survival model");
4684 model
4685 .set_structural_monotonicity(true, 2)
4686 .expect("enable structural monotonicity");
4687
4688 let constraints = model
4689 .monotonicity_linear_constraints()
4690 .expect("structural derivative constraints");
4691
4692 assert_eq!(constraints.a.nrows(), 2);
4693 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4694 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4695 }
4696
4697 #[test]
4698 fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4699 let age_entry = array![1.0_f64];
4700 let age_exit = array![1.1_f64];
4701 let event_target = array![0u8];
4702 let event_competing = array![0u8];
4703 let sampleweight = array![1.0];
4704 let x_entry = array![[0.0]];
4705 let x_exit = array![[0.0]];
4706 let x_derivative = array![[-1.0]];
4707 let penalties = PenaltyBlocks::new(Vec::new());
4708 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4709 let model = survival_model(
4710 survival_inputs(
4711 &age_entry,
4712 &age_exit,
4713 &event_target,
4714 &event_competing,
4715 &sampleweight,
4716 &x_entry,
4717 &x_exit,
4718 &x_derivative,
4719 ),
4720 penalties,
4721 mono,
4722 SurvivalSpec::Net,
4723 )
4724 .expect("construct censored survival model");
4725
4726 let err = model
4727 .update_state(&array![1.0])
4728 .expect_err("censored row should still enforce monotonic derivative");
4729 assert!(
4730 matches!(err, EstimationError::ParameterConstraintViolation(_)),
4731 "unexpected error: {err:?}"
4732 );
4733 }
4734
4735 fn crude_risk_quadrature_error(
4736 cumulative_entry: f64,
4737 cumulative_exit: f64,
4738 hazard_exit: f64,
4739 ) -> SurvivalError {
4740 calculate_crude_risk_quadrature(
4741 1.0,
4742 2.0,
4743 &[],
4744 0.4,
4745 0.2,
4746 array![1.0].view(),
4747 array![1.0].view(),
4748 |_, design_d, deriv_d, design_m| {
4749 design_d[0] = 1.0;
4750 deriv_d[0] = 0.0;
4751 design_m[0] = 1.0;
4752 Ok((cumulative_entry, cumulative_exit, hazard_exit))
4753 },
4754 )
4755 .expect_err("invalid hazards should fail")
4756 }
4757
4758 #[test]
4759 fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4760 let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4761 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4762 }
4763
4764 #[test]
4765 fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4766 let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4767 assert!(matches!(err, SurvivalError::NonPositiveHazard));
4768 }
4769
4770 #[test]
4771 fn laml_no_penalties_matches_documentedobjective() {
4772 let age_entry = array![40.0, 45.0, 50.0, 55.0];
4773 let age_exit = array![44.0, 49.0, 54.0, 59.0];
4774 let event_target = array![1u8, 0u8, 1u8, 0u8];
4775 let event_competing = Array1::<u8>::zeros(4);
4776 let sampleweight = Array1::ones(4);
4777 let x_entry = array![
4778 [1.0, -0.2, 0.04],
4779 [1.0, -0.1, 0.01],
4780 [1.0, 0.0, 0.0],
4781 [1.0, 0.1, 0.01]
4782 ];
4783 let x_exit = array![
4784 [1.0, -0.12, 0.0144],
4785 [1.0, -0.02, 0.0004],
4786 [1.0, 0.08, 0.0064],
4787 [1.0, 0.18, 0.0324]
4788 ];
4789 let x_derivative = array![
4790 [0.0, 0.02, 0.001],
4791 [0.0, 0.02, 0.001],
4792 [0.0, 0.02, 0.001],
4793 [0.0, 0.02, 0.001]
4794 ];
4795 let penalties = PenaltyBlocks::new(Vec::new());
4796 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4797 let beta = array![-2.0, 0.7, 0.2];
4798
4799 let model = survival_model(
4800 survival_inputs(
4801 &age_entry,
4802 &age_exit,
4803 &event_target,
4804 &event_competing,
4805 &sampleweight,
4806 &x_entry,
4807 &x_exit,
4808 &x_derivative,
4809 ),
4810 penalties,
4811 mono,
4812 SurvivalSpec::Net,
4813 )
4814 .expect("construct survival model");
4815
4816 let state = model.update_state(&beta).expect("state at beta");
4817 let rho = Array1::from_iter(
4818 model
4819 .penalties
4820 .blocks
4821 .iter()
4822 .filter(|b| b.lambda > 0.0)
4823 .map(|b| b.lambda.ln()),
4824 );
4825 let (obj, grad) = model
4826 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4827 .expect("laml objective for no-penalty model");
4828
4829 let h_dense = state.hessian.to_dense();
4830 let logdet_h: f64 = {
4831 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4832 use gam_linalg::faer_ndarray::FaerEigh;
4833 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4834 let eps = spectral_epsilon(evals.as_slice().unwrap());
4835 evals
4836 .iter()
4837 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4838 .sum()
4839 };
4840 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4841
4842 assert_eq!(grad.len(), 0);
4843 assert!(
4844 (obj - expected).abs() < 1e-10,
4845 "no-penalty LAML objective mismatch: obj={} expected={}",
4846 obj,
4847 expected
4848 );
4849 }
4850
4851 #[test]
4852 fn monotonicity_constraints_collapse_positive_collinearrows() {
4853 let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4854 let b = array![1e-8, 1e-8, 1e-8];
4855
4856 let compressed = compress_positive_collinear_constraints(&a, &b);
4857
4858 assert_eq!(compressed.a.nrows(), 1);
4859 assert_eq!(compressed.a.ncols(), 3);
4860 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4861 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4862 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4863 assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4864 }
4865
4866 #[test]
4867 fn monotonicity_constraints_preserve_distinct_directions() {
4868 let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4869 let b = array![0.2, 0.3, 0.1];
4870
4871 let compressed = compress_positive_collinear_constraints(&a, &b);
4872
4873 assert_eq!(compressed.a.nrows(), 2);
4874 let mut saw_x = false;
4875 let mut saw_y = false;
4876 for i in 0..compressed.a.nrows() {
4877 if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4878 saw_x = true;
4879 assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4880 }
4881 if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4882 saw_y = true;
4883 assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4884 }
4885 }
4886 assert!(saw_x);
4887 assert!(saw_y);
4888 }
4889
4890 #[test]
4891 fn monotonicity_constraints_cluster_near_collinearrows() {
4892 let a = array![
4893 [0.0, 0.5, 0.0],
4894 [0.0, 0.50000000003, 0.0],
4895 [0.0, 0.49999999997, 0.0]
4896 ];
4897 let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4898
4899 let compressed = compress_positive_collinear_constraints(&a, &b);
4900
4901 assert_eq!(compressed.a.nrows(), 1);
4902 assert_eq!(compressed.a.ncols(), 3);
4903 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4904 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4905 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4906 assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4907 }
4908
4909 #[test]
4910 fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4911 let a = array![
4912 [0.0, 0.401, 0.302, 0.197],
4913 [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4914 [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4915 [0.0, 0.125, 0.500, 0.375]
4916 ];
4917 let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4918
4919 let compressed = compress_positive_collinear_constraints(&a, &b);
4920
4921 assert_eq!(compressed.a.nrows(), 2);
4922 let mut clustered_face = false;
4923 let mut distinct_face = false;
4924 for i in 0..compressed.a.nrows() {
4925 let row = compressed.a.row(i);
4926 if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4927 clustered_face = true;
4928 assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4929 } else {
4930 distinct_face = true;
4931 assert!((row[1] - 0.25).abs() <= 1e-12);
4932 assert!((row[2] - 1.0).abs() <= 1e-12);
4933 assert!((row[3] - 0.75).abs() <= 1e-12);
4934 assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4935 }
4936 }
4937 assert!(clustered_face);
4938 assert!(distinct_face);
4939 }
4940
4941 #[test]
4942 fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4943 let age_entry = array![1.0_f64, 1.0, 1.0];
4944 let age_exit = array![2.0_f64, 4.0, 8.0];
4945 let event_target = array![0u8, 1u8, 0u8];
4946 let event_competing = array![0u8, 0u8, 0u8];
4947 let sampleweight = array![1.0, 1.0, 1.0];
4948 let x_entry = array![
4949 [1.0, age_entry[0].ln()],
4950 [1.0, age_entry[1].ln()],
4951 [1.0, age_entry[2].ln()]
4952 ];
4953 let x_exit = array![
4954 [1.0, age_exit[0].ln()],
4955 [1.0, age_exit[1].ln()],
4956 [1.0, age_exit[2].ln()]
4957 ];
4958 let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4959 let penalties = PenaltyBlocks::new(Vec::new());
4960 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4961
4962 let collocation_offsets = Array1::zeros(x_derivative.nrows());
4963 let mut inputs = survival_inputs(
4964 &age_entry,
4965 &age_exit,
4966 &event_target,
4967 &event_competing,
4968 &sampleweight,
4969 &x_entry,
4970 &x_exit,
4971 &x_derivative,
4972 );
4973 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4974 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4975
4976 let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4977 .expect("construct linear survival model");
4978
4979 let constraints = model
4980 .monotonicity_linear_constraints()
4981 .expect("monotonicity constraints");
4982 assert_eq!(constraints.a.nrows(), 1);
4983 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
4984 assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
4985 }
4986
4987 #[test]
4988 fn monotonicity_constraints_skip_numericallyzerorows() {
4989 let age_entry = array![1.0_f64, 1.0, 1.0];
4990 let age_exit = array![2.0_f64, 3.0, 4.0];
4991 let event_target = array![0u8, 0u8, 0u8];
4992 let event_competing = array![0u8, 0u8, 0u8];
4993 let sampleweight = array![1.0, 1.0, 1.0];
4994 let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
4995 let x_exit = x_entry.clone();
4996 let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
4997
4998 let collocation_offsets = Array1::zeros(x_derivative.nrows());
4999 let mut inputs = survival_inputs(
5000 &age_entry,
5001 &age_exit,
5002 &event_target,
5003 &event_competing,
5004 &sampleweight,
5005 &x_entry,
5006 &x_exit,
5007 &x_derivative,
5008 );
5009 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5010 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5011
5012 let model = survival_model(
5013 inputs,
5014 PenaltyBlocks::new(Vec::new()),
5015 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5016 SurvivalSpec::Net,
5017 )
5018 .expect("construct survival model");
5019
5020 let constraints = model
5021 .monotonicity_linear_constraints()
5022 .expect("nonzero derivative row should remain");
5023 assert_eq!(constraints.a.nrows(), 1);
5024 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5025 assert!(constraints.b[0].abs() <= 1e-18);
5026 }
5027
5028 #[test]
5029 fn censoredrows_allowzero_boundary_derivative() {
5030 let age_entry = array![1.0_f64];
5031 let age_exit = array![2.0_f64];
5032 let event_target = array![0u8];
5033 let event_competing = array![0u8];
5034 let sampleweight = array![1.0];
5035 let x_entry = array![[0.0]];
5036 let x_exit = array![[0.0]];
5037 let x_derivative = array![[1.0]];
5038
5039 let model = survival_model(
5040 survival_inputs(
5041 &age_entry,
5042 &age_exit,
5043 &event_target,
5044 &event_competing,
5045 &sampleweight,
5046 &x_entry,
5047 &x_exit,
5048 &x_derivative,
5049 ),
5050 PenaltyBlocks::new(Vec::new()),
5051 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5052 SurvivalSpec::Net,
5053 )
5054 .expect("construct censored survival model");
5055
5056 let state = model
5057 .update_state(&array![0.0])
5058 .expect("censored boundary derivative should remain feasible with zero tolerance");
5059 assert!(state.deviance.is_finite());
5060 }
5061
5062 #[test]
5063 fn eventrows_keep_positive_derivative_constraint() {
5064 let age_entry = array![1.0_f64, 1.0];
5065 let age_exit = array![2.0_f64, 4.0];
5066 let event_target = array![0u8, 1u8];
5067 let event_competing = array![0u8, 0u8];
5068 let sampleweight = array![1.0, 1.0];
5069 let x_entry = array![[0.0], [0.0]];
5070 let x_exit = array![[0.0], [0.0]];
5071 let x_derivative = array![[0.5], [0.25]];
5072
5073 let collocation_offsets = Array1::zeros(x_derivative.nrows());
5074 let mut inputs = survival_inputs(
5075 &age_entry,
5076 &age_exit,
5077 &event_target,
5078 &event_competing,
5079 &sampleweight,
5080 &x_entry,
5081 &x_exit,
5082 &x_derivative,
5083 );
5084 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5085 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5086
5087 let model = survival_model(
5088 inputs,
5089 PenaltyBlocks::new(Vec::new()),
5090 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5091 SurvivalSpec::Net,
5092 )
5093 .expect("construct mixed survival model");
5094
5095 let constraints = model
5096 .monotonicity_linear_constraints()
5097 .expect("event row should induce positive lower bound");
5098 assert_eq!(constraints.a.nrows(), 1);
5099 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5100 assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5101 }
5102
5103 #[test]
5104 fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5105 let age_entry = array![1.0_f64];
5106 let age_exit = array![2.0_f64];
5107 let event_target = array![1u8];
5108 let event_competing = array![0u8];
5109 let sampleweight = array![1.0];
5110 let x_entry = array![[0.0]];
5111 let x_exit = array![[0.0]];
5112 let x_derivative = array![[1.0]];
5113 let mut model = survival_model(
5114 survival_inputs(
5115 &age_entry,
5116 &age_exit,
5117 &event_target,
5118 &event_competing,
5119 &sampleweight,
5120 &x_entry,
5121 &x_exit,
5122 &x_derivative,
5123 ),
5124 PenaltyBlocks::new(Vec::new()),
5125 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5126 SurvivalSpec::Net,
5127 )
5128 .expect("construct survival model");
5129 model
5130 .set_structural_monotonicity(true, 1)
5131 .expect("enable structural monotonicity");
5132
5133 let state = model
5134 .update_state(&array![-1e-8])
5135 .expect("tiny structural roundoff should be clamped");
5136 assert!(state.deviance.is_finite());
5137 }
5138
5139 #[test]
5140 fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5141 let uncompressed_constraints = LinearInequalityConstraints {
5142 a: array![
5143 [0.0, 0.5, 0.0],
5144 [0.0, 1.0 / 3.0, 0.0],
5145 [0.0, 0.2, 0.0],
5146 [0.0, 0.125, 0.0]
5147 ],
5148 b: Array1::from_elem(4, 1e-8),
5149 };
5150 let compressed_constraints = compress_positive_collinear_constraints(
5151 &uncompressed_constraints.a,
5152 &uncompressed_constraints.b,
5153 );
5154
5155 let candidates = [
5156 array![0.0, 1e-9, 0.0],
5157 array![0.0, 4e-8, 0.0],
5158 array![0.0, 8e-8, 0.0],
5159 array![0.0, 2e-7, 1.5],
5160 ];
5161 for beta in candidates {
5162 let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5163 uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5164 });
5165 let compressed_ok = (0..compressed_constraints.a.nrows())
5166 .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5167 assert_eq!(compressed_ok, uncompressed_ok);
5168 }
5169 }
5170
5171 #[test]
5172 fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5173 let age_entry = array![10.0_f64, 20.0, 25.0];
5174 let age_exit = array![15.0_f64, 30.0, 40.0];
5175 let event_target = array![1u8, 0u8, 1u8];
5176 let event_competing = array![0u8, 0u8, 0u8];
5177 let sampleweight = array![1.0, 2.0, 0.5];
5178 let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5179 let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5180 let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5181 let beta = array![0.8, 1.1, -0.2];
5182
5183 let base_model = survival_model(
5184 survival_inputs(
5185 &age_entry,
5186 &age_exit,
5187 &event_target,
5188 &event_competing,
5189 &sampleweight,
5190 &x_entry,
5191 &x_exit,
5192 &x_derivative,
5193 ),
5194 PenaltyBlocks::new(Vec::new()),
5195 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5196 SurvivalSpec::Net,
5197 )
5198 .expect("construct base survival model");
5199 let base_state = base_model
5200 .update_state(&beta)
5201 .expect("evaluate base survival state");
5202
5203 let time_scale = 365.25;
5204 let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5205 let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5206 let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5207 let scaled_model = survival_model(
5208 survival_inputs(
5209 &scaled_age_entry,
5210 &scaled_age_exit,
5211 &event_target,
5212 &event_competing,
5213 &sampleweight,
5214 &x_entry,
5215 &x_exit,
5216 &scaled_x_derivative,
5217 ),
5218 PenaltyBlocks::new(Vec::new()),
5219 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5220 SurvivalSpec::Net,
5221 )
5222 .expect("construct scaled survival model");
5223 let scaled_state = scaled_model
5224 .update_state(&beta)
5225 .expect("evaluate scaled survival state");
5226
5227 let weighted_events = sampleweight
5228 .iter()
5229 .zip(event_target.iter())
5230 .map(|(w, d)| *w * f64::from(*d))
5231 .sum::<f64>();
5232 let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5233 assert!(
5234 (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5235 "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5236 scaled_state.deviance,
5237 base_state.deviance
5238 );
5239
5240 for j in 0..beta.len() {
5241 assert!(
5242 (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5243 "gradient mismatch at j={j}: scaled={} base={}",
5244 scaled_state.gradient[j],
5245 base_state.gradient[j]
5246 );
5247 }
5248
5249 let base_hessian = base_state.hessian.to_dense();
5250 let scaled_hessian = scaled_state.hessian.to_dense();
5251 for r in 0..beta.len() {
5252 for c in 0..beta.len() {
5253 assert!(
5254 (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5255 "hessian mismatch at ({r},{c}): scaled={} base={}",
5256 scaled_hessian[[r, c]],
5257 base_hessian[[r, c]]
5258 );
5259 }
5260 }
5261 }
5262}