1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3 BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4 projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9 LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21 #[error("input dimensions are inconsistent")]
22 DimensionMismatch,
23 #[error("inputs contain non-finite values")]
24 NonFiniteInput,
25 #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26 UnsupportedSpec(&'static str),
27 #[error("crude risk integration setup is invalid")]
28 InvalidIntegrationSetup,
29 #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30 InvalidTimeGrid,
31 #[error("cumulative hazard must be nondecreasing")]
32 NonMonotoneCumulativeHazard,
33 #[error("instantaneous hazard must stay strictly positive during integration")]
34 NonPositiveHazard,
35 #[error("{reason}")]
36 InvalidInput { reason: String },
37 #[error("{reason}")]
38 CauseSpecificDimensionMismatch { reason: String },
39 #[error("{reason}")]
40 NumericalFailure { reason: String },
41 #[error("{reason}")]
42 EventCodeInvalid { reason: String },
43 #[error("{reason}")]
44 EventDegenerate { reason: String },
45 #[error("cause-specific survival block {block}: {source}")]
46 CauseSpecificBlock {
47 block: usize,
48 #[source]
49 source: Box<SurvivalError>,
50 },
51}
52
53impl From<SurvivalError> for String {
54 fn from(err: SurvivalError) -> Self {
55 err.to_string()
56 }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60 fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61 SurvivalError::CauseSpecificDimensionMismatch {
62 reason: err.message(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69 #[default]
70 Net,
71 Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76 pub age_entry: ArrayView1<'a, f64>,
77 pub age_exit: ArrayView1<'a, f64>,
78 pub event_target: ArrayView1<'a, u8>,
79 pub event_competing: ArrayView1<'a, u8>,
80 pub sampleweight: ArrayView1<'a, f64>,
81 pub x_entry: ArrayView2<'a, f64>,
82 pub x_exit: ArrayView2<'a, f64>,
83 pub x_derivative: ArrayView2<'a, f64>,
84 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94 pub age_entry: ArrayView1<'a, f64>,
95 pub age_exit: ArrayView1<'a, f64>,
96 pub event_target: ArrayView1<'a, u8>,
97 pub event_competing: ArrayView1<'a, u8>,
98 pub sampleweight: ArrayView1<'a, f64>,
99 pub time_entry: ArrayView2<'a, f64>,
100 pub time_exit: ArrayView2<'a, f64>,
101 pub time_derivative: ArrayView2<'a, f64>,
102 pub covariates: ArrayView2<'a, f64>,
103 pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107 pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113 pub eta_entry: ArrayView1<'a, f64>,
115 pub eta_exit: ArrayView1<'a, f64>,
117 pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129 pub matrix: Array2<f64>,
130 pub lambda: f64,
131 pub range: Range<usize>,
132 pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139 pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143 pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144 Self { blocks }
145 }
146
147 pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148 let mut grad = Array1::zeros(beta.len());
149 for block in &self.blocks {
150 if block.lambda == 0.0 {
151 continue;
152 }
153 let b = beta.slice(ndarray::s![block.range.clone()]);
154 let g = block.matrix.dot(&b);
155 let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156 dst += &(block.lambda * g);
157 }
158 grad
159 }
160
161 pub fn hessian(&self, dim: usize) -> Array2<f64> {
162 let mut h = Array2::zeros((dim, dim));
163 self.addhessian_inplace(&mut h);
164 h
165 }
166
167 pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168 let mut value = 0.0;
169 for block in &self.blocks {
170 if block.lambda == 0.0 {
171 continue;
172 }
173 let b = beta.slice(ndarray::s![block.range.clone()]);
174 value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175 }
176 value
177 }
178
179 pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180 for block in &self.blocks {
181 if block.lambda == 0.0 {
182 continue;
183 }
184 let start = block.range.start;
185 let end = block.range.end;
186 h.slice_mut(ndarray::s![start..end, start..end])
187 .scaled_add(block.lambda, &block.matrix);
188 }
189 }
190}
191
192const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
200
201const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
209
210#[derive(Debug, Clone)]
211pub struct CauseSpecificRoystonParmarBlock {
212 pub age_entry: Array1<f64>,
213 pub age_exit: Array1<f64>,
214 pub event_target: Array1<u8>,
215 pub sampleweight: Array1<f64>,
216 pub x_entry: Array2<f64>,
217 pub x_exit: Array2<f64>,
218 pub x_derivative: Array2<f64>,
219 pub offset_eta_entry: Array1<f64>,
220 pub offset_eta_exit: Array1<f64>,
221 pub offset_derivative_exit: Array1<f64>,
222 pub derivative_floor: f64,
223}
224
225#[derive(Debug, Clone)]
231pub struct CauseSpecificRoystonParmarFamily {
232 blocks: Vec<CauseSpecificRoystonParmarBlock>,
233}
234
235impl CauseSpecificRoystonParmarFamily {
236 pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
237 if blocks.is_empty() {
238 return Err(SurvivalError::InvalidInput {
239 reason: "cause-specific survival family requires at least one endpoint".to_string(),
240 }
241 .into());
242 }
243 for (idx, block) in blocks.iter().enumerate() {
244 validate_cause_specific_block(block).map_err(|err| {
245 SurvivalError::CauseSpecificBlock {
246 block: idx + 1,
247 source: Box::new(err),
248 }
249 .to_string()
250 })?;
251 }
252 Ok(Self { blocks })
253 }
254
255 pub fn cause_count(&self) -> usize {
256 self.blocks.len()
257 }
258}
259
260fn validate_cause_specific_block(
261 block: &CauseSpecificRoystonParmarBlock,
262) -> Result<(), SurvivalError> {
263 let n = block.event_target.len();
264 let p = block.x_exit.ncols();
265 if n == 0 || p == 0 {
266 bail_invalid_surv!("empty event vector or coefficient block");
267 }
268 if block.age_entry.len() != n
269 || block.age_exit.len() != n
270 || block.sampleweight.len() != n
271 || block.x_entry.nrows() != n
272 || block.x_exit.nrows() != n
273 || block.x_derivative.nrows() != n
274 || block.x_entry.ncols() != p
275 || block.x_derivative.ncols() != p
276 || block.offset_eta_entry.len() != n
277 || block.offset_eta_exit.len() != n
278 || block.offset_derivative_exit.len() != n
279 {
280 return Err(SurvivalError::CauseSpecificDimensionMismatch {
281 reason: "dimension mismatch".to_string(),
282 });
283 }
284 if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
290 return Err(SurvivalError::EventCodeInvalid {
291 reason: format!(
292 "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
293 ),
294 });
295 }
296 if block.age_entry.iter().any(|v| !v.is_finite())
297 || block.age_exit.iter().any(|v| !v.is_finite())
298 || block
299 .sampleweight
300 .iter()
301 .any(|v| !v.is_finite() || *v < 0.0)
302 || block.x_entry.iter().any(|v| !v.is_finite())
303 || block.x_exit.iter().any(|v| !v.is_finite())
304 || block.x_derivative.iter().any(|v| !v.is_finite())
305 || block.offset_eta_entry.iter().any(|v| !v.is_finite())
306 || block.offset_eta_exit.iter().any(|v| !v.is_finite())
307 || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
308 || !block.derivative_floor.is_finite()
309 || block.derivative_floor < 0.0
310 {
311 bail_invalid_surv!("non-finite input");
312 }
313 Ok(())
314}
315
316fn evaluate_cause_specific_block(
317 block: &CauseSpecificRoystonParmarBlock,
318 beta: &Array1<f64>,
319) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
320 let n = block.event_target.len();
321 let p = block.x_exit.ncols();
322 if beta.len() != p {
323 return Err(SurvivalError::CauseSpecificDimensionMismatch {
324 reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
325 });
326 }
327 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
328 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
329 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
330 let mut log_likelihood = 0.0;
331 let mut w_exit = Array1::<f64>::zeros(n);
332 let mut w_entry = Array1::<f64>::zeros(n);
333 let mut w_event = Array1::<f64>::zeros(n);
334 let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
335 let mut w_event_outer = Array1::<f64>::zeros(n);
336
337 for i in 0..n {
338 let weight = block.sampleweight[i];
339 if weight <= 0.0 {
340 continue;
341 }
342 if block.age_exit[i] < block.age_entry[i] {
343 bail_invalid_surv!("age_exit < age_entry at row {i}");
344 }
345 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
346 let h_exit = eta_exit[i].exp();
347 let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
348 if !(h_exit.is_finite() && h_entry.is_finite()) {
349 return Err(SurvivalError::NumericalFailure {
350 reason: format!("non-finite cumulative hazard at row {i}"),
351 });
352 }
353 log_likelihood -= weight * (h_exit - h_entry);
354 w_exit[i] = weight * h_exit;
355 w_entry[i] = weight * h_entry;
356 if block.event_target[i] > 0 {
357 let deriv = derivative[i];
358 if !(deriv.is_finite() && deriv > 0.0) {
359 return Err(SurvivalError::NumericalFailure {
360 reason: format!(
361 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
362 ),
363 });
364 }
365 log_likelihood += weight * (eta_exit[i] + deriv.ln());
366 w_event[i] = weight;
367 w_event_inv_deriv[i] = weight / deriv;
368 w_event_outer[i] = weight / (deriv * deriv);
369 }
370 }
371
372 let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
373 nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
374 nll_gradient -= &fast_atv(&block.x_exit, &w_event);
375 nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
376 let gradient = -nll_gradient;
377
378 let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
379 hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
380 hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
381 Ok((log_likelihood, gradient, hessian))
382}
383
384impl CustomFamily for CauseSpecificRoystonParmarFamily {
385 fn joint_jeffreys_term_required(&self) -> bool {
389 true
390 }
391
392 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
393 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
394 "cause-specific survival",
395 self.blocks.len(),
396 block_states.len(),
397 )?;
398 let mut log_likelihood = 0.0;
399 let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
400 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
401 let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
402 log_likelihood += ll;
403 blockworking_sets.push(BlockWorkingSet::ExactNewton {
404 gradient,
405 hessian: SymmetricMatrix::Dense(hessian),
406 });
407 }
408 Ok(FamilyEvaluation {
409 log_likelihood,
410 blockworking_sets,
411 })
412 }
413
414 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
415 crate::block_layout::block_count::validate_block_count::<SurvivalError>(
416 "cause-specific survival",
417 self.blocks.len(),
418 block_states.len(),
419 )?;
420 let mut log_likelihood = 0.0;
421 for (block, state) in self.blocks.iter().zip(block_states.iter()) {
422 let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
423 log_likelihood += ll;
424 }
425 Ok(log_likelihood)
426 }
427
428 fn likelihood_blocks_uncoupled(&self) -> bool {
429 true
430 }
431
432 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
433 true
434 }
435
436 fn output_channel_assignment(
437 &self,
438 specs: &[crate::custom_family::ParameterBlockSpec],
439 ) -> Option<Vec<usize>> {
440 if specs.len() != self.blocks.len() {
441 return Some((0..self.blocks.len()).collect());
442 }
443 Some((0..specs.len()).collect())
444 }
445
446 fn coefficient_hessian_cost(
447 &self,
448 specs: &[crate::custom_family::ParameterBlockSpec],
449 ) -> u64 {
450 crate::custom_family::default_coefficient_hessian_cost(specs)
451 }
452
453 fn block_linear_constraints(
454 &self,
455 _: &[ParameterBlockState],
456 block_idx: usize,
457 spec: &crate::custom_family::ParameterBlockSpec,
458 ) -> Result<Option<LinearInequalityConstraints>, String> {
459 let block = self.blocks.get(block_idx).ok_or_else(|| {
460 SurvivalError::CauseSpecificDimensionMismatch {
461 reason: format!(
462 "cause-specific survival expected block index < {}, got {block_idx}",
463 self.blocks.len()
464 ),
465 }
466 .to_string()
467 })?;
468 if block.x_derivative.ncols() != spec.design.ncols() {
469 return Err(SurvivalError::CauseSpecificDimensionMismatch {
470 reason: format!(
471 "cause-specific survival derivative design has {} columns but block '{}' has {}",
472 block.x_derivative.ncols(),
473 spec.name,
474 spec.design.ncols()
475 ),
476 }
477 .into());
478 }
479 let rhs = block
480 .offset_derivative_exit
481 .mapv(|offset| block.derivative_floor - offset);
482 Ok(Some(LinearInequalityConstraints {
483 a: block.x_derivative.clone(),
484 b: rhs,
485 }))
486 }
487
488 fn max_feasible_step_size(
489 &self,
490 block_states: &[ParameterBlockState],
491 block_idx: usize,
492 delta: &Array1<f64>,
493 ) -> Result<Option<f64>, String> {
494 let block = self.blocks.get(block_idx).ok_or_else(|| {
495 SurvivalError::CauseSpecificDimensionMismatch {
496 reason: format!(
497 "cause-specific survival expected block index < {}, got {block_idx}",
498 self.blocks.len()
499 ),
500 }
501 .to_string()
502 })?;
503 let state = block_states.get(block_idx).ok_or_else(|| {
504 SurvivalError::CauseSpecificDimensionMismatch {
505 reason: format!(
506 "cause-specific survival expected {} block states, got {}",
507 self.blocks.len(),
508 block_states.len()
509 ),
510 }
511 .to_string()
512 })?;
513 if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
514 return Err(SurvivalError::CauseSpecificDimensionMismatch {
515 reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
516 }
517 .into());
518 }
519 let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
520 let derivative_delta = fast_av(&block.x_derivative, delta);
521 let mut alpha_max = 1.0_f64;
522 for i in 0..derivative.len() {
523 if block.sampleweight[i] <= 0.0 {
524 continue;
525 }
526 let current = derivative[i] - block.derivative_floor;
527 let slope = derivative_delta[i];
528 if slope < 0.0 {
529 if current <= 0.0 {
530 return Ok(Some(0.0));
531 }
532 alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
533 }
534 }
535 Ok(Some(alpha_max.clamp(0.0, 1.0)))
536 }
537
538 fn exact_newton_hessian_directional_derivative(
539 &self,
540 block_states: &[ParameterBlockState],
541 block_idx: usize,
542 d_beta: &Array1<f64>,
543 ) -> Result<Option<Array2<f64>>, String> {
544 let block = self.blocks.get(block_idx).ok_or_else(|| {
545 SurvivalError::CauseSpecificDimensionMismatch {
546 reason: format!(
547 "cause-specific survival expected block index < {}, got {block_idx}",
548 self.blocks.len()
549 ),
550 }
551 .to_string()
552 })?;
553 let state = block_states.get(block_idx).ok_or_else(|| {
554 SurvivalError::CauseSpecificDimensionMismatch {
555 reason: format!(
556 "cause-specific survival expected {} block states, got {}",
557 self.blocks.len(),
558 block_states.len()
559 ),
560 }
561 .to_string()
562 })?;
563 Ok(Some(cause_specific_hessian_directional_derivative(
564 block,
565 &state.beta,
566 d_beta,
567 )?))
568 }
569
570 fn exact_newton_hessian_second_directional_derivative(
571 &self,
572 block_states: &[ParameterBlockState],
573 block_idx: usize,
574 d_beta_u: &Array1<f64>,
575 d_beta_v: &Array1<f64>,
576 ) -> Result<Option<Array2<f64>>, String> {
577 let block = self.blocks.get(block_idx).ok_or_else(|| {
578 SurvivalError::CauseSpecificDimensionMismatch {
579 reason: format!(
580 "cause-specific survival expected block index < {}, got {block_idx}",
581 self.blocks.len()
582 ),
583 }
584 .to_string()
585 })?;
586 let state = block_states.get(block_idx).ok_or_else(|| {
587 SurvivalError::CauseSpecificDimensionMismatch {
588 reason: format!(
589 "cause-specific survival expected {} block states, got {}",
590 self.blocks.len(),
591 block_states.len()
592 ),
593 }
594 .to_string()
595 })?;
596 Ok(Some(cause_specific_hessian_second_directional_derivative(
597 block,
598 &state.beta,
599 d_beta_u,
600 d_beta_v,
601 )?))
602 }
603}
604
605fn cause_specific_hessian_directional_derivative(
606 block: &CauseSpecificRoystonParmarBlock,
607 beta: &Array1<f64>,
608 d_beta: &Array1<f64>,
609) -> Result<Array2<f64>, SurvivalError> {
610 let p = block.x_exit.ncols();
611 if beta.len() != p || d_beta.len() != p {
612 return Err(SurvivalError::CauseSpecificDimensionMismatch {
613 reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
614 });
615 }
616 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
617 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
618 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
619 let d_eta_entry = fast_av(&block.x_entry, d_beta);
620 let d_eta_exit = fast_av(&block.x_exit, d_beta);
621 let d_derivative = fast_av(&block.x_derivative, d_beta);
622 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
623 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
624 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
625
626 for i in 0..block.event_target.len() {
627 let weight = block.sampleweight[i];
628 if weight <= 0.0 {
629 continue;
630 }
631 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
632 w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
633 if has_entry {
634 w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
635 }
636 if block.event_target[i] > 0 {
637 let deriv = derivative[i];
638 if !(deriv.is_finite() && deriv > 0.0) {
639 return Err(SurvivalError::NumericalFailure {
640 reason: format!(
641 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
642 ),
643 });
644 }
645 w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
646 }
647 }
648
649 let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
650 d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
651 d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
652 Ok(d_hessian)
653}
654
655fn cause_specific_hessian_second_directional_derivative(
656 block: &CauseSpecificRoystonParmarBlock,
657 beta: &Array1<f64>,
658 d_beta_u: &Array1<f64>,
659 d_beta_v: &Array1<f64>,
660) -> Result<Array2<f64>, SurvivalError> {
661 let p = block.x_exit.ncols();
662 if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
663 return Err(SurvivalError::CauseSpecificDimensionMismatch {
664 reason: "cause-specific survival second Hessian derivative dimension mismatch"
665 .to_string(),
666 });
667 }
668 let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
669 let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
670 let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
671 let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
672 let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
673 let u_derivative = fast_av(&block.x_derivative, d_beta_u);
674 let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
675 let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
676 let v_derivative = fast_av(&block.x_derivative, d_beta_v);
677 let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
678 let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
679 let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
680
681 for i in 0..block.event_target.len() {
682 let weight = block.sampleweight[i];
683 if weight <= 0.0 {
684 continue;
685 }
686 let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
687 w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
688 if has_entry {
689 w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
690 }
691 if block.event_target[i] > 0 {
692 let deriv = derivative[i];
693 if !(deriv.is_finite() && deriv > 0.0) {
694 return Err(SurvivalError::NumericalFailure {
695 reason: format!(
696 "cause-specific survival derivative must be positive at row {i}, got {deriv}"
697 ),
698 });
699 }
700 w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
701 }
702 }
703
704 let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
705 d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
706 d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
707 Ok(d2_hessian)
708}
709
710pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
711 const INTEGER_TOL: f64 = 1e-8;
712 const MAX_AUTO_CAUSES: u8 = 32;
713 if !value.is_finite() {
714 return Err(SurvivalError::EventCodeInvalid {
715 reason: format!(
716 "survival event value at row {} is non-finite",
717 row_index + 1
718 ),
719 }
720 .into());
721 }
722 if value < 0.0 {
723 return Err(SurvivalError::EventCodeInvalid {
724 reason: format!(
725 "survival event value at row {} is negative: {value}",
726 row_index + 1
727 ),
728 }
729 .into());
730 }
731 let rounded = value.round();
732 if (value - rounded).abs() > INTEGER_TOL {
733 return Err(SurvivalError::EventCodeInvalid {
734 reason: format!(
735 "survival event value at row {} must be an integer code with 0=censored, got {value}",
736 row_index + 1
737 ),
738 }
739 .into());
740 }
741 if rounded > f64::from(MAX_AUTO_CAUSES) {
742 return Err(SurvivalError::EventCodeInvalid {
743 reason: format!(
744 "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
745 row_index + 1
746 ),
747 }
748 .into());
749 }
750 Ok(rounded as u8)
751}
752
753pub fn cause_count_from_event_codes(
754 event_codes: ArrayView1<'_, u8>,
755) -> Result<usize, SurvivalError> {
756 let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
757 if max_code == 0 {
758 return Ok(1);
759 }
760
761 let mut present = vec![false; max_code + 1];
762 for code in event_codes.iter().copied() {
763 present[usize::from(code)] = true;
764 }
765 if (1..=max_code).any(|code| !present[code]) {
766 let actual = present
767 .iter()
768 .enumerate()
769 .skip(1)
770 .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
771 .collect::<Vec<_>>()
772 .join(", ");
773 return Err(SurvivalError::EventCodeInvalid {
774 reason: format!(
775 "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
776 ),
777 });
778 }
779
780 Ok(max_code)
781}
782
783pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
796 event_codes.mapv(|label| u8::from(label > 0))
797}
798
799pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
809 let cause_code = cause as u8;
810 event_codes.mapv(|observed| u8::from(observed == cause_code))
811}
812
813fn compress_positive_collinear_constraints(
814 a: &Array2<f64>,
815 b: &Array1<f64>,
816) -> LinearInequalityConstraints {
817 const SCALE_TOL: f64 = 1e-14;
818 const KEY_TOL: f64 = 1e-8;
819
820 let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
821 let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
822
823 for i in 0..a.nrows() {
824 let row = a.row(i);
825 let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
826 if !scale.is_finite() || scale <= SCALE_TOL {
827 if b[i] > 0.0 {
828 fallbackrows.push((row.to_vec(), b[i]));
829 }
830 continue;
831 }
832
833 let normalizedrow: Vec<f64> = row
834 .iter()
835 .map(|&v| {
836 let scaled = v / scale;
837 if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
838 })
839 .collect();
840 let normalized_rhs = b[i] / scale;
841 let key: Vec<i64> = normalizedrow
842 .iter()
843 .map(|&v| (v / KEY_TOL).round() as i64)
844 .collect();
845
846 match grouped.get_mut(&key) {
847 Some((_, rhs_max)) => {
848 if normalized_rhs > *rhs_max {
849 *rhs_max = normalized_rhs;
850 }
851 }
852 None => {
853 grouped.insert(key, (normalizedrow, normalized_rhs));
854 }
855 }
856 }
857
858 let nrows = grouped.len() + fallbackrows.len();
859 let n_cols = a.ncols();
860 let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
861 let mut b_out = Array1::<f64>::zeros(nrows);
862
863 let mut outrow = 0usize;
864 for (_, (row, rhs)) in grouped {
865 for (j, value) in row.into_iter().enumerate() {
866 a_out[[outrow, j]] = value;
867 }
868 b_out[outrow] = rhs;
869 outrow += 1;
870 }
871 for (row, rhs) in fallbackrows {
872 for (j, value) in row.into_iter().enumerate() {
873 a_out[[outrow, j]] = value;
874 }
875 b_out[outrow] = rhs;
876 outrow += 1;
877 }
878
879 LinearInequalityConstraints { a: a_out, b: b_out }
880}
881
882#[derive(Debug, Clone, Copy, Default)]
883pub struct SurvivalMonotonicityPenalty {
884 pub tolerance: f64,
885}
886
887#[derive(Debug, Clone)]
888enum SurvivalDesign {
889 Flat {
890 x_entry: Array2<f64>,
891 x_exit: Array2<f64>,
892 x_derivative: Array2<f64>,
893 },
894 TimeCovariateShared {
895 time_entry: Array2<f64>,
896 time_exit: Array2<f64>,
897 time_derivative: Array2<f64>,
898 covariates: Array2<f64>,
899 },
900}
901
902impl SurvivalDesign {
903 fn p_total(&self) -> usize {
904 match self {
905 Self::Flat { x_exit, .. } => x_exit.ncols(),
906 Self::TimeCovariateShared {
907 time_exit,
908 covariates,
909 ..
910 } => time_exit.ncols() + covariates.ncols(),
911 }
912 }
913
914 fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
915 match self {
916 Self::Flat { .. } => time_mat.dot(beta),
917 Self::TimeCovariateShared { covariates, .. } => {
918 let p_time = time_mat.ncols();
919 let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
920 if covariates.ncols() > 0 {
921 out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
922 }
923 out
924 }
925 }
926 }
927
928 fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
929 match self {
930 Self::Flat { .. } => {
931 for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
932 *dst = src;
933 }
934 }
935 Self::TimeCovariateShared { covariates, .. } => {
936 let p_time = time_mat.ncols();
937 for j in 0..p_time {
938 out[j] = time_mat[[i, j]];
939 }
940 for j in 0..covariates.ncols() {
941 out[p_time + j] = covariates[[i, j]];
942 }
943 }
944 }
945 }
946}
947
948#[derive(Debug, Clone)]
950struct SurvivalWorkspace {
951 w_event: Array1<f64>,
952 w_event_inv_deriv: Array1<f64>,
953 w_event_outer: Array1<f64>,
954 w_hess_exit: Array1<f64>,
955 w_hess_entry: Array1<f64>,
956}
957
958impl SurvivalWorkspace {
959 fn new(n: usize) -> Self {
960 Self {
961 w_event: Array1::zeros(n),
962 w_event_inv_deriv: Array1::zeros(n),
963 w_event_outer: Array1::zeros(n),
964 w_hess_exit: Array1::zeros(n),
965 w_hess_entry: Array1::zeros(n),
966 }
967 }
968
969 fn reset(&mut self, n: usize) {
970 if self.w_event.len() != n {
971 *self = Self::new(n);
972 } else {
973 self.w_event.fill(0.0);
974 self.w_event_inv_deriv.fill(0.0);
975 self.w_event_outer.fill(0.0);
976 self.w_hess_exit.fill(0.0);
977 self.w_hess_entry.fill(0.0);
978 }
979 }
980}
981
982#[derive(Clone, Debug)]
995pub struct OffsetChannelResiduals {
996 pub exit: Array1<f64>,
998 pub entry: Array1<f64>,
1000 pub derivative: Array1<f64>,
1002 pub right: Array1<f64>,
1006}
1007
1008#[derive(Clone, Debug)]
1011pub struct OffsetChannelCurvatures {
1012 pub rows: Vec<[[f64; 3]; 3]>,
1013}
1014
1015#[derive(Debug)]
1016pub struct WorkingModelSurvival {
1017 age_entry: Array1<f64>,
1018 age_exit: Array1<f64>,
1019 entry_at_origin: Array1<bool>,
1020 event_target: Array1<u8>,
1021 sampleweight: Array1<f64>,
1022 design: SurvivalDesign,
1023 offset_eta_entry: Array1<f64>,
1024 offset_eta_exit: Array1<f64>,
1025 offset_derivative_exit: Array1<f64>,
1026 penalties: PenaltyBlocks,
1027 monotonicity: SurvivalMonotonicityPenalty,
1028 structurally_monotonic: bool,
1029 structural_time_columns: usize,
1030 monotonicity_constraint_rows: Option<Array2<f64>>,
1031 monotonicity_constraint_offsets: Option<Array1<f64>>,
1032 workspace: std::sync::Mutex<SurvivalWorkspace>,
1033}
1034
1035impl Clone for WorkingModelSurvival {
1036 fn clone(&self) -> Self {
1037 let workspace = self.workspace.lock().unwrap().clone();
1038 Self {
1039 age_entry: self.age_entry.clone(),
1040 age_exit: self.age_exit.clone(),
1041 entry_at_origin: self.entry_at_origin.clone(),
1042 event_target: self.event_target.clone(),
1043 sampleweight: self.sampleweight.clone(),
1044 design: self.design.clone(),
1045 offset_eta_entry: self.offset_eta_entry.clone(),
1046 offset_eta_exit: self.offset_eta_exit.clone(),
1047 offset_derivative_exit: self.offset_derivative_exit.clone(),
1048 penalties: self.penalties.clone(),
1049 monotonicity: self.monotonicity,
1050 structurally_monotonic: self.structurally_monotonic,
1051 structural_time_columns: self.structural_time_columns,
1052 monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1053 monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1054 workspace: std::sync::Mutex::new(workspace),
1055 }
1056 }
1057}
1058
1059impl WorkingModelSurvival {
1060 const LOG_F64_MAX: f64 = 709.782712893384;
1061
1062 #[inline]
1063 fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1064 if base == 0.0 {
1065 return Ok(0.0);
1066 }
1067 let log_abs = log_scale + base.abs().ln();
1068 if !log_abs.is_finite() {
1069 crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1070 }
1071 if log_abs > Self::LOG_F64_MAX {
1072 crate::bail_invalid_estim!(
1073 "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1074 );
1075 }
1076 Ok(base.signum() * log_abs.exp())
1077 }
1078
1079 fn coefficient_dim(&self) -> usize {
1080 self.design.p_total()
1081 }
1082
1083 fn nrows(&self) -> usize {
1084 self.sampleweight.len()
1085 }
1086
1087 fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1088 let time_mat = match &self.design {
1089 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1090 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1091 };
1092 self.design.design_dot(time_mat, beta)
1093 }
1094
1095 fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1096 let time_mat = match &self.design {
1097 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1098 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1099 };
1100 self.design.design_dot(time_mat, beta)
1101 }
1102
1103 fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1104 match &self.design {
1105 SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1106 SurvivalDesign::TimeCovariateShared {
1107 time_derivative, ..
1108 } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1109 }
1110 }
1111
1112 fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1113 let time_mat = match &self.design {
1114 SurvivalDesign::Flat { x_entry, .. } => x_entry,
1115 SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1116 };
1117 self.design.fill_row(time_mat, i, out);
1118 }
1119
1120 fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1121 let time_mat = match &self.design {
1122 SurvivalDesign::Flat { x_exit, .. } => x_exit,
1123 SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1124 };
1125 self.design.fill_row(time_mat, i, out);
1126 }
1127
1128 fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1129 match &self.design {
1130 SurvivalDesign::Flat { x_derivative, .. } => {
1131 for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1132 *dst = src;
1133 }
1134 }
1135 SurvivalDesign::TimeCovariateShared {
1136 time_derivative, ..
1137 } => {
1138 let p_time = time_derivative.ncols();
1139 for j in 0..p_time {
1140 out[j] = time_derivative[[i, j]];
1141 }
1142 for dst in out.iter_mut().skip(p_time) {
1143 *dst = 0.0;
1144 }
1145 }
1146 }
1147 }
1148
1149 fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1150 match &self.design {
1151 SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1152 SurvivalDesign::TimeCovariateShared {
1153 time_derivative,
1154 covariates,
1155 ..
1156 } => {
1157 let p_time = time_derivative.ncols();
1158 let p_cov = covariates.ncols();
1159 let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1160 let time_block = fast_xt_diag_x(time_derivative, weights);
1161 out.slice_mut(ndarray::s![..p_time, ..p_time])
1162 .assign(&time_block);
1163 out
1164 }
1165 }
1166 }
1167
1168 fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1172 match &self.design {
1173 SurvivalDesign::Flat {
1174 x_entry, x_exit, ..
1175 } => {
1176 let mut h = fast_xt_diag_x(x_exit, w_exit);
1177 h -= &fast_xt_diag_x(x_entry, w_entry);
1178 h
1179 }
1180 SurvivalDesign::TimeCovariateShared {
1181 time_entry,
1182 time_exit,
1183 covariates,
1184 ..
1185 } => {
1186 let p_time = time_exit.ncols();
1187 let p_cov = covariates.ncols();
1188 let p = p_time + p_cov;
1189 let mut h = Array2::<f64>::zeros((p, p));
1190 let tt = {
1192 let mut block = fast_xt_diag_x(time_exit, w_exit);
1193 block -= &fast_xt_diag_x(time_entry, w_entry);
1194 block
1195 };
1196 h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1197 if p_cov > 0 {
1198 let tc = {
1200 let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1201 block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1202 block
1203 };
1204 h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1205 h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1206 let w_diff = w_exit - w_entry;
1208 let cc = fast_xt_diag_x(covariates, &w_diff);
1209 h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1210 }
1211 h
1212 }
1213 }
1214 }
1215
1216 fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1217 const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1218 if !self.structurally_monotonic {
1219 return None;
1220 }
1221 if deriv >= 1e-12 {
1222 return Some(deriv);
1223 }
1224 if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1225 return Some(1e-12);
1226 }
1227 None
1228 }
1229
1230 fn validate_penalties(
1231 penalties: &PenaltyBlocks,
1232 coefficient_dim: usize,
1233 ) -> Result<(), SurvivalError> {
1234 for block in &penalties.blocks {
1235 if !block.lambda.is_finite() || block.lambda < 0.0 {
1236 return Err(SurvivalError::NonFiniteInput);
1237 }
1238 if block.range.start > block.range.end || block.range.end > coefficient_dim {
1239 return Err(SurvivalError::DimensionMismatch);
1240 }
1241 let block_dim = block.range.end - block.range.start;
1242 if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1243 return Err(SurvivalError::DimensionMismatch);
1244 }
1245 if block.matrix.iter().any(|v| !v.is_finite()) {
1246 return Err(SurvivalError::NonFiniteInput);
1247 }
1248 }
1249 Ok(())
1250 }
1251
1252 fn derivative_guard(&self) -> f64 {
1253 if self.structurally_monotonic {
1254 return 0.0;
1258 }
1259 self.monotonicity.tolerance.max(0.0)
1260 }
1261
1262 fn derivative_guard_numerical(&self) -> f64 {
1263 let derivative_guard = self.derivative_guard();
1264 if derivative_guard <= 0.0 {
1265 if self.structurally_monotonic {
1274 -1e-10
1275 } else {
1276 1e-12
1277 }
1278 } else {
1279 (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1280 }
1281 }
1282
1283 fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1284 let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1285 1e-10 * scale
1286 }
1287
1288 fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1289 if !self.structurally_monotonic {
1290 return None;
1291 }
1292 let p = self.coefficient_dim();
1293 let time_columns = self.structural_time_columns.min(p);
1294 if time_columns == 0 {
1295 return None;
1296 }
1297 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1298 let mut active_columns = vec![false; time_columns];
1299 let mut derivative_row = vec![0.0_f64; p];
1300 for i in 0..self.nrows() {
1301 if self.sampleweight[i] <= 0.0 {
1302 continue;
1303 }
1304 self.fill_derivative_row(i, &mut derivative_row);
1305 for j in 0..time_columns {
1306 if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1307 active_columns[j] = true;
1308 }
1309 }
1310 }
1311 if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1312 for i in 0..rows.nrows() {
1313 for j in 0..time_columns {
1314 if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1315 active_columns[j] = true;
1316 }
1317 }
1318 }
1319 }
1320 let active_columns: Vec<usize> = active_columns
1321 .into_iter()
1322 .enumerate()
1323 .filter_map(|(j, active)| active.then_some(j))
1324 .collect();
1325 if active_columns.is_empty() {
1326 return None;
1327 }
1328 let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1329 let b = Array1::<f64>::zeros(active_columns.len());
1330 for (row, &col) in active_columns.iter().enumerate() {
1331 a[[row, col]] = 1.0;
1332 }
1333 Some(LinearInequalityConstraints { a, b })
1334 }
1335
1336 pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1337 let p = self.coefficient_dim();
1338 const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1339 if p == 0 {
1340 return None;
1341 }
1342 if self.structurally_monotonic {
1343 return self.structural_time_coefficient_constraints();
1344 }
1345 if let (Some(rows), Some(offsets)) = (
1346 self.monotonicity_constraint_rows.as_ref(),
1347 self.monotonicity_constraint_offsets.as_ref(),
1348 ) {
1349 let activerows: Vec<usize> = (0..rows.nrows())
1350 .filter(|&i| {
1351 rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1352 > DERIVATIVE_ROW_NORM_TOL
1353 })
1354 .collect();
1355 if activerows.is_empty() {
1356 return None;
1357 }
1358 let mut a = Array2::<f64>::zeros((activerows.len(), p));
1359 let mut b = Array1::<f64>::zeros(activerows.len());
1360 for (r, &i) in activerows.iter().enumerate() {
1361 a.row_mut(r).assign(&rows.row(i));
1362 b[r] = self.derivative_guard() - offsets[i];
1363 }
1364 return Some(compress_positive_collinear_constraints(&a, &b));
1365 }
1366 None
1367 }
1368
1369 pub fn from_engine_inputs(
1370 inputs: SurvivalEngineInputs<'_>,
1371 penalties: PenaltyBlocks,
1372 monotonicity: SurvivalMonotonicityPenalty,
1373 spec: SurvivalSpec,
1374 ) -> Result<Self, SurvivalError> {
1375 Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1376 }
1377
1378 fn validate_offsets(
1379 offsets: Option<SurvivalBaselineOffsets<'_>>,
1380 n: usize,
1381 ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1382 if let Some(off) = offsets {
1383 if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1384 {
1385 return Err(SurvivalError::DimensionMismatch);
1386 }
1387 if off.eta_entry.iter().any(|v| !v.is_finite())
1388 || off.eta_exit.iter().any(|v| !v.is_finite())
1389 || off.derivative_exit.iter().any(|v| !v.is_finite())
1390 {
1391 return Err(SurvivalError::NonFiniteInput);
1392 }
1393 Ok((
1394 off.eta_entry.to_owned(),
1395 off.eta_exit.to_owned(),
1396 off.derivative_exit.to_owned(),
1397 ))
1398 } else {
1399 Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1400 }
1401 }
1402
1403 fn validate_common_inputs(
1404 age_entry: &ArrayView1<f64>,
1405 age_exit: &ArrayView1<f64>,
1406 event_target: &ArrayView1<u8>,
1407 event_competing: &ArrayView1<u8>,
1408 sampleweight: &ArrayView1<f64>,
1409 ) -> Result<(), SurvivalError> {
1410 if age_entry.iter().any(|v| !v.is_finite())
1411 || age_exit.iter().any(|v| !v.is_finite())
1412 || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1413 {
1414 return Err(SurvivalError::NonFiniteInput);
1415 }
1416 if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1423 return Err(SurvivalError::EventCodeInvalid {
1424 reason: format!(
1425 "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1426 ),
1427 });
1428 }
1429 if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1430 return Err(SurvivalError::EventCodeInvalid {
1431 reason: format!(
1432 "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1433 ),
1434 });
1435 }
1436 if event_target
1437 .iter()
1438 .zip(event_competing.iter())
1439 .any(|(&target, &competing)| target > 0 && competing > 0)
1440 {
1441 return Err(SurvivalError::EventCodeInvalid {
1442 reason: "a row cannot be simultaneously a target event and a competing event"
1443 .to_string(),
1444 });
1445 }
1446 if age_entry
1460 .iter()
1461 .zip(age_exit.iter())
1462 .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1463 {
1464 return Err(SurvivalError::NonFiniteInput);
1465 }
1466 Ok::<(), _>(())
1467 }
1468
1469 fn validate_monotonicity_constraints(
1470 rows: Option<ArrayView2<'_, f64>>,
1471 offsets: Option<ArrayView1<'_, f64>>,
1472 coefficient_dim: usize,
1473 ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1474 match (rows, offsets) {
1475 (None, None) => Ok((None, None)),
1476 (Some(rows), Some(offsets)) => {
1477 if rows.ncols() != coefficient_dim
1478 || rows.nrows() != offsets.len()
1479 || rows.iter().any(|v| !v.is_finite())
1480 || offsets.iter().any(|v| !v.is_finite())
1481 {
1482 return Err(SurvivalError::DimensionMismatch);
1483 }
1484 Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1485 }
1486 _ => Err(SurvivalError::DimensionMismatch),
1487 }
1488 }
1489
1490 fn finish_construction(
1491 age_entry: ArrayView1<f64>,
1492 age_exit: ArrayView1<f64>,
1493 event_target: ArrayView1<u8>,
1494 sampleweight: ArrayView1<f64>,
1495 design: SurvivalDesign,
1496 offset_eta_entry: Array1<f64>,
1497 offset_eta_exit: Array1<f64>,
1498 offset_derivative_exit: Array1<f64>,
1499 penalties: PenaltyBlocks,
1500 monotonicity: SurvivalMonotonicityPenalty,
1501 monotonicity_constraint_rows: Option<Array2<f64>>,
1502 monotonicity_constraint_offsets: Option<Array1<f64>>,
1503 ) -> Self {
1504 let n = age_entry.len();
1505 Self {
1506 age_entry: age_entry.to_owned(),
1507 age_exit: age_exit.to_owned(),
1508 entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1509 event_target: event_target.to_owned(),
1510 sampleweight: sampleweight.to_owned(),
1511 design,
1512 offset_eta_entry,
1513 offset_eta_exit,
1514 offset_derivative_exit,
1515 penalties,
1516 monotonicity,
1517 structurally_monotonic: false,
1518 structural_time_columns: 0,
1519 monotonicity_constraint_rows,
1520 monotonicity_constraint_offsets,
1521 workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1522 }
1523 }
1524
1525 pub fn from_engine_inputswith_offsets(
1526 inputs: SurvivalEngineInputs<'_>,
1527 offsets: Option<SurvivalBaselineOffsets<'_>>,
1528 penalties: PenaltyBlocks,
1529 monotonicity: SurvivalMonotonicityPenalty,
1530 spec: SurvivalSpec,
1531 ) -> Result<Self, SurvivalError> {
1532 if spec == SurvivalSpec::Crude {
1533 return Err(SurvivalError::UnsupportedSpec("crude"));
1534 }
1535 let n = inputs.age_entry.len();
1536 let p = inputs.x_entry.ncols();
1537 if inputs.age_exit.len() != n
1538 || inputs.event_target.len() != n
1539 || inputs.event_competing.len() != n
1540 || inputs.sampleweight.len() != n
1541 || inputs.x_entry.nrows() != n
1542 || inputs.x_exit.nrows() != n
1543 || inputs.x_derivative.nrows() != n
1544 || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1545 || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1546 {
1547 return Err(SurvivalError::DimensionMismatch);
1548 }
1549 Self::validate_penalties(&penalties, p)?;
1550 Self::validate_common_inputs(
1551 &inputs.age_entry,
1552 &inputs.age_exit,
1553 &inputs.event_target,
1554 &inputs.event_competing,
1555 &inputs.sampleweight,
1556 )?;
1557 if inputs.x_entry.iter().any(|v| !v.is_finite())
1558 || inputs.x_exit.iter().any(|v| !v.is_finite())
1559 || inputs.x_derivative.iter().any(|v| !v.is_finite())
1560 {
1561 return Err(SurvivalError::NonFiniteInput);
1562 }
1563 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1564 Self::validate_offsets(offsets, n)?;
1565 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1566 Self::validate_monotonicity_constraints(
1567 inputs.monotonicity_constraint_rows,
1568 inputs.monotonicity_constraint_offsets,
1569 p,
1570 )?;
1571
1572 Ok(Self::finish_construction(
1573 inputs.age_entry,
1574 inputs.age_exit,
1575 inputs.event_target,
1576 inputs.sampleweight,
1577 SurvivalDesign::Flat {
1578 x_entry: inputs.x_entry.to_owned(),
1579 x_exit: inputs.x_exit.to_owned(),
1580 x_derivative: inputs.x_derivative.to_owned(),
1581 },
1582 offset_eta_entry,
1583 offset_eta_exit,
1584 offset_derivative_exit,
1585 penalties,
1586 monotonicity,
1587 monotonicity_constraint_rows,
1588 monotonicity_constraint_offsets,
1589 ))
1590 }
1591
1592 pub fn from_time_covariate_inputswith_offsets(
1593 inputs: SurvivalTimeCovarInputs<'_>,
1594 offsets: Option<SurvivalBaselineOffsets<'_>>,
1595 penalties: PenaltyBlocks,
1596 monotonicity: SurvivalMonotonicityPenalty,
1597 spec: SurvivalSpec,
1598 ) -> Result<Self, SurvivalError> {
1599 if spec == SurvivalSpec::Crude {
1600 return Err(SurvivalError::UnsupportedSpec("crude"));
1601 }
1602 let n = inputs.age_entry.len();
1603 let p_time = inputs.time_entry.ncols();
1604 let p_cov = inputs.covariates.ncols();
1605 let p = p_time + p_cov;
1606 if inputs.age_exit.len() != n
1607 || inputs.event_target.len() != n
1608 || inputs.event_competing.len() != n
1609 || inputs.sampleweight.len() != n
1610 || inputs.time_entry.nrows() != n
1611 || inputs.time_exit.nrows() != n
1612 || inputs.time_derivative.nrows() != n
1613 || inputs.covariates.nrows() != n
1614 || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1615 || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1616 {
1617 return Err(SurvivalError::DimensionMismatch);
1618 }
1619 Self::validate_penalties(&penalties, p)?;
1620 Self::validate_common_inputs(
1621 &inputs.age_entry,
1622 &inputs.age_exit,
1623 &inputs.event_target,
1624 &inputs.event_competing,
1625 &inputs.sampleweight,
1626 )?;
1627 if inputs.time_entry.iter().any(|v| !v.is_finite())
1628 || inputs.time_exit.iter().any(|v| !v.is_finite())
1629 || inputs.time_derivative.iter().any(|v| !v.is_finite())
1630 || inputs.covariates.iter().any(|v| !v.is_finite())
1631 {
1632 return Err(SurvivalError::NonFiniteInput);
1633 }
1634 let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1635 Self::validate_offsets(offsets, n)?;
1636 let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1637 Self::validate_monotonicity_constraints(
1638 inputs.monotonicity_constraint_rows,
1639 inputs.monotonicity_constraint_offsets,
1640 p,
1641 )?;
1642
1643 Ok(Self::finish_construction(
1644 inputs.age_entry,
1645 inputs.age_exit,
1646 inputs.event_target,
1647 inputs.sampleweight,
1648 SurvivalDesign::TimeCovariateShared {
1649 time_entry: inputs.time_entry.to_owned(),
1650 time_exit: inputs.time_exit.to_owned(),
1651 time_derivative: inputs.time_derivative.to_owned(),
1652 covariates: inputs.covariates.to_owned(),
1653 },
1654 offset_eta_entry,
1655 offset_eta_exit,
1656 offset_derivative_exit,
1657 penalties,
1658 monotonicity,
1659 monotonicity_constraint_rows,
1660 monotonicity_constraint_offsets,
1661 ))
1662 }
1663
1664 pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1679 if lambdas.len() != self.penalties.blocks.len() {
1680 crate::bail_invalid_estim!(
1681 "set_penalty_lambdas expects {} lambdas, got {}",
1682 self.penalties.blocks.len(),
1683 lambdas.len()
1684 );
1685 }
1686 for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1687 if !lambda.is_finite() || lambda < 0.0 {
1688 crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1689 }
1690 block.lambda = lambda;
1691 }
1692 Ok(())
1693 }
1694
1695 pub fn set_structural_monotonicity(
1696 &mut self,
1697 enabled: bool,
1698 time_columns: usize,
1699 ) -> Result<(), EstimationError> {
1700 let p = self.coefficient_dim();
1701 if time_columns > p {
1702 crate::bail_invalid_estim!(
1703 "structural time columns {} exceed coefficient dimension {}",
1704 time_columns,
1705 p
1706 );
1707 }
1708 if enabled && time_columns == 0 {
1709 crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1710 }
1711 if enabled {
1712 const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1713 for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1714 if offset < -STRUCTURAL_DERIV_TOL {
1715 crate::bail_invalid_estim!(
1716 "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1717 );
1718 }
1719 }
1720 let mut derivative_row = vec![0.0_f64; p];
1721 for i in 0..self.nrows() {
1722 self.fill_derivative_row(i, &mut derivative_row);
1723 for j in 0..time_columns {
1724 let v = derivative_row[j];
1725 if v < -STRUCTURAL_DERIV_TOL {
1726 crate::bail_invalid_estim!(
1727 "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1728 );
1729 }
1730 }
1731 for j in time_columns..p {
1732 let v = derivative_row[j];
1733 if v.abs() > STRUCTURAL_DERIV_TOL {
1734 crate::bail_invalid_estim!(
1735 "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1736 );
1737 }
1738 }
1739 }
1740 if let (Some(rows), Some(offsets)) = (
1741 self.monotonicity_constraint_rows.as_ref(),
1742 self.monotonicity_constraint_offsets.as_ref(),
1743 ) {
1744 for (i, &offset) in offsets.iter().enumerate() {
1745 if offset < -STRUCTURAL_DERIV_TOL {
1746 crate::bail_invalid_estim!(
1747 "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1748 );
1749 }
1750 }
1751 for i in 0..rows.nrows() {
1752 for j in 0..time_columns {
1753 let v = rows[[i, j]];
1754 if v < -STRUCTURAL_DERIV_TOL {
1755 crate::bail_invalid_estim!(
1756 "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1757 );
1758 }
1759 }
1760 for j in time_columns..p {
1761 let v = rows[[i, j]];
1762 if v.abs() > STRUCTURAL_DERIV_TOL {
1763 crate::bail_invalid_estim!(
1764 "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1765 );
1766 }
1767 }
1768 }
1769 }
1770 }
1771 self.structurally_monotonic = enabled;
1772 self.structural_time_columns = if enabled { time_columns } else { 0 };
1773 Ok(())
1774 }
1775
1776 pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1777 if beta.len() != self.coefficient_dim() {
1778 crate::bail_invalid_estim!("survival beta dimension mismatch");
1779 }
1780
1781 let n = self.nrows();
1782 let p = self.coefficient_dim();
1783
1784 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1810 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1811 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1812
1813 let mut nll = 0.0;
1814 let derivative_guard = self.derivative_guard();
1815 let derivative_guard_numerical = self.derivative_guard_numerical();
1816 let mut workspace = self.workspace.lock().unwrap();
1817 workspace.reset(n);
1818 let SurvivalWorkspace {
1819 w_event,
1820 w_event_inv_deriv,
1821 w_event_outer,
1822 w_hess_exit,
1823 w_hess_entry,
1824 } = &mut *workspace;
1825
1826 for i in 0..n {
1828 let w = self.sampleweight[i];
1829 if w <= 0.0 {
1830 continue;
1831 }
1832 let entry_age = self.age_entry[i];
1833 let exit_age = self.age_exit[i];
1834 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1835 crate::bail_invalid_estim!(
1836 "survival ages must be finite with age_exit >= age_entry"
1837 );
1838 }
1839 let d = f64::from(self.event_target[i]);
1840
1841 let has_entry_interval = !self.entry_at_origin[i];
1842 let interval_scale = if has_entry_interval {
1843 eta_exit[i].max(eta_entry[i])
1844 } else {
1845 eta_exit[i]
1846 };
1847 let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1848 let h_s_scaled = if has_entry_interval {
1849 (eta_entry[i] - interval_scale).exp()
1850 } else {
1851 0.0
1852 };
1853 let interval_scaled = h_e_scaled - h_s_scaled;
1854 let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1855 let deriv = self
1856 .stabilized_structural_derivative(derivative_raw[i])
1857 .unwrap_or(derivative_raw[i]);
1858 let mono_floor = if d > 0.0 {
1867 derivative_guard_numerical
1868 } else {
1869 0.0
1870 };
1871 if !deriv.is_finite() || deriv < mono_floor {
1872 return Err(EstimationError::ParameterConstraintViolation(format!(
1873 "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1874 i, deriv, derivative_guard
1875 )));
1876 }
1877 if has_entry_interval {
1878 let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1879 if interval_scaled + increment_guard < 0.0 {
1880 return Err(EstimationError::ParameterConstraintViolation(format!(
1881 "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1882 i, interval
1883 )));
1884 }
1885 }
1886 nll += w * interval;
1887
1888 let w_exit_i = w * eta_exit[i].exp();
1892 let w_entry_i = if has_entry_interval {
1893 w * eta_entry[i].exp()
1894 } else {
1895 0.0
1896 };
1897 if !w_exit_i.is_finite() {
1898 crate::bail_invalid_estim!(
1899 "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1900 );
1901 }
1902 w_hess_exit[i] = w_exit_i;
1903 w_hess_entry[i] = w_entry_i;
1904
1905 if d > 0.0 {
1906 let inv_deriv = 1.0 / deriv;
1907 nll += -w * (eta_exit[i] + deriv.ln());
1908 w_event[i] = w;
1909 w_event_inv_deriv[i] = w * inv_deriv;
1910 w_event_outer[i] = w * inv_deriv * inv_deriv;
1911 }
1912 }
1913
1914 let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1918 let mut grad = Array1::<f64>::zeros(p);
1922 let mut grad_comp = Array1::<f64>::zeros(p);
1923 let mut row_exit = vec![0.0_f64; p];
1924 let mut row_entry = vec![0.0_f64; p];
1925 let mut row_derivative = vec![0.0_f64; p];
1926 for i in 0..n {
1927 let w_interval_exit = w_hess_exit[i];
1928 let w_interval_entry = w_hess_entry[i];
1929 let w_event_exit = w_event[i];
1930 let w_event_derivative = w_event_inv_deriv[i];
1931 if w_interval_exit == 0.0
1932 && w_interval_entry == 0.0
1933 && w_event_exit == 0.0
1934 && w_event_derivative == 0.0
1935 {
1936 continue;
1937 }
1938 self.fill_exit_row(i, &mut row_exit);
1939 self.fill_entry_row(i, &mut row_entry);
1940 self.fill_derivative_row(i, &mut row_derivative);
1941 for j in 0..p {
1942 let contribution = w_interval_exit * row_exit[j]
1943 - w_interval_entry * row_entry[j]
1944 - w_event_exit * row_exit[j]
1945 - w_event_derivative * row_derivative[j];
1946 let t = grad[j] + contribution;
1947 if grad[j].abs() >= contribution.abs() {
1948 grad_comp[j] += (grad[j] - t) + contribution;
1949 } else {
1950 grad_comp[j] += (contribution - t) + grad[j];
1951 }
1952 grad[j] = t;
1953 }
1954 }
1955 grad += &grad_comp;
1956
1957 h += &self.derivative_xt_diag_x(w_event_outer);
1958
1959 let score_norm = array1_l2_norm(&grad);
1963
1964 let penaltygrad = self.penalties.gradient(beta);
1965 let penalty_dev = self.penalties.deviance(beta);
1966 let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1967
1968 let mut totalgrad = grad;
1969 totalgrad += &penaltygrad;
1970
1971 self.penalties.addhessian_inplace(&mut h);
1972 const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1984 let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1985 for d in 0..p {
1986 h[[d, d]] += ridge_used;
1987 }
1988 totalgrad += &beta.mapv(|v| ridge_used * v);
1989 let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1993 let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1994
1995 let log_likelihood = -nll;
1996 let deviance = 2.0 * nll;
1997
1998 Ok(WorkingState {
1999 eta: LinearPredictor::new(eta_exit),
2000 gradient: totalgrad,
2001 hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2002 log_likelihood,
2003 deviance,
2004 penalty_term: penalty_dev + ridge_penalty,
2005 firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2006 ridge_used,
2007 hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2008 gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2009 })
2010 }
2011
2012 pub(crate) fn survival_hessian_derivative_correction(
2022 &self,
2023 beta: &Array1<f64>,
2024 u_k: &Array1<f64>,
2025 ) -> Result<Array2<f64>, EstimationError> {
2026 let p = beta.len();
2027 let n = self.nrows();
2028
2029 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2030 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2031 let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2032 let exp_entry = eta_entry.mapv(f64::exp);
2033 let exp_exit = eta_exit.mapv(f64::exp);
2034 let guard = self.derivative_guard();
2035 let guard_numerical = self.derivative_guard_numerical();
2036
2037 let jac = Array1::<f64>::ones(p);
2038 let curvature = Array1::<f64>::zeros(p);
2039 let third = Array1::<f64>::zeros(p);
2040
2041 let mut row_exit = vec![0.0_f64; p];
2042 let mut row_entry = vec![0.0_f64; p];
2043 let mut row_derivative = vec![0.0_f64; p];
2044 let mut ge = vec![0.0_f64; p];
2045 let mut gs = vec![0.0_f64; p];
2046 let mut gsd = vec![0.0_f64; p];
2047 let mut he = vec![0.0_f64; p];
2048 let mut hs = vec![0.0_f64; p];
2049 let mut hsd = vec![0.0_f64; p];
2050 let mut te = vec![0.0_f64; p];
2051 let mut ts = vec![0.0_f64; p];
2052 let mut tsd = vec![0.0_f64; p];
2053
2054 let mut b_dir = Array2::<f64>::zeros((p, p));
2055
2056 for i in 0..n {
2057 let w_i = self.sampleweight[i];
2058 if w_i <= 0.0 {
2059 continue;
2060 }
2061 let has_entry = !self.entry_at_origin[i];
2062 let mut deta_e = 0.0_f64;
2063 let mut deta_s = 0.0_f64;
2064 let mut ds = 0.0_f64;
2065 self.fill_exit_row(i, &mut row_exit);
2066 self.fill_entry_row(i, &mut row_entry);
2067 self.fill_derivative_row(i, &mut row_derivative);
2068 for j in 0..p {
2069 ge[j] = row_exit[j] * jac[j];
2070 gs[j] = row_entry[j] * jac[j];
2071 gsd[j] = row_derivative[j] * jac[j];
2072 he[j] = row_exit[j] * curvature[j];
2073 hs[j] = row_entry[j] * curvature[j];
2074 hsd[j] = row_derivative[j] * curvature[j];
2075 te[j] = row_exit[j] * third[j];
2076 ts[j] = row_entry[j] * third[j];
2077 tsd[j] = row_derivative[j] * third[j];
2078 deta_e += ge[j] * u_k[j];
2079 if has_entry {
2080 deta_s += gs[j] * u_k[j];
2081 }
2082 ds += gsd[j] * u_k[j];
2083 }
2084
2085 for r in 0..p {
2087 let dge_r = he[r] * u_k[r];
2088 let dgs_r = hs[r] * u_k[r];
2089 let dhe_r = te[r] * u_k[r];
2090 let dhs_r = ts[r] * u_k[r];
2091 for c in 0..p {
2092 let dge_c = he[c] * u_k[c];
2093 let dgs_c = hs[c] * u_k[c];
2094 let mut d_h_rc =
2095 exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2096 if r == c {
2097 d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2098 }
2099 if has_entry {
2100 d_h_rc -=
2101 exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2102 if r == c {
2103 d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2104 }
2105 }
2106 b_dir[[r, c]] += w_i * d_h_rc;
2107 }
2108 }
2109
2110 let s_i = self
2112 .stabilized_structural_derivative(deriv_raw[i])
2113 .unwrap_or(deriv_raw[i]);
2114 if !s_i.is_finite() {
2115 return Err(EstimationError::ParameterConstraintViolation(format!(
2116 "survival monotonicity violated in unified trace contraction at row {i}: \
2117 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2118 )));
2119 }
2120 if self.event_target[i] > 0 {
2121 if s_i < guard_numerical {
2122 return Err(EstimationError::ParameterConstraintViolation(format!(
2123 "survival monotonicity violated in unified trace contraction at row {i}: \
2124 d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2125 )));
2126 }
2127 let inv_s = 1.0 / s_i;
2128 let inv_s2 = inv_s * inv_s;
2129 let inv_s3 = inv_s2 * inv_s;
2130 for r in 0..p {
2131 let dgd_r = hsd[r] * u_k[r];
2132 let dtsd_r = tsd[r] * u_k[r];
2133 let dte_r = te[r] * u_k[r];
2134 for c in 0..p {
2135 let dgd_c = hsd[c] * u_k[c];
2136 let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2137 - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2138 if r == c {
2139 d_h_rc += -dte_r;
2140 d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2141 }
2142 b_dir[[r, c]] += w_i * d_h_rc;
2143 }
2144 }
2145 }
2146 }
2147
2148 Ok(b_dir)
2149 }
2150
2151 pub fn offset_channel_residuals(
2189 &self,
2190 beta: &Array1<f64>,
2191 ) -> Result<OffsetChannelResiduals, EstimationError> {
2192 if beta.len() != self.coefficient_dim() {
2193 crate::bail_invalid_estim!(
2194 "survival beta dimension mismatch in offset_channel_residuals"
2195 );
2196 }
2197 let n = self.nrows();
2198 let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2199 let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2200 let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2201
2202 let derivative_guard_numerical = self.derivative_guard_numerical();
2203 let mut r_exit = Array1::<f64>::zeros(n);
2204 let mut r_entry = Array1::<f64>::zeros(n);
2205 let mut r_deriv = Array1::<f64>::zeros(n);
2206
2207 for i in 0..n {
2208 let w = self.sampleweight[i];
2209 if w <= 0.0 {
2210 continue;
2211 }
2212 let entry_age = self.age_entry[i];
2213 let exit_age = self.age_exit[i];
2214 if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2215 crate::bail_invalid_estim!(
2216 "survival ages must be finite with age_exit >= age_entry"
2217 );
2218 }
2219 let has_entry_interval = !self.entry_at_origin[i];
2220 let d = f64::from(self.event_target[i]);
2221 let w_exit_i = w * eta_exit[i].exp();
2225 let w_entry_i = if has_entry_interval {
2226 w * eta_entry[i].exp()
2227 } else {
2228 0.0
2229 };
2230 if !w_exit_i.is_finite() {
2231 crate::bail_invalid_estim!(
2232 "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2233 );
2234 }
2235 r_exit[i] = w_exit_i - d * w;
2236 r_entry[i] = -w_entry_i;
2237 let deriv_raw = derivative_raw[i];
2242 let deriv = self
2243 .stabilized_structural_derivative(deriv_raw)
2244 .unwrap_or(deriv_raw);
2245 let mono_floor = if d > 0.0 {
2246 derivative_guard_numerical
2247 } else {
2248 0.0
2249 };
2250 if !deriv.is_finite() || deriv < mono_floor {
2251 return Err(EstimationError::ParameterConstraintViolation(format!(
2252 "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2253 )));
2254 }
2255 if d > 0.0 {
2256 r_deriv[i] = -w * d / deriv;
2257 }
2258 }
2259
2260 let right = Array1::<f64>::zeros(r_exit.len());
2261 Ok(OffsetChannelResiduals {
2262 exit: r_exit,
2263 entry: r_entry,
2264 derivative: r_deriv,
2265 right,
2266 })
2267 }
2268
2269 pub fn unified_lamlobjective_and_rhogradient(
2275 &self,
2276 beta: &Array1<f64>,
2277 state: &WorkingState,
2278 rho: &Array1<f64>,
2279 ) -> Result<(f64, Array1<f64>), EstimationError> {
2280 use gam_solve::estimate::reml::assembly::{
2281 InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2282 };
2283 use gam_solve::estimate::reml::reml_outer_engine::{
2284 DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2285 compute_block_penalty_logdet_derivs,
2286 };
2287 use gam_problem::EvalMode;
2288
2289 let p = beta.len();
2290 let active_penalty_blocks: Vec<&PenaltyBlock> = self
2291 .penalties
2292 .blocks
2293 .iter()
2294 .filter(|b| b.lambda > 0.0)
2295 .collect();
2296 if rho.len() != active_penalty_blocks.len() {
2297 crate::bail_invalid_estim!(
2298 "survival LAML rho dimension {} does not match active penalty block count {}",
2299 rho.len(),
2300 active_penalty_blocks.len()
2301 );
2302 }
2303 let k_count = active_penalty_blocks.len();
2304
2305 let h_dense = state.hessian.to_dense();
2307 let hop = DenseSpectralOperator::from_symmetric(&h_dense)
2308 .map_err(EstimationError::InvalidInput)?;
2309
2310 let block_descs: Vec<PenaltyBlockDesc> = self
2312 .penalties
2313 .blocks
2314 .iter()
2315 .filter(|b| b.lambda > 0.0)
2316 .map(|b| PenaltyBlockDesc {
2317 matrix: &b.matrix,
2318 range_start: b.range.start,
2319 range_end: b.range.end,
2320 })
2321 .collect();
2322 let penalty_coords =
2323 penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2324
2325 let per_block_rho: Vec<Array1<f64>> =
2327 rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2328 let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2329 .iter()
2330 .map(|b| vec![b.matrix.clone()])
2331 .collect();
2332 let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2333 .iter()
2334 .map(|v| v.as_slice())
2335 .collect();
2336 let penalty_logdet = if k_count > 0 {
2337 compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2338 .map_err(EstimationError::InvalidInput)?
2339 } else {
2340 PenaltyLogdetDerivs {
2341 value: 0.0,
2342 first: Array1::zeros(0),
2343 second: Some(Array2::zeros((0, 0))),
2344 }
2345 };
2346
2347 let penalty_quadratic = 2.0 * state.penalty_term;
2349 let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2350
2351 const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2365 let kkt_residual = {
2366 let raw = state.gradient.clone();
2367 let projected = match self.monotonicity_linear_constraints() {
2368 Some(constraints) => {
2369 projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2370 .ok_or_else(|| {
2371 EstimationError::InvalidInput(
2372 "survival LAML could not project the monotonicity KKT residual"
2373 .to_string(),
2374 )
2375 })?
2376 }
2377 None => raw,
2378 };
2379 let projected_norm = array1_l2_norm(&projected);
2380 let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2381 if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2382 Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2383 } else {
2384 None
2385 }
2386 };
2387
2388 let result = InnerAssembly {
2389 log_likelihood: state.log_likelihood,
2390 penalty_quadratic,
2391 beta: beta.clone(),
2392 n_observations: self.nrows(),
2393 hessian_op: std::sync::Arc::new(hop),
2394 penalty_coords,
2395 penalty_logdet,
2396 dispersion: DispersionHandling::Fixed {
2397 phi: 1.0,
2398 include_logdet_h: true,
2399 include_logdet_s: true,
2400 },
2401 rho_curvature_scale: 1.0,
2402 rho_prior: gam_problem::RhoPrior::Flat,
2403 hessian_logdet_correction: 0.0,
2404 penalty_subspace_trace: None,
2405 deriv_provider: Some(Box::new(provider)),
2406 firth: None,
2407 nullspace_dim: None,
2408 barrier_config: None,
2409 ext_coords: Vec::new(),
2410 ext_coord_pair_fn: None,
2411 rho_ext_pair_fn: None,
2412 fixed_drift_deriv: None,
2413 contracted_psi_second_order: None,
2414 kkt_residual,
2415 active_constraints: None,
2416 }
2417 .evaluate(
2418 rho.as_slice().expect("rho must be contiguous"),
2419 EvalMode::ValueAndGradient,
2420 None,
2421 )
2422 .map_err(EstimationError::InvalidInput)?;
2423
2424 let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2425 Ok((result.cost, gradient))
2426 }
2427
2428 pub fn evaluate_survival_lamlcost_and_gradient(
2449 &self,
2450 rho: &[f64],
2451 beta0: &Array1<f64>,
2452 ) -> Result<(f64, Array1<f64>), EstimationError> {
2453 let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2454 let rho_arr = Array1::from_vec(rho.to_vec());
2459 let state = candidate.update_state(&beta)?;
2460 candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2461 }
2462
2463 fn reconverge_survival_inner_mode(
2473 &self,
2474 rho: &[f64],
2475 beta0: &Array1<f64>,
2476 ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2477 const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2482 const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2483 const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2484 const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2485
2486 let active_block_count = self
2487 .penalties
2488 .blocks
2489 .iter()
2490 .filter(|b| b.lambda > 0.0)
2491 .count();
2492 if rho.len() != active_block_count {
2493 crate::bail_invalid_estim!(
2494 "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2495 rho.len(),
2496 active_block_count
2497 );
2498 }
2499 if beta0.len() != self.coefficient_dim() {
2500 crate::bail_invalid_estim!(
2501 "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2502 beta0.len(),
2503 self.coefficient_dim()
2504 );
2505 }
2506
2507 let mut candidate = self.clone();
2510 let mut lambdas: Vec<f64> = candidate
2511 .penalties
2512 .blocks
2513 .iter()
2514 .map(|b| b.lambda)
2515 .collect();
2516 let mut active_idx = 0usize;
2517 for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2518 if block.lambda > 0.0 {
2519 *lambda = rho[active_idx].exp();
2520 active_idx += 1;
2521 }
2522 }
2523 candidate.set_penalty_lambdas(&lambdas)?;
2524
2525 let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2526 max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2527 convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2528 adaptive_kkt_tolerance: None,
2529 max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2530 min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2531 firth_bias_reduction: false,
2532 coefficient_lower_bounds: None,
2533 linear_constraints: None,
2534 initial_lm_lambda: None,
2535 geodesic_acceleration: false,
2536 arrow_schur: None,
2537 };
2538 let summary = gam_solve::pirls::runworking_model_pirls(
2539 &mut candidate,
2540 Coefficients::new(beta0.clone()),
2541 &opts,
2542 |_| {},
2543 )?;
2544 let mut beta = summary.beta.as_ref().to_owned();
2545
2546 {
2568 const POLISH_MAX_ITERS: usize = 400;
2569 const POLISH_TOL: f64 = 1e-13;
2570 const ARMIJO_C: f64 = 1e-4;
2572 const BACKTRACK: f64 = 0.5;
2573 const MAX_BACKTRACK: usize = 80;
2574 let p = beta.len();
2575 let penalized_objective =
2579 |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2580 for _ in 0..POLISH_MAX_ITERS {
2581 let st = match candidate.update_state(&beta) {
2582 Ok(st) => st,
2583 Err(_) => break,
2584 };
2585 let r = st.gradient.clone();
2586 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2587 if !r_norm.is_finite() || r_norm < POLISH_TOL {
2588 break;
2589 }
2590 let h = st.hessian.to_dense();
2591 let f0 = penalized_objective(&st);
2592 let h_scale = (0..p)
2607 .map(|d| h[[d, d]].abs())
2608 .fold(0.0_f64, f64::max)
2609 .max(1.0);
2610 let mut step: Option<Array1<f64>> = None;
2621 let mut dir_deriv = 0.0_f64;
2622 for lm_pow in 0..18 {
2623 let lambda_lm = if lm_pow == 0 {
2624 0.0
2625 } else {
2626 1e-12 * h_scale * 10f64.powi(lm_pow)
2627 };
2628 let mut h_reg = h.clone();
2629 for d in 0..p {
2630 h_reg[[d, d]] += lambda_lm;
2631 }
2632 let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2633 &h_reg,
2634 faer::Side::Lower,
2635 ) {
2636 Ok(f) => f,
2637 Err(_) => continue,
2638 };
2639 let candidate_step = factor.solvevec(&r);
2640 if candidate_step.iter().any(|v| !v.is_finite()) {
2641 continue;
2642 }
2643 let dd = -r.dot(&candidate_step);
2645 if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2646 step = Some(candidate_step);
2647 dir_deriv = dd;
2648 break;
2649 }
2650 }
2651 let (step, dir_deriv) = match step {
2652 Some(s) => (s, dir_deriv),
2653 None => {
2654 (r.clone(), -r_norm * r_norm)
2657 }
2658 };
2659 let mut alpha = 1.0_f64;
2660 let mut accepted = false;
2661 for _ in 0..MAX_BACKTRACK {
2662 let trial = &beta - &(alpha * &step);
2663 if let Ok(ts) = candidate.update_state(&trial) {
2664 let ft = penalized_objective(&ts);
2665 let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2666 let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2677 let residual_ok = tn.is_finite() && tn < r_norm;
2678 if armijo_ok || residual_ok {
2679 beta = trial;
2680 accepted = true;
2681 break;
2682 }
2683 }
2684 alpha *= BACKTRACK;
2685 }
2686 if !accepted {
2687 break;
2688 }
2689 }
2690 }
2691
2692 Ok((candidate, beta))
2693 }
2694}
2695
2696pub(crate) struct SurvivalDerivProvider {
2705 model: WorkingModelSurvival,
2706 beta: Array1<f64>,
2707}
2708
2709impl SurvivalDerivProvider {
2710 pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2711 Self { model, beta }
2712 }
2713}
2714
2715impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2716 fn hessian_derivative_correction(
2717 &self,
2718 v_k: &Array1<f64>,
2719 ) -> Result<Option<Array2<f64>>, String> {
2720 let u_k = -v_k;
2723 match self
2724 .model
2725 .survival_hessian_derivative_correction(&self.beta, &u_k)
2726 {
2727 Ok(correction) => Ok(Some(correction)),
2728 Err(e) => Err(e.to_string()),
2729 }
2730 }
2731
2732 fn has_corrections(&self) -> bool {
2733 true
2734 }
2735}
2736
2737#[derive(Debug, Clone)]
2738pub struct CrudeRiskResult {
2739 pub risk: f64,
2740 pub diseasegradient: Array1<f64>,
2741 pub mortalitygradient: Array1<f64>,
2742}
2743
2744#[derive(Debug, Clone)]
2745pub struct CompetingRisksCifResult {
2746 pub cif: Vec<Array2<f64>>,
2751 pub overall_survival: Array2<f64>,
2752}
2753
2754const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2759
2760pub fn assemble_competing_risks_cif(
2761 times: ArrayView1<'_, f64>,
2762 cumulative_hazard: ArrayView3<'_, f64>,
2763) -> Result<CompetingRisksCifResult, SurvivalError> {
2764 let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2765 if n_endpoints == 0 {
2766 return Err(SurvivalError::DimensionMismatch);
2767 }
2768 let endpoint_hazards = cumulative_hazard
2769 .axis_iter(Axis(0))
2770 .map(|view| view.to_owned())
2771 .collect::<Vec<_>>();
2772 assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2773 if result.overall_survival.dim() != (n_rows, n_times) {
2774 Err(SurvivalError::DimensionMismatch)
2775 } else {
2776 Ok(result)
2777 }
2778 })
2779}
2780
2781pub fn assemble_competing_risks_cif_from_endpoints(
2782 times: ArrayView1<'_, f64>,
2783 cumulative_hazards: &[Array2<f64>],
2784) -> Result<CompetingRisksCifResult, SurvivalError> {
2785 let n_endpoints = cumulative_hazards.len();
2786 if n_endpoints == 0 || times.is_empty() {
2787 return Err(SurvivalError::DimensionMismatch);
2788 }
2789 let (n_rows, n_times) = cumulative_hazards[0].dim();
2790 if n_rows == 0 || n_times == 0 || times.len() != n_times {
2791 return Err(SurvivalError::DimensionMismatch);
2792 }
2793 if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2794 return Err(SurvivalError::InvalidTimeGrid);
2795 }
2796 if times
2797 .iter()
2798 .zip(times.iter().skip(1))
2799 .any(|(previous, current)| current <= previous)
2800 {
2801 return Err(SurvivalError::InvalidTimeGrid);
2802 }
2803 for endpoint_hazard in cumulative_hazards {
2804 if endpoint_hazard.dim() != (n_rows, n_times) {
2805 return Err(SurvivalError::DimensionMismatch);
2806 }
2807 if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2808 return Err(SurvivalError::NonFiniteInput);
2809 }
2810 }
2811
2812 let max_abs_hazard = cumulative_hazards
2813 .iter()
2814 .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2815 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2816 let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2817 let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2818 .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2819 .collect();
2820 let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2821
2822 let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2834 let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2835 let mut surv_row = vec![0.0_f64; n_times];
2836 let mut previous_cif = vec![0.0_f64; n_endpoints];
2837 let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2838 let mut increments = vec![0.0_f64; n_endpoints];
2839 let mut previous_total_cumulative = 0.0_f64;
2840 for time_idx in 0..n_times {
2841 let mut total_increment = 0.0_f64;
2842 for endpoint in 0..n_endpoints {
2843 let current = cumulative_hazards[endpoint][[row, time_idx]];
2844 if current < -monotone_tolerance {
2845 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2846 }
2847 let raw_increment = current - previous_cumulative[endpoint];
2848 if raw_increment < -monotone_tolerance {
2849 return Err(SurvivalError::NonMonotoneCumulativeHazard);
2850 }
2851 let increment = raw_increment.max(0.0);
2852 increments[endpoint] = increment;
2853 total_increment += increment;
2854 previous_cumulative[endpoint] += increment;
2855 }
2856
2857 let survival_left = (-previous_total_cumulative).exp();
2858 let interval_failure = -(-total_increment).exp_m1();
2859 for endpoint in 0..n_endpoints {
2860 if total_increment > 0.0 {
2861 previous_cif[endpoint] +=
2862 survival_left * interval_failure * increments[endpoint] / total_increment;
2863 }
2864 cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2865 }
2866 previous_total_cumulative += total_increment;
2867 let mut fsum_at_t = 0.0_f64;
2884 for endpoint in 0..n_endpoints {
2885 fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2886 }
2887 surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2888 }
2889 Ok((cif_flat, surv_row))
2890 };
2891
2892 let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2896 && rayon::current_thread_index().is_none()
2897 {
2898 use rayon::prelude::*;
2899 (0..n_rows)
2900 .into_par_iter()
2901 .map(assemble_row)
2902 .collect::<Result<_, _>>()?
2903 } else {
2904 (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2905 };
2906
2907 for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2908 for endpoint in 0..n_endpoints {
2909 for time_idx in 0..n_times {
2910 cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2911 }
2912 }
2913 for time_idx in 0..n_times {
2914 overall_survival[[row, time_idx]] = surv_row[time_idx];
2915 }
2916 }
2917
2918 Ok(CompetingRisksCifResult {
2919 cif,
2920 overall_survival,
2921 })
2922}
2923
2924fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2925 let mut nodesweights = Vec::with_capacity(n);
2926 let m = n.div_ceil(2);
2927
2928 for i in 0..m {
2929 let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2930 let mut pp = 0.0;
2931
2932 for _ in 0..100 {
2933 let mut p1 = 1.0;
2934 let mut p2 = 0.0;
2935 for j in 0..n {
2936 let p3 = p2;
2937 p2 = p1;
2938 p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2939 }
2940 pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2941 let z_prev = z;
2942 z = z_prev - p1 / pp;
2943 if (z - z_prev).abs() < 1e-14 {
2944 break;
2945 }
2946 }
2947
2948 let x = z;
2949 let w = 2.0 / ((1.0 - z * z) * pp * pp);
2950 if !n.is_multiple_of(2) && i == m - 1 {
2951 nodesweights.push((0.0, w));
2952 } else {
2953 nodesweights.push((-x, w));
2954 nodesweights.push((x, w));
2955 }
2956 }
2957
2958 nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2959 nodesweights
2960}
2961
2962fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2963 static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2969 &CACHE
2970}
2971
2972pub fn calculate_crude_risk_quadrature<F>(
2996 t0: f64,
2997 t1: f64,
2998 breakpoints: &[f64],
2999 h_dis_t0: f64,
3000 h_mor_t0: f64,
3001 design_d_t0: ArrayView1<'_, f64>,
3002 design_m_t0: ArrayView1<'_, f64>,
3003 mut eval_at: F,
3004) -> Result<CrudeRiskResult, SurvivalError>
3005where
3006 F: FnMut(
3007 f64,
3008 &mut Array1<f64>,
3009 &mut Array1<f64>,
3010 &mut Array1<f64>,
3011 ) -> Result<(f64, f64, f64), SurvivalError>,
3012{
3013 let coeff_len_d = design_d_t0.len();
3014 let coeff_len_m = design_m_t0.len();
3015 if coeff_len_d == 0 || coeff_len_m == 0 {
3016 return Err(SurvivalError::InvalidIntegrationSetup);
3017 }
3018 if !t0.is_finite()
3019 || !t1.is_finite()
3020 || !h_dis_t0.is_finite()
3021 || !h_mor_t0.is_finite()
3022 || design_d_t0.iter().any(|v| !v.is_finite())
3023 || design_m_t0.iter().any(|v| !v.is_finite())
3024 {
3025 return Err(SurvivalError::NonFiniteInput);
3026 }
3027 if t1 <= t0 {
3028 return Ok(CrudeRiskResult {
3029 risk: 0.0,
3030 diseasegradient: Array1::zeros(coeff_len_d),
3031 mortalitygradient: Array1::zeros(coeff_len_m),
3032 });
3033 }
3034
3035 let mut sorted_breaks: Vec<f64> = breakpoints
3036 .iter()
3037 .copied()
3038 .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3039 .collect();
3040 sorted_breaks.push(t0);
3041 sorted_breaks.push(t1);
3042 sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3043 sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3044 if sorted_breaks.len() < 2 {
3045 return Err(SurvivalError::InvalidIntegrationSetup);
3046 }
3047
3048 let mut total_risk = 0.0;
3049 let mut diseasegradient = Array1::zeros(coeff_len_d);
3050 let mut mortalitygradient = Array1::zeros(coeff_len_m);
3051 let nodesweights = gauss_legendre_quadrature();
3052
3053 let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3054 let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3055 let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3056
3057 for segment in sorted_breaks.windows(2) {
3058 let a = segment[0];
3059 let b = segment[1];
3060 let center = 0.5 * (b + a);
3061 let halfwidth = 0.5 * (b - a);
3062 if halfwidth <= 0.0 {
3063 continue;
3064 }
3065
3066 for &(x, w) in nodesweights {
3067 let u = center + halfwidth * x;
3068 let (inst_hazard_d, hazard_d, hazard_m) =
3069 eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3070 if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3071 return Err(SurvivalError::NonFiniteInput);
3072 }
3073 if inst_hazard_d <= 0.0 {
3074 return Err(SurvivalError::NonPositiveHazard);
3075 }
3076
3077 if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3078 return Err(SurvivalError::NonMonotoneCumulativeHazard);
3079 }
3080
3081 let h_dis_cond = hazard_d - h_dis_t0;
3082 let h_mor_cond = hazard_m - h_mor_t0;
3083 let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3084
3085 total_risk += w * inst_hazard_d * s_total * halfwidth;
3086
3087 let weight = w * s_total * halfwidth;
3093 for j in 0..coeff_len_d {
3094 let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3095 let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3096 let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3097 diseasegradient[j] += weight * g;
3098 }
3099
3100 let weight = w * inst_hazard_d * s_total * halfwidth;
3103 for j in 0..coeff_len_m {
3104 let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3105 mortalitygradient[j] += weight * g;
3106 }
3107 }
3108 }
3109
3110 Ok(CrudeRiskResult {
3111 risk: total_risk,
3112 diseasegradient,
3113 mortalitygradient,
3114 })
3115}
3116
3117impl PirlsWorkingModel for WorkingModelSurvival {
3118 fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3119 self.update_state(beta)
3120 }
3121}
3122
3123#[cfg(test)]
3124mod tests {
3125 use super::*;
3126 use ndarray::{Array1, Array2, Array3, array, s};
3127
3128 #[test]
3129 fn competing_risks_cif_constant_hazard_matches_closed_form() {
3130 let times = array![0.0, 2.0, 5.0, 10.0];
3131 let disease_rates = [0.12, 0.06];
3132 let death_rates = [0.05, 0.02];
3133 let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3134 let rate = if endpoint == 0 {
3135 disease_rates[row]
3136 } else {
3137 death_rates[row]
3138 };
3139 rate * times[time_idx]
3140 });
3141
3142 let result =
3143 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3144
3145 for row in 0..2 {
3146 let total_rate = disease_rates[row] + death_rates[row];
3147 for time_idx in 0..times.len() {
3148 let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3149 let expected_disease = disease_rates[row] / total_rate * failure;
3150 let expected_death = death_rates[row] / total_rate * failure;
3151 assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3152 assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3153 assert!(
3154 (result.cif[0][[row, time_idx]]
3155 + result.cif[1][[row, time_idx]]
3156 + result.overall_survival[[row, time_idx]]
3157 - 1.0)
3158 .abs()
3159 < 1e-12
3160 );
3161 }
3162 }
3163 }
3164
3165 #[test]
3166 fn competing_risks_cif_rejects_nonmonotone_hazards() {
3167 let times = array![0.0, 1.0, 2.0];
3168 let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3169 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3170 .expect_err("nonmonotone cumulative hazard should be rejected");
3171 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3172 }
3173
3174 #[test]
3175 fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3176 let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3177 let cumulative = Array3::from_shape_vec(
3178 (3, 2, 5),
3179 vec![
3180 0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, 0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, 0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3184 ],
3185 )
3186 .expect("shape");
3187
3188 let result =
3189 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3190
3191 for row in 0..2 {
3192 for time_idx in 0..times.len() {
3193 let total_cif = result.cif[0][[row, time_idx]]
3194 + result.cif[1][[row, time_idx]]
3195 + result.cif[2][[row, time_idx]];
3196 assert!(
3197 (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3198 "probability mass mismatch at row={row}, time_idx={time_idx}"
3199 );
3200 assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3201 for cause in 0..3 {
3202 assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3203 if time_idx > 0 {
3204 assert!(
3205 result.cif[cause][[row, time_idx]] + 1e-12
3206 >= result.cif[cause][[row, time_idx - 1]],
3207 "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3208 );
3209 }
3210 }
3211 }
3212 }
3213
3214 assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3217 assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3220 assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3221 }
3222
3223 #[test]
3224 fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3225 let cumulative = Array3::zeros((2, 1, 2));
3226
3227 for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3228 let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3229 .expect_err("bad time grid should be rejected");
3230 assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3231 }
3232
3233 let times = array![0.0, 1.0];
3234 let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3235 let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3236 .expect_err("nonfinite hazard should be rejected");
3237 assert!(matches!(err, SurvivalError::NonFiniteInput));
3238 }
3239
3240 #[test]
3241 fn competing_risks_cif_extreme_hazards_remain_bounded() {
3242 let times = array![0.0, 1.0, 2.0];
3243 let cumulative =
3244 Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3245 .expect("shape");
3246
3247 let result =
3248 assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3249
3250 for value in result
3251 .cif
3252 .iter()
3253 .flat_map(|m| m.iter())
3254 .chain(result.overall_survival.iter())
3255 {
3256 assert!(value.is_finite());
3257 assert!((0.0..=1.0).contains(value));
3258 }
3259 assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3260 assert_eq!(result.overall_survival[[0, 2]], 0.0);
3261 }
3262
3263 fn toy_penalties() -> PenaltyBlocks {
3264 let s = array![[2.0, 0.5], [0.5, 3.0]];
3265 PenaltyBlocks::new(vec![PenaltyBlock {
3266 matrix: s,
3267 lambda: 1.7,
3268 range: 1..3,
3269 nullspace_dim: 0,
3270 }])
3271 }
3272
3273 fn survival_inputs<'a>(
3274 age_entry: &'a Array1<f64>,
3275 age_exit: &'a Array1<f64>,
3276 event_target: &'a Array1<u8>,
3277 event_competing: &'a Array1<u8>,
3278 sampleweight: &'a Array1<f64>,
3279 x_entry: &'a Array2<f64>,
3280 x_exit: &'a Array2<f64>,
3281 x_derivative: &'a Array2<f64>,
3282 ) -> SurvivalEngineInputs<'a> {
3283 SurvivalEngineInputs {
3284 age_entry: age_entry.view(),
3285 age_exit: age_exit.view(),
3286 event_target: event_target.view(),
3287 event_competing: event_competing.view(),
3288 sampleweight: sampleweight.view(),
3289 x_entry: x_entry.view(),
3290 x_exit: x_exit.view(),
3291 x_derivative: x_derivative.view(),
3292 monotonicity_constraint_rows: None,
3293 monotonicity_constraint_offsets: None,
3294 }
3295 }
3296
3297 fn survival_model(
3298 inputs: SurvivalEngineInputs<'_>,
3299 penalties: PenaltyBlocks,
3300 monotonicity: SurvivalMonotonicityPenalty,
3301 spec: SurvivalSpec,
3302 ) -> Result<WorkingModelSurvival, SurvivalError> {
3303 WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3304 }
3305
3306 fn survival_model_with_offsets(
3307 inputs: SurvivalEngineInputs<'_>,
3308 offsets: Option<SurvivalBaselineOffsets<'_>>,
3309 penalties: PenaltyBlocks,
3310 monotonicity: SurvivalMonotonicityPenalty,
3311 spec: SurvivalSpec,
3312 ) -> Result<WorkingModelSurvival, SurvivalError> {
3313 WorkingModelSurvival::from_engine_inputswith_offsets(
3314 inputs,
3315 offsets,
3316 penalties,
3317 monotonicity,
3318 spec,
3319 )
3320 }
3321
3322 #[test]
3323 fn penaltyhessian_matchesgradient_jacobian() {
3324 let penalties = toy_penalties();
3325 let beta = array![10.0, -0.3, 1.2, 7.0];
3326
3327 let grad = penalties.gradient(&beta);
3328 let h = penalties.hessian(beta.len());
3329 let b_block = beta.slice(s![1..3]).to_owned();
3330 let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3331
3332 assert!((grad[1] - expected[0]).abs() < 1e-12);
3333 assert!((grad[2] - expected[1]).abs() < 1e-12);
3334 assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3335 assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3336 assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3337 assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3338 }
3339
3340 #[test]
3341 fn penaltygradient_matches_deviance_finite_difference() {
3342 let penalties = toy_penalties();
3343 let beta = array![10.0, -0.3, 1.2, 7.0];
3344 let grad = penalties.gradient(&beta);
3345 let eps = 1e-7;
3346
3347 for idx in 0..beta.len() {
3348 let mut plus = beta.clone();
3349 let mut minus = beta.clone();
3350 plus[idx] += eps;
3351 minus[idx] -= eps;
3352 let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3353 assert_eq!(
3354 grad[idx].signum(),
3355 fd.signum(),
3356 "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3357 grad[idx]
3358 );
3359 assert!(
3360 (grad[idx] - fd).abs() < 1e-6,
3361 "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3362 grad[idx]
3363 );
3364 }
3365 }
3366
3367 #[test]
3368 fn zero_offsets_match_default_survival_state() {
3369 let age_entry = array![1.0_f64, 2.0_f64];
3370 let age_exit = array![2.0_f64, 3.5_f64];
3371 let event_target = array![1u8, 0u8];
3372 let event_competing = array![0u8, 0u8];
3373 let sampleweight = array![1.0, 1.0];
3374 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3375 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3376 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3377 let penalties = PenaltyBlocks::new(Vec::new());
3378 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3379 let beta = array![-1.0, 0.8];
3380
3381 let base = survival_model(
3382 survival_inputs(
3383 &age_entry,
3384 &age_exit,
3385 &event_target,
3386 &event_competing,
3387 &sampleweight,
3388 &x_entry,
3389 &x_exit,
3390 &x_derivative,
3391 ),
3392 penalties.clone(),
3393 mono,
3394 SurvivalSpec::Net,
3395 )
3396 .expect("construct base survival model");
3397
3398 let zero_offsets = survival_model_with_offsets(
3399 survival_inputs(
3400 &age_entry,
3401 &age_exit,
3402 &event_target,
3403 &event_competing,
3404 &sampleweight,
3405 &x_entry,
3406 &x_exit,
3407 &x_derivative,
3408 ),
3409 Some(SurvivalBaselineOffsets {
3410 eta_entry: array![0.0, 0.0].view(),
3411 eta_exit: array![0.0, 0.0].view(),
3412 derivative_exit: array![0.0, 0.0].view(),
3413 }),
3414 penalties,
3415 mono,
3416 SurvivalSpec::Net,
3417 )
3418 .expect("construct offset survival model");
3419
3420 let state_base = base.update_state(&beta).expect("base state");
3421 let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3422 assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3423 assert!(
3424 state_base
3425 .gradient
3426 .iter()
3427 .zip(statezero.gradient.iter())
3428 .all(|(a, b)| (a - b).abs() < 1e-12)
3429 );
3430 }
3431
3432 #[test]
3433 fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3434 let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3448 let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3449 let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3451 let event_competing = Array1::<u8>::zeros(cause_labels.len());
3452 let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3453 let x_entry = array![
3454 [1.0, age_entry[0].max(1e-8).ln()],
3455 [1.0, age_entry[1].max(1e-8).ln()],
3456 [1.0, age_entry[2].max(1e-8).ln()],
3457 [1.0, age_entry[3].max(1e-8).ln()],
3458 ];
3459 let x_exit = array![
3460 [1.0, age_exit[0].ln()],
3461 [1.0, age_exit[1].ln()],
3462 [1.0, age_exit[2].ln()],
3463 [1.0, age_exit[3].ln()],
3464 ];
3465 let x_derivative = array![
3466 [0.0, 1.0 / age_exit[0]],
3467 [0.0, 1.0 / age_exit[1]],
3468 [0.0, 1.0 / age_exit[2]],
3469 [0.0, 1.0 / age_exit[3]],
3470 ];
3471 let penalties = PenaltyBlocks::new(Vec::new());
3472 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3473
3474 let raw = survival_model(
3479 survival_inputs(
3480 &age_entry,
3481 &age_exit,
3482 &cause_labels,
3483 &event_competing,
3484 &sampleweight,
3485 &x_entry,
3486 &x_exit,
3487 &x_derivative,
3488 ),
3489 penalties.clone(),
3490 mono,
3491 SurvivalSpec::Net,
3492 );
3493 assert!(
3494 matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3495 "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3496 );
3497
3498 let any_event = pooled_any_event_indicator(cause_labels.view());
3501 assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3502 assert_eq!(
3504 cause_specific_event_indicator(cause_labels.view(), 1),
3505 array![0u8, 1u8, 0u8, 0u8]
3506 );
3507 assert_eq!(
3508 cause_specific_event_indicator(cause_labels.view(), 2),
3509 array![0u8, 0u8, 1u8, 0u8]
3510 );
3511 let model = survival_model(
3512 survival_inputs(
3513 &age_entry,
3514 &age_exit,
3515 &any_event,
3516 &event_competing,
3517 &sampleweight,
3518 &x_entry,
3519 &x_exit,
3520 &x_derivative,
3521 ),
3522 penalties,
3523 mono,
3524 SurvivalSpec::Net,
3525 )
3526 .expect("pooled any-event baseline model must construct from competing-risks data");
3527
3528 let beta = array![-1.0_f64, 0.8];
3531 let state = model.update_state(&beta).expect("pooled baseline state");
3532 assert!(
3533 state.deviance.is_finite(),
3534 "pooled baseline deviance must be finite, got {}",
3535 state.deviance
3536 );
3537 assert!(
3538 state.gradient.iter().all(|g| g.is_finite()),
3539 "pooled baseline gradient must be finite"
3540 );
3541 }
3542
3543 #[test]
3544 fn offset_channel_residuals_match_central_fd_of_nll() {
3545 let age_entry = array![0.5_f64, 0.0, 0.3];
3550 let age_exit = array![1.4_f64, 1.0, 2.0];
3551 let event_target = array![1u8, 1u8, 0u8];
3552 let event_competing = array![0u8, 0u8, 0u8];
3553 let sampleweight = array![1.0_f64, 2.5, 0.7];
3554 let x_entry = array![
3555 [1.0, age_entry[0].ln()],
3556 [1.0, age_entry[1].max(1e-8).ln()],
3557 [1.0, age_entry[2].ln()]
3558 ];
3559 let x_exit = array![
3560 [1.0, age_exit[0].ln()],
3561 [1.0, age_exit[1].ln()],
3562 [1.0, age_exit[2].ln()]
3563 ];
3564 let x_derivative = array![
3565 [0.0, 1.0 / age_exit[0]],
3566 [0.0, 1.0 / age_exit[1]],
3567 [0.0, 1.0 / age_exit[2]]
3568 ];
3569 let o_entry = array![0.2_f64, 0.0, 0.1];
3572 let o_exit = array![0.4_f64, 0.5, 0.7];
3573 let o_deriv = array![0.3_f64, 0.8, 0.5];
3574 let penalties = PenaltyBlocks::new(Vec::new());
3575 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3576 let beta = array![-0.7_f64, 0.6];
3577
3578 let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3579 survival_model_with_offsets(
3580 survival_inputs(
3581 &age_entry,
3582 &age_exit,
3583 &event_target,
3584 &event_competing,
3585 &sampleweight,
3586 &x_entry,
3587 &x_exit,
3588 &x_derivative,
3589 ),
3590 Some(SurvivalBaselineOffsets {
3591 eta_entry: o_e.view(),
3592 eta_exit: o_x.view(),
3593 derivative_exit: o_d.view(),
3594 }),
3595 penalties.clone(),
3596 mono,
3597 SurvivalSpec::Net,
3598 )
3599 .expect("model build")
3600 };
3601
3602 let base = build(&o_entry, &o_exit, &o_deriv);
3603 let resid = base
3604 .offset_channel_residuals(&beta)
3605 .expect("offset residuals");
3606 assert_eq!(resid.exit.len(), 3);
3607 assert_eq!(resid.entry.len(), 3);
3608 assert_eq!(resid.derivative.len(), 3);
3609
3610 let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3613 let h = 1e-6;
3614
3615 assert_eq!(resid.entry[1], 0.0);
3619 assert_eq!(resid.derivative[2], 0.0);
3620
3621 for i in 0..3 {
3622 {
3624 let mut op = o_exit.clone();
3625 let mut om = o_exit.clone();
3626 op[i] += h;
3627 om[i] -= h;
3628 let fd = (nll(&build(&o_entry, &op, &o_deriv))
3629 - nll(&build(&o_entry, &om, &o_deriv)))
3630 / (2.0 * h);
3631 assert!(
3632 (resid.exit[i] - fd).abs() < 1e-6,
3633 "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3634 resid.exit[i],
3635 fd
3636 );
3637 }
3638 {
3642 let mut op = o_entry.clone();
3643 let mut om = o_entry.clone();
3644 op[i] += h;
3645 om[i] -= h;
3646 let fd = (nll(&build(&op, &o_exit, &o_deriv))
3647 - nll(&build(&om, &o_exit, &o_deriv)))
3648 / (2.0 * h);
3649 assert!(
3650 (resid.entry[i] - fd).abs() < 1e-6,
3651 "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3652 resid.entry[i],
3653 fd
3654 );
3655 }
3656 {
3658 let mut op = o_deriv.clone();
3659 let mut om = o_deriv.clone();
3660 op[i] += h;
3661 om[i] -= h;
3662 let fd = (nll(&build(&o_entry, &o_exit, &op))
3663 - nll(&build(&o_entry, &o_exit, &om)))
3664 / (2.0 * h);
3665 assert!(
3666 (resid.derivative[i] - fd).abs() < 1e-6,
3667 "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3668 resid.derivative[i],
3669 fd
3670 );
3671 }
3672 }
3673 }
3674
3675 #[test]
3676 fn offset_channel_residuals_respect_zero_sampleweight() {
3677 let age_entry = array![1.0_f64, 2.0];
3678 let age_exit = array![2.0_f64, 3.5];
3679 let event_target = array![1u8, 1u8];
3680 let event_competing = array![0u8, 0u8];
3681 let sampleweight = array![0.0_f64, 1.2]; let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3683 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3684 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3685 let penalties = PenaltyBlocks::new(Vec::new());
3686 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3687 let beta = array![-1.0_f64, 0.8];
3688
3689 let model = survival_model_with_offsets(
3690 survival_inputs(
3691 &age_entry,
3692 &age_exit,
3693 &event_target,
3694 &event_competing,
3695 &sampleweight,
3696 &x_entry,
3697 &x_exit,
3698 &x_derivative,
3699 ),
3700 Some(SurvivalBaselineOffsets {
3701 eta_entry: array![0.0_f64, 0.1].view(),
3702 eta_exit: array![0.0_f64, 0.2].view(),
3703 derivative_exit: array![0.0_f64, 0.1].view(),
3704 }),
3705 penalties,
3706 mono,
3707 SurvivalSpec::Net,
3708 )
3709 .expect("model");
3710 let r = model.offset_channel_residuals(&beta).expect("resid");
3711 assert_eq!(r.exit[0], 0.0);
3713 assert_eq!(r.entry[0], 0.0);
3714 assert_eq!(r.derivative[0], 0.0);
3715 assert!(r.exit[1] != 0.0);
3717 }
3718
3719 #[test]
3720 fn offset_channel_residuals_reject_beta_dim_mismatch() {
3721 let age_entry = array![1.0_f64];
3722 let age_exit = array![2.0_f64];
3723 let event_target = array![1u8];
3724 let event_competing = array![0u8];
3725 let sampleweight = array![1.0_f64];
3726 let x_entry = array![[1.0, 0.0]];
3727 let x_exit = array![[1.0, 0.7]];
3728 let x_derivative = array![[0.0, 0.5]];
3729 let penalties = PenaltyBlocks::new(Vec::new());
3730 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3731 let model = survival_model(
3732 survival_inputs(
3733 &age_entry,
3734 &age_exit,
3735 &event_target,
3736 &event_competing,
3737 &sampleweight,
3738 &x_entry,
3739 &x_exit,
3740 &x_derivative,
3741 ),
3742 penalties,
3743 mono,
3744 SurvivalSpec::Net,
3745 )
3746 .expect("model");
3747 let bad_beta = array![0.0_f64]; let err = model
3749 .offset_channel_residuals(&bad_beta)
3750 .expect_err("mismatch must error");
3751 match err {
3752 EstimationError::InvalidInput(msg) => {
3753 assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3754 }
3755 other => panic!("expected InvalidInput, got {other:?}"),
3756 }
3757 }
3758
3759 #[test]
3760 fn crudespec_is_rejected_by_one_hazard_engine() {
3761 let age_entry = array![1.0_f64];
3762 let age_exit = array![2.0_f64];
3763 let event_target = array![0u8];
3764 let event_competing = array![1u8];
3765 let sampleweight = array![1.0];
3766 let x_entry = array![[0.1]];
3767 let x_exit = array![[0.4]];
3768 let x_derivative = array![[1.0]];
3769 let penalties = PenaltyBlocks::new(Vec::new());
3770 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3771
3772 let err = survival_model(
3773 survival_inputs(
3774 &age_entry,
3775 &age_exit,
3776 &event_target,
3777 &event_competing,
3778 &sampleweight,
3779 &x_entry,
3780 &x_exit,
3781 &x_derivative,
3782 ),
3783 penalties,
3784 mono,
3785 SurvivalSpec::Crude,
3786 )
3787 .expect_err("crude fitting should be rejected by the one-hazard engine");
3788 assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3789 }
3790
3791 #[test]
3792 fn nonstructural_models_require_explicit_monotonicity_collocation() {
3793 let age_entry = array![1.0_f64, 1.5_f64];
3794 let age_exit = array![2.0_f64, 2.5_f64];
3795 let event_target = array![0u8, 0u8];
3796 let event_competing = array![0u8, 1u8];
3797 let sampleweight = array![1.0, 1.0];
3798 let x_entry = array![[0.2], [0.1]];
3799 let x_exit = array![[0.3], [0.2]];
3800 let x_derivative = array![[1.0], [1.0]];
3801
3802 let model = survival_model(
3803 survival_inputs(
3804 &age_entry,
3805 &age_exit,
3806 &event_target,
3807 &event_competing,
3808 &sampleweight,
3809 &x_entry,
3810 &x_exit,
3811 &x_derivative,
3812 ),
3813 PenaltyBlocks::new(Vec::new()),
3814 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3815 SurvivalSpec::Net,
3816 )
3817 .expect("construct censored survival model");
3818
3819 assert!(
3820 model.monotonicity_linear_constraints().is_none(),
3821 "non-structural survival models must not fabricate rowwise monotonicity constraints"
3822 );
3823 }
3824
3825 #[test]
3826 fn decreasing_interval_is_rejectedwithout_target_events() {
3827 let age_entry = array![1.0_f64];
3828 let age_exit = array![2.0_f64];
3829 let event_target = array![0u8];
3830 let event_competing = array![0u8];
3831 let sampleweight = array![1.0];
3832 let x_entry = array![[0.5]];
3833 let x_exit = array![[0.0]];
3834 let x_derivative = array![[1.0]];
3835
3836 let model = survival_model(
3837 survival_inputs(
3838 &age_entry,
3839 &age_exit,
3840 &event_target,
3841 &event_competing,
3842 &sampleweight,
3843 &x_entry,
3844 &x_exit,
3845 &x_derivative,
3846 ),
3847 PenaltyBlocks::new(Vec::new()),
3848 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3849 SurvivalSpec::Net,
3850 )
3851 .expect("construct censored survival model");
3852
3853 let err = model
3854 .update_state(&array![1.0])
3855 .expect_err("decreasing cumulative hazard increment should be rejected");
3856 assert!(
3857 err.to_string().contains("cumulative hazard decreased"),
3858 "unexpected error: {err}"
3859 );
3860 }
3861
3862 fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3863 calculate_crude_risk_quadrature(
3864 0.0,
3865 1.0,
3866 &[0.0, 1.0],
3867 beta_d.exp(),
3868 beta_m.exp(),
3869 array![1.0].view(),
3870 array![1.0].view(),
3871 |u, design_d, deriv_d, design_m| {
3872 let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3873 let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3874 let inst_hazard_d = 0.2 * beta_d.exp();
3875 design_d[0] = 1.0;
3876 deriv_d[0] = 0.0;
3879 design_m[0] = 1.0;
3880 Ok((inst_hazard_d, cumulative_d, cumulative_m))
3881 },
3882 )
3883 .expect("smooth crude-risk quadrature should succeed")
3884 }
3885
3886 #[test]
3887 fn crude_riskgradient_matches_monotoneobjective() {
3888 let beta_d = -0.2_f64;
3889 let beta_m = -0.5_f64;
3890 let result = smooth_crude_risk(beta_d, beta_m);
3891 let eps = 1e-6;
3892
3893 let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3894 - smooth_crude_risk(beta_d - eps, beta_m).risk)
3895 / (2.0 * eps);
3896 let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3897 - smooth_crude_risk(beta_d, beta_m - eps).risk)
3898 / (2.0 * eps);
3899
3900 assert!(
3901 (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3902 "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3903 result.diseasegradient[0]
3904 );
3905 assert!(
3906 (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3907 "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3908 result.mortalitygradient[0]
3909 );
3910 }
3911
3912 #[test]
3913 fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3914 let age_entry = array![1.0_f64, 2.0_f64];
3915 let age_exit = array![2.0_f64, 3.5_f64];
3916 let event_target = array![1u8, 0u8];
3917 let event_competing = array![0u8, 0u8];
3918 let sampleweight = array![1.0, 1.0];
3919 let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3920 let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3921 let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3922 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3923 matrix: array![[2.0]],
3924 lambda: 1.7,
3925 range: 1..2,
3926 nullspace_dim: 0,
3927 }]);
3928 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3929 let beta = array![-1.2, 0.4];
3930
3931 let model = survival_model(
3932 survival_inputs(
3933 &age_entry,
3934 &age_exit,
3935 &event_target,
3936 &event_competing,
3937 &sampleweight,
3938 &x_entry,
3939 &x_exit,
3940 &x_derivative,
3941 ),
3942 penalties.clone(),
3943 mono,
3944 SurvivalSpec::Net,
3945 )
3946 .expect("construct survival model");
3947
3948 let state = model.update_state(&beta).expect("survival state");
3949 let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3950 assert!(
3951 (state.penalty_term - expected_penalty).abs() < 1e-12,
3952 "penalty_term mismatch: state={} expected={}",
3953 state.penalty_term,
3954 expected_penalty
3955 );
3956 }
3957
3958 #[test]
3959 fn negative_penalty_lambda_is_rejected() {
3960 let age_entry = array![1.0_f64];
3961 let age_exit = array![2.0_f64];
3962 let event_target = array![1u8];
3963 let event_competing = array![0u8];
3964 let sampleweight = array![1.0];
3965 let x_entry = array![[1.0, 0.0]];
3966 let x_exit = array![[1.0, 0.5]];
3967 let x_derivative = array![[0.0, 1.0]];
3968 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3969 matrix: array![[1.0]],
3970 lambda: -0.1,
3971 range: 1..2,
3972 nullspace_dim: 0,
3973 }]);
3974
3975 let err = survival_model(
3976 survival_inputs(
3977 &age_entry,
3978 &age_exit,
3979 &event_target,
3980 &event_competing,
3981 &sampleweight,
3982 &x_entry,
3983 &x_exit,
3984 &x_derivative,
3985 ),
3986 penalties,
3987 SurvivalMonotonicityPenalty { tolerance: 0.0 },
3988 SurvivalSpec::Net,
3989 )
3990 .expect_err("negative lambda must be rejected");
3991
3992 assert!(matches!(err, SurvivalError::NonFiniteInput));
3993 }
3994
3995 #[test]
3996 fn penalty_block_range_and_shapemust_match_coefficients() {
3997 let age_entry = array![1.0_f64];
3998 let age_exit = array![2.0_f64];
3999 let event_target = array![1u8];
4000 let event_competing = array![0u8];
4001 let sampleweight = array![1.0];
4002 let x_entry = array![[1.0, 0.0]];
4003 let x_exit = array![[1.0, 0.5]];
4004 let x_derivative = array![[0.0, 1.0]];
4005 let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4006 matrix: array![[1.0]],
4007 lambda: 0.5,
4008 range: 0..2,
4009 nullspace_dim: 0,
4010 }]);
4011
4012 let err = survival_model(
4013 survival_inputs(
4014 &age_entry,
4015 &age_exit,
4016 &event_target,
4017 &event_competing,
4018 &sampleweight,
4019 &x_entry,
4020 &x_exit,
4021 &x_derivative,
4022 ),
4023 penalties,
4024 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4025 SurvivalSpec::Net,
4026 )
4027 .expect_err("penalty block geometry must match coefficient support");
4028
4029 assert!(matches!(err, SurvivalError::DimensionMismatch));
4030 }
4031
4032 #[test]
4033 fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4034 let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4035 let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4036 let event_target = array![1u8, 0u8, 1u8];
4037 let event_competing = array![0u8, 0u8, 0u8];
4038 let sampleweight = array![1.0, 1.0, 1.0];
4039 let x_entry = array![
4040 [1.0, age_entry[0].ln()],
4041 [1.0, age_entry[1].ln()],
4042 [1.0, age_entry[2].ln()]
4043 ];
4044 let x_exit = array![
4045 [1.0, age_exit[0].ln()],
4046 [1.0, age_exit[1].ln()],
4047 [1.0, age_exit[2].ln()]
4048 ];
4049 let x_derivative = array![
4050 [0.0, 1.0 / age_exit[0]],
4051 [0.0, 1.0 / age_exit[1]],
4052 [0.0, 1.0 / age_exit[2]]
4053 ];
4054 let penalties = PenaltyBlocks::new(Vec::new());
4055 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4056 let beta = array![-1.0, 3.0];
4057
4058 let model = survival_model(
4059 survival_inputs(
4060 &age_entry,
4061 &age_exit,
4062 &event_target,
4063 &event_competing,
4064 &sampleweight,
4065 &x_entry,
4066 &x_exit,
4067 &x_derivative,
4068 ),
4069 penalties,
4070 mono,
4071 SurvivalSpec::Net,
4072 )
4073 .expect("construct survival model");
4074
4075 let state = model.update_state(&beta).expect("state at beta");
4076 let eps = 1e-7;
4077 for j in 0..beta.len() {
4078 let mut plus = beta.clone();
4079 let mut minus = beta.clone();
4080 plus[j] += eps;
4081 minus[j] -= eps;
4082 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4083 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4084 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4085 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4086 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4087 assert_eq!(
4088 state.gradient[j].signum(),
4089 fd.signum(),
4090 "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4091 state.gradient[j]
4092 );
4093 assert!(
4094 (state.gradient[j] - fd).abs() < 1e-5,
4095 "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4096 state.gradient[j]
4097 );
4098 }
4099 }
4100
4101 fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4102 let age_entry: Array1<f64> = Array1::from(vec![
4109 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4110 34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4111 ]);
4112 let age_exit: Array1<f64> = Array1::from(vec![
4113 45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4114 48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4115 ]);
4116 let event_target = Array1::from(vec![
4117 1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4118 ]);
4119 let event_competing = Array1::<u8>::zeros(age_entry.len());
4120 let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4121 let n = age_entry.len();
4122 let ln_age_mean: f64 = {
4123 let mut sum = 0.0;
4124 for i in 0..n {
4125 sum += age_entry[i].ln() + age_exit[i].ln();
4126 }
4127 sum / (2.0 * n as f64)
4128 };
4129 let mut x_entry = Array2::<f64>::zeros((n, 2));
4130 let mut x_exit = Array2::<f64>::zeros((n, 2));
4131 let mut x_derivative = Array2::<f64>::zeros((n, 2));
4132 for i in 0..n {
4133 x_entry[[i, 0]] = 1.0;
4134 x_exit[[i, 0]] = 1.0;
4135 x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4136 x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4137 x_derivative[[i, 0]] = 0.0;
4138 x_derivative[[i, 1]] = 1.0 / age_exit[i];
4139 }
4140 let penalties = PenaltyBlocks::new(vec![
4141 PenaltyBlock {
4142 matrix: array![[3.0]],
4143 lambda: 0.0,
4144 range: 0..1,
4145 nullspace_dim: 0,
4146 },
4147 PenaltyBlock {
4148 matrix: array![[2.5]],
4149 lambda,
4150 range: 1..2,
4151 nullspace_dim: 0,
4152 },
4153 ]);
4154 survival_model(
4155 survival_inputs(
4156 &age_entry,
4157 &age_exit,
4158 &event_target,
4159 &event_competing,
4160 &sampleweight,
4161 &x_entry,
4162 &x_exit,
4163 &x_derivative,
4164 ),
4165 penalties,
4166 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4167 SurvivalSpec::Net,
4168 )
4169 .expect("construct LAML FD survival model")
4170 }
4171
4172 fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4173 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4174 use gam_linalg::faer_ndarray::FaerEigh;
4175
4176 let h_dense = state.hessian.to_dense();
4177 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4178 let eps = spectral_epsilon(evals.as_slice().unwrap());
4179 evals
4180 .iter()
4181 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4182 .sum()
4183 }
4184
4185 #[test]
4186 fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4187 let rho0 = -0.35_f64;
4201 let beta = array![-2.5_f64, 1.0];
4202 let model = laml_fd_test_model(rho0.exp());
4203 let state = model
4204 .update_state(&beta)
4205 .expect("state for LAML prefix-skip test");
4206
4207 assert_eq!(model.penalties.blocks.len(), 2);
4212 assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4213 assert!(model.penalties.blocks[1].lambda > 0.0);
4214
4215 let rho = Array1::from_iter(
4216 model
4217 .penalties
4218 .blocks
4219 .iter()
4220 .filter(|b| b.lambda > 0.0)
4221 .map(|b| b.lambda.ln()),
4222 );
4223 assert_eq!(
4224 rho.len(),
4225 1,
4226 "fixture should expose exactly one active penalty block for the rho vector"
4227 );
4228
4229 let (obj, grad) = model
4230 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4231 .expect("survival LAML objective and gradient");
4232
4233 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4234 - 0.5 * (rho0 + 2.5_f64.ln());
4235 assert_eq!(
4236 grad.len(),
4237 1,
4238 "rho-gradient must match the active-penalty count, not the full block list"
4239 );
4240 assert!(
4241 (obj - expected).abs() < 1e-10,
4242 "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4243 );
4244 assert!(
4245 grad[0].is_finite(),
4246 "rho-gradient must be finite: {}",
4247 grad[0]
4248 );
4249 }
4250
4251 #[test]
4252 fn structural_monotonicgradient_matchesobjectivefd() {
4253 let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4254 let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4255 let event_target = array![1u8, 0u8, 1u8];
4256 let event_competing = array![0u8, 0u8, 0u8];
4257 let sampleweight = array![1.0, 1.0, 1.0];
4258
4259 let x_entry = array![
4262 [1.0, 0.2, 0.05, -0.7],
4263 [1.0, 0.5, 0.20, 0.1],
4264 [1.0, 0.9, 0.60, 1.2]
4265 ];
4266 let x_exit = array![
4267 [1.0, 0.4, 0.16, -0.7],
4268 [1.0, 0.8, 0.64, 0.1],
4269 [1.0, 1.1, 1.21, 1.2]
4270 ];
4271 let x_derivative = array![
4272 [0.0, 0.8, 0.64, 0.0],
4273 [0.0, 0.7, 1.12, 0.0],
4274 [0.0, 0.6, 1.32, 0.0]
4275 ];
4276 let penalties = PenaltyBlocks::new(Vec::new());
4277 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4278 let mut model = survival_model(
4279 survival_inputs(
4280 &age_entry,
4281 &age_exit,
4282 &event_target,
4283 &event_competing,
4284 &sampleweight,
4285 &x_entry,
4286 &x_exit,
4287 &x_derivative,
4288 ),
4289 penalties,
4290 mono,
4291 SurvivalSpec::Net,
4292 )
4293 .expect("construct structural survival model");
4294 model
4295 .set_structural_monotonicity(true, 3)
4296 .expect("enable structural monotonicity");
4297 let constraints = model
4298 .monotonicity_linear_constraints()
4299 .expect("structural derivative constraints");
4300 assert_eq!(constraints.a.nrows(), 2);
4301 assert_eq!(constraints.a.ncols(), 4);
4302 assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4303 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4304 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4305
4306 let beta = array![0.2, 0.2, 0.1, 0.2];
4307 let state = model.update_state(&beta).expect("state at structural beta");
4308 let eps = 1e-7;
4309 for j in 0..beta.len() {
4310 let mut plus = beta.clone();
4311 let mut minus = beta.clone();
4312 plus[j] += eps;
4313 minus[j] -= eps;
4314 let state_plus = model.update_state(&plus).expect("state at beta + eps");
4315 let state_minus = model.update_state(&minus).expect("state at beta - eps");
4316 let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4317 let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4318 let fd = (obj_plus - obj_minus) / (2.0 * eps);
4319 assert_eq!(
4320 state.gradient[j].signum(),
4321 fd.signum(),
4322 "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4323 state.gradient[j]
4324 );
4325 assert!(
4326 (state.gradient[j] - fd).abs() < 2e-5,
4327 "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4328 state.gradient[j]
4329 );
4330 }
4331 }
4332
4333 #[test]
4334 fn structural_monotonic_lamlgradient_returns_finitevalues() {
4335 let age_entry = array![1.0_f64, 1.2_f64];
4336 let age_exit = array![1.5_f64, 2.0_f64];
4337 let event_target = array![1u8, 0u8];
4338 let event_competing = array![0u8, 0u8];
4339 let sampleweight = array![1.0, 1.0];
4340 let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4341 let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4342 let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4343 let penalties = PenaltyBlocks::new(Vec::new());
4344 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4345 let mut model = survival_model(
4346 survival_inputs(
4347 &age_entry,
4348 &age_exit,
4349 &event_target,
4350 &event_competing,
4351 &sampleweight,
4352 &x_entry,
4353 &x_exit,
4354 &x_derivative,
4355 ),
4356 penalties,
4357 mono,
4358 SurvivalSpec::Net,
4359 )
4360 .expect("construct structural survival model");
4361 model
4362 .set_structural_monotonicity(true, 2)
4363 .expect("enable structural monotonicity");
4364 model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4366 matrix: array![[1.0]],
4367 lambda: 0.7,
4368 range: 1..2,
4369 nullspace_dim: 0,
4370 }]);
4371 let beta = array![0.2, 0.2, 0.1];
4372 let state = model.update_state(&beta).expect("state at structural beta");
4373 let rho = Array1::from_iter(
4374 model
4375 .penalties
4376 .blocks
4377 .iter()
4378 .filter(|b| b.lambda > 0.0)
4379 .map(|b| b.lambda.ln()),
4380 );
4381 let (obj, grad) = model
4382 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4383 .expect("laml gradient should work in structural mode");
4384 assert!(obj.is_finite());
4385 assert_eq!(grad.len(), 1);
4386 assert!(grad[0].is_finite());
4387 }
4388
4389 #[test]
4390 fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4391 let age_entry = array![1.0_f64];
4392 let age_exit = array![2.0_f64];
4393 let event_target = array![1u8];
4394 let event_competing = array![0u8];
4395 let sampleweight = array![1.0];
4396 let x_entry = array![[0.0]];
4397 let x_exit = array![[0.2]];
4398 let x_derivative = array![[1.0]];
4399
4400 let penalties = PenaltyBlocks::new(Vec::new());
4401 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4402 let mut model = survival_model(
4403 survival_inputs(
4404 &age_entry,
4405 &age_exit,
4406 &event_target,
4407 &event_competing,
4408 &sampleweight,
4409 &x_entry,
4410 &x_exit,
4411 &x_derivative,
4412 ),
4413 penalties,
4414 mono,
4415 SurvivalSpec::Net,
4416 )
4417 .expect("construct structural survival model");
4418
4419 let beta = array![-3.0];
4420 assert!(
4421 model.update_state(&beta).is_err(),
4422 "negative derivative coefficient should violate derivative guard"
4423 );
4424
4425 model
4426 .set_structural_monotonicity(true, 1)
4427 .expect("enable structural monotonicity");
4428 let constraints = model
4429 .monotonicity_linear_constraints()
4430 .expect("structural derivative constraints");
4431 assert_eq!(constraints.a.nrows(), 1);
4432 assert_eq!(constraints.a.ncols(), 1);
4433 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4434 assert!(constraints.b[0].abs() <= 1e-12);
4436 let state = model
4437 .update_state(&array![1e-6])
4438 .expect("small positive derivative coefficient should remain feasible");
4439 assert!(state.deviance.is_finite());
4440 }
4441
4442 #[test]
4443 fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4444 let age_entry = array![1.0_f64];
4445 let age_exit = array![2.0_f64];
4446 let event_target = array![1u8];
4447 let event_competing = array![0u8];
4448 let sampleweight = array![1.0];
4449 let x_entry = array![[1.0, 0.0]];
4450 let x_exit = array![[1.0, 0.0]];
4451 let x_derivative = array![[0.0, 0.0]];
4452 let penalties = PenaltyBlocks::new(Vec::new());
4453 let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4454 let eta_entry_offset = array![0.0];
4455 let eta_exit_offset = array![0.0];
4456 let derivative_offset_below_guard = array![2.0];
4457 let derivative_offset_above_guard = array![3.1];
4458 let offsets_below_guard = SurvivalBaselineOffsets {
4459 eta_entry: eta_entry_offset.view(),
4460 eta_exit: eta_exit_offset.view(),
4461 derivative_exit: derivative_offset_below_guard.view(),
4462 };
4463 let offsets_above_guard = SurvivalBaselineOffsets {
4464 eta_entry: eta_entry_offset.view(),
4465 eta_exit: eta_exit_offset.view(),
4466 derivative_exit: derivative_offset_above_guard.view(),
4467 };
4468
4469 let model_below_guard = survival_model_with_offsets(
4470 survival_inputs(
4471 &age_entry,
4472 &age_exit,
4473 &event_target,
4474 &event_competing,
4475 &sampleweight,
4476 &x_entry,
4477 &x_exit,
4478 &x_derivative,
4479 ),
4480 Some(offsets_below_guard),
4481 penalties.clone(),
4482 monotonicity,
4483 SurvivalSpec::Net,
4484 )
4485 .expect("construct model with derivative offset below guard");
4486 let err = model_below_guard
4487 .update_state(&array![0.0, 0.0])
4488 .expect_err("derivative offset below guard should be rejected");
4489 let err_text = err.to_string();
4490 assert!(
4491 err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4492 "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4493 );
4494
4495 let model_above_guard = survival_model_with_offsets(
4496 survival_inputs(
4497 &age_entry,
4498 &age_exit,
4499 &event_target,
4500 &event_competing,
4501 &sampleweight,
4502 &x_entry,
4503 &x_exit,
4504 &x_derivative,
4505 ),
4506 Some(offsets_above_guard),
4507 penalties,
4508 SurvivalMonotonicityPenalty { tolerance: 3.0 },
4509 SurvivalSpec::Net,
4510 )
4511 .expect("construct model with derivative offset above guard");
4512 let state = model_above_guard
4513 .update_state(&array![0.0, 0.0])
4514 .expect("derivative offset above guard should remain feasible");
4515 assert!(state.deviance.is_finite());
4516 }
4517
4518 #[test]
4519 fn structural_monotonicity_rejects_negative_derivative_offsets() {
4520 let age_entry = array![1.0_f64];
4521 let age_exit = array![2.0_f64];
4522 let event_target = array![1u8];
4523 let event_competing = array![0u8];
4524 let sampleweight = array![1.0];
4525 let x_entry = array![[0.0]];
4526 let x_exit = array![[0.2]];
4527 let x_derivative = array![[1.0]];
4528 let eta_entry = array![0.0];
4529 let eta_exit = array![0.0];
4530 let derivative_exit = array![-1e-3];
4531 let offsets = SurvivalBaselineOffsets {
4532 eta_entry: eta_entry.view(),
4533 eta_exit: eta_exit.view(),
4534 derivative_exit: derivative_exit.view(),
4535 };
4536
4537 let mut model = survival_model_with_offsets(
4538 survival_inputs(
4539 &age_entry,
4540 &age_exit,
4541 &event_target,
4542 &event_competing,
4543 &sampleweight,
4544 &x_entry,
4545 &x_exit,
4546 &x_derivative,
4547 ),
4548 Some(offsets),
4549 PenaltyBlocks::new(Vec::new()),
4550 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4551 SurvivalSpec::Net,
4552 )
4553 .expect("construct structural survival model");
4554 let err = model
4555 .set_structural_monotonicity(true, 1)
4556 .expect_err("negative derivative offsets must be rejected");
4557 assert!(
4558 err.to_string()
4559 .contains("structural monotonicity requires nonnegative derivative offsets"),
4560 "unexpected error: {err}"
4561 );
4562 }
4563
4564 #[test]
4565 fn structural_monotonicity_emits_coefficient_constraints() {
4566 let age_entry = array![1.0_f64, 1.5_f64];
4567 let age_exit = array![2.0_f64, 3.0_f64];
4568 let event_target = array![1u8, 0u8];
4569 let event_competing = array![0u8, 0u8];
4570 let sampleweight = array![1.0, 1.0];
4571 let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4572 let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4573 let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4574
4575 let mut model = survival_model(
4576 survival_inputs(
4577 &age_entry,
4578 &age_exit,
4579 &event_target,
4580 &event_competing,
4581 &sampleweight,
4582 &x_entry,
4583 &x_exit,
4584 &x_derivative,
4585 ),
4586 PenaltyBlocks::new(Vec::new()),
4587 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4588 SurvivalSpec::Net,
4589 )
4590 .expect("construct structural survival model");
4591 model
4592 .set_structural_monotonicity(true, 2)
4593 .expect("enable structural monotonicity");
4594
4595 let constraints = model
4596 .monotonicity_linear_constraints()
4597 .expect("structural derivative constraints");
4598
4599 assert_eq!(constraints.a.nrows(), 2);
4600 assert_eq!(constraints.a.ncols(), 3);
4601 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4602 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4603 assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4604 }
4605
4606 #[test]
4607 fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4608 let age_entry = array![1.0_f64];
4609 let age_exit = array![2.0_f64];
4610 let event_target = array![1u8];
4611 let event_competing = array![0u8];
4612 let sampleweight = array![1.0];
4613 let x_entry = array![[1.0, 0.2]];
4614 let x_exit = array![[1.0, 0.6]];
4615 let x_derivative = array![[0.0, 1.0]];
4616
4617 let mut model = survival_model(
4618 survival_inputs(
4619 &age_entry,
4620 &age_exit,
4621 &event_target,
4622 &event_competing,
4623 &sampleweight,
4624 &x_entry,
4625 &x_exit,
4626 &x_derivative,
4627 ),
4628 PenaltyBlocks::new(Vec::new()),
4629 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4630 SurvivalSpec::Net,
4631 )
4632 .expect("construct structural survival model");
4633 model
4634 .set_structural_monotonicity(true, 2)
4635 .expect("enable structural monotonicity");
4636
4637 let constraints = model
4638 .monotonicity_linear_constraints()
4639 .expect("structural derivative constraints");
4640
4641 assert_eq!(constraints.a.nrows(), 1);
4642 assert!(
4643 constraints.a[[0, 0]].abs() <= 1e-12,
4644 "inactive time column should remain unconstrained"
4645 );
4646 assert!(
4647 (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4648 "active time column should remain constrained"
4649 );
4650 }
4651
4652 #[test]
4653 fn structural_monotonicity_preserves_sparse_row_patterns() {
4654 let age_entry = array![1.0_f64, 1.5_f64];
4655 let age_exit = array![2.0_f64, 2.5_f64];
4656 let event_target = array![1u8, 1u8];
4657 let event_competing = array![0u8, 0u8];
4658 let sampleweight = array![1.0, 1.0];
4659 let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4660 let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4661 let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4662
4663 let mut model = survival_model(
4664 survival_inputs(
4665 &age_entry,
4666 &age_exit,
4667 &event_target,
4668 &event_competing,
4669 &sampleweight,
4670 &x_entry,
4671 &x_exit,
4672 &x_derivative,
4673 ),
4674 PenaltyBlocks::new(Vec::new()),
4675 SurvivalMonotonicityPenalty { tolerance: 0.0 },
4676 SurvivalSpec::Net,
4677 )
4678 .expect("construct structural survival model");
4679 model
4680 .set_structural_monotonicity(true, 2)
4681 .expect("enable structural monotonicity");
4682
4683 let constraints = model
4684 .monotonicity_linear_constraints()
4685 .expect("structural derivative constraints");
4686
4687 assert_eq!(constraints.a.nrows(), 2);
4688 assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4689 assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4690 }
4691
4692 #[test]
4693 fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4694 let age_entry = array![1.0_f64];
4695 let age_exit = array![1.1_f64];
4696 let event_target = array![0u8];
4697 let event_competing = array![0u8];
4698 let sampleweight = array![1.0];
4699 let x_entry = array![[0.0]];
4700 let x_exit = array![[0.0]];
4701 let x_derivative = array![[-1.0]];
4702 let penalties = PenaltyBlocks::new(Vec::new());
4703 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4704 let model = survival_model(
4705 survival_inputs(
4706 &age_entry,
4707 &age_exit,
4708 &event_target,
4709 &event_competing,
4710 &sampleweight,
4711 &x_entry,
4712 &x_exit,
4713 &x_derivative,
4714 ),
4715 penalties,
4716 mono,
4717 SurvivalSpec::Net,
4718 )
4719 .expect("construct censored survival model");
4720
4721 let err = model
4722 .update_state(&array![1.0])
4723 .expect_err("censored row should still enforce monotonic derivative");
4724 assert!(
4725 matches!(err, EstimationError::ParameterConstraintViolation(_)),
4726 "unexpected error: {err:?}"
4727 );
4728 }
4729
4730 fn crude_risk_quadrature_error(
4731 cumulative_entry: f64,
4732 cumulative_exit: f64,
4733 hazard_exit: f64,
4734 ) -> SurvivalError {
4735 calculate_crude_risk_quadrature(
4736 1.0,
4737 2.0,
4738 &[],
4739 0.4,
4740 0.2,
4741 array![1.0].view(),
4742 array![1.0].view(),
4743 |_, design_d, deriv_d, design_m| {
4744 design_d[0] = 1.0;
4745 deriv_d[0] = 0.0;
4746 design_m[0] = 1.0;
4747 Ok((cumulative_entry, cumulative_exit, hazard_exit))
4748 },
4749 )
4750 .expect_err("invalid hazards should fail")
4751 }
4752
4753 #[test]
4754 fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4755 let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4756 assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4757 }
4758
4759 #[test]
4760 fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4761 let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4762 assert!(matches!(err, SurvivalError::NonPositiveHazard));
4763 }
4764
4765 #[test]
4766 fn laml_no_penalties_matches_documentedobjective() {
4767 let age_entry = array![40.0, 45.0, 50.0, 55.0];
4768 let age_exit = array![44.0, 49.0, 54.0, 59.0];
4769 let event_target = array![1u8, 0u8, 1u8, 0u8];
4770 let event_competing = Array1::<u8>::zeros(4);
4771 let sampleweight = Array1::ones(4);
4772 let x_entry = array![
4773 [1.0, -0.2, 0.04],
4774 [1.0, -0.1, 0.01],
4775 [1.0, 0.0, 0.0],
4776 [1.0, 0.1, 0.01]
4777 ];
4778 let x_exit = array![
4779 [1.0, -0.12, 0.0144],
4780 [1.0, -0.02, 0.0004],
4781 [1.0, 0.08, 0.0064],
4782 [1.0, 0.18, 0.0324]
4783 ];
4784 let x_derivative = array![
4785 [0.0, 0.02, 0.001],
4786 [0.0, 0.02, 0.001],
4787 [0.0, 0.02, 0.001],
4788 [0.0, 0.02, 0.001]
4789 ];
4790 let penalties = PenaltyBlocks::new(Vec::new());
4791 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4792 let beta = array![-2.0, 0.7, 0.2];
4793
4794 let model = survival_model(
4795 survival_inputs(
4796 &age_entry,
4797 &age_exit,
4798 &event_target,
4799 &event_competing,
4800 &sampleweight,
4801 &x_entry,
4802 &x_exit,
4803 &x_derivative,
4804 ),
4805 penalties,
4806 mono,
4807 SurvivalSpec::Net,
4808 )
4809 .expect("construct survival model");
4810
4811 let state = model.update_state(&beta).expect("state at beta");
4812 let rho = Array1::from_iter(
4813 model
4814 .penalties
4815 .blocks
4816 .iter()
4817 .filter(|b| b.lambda > 0.0)
4818 .map(|b| b.lambda.ln()),
4819 );
4820 let (obj, grad) = model
4821 .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4822 .expect("laml objective for no-penalty model");
4823
4824 let h_dense = state.hessian.to_dense();
4825 let logdet_h: f64 = {
4826 use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4827 use gam_linalg::faer_ndarray::FaerEigh;
4828 let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4829 let eps = spectral_epsilon(evals.as_slice().unwrap());
4830 evals
4831 .iter()
4832 .map(|&sigma| spectral_regularize(sigma, eps).ln())
4833 .sum()
4834 };
4835 let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4836
4837 assert_eq!(grad.len(), 0);
4838 assert!(
4839 (obj - expected).abs() < 1e-10,
4840 "no-penalty LAML objective mismatch: obj={} expected={}",
4841 obj,
4842 expected
4843 );
4844 }
4845
4846 #[test]
4847 fn monotonicity_constraints_collapse_positive_collinearrows() {
4848 let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4849 let b = array![1e-8, 1e-8, 1e-8];
4850
4851 let compressed = compress_positive_collinear_constraints(&a, &b);
4852
4853 assert_eq!(compressed.a.nrows(), 1);
4854 assert_eq!(compressed.a.ncols(), 3);
4855 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4856 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4857 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4858 assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4859 }
4860
4861 #[test]
4862 fn monotonicity_constraints_preserve_distinct_directions() {
4863 let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4864 let b = array![0.2, 0.3, 0.1];
4865
4866 let compressed = compress_positive_collinear_constraints(&a, &b);
4867
4868 assert_eq!(compressed.a.nrows(), 2);
4869 let mut saw_x = false;
4870 let mut saw_y = false;
4871 for i in 0..compressed.a.nrows() {
4872 if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4873 saw_x = true;
4874 assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4875 }
4876 if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4877 saw_y = true;
4878 assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4879 }
4880 }
4881 assert!(saw_x);
4882 assert!(saw_y);
4883 }
4884
4885 #[test]
4886 fn monotonicity_constraints_cluster_near_collinearrows() {
4887 let a = array![
4888 [0.0, 0.5, 0.0],
4889 [0.0, 0.50000000003, 0.0],
4890 [0.0, 0.49999999997, 0.0]
4891 ];
4892 let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4893
4894 let compressed = compress_positive_collinear_constraints(&a, &b);
4895
4896 assert_eq!(compressed.a.nrows(), 1);
4897 assert_eq!(compressed.a.ncols(), 3);
4898 assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4899 assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4900 assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4901 assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4902 }
4903
4904 #[test]
4905 fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4906 let a = array![
4907 [0.0, 0.401, 0.302, 0.197],
4908 [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4909 [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4910 [0.0, 0.125, 0.500, 0.375]
4911 ];
4912 let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4913
4914 let compressed = compress_positive_collinear_constraints(&a, &b);
4915
4916 assert_eq!(compressed.a.nrows(), 2);
4917 let mut clustered_face = false;
4918 let mut distinct_face = false;
4919 for i in 0..compressed.a.nrows() {
4920 let row = compressed.a.row(i);
4921 if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4922 clustered_face = true;
4923 assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4924 } else {
4925 distinct_face = true;
4926 assert!((row[1] - 0.25).abs() <= 1e-12);
4927 assert!((row[2] - 1.0).abs() <= 1e-12);
4928 assert!((row[3] - 0.75).abs() <= 1e-12);
4929 assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4930 }
4931 }
4932 assert!(clustered_face);
4933 assert!(distinct_face);
4934 }
4935
4936 #[test]
4937 fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4938 let age_entry = array![1.0_f64, 1.0, 1.0];
4939 let age_exit = array![2.0_f64, 4.0, 8.0];
4940 let event_target = array![0u8, 1u8, 0u8];
4941 let event_competing = array![0u8, 0u8, 0u8];
4942 let sampleweight = array![1.0, 1.0, 1.0];
4943 let x_entry = array![
4944 [1.0, age_entry[0].ln()],
4945 [1.0, age_entry[1].ln()],
4946 [1.0, age_entry[2].ln()]
4947 ];
4948 let x_exit = array![
4949 [1.0, age_exit[0].ln()],
4950 [1.0, age_exit[1].ln()],
4951 [1.0, age_exit[2].ln()]
4952 ];
4953 let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4954 let penalties = PenaltyBlocks::new(Vec::new());
4955 let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4956
4957 let collocation_offsets = Array1::zeros(x_derivative.nrows());
4958 let mut inputs = survival_inputs(
4959 &age_entry,
4960 &age_exit,
4961 &event_target,
4962 &event_competing,
4963 &sampleweight,
4964 &x_entry,
4965 &x_exit,
4966 &x_derivative,
4967 );
4968 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4969 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4970
4971 let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4972 .expect("construct linear survival model");
4973
4974 let constraints = model
4975 .monotonicity_linear_constraints()
4976 .expect("monotonicity constraints");
4977 assert_eq!(constraints.a.nrows(), 1);
4978 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
4979 assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
4980 }
4981
4982 #[test]
4983 fn monotonicity_constraints_skip_numericallyzerorows() {
4984 let age_entry = array![1.0_f64, 1.0, 1.0];
4985 let age_exit = array![2.0_f64, 3.0, 4.0];
4986 let event_target = array![0u8, 0u8, 0u8];
4987 let event_competing = array![0u8, 0u8, 0u8];
4988 let sampleweight = array![1.0, 1.0, 1.0];
4989 let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
4990 let x_exit = x_entry.clone();
4991 let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
4992
4993 let collocation_offsets = Array1::zeros(x_derivative.nrows());
4994 let mut inputs = survival_inputs(
4995 &age_entry,
4996 &age_exit,
4997 &event_target,
4998 &event_competing,
4999 &sampleweight,
5000 &x_entry,
5001 &x_exit,
5002 &x_derivative,
5003 );
5004 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5005 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5006
5007 let model = survival_model(
5008 inputs,
5009 PenaltyBlocks::new(Vec::new()),
5010 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5011 SurvivalSpec::Net,
5012 )
5013 .expect("construct survival model");
5014
5015 let constraints = model
5016 .monotonicity_linear_constraints()
5017 .expect("nonzero derivative row should remain");
5018 assert_eq!(constraints.a.nrows(), 1);
5019 assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5020 assert!(constraints.b[0].abs() <= 1e-18);
5021 }
5022
5023 #[test]
5024 fn censoredrows_allowzero_boundary_derivative() {
5025 let age_entry = array![1.0_f64];
5026 let age_exit = array![2.0_f64];
5027 let event_target = array![0u8];
5028 let event_competing = array![0u8];
5029 let sampleweight = array![1.0];
5030 let x_entry = array![[0.0]];
5031 let x_exit = array![[0.0]];
5032 let x_derivative = array![[1.0]];
5033
5034 let model = survival_model(
5035 survival_inputs(
5036 &age_entry,
5037 &age_exit,
5038 &event_target,
5039 &event_competing,
5040 &sampleweight,
5041 &x_entry,
5042 &x_exit,
5043 &x_derivative,
5044 ),
5045 PenaltyBlocks::new(Vec::new()),
5046 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5047 SurvivalSpec::Net,
5048 )
5049 .expect("construct censored survival model");
5050
5051 let state = model
5052 .update_state(&array![0.0])
5053 .expect("censored boundary derivative should remain feasible with zero tolerance");
5054 assert!(state.deviance.is_finite());
5055 }
5056
5057 #[test]
5058 fn eventrows_keep_positive_derivative_constraint() {
5059 let age_entry = array![1.0_f64, 1.0];
5060 let age_exit = array![2.0_f64, 4.0];
5061 let event_target = array![0u8, 1u8];
5062 let event_competing = array![0u8, 0u8];
5063 let sampleweight = array![1.0, 1.0];
5064 let x_entry = array![[0.0], [0.0]];
5065 let x_exit = array![[0.0], [0.0]];
5066 let x_derivative = array![[0.5], [0.25]];
5067
5068 let collocation_offsets = Array1::zeros(x_derivative.nrows());
5069 let mut inputs = survival_inputs(
5070 &age_entry,
5071 &age_exit,
5072 &event_target,
5073 &event_competing,
5074 &sampleweight,
5075 &x_entry,
5076 &x_exit,
5077 &x_derivative,
5078 );
5079 inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5080 inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5081
5082 let model = survival_model(
5083 inputs,
5084 PenaltyBlocks::new(Vec::new()),
5085 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5086 SurvivalSpec::Net,
5087 )
5088 .expect("construct mixed survival model");
5089
5090 let constraints = model
5091 .monotonicity_linear_constraints()
5092 .expect("event row should induce positive lower bound");
5093 assert_eq!(constraints.a.nrows(), 1);
5094 assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5095 assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5096 }
5097
5098 #[test]
5099 fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5100 let age_entry = array![1.0_f64];
5101 let age_exit = array![2.0_f64];
5102 let event_target = array![1u8];
5103 let event_competing = array![0u8];
5104 let sampleweight = array![1.0];
5105 let x_entry = array![[0.0]];
5106 let x_exit = array![[0.0]];
5107 let x_derivative = array![[1.0]];
5108 let mut model = survival_model(
5109 survival_inputs(
5110 &age_entry,
5111 &age_exit,
5112 &event_target,
5113 &event_competing,
5114 &sampleweight,
5115 &x_entry,
5116 &x_exit,
5117 &x_derivative,
5118 ),
5119 PenaltyBlocks::new(Vec::new()),
5120 SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5121 SurvivalSpec::Net,
5122 )
5123 .expect("construct survival model");
5124 model
5125 .set_structural_monotonicity(true, 1)
5126 .expect("enable structural monotonicity");
5127
5128 let state = model
5129 .update_state(&array![-1e-8])
5130 .expect("tiny structural roundoff should be clamped");
5131 assert!(state.deviance.is_finite());
5132 }
5133
5134 #[test]
5135 fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5136 let uncompressed_constraints = LinearInequalityConstraints {
5137 a: array![
5138 [0.0, 0.5, 0.0],
5139 [0.0, 1.0 / 3.0, 0.0],
5140 [0.0, 0.2, 0.0],
5141 [0.0, 0.125, 0.0]
5142 ],
5143 b: Array1::from_elem(4, 1e-8),
5144 };
5145 let compressed_constraints = compress_positive_collinear_constraints(
5146 &uncompressed_constraints.a,
5147 &uncompressed_constraints.b,
5148 );
5149
5150 let candidates = [
5151 array![0.0, 1e-9, 0.0],
5152 array![0.0, 4e-8, 0.0],
5153 array![0.0, 8e-8, 0.0],
5154 array![0.0, 2e-7, 1.5],
5155 ];
5156 for beta in candidates {
5157 let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5158 uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5159 });
5160 let compressed_ok = (0..compressed_constraints.a.nrows())
5161 .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5162 assert_eq!(compressed_ok, uncompressed_ok);
5163 }
5164 }
5165
5166 #[test]
5167 fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5168 let age_entry = array![10.0_f64, 20.0, 25.0];
5169 let age_exit = array![15.0_f64, 30.0, 40.0];
5170 let event_target = array![1u8, 0u8, 1u8];
5171 let event_competing = array![0u8, 0u8, 0u8];
5172 let sampleweight = array![1.0, 2.0, 0.5];
5173 let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5174 let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5175 let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5176 let beta = array![0.8, 1.1, -0.2];
5177
5178 let base_model = survival_model(
5179 survival_inputs(
5180 &age_entry,
5181 &age_exit,
5182 &event_target,
5183 &event_competing,
5184 &sampleweight,
5185 &x_entry,
5186 &x_exit,
5187 &x_derivative,
5188 ),
5189 PenaltyBlocks::new(Vec::new()),
5190 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5191 SurvivalSpec::Net,
5192 )
5193 .expect("construct base survival model");
5194 let base_state = base_model
5195 .update_state(&beta)
5196 .expect("evaluate base survival state");
5197
5198 let time_scale = 365.25;
5199 let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5200 let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5201 let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5202 let scaled_model = survival_model(
5203 survival_inputs(
5204 &scaled_age_entry,
5205 &scaled_age_exit,
5206 &event_target,
5207 &event_competing,
5208 &sampleweight,
5209 &x_entry,
5210 &x_exit,
5211 &scaled_x_derivative,
5212 ),
5213 PenaltyBlocks::new(Vec::new()),
5214 SurvivalMonotonicityPenalty { tolerance: 0.0 },
5215 SurvivalSpec::Net,
5216 )
5217 .expect("construct scaled survival model");
5218 let scaled_state = scaled_model
5219 .update_state(&beta)
5220 .expect("evaluate scaled survival state");
5221
5222 let weighted_events = sampleweight
5223 .iter()
5224 .zip(event_target.iter())
5225 .map(|(w, d)| *w * f64::from(*d))
5226 .sum::<f64>();
5227 let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5228 assert!(
5229 (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5230 "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5231 scaled_state.deviance,
5232 base_state.deviance
5233 );
5234
5235 for j in 0..beta.len() {
5236 assert!(
5237 (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5238 "gradient mismatch at j={j}: scaled={} base={}",
5239 scaled_state.gradient[j],
5240 base_state.gradient[j]
5241 );
5242 }
5243
5244 let base_hessian = base_state.hessian.to_dense();
5245 let scaled_hessian = scaled_state.hessian.to_dense();
5246 for r in 0..beta.len() {
5247 for c in 0..beta.len() {
5248 assert!(
5249 (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5250 "hessian mismatch at ({r},{c}): scaled={} base={}",
5251 scaled_hessian[[r, c]],
5252 base_hessian[[r, c]]
5253 );
5254 }
5255 }
5256 }
5257}