1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3 BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4 projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9 LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21 #[error("input dimensions are inconsistent")]
22 DimensionMismatch,
23 #[error("inputs contain non-finite values")]
24 NonFiniteInput,
25 #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26 UnsupportedSpec(&'static str),
27 #[error("crude risk integration setup is invalid")]
28 InvalidIntegrationSetup,
29 #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30 InvalidTimeGrid,
31 #[error("cumulative hazard must be nondecreasing")]
32 NonMonotoneCumulativeHazard,
33 #[error("instantaneous hazard must stay strictly positive during integration")]
34 NonPositiveHazard,
35 #[error("{reason}")]
36 InvalidInput { reason: String },
37 #[error("{reason}")]
38 CauseSpecificDimensionMismatch { reason: String },
39 #[error("{reason}")]
40 NumericalFailure { reason: String },
41 #[error("{reason}")]
42 EventCodeInvalid { reason: String },
43 #[error("{reason}")]
44 EventDegenerate { reason: String },
45 #[error("cause-specific survival block {block}: {source}")]
46 CauseSpecificBlock {
47 block: usize,
48 #[source]
49 source: Box<SurvivalError>,
50 },
51}
52
53impl From<SurvivalError> for String {
54 fn from(err: SurvivalError) -> Self {
55 err.to_string()
56 }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60 fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61 SurvivalError::CauseSpecificDimensionMismatch {
62 reason: err.message(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69 #[default]
70 Net,
71 Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76 pub age_entry: ArrayView1<'a, f64>,
77 pub age_exit: ArrayView1<'a, f64>,
78 pub event_target: ArrayView1<'a, u8>,
79 pub event_competing: ArrayView1<'a, u8>,
80 pub sampleweight: ArrayView1<'a, f64>,
81 pub x_entry: ArrayView2<'a, f64>,
82 pub x_exit: ArrayView2<'a, f64>,
83 pub x_derivative: ArrayView2<'a, f64>,
84 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94 pub age_entry: ArrayView1<'a, f64>,
95 pub age_exit: ArrayView1<'a, f64>,
96 pub event_target: ArrayView1<'a, u8>,
97 pub event_competing: ArrayView1<'a, u8>,
98 pub sampleweight: ArrayView1<'a, f64>,
99 pub time_entry: ArrayView2<'a, f64>,
100 pub time_exit: ArrayView2<'a, f64>,
101 pub time_derivative: ArrayView2<'a, f64>,
102 pub covariates: ArrayView2<'a, f64>,
103 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113 pub eta_entry: ArrayView1<'a, f64>,
115 pub eta_exit: ArrayView1<'a, f64>,
117 pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129 pub matrix: Array2<f64>,
130 pub lambda: f64,
131 pub range: Range<usize>,
132 pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139 pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143 pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144 Self { blocks }
145 }
146
147 pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148 let mut grad = Array1::zeros(beta.len());
149 for block in &self.blocks {
150 if block.lambda == 0.0 {
151 continue;
152 }
153 let b = beta.slice(ndarray::s![block.range.clone()]);
154 let g = block.matrix.dot(&b);
155 let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156 dst += &(block.lambda * g);
157 }
158 grad
159 }
160
161 pub fn hessian(&self, dim: usize) -> Array2<f64> {
162 let mut h = Array2::zeros((dim, dim));
163 self.addhessian_inplace(&mut h);
164 h
165 }
166
167 pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168 let mut value = 0.0;
169 for block in &self.blocks {
170 if block.lambda == 0.0 {
171 continue;
172 }
173 let b = beta.slice(ndarray::s![block.range.clone()]);
174 value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175 }
176 value
177 }
178
179 pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180 for block in &self.blocks {
181 if block.lambda == 0.0 {
182 continue;
183 }
184 let start = block.range.start;
185 let end = block.range.end;
186 h.slice_mut(ndarray::s![start..end, start..end])
187 .scaled_add(block.lambda, &block.matrix);
188 }
189 }
190}
191
192pub const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
205
206const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
214
215#[derive(Debug, Clone)]
216pub struct CauseSpecificRoystonParmarBlock {
217 pub age_entry: Array1<f64>,
218 pub age_exit: Array1<f64>,
219 pub event_target: Array1<u8>,
220 pub sampleweight: Array1<f64>,
221 pub x_entry: Array2<f64>,
222 pub x_exit: Array2<f64>,
223 pub x_derivative: Array2<f64>,
224 pub offset_eta_entry: Array1<f64>,
225 pub offset_eta_exit: Array1<f64>,
226 pub offset_derivative_exit: Array1<f64>,
227 pub derivative_floor: f64,
228}
229
230#[derive(Debug, Clone)]
236pub struct CauseSpecificRoystonParmarFamily {
237 blocks: Vec<CauseSpecificRoystonParmarBlock>,
238}
239
240impl CauseSpecificRoystonParmarFamily {
241 pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
242 if blocks.is_empty() {
243 return Err(SurvivalError::InvalidInput {
244 reason: "cause-specific survival family requires at least one endpoint".to_string(),
245 }
246 .into());
247 }
248 for (idx, block) in blocks.iter().enumerate() {
249 validate_cause_specific_block(block).map_err(|err| {
250 SurvivalError::CauseSpecificBlock {
251 block: idx + 1,
252 source: Box::new(err),
253 }
254 .to_string()
255 })?;
256 }
257 Ok(Self { blocks })
258 }
259
260 pub fn cause_count(&self) -> usize {
261 self.blocks.len()
262 }
263}
264
265fn validate_cause_specific_block(
266 block: &CauseSpecificRoystonParmarBlock,
267) -> Result<(), SurvivalError> {
268 let n = block.event_target.len();
269 let p = block.x_exit.ncols();
270 if n == 0 || p == 0 {
271 bail_invalid_surv!("empty event vector or coefficient block");
272 }
273 if block.age_entry.len() != n
274 || block.age_exit.len() != n
275 || block.sampleweight.len() != n
276 || block.x_entry.nrows() != n
277 || block.x_exit.nrows() != n
278 || block.x_derivative.nrows() != n
279 || block.x_entry.ncols() != p
280 || block.x_derivative.ncols() != p
281 || block.offset_eta_entry.len() != n
282 || block.offset_eta_exit.len() != n
283 || block.offset_derivative_exit.len() != n
284 {
285 return Err(SurvivalError::CauseSpecificDimensionMismatch {
286 reason: "dimension mismatch".to_string(),
287 });
288 }
289 if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
295 return Err(SurvivalError::EventCodeInvalid {
296 reason: format!(
297 "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
298 ),
299 });
300 }
301 if block.age_entry.iter().any(|v| !v.is_finite())
302 || block.age_exit.iter().any(|v| !v.is_finite())
303 || block
304 .sampleweight
305 .iter()
306 .any(|v| !v.is_finite() || *v < 0.0)
307 || block.x_entry.iter().any(|v| !v.is_finite())
308 || block.x_exit.iter().any(|v| !v.is_finite())
309 || block.x_derivative.iter().any(|v| !v.is_finite())
310 || block.offset_eta_entry.iter().any(|v| !v.is_finite())
311 || block.offset_eta_exit.iter().any(|v| !v.is_finite())
312 || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
313 || !block.derivative_floor.is_finite()
314 || block.derivative_floor < 0.0
315 {
316 bail_invalid_surv!("non-finite input");
317 }
318 Ok(())
319}
320
321fn evaluate_cause_specific_block(
322 block: &CauseSpecificRoystonParmarBlock,
323 beta: &Array1<f64>,
324) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
325 let n = block.event_target.len();
326 let p = block.x_exit.ncols();
327 if beta.len() != p {
328 return Err(SurvivalError::CauseSpecificDimensionMismatch {
329 reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
330 });
331 }
332 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
333 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
334 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
335 let mut log_likelihood = 0.0;
336 let mut w_exit = Array1::<f64>::zeros(n);
337 let mut w_entry = Array1::<f64>::zeros(n);
338 let mut w_event = Array1::<f64>::zeros(n);
339 let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
340 let mut w_event_outer = Array1::<f64>::zeros(n);
341
342 for i in 0..n {
343 let weight = block.sampleweight[i];
344 if weight <= 0.0 {
345 continue;
346 }
347 if block.age_exit[i] < block.age_entry[i] {
348 bail_invalid_surv!("age_exit < age_entry at row {i}");
349 }
350 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
351 let h_exit = eta_exit[i].exp();
352 let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
353 if !(h_exit.is_finite() && h_entry.is_finite()) {
354 return Err(SurvivalError::NumericalFailure {
355 reason: format!("non-finite cumulative hazard at row {i}"),
356 });
357 }
358 log_likelihood -= weight * (h_exit - h_entry);
359 w_exit[i] = weight * h_exit;
360 w_entry[i] = weight * h_entry;
361 if block.event_target[i] > 0 {
362 let deriv = derivative[i];
363 if !(deriv.is_finite() && deriv > 0.0) {
364 return Err(SurvivalError::NumericalFailure {
365 reason: format!(
366 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
367 ),
368 });
369 }
370 log_likelihood += weight * (eta_exit[i] + deriv.ln());
371 w_event[i] = weight;
372 w_event_inv_deriv[i] = weight / deriv;
373 w_event_outer[i] = weight / (deriv * deriv);
374 }
375 }
376
377 let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
378 nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
379 nll_gradient -= &fast_atv(&block.x_exit, &w_event);
380 nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
381 let gradient = -nll_gradient;
382
383 let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
384 hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
385 hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
386 Ok((log_likelihood, gradient, hessian))
387}
388
389impl CustomFamily for CauseSpecificRoystonParmarFamily {
390 fn joint_jeffreys_term_required(&self) -> bool {
394 true
395 }
396
397 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
398 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
399 "cause-specific survival",
400 self.blocks.len(),
401 block_states.len(),
402 )?;
403 let mut log_likelihood = 0.0;
404 let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
405 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
406 let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
407 log_likelihood += ll;
408 blockworking_sets.push(BlockWorkingSet::ExactNewton {
409 gradient,
410 hessian: SymmetricMatrix::Dense(hessian),
411 });
412 }
413 Ok(FamilyEvaluation {
414 log_likelihood,
415 blockworking_sets,
416 })
417 }
418
419 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
420 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
421 "cause-specific survival",
422 self.blocks.len(),
423 block_states.len(),
424 )?;
425 let mut log_likelihood = 0.0;
426 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
427 let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
428 log_likelihood += ll;
429 }
430 Ok(log_likelihood)
431 }
432
433 fn likelihood_blocks_uncoupled(&self) -> bool {
434 true
435 }
436
437 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
438 true
439 }
440
441 fn output_channel_assignment(
442 &self,
443 specs: &[crate::custom_family::ParameterBlockSpec],
444 ) -> Option<Vec<usize>> {
445 if specs.len() != self.blocks.len() {
446 return Some((0..self.blocks.len()).collect());
447 }
448 Some((0..specs.len()).collect())
449 }
450
451 fn coefficient_hessian_cost(
452 &self,
453 specs: &[crate::custom_family::ParameterBlockSpec],
454 ) -> u64 {
455 crate::custom_family::default_coefficient_hessian_cost(specs)
456 }
457
458 fn block_linear_constraints(
459 &self,
460 _: &[ParameterBlockState],
461 block_idx: usize,
462 spec: &crate::custom_family::ParameterBlockSpec,
463 ) -> Result<Option<LinearInequalityConstraints>, String> {
464 let block = self.blocks.get(block_idx).ok_or_else(|| {
465 SurvivalError::CauseSpecificDimensionMismatch {
466 reason: format!(
467 "cause-specific survival expected block index < {}, got {block_idx}",
468 self.blocks.len()
469 ),
470 }
471 .to_string()
472 })?;
473 if block.x_derivative.ncols() != spec.design.ncols() {
474 return Err(SurvivalError::CauseSpecificDimensionMismatch {
475 reason: format!(
476 "cause-specific survival derivative design has {} columns but block '{}' has {}",
477 block.x_derivative.ncols(),
478 spec.name,
479 spec.design.ncols()
480 ),
481 }
482 .into());
483 }
484 let rhs = block
485 .offset_derivative_exit
486 .mapv(|offset| block.derivative_floor - offset);
487 Ok(Some(LinearInequalityConstraints {
488 a: block.x_derivative.clone(),
489 b: rhs,
490 }))
491 }
492
493 fn max_feasible_step_size(
494 &self,
495 block_states: &[ParameterBlockState],
496 block_idx: usize,
497 delta: &Array1<f64>,
498 ) -> Result<Option<f64>, String> {
499 let block = self.blocks.get(block_idx).ok_or_else(|| {
500 SurvivalError::CauseSpecificDimensionMismatch {
501 reason: format!(
502 "cause-specific survival expected block index < {}, got {block_idx}",
503 self.blocks.len()
504 ),
505 }
506 .to_string()
507 })?;
508 let state = block_states.get(block_idx).ok_or_else(|| {
509 SurvivalError::CauseSpecificDimensionMismatch {
510 reason: format!(
511 "cause-specific survival expected {} block states, got {}",
512 self.blocks.len(),
513 block_states.len()
514 ),
515 }
516 .to_string()
517 })?;
518 if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
519 return Err(SurvivalError::CauseSpecificDimensionMismatch {
520 reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
521 }
522 .into());
523 }
524 let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
525 let derivative_delta = fast_av(&block.x_derivative, delta);
526 let mut alpha_max = 1.0_f64;
527 for i in 0..derivative.len() {
528 if block.sampleweight[i] <= 0.0 {
529 continue;
530 }
531 let current = derivative[i] - block.derivative_floor;
532 let slope = derivative_delta[i];
533 if slope < 0.0 {
534 if current <= 0.0 {
535 return Ok(Some(0.0));
536 }
537 alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
538 }
539 }
540 Ok(Some(alpha_max.clamp(0.0, 1.0)))
541 }
542
543 fn exact_newton_hessian_directional_derivative(
544 &self,
545 block_states: &[ParameterBlockState],
546 block_idx: usize,
547 d_beta: &Array1<f64>,
548 ) -> Result<Option<Array2<f64>>, String> {
549 let block = self.blocks.get(block_idx).ok_or_else(|| {
550 SurvivalError::CauseSpecificDimensionMismatch {
551 reason: format!(
552 "cause-specific survival expected block index < {}, got {block_idx}",
553 self.blocks.len()
554 ),
555 }
556 .to_string()
557 })?;
558 let state = block_states.get(block_idx).ok_or_else(|| {
559 SurvivalError::CauseSpecificDimensionMismatch {
560 reason: format!(
561 "cause-specific survival expected {} block states, got {}",
562 self.blocks.len(),
563 block_states.len()
564 ),
565 }
566 .to_string()
567 })?;
568 Ok(Some(cause_specific_hessian_directional_derivative(
569 block,
570 &state.beta,
571 d_beta,
572 )?))
573 }
574
575 fn exact_newton_hessian_second_directional_derivative(
576 &self,
577 block_states: &[ParameterBlockState],
578 block_idx: usize,
579 d_beta_u: &Array1<f64>,
580 d_beta_v: &Array1<f64>,
581 ) -> Result<Option<Array2<f64>>, String> {
582 let block = self.blocks.get(block_idx).ok_or_else(|| {
583 SurvivalError::CauseSpecificDimensionMismatch {
584 reason: format!(
585 "cause-specific survival expected block index < {}, got {block_idx}",
586 self.blocks.len()
587 ),
588 }
589 .to_string()
590 })?;
591 let state = block_states.get(block_idx).ok_or_else(|| {
592 SurvivalError::CauseSpecificDimensionMismatch {
593 reason: format!(
594 "cause-specific survival expected {} block states, got {}",
595 self.blocks.len(),
596 block_states.len()
597 ),
598 }
599 .to_string()
600 })?;
601 Ok(Some(cause_specific_hessian_second_directional_derivative(
602 block,
603 &state.beta,
604 d_beta_u,
605 d_beta_v,
606 )?))
607 }
608}
609
610fn cause_specific_hessian_directional_derivative(
611 block: &CauseSpecificRoystonParmarBlock,
612 beta: &Array1<f64>,
613 d_beta: &Array1<f64>,
614) -> Result<Array2<f64>, SurvivalError> {
615 let p = block.x_exit.ncols();
616 if beta.len() != p || d_beta.len() != p {
617 return Err(SurvivalError::CauseSpecificDimensionMismatch {
618 reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
619 });
620 }
621 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
622 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
623 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
624 let d_eta_entry = fast_av(&block.x_entry, d_beta);
625 let d_eta_exit = fast_av(&block.x_exit, d_beta);
626 let d_derivative = fast_av(&block.x_derivative, d_beta);
627 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
628 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
629 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
630
631 for i in 0..block.event_target.len() {
632 let weight = block.sampleweight[i];
633 if weight <= 0.0 {
634 continue;
635 }
636 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
637 w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
638 if has_entry {
639 w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
640 }
641 if block.event_target[i] > 0 {
642 let deriv = derivative[i];
643 if !(deriv.is_finite() && deriv > 0.0) {
644 return Err(SurvivalError::NumericalFailure {
645 reason: format!(
646 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
647 ),
648 });
649 }
650 w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
651 }
652 }
653
654 let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
655 d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
656 d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
657 Ok(d_hessian)
658}
659
660fn cause_specific_hessian_second_directional_derivative(
661 block: &CauseSpecificRoystonParmarBlock,
662 beta: &Array1<f64>,
663 d_beta_u: &Array1<f64>,
664 d_beta_v: &Array1<f64>,
665) -> Result<Array2<f64>, SurvivalError> {
666 let p = block.x_exit.ncols();
667 if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
668 return Err(SurvivalError::CauseSpecificDimensionMismatch {
669 reason: "cause-specific survival second Hessian derivative dimension mismatch"
670 .to_string(),
671 });
672 }
673 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
674 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
675 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
676 let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
677 let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
678 let u_derivative = fast_av(&block.x_derivative, d_beta_u);
679 let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
680 let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
681 let v_derivative = fast_av(&block.x_derivative, d_beta_v);
682 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
683 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
684 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
685
686 for i in 0..block.event_target.len() {
687 let weight = block.sampleweight[i];
688 if weight <= 0.0 {
689 continue;
690 }
691 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
692 w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
693 if has_entry {
694 w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
695 }
696 if block.event_target[i] > 0 {
697 let deriv = derivative[i];
698 if !(deriv.is_finite() && deriv > 0.0) {
699 return Err(SurvivalError::NumericalFailure {
700 reason: format!(
701 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
702 ),
703 });
704 }
705 w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
706 }
707 }
708
709 let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
710 d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
711 d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
712 Ok(d2_hessian)
713}
714
715pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
716 const INTEGER_TOL: f64 = 1e-8;
717 const MAX_AUTO_CAUSES: u8 = 32;
718 if !value.is_finite() {
719 return Err(SurvivalError::EventCodeInvalid {
720 reason: format!(
721 "survival event value at row {} is non-finite",
722 row_index + 1
723 ),
724 }
725 .into());
726 }
727 if value < 0.0 {
728 return Err(SurvivalError::EventCodeInvalid {
729 reason: format!(
730 "survival event value at row {} is negative: {value}",
731 row_index + 1
732 ),
733 }
734 .into());
735 }
736 let rounded = value.round();
737 if (value - rounded).abs() > INTEGER_TOL {
738 return Err(SurvivalError::EventCodeInvalid {
739 reason: format!(
740 "survival event value at row {} must be an integer code with 0=censored, got {value}",
741 row_index + 1
742 ),
743 }
744 .into());
745 }
746 if rounded > f64::from(MAX_AUTO_CAUSES) {
747 return Err(SurvivalError::EventCodeInvalid {
748 reason: format!(
749 "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
750 row_index + 1
751 ),
752 }
753 .into());
754 }
755 Ok(rounded as u8)
756}
757
758pub fn cause_count_from_event_codes(
759 event_codes: ArrayView1<'_, u8>,
760) -> Result<usize, SurvivalError> {
761 let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
762 if max_code == 0 {
763 return Ok(1);
764 }
765
766 let mut present = vec![false; max_code + 1];
767 for code in event_codes.iter().copied() {
768 present[usize::from(code)] = true;
769 }
770 if (1..=max_code).any(|code| !present[code]) {
771 let actual = present
772 .iter()
773 .enumerate()
774 .skip(1)
775 .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
776 .collect::<Vec<_>>()
777 .join(", ");
778 return Err(SurvivalError::EventCodeInvalid {
779 reason: format!(
780 "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
781 ),
782 });
783 }
784
785 Ok(max_code)
786}
787
788pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
801 event_codes.mapv(|label| u8::from(label > 0))
802}
803
804pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
814 let cause_code = cause as u8;
815 event_codes.mapv(|observed| u8::from(observed == cause_code))
816}
817
818fn compress_positive_collinear_constraints(
819 a: &Array2<f64>,
820 b: &Array1<f64>,
821) -> LinearInequalityConstraints {
822 const SCALE_TOL: f64 = 1e-14;
823 const KEY_TOL: f64 = 1e-8;
824
825 let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
826 let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
827
828 for i in 0..a.nrows() {
829 let row = a.row(i);
830 let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
831 if !scale.is_finite() || scale <= SCALE_TOL {
832 if b[i] > 0.0 {
833 fallbackrows.push((row.to_vec(), b[i]));
834 }
835 continue;
836 }
837
838 let normalizedrow: Vec<f64> = row
839 .iter()
840 .map(|&v| {
841 let scaled = v / scale;
842 if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
843 })
844 .collect();
845 let normalized_rhs = b[i] / scale;
846 let key: Vec<i64> = normalizedrow
847 .iter()
848 .map(|&v| (v / KEY_TOL).round() as i64)
849 .collect();
850
851 match grouped.get_mut(&key) {
852 Some((_, rhs_max)) => {
853 if normalized_rhs > *rhs_max {
854 *rhs_max = normalized_rhs;
855 }
856 }
857 None => {
858 grouped.insert(key, (normalizedrow, normalized_rhs));
859 }
860 }
861 }
862
863 let nrows = grouped.len() + fallbackrows.len();
864 let n_cols = a.ncols();
865 let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
866 let mut b_out = Array1::<f64>::zeros(nrows);
867
868 let mut outrow = 0usize;
869 for (_, (row, rhs)) in grouped {
870 for (j, value) in row.into_iter().enumerate() {
871 a_out[[outrow, j]] = value;
872 }
873 b_out[outrow] = rhs;
874 outrow += 1;
875 }
876 for (row, rhs) in fallbackrows {
877 for (j, value) in row.into_iter().enumerate() {
878 a_out[[outrow, j]] = value;
879 }
880 b_out[outrow] = rhs;
881 outrow += 1;
882 }
883
884 LinearInequalityConstraints { a: a_out, b: b_out }
885}
886
887#[derive(Debug, Clone, Copy, Default)]
888pub struct SurvivalMonotonicityPenalty {
889 pub tolerance: f64,
890}
891
892#[derive(Debug, Clone)]
893enum SurvivalDesign {
894 Flat {
895 x_entry: Array2<f64>,
896 x_exit: Array2<f64>,
897 x_derivative: Array2<f64>,
898 },
899 TimeCovariateShared {
900 time_entry: Array2<f64>,
901 time_exit: Array2<f64>,
902 time_derivative: Array2<f64>,
903 covariates: Array2<f64>,
904 },
905}
906
907impl SurvivalDesign {
908 fn p_total(&self) -> usize {
909 match self {
910 Self::Flat { x_exit, .. } => x_exit.ncols(),
911 Self::TimeCovariateShared {
912 time_exit,
913 covariates,
914 ..
915 } => time_exit.ncols() + covariates.ncols(),
916 }
917 }
918
919 fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
920 match self {
921 Self::Flat { .. } => time_mat.dot(beta),
922 Self::TimeCovariateShared { covariates, .. } => {
923 let p_time = time_mat.ncols();
924 let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
925 if covariates.ncols() > 0 {
926 out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
927 }
928 out
929 }
930 }
931 }
932
933 fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
934 match self {
935 Self::Flat { .. } => {
936 for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
937 *dst = src;
938 }
939 }
940 Self::TimeCovariateShared { covariates, .. } => {
941 let p_time = time_mat.ncols();
942 for j in 0..p_time {
943 out[j] = time_mat[[i, j]];
944 }
945 for j in 0..covariates.ncols() {
946 out[p_time + j] = covariates[[i, j]];
947 }
948 }
949 }
950 }
951}
952
953#[derive(Debug, Clone)]
955struct SurvivalWorkspace {
956 w_event: Array1<f64>,
957 w_event_inv_deriv: Array1<f64>,
958 w_event_outer: Array1<f64>,
959 w_hess_exit: Array1<f64>,
960 w_hess_entry: Array1<f64>,
961}
962
963impl SurvivalWorkspace {
964 fn new(n: usize) -> Self {
965 Self {
966 w_event: Array1::zeros(n),
967 w_event_inv_deriv: Array1::zeros(n),
968 w_event_outer: Array1::zeros(n),
969 w_hess_exit: Array1::zeros(n),
970 w_hess_entry: Array1::zeros(n),
971 }
972 }
973
974 fn reset(&mut self, n: usize) {
975 if self.w_event.len() != n {
976 *self = Self::new(n);
977 } else {
978 self.w_event.fill(0.0);
979 self.w_event_inv_deriv.fill(0.0);
980 self.w_event_outer.fill(0.0);
981 self.w_hess_exit.fill(0.0);
982 self.w_hess_entry.fill(0.0);
983 }
984 }
985}
986
987#[derive(Clone, Debug)]
1000pub struct OffsetChannelResiduals {
1001 pub exit: Array1<f64>,
1003 pub entry: Array1<f64>,
1005 pub derivative: Array1<f64>,
1007 pub right: Array1<f64>,
1011}
1012
1013#[derive(Clone, Debug)]
1016pub struct OffsetChannelCurvatures {
1017 pub rows: Vec<[[f64; 3]; 3]>,
1018}
1019
1020#[derive(Debug)]
1021pub struct WorkingModelSurvival {
1022 age_entry: Array1<f64>,
1023 age_exit: Array1<f64>,
1024 entry_at_origin: Array1<bool>,
1025 event_target: Array1<u8>,
1026 sampleweight: Array1<f64>,
1027 design: SurvivalDesign,
1028 offset_eta_entry: Array1<f64>,
1029 offset_eta_exit: Array1<f64>,
1030 offset_derivative_exit: Array1<f64>,
1031 penalties: PenaltyBlocks,
1032 monotonicity: SurvivalMonotonicityPenalty,
1033 structurally_monotonic: bool,
1034 structural_time_columns: usize,
1035 monotonicity_constraint_rows: Option<Array2<f64>>,
1036 monotonicity_constraint_offsets: Option<Array1<f64>>,
1037 workspace: std::sync::Mutex<SurvivalWorkspace>,
1038}
1039
1040impl Clone for WorkingModelSurvival {
1041 fn clone(&self) -> Self {
1042 let workspace = self.workspace.lock().unwrap().clone();
1043 Self {
1044 age_entry: self.age_entry.clone(),
1045 age_exit: self.age_exit.clone(),
1046 entry_at_origin: self.entry_at_origin.clone(),
1047 event_target: self.event_target.clone(),
1048 sampleweight: self.sampleweight.clone(),
1049 design: self.design.clone(),
1050 offset_eta_entry: self.offset_eta_entry.clone(),
1051 offset_eta_exit: self.offset_eta_exit.clone(),
1052 offset_derivative_exit: self.offset_derivative_exit.clone(),
1053 penalties: self.penalties.clone(),
1054 monotonicity: self.monotonicity,
1055 structurally_monotonic: self.structurally_monotonic,
1056 structural_time_columns: self.structural_time_columns,
1057 monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1058 monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1059 workspace: std::sync::Mutex::new(workspace),
1060 }
1061 }
1062}
1063
1064impl WorkingModelSurvival {
1065 const LOG_F64_MAX: f64 = 709.782712893384;
1066
1067 #[inline]
1068 fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1069 if base == 0.0 {
1070 return Ok(0.0);
1071 }
1072 let log_abs = log_scale + base.abs().ln();
1073 if !log_abs.is_finite() {
1074 crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1075 }
1076 if log_abs > Self::LOG_F64_MAX {
1077 crate::bail_invalid_estim!(
1078 "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1079 );
1080 }
1081 Ok(base.signum() * log_abs.exp())
1082 }
1083
1084 fn coefficient_dim(&self) -> usize {
1085 self.design.p_total()
1086 }
1087
1088 fn nrows(&self) -> usize {
1089 self.sampleweight.len()
1090 }
1091
1092 fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1093 let time_mat = match &self.design {
1094 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1095 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1096 };
1097 self.design.design_dot(time_mat, beta)
1098 }
1099
1100 fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1101 let time_mat = match &self.design {
1102 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1103 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1104 };
1105 self.design.design_dot(time_mat, beta)
1106 }
1107
1108 fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1109 match &self.design {
1110 SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1111 SurvivalDesign::TimeCovariateShared {
1112 time_derivative, ..
1113 } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1114 }
1115 }
1116
1117 fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1118 let time_mat = match &self.design {
1119 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1120 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1121 };
1122 self.design.fill_row(time_mat, i, out);
1123 }
1124
1125 fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1126 let time_mat = match &self.design {
1127 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1128 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1129 };
1130 self.design.fill_row(time_mat, i, out);
1131 }
1132
1133 fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1134 match &self.design {
1135 SurvivalDesign::Flat { x_derivative, .. } => {
1136 for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1137 *dst = src;
1138 }
1139 }
1140 SurvivalDesign::TimeCovariateShared {
1141 time_derivative, ..
1142 } => {
1143 let p_time = time_derivative.ncols();
1144 for j in 0..p_time {
1145 out[j] = time_derivative[[i, j]];
1146 }
1147 for dst in out.iter_mut().skip(p_time) {
1148 *dst = 0.0;
1149 }
1150 }
1151 }
1152 }
1153
1154 fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1155 match &self.design {
1156 SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1157 SurvivalDesign::TimeCovariateShared {
1158 time_derivative,
1159 covariates,
1160 ..
1161 } => {
1162 let p_time = time_derivative.ncols();
1163 let p_cov = covariates.ncols();
1164 let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1165 let time_block = fast_xt_diag_x(time_derivative, weights);
1166 out.slice_mut(ndarray::s![..p_time, ..p_time])
1167 .assign(&time_block);
1168 out
1169 }
1170 }
1171 }
1172
1173 fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1177 match &self.design {
1178 SurvivalDesign::Flat {
1179 x_entry, x_exit, ..
1180 } => {
1181 let mut h = fast_xt_diag_x(x_exit, w_exit);
1182 h -= &fast_xt_diag_x(x_entry, w_entry);
1183 h
1184 }
1185 SurvivalDesign::TimeCovariateShared {
1186 time_entry,
1187 time_exit,
1188 covariates,
1189 ..
1190 } => {
1191 let p_time = time_exit.ncols();
1192 let p_cov = covariates.ncols();
1193 let p = p_time + p_cov;
1194 let mut h = Array2::<f64>::zeros((p, p));
1195 let tt = {
1197 let mut block = fast_xt_diag_x(time_exit, w_exit);
1198 block -= &fast_xt_diag_x(time_entry, w_entry);
1199 block
1200 };
1201 h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1202 if p_cov > 0 {
1203 let tc = {
1205 let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1206 block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1207 block
1208 };
1209 h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1210 h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1211 let w_diff = w_exit - w_entry;
1213 let cc = fast_xt_diag_x(covariates, &w_diff);
1214 h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1215 }
1216 h
1217 }
1218 }
1219 }
1220
1221 fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1222 const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1223 if !self.structurally_monotonic {
1224 return None;
1225 }
1226 if deriv >= 1e-12 {
1227 return Some(deriv);
1228 }
1229 if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1230 return Some(1e-12);
1231 }
1232 None
1233 }
1234
1235 fn validate_penalties(
1236 penalties: &PenaltyBlocks,
1237 coefficient_dim: usize,
1238 ) -> Result<(), SurvivalError> {
1239 for block in &penalties.blocks {
1240 if !block.lambda.is_finite() || block.lambda < 0.0 {
1241 return Err(SurvivalError::NonFiniteInput);
1242 }
1243 if block.range.start > block.range.end || block.range.end > coefficient_dim {
1244 return Err(SurvivalError::DimensionMismatch);
1245 }
1246 let block_dim = block.range.end - block.range.start;
1247 if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1248 return Err(SurvivalError::DimensionMismatch);
1249 }
1250 if block.matrix.iter().any(|v| !v.is_finite()) {
1251 return Err(SurvivalError::NonFiniteInput);
1252 }
1253 }
1254 Ok(())
1255 }
1256
1257 fn derivative_guard(&self) -> f64 {
1258 if self.structurally_monotonic {
1259 return 0.0;
1263 }
1264 self.monotonicity.tolerance.max(0.0)
1265 }
1266
1267 fn derivative_guard_numerical(&self) -> f64 {
1268 let derivative_guard = self.derivative_guard();
1269 if derivative_guard <= 0.0 {
1270 if self.structurally_monotonic {
1279 -1e-10
1280 } else {
1281 1e-12
1282 }
1283 } else {
1284 (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1285 }
1286 }
1287
1288 fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1289 let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1290 1e-10 * scale
1291 }
1292
1293 fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1294 if !self.structurally_monotonic {
1295 return None;
1296 }
1297 let p = self.coefficient_dim();
1298 let time_columns = self.structural_time_columns.min(p);
1299 if time_columns == 0 {
1300 return None;
1301 }
1302 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1303 let mut active_columns = vec![false; time_columns];
1304 let mut derivative_row = vec![0.0_f64; p];
1305 for i in 0..self.nrows() {
1306 if self.sampleweight[i] <= 0.0 {
1307 continue;
1308 }
1309 self.fill_derivative_row(i, &mut derivative_row);
1310 for j in 0..time_columns {
1311 if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1312 active_columns[j] = true;
1313 }
1314 }
1315 }
1316 if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1317 for i in 0..rows.nrows() {
1318 for j in 0..time_columns {
1319 if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1320 active_columns[j] = true;
1321 }
1322 }
1323 }
1324 }
1325 let active_columns: Vec<usize> = active_columns
1326 .into_iter()
1327 .enumerate()
1328 .filter_map(|(j, active)| active.then_some(j))
1329 .collect();
1330 if active_columns.is_empty() {
1331 return None;
1332 }
1333 let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1334 let b = Array1::<f64>::zeros(active_columns.len());
1335 for (row, &col) in active_columns.iter().enumerate() {
1336 a[[row, col]] = 1.0;
1337 }
1338 Some(LinearInequalityConstraints { a, b })
1339 }
1340
1341 pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1342 let p = self.coefficient_dim();
1343 const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1344 if p == 0 {
1345 return None;
1346 }
1347 if self.structurally_monotonic {
1348 return self.structural_time_coefficient_constraints();
1349 }
1350 if let (Some(rows), Some(offsets)) = (
1351 self.monotonicity_constraint_rows.as_ref(),
1352 self.monotonicity_constraint_offsets.as_ref(),
1353 ) {
1354 let activerows: Vec<usize> = (0..rows.nrows())
1355 .filter(|&i| {
1356 rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1357 > DERIVATIVE_ROW_NORM_TOL
1358 })
1359 .collect();
1360 if activerows.is_empty() {
1361 return None;
1362 }
1363 let mut a = Array2::<f64>::zeros((activerows.len(), p));
1364 let mut b = Array1::<f64>::zeros(activerows.len());
1365 for (r, &i) in activerows.iter().enumerate() {
1366 a.row_mut(r).assign(&rows.row(i));
1367 b[r] = self.derivative_guard() - offsets[i];
1368 }
1369 return Some(compress_positive_collinear_constraints(&a, &b));
1370 }
1371 None
1372 }
1373
1374 pub fn from_engine_inputs(
1375 inputs: SurvivalEngineInputs<'_>,
1376 penalties: PenaltyBlocks,
1377 monotonicity: SurvivalMonotonicityPenalty,
1378 spec: SurvivalSpec,
1379 ) -> Result<Self, SurvivalError> {
1380 Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1381 }
1382
1383 fn validate_offsets(
1384 offsets: Option<SurvivalBaselineOffsets<'_>>,
1385 n: usize,
1386 ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1387 if let Some(off) = offsets {
1388 if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1389 {
1390 return Err(SurvivalError::DimensionMismatch);
1391 }
1392 if off.eta_entry.iter().any(|v| !v.is_finite())
1393 || off.eta_exit.iter().any(|v| !v.is_finite())
1394 || off.derivative_exit.iter().any(|v| !v.is_finite())
1395 {
1396 return Err(SurvivalError::NonFiniteInput);
1397 }
1398 Ok((
1399 off.eta_entry.to_owned(),
1400 off.eta_exit.to_owned(),
1401 off.derivative_exit.to_owned(),
1402 ))
1403 } else {
1404 Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1405 }
1406 }
1407
1408 fn validate_common_inputs(
1409 age_entry: &ArrayView1<f64>,
1410 age_exit: &ArrayView1<f64>,
1411 event_target: &ArrayView1<u8>,
1412 event_competing: &ArrayView1<u8>,
1413 sampleweight: &ArrayView1<f64>,
1414 ) -> Result<(), SurvivalError> {
1415 if age_entry.iter().any(|v| !v.is_finite())
1416 || age_exit.iter().any(|v| !v.is_finite())
1417 || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1418 {
1419 return Err(SurvivalError::NonFiniteInput);
1420 }
1421 if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1428 return Err(SurvivalError::EventCodeInvalid {
1429 reason: format!(
1430 "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1431 ),
1432 });
1433 }
1434 if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1435 return Err(SurvivalError::EventCodeInvalid {
1436 reason: format!(
1437 "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1438 ),
1439 });
1440 }
1441 if event_target
1442 .iter()
1443 .zip(event_competing.iter())
1444 .any(|(&target, &competing)| target > 0 && competing > 0)
1445 {
1446 return Err(SurvivalError::EventCodeInvalid {
1447 reason: "a row cannot be simultaneously a target event and a competing event"
1448 .to_string(),
1449 });
1450 }
1451 if age_entry
1465 .iter()
1466 .zip(age_exit.iter())
1467 .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1468 {
1469 return Err(SurvivalError::NonFiniteInput);
1470 }
1471 Ok::<(), _>(())
1472 }
1473
1474 fn validate_monotonicity_constraints(
1475 rows: Option<ArrayView2<'_, f64>>,
1476 offsets: Option<ArrayView1<'_, f64>>,
1477 coefficient_dim: usize,
1478 ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1479 match (rows, offsets) {
1480 (None, None) => Ok((None, None)),
1481 (Some(rows), Some(offsets)) => {
1482 if rows.ncols() != coefficient_dim
1483 || rows.nrows() != offsets.len()
1484 || rows.iter().any(|v| !v.is_finite())
1485 || offsets.iter().any(|v| !v.is_finite())
1486 {
1487 return Err(SurvivalError::DimensionMismatch);
1488 }
1489 Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1490 }
1491 _ => Err(SurvivalError::DimensionMismatch),
1492 }
1493 }
1494
1495 fn finish_construction(
1496 age_entry: ArrayView1<f64>,
1497 age_exit: ArrayView1<f64>,
1498 event_target: ArrayView1<u8>,
1499 sampleweight: ArrayView1<f64>,
1500 design: SurvivalDesign,
1501 offset_eta_entry: Array1<f64>,
1502 offset_eta_exit: Array1<f64>,
1503 offset_derivative_exit: Array1<f64>,
1504 penalties: PenaltyBlocks,
1505 monotonicity: SurvivalMonotonicityPenalty,
1506 monotonicity_constraint_rows: Option<Array2<f64>>,
1507 monotonicity_constraint_offsets: Option<Array1<f64>>,
1508 ) -> Self {
1509 let n = age_entry.len();
1510 Self {
1511 age_entry: age_entry.to_owned(),
1512 age_exit: age_exit.to_owned(),
1513 entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1514 event_target: event_target.to_owned(),
1515 sampleweight: sampleweight.to_owned(),
1516 design,
1517 offset_eta_entry,
1518 offset_eta_exit,
1519 offset_derivative_exit,
1520 penalties,
1521 monotonicity,
1522 structurally_monotonic: false,
1523 structural_time_columns: 0,
1524 monotonicity_constraint_rows,
1525 monotonicity_constraint_offsets,
1526 workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1527 }
1528 }
1529
1530 pub fn from_engine_inputswith_offsets(
1531 inputs: SurvivalEngineInputs<'_>,
1532 offsets: Option<SurvivalBaselineOffsets<'_>>,
1533 penalties: PenaltyBlocks,
1534 monotonicity: SurvivalMonotonicityPenalty,
1535 spec: SurvivalSpec,
1536 ) -> Result<Self, SurvivalError> {
1537 if spec == SurvivalSpec::Crude {
1538 return Err(SurvivalError::UnsupportedSpec("crude"));
1539 }
1540 let n = inputs.age_entry.len();
1541 let p = inputs.x_entry.ncols();
1542 if inputs.age_exit.len() != n
1543 || inputs.event_target.len() != n
1544 || inputs.event_competing.len() != n
1545 || inputs.sampleweight.len() != n
1546 || inputs.x_entry.nrows() != n
1547 || inputs.x_exit.nrows() != n
1548 || inputs.x_derivative.nrows() != n
1549 || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1550 || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1551 {
1552 return Err(SurvivalError::DimensionMismatch);
1553 }
1554 Self::validate_penalties(&penalties, p)?;
1555 Self::validate_common_inputs(
1556 &inputs.age_entry,
1557 &inputs.age_exit,
1558 &inputs.event_target,
1559 &inputs.event_competing,
1560 &inputs.sampleweight,
1561 )?;
1562 if inputs.x_entry.iter().any(|v| !v.is_finite())
1563 || inputs.x_exit.iter().any(|v| !v.is_finite())
1564 || inputs.x_derivative.iter().any(|v| !v.is_finite())
1565 {
1566 return Err(SurvivalError::NonFiniteInput);
1567 }
1568 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1569 Self::validate_offsets(offsets, n)?;
1570 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1571 Self::validate_monotonicity_constraints(
1572 inputs.monotonicity_constraint_rows,
1573 inputs.monotonicity_constraint_offsets,
1574 p,
1575 )?;
1576
1577 Ok(Self::finish_construction(
1578 inputs.age_entry,
1579 inputs.age_exit,
1580 inputs.event_target,
1581 inputs.sampleweight,
1582 SurvivalDesign::Flat {
1583 x_entry: inputs.x_entry.to_owned(),
1584 x_exit: inputs.x_exit.to_owned(),
1585 x_derivative: inputs.x_derivative.to_owned(),
1586 },
1587 offset_eta_entry,
1588 offset_eta_exit,
1589 offset_derivative_exit,
1590 penalties,
1591 monotonicity,
1592 monotonicity_constraint_rows,
1593 monotonicity_constraint_offsets,
1594 ))
1595 }
1596
1597 pub fn from_time_covariate_inputswith_offsets(
1598 inputs: SurvivalTimeCovarInputs<'_>,
1599 offsets: Option<SurvivalBaselineOffsets<'_>>,
1600 penalties: PenaltyBlocks,
1601 monotonicity: SurvivalMonotonicityPenalty,
1602 spec: SurvivalSpec,
1603 ) -> Result<Self, SurvivalError> {
1604 if spec == SurvivalSpec::Crude {
1605 return Err(SurvivalError::UnsupportedSpec("crude"));
1606 }
1607 let n = inputs.age_entry.len();
1608 let p_time = inputs.time_entry.ncols();
1609 let p_cov = inputs.covariates.ncols();
1610 let p = p_time + p_cov;
1611 if inputs.age_exit.len() != n
1612 || inputs.event_target.len() != n
1613 || inputs.event_competing.len() != n
1614 || inputs.sampleweight.len() != n
1615 || inputs.time_entry.nrows() != n
1616 || inputs.time_exit.nrows() != n
1617 || inputs.time_derivative.nrows() != n
1618 || inputs.covariates.nrows() != n
1619 || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1620 || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1621 {
1622 return Err(SurvivalError::DimensionMismatch);
1623 }
1624 Self::validate_penalties(&penalties, p)?;
1625 Self::validate_common_inputs(
1626 &inputs.age_entry,
1627 &inputs.age_exit,
1628 &inputs.event_target,
1629 &inputs.event_competing,
1630 &inputs.sampleweight,
1631 )?;
1632 if inputs.time_entry.iter().any(|v| !v.is_finite())
1633 || inputs.time_exit.iter().any(|v| !v.is_finite())
1634 || inputs.time_derivative.iter().any(|v| !v.is_finite())
1635 || inputs.covariates.iter().any(|v| !v.is_finite())
1636 {
1637 return Err(SurvivalError::NonFiniteInput);
1638 }
1639 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1640 Self::validate_offsets(offsets, n)?;
1641 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1642 Self::validate_monotonicity_constraints(
1643 inputs.monotonicity_constraint_rows,
1644 inputs.monotonicity_constraint_offsets,
1645 p,
1646 )?;
1647
1648 Ok(Self::finish_construction(
1649 inputs.age_entry,
1650 inputs.age_exit,
1651 inputs.event_target,
1652 inputs.sampleweight,
1653 SurvivalDesign::TimeCovariateShared {
1654 time_entry: inputs.time_entry.to_owned(),
1655 time_exit: inputs.time_exit.to_owned(),
1656 time_derivative: inputs.time_derivative.to_owned(),
1657 covariates: inputs.covariates.to_owned(),
1658 },
1659 offset_eta_entry,
1660 offset_eta_exit,
1661 offset_derivative_exit,
1662 penalties,
1663 monotonicity,
1664 monotonicity_constraint_rows,
1665 monotonicity_constraint_offsets,
1666 ))
1667 }
1668
1669 pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1684 if lambdas.len() != self.penalties.blocks.len() {
1685 crate::bail_invalid_estim!(
1686 "set_penalty_lambdas expects {} lambdas, got {}",
1687 self.penalties.blocks.len(),
1688 lambdas.len()
1689 );
1690 }
1691 for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1692 if !lambda.is_finite() || lambda < 0.0 {
1693 crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1694 }
1695 block.lambda = lambda;
1696 }
1697 Ok(())
1698 }
1699
1700 pub fn set_structural_monotonicity(
1701 &mut self,
1702 enabled: bool,
1703 time_columns: usize,
1704 ) -> Result<(), EstimationError> {
1705 let p = self.coefficient_dim();
1706 if time_columns > p {
1707 crate::bail_invalid_estim!(
1708 "structural time columns {} exceed coefficient dimension {}",
1709 time_columns,
1710 p
1711 );
1712 }
1713 if enabled && time_columns == 0 {
1714 crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1715 }
1716 if enabled {
1717 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1718 for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1719 if offset < -STRUCTURAL_DERIV_TOL {
1720 crate::bail_invalid_estim!(
1721 "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1722 );
1723 }
1724 }
1725 let mut derivative_row = vec![0.0_f64; p];
1726 for i in 0..self.nrows() {
1727 self.fill_derivative_row(i, &mut derivative_row);
1728 for j in 0..time_columns {
1729 let v = derivative_row[j];
1730 if v < -STRUCTURAL_DERIV_TOL {
1731 crate::bail_invalid_estim!(
1732 "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1733 );
1734 }
1735 }
1736 for j in time_columns..p {
1737 let v = derivative_row[j];
1738 if v.abs() > STRUCTURAL_DERIV_TOL {
1739 crate::bail_invalid_estim!(
1740 "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1741 );
1742 }
1743 }
1744 }
1745 if let (Some(rows), Some(offsets)) = (
1746 self.monotonicity_constraint_rows.as_ref(),
1747 self.monotonicity_constraint_offsets.as_ref(),
1748 ) {
1749 for (i, &offset) in offsets.iter().enumerate() {
1750 if offset < -STRUCTURAL_DERIV_TOL {
1751 crate::bail_invalid_estim!(
1752 "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1753 );
1754 }
1755 }
1756 for i in 0..rows.nrows() {
1757 for j in 0..time_columns {
1758 let v = rows[[i, j]];
1759 if v < -STRUCTURAL_DERIV_TOL {
1760 crate::bail_invalid_estim!(
1761 "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1762 );
1763 }
1764 }
1765 for j in time_columns..p {
1766 let v = rows[[i, j]];
1767 if v.abs() > STRUCTURAL_DERIV_TOL {
1768 crate::bail_invalid_estim!(
1769 "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1770 );
1771 }
1772 }
1773 }
1774 }
1775 }
1776 self.structurally_monotonic = enabled;
1777 self.structural_time_columns = if enabled { time_columns } else { 0 };
1778 Ok(())
1779 }
1780
1781 pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1782 if beta.len() != self.coefficient_dim() {
1783 crate::bail_invalid_estim!("survival beta dimension mismatch");
1784 }
1785
1786 let n = self.nrows();
1787 let p = self.coefficient_dim();
1788
1789 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1815 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1816 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1817
1818 let mut nll = 0.0;
1819 let derivative_guard = self.derivative_guard();
1820 let derivative_guard_numerical = self.derivative_guard_numerical();
1821 let mut workspace = self.workspace.lock().unwrap();
1822 workspace.reset(n);
1823 let SurvivalWorkspace {
1824 w_event,
1825 w_event_inv_deriv,
1826 w_event_outer,
1827 w_hess_exit,
1828 w_hess_entry,
1829 } = &mut *workspace;
1830
1831 for i in 0..n {
1833 let w = self.sampleweight[i];
1834 if w <= 0.0 {
1835 continue;
1836 }
1837 let entry_age = self.age_entry[i];
1838 let exit_age = self.age_exit[i];
1839 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1840 crate::bail_invalid_estim!(
1841 "survival ages must be finite with age_exit >= age_entry"
1842 );
1843 }
1844 let d = f64::from(self.event_target[i]);
1845
1846 let has_entry_interval = !self.entry_at_origin[i];
1847 let interval_scale = if has_entry_interval {
1848 eta_exit[i].max(eta_entry[i])
1849 } else {
1850 eta_exit[i]
1851 };
1852 let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1853 let h_s_scaled = if has_entry_interval {
1854 (eta_entry[i] - interval_scale).exp()
1855 } else {
1856 0.0
1857 };
1858 let interval_scaled = h_e_scaled - h_s_scaled;
1859 let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1860 let deriv = self
1861 .stabilized_structural_derivative(derivative_raw[i])
1862 .unwrap_or(derivative_raw[i]);
1863 let mono_floor = if d > 0.0 {
1872 derivative_guard_numerical
1873 } else {
1874 0.0
1875 };
1876 if !deriv.is_finite() || deriv < mono_floor {
1877 return Err(EstimationError::ParameterConstraintViolation(format!(
1878 "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1879 i, deriv, derivative_guard
1880 )));
1881 }
1882 if has_entry_interval {
1883 let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1884 if interval_scaled + increment_guard < 0.0 {
1885 return Err(EstimationError::ParameterConstraintViolation(format!(
1886 "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1887 i, interval
1888 )));
1889 }
1890 }
1891 nll += w * interval;
1892
1893 let w_exit_i = w * eta_exit[i].exp();
1897 let w_entry_i = if has_entry_interval {
1898 w * eta_entry[i].exp()
1899 } else {
1900 0.0
1901 };
1902 if !w_exit_i.is_finite() {
1903 crate::bail_invalid_estim!(
1904 "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1905 );
1906 }
1907 w_hess_exit[i] = w_exit_i;
1908 w_hess_entry[i] = w_entry_i;
1909
1910 if d > 0.0 {
1911 let inv_deriv = 1.0 / deriv;
1912 nll += -w * (eta_exit[i] + deriv.ln());
1913 w_event[i] = w;
1914 w_event_inv_deriv[i] = w * inv_deriv;
1915 w_event_outer[i] = w * inv_deriv * inv_deriv;
1916 }
1917 }
1918
1919 let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1923 let mut grad = Array1::<f64>::zeros(p);
1927 let mut grad_comp = Array1::<f64>::zeros(p);
1928 let mut row_exit = vec![0.0_f64; p];
1929 let mut row_entry = vec![0.0_f64; p];
1930 let mut row_derivative = vec![0.0_f64; p];
1931 for i in 0..n {
1932 let w_interval_exit = w_hess_exit[i];
1933 let w_interval_entry = w_hess_entry[i];
1934 let w_event_exit = w_event[i];
1935 let w_event_derivative = w_event_inv_deriv[i];
1936 if w_interval_exit == 0.0
1937 && w_interval_entry == 0.0
1938 && w_event_exit == 0.0
1939 && w_event_derivative == 0.0
1940 {
1941 continue;
1942 }
1943 self.fill_exit_row(i, &mut row_exit);
1944 self.fill_entry_row(i, &mut row_entry);
1945 self.fill_derivative_row(i, &mut row_derivative);
1946 for j in 0..p {
1947 let contribution = w_interval_exit * row_exit[j]
1948 - w_interval_entry * row_entry[j]
1949 - w_event_exit * row_exit[j]
1950 - w_event_derivative * row_derivative[j];
1951 let t = grad[j] + contribution;
1952 if grad[j].abs() >= contribution.abs() {
1953 grad_comp[j] += (grad[j] - t) + contribution;
1954 } else {
1955 grad_comp[j] += (contribution - t) + grad[j];
1956 }
1957 grad[j] = t;
1958 }
1959 }
1960 grad += &grad_comp;
1961
1962 h += &self.derivative_xt_diag_x(w_event_outer);
1963
1964 let score_norm = array1_l2_norm(&grad);
1968
1969 let penaltygrad = self.penalties.gradient(beta);
1970 let penalty_dev = self.penalties.deviance(beta);
1971 let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1972
1973 let mut totalgrad = grad;
1974 totalgrad += &penaltygrad;
1975
1976 self.penalties.addhessian_inplace(&mut h);
1977 const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1989 let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1990 for d in 0..p {
1991 h[[d, d]] += ridge_used;
1992 }
1993 totalgrad += &beta.mapv(|v| ridge_used * v);
1994 let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1998 let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1999
2000 let log_likelihood = -nll;
2001 let deviance = 2.0 * nll;
2002
2003 Ok(WorkingState {
2004 eta: LinearPredictor::new(eta_exit),
2005 gradient: totalgrad,
2006 hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2007 log_likelihood,
2008 deviance,
2009 penalty_term: penalty_dev + ridge_penalty,
2010 firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2011 ridge_used,
2012 hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2013 gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2014 })
2015 }
2016
2017 pub(crate) fn survival_hessian_derivative_correction(
2027 &self,
2028 beta: &Array1<f64>,
2029 u_k: &Array1<f64>,
2030 ) -> Result<Array2<f64>, EstimationError> {
2031 let p = beta.len();
2032 let n = self.nrows();
2033
2034 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2035 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2036 let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2037 let exp_entry = eta_entry.mapv(f64::exp);
2038 let exp_exit = eta_exit.mapv(f64::exp);
2039 let guard = self.derivative_guard();
2040 let guard_numerical = self.derivative_guard_numerical();
2041
2042 let jac = Array1::<f64>::ones(p);
2043 let curvature = Array1::<f64>::zeros(p);
2044 let third = Array1::<f64>::zeros(p);
2045
2046 let mut row_exit = vec![0.0_f64; p];
2047 let mut row_entry = vec![0.0_f64; p];
2048 let mut row_derivative = vec![0.0_f64; p];
2049 let mut ge = vec![0.0_f64; p];
2050 let mut gs = vec![0.0_f64; p];
2051 let mut gsd = vec![0.0_f64; p];
2052 let mut he = vec![0.0_f64; p];
2053 let mut hs = vec![0.0_f64; p];
2054 let mut hsd = vec![0.0_f64; p];
2055 let mut te = vec![0.0_f64; p];
2056 let mut ts = vec![0.0_f64; p];
2057 let mut tsd = vec![0.0_f64; p];
2058
2059 let mut b_dir = Array2::<f64>::zeros((p, p));
2060
2061 for i in 0..n {
2062 let w_i = self.sampleweight[i];
2063 if w_i <= 0.0 {
2064 continue;
2065 }
2066 let has_entry = !self.entry_at_origin[i];
2067 let mut deta_e = 0.0_f64;
2068 let mut deta_s = 0.0_f64;
2069 let mut ds = 0.0_f64;
2070 self.fill_exit_row(i, &mut row_exit);
2071 self.fill_entry_row(i, &mut row_entry);
2072 self.fill_derivative_row(i, &mut row_derivative);
2073 for j in 0..p {
2074 ge[j] = row_exit[j] * jac[j];
2075 gs[j] = row_entry[j] * jac[j];
2076 gsd[j] = row_derivative[j] * jac[j];
2077 he[j] = row_exit[j] * curvature[j];
2078 hs[j] = row_entry[j] * curvature[j];
2079 hsd[j] = row_derivative[j] * curvature[j];
2080 te[j] = row_exit[j] * third[j];
2081 ts[j] = row_entry[j] * third[j];
2082 tsd[j] = row_derivative[j] * third[j];
2083 deta_e += ge[j] * u_k[j];
2084 if has_entry {
2085 deta_s += gs[j] * u_k[j];
2086 }
2087 ds += gsd[j] * u_k[j];
2088 }
2089
2090 for r in 0..p {
2092 let dge_r = he[r] * u_k[r];
2093 let dgs_r = hs[r] * u_k[r];
2094 let dhe_r = te[r] * u_k[r];
2095 let dhs_r = ts[r] * u_k[r];
2096 for c in 0..p {
2097 let dge_c = he[c] * u_k[c];
2098 let dgs_c = hs[c] * u_k[c];
2099 let mut d_h_rc =
2100 exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2101 if r == c {
2102 d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2103 }
2104 if has_entry {
2105 d_h_rc -=
2106 exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2107 if r == c {
2108 d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2109 }
2110 }
2111 b_dir[[r, c]] += w_i * d_h_rc;
2112 }
2113 }
2114
2115 let s_i = self
2117 .stabilized_structural_derivative(deriv_raw[i])
2118 .unwrap_or(deriv_raw[i]);
2119 if !s_i.is_finite() {
2120 return Err(EstimationError::ParameterConstraintViolation(format!(
2121 "survival monotonicity violated in unified trace contraction at row {i}: \
2122 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2123 )));
2124 }
2125 if self.event_target[i] > 0 {
2126 if s_i < guard_numerical {
2127 return Err(EstimationError::ParameterConstraintViolation(format!(
2128 "survival monotonicity violated in unified trace contraction at row {i}: \
2129 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2130 )));
2131 }
2132 let inv_s = 1.0 / s_i;
2133 let inv_s2 = inv_s * inv_s;
2134 let inv_s3 = inv_s2 * inv_s;
2135 for r in 0..p {
2136 let dgd_r = hsd[r] * u_k[r];
2137 let dtsd_r = tsd[r] * u_k[r];
2138 let dte_r = te[r] * u_k[r];
2139 for c in 0..p {
2140 let dgd_c = hsd[c] * u_k[c];
2141 let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2142 - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2143 if r == c {
2144 d_h_rc += -dte_r;
2145 d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2146 }
2147 b_dir[[r, c]] += w_i * d_h_rc;
2148 }
2149 }
2150 }
2151 }
2152
2153 Ok(b_dir)
2154 }
2155
2156 pub fn offset_channel_residuals(
2194 &self,
2195 beta: &Array1<f64>,
2196 ) -> Result<OffsetChannelResiduals, EstimationError> {
2197 if beta.len() != self.coefficient_dim() {
2198 crate::bail_invalid_estim!(
2199 "survival beta dimension mismatch in offset_channel_residuals"
2200 );
2201 }
2202 let n = self.nrows();
2203 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2204 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2205 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2206
2207 let derivative_guard_numerical = self.derivative_guard_numerical();
2208 let mut r_exit = Array1::<f64>::zeros(n);
2209 let mut r_entry = Array1::<f64>::zeros(n);
2210 let mut r_deriv = Array1::<f64>::zeros(n);
2211
2212 for i in 0..n {
2213 let w = self.sampleweight[i];
2214 if w <= 0.0 {
2215 continue;
2216 }
2217 let entry_age = self.age_entry[i];
2218 let exit_age = self.age_exit[i];
2219 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2220 crate::bail_invalid_estim!(
2221 "survival ages must be finite with age_exit >= age_entry"
2222 );
2223 }
2224 let has_entry_interval = !self.entry_at_origin[i];
2225 let d = f64::from(self.event_target[i]);
2226 let w_exit_i = w * eta_exit[i].exp();
2230 let w_entry_i = if has_entry_interval {
2231 w * eta_entry[i].exp()
2232 } else {
2233 0.0
2234 };
2235 if !w_exit_i.is_finite() {
2236 crate::bail_invalid_estim!(
2237 "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2238 );
2239 }
2240 r_exit[i] = w_exit_i - d * w;
2241 r_entry[i] = -w_entry_i;
2242 let deriv_raw = derivative_raw[i];
2247 let deriv = self
2248 .stabilized_structural_derivative(deriv_raw)
2249 .unwrap_or(deriv_raw);
2250 let mono_floor = if d > 0.0 {
2251 derivative_guard_numerical
2252 } else {
2253 0.0
2254 };
2255 if !deriv.is_finite() || deriv < mono_floor {
2256 return Err(EstimationError::ParameterConstraintViolation(format!(
2257 "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2258 )));
2259 }
2260 if d > 0.0 {
2261 r_deriv[i] = -w * d / deriv;
2262 }
2263 }
2264
2265 let right = Array1::<f64>::zeros(r_exit.len());
2266 Ok(OffsetChannelResiduals {
2267 exit: r_exit,
2268 entry: r_entry,
2269 derivative: r_deriv,
2270 right,
2271 })
2272 }
2273
2274 pub fn unified_lamlobjective_and_rhogradient(
2280 &self,
2281 beta: &Array1<f64>,
2282 state: &WorkingState,
2283 rho: &Array1<f64>,
2284 ) -> Result<(f64, Array1<f64>), EstimationError> {
2285 use gam_solve::estimate::reml::assembly::{
2286 InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2287 };
2288 use gam_solve::estimate::reml::reml_outer_engine::{
2289 DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2290 compute_block_penalty_logdet_derivs,
2291 };
2292 use gam_problem::{EvalMode, PseudoLogdetMode};
2293
2294 let p = beta.len();
2295 let active_penalty_blocks: Vec<&PenaltyBlock> = self
2296 .penalties
2297 .blocks
2298 .iter()
2299 .filter(|b| b.lambda > 0.0)
2300 .collect();
2301 if rho.len() != active_penalty_blocks.len() {
2302 crate::bail_invalid_estim!(
2303 "survival LAML rho dimension {} does not match active penalty block count {}",
2304 rho.len(),
2305 active_penalty_blocks.len()
2306 );
2307 }
2308 let k_count = active_penalty_blocks.len();
2309
2310 let h_dense = state.hessian.to_dense();
2312 let has_left_truncation = self
2313 .age_entry
2314 .iter()
2315 .any(|&t| t > ENTRY_AT_ORIGIN_THRESHOLD);
2316 let hessian_logdet_mode = if has_left_truncation {
2328 PseudoLogdetMode::HardPseudo
2329 } else {
2330 PseudoLogdetMode::Smooth
2331 };
2332 let hop = DenseSpectralOperator::from_symmetric_with_mode(&h_dense, hessian_logdet_mode)
2333 .map_err(EstimationError::InvalidInput)?;
2334
2335 let block_descs: Vec<PenaltyBlockDesc> = self
2337 .penalties
2338 .blocks
2339 .iter()
2340 .filter(|b| b.lambda > 0.0)
2341 .map(|b| PenaltyBlockDesc {
2342 matrix: &b.matrix,
2343 range_start: b.range.start,
2344 range_end: b.range.end,
2345 })
2346 .collect();
2347 let penalty_coords =
2348 penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2349
2350 let per_block_rho: Vec<Array1<f64>> =
2352 rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2353 let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2354 .iter()
2355 .map(|b| vec![b.matrix.clone()])
2356 .collect();
2357 let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2358 .iter()
2359 .map(|v| v.as_slice())
2360 .collect();
2361 let penalty_logdet = if k_count > 0 {
2362 compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2363 .map_err(EstimationError::InvalidInput)?
2364 } else {
2365 PenaltyLogdetDerivs {
2366 value: 0.0,
2367 first: Array1::zeros(0),
2368 second: Some(Array2::zeros((0, 0))),
2369 }
2370 };
2371
2372 let penalty_quadratic = 2.0 * state.penalty_term;
2374 let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2375
2376 const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2390 let kkt_residual = {
2391 let raw = state.gradient.clone();
2392 let projected = match self.monotonicity_linear_constraints() {
2393 Some(constraints) => {
2394 projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2395 .ok_or_else(|| {
2396 EstimationError::InvalidInput(
2397 "survival LAML could not project the monotonicity KKT residual"
2398 .to_string(),
2399 )
2400 })?
2401 }
2402 None => raw,
2403 };
2404 let projected_norm = array1_l2_norm(&projected);
2405 let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2406 if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2407 Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2408 } else {
2409 None
2410 }
2411 };
2412
2413 let result = InnerAssembly {
2414 log_likelihood: state.log_likelihood,
2415 penalty_quadratic,
2416 beta: beta.clone(),
2417 n_observations: self.nrows(),
2418 hessian_op: std::sync::Arc::new(hop),
2419 penalty_coords,
2420 penalty_logdet,
2421 dispersion: DispersionHandling::Fixed {
2422 phi: 1.0,
2423 include_logdet_h: true,
2424 include_logdet_s: true,
2425 },
2426 rho_curvature_scale: 1.0,
2427 rho_prior: gam_problem::RhoPrior::Flat,
2428 hessian_logdet_correction: 0.0,
2429 penalty_subspace_trace: None,
2430 deriv_provider: Some(Box::new(provider)),
2431 firth: None,
2432 nullspace_dim: None,
2433 barrier_config: None,
2434 ext_coords: Vec::new(),
2435 ext_coord_pair_fn: None,
2436 rho_ext_pair_fn: None,
2437 fixed_drift_deriv: None,
2438 contracted_psi_second_order: None,
2439 kkt_residual,
2440 active_constraints: None,
2441 }
2442 .evaluate(
2443 rho.as_slice().expect("rho must be contiguous"),
2444 EvalMode::ValueAndGradient,
2445 None,
2446 )
2447 .map_err(EstimationError::InvalidInput)?;
2448
2449 let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2450 Ok((result.cost, gradient))
2451 }
2452
2453 pub fn evaluate_survival_lamlcost_and_gradient(
2474 &self,
2475 rho: &[f64],
2476 beta0: &Array1<f64>,
2477 ) -> Result<(f64, Array1<f64>), EstimationError> {
2478 let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2479 let rho_arr = Array1::from_vec(rho.to_vec());
2484 let state = candidate.update_state(&beta)?;
2485 candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2486 }
2487
2488 fn reconverge_survival_inner_mode(
2498 &self,
2499 rho: &[f64],
2500 beta0: &Array1<f64>,
2501 ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2502 const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2507 const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2508 const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2509 const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2510
2511 let active_block_count = self
2512 .penalties
2513 .blocks
2514 .iter()
2515 .filter(|b| b.lambda > 0.0)
2516 .count();
2517 if rho.len() != active_block_count {
2518 crate::bail_invalid_estim!(
2519 "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2520 rho.len(),
2521 active_block_count
2522 );
2523 }
2524 if beta0.len() != self.coefficient_dim() {
2525 crate::bail_invalid_estim!(
2526 "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2527 beta0.len(),
2528 self.coefficient_dim()
2529 );
2530 }
2531
2532 let mut candidate = self.clone();
2535 let mut lambdas: Vec<f64> = candidate
2536 .penalties
2537 .blocks
2538 .iter()
2539 .map(|b| b.lambda)
2540 .collect();
2541 let mut active_idx = 0usize;
2542 for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2543 if block.lambda > 0.0 {
2544 *lambda = rho[active_idx].exp();
2545 active_idx += 1;
2546 }
2547 }
2548 candidate.set_penalty_lambdas(&lambdas)?;
2549
2550 let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2551 max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2552 convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2553 adaptive_kkt_tolerance: None,
2554 max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2555 min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2556 firth_bias_reduction: false,
2557 coefficient_lower_bounds: None,
2558 linear_constraints: None,
2559 initial_lm_lambda: None,
2560 geodesic_acceleration: false,
2561 arrow_schur: None,
2562 };
2563 let summary = gam_solve::pirls::runworking_model_pirls(
2564 &mut candidate,
2565 Coefficients::new(beta0.clone()),
2566 &opts,
2567 |_| {},
2568 )?;
2569 let mut beta = summary.beta.as_ref().to_owned();
2570
2571 {
2593 const POLISH_MAX_ITERS: usize = 400;
2594 const POLISH_TOL: f64 = 1e-13;
2595 const ARMIJO_C: f64 = 1e-4;
2597 const BACKTRACK: f64 = 0.5;
2598 const MAX_BACKTRACK: usize = 80;
2599 let p = beta.len();
2600 let penalized_objective =
2604 |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2605 for _ in 0..POLISH_MAX_ITERS {
2606 let st = match candidate.update_state(&beta) {
2607 Ok(st) => st,
2608 Err(_) => break,
2609 };
2610 let r = st.gradient.clone();
2611 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2612 if !r_norm.is_finite() || r_norm < POLISH_TOL {
2613 break;
2614 }
2615 let h = st.hessian.to_dense();
2616 let f0 = penalized_objective(&st);
2617 let h_scale = (0..p)
2632 .map(|d| h[[d, d]].abs())
2633 .fold(0.0_f64, f64::max)
2634 .max(1.0);
2635 let mut step: Option<Array1<f64>> = None;
2646 let mut dir_deriv = 0.0_f64;
2647 for lm_pow in 0..18 {
2648 let lambda_lm = if lm_pow == 0 {
2649 0.0
2650 } else {
2651 1e-12 * h_scale * 10f64.powi(lm_pow)
2652 };
2653 let mut h_reg = h.clone();
2654 for d in 0..p {
2655 h_reg[[d, d]] += lambda_lm;
2656 }
2657 let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2658 &h_reg,
2659 faer::Side::Lower,
2660 ) {
2661 Ok(f) => f,
2662 Err(_) => continue,
2663 };
2664 let candidate_step = factor.solvevec(&r);
2665 if candidate_step.iter().any(|v| !v.is_finite()) {
2666 continue;
2667 }
2668 let dd = -r.dot(&candidate_step);
2670 if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2671 step = Some(candidate_step);
2672 dir_deriv = dd;
2673 break;
2674 }
2675 }
2676 let (step, dir_deriv) = match step {
2677 Some(s) => (s, dir_deriv),
2678 None => {
2679 (r.clone(), -r_norm * r_norm)
2682 }
2683 };
2684 let mut alpha = 1.0_f64;
2685 let mut accepted = false;
2686 for _ in 0..MAX_BACKTRACK {
2687 let trial = &beta - &(alpha * &step);
2688 if let Ok(ts) = candidate.update_state(&trial) {
2689 let ft = penalized_objective(&ts);
2690 let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2691 let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2702 let residual_ok = tn.is_finite() && tn < r_norm;
2703 if armijo_ok || residual_ok {
2704 beta = trial;
2705 accepted = true;
2706 break;
2707 }
2708 }
2709 alpha *= BACKTRACK;
2710 }
2711 if !accepted {
2712 break;
2713 }
2714 }
2715 }
2716
2717 Ok((candidate, beta))
2718 }
2719}
2720
2721pub(crate) struct SurvivalDerivProvider {
2730 model: WorkingModelSurvival,
2731 beta: Array1<f64>,
2732}
2733
2734impl SurvivalDerivProvider {
2735 pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2736 Self { model, beta }
2737 }
2738}
2739
2740impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2741 fn hessian_derivative_correction(
2742 &self,
2743 v_k: &Array1<f64>,
2744 ) -> Result<Option<Array2<f64>>, String> {
2745 let u_k = -v_k;
2748 match self
2749 .model
2750 .survival_hessian_derivative_correction(&self.beta, &u_k)
2751 {
2752 Ok(correction) => Ok(Some(correction)),
2753 Err(e) => Err(e.to_string()),
2754 }
2755 }
2756
2757 fn has_corrections(&self) -> bool {
2758 true
2759 }
2760}
2761
2762#[derive(Debug, Clone)]
2763pub struct CrudeRiskResult {
2764 pub risk: f64,
2765 pub diseasegradient: Array1<f64>,
2766 pub mortalitygradient: Array1<f64>,
2767}
2768
2769#[derive(Debug, Clone)]
2770pub struct CompetingRisksCifResult {
2771 pub cif: Vec<Array2<f64>>,
2776 pub overall_survival: Array2<f64>,
2777}
2778
2779const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2784
2785pub fn assemble_competing_risks_cif(
2786 times: ArrayView1<'_, f64>,
2787 cumulative_hazard: ArrayView3<'_, f64>,
2788) -> Result<CompetingRisksCifResult, SurvivalError> {
2789 let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2790 if n_endpoints == 0 {
2791 return Err(SurvivalError::DimensionMismatch);
2792 }
2793 let endpoint_hazards = cumulative_hazard
2794 .axis_iter(Axis(0))
2795 .map(|view| view.to_owned())
2796 .collect::<Vec<_>>();
2797 assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2798 if result.overall_survival.dim() != (n_rows, n_times) {
2799 Err(SurvivalError::DimensionMismatch)
2800 } else {
2801 Ok(result)
2802 }
2803 })
2804}
2805
2806pub fn assemble_competing_risks_cif_from_endpoints(
2807 times: ArrayView1<'_, f64>,
2808 cumulative_hazards: &[Array2<f64>],
2809) -> Result<CompetingRisksCifResult, SurvivalError> {
2810 let n_endpoints = cumulative_hazards.len();
2811 if n_endpoints == 0 || times.is_empty() {
2812 return Err(SurvivalError::DimensionMismatch);
2813 }
2814 let (n_rows, n_times) = cumulative_hazards[0].dim();
2815 if n_rows == 0 || n_times == 0 || times.len() != n_times {
2816 return Err(SurvivalError::DimensionMismatch);
2817 }
2818 if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2819 return Err(SurvivalError::InvalidTimeGrid);
2820 }
2821 if times
2822 .iter()
2823 .zip(times.iter().skip(1))
2824 .any(|(previous, current)| current <= previous)
2825 {
2826 return Err(SurvivalError::InvalidTimeGrid);
2827 }
2828 for endpoint_hazard in cumulative_hazards {
2829 if endpoint_hazard.dim() != (n_rows, n_times) {
2830 return Err(SurvivalError::DimensionMismatch);
2831 }
2832 if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2833 return Err(SurvivalError::NonFiniteInput);
2834 }
2835 }
2836
2837 let max_abs_hazard = cumulative_hazards
2838 .iter()
2839 .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2840 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2841 let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2842 let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2843 .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2844 .collect();
2845 let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2846
2847 let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2859 let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2860 let mut surv_row = vec![0.0_f64; n_times];
2861 let mut previous_cif = vec![0.0_f64; n_endpoints];
2862 let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2863 let mut increments = vec![0.0_f64; n_endpoints];
2864 let mut previous_total_cumulative = 0.0_f64;
2865 for time_idx in 0..n_times {
2866 let mut total_increment = 0.0_f64;
2867 for endpoint in 0..n_endpoints {
2868 let current = cumulative_hazards[endpoint][[row, time_idx]];
2869 if current < -monotone_tolerance {
2870 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2871 }
2872 let raw_increment = current - previous_cumulative[endpoint];
2873 if raw_increment < -monotone_tolerance {
2874 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2875 }
2876 let increment = raw_increment.max(0.0);
2877 increments[endpoint] = increment;
2878 total_increment += increment;
2879 previous_cumulative[endpoint] += increment;
2880 }
2881
2882 let survival_left = (-previous_total_cumulative).exp();
2883 let interval_failure = -(-total_increment).exp_m1();
2884 for endpoint in 0..n_endpoints {
2885 if total_increment > 0.0 {
2886 previous_cif[endpoint] +=
2887 survival_left * interval_failure * increments[endpoint] / total_increment;
2888 }
2889 cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2890 }
2891 previous_total_cumulative += total_increment;
2892 let mut fsum_at_t = 0.0_f64;
2909 for endpoint in 0..n_endpoints {
2910 fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2911 }
2912 surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2913 }
2914 Ok((cif_flat, surv_row))
2915 };
2916
2917 let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2921 && rayon::current_thread_index().is_none()
2922 {
2923 use rayon::prelude::*;
2924 (0..n_rows)
2925 .into_par_iter()
2926 .map(assemble_row)
2927 .collect::<Result<_, _>>()?
2928 } else {
2929 (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2930 };
2931
2932 for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2933 for endpoint in 0..n_endpoints {
2934 for time_idx in 0..n_times {
2935 cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2936 }
2937 }
2938 for time_idx in 0..n_times {
2939 overall_survival[[row, time_idx]] = surv_row[time_idx];
2940 }
2941 }
2942
2943 Ok(CompetingRisksCifResult {
2944 cif,
2945 overall_survival,
2946 })
2947}
2948
2949fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2950 let mut nodesweights = Vec::with_capacity(n);
2951 let m = n.div_ceil(2);
2952
2953 for i in 0..m {
2954 let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2955 let mut pp = 0.0;
2956
2957 for _ in 0..100 {
2958 let mut p1 = 1.0;
2959 let mut p2 = 0.0;
2960 for j in 0..n {
2961 let p3 = p2;
2962 p2 = p1;
2963 p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2964 }
2965 pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2966 let z_prev = z;
2967 z = z_prev - p1 / pp;
2968 if (z - z_prev).abs() < 1e-14 {
2969 break;
2970 }
2971 }
2972
2973 let x = z;
2974 let w = 2.0 / ((1.0 - z * z) * pp * pp);
2975 if !n.is_multiple_of(2) && i == m - 1 {
2976 nodesweights.push((0.0, w));
2977 } else {
2978 nodesweights.push((-x, w));
2979 nodesweights.push((x, w));
2980 }
2981 }
2982
2983 nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2984 nodesweights
2985}
2986
2987fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2988 static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2994 &CACHE
2995}
2996
2997pub fn calculate_crude_risk_quadrature<F>(
3021 t0: f64,
3022 t1: f64,
3023 breakpoints: &[f64],
3024 h_dis_t0: f64,
3025 h_mor_t0: f64,
3026 design_d_t0: ArrayView1<'_, f64>,
3027 design_m_t0: ArrayView1<'_, f64>,
3028 mut eval_at: F,
3029) -> Result<CrudeRiskResult, SurvivalError>
3030where
3031 F: FnMut(
3032 f64,
3033 &mut Array1<f64>,
3034 &mut Array1<f64>,
3035 &mut Array1<f64>,
3036 ) -> Result<(f64, f64, f64), SurvivalError>,
3037{
3038 let coeff_len_d = design_d_t0.len();
3039 let coeff_len_m = design_m_t0.len();
3040 if coeff_len_d == 0 || coeff_len_m == 0 {
3041 return Err(SurvivalError::InvalidIntegrationSetup);
3042 }
3043 if !t0.is_finite()
3044 || !t1.is_finite()
3045 || !h_dis_t0.is_finite()
3046 || !h_mor_t0.is_finite()
3047 || design_d_t0.iter().any(|v| !v.is_finite())
3048 || design_m_t0.iter().any(|v| !v.is_finite())
3049 {
3050 return Err(SurvivalError::NonFiniteInput);
3051 }
3052 if t1 <= t0 {
3053 return Ok(CrudeRiskResult {
3054 risk: 0.0,
3055 diseasegradient: Array1::zeros(coeff_len_d),
3056 mortalitygradient: Array1::zeros(coeff_len_m),
3057 });
3058 }
3059
3060 let mut sorted_breaks: Vec<f64> = breakpoints
3061 .iter()
3062 .copied()
3063 .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3064 .collect();
3065 sorted_breaks.push(t0);
3066 sorted_breaks.push(t1);
3067 sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3068 sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3069 if sorted_breaks.len() < 2 {
3070 return Err(SurvivalError::InvalidIntegrationSetup);
3071 }
3072
3073 let mut total_risk = 0.0;
3074 let mut diseasegradient = Array1::zeros(coeff_len_d);
3075 let mut mortalitygradient = Array1::zeros(coeff_len_m);
3076 let nodesweights = gauss_legendre_quadrature();
3077
3078 let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3079 let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3080 let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3081
3082 for segment in sorted_breaks.windows(2) {
3083 let a = segment[0];
3084 let b = segment[1];
3085 let center = 0.5 * (b + a);
3086 let halfwidth = 0.5 * (b - a);
3087 if halfwidth <= 0.0 {
3088 continue;
3089 }
3090
3091 for &(x, w) in nodesweights {
3092 let u = center + halfwidth * x;
3093 let (inst_hazard_d, hazard_d, hazard_m) =
3094 eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3095 if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3096 return Err(SurvivalError::NonFiniteInput);
3097 }
3098 if inst_hazard_d <= 0.0 {
3099 return Err(SurvivalError::NonPositiveHazard);
3100 }
3101
3102 if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3103 return Err(SurvivalError::NonMonotoneCumulativeHazard);
3104 }
3105
3106 let h_dis_cond = hazard_d - h_dis_t0;
3107 let h_mor_cond = hazard_m - h_mor_t0;
3108 let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3109
3110 total_risk += w * inst_hazard_d * s_total * halfwidth;
3111
3112 let weight = w * s_total * halfwidth;
3118 for j in 0..coeff_len_d {
3119 let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3120 let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3121 let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3122 diseasegradient[j] += weight * g;
3123 }
3124
3125 let weight = w * inst_hazard_d * s_total * halfwidth;
3128 for j in 0..coeff_len_m {
3129 let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3130 mortalitygradient[j] += weight * g;
3131 }
3132 }
3133 }
3134
3135 Ok(CrudeRiskResult {
3136 risk: total_risk,
3137 diseasegradient,
3138 mortalitygradient,
3139 })
3140}
3141
3142impl PirlsWorkingModel for WorkingModelSurvival {
3143 fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3144 self.update_state(beta)
3145 }
3146}
3147
3148#[cfg(test)]
3149mod tests {
3150 use super::*;
3151 use ndarray::{Array1, Array2, Array3, array, s};
3152
3153 #[test]
3154 fn competing_risks_cif_constant_hazard_matches_closed_form() {
3155 let times = array![0.0, 2.0, 5.0, 10.0];
3156 let disease_rates = [0.12, 0.06];
3157 let death_rates = [0.05, 0.02];
3158 let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3159 let rate = if endpoint == 0 {
3160 disease_rates[row]
3161 } else {
3162 death_rates[row]
3163 };
3164 rate * times[time_idx]
3165 });
3166
3167 let result =
3168 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3169
3170 for row in 0..2 {
3171 let total_rate = disease_rates[row] + death_rates[row];
3172 for time_idx in 0..times.len() {
3173 let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3174 let expected_disease = disease_rates[row] / total_rate * failure;
3175 let expected_death = death_rates[row] / total_rate * failure;
3176 assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3177 assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3178 assert!(
3179 (result.cif[0][[row, time_idx]]
3180 + result.cif[1][[row, time_idx]]
3181 + result.overall_survival[[row, time_idx]]
3182 - 1.0)
3183 .abs()
3184 < 1e-12
3185 );
3186 }
3187 }
3188 }
3189
3190 #[test]
3191 fn competing_risks_cif_rejects_nonmonotone_hazards() {
3192 let times = array![0.0, 1.0, 2.0];
3193 let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3194 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3195 .expect_err("nonmonotone cumulative hazard should be rejected");
3196 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3197 }
3198
3199 #[test]
3200 fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3201 let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3202 let cumulative = Array3::from_shape_vec(
3203 (3, 2, 5),
3204 vec![
3205 0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, 0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, 0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3209 ],
3210 )
3211 .expect("shape");
3212
3213 let result =
3214 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3215
3216 for row in 0..2 {
3217 for time_idx in 0..times.len() {
3218 let total_cif = result.cif[0][[row, time_idx]]
3219 + result.cif[1][[row, time_idx]]
3220 + result.cif[2][[row, time_idx]];
3221 assert!(
3222 (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3223 "probability mass mismatch at row={row}, time_idx={time_idx}"
3224 );
3225 assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3226 for cause in 0..3 {
3227 assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3228 if time_idx > 0 {
3229 assert!(
3230 result.cif[cause][[row, time_idx]] + 1e-12
3231 >= result.cif[cause][[row, time_idx - 1]],
3232 "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3233 );
3234 }
3235 }
3236 }
3237 }
3238
3239 assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3242 assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3245 assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3246 }
3247
3248 #[test]
3249 fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3250 let cumulative = Array3::zeros((2, 1, 2));
3251
3252 for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3253 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3254 .expect_err("bad time grid should be rejected");
3255 assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3256 }
3257
3258 let times = array![0.0, 1.0];
3259 let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3260 let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3261 .expect_err("nonfinite hazard should be rejected");
3262 assert!(matches!(err, SurvivalError::NonFiniteInput));
3263 }
3264
3265 #[test]
3266 fn competing_risks_cif_extreme_hazards_remain_bounded() {
3267 let times = array![0.0, 1.0, 2.0];
3268 let cumulative =
3269 Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3270 .expect("shape");
3271
3272 let result =
3273 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3274
3275 for value in result
3276 .cif
3277 .iter()
3278 .flat_map(|m| m.iter())
3279 .chain(result.overall_survival.iter())
3280 {
3281 assert!(value.is_finite());
3282 assert!((0.0..=1.0).contains(value));
3283 }
3284 assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3285 assert_eq!(result.overall_survival[[0, 2]], 0.0);
3286 }
3287
3288 fn toy_penalties() -> PenaltyBlocks {
3289 let s = array![[2.0, 0.5], [0.5, 3.0]];
3290 PenaltyBlocks::new(vec![PenaltyBlock {
3291 matrix: s,
3292 lambda: 1.7,
3293 range: 1..3,
3294 nullspace_dim: 0,
3295 }])
3296 }
3297
3298 fn survival_inputs<'a>(
3299 age_entry: &'a Array1<f64>,
3300 age_exit: &'a Array1<f64>,
3301 event_target: &'a Array1<u8>,
3302 event_competing: &'a Array1<u8>,
3303 sampleweight: &'a Array1<f64>,
3304 x_entry: &'a Array2<f64>,
3305 x_exit: &'a Array2<f64>,
3306 x_derivative: &'a Array2<f64>,
3307 ) -> SurvivalEngineInputs<'a> {
3308 SurvivalEngineInputs {
3309 age_entry: age_entry.view(),
3310 age_exit: age_exit.view(),
3311 event_target: event_target.view(),
3312 event_competing: event_competing.view(),
3313 sampleweight: sampleweight.view(),
3314 x_entry: x_entry.view(),
3315 x_exit: x_exit.view(),
3316 x_derivative: x_derivative.view(),
3317 monotonicity_constraint_rows: None,
3318 monotonicity_constraint_offsets: None,
3319 }
3320 }
3321
3322 fn survival_model(
3323 inputs: SurvivalEngineInputs<'_>,
3324 penalties: PenaltyBlocks,
3325 monotonicity: SurvivalMonotonicityPenalty,
3326 spec: SurvivalSpec,
3327 ) -> Result<WorkingModelSurvival, SurvivalError> {
3328 WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3329 }
3330
3331 fn survival_model_with_offsets(
3332 inputs: SurvivalEngineInputs<'_>,
3333 offsets: Option<SurvivalBaselineOffsets<'_>>,
3334 penalties: PenaltyBlocks,
3335 monotonicity: SurvivalMonotonicityPenalty,
3336 spec: SurvivalSpec,
3337 ) -> Result<WorkingModelSurvival, SurvivalError> {
3338 WorkingModelSurvival::from_engine_inputswith_offsets(
3339 inputs,
3340 offsets,
3341 penalties,
3342 monotonicity,
3343 spec,
3344 )
3345 }
3346
3347 #[test]
3348 fn penaltyhessian_matchesgradient_jacobian() {
3349 let penalties = toy_penalties();
3350 let beta = array![10.0, -0.3, 1.2, 7.0];
3351
3352 let grad = penalties.gradient(&beta);
3353 let h = penalties.hessian(beta.len());
3354 let b_block = beta.slice(s![1..3]).to_owned();
3355 let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3356
3357 assert!((grad[1] - expected[0]).abs() < 1e-12);
3358 assert!((grad[2] - expected[1]).abs() < 1e-12);
3359 assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3360 assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3361 assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3362 assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3363 }
3364
3365 #[test]
3366 fn penaltygradient_matches_deviance_finite_difference() {
3367 let penalties = toy_penalties();
3368 let beta = array![10.0, -0.3, 1.2, 7.0];
3369 let grad = penalties.gradient(&beta);
3370 let eps = 1e-7;
3371
3372 for idx in 0..beta.len() {
3373 let mut plus = beta.clone();
3374 let mut minus = beta.clone();
3375 plus[idx] += eps;
3376 minus[idx] -= eps;
3377 let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3378 assert_eq!(
3379 grad[idx].signum(),
3380 fd.signum(),
3381 "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3382 grad[idx]
3383 );
3384 assert!(
3385 (grad[idx] - fd).abs() < 1e-6,
3386 "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3387 grad[idx]
3388 );
3389 }
3390 }
3391
3392 #[test]
3393 fn zero_offsets_match_default_survival_state() {
3394 let age_entry = array![1.0_f64, 2.0_f64];
3395 let age_exit = array![2.0_f64, 3.5_f64];
3396 let event_target = array![1u8, 0u8];
3397 let event_competing = array![0u8, 0u8];
3398 let sampleweight = array![1.0, 1.0];
3399 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3400 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3401 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3402 let penalties = PenaltyBlocks::new(Vec::new());
3403 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3404 let beta = array![-1.0, 0.8];
3405
3406 let base = survival_model(
3407 survival_inputs(
3408 &age_entry,
3409 &age_exit,
3410 &event_target,
3411 &event_competing,
3412 &sampleweight,
3413 &x_entry,
3414 &x_exit,
3415 &x_derivative,
3416 ),
3417 penalties.clone(),
3418 mono,
3419 SurvivalSpec::Net,
3420 )
3421 .expect("construct base survival model");
3422
3423 let zero_offsets = survival_model_with_offsets(
3424 survival_inputs(
3425 &age_entry,
3426 &age_exit,
3427 &event_target,
3428 &event_competing,
3429 &sampleweight,
3430 &x_entry,
3431 &x_exit,
3432 &x_derivative,
3433 ),
3434 Some(SurvivalBaselineOffsets {
3435 eta_entry: array![0.0, 0.0].view(),
3436 eta_exit: array![0.0, 0.0].view(),
3437 derivative_exit: array![0.0, 0.0].view(),
3438 }),
3439 penalties,
3440 mono,
3441 SurvivalSpec::Net,
3442 )
3443 .expect("construct offset survival model");
3444
3445 let state_base = base.update_state(&beta).expect("base state");
3446 let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3447 assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3448 assert!(
3449 state_base
3450 .gradient
3451 .iter()
3452 .zip(statezero.gradient.iter())
3453 .all(|(a, b)| (a - b).abs() < 1e-12)
3454 );
3455 }
3456
3457 #[test]
3458 fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3459 let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3473 let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3474 let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3476 let event_competing = Array1::<u8>::zeros(cause_labels.len());
3477 let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3478 let x_entry = array![
3479 [1.0, age_entry[0].max(1e-8).ln()],
3480 [1.0, age_entry[1].max(1e-8).ln()],
3481 [1.0, age_entry[2].max(1e-8).ln()],
3482 [1.0, age_entry[3].max(1e-8).ln()],
3483 ];
3484 let x_exit = array![
3485 [1.0, age_exit[0].ln()],
3486 [1.0, age_exit[1].ln()],
3487 [1.0, age_exit[2].ln()],
3488 [1.0, age_exit[3].ln()],
3489 ];
3490 let x_derivative = array![
3491 [0.0, 1.0 / age_exit[0]],
3492 [0.0, 1.0 / age_exit[1]],
3493 [0.0, 1.0 / age_exit[2]],
3494 [0.0, 1.0 / age_exit[3]],
3495 ];
3496 let penalties = PenaltyBlocks::new(Vec::new());
3497 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3498
3499 let raw = survival_model(
3504 survival_inputs(
3505 &age_entry,
3506 &age_exit,
3507 &cause_labels,
3508 &event_competing,
3509 &sampleweight,
3510 &x_entry,
3511 &x_exit,
3512 &x_derivative,
3513 ),
3514 penalties.clone(),
3515 mono,
3516 SurvivalSpec::Net,
3517 );
3518 assert!(
3519 matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3520 "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3521 );
3522
3523 let any_event = pooled_any_event_indicator(cause_labels.view());
3526 assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3527 assert_eq!(
3529 cause_specific_event_indicator(cause_labels.view(), 1),
3530 array![0u8, 1u8, 0u8, 0u8]
3531 );
3532 assert_eq!(
3533 cause_specific_event_indicator(cause_labels.view(), 2),
3534 array![0u8, 0u8, 1u8, 0u8]
3535 );
3536 let model = survival_model(
3537 survival_inputs(
3538 &age_entry,
3539 &age_exit,
3540 &any_event,
3541 &event_competing,
3542 &sampleweight,
3543 &x_entry,
3544 &x_exit,
3545 &x_derivative,
3546 ),
3547 penalties,
3548 mono,
3549 SurvivalSpec::Net,
3550 )
3551 .expect("pooled any-event baseline model must construct from competing-risks data");
3552
3553 let beta = array![-1.0_f64, 0.8];
3556 let state = model.update_state(&beta).expect("pooled baseline state");
3557 assert!(
3558 state.deviance.is_finite(),
3559 "pooled baseline deviance must be finite, got {}",
3560 state.deviance
3561 );
3562 assert!(
3563 state.gradient.iter().all(|g| g.is_finite()),
3564 "pooled baseline gradient must be finite"
3565 );
3566 }
3567
3568 #[test]
3569 fn offset_channel_residuals_match_central_fd_of_nll() {
3570 let age_entry = array![0.5_f64, 0.0, 0.3];
3575 let age_exit = array![1.4_f64, 1.0, 2.0];
3576 let event_target = array![1u8, 1u8, 0u8];
3577 let event_competing = array![0u8, 0u8, 0u8];
3578 let sampleweight = array![1.0_f64, 2.5, 0.7];
3579 let x_entry = array![
3580 [1.0, age_entry[0].ln()],
3581 [1.0, age_entry[1].max(1e-8).ln()],
3582 [1.0, age_entry[2].ln()]
3583 ];
3584 let x_exit = array![
3585 [1.0, age_exit[0].ln()],
3586 [1.0, age_exit[1].ln()],
3587 [1.0, age_exit[2].ln()]
3588 ];
3589 let x_derivative = array![
3590 [0.0, 1.0 / age_exit[0]],
3591 [0.0, 1.0 / age_exit[1]],
3592 [0.0, 1.0 / age_exit[2]]
3593 ];
3594 let o_entry = array![0.2_f64, 0.0, 0.1];
3597 let o_exit = array![0.4_f64, 0.5, 0.7];
3598 let o_deriv = array![0.3_f64, 0.8, 0.5];
3599 let penalties = PenaltyBlocks::new(Vec::new());
3600 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3601 let beta = array![-0.7_f64, 0.6];
3602
3603 let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3604 survival_model_with_offsets(
3605 survival_inputs(
3606 &age_entry,
3607 &age_exit,
3608 &event_target,
3609 &event_competing,
3610 &sampleweight,
3611 &x_entry,
3612 &x_exit,
3613 &x_derivative,
3614 ),
3615 Some(SurvivalBaselineOffsets {
3616 eta_entry: o_e.view(),
3617 eta_exit: o_x.view(),
3618 derivative_exit: o_d.view(),
3619 }),
3620 penalties.clone(),
3621 mono,
3622 SurvivalSpec::Net,
3623 )
3624 .expect("model build")
3625 };
3626
3627 let base = build(&o_entry, &o_exit, &o_deriv);
3628 let resid = base
3629 .offset_channel_residuals(&beta)
3630 .expect("offset residuals");
3631 assert_eq!(resid.exit.len(), 3);
3632 assert_eq!(resid.entry.len(), 3);
3633 assert_eq!(resid.derivative.len(), 3);
3634
3635 let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3638 let h = 1e-6;
3639
3640 assert_eq!(resid.entry[1], 0.0);
3644 assert_eq!(resid.derivative[2], 0.0);
3645
3646 for i in 0..3 {
3647 {
3649 let mut op = o_exit.clone();
3650 let mut om = o_exit.clone();
3651 op[i] += h;
3652 om[i] -= h;
3653 let fd = (nll(&build(&o_entry, &op, &o_deriv))
3654 - nll(&build(&o_entry, &om, &o_deriv)))
3655 / (2.0 * h);
3656 assert!(
3657 (resid.exit[i] - fd).abs() < 1e-6,
3658 "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3659 resid.exit[i],
3660 fd
3661 );
3662 }
3663 {
3667 let mut op = o_entry.clone();
3668 let mut om = o_entry.clone();
3669 op[i] += h;
3670 om[i] -= h;
3671 let fd = (nll(&build(&op, &o_exit, &o_deriv))
3672 - nll(&build(&om, &o_exit, &o_deriv)))
3673 / (2.0 * h);
3674 assert!(
3675 (resid.entry[i] - fd).abs() < 1e-6,
3676 "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3677 resid.entry[i],
3678 fd
3679 );
3680 }
3681 {
3683 let mut op = o_deriv.clone();
3684 let mut om = o_deriv.clone();
3685 op[i] += h;
3686 om[i] -= h;
3687 let fd = (nll(&build(&o_entry, &o_exit, &op))
3688 - nll(&build(&o_entry, &o_exit, &om)))
3689 / (2.0 * h);
3690 assert!(
3691 (resid.derivative[i] - fd).abs() < 1e-6,
3692 "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3693 resid.derivative[i],
3694 fd
3695 );
3696 }
3697 }
3698 }
3699
3700 #[test]
3701 fn offset_channel_residuals_respect_zero_sampleweight() {
3702 let age_entry = array![1.0_f64, 2.0];
3703 let age_exit = array![2.0_f64, 3.5];
3704 let event_target = array![1u8, 1u8];
3705 let event_competing = array![0u8, 0u8];
3706 let sampleweight = array![0.0_f64, 1.2]; let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3708 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3709 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3710 let penalties = PenaltyBlocks::new(Vec::new());
3711 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3712 let beta = array![-1.0_f64, 0.8];
3713
3714 let model = survival_model_with_offsets(
3715 survival_inputs(
3716 &age_entry,
3717 &age_exit,
3718 &event_target,
3719 &event_competing,
3720 &sampleweight,
3721 &x_entry,
3722 &x_exit,
3723 &x_derivative,
3724 ),
3725 Some(SurvivalBaselineOffsets {
3726 eta_entry: array![0.0_f64, 0.1].view(),
3727 eta_exit: array![0.0_f64, 0.2].view(),
3728 derivative_exit: array![0.0_f64, 0.1].view(),
3729 }),
3730 penalties,
3731 mono,
3732 SurvivalSpec::Net,
3733 )
3734 .expect("model");
3735 let r = model.offset_channel_residuals(&beta).expect("resid");
3736 assert_eq!(r.exit[0], 0.0);
3738 assert_eq!(r.entry[0], 0.0);
3739 assert_eq!(r.derivative[0], 0.0);
3740 assert!(r.exit[1] != 0.0);
3742 }
3743
3744 #[test]
3745 fn offset_channel_residuals_reject_beta_dim_mismatch() {
3746 let age_entry = array![1.0_f64];
3747 let age_exit = array![2.0_f64];
3748 let event_target = array![1u8];
3749 let event_competing = array![0u8];
3750 let sampleweight = array![1.0_f64];
3751 let x_entry = array![[1.0, 0.0]];
3752 let x_exit = array![[1.0, 0.7]];
3753 let x_derivative = array![[0.0, 0.5]];
3754 let penalties = PenaltyBlocks::new(Vec::new());
3755 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3756 let model = survival_model(
3757 survival_inputs(
3758 &age_entry,
3759 &age_exit,
3760 &event_target,
3761 &event_competing,
3762 &sampleweight,
3763 &x_entry,
3764 &x_exit,
3765 &x_derivative,
3766 ),
3767 penalties,
3768 mono,
3769 SurvivalSpec::Net,
3770 )
3771 .expect("model");
3772 let bad_beta = array![0.0_f64]; let err = model
3774 .offset_channel_residuals(&bad_beta)
3775 .expect_err("mismatch must error");
3776 match err {
3777 EstimationError::InvalidInput(msg) => {
3778 assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3779 }
3780 other => panic!("expected InvalidInput, got {other:?}"),
3781 }
3782 }
3783
3784 #[test]
3785 fn crudespec_is_rejected_by_one_hazard_engine() {
3786 let age_entry = array![1.0_f64];
3787 let age_exit = array![2.0_f64];
3788 let event_target = array![0u8];
3789 let event_competing = array![1u8];
3790 let sampleweight = array![1.0];
3791 let x_entry = array![[0.1]];
3792 let x_exit = array![[0.4]];
3793 let x_derivative = array![[1.0]];
3794 let penalties = PenaltyBlocks::new(Vec::new());
3795 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3796
3797 let err = survival_model(
3798 survival_inputs(
3799 &age_entry,
3800 &age_exit,
3801 &event_target,
3802 &event_competing,
3803 &sampleweight,
3804 &x_entry,
3805 &x_exit,
3806 &x_derivative,
3807 ),
3808 penalties,
3809 mono,
3810 SurvivalSpec::Crude,
3811 )
3812 .expect_err("crude fitting should be rejected by the one-hazard engine");
3813 assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3814 }
3815
3816 #[test]
3817 fn nonstructural_models_require_explicit_monotonicity_collocation() {
3818 let age_entry = array![1.0_f64, 1.5_f64];
3819 let age_exit = array![2.0_f64, 2.5_f64];
3820 let event_target = array![0u8, 0u8];
3821 let event_competing = array![0u8, 1u8];
3822 let sampleweight = array![1.0, 1.0];
3823 let x_entry = array![[0.2], [0.1]];
3824 let x_exit = array![[0.3], [0.2]];
3825 let x_derivative = array![[1.0], [1.0]];
3826
3827 let model = survival_model(
3828 survival_inputs(
3829 &age_entry,
3830 &age_exit,
3831 &event_target,
3832 &event_competing,
3833 &sampleweight,
3834 &x_entry,
3835 &x_exit,
3836 &x_derivative,
3837 ),
3838 PenaltyBlocks::new(Vec::new()),
3839 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3840 SurvivalSpec::Net,
3841 )
3842 .expect("construct censored survival model");
3843
3844 assert!(
3845 model.monotonicity_linear_constraints().is_none(),
3846 "non-structural survival models must not fabricate rowwise monotonicity constraints"
3847 );
3848 }
3849
3850 #[test]
3851 fn decreasing_interval_is_rejectedwithout_target_events() {
3852 let age_entry = array![1.0_f64];
3853 let age_exit = array![2.0_f64];
3854 let event_target = array![0u8];
3855 let event_competing = array![0u8];
3856 let sampleweight = array![1.0];
3857 let x_entry = array![[0.5]];
3858 let x_exit = array![[0.0]];
3859 let x_derivative = array![[1.0]];
3860
3861 let model = survival_model(
3862 survival_inputs(
3863 &age_entry,
3864 &age_exit,
3865 &event_target,
3866 &event_competing,
3867 &sampleweight,
3868 &x_entry,
3869 &x_exit,
3870 &x_derivative,
3871 ),
3872 PenaltyBlocks::new(Vec::new()),
3873 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3874 SurvivalSpec::Net,
3875 )
3876 .expect("construct censored survival model");
3877
3878 let err = model
3879 .update_state(&array![1.0])
3880 .expect_err("decreasing cumulative hazard increment should be rejected");
3881 assert!(
3882 err.to_string().contains("cumulative hazard decreased"),
3883 "unexpected error: {err}"
3884 );
3885 }
3886
3887 fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3888 calculate_crude_risk_quadrature(
3889 0.0,
3890 1.0,
3891 &[0.0, 1.0],
3892 beta_d.exp(),
3893 beta_m.exp(),
3894 array![1.0].view(),
3895 array![1.0].view(),
3896 |u, design_d, deriv_d, design_m| {
3897 let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3898 let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3899 let inst_hazard_d = 0.2 * beta_d.exp();
3900 design_d[0] = 1.0;
3901 deriv_d[0] = 0.0;
3904 design_m[0] = 1.0;
3905 Ok((inst_hazard_d, cumulative_d, cumulative_m))
3906 },
3907 )
3908 .expect("smooth crude-risk quadrature should succeed")
3909 }
3910
3911 #[test]
3912 fn crude_riskgradient_matches_monotoneobjective() {
3913 let beta_d = -0.2_f64;
3914 let beta_m = -0.5_f64;
3915 let result = smooth_crude_risk(beta_d, beta_m);
3916 let eps = 1e-6;
3917
3918 let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3919 - smooth_crude_risk(beta_d - eps, beta_m).risk)
3920 / (2.0 * eps);
3921 let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3922 - smooth_crude_risk(beta_d, beta_m - eps).risk)
3923 / (2.0 * eps);
3924
3925 assert!(
3926 (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3927 "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3928 result.diseasegradient[0]
3929 );
3930 assert!(
3931 (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3932 "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3933 result.mortalitygradient[0]
3934 );
3935 }
3936
3937 #[test]
3938 fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3939 let age_entry = array![1.0_f64, 2.0_f64];
3940 let age_exit = array![2.0_f64, 3.5_f64];
3941 let event_target = array![1u8, 0u8];
3942 let event_competing = array![0u8, 0u8];
3943 let sampleweight = array![1.0, 1.0];
3944 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3945 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3946 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3947 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3948 matrix: array![[2.0]],
3949 lambda: 1.7,
3950 range: 1..2,
3951 nullspace_dim: 0,
3952 }]);
3953 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3954 let beta = array![-1.2, 0.4];
3955
3956 let model = survival_model(
3957 survival_inputs(
3958 &age_entry,
3959 &age_exit,
3960 &event_target,
3961 &event_competing,
3962 &sampleweight,
3963 &x_entry,
3964 &x_exit,
3965 &x_derivative,
3966 ),
3967 penalties.clone(),
3968 mono,
3969 SurvivalSpec::Net,
3970 )
3971 .expect("construct survival model");
3972
3973 let state = model.update_state(&beta).expect("survival state");
3974 let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3975 assert!(
3976 (state.penalty_term - expected_penalty).abs() < 1e-12,
3977 "penalty_term mismatch: state={} expected={}",
3978 state.penalty_term,
3979 expected_penalty
3980 );
3981 }
3982
3983 #[test]
3984 fn negative_penalty_lambda_is_rejected() {
3985 let age_entry = array![1.0_f64];
3986 let age_exit = array![2.0_f64];
3987 let event_target = array![1u8];
3988 let event_competing = array![0u8];
3989 let sampleweight = array![1.0];
3990 let x_entry = array![[1.0, 0.0]];
3991 let x_exit = array![[1.0, 0.5]];
3992 let x_derivative = array![[0.0, 1.0]];
3993 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3994 matrix: array![[1.0]],
3995 lambda: -0.1,
3996 range: 1..2,
3997 nullspace_dim: 0,
3998 }]);
3999
4000 let err = survival_model(
4001 survival_inputs(
4002 &age_entry,
4003 &age_exit,
4004 &event_target,
4005 &event_competing,
4006 &sampleweight,
4007 &x_entry,
4008 &x_exit,
4009 &x_derivative,
4010 ),
4011 penalties,
4012 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4013 SurvivalSpec::Net,
4014 )
4015 .expect_err("negative lambda must be rejected");
4016
4017 assert!(matches!(err, SurvivalError::NonFiniteInput));
4018 }
4019
4020 #[test]
4021 fn penalty_block_range_and_shapemust_match_coefficients() {
4022 let age_entry = array![1.0_f64];
4023 let age_exit = array![2.0_f64];
4024 let event_target = array![1u8];
4025 let event_competing = array![0u8];
4026 let sampleweight = array![1.0];
4027 let x_entry = array![[1.0, 0.0]];
4028 let x_exit = array![[1.0, 0.5]];
4029 let x_derivative = array![[0.0, 1.0]];
4030 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4031 matrix: array![[1.0]],
4032 lambda: 0.5,
4033 range: 0..2,
4034 nullspace_dim: 0,
4035 }]);
4036
4037 let err = survival_model(
4038 survival_inputs(
4039 &age_entry,
4040 &age_exit,
4041 &event_target,
4042 &event_competing,
4043 &sampleweight,
4044 &x_entry,
4045 &x_exit,
4046 &x_derivative,
4047 ),
4048 penalties,
4049 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4050 SurvivalSpec::Net,
4051 )
4052 .expect_err("penalty block geometry must match coefficient support");
4053
4054 assert!(matches!(err, SurvivalError::DimensionMismatch));
4055 }
4056
4057 #[test]
4058 fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4059 let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4060 let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4061 let event_target = array![1u8, 0u8, 1u8];
4062 let event_competing = array![0u8, 0u8, 0u8];
4063 let sampleweight = array![1.0, 1.0, 1.0];
4064 let x_entry = array![
4065 [1.0, age_entry[0].ln()],
4066 [1.0, age_entry[1].ln()],
4067 [1.0, age_entry[2].ln()]
4068 ];
4069 let x_exit = array![
4070 [1.0, age_exit[0].ln()],
4071 [1.0, age_exit[1].ln()],
4072 [1.0, age_exit[2].ln()]
4073 ];
4074 let x_derivative = array![
4075 [0.0, 1.0 / age_exit[0]],
4076 [0.0, 1.0 / age_exit[1]],
4077 [0.0, 1.0 / age_exit[2]]
4078 ];
4079 let penalties = PenaltyBlocks::new(Vec::new());
4080 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4081 let beta = array![-1.0, 3.0];
4082
4083 let model = survival_model(
4084 survival_inputs(
4085 &age_entry,
4086 &age_exit,
4087 &event_target,
4088 &event_competing,
4089 &sampleweight,
4090 &x_entry,
4091 &x_exit,
4092 &x_derivative,
4093 ),
4094 penalties,
4095 mono,
4096 SurvivalSpec::Net,
4097 )
4098 .expect("construct survival model");
4099
4100 let state = model.update_state(&beta).expect("state at beta");
4101 let eps = 1e-7;
4102 for j in 0..beta.len() {
4103 let mut plus = beta.clone();
4104 let mut minus = beta.clone();
4105 plus[j] += eps;
4106 minus[j] -= eps;
4107 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4108 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4109 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4110 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4111 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4112 assert_eq!(
4113 state.gradient[j].signum(),
4114 fd.signum(),
4115 "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4116 state.gradient[j]
4117 );
4118 assert!(
4119 (state.gradient[j] - fd).abs() < 1e-5,
4120 "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4121 state.gradient[j]
4122 );
4123 }
4124 }
4125
4126 fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4127 let age_entry: Array1<f64> = Array1::from(vec![
4134 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4135 34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4136 ]);
4137 let age_exit: Array1<f64> = Array1::from(vec![
4138 45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4139 48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4140 ]);
4141 let event_target = Array1::from(vec![
4142 1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4143 ]);
4144 let event_competing = Array1::<u8>::zeros(age_entry.len());
4145 let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4146 let n = age_entry.len();
4147 let ln_age_mean: f64 = {
4148 let mut sum = 0.0;
4149 for i in 0..n {
4150 sum += age_entry[i].ln() + age_exit[i].ln();
4151 }
4152 sum / (2.0 * n as f64)
4153 };
4154 let mut x_entry = Array2::<f64>::zeros((n, 2));
4155 let mut x_exit = Array2::<f64>::zeros((n, 2));
4156 let mut x_derivative = Array2::<f64>::zeros((n, 2));
4157 for i in 0..n {
4158 x_entry[[i, 0]] = 1.0;
4159 x_exit[[i, 0]] = 1.0;
4160 x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4161 x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4162 x_derivative[[i, 0]] = 0.0;
4163 x_derivative[[i, 1]] = 1.0 / age_exit[i];
4164 }
4165 let penalties = PenaltyBlocks::new(vec![
4166 PenaltyBlock {
4167 matrix: array![[3.0]],
4168 lambda: 0.0,
4169 range: 0..1,
4170 nullspace_dim: 0,
4171 },
4172 PenaltyBlock {
4173 matrix: array![[2.5]],
4174 lambda,
4175 range: 1..2,
4176 nullspace_dim: 0,
4177 },
4178 ]);
4179 survival_model(
4180 survival_inputs(
4181 &age_entry,
4182 &age_exit,
4183 &event_target,
4184 &event_competing,
4185 &sampleweight,
4186 &x_entry,
4187 &x_exit,
4188 &x_derivative,
4189 ),
4190 penalties,
4191 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4192 SurvivalSpec::Net,
4193 )
4194 .expect("construct LAML FD survival model")
4195 }
4196
4197 fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4198 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4199 use gam_linalg::faer_ndarray::FaerEigh;
4200
4201 let h_dense = state.hessian.to_dense();
4202 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4203 let eps = spectral_epsilon(evals.as_slice().unwrap());
4204 evals
4205 .iter()
4206 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4207 .sum()
4208 }
4209
4210 #[test]
4211 fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4212 let rho0 = -0.35_f64;
4226 let beta = array![-2.5_f64, 1.0];
4227 let model = laml_fd_test_model(rho0.exp());
4228 let state = model
4229 .update_state(&beta)
4230 .expect("state for LAML prefix-skip test");
4231
4232 assert_eq!(model.penalties.blocks.len(), 2);
4237 assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4238 assert!(model.penalties.blocks[1].lambda > 0.0);
4239
4240 let rho = Array1::from_iter(
4241 model
4242 .penalties
4243 .blocks
4244 .iter()
4245 .filter(|b| b.lambda > 0.0)
4246 .map(|b| b.lambda.ln()),
4247 );
4248 assert_eq!(
4249 rho.len(),
4250 1,
4251 "fixture should expose exactly one active penalty block for the rho vector"
4252 );
4253
4254 let (obj, grad) = model
4255 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4256 .expect("survival LAML objective and gradient");
4257
4258 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4259 - 0.5 * (rho0 + 2.5_f64.ln());
4260 assert_eq!(
4261 grad.len(),
4262 1,
4263 "rho-gradient must match the active-penalty count, not the full block list"
4264 );
4265 assert!(
4266 (obj - expected).abs() < 1e-10,
4267 "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4268 );
4269 assert!(
4270 grad[0].is_finite(),
4271 "rho-gradient must be finite: {}",
4272 grad[0]
4273 );
4274 }
4275
4276 #[test]
4277 fn structural_monotonicgradient_matchesobjectivefd() {
4278 let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4279 let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4280 let event_target = array![1u8, 0u8, 1u8];
4281 let event_competing = array![0u8, 0u8, 0u8];
4282 let sampleweight = array![1.0, 1.0, 1.0];
4283
4284 let x_entry = array![
4287 [1.0, 0.2, 0.05, -0.7],
4288 [1.0, 0.5, 0.20, 0.1],
4289 [1.0, 0.9, 0.60, 1.2]
4290 ];
4291 let x_exit = array![
4292 [1.0, 0.4, 0.16, -0.7],
4293 [1.0, 0.8, 0.64, 0.1],
4294 [1.0, 1.1, 1.21, 1.2]
4295 ];
4296 let x_derivative = array![
4297 [0.0, 0.8, 0.64, 0.0],
4298 [0.0, 0.7, 1.12, 0.0],
4299 [0.0, 0.6, 1.32, 0.0]
4300 ];
4301 let penalties = PenaltyBlocks::new(Vec::new());
4302 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4303 let mut model = survival_model(
4304 survival_inputs(
4305 &age_entry,
4306 &age_exit,
4307 &event_target,
4308 &event_competing,
4309 &sampleweight,
4310 &x_entry,
4311 &x_exit,
4312 &x_derivative,
4313 ),
4314 penalties,
4315 mono,
4316 SurvivalSpec::Net,
4317 )
4318 .expect("construct structural survival model");
4319 model
4320 .set_structural_monotonicity(true, 3)
4321 .expect("enable structural monotonicity");
4322 let constraints = model
4323 .monotonicity_linear_constraints()
4324 .expect("structural derivative constraints");
4325 assert_eq!(constraints.a.nrows(), 2);
4326 assert_eq!(constraints.a.ncols(), 4);
4327 assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4328 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4329 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4330
4331 let beta = array![0.2, 0.2, 0.1, 0.2];
4332 let state = model.update_state(&beta).expect("state at structural beta");
4333 let eps = 1e-7;
4334 for j in 0..beta.len() {
4335 let mut plus = beta.clone();
4336 let mut minus = beta.clone();
4337 plus[j] += eps;
4338 minus[j] -= eps;
4339 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4340 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4341 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4342 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4343 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4344 assert_eq!(
4345 state.gradient[j].signum(),
4346 fd.signum(),
4347 "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4348 state.gradient[j]
4349 );
4350 assert!(
4351 (state.gradient[j] - fd).abs() < 2e-5,
4352 "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4353 state.gradient[j]
4354 );
4355 }
4356 }
4357
4358 #[test]
4359 fn structural_monotonic_lamlgradient_returns_finitevalues() {
4360 let age_entry = array![1.0_f64, 1.2_f64];
4361 let age_exit = array![1.5_f64, 2.0_f64];
4362 let event_target = array![1u8, 0u8];
4363 let event_competing = array![0u8, 0u8];
4364 let sampleweight = array![1.0, 1.0];
4365 let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4366 let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4367 let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4368 let penalties = PenaltyBlocks::new(Vec::new());
4369 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4370 let mut model = survival_model(
4371 survival_inputs(
4372 &age_entry,
4373 &age_exit,
4374 &event_target,
4375 &event_competing,
4376 &sampleweight,
4377 &x_entry,
4378 &x_exit,
4379 &x_derivative,
4380 ),
4381 penalties,
4382 mono,
4383 SurvivalSpec::Net,
4384 )
4385 .expect("construct structural survival model");
4386 model
4387 .set_structural_monotonicity(true, 2)
4388 .expect("enable structural monotonicity");
4389 model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4391 matrix: array![[1.0]],
4392 lambda: 0.7,
4393 range: 1..2,
4394 nullspace_dim: 0,
4395 }]);
4396 let beta = array![0.2, 0.2, 0.1];
4397 let state = model.update_state(&beta).expect("state at structural beta");
4398 let rho = Array1::from_iter(
4399 model
4400 .penalties
4401 .blocks
4402 .iter()
4403 .filter(|b| b.lambda > 0.0)
4404 .map(|b| b.lambda.ln()),
4405 );
4406 let (obj, grad) = model
4407 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4408 .expect("laml gradient should work in structural mode");
4409 assert!(obj.is_finite());
4410 assert_eq!(grad.len(), 1);
4411 assert!(grad[0].is_finite());
4412 }
4413
4414 #[test]
4415 fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4416 let age_entry = array![1.0_f64];
4417 let age_exit = array![2.0_f64];
4418 let event_target = array![1u8];
4419 let event_competing = array![0u8];
4420 let sampleweight = array![1.0];
4421 let x_entry = array![[0.0]];
4422 let x_exit = array![[0.2]];
4423 let x_derivative = array![[1.0]];
4424
4425 let penalties = PenaltyBlocks::new(Vec::new());
4426 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4427 let mut model = survival_model(
4428 survival_inputs(
4429 &age_entry,
4430 &age_exit,
4431 &event_target,
4432 &event_competing,
4433 &sampleweight,
4434 &x_entry,
4435 &x_exit,
4436 &x_derivative,
4437 ),
4438 penalties,
4439 mono,
4440 SurvivalSpec::Net,
4441 )
4442 .expect("construct structural survival model");
4443
4444 let beta = array![-3.0];
4445 assert!(
4446 model.update_state(&beta).is_err(),
4447 "negative derivative coefficient should violate derivative guard"
4448 );
4449
4450 model
4451 .set_structural_monotonicity(true, 1)
4452 .expect("enable structural monotonicity");
4453 let constraints = model
4454 .monotonicity_linear_constraints()
4455 .expect("structural derivative constraints");
4456 assert_eq!(constraints.a.nrows(), 1);
4457 assert_eq!(constraints.a.ncols(), 1);
4458 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4459 assert!(constraints.b[0].abs() <= 1e-12);
4461 let state = model
4462 .update_state(&array![1e-6])
4463 .expect("small positive derivative coefficient should remain feasible");
4464 assert!(state.deviance.is_finite());
4465 }
4466
4467 #[test]
4468 fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4469 let age_entry = array![1.0_f64];
4470 let age_exit = array![2.0_f64];
4471 let event_target = array![1u8];
4472 let event_competing = array![0u8];
4473 let sampleweight = array![1.0];
4474 let x_entry = array![[1.0, 0.0]];
4475 let x_exit = array![[1.0, 0.0]];
4476 let x_derivative = array![[0.0, 0.0]];
4477 let penalties = PenaltyBlocks::new(Vec::new());
4478 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4479 let eta_entry_offset = array![0.0];
4480 let eta_exit_offset = array![0.0];
4481 let derivative_offset_below_guard = array![2.0];
4482 let derivative_offset_above_guard = array![3.1];
4483 let offsets_below_guard = SurvivalBaselineOffsets {
4484 eta_entry: eta_entry_offset.view(),
4485 eta_exit: eta_exit_offset.view(),
4486 derivative_exit: derivative_offset_below_guard.view(),
4487 };
4488 let offsets_above_guard = SurvivalBaselineOffsets {
4489 eta_entry: eta_entry_offset.view(),
4490 eta_exit: eta_exit_offset.view(),
4491 derivative_exit: derivative_offset_above_guard.view(),
4492 };
4493
4494 let model_below_guard = survival_model_with_offsets(
4495 survival_inputs(
4496 &age_entry,
4497 &age_exit,
4498 &event_target,
4499 &event_competing,
4500 &sampleweight,
4501 &x_entry,
4502 &x_exit,
4503 &x_derivative,
4504 ),
4505 Some(offsets_below_guard),
4506 penalties.clone(),
4507 monotonicity,
4508 SurvivalSpec::Net,
4509 )
4510 .expect("construct model with derivative offset below guard");
4511 let err = model_below_guard
4512 .update_state(&array![0.0, 0.0])
4513 .expect_err("derivative offset below guard should be rejected");
4514 let err_text = err.to_string();
4515 assert!(
4516 err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4517 "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4518 );
4519
4520 let model_above_guard = survival_model_with_offsets(
4521 survival_inputs(
4522 &age_entry,
4523 &age_exit,
4524 &event_target,
4525 &event_competing,
4526 &sampleweight,
4527 &x_entry,
4528 &x_exit,
4529 &x_derivative,
4530 ),
4531 Some(offsets_above_guard),
4532 penalties,
4533 SurvivalMonotonicityPenalty { tolerance: 3.0 },
4534 SurvivalSpec::Net,
4535 )
4536 .expect("construct model with derivative offset above guard");
4537 let state = model_above_guard
4538 .update_state(&array![0.0, 0.0])
4539 .expect("derivative offset above guard should remain feasible");
4540 assert!(state.deviance.is_finite());
4541 }
4542
4543 #[test]
4544 fn structural_monotonicity_rejects_negative_derivative_offsets() {
4545 let age_entry = array![1.0_f64];
4546 let age_exit = array![2.0_f64];
4547 let event_target = array![1u8];
4548 let event_competing = array![0u8];
4549 let sampleweight = array![1.0];
4550 let x_entry = array![[0.0]];
4551 let x_exit = array![[0.2]];
4552 let x_derivative = array![[1.0]];
4553 let eta_entry = array![0.0];
4554 let eta_exit = array![0.0];
4555 let derivative_exit = array![-1e-3];
4556 let offsets = SurvivalBaselineOffsets {
4557 eta_entry: eta_entry.view(),
4558 eta_exit: eta_exit.view(),
4559 derivative_exit: derivative_exit.view(),
4560 };
4561
4562 let mut model = survival_model_with_offsets(
4563 survival_inputs(
4564 &age_entry,
4565 &age_exit,
4566 &event_target,
4567 &event_competing,
4568 &sampleweight,
4569 &x_entry,
4570 &x_exit,
4571 &x_derivative,
4572 ),
4573 Some(offsets),
4574 PenaltyBlocks::new(Vec::new()),
4575 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4576 SurvivalSpec::Net,
4577 )
4578 .expect("construct structural survival model");
4579 let err = model
4580 .set_structural_monotonicity(true, 1)
4581 .expect_err("negative derivative offsets must be rejected");
4582 assert!(
4583 err.to_string()
4584 .contains("structural monotonicity requires nonnegative derivative offsets"),
4585 "unexpected error: {err}"
4586 );
4587 }
4588
4589 #[test]
4590 fn structural_monotonicity_emits_coefficient_constraints() {
4591 let age_entry = array![1.0_f64, 1.5_f64];
4592 let age_exit = array![2.0_f64, 3.0_f64];
4593 let event_target = array![1u8, 0u8];
4594 let event_competing = array![0u8, 0u8];
4595 let sampleweight = array![1.0, 1.0];
4596 let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4597 let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4598 let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4599
4600 let mut model = survival_model(
4601 survival_inputs(
4602 &age_entry,
4603 &age_exit,
4604 &event_target,
4605 &event_competing,
4606 &sampleweight,
4607 &x_entry,
4608 &x_exit,
4609 &x_derivative,
4610 ),
4611 PenaltyBlocks::new(Vec::new()),
4612 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4613 SurvivalSpec::Net,
4614 )
4615 .expect("construct structural survival model");
4616 model
4617 .set_structural_monotonicity(true, 2)
4618 .expect("enable structural monotonicity");
4619
4620 let constraints = model
4621 .monotonicity_linear_constraints()
4622 .expect("structural derivative constraints");
4623
4624 assert_eq!(constraints.a.nrows(), 2);
4625 assert_eq!(constraints.a.ncols(), 3);
4626 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4627 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4628 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4629 }
4630
4631 #[test]
4632 fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4633 let age_entry = array![1.0_f64];
4634 let age_exit = array![2.0_f64];
4635 let event_target = array![1u8];
4636 let event_competing = array![0u8];
4637 let sampleweight = array![1.0];
4638 let x_entry = array![[1.0, 0.2]];
4639 let x_exit = array![[1.0, 0.6]];
4640 let x_derivative = array![[0.0, 1.0]];
4641
4642 let mut model = survival_model(
4643 survival_inputs(
4644 &age_entry,
4645 &age_exit,
4646 &event_target,
4647 &event_competing,
4648 &sampleweight,
4649 &x_entry,
4650 &x_exit,
4651 &x_derivative,
4652 ),
4653 PenaltyBlocks::new(Vec::new()),
4654 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4655 SurvivalSpec::Net,
4656 )
4657 .expect("construct structural survival model");
4658 model
4659 .set_structural_monotonicity(true, 2)
4660 .expect("enable structural monotonicity");
4661
4662 let constraints = model
4663 .monotonicity_linear_constraints()
4664 .expect("structural derivative constraints");
4665
4666 assert_eq!(constraints.a.nrows(), 1);
4667 assert!(
4668 constraints.a[[0, 0]].abs() <= 1e-12,
4669 "inactive time column should remain unconstrained"
4670 );
4671 assert!(
4672 (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4673 "active time column should remain constrained"
4674 );
4675 }
4676
4677 #[test]
4678 fn structural_monotonicity_preserves_sparse_row_patterns() {
4679 let age_entry = array![1.0_f64, 1.5_f64];
4680 let age_exit = array![2.0_f64, 2.5_f64];
4681 let event_target = array![1u8, 1u8];
4682 let event_competing = array![0u8, 0u8];
4683 let sampleweight = array![1.0, 1.0];
4684 let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4685 let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4686 let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4687
4688 let mut model = survival_model(
4689 survival_inputs(
4690 &age_entry,
4691 &age_exit,
4692 &event_target,
4693 &event_competing,
4694 &sampleweight,
4695 &x_entry,
4696 &x_exit,
4697 &x_derivative,
4698 ),
4699 PenaltyBlocks::new(Vec::new()),
4700 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4701 SurvivalSpec::Net,
4702 )
4703 .expect("construct structural survival model");
4704 model
4705 .set_structural_monotonicity(true, 2)
4706 .expect("enable structural monotonicity");
4707
4708 let constraints = model
4709 .monotonicity_linear_constraints()
4710 .expect("structural derivative constraints");
4711
4712 assert_eq!(constraints.a.nrows(), 2);
4713 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4714 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4715 }
4716
4717 #[test]
4718 fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4719 let age_entry = array![1.0_f64];
4720 let age_exit = array![1.1_f64];
4721 let event_target = array![0u8];
4722 let event_competing = array![0u8];
4723 let sampleweight = array![1.0];
4724 let x_entry = array![[0.0]];
4725 let x_exit = array![[0.0]];
4726 let x_derivative = array![[-1.0]];
4727 let penalties = PenaltyBlocks::new(Vec::new());
4728 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4729 let model = survival_model(
4730 survival_inputs(
4731 &age_entry,
4732 &age_exit,
4733 &event_target,
4734 &event_competing,
4735 &sampleweight,
4736 &x_entry,
4737 &x_exit,
4738 &x_derivative,
4739 ),
4740 penalties,
4741 mono,
4742 SurvivalSpec::Net,
4743 )
4744 .expect("construct censored survival model");
4745
4746 let err = model
4747 .update_state(&array![1.0])
4748 .expect_err("censored row should still enforce monotonic derivative");
4749 assert!(
4750 matches!(err, EstimationError::ParameterConstraintViolation(_)),
4751 "unexpected error: {err:?}"
4752 );
4753 }
4754
4755 fn crude_risk_quadrature_error(
4756 cumulative_entry: f64,
4757 cumulative_exit: f64,
4758 hazard_exit: f64,
4759 ) -> SurvivalError {
4760 calculate_crude_risk_quadrature(
4761 1.0,
4762 2.0,
4763 &[],
4764 0.4,
4765 0.2,
4766 array![1.0].view(),
4767 array![1.0].view(),
4768 |_, design_d, deriv_d, design_m| {
4769 design_d[0] = 1.0;
4770 deriv_d[0] = 0.0;
4771 design_m[0] = 1.0;
4772 Ok((cumulative_entry, cumulative_exit, hazard_exit))
4773 },
4774 )
4775 .expect_err("invalid hazards should fail")
4776 }
4777
4778 #[test]
4779 fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4780 let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4781 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4782 }
4783
4784 #[test]
4785 fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4786 let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4787 assert!(matches!(err, SurvivalError::NonPositiveHazard));
4788 }
4789
4790 #[test]
4791 fn laml_no_penalties_matches_documentedobjective() {
4792 let age_entry = array![40.0, 45.0, 50.0, 55.0];
4793 let age_exit = array![44.0, 49.0, 54.0, 59.0];
4794 let event_target = array![1u8, 0u8, 1u8, 0u8];
4795 let event_competing = Array1::<u8>::zeros(4);
4796 let sampleweight = Array1::ones(4);
4797 let x_entry = array![
4798 [1.0, -0.2, 0.04],
4799 [1.0, -0.1, 0.01],
4800 [1.0, 0.0, 0.0],
4801 [1.0, 0.1, 0.01]
4802 ];
4803 let x_exit = array![
4804 [1.0, -0.12, 0.0144],
4805 [1.0, -0.02, 0.0004],
4806 [1.0, 0.08, 0.0064],
4807 [1.0, 0.18, 0.0324]
4808 ];
4809 let x_derivative = array![
4810 [0.0, 0.02, 0.001],
4811 [0.0, 0.02, 0.001],
4812 [0.0, 0.02, 0.001],
4813 [0.0, 0.02, 0.001]
4814 ];
4815 let penalties = PenaltyBlocks::new(Vec::new());
4816 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4817 let beta = array![-2.0, 0.7, 0.2];
4818
4819 let model = survival_model(
4820 survival_inputs(
4821 &age_entry,
4822 &age_exit,
4823 &event_target,
4824 &event_competing,
4825 &sampleweight,
4826 &x_entry,
4827 &x_exit,
4828 &x_derivative,
4829 ),
4830 penalties,
4831 mono,
4832 SurvivalSpec::Net,
4833 )
4834 .expect("construct survival model");
4835
4836 let state = model.update_state(&beta).expect("state at beta");
4837 let rho = Array1::from_iter(
4838 model
4839 .penalties
4840 .blocks
4841 .iter()
4842 .filter(|b| b.lambda > 0.0)
4843 .map(|b| b.lambda.ln()),
4844 );
4845 let (obj, grad) = model
4846 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4847 .expect("laml objective for no-penalty model");
4848
4849 let h_dense = state.hessian.to_dense();
4850 let logdet_h: f64 = {
4851 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4852 use gam_linalg::faer_ndarray::FaerEigh;
4853 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4854 let eps = spectral_epsilon(evals.as_slice().unwrap());
4855 evals
4856 .iter()
4857 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4858 .sum()
4859 };
4860 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4861
4862 assert_eq!(grad.len(), 0);
4863 assert!(
4864 (obj - expected).abs() < 1e-10,
4865 "no-penalty LAML objective mismatch: obj={} expected={}",
4866 obj,
4867 expected
4868 );
4869 }
4870
4871 #[test]
4872 fn monotonicity_constraints_collapse_positive_collinearrows() {
4873 let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4874 let b = array![1e-8, 1e-8, 1e-8];
4875
4876 let compressed = compress_positive_collinear_constraints(&a, &b);
4877
4878 assert_eq!(compressed.a.nrows(), 1);
4879 assert_eq!(compressed.a.ncols(), 3);
4880 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4881 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4882 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4883 assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4884 }
4885
4886 #[test]
4887 fn monotonicity_constraints_preserve_distinct_directions() {
4888 let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4889 let b = array![0.2, 0.3, 0.1];
4890
4891 let compressed = compress_positive_collinear_constraints(&a, &b);
4892
4893 assert_eq!(compressed.a.nrows(), 2);
4894 let mut saw_x = false;
4895 let mut saw_y = false;
4896 for i in 0..compressed.a.nrows() {
4897 if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4898 saw_x = true;
4899 assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4900 }
4901 if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4902 saw_y = true;
4903 assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4904 }
4905 }
4906 assert!(saw_x);
4907 assert!(saw_y);
4908 }
4909
4910 #[test]
4911 fn monotonicity_constraints_cluster_near_collinearrows() {
4912 let a = array![
4913 [0.0, 0.5, 0.0],
4914 [0.0, 0.50000000003, 0.0],
4915 [0.0, 0.49999999997, 0.0]
4916 ];
4917 let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4918
4919 let compressed = compress_positive_collinear_constraints(&a, &b);
4920
4921 assert_eq!(compressed.a.nrows(), 1);
4922 assert_eq!(compressed.a.ncols(), 3);
4923 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4924 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4925 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4926 assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4927 }
4928
4929 #[test]
4930 fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4931 let a = array![
4932 [0.0, 0.401, 0.302, 0.197],
4933 [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4934 [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4935 [0.0, 0.125, 0.500, 0.375]
4936 ];
4937 let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4938
4939 let compressed = compress_positive_collinear_constraints(&a, &b);
4940
4941 assert_eq!(compressed.a.nrows(), 2);
4942 let mut clustered_face = false;
4943 let mut distinct_face = false;
4944 for i in 0..compressed.a.nrows() {
4945 let row = compressed.a.row(i);
4946 if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4947 clustered_face = true;
4948 assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4949 } else {
4950 distinct_face = true;
4951 assert!((row[1] - 0.25).abs() <= 1e-12);
4952 assert!((row[2] - 1.0).abs() <= 1e-12);
4953 assert!((row[3] - 0.75).abs() <= 1e-12);
4954 assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4955 }
4956 }
4957 assert!(clustered_face);
4958 assert!(distinct_face);
4959 }
4960
4961 #[test]
4962 fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4963 let age_entry = array![1.0_f64, 1.0, 1.0];
4964 let age_exit = array![2.0_f64, 4.0, 8.0];
4965 let event_target = array![0u8, 1u8, 0u8];
4966 let event_competing = array![0u8, 0u8, 0u8];
4967 let sampleweight = array![1.0, 1.0, 1.0];
4968 let x_entry = array![
4969 [1.0, age_entry[0].ln()],
4970 [1.0, age_entry[1].ln()],
4971 [1.0, age_entry[2].ln()]
4972 ];
4973 let x_exit = array![
4974 [1.0, age_exit[0].ln()],
4975 [1.0, age_exit[1].ln()],
4976 [1.0, age_exit[2].ln()]
4977 ];
4978 let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4979 let penalties = PenaltyBlocks::new(Vec::new());
4980 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4981
4982 let collocation_offsets = Array1::zeros(x_derivative.nrows());
4983 let mut inputs = survival_inputs(
4984 &age_entry,
4985 &age_exit,
4986 &event_target,
4987 &event_competing,
4988 &sampleweight,
4989 &x_entry,
4990 &x_exit,
4991 &x_derivative,
4992 );
4993 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4994 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4995
4996 let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4997 .expect("construct linear survival model");
4998
4999 let constraints = model
5000 .monotonicity_linear_constraints()
5001 .expect("monotonicity constraints");
5002 assert_eq!(constraints.a.nrows(), 1);
5003 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5004 assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
5005 }
5006
5007 #[test]
5008 fn monotonicity_constraints_skip_numericallyzerorows() {
5009 let age_entry = array![1.0_f64, 1.0, 1.0];
5010 let age_exit = array![2.0_f64, 3.0, 4.0];
5011 let event_target = array![0u8, 0u8, 0u8];
5012 let event_competing = array![0u8, 0u8, 0u8];
5013 let sampleweight = array![1.0, 1.0, 1.0];
5014 let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
5015 let x_exit = x_entry.clone();
5016 let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
5017
5018 let collocation_offsets = Array1::zeros(x_derivative.nrows());
5019 let mut inputs = survival_inputs(
5020 &age_entry,
5021 &age_exit,
5022 &event_target,
5023 &event_competing,
5024 &sampleweight,
5025 &x_entry,
5026 &x_exit,
5027 &x_derivative,
5028 );
5029 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5030 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5031
5032 let model = survival_model(
5033 inputs,
5034 PenaltyBlocks::new(Vec::new()),
5035 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5036 SurvivalSpec::Net,
5037 )
5038 .expect("construct survival model");
5039
5040 let constraints = model
5041 .monotonicity_linear_constraints()
5042 .expect("nonzero derivative row should remain");
5043 assert_eq!(constraints.a.nrows(), 1);
5044 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5045 assert!(constraints.b[0].abs() <= 1e-18);
5046 }
5047
5048 #[test]
5049 fn censoredrows_allowzero_boundary_derivative() {
5050 let age_entry = array![1.0_f64];
5051 let age_exit = array![2.0_f64];
5052 let event_target = array![0u8];
5053 let event_competing = array![0u8];
5054 let sampleweight = array![1.0];
5055 let x_entry = array![[0.0]];
5056 let x_exit = array![[0.0]];
5057 let x_derivative = array![[1.0]];
5058
5059 let model = survival_model(
5060 survival_inputs(
5061 &age_entry,
5062 &age_exit,
5063 &event_target,
5064 &event_competing,
5065 &sampleweight,
5066 &x_entry,
5067 &x_exit,
5068 &x_derivative,
5069 ),
5070 PenaltyBlocks::new(Vec::new()),
5071 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5072 SurvivalSpec::Net,
5073 )
5074 .expect("construct censored survival model");
5075
5076 let state = model
5077 .update_state(&array![0.0])
5078 .expect("censored boundary derivative should remain feasible with zero tolerance");
5079 assert!(state.deviance.is_finite());
5080 }
5081
5082 #[test]
5083 fn eventrows_keep_positive_derivative_constraint() {
5084 let age_entry = array![1.0_f64, 1.0];
5085 let age_exit = array![2.0_f64, 4.0];
5086 let event_target = array![0u8, 1u8];
5087 let event_competing = array![0u8, 0u8];
5088 let sampleweight = array![1.0, 1.0];
5089 let x_entry = array![[0.0], [0.0]];
5090 let x_exit = array![[0.0], [0.0]];
5091 let x_derivative = array![[0.5], [0.25]];
5092
5093 let collocation_offsets = Array1::zeros(x_derivative.nrows());
5094 let mut inputs = survival_inputs(
5095 &age_entry,
5096 &age_exit,
5097 &event_target,
5098 &event_competing,
5099 &sampleweight,
5100 &x_entry,
5101 &x_exit,
5102 &x_derivative,
5103 );
5104 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5105 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5106
5107 let model = survival_model(
5108 inputs,
5109 PenaltyBlocks::new(Vec::new()),
5110 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5111 SurvivalSpec::Net,
5112 )
5113 .expect("construct mixed survival model");
5114
5115 let constraints = model
5116 .monotonicity_linear_constraints()
5117 .expect("event row should induce positive lower bound");
5118 assert_eq!(constraints.a.nrows(), 1);
5119 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5120 assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5121 }
5122
5123 #[test]
5124 fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5125 let age_entry = array![1.0_f64];
5126 let age_exit = array![2.0_f64];
5127 let event_target = array![1u8];
5128 let event_competing = array![0u8];
5129 let sampleweight = array![1.0];
5130 let x_entry = array![[0.0]];
5131 let x_exit = array![[0.0]];
5132 let x_derivative = array![[1.0]];
5133 let mut model = survival_model(
5134 survival_inputs(
5135 &age_entry,
5136 &age_exit,
5137 &event_target,
5138 &event_competing,
5139 &sampleweight,
5140 &x_entry,
5141 &x_exit,
5142 &x_derivative,
5143 ),
5144 PenaltyBlocks::new(Vec::new()),
5145 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5146 SurvivalSpec::Net,
5147 )
5148 .expect("construct survival model");
5149 model
5150 .set_structural_monotonicity(true, 1)
5151 .expect("enable structural monotonicity");
5152
5153 let state = model
5154 .update_state(&array![-1e-8])
5155 .expect("tiny structural roundoff should be clamped");
5156 assert!(state.deviance.is_finite());
5157 }
5158
5159 #[test]
5160 fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5161 let uncompressed_constraints = LinearInequalityConstraints {
5162 a: array![
5163 [0.0, 0.5, 0.0],
5164 [0.0, 1.0 / 3.0, 0.0],
5165 [0.0, 0.2, 0.0],
5166 [0.0, 0.125, 0.0]
5167 ],
5168 b: Array1::from_elem(4, 1e-8),
5169 };
5170 let compressed_constraints = compress_positive_collinear_constraints(
5171 &uncompressed_constraints.a,
5172 &uncompressed_constraints.b,
5173 );
5174
5175 let candidates = [
5176 array![0.0, 1e-9, 0.0],
5177 array![0.0, 4e-8, 0.0],
5178 array![0.0, 8e-8, 0.0],
5179 array![0.0, 2e-7, 1.5],
5180 ];
5181 for beta in candidates {
5182 let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5183 uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5184 });
5185 let compressed_ok = (0..compressed_constraints.a.nrows())
5186 .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5187 assert_eq!(compressed_ok, uncompressed_ok);
5188 }
5189 }
5190
5191 #[test]
5192 fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5193 let age_entry = array![10.0_f64, 20.0, 25.0];
5194 let age_exit = array![15.0_f64, 30.0, 40.0];
5195 let event_target = array![1u8, 0u8, 1u8];
5196 let event_competing = array![0u8, 0u8, 0u8];
5197 let sampleweight = array![1.0, 2.0, 0.5];
5198 let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5199 let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5200 let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5201 let beta = array![0.8, 1.1, -0.2];
5202
5203 let base_model = survival_model(
5204 survival_inputs(
5205 &age_entry,
5206 &age_exit,
5207 &event_target,
5208 &event_competing,
5209 &sampleweight,
5210 &x_entry,
5211 &x_exit,
5212 &x_derivative,
5213 ),
5214 PenaltyBlocks::new(Vec::new()),
5215 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5216 SurvivalSpec::Net,
5217 )
5218 .expect("construct base survival model");
5219 let base_state = base_model
5220 .update_state(&beta)
5221 .expect("evaluate base survival state");
5222
5223 let time_scale = 365.25;
5224 let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5225 let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5226 let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5227 let scaled_model = survival_model(
5228 survival_inputs(
5229 &scaled_age_entry,
5230 &scaled_age_exit,
5231 &event_target,
5232 &event_competing,
5233 &sampleweight,
5234 &x_entry,
5235 &x_exit,
5236 &scaled_x_derivative,
5237 ),
5238 PenaltyBlocks::new(Vec::new()),
5239 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5240 SurvivalSpec::Net,
5241 )
5242 .expect("construct scaled survival model");
5243 let scaled_state = scaled_model
5244 .update_state(&beta)
5245 .expect("evaluate scaled survival state");
5246
5247 let weighted_events = sampleweight
5248 .iter()
5249 .zip(event_target.iter())
5250 .map(|(w, d)| *w * f64::from(*d))
5251 .sum::<f64>();
5252 let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5253 assert!(
5254 (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5255 "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5256 scaled_state.deviance,
5257 base_state.deviance
5258 );
5259
5260 for j in 0..beta.len() {
5261 assert!(
5262 (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5263 "gradient mismatch at j={j}: scaled={} base={}",
5264 scaled_state.gradient[j],
5265 base_state.gradient[j]
5266 );
5267 }
5268
5269 let base_hessian = base_state.hessian.to_dense();
5270 let scaled_hessian = scaled_state.hessian.to_dense();
5271 for r in 0..beta.len() {
5272 for c in 0..beta.len() {
5273 assert!(
5274 (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5275 "hessian mismatch at ({r},{c}): scaled={} base={}",
5276 scaled_hessian[[r, c]],
5277 base_hessian[[r, c]]
5278 );
5279 }
5280 }
5281 }
5282}