Skip to main content

gam_models/survival/
base.rs

1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3    BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4    projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9    LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21    #[error("input dimensions are inconsistent")]
22    DimensionMismatch,
23    #[error("inputs contain non-finite values")]
24    NonFiniteInput,
25    #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26    UnsupportedSpec(&'static str),
27    #[error("crude risk integration setup is invalid")]
28    InvalidIntegrationSetup,
29    #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30    InvalidTimeGrid,
31    #[error("cumulative hazard must be nondecreasing")]
32    NonMonotoneCumulativeHazard,
33    #[error("instantaneous hazard must stay strictly positive during integration")]
34    NonPositiveHazard,
35    #[error("{reason}")]
36    InvalidInput { reason: String },
37    #[error("{reason}")]
38    CauseSpecificDimensionMismatch { reason: String },
39    #[error("{reason}")]
40    NumericalFailure { reason: String },
41    #[error("{reason}")]
42    EventCodeInvalid { reason: String },
43    #[error("{reason}")]
44    EventDegenerate { reason: String },
45    #[error("cause-specific survival block {block}: {source}")]
46    CauseSpecificBlock {
47        block: usize,
48        #[source]
49        source: Box<SurvivalError>,
50    },
51}
52
53impl From<SurvivalError> for String {
54    fn from(err: SurvivalError) -> Self {
55        err.to_string()
56    }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60    fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61        SurvivalError::CauseSpecificDimensionMismatch {
62            reason: err.message(),
63        }
64    }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69    #[default]
70    Net,
71    Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76    pub age_entry: ArrayView1<'a, f64>,
77    pub age_exit: ArrayView1<'a, f64>,
78    pub event_target: ArrayView1<'a, u8>,
79    pub event_competing: ArrayView1<'a, u8>,
80    pub sampleweight: ArrayView1<'a, f64>,
81    pub x_entry: ArrayView2<'a, f64>,
82    pub x_exit: ArrayView2<'a, f64>,
83    pub x_derivative: ArrayView2<'a, f64>,
84    /// Optional global monotonicity collocation rows for the full coefficient vector.
85    /// Non-structural survival models should pass these explicitly instead of
86    /// relying on observed derivative rows.
87    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
89    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94    pub age_entry: ArrayView1<'a, f64>,
95    pub age_exit: ArrayView1<'a, f64>,
96    pub event_target: ArrayView1<'a, u8>,
97    pub event_competing: ArrayView1<'a, u8>,
98    pub sampleweight: ArrayView1<'a, f64>,
99    pub time_entry: ArrayView2<'a, f64>,
100    pub time_exit: ArrayView2<'a, f64>,
101    pub time_derivative: ArrayView2<'a, f64>,
102    pub covariates: ArrayView2<'a, f64>,
103    /// Optional global monotonicity collocation rows for the full coefficient vector.
104    /// Non-structural survival models should pass these explicitly instead of
105    /// relying on observed derivative rows.
106    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
108    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113    /// Baseline target contribution to eta at entry time: eta_target(t_entry).
114    pub eta_entry: ArrayView1<'a, f64>,
115    /// Baseline target contribution to eta at exit time: eta_target(t_exit).
116    pub eta_exit: ArrayView1<'a, f64>,
117    /// Baseline target contribution to d eta / d t at exit: eta_target'(t_exit).
118    ///
119    /// This is used in event terms where log-hazard requires
120    /// log(d eta / d t). By threading this as an explicit offset, we get
121    /// "parametric default + spline deviation" behavior:
122    /// - strong penalty => deviation ~ 0 => model collapses to baseline target,
123    /// - weak penalty   => deviation can bend away where data supports it.
124    pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129    pub matrix: Array2<f64>,
130    pub lambda: f64,
131    pub range: Range<usize>,
132    /// Structural nullspace dimension of this penalty matrix.
133    /// Used for exact pseudo-logdet computation. 0 means full rank.
134    pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139    pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143    pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144        Self { blocks }
145    }
146
147    pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148        let mut grad = Array1::zeros(beta.len());
149        for block in &self.blocks {
150            if block.lambda == 0.0 {
151                continue;
152            }
153            let b = beta.slice(ndarray::s![block.range.clone()]);
154            let g = block.matrix.dot(&b);
155            let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156            dst += &(block.lambda * g);
157        }
158        grad
159    }
160
161    pub fn hessian(&self, dim: usize) -> Array2<f64> {
162        let mut h = Array2::zeros((dim, dim));
163        self.addhessian_inplace(&mut h);
164        h
165    }
166
167    pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168        let mut value = 0.0;
169        for block in &self.blocks {
170            if block.lambda == 0.0 {
171                continue;
172            }
173            let b = beta.slice(ndarray::s![block.range.clone()]);
174            value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175        }
176        value
177    }
178
179    pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180        for block in &self.blocks {
181            if block.lambda == 0.0 {
182                continue;
183            }
184            let start = block.range.start;
185            let end = block.range.end;
186            h.slice_mut(ndarray::s![start..end, start..end])
187                .scaled_add(block.lambda, &block.matrix);
188        }
189    }
190}
191
192/// Entry ages at or below this value are treated as left-truncation at the time
193/// origin, i.e. "no delayed-entry interval" — the cumulative-hazard term
194/// `exp(η_entry)` is dropped because `H(0) = 0`. The Royston-Parmar baseline is
195/// `η(t) = log H(t)` with `H(t) → 0` as `t → 0`, so `log H` diverges at the
196/// origin; this small positive floor lets a row that genuinely enters at time
197/// zero skip the entry contribution instead of evaluating `log H` at a
198/// degenerate point. Shared so every entry-detection site stays in lockstep.
199///
200/// Public so the fit-orchestration layer can classify a dataset as genuinely
201/// left-truncated (`entry > threshold`) with the SAME origin convention the
202/// likelihood engines use, and pick the left-truncation-robust time anchor
203/// accordingly (issue #1790).
204pub const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
205
206/// Fraction-to-the-boundary factor for the cause-specific feasible-step search.
207/// When a Newton direction would drive a row's derivative down to the
208/// monotonicity floor, the step is capped at this fraction of the distance to
209/// the boundary rather than landing exactly on it. Staying strictly inside the
210/// feasible region (the standard interior-point fraction-to-boundary rule)
211/// keeps the next `1/deriv` / `deriv.ln()` evaluation away from the singular
212/// boundary where curvature blows up.
213const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
214
215#[derive(Debug, Clone)]
216pub struct CauseSpecificRoystonParmarBlock {
217    pub age_entry: Array1<f64>,
218    pub age_exit: Array1<f64>,
219    pub event_target: Array1<u8>,
220    pub sampleweight: Array1<f64>,
221    pub x_entry: Array2<f64>,
222    pub x_exit: Array2<f64>,
223    pub x_derivative: Array2<f64>,
224    pub offset_eta_entry: Array1<f64>,
225    pub offset_eta_exit: Array1<f64>,
226    pub offset_derivative_exit: Array1<f64>,
227    pub derivative_floor: f64,
228}
229
230/// Cause-specific competing-risks survival as a blockwise custom family.
231///
232/// Each cause is represented by one `ParameterBlockState`, so endpoint-specific
233/// coefficients, shared smoothing labels, and user-defined coefficient groups
234/// stay on the existing `CustomFamily` / `BlockwiseFitOptions` joint-fit path.
235#[derive(Debug, Clone)]
236pub struct CauseSpecificRoystonParmarFamily {
237    blocks: Vec<CauseSpecificRoystonParmarBlock>,
238}
239
240impl CauseSpecificRoystonParmarFamily {
241    pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
242        if blocks.is_empty() {
243            return Err(SurvivalError::InvalidInput {
244                reason: "cause-specific survival family requires at least one endpoint".to_string(),
245            }
246            .into());
247        }
248        for (idx, block) in blocks.iter().enumerate() {
249            validate_cause_specific_block(block).map_err(|err| {
250                SurvivalError::CauseSpecificBlock {
251                    block: idx + 1,
252                    source: Box::new(err),
253                }
254                .to_string()
255            })?;
256        }
257        Ok(Self { blocks })
258    }
259
260    pub fn cause_count(&self) -> usize {
261        self.blocks.len()
262    }
263}
264
265fn validate_cause_specific_block(
266    block: &CauseSpecificRoystonParmarBlock,
267) -> Result<(), SurvivalError> {
268    let n = block.event_target.len();
269    let p = block.x_exit.ncols();
270    if n == 0 || p == 0 {
271        bail_invalid_surv!("empty event vector or coefficient block");
272    }
273    if block.age_entry.len() != n
274        || block.age_exit.len() != n
275        || block.sampleweight.len() != n
276        || block.x_entry.nrows() != n
277        || block.x_exit.nrows() != n
278        || block.x_derivative.nrows() != n
279        || block.x_entry.ncols() != p
280        || block.x_derivative.ncols() != p
281        || block.offset_eta_entry.len() != n
282        || block.offset_eta_exit.len() != n
283        || block.offset_derivative_exit.len() != n
284    {
285        return Err(SurvivalError::CauseSpecificDimensionMismatch {
286            reason: "dimension mismatch".to_string(),
287        });
288    }
289    // A cause-specific block's `event_target` is the binary cause-k indicator
290    // produced by `cause_specific_event_indicator`; a label > 1 here means the
291    // caller passed raw multi-cause codes instead of projecting per cause. That
292    // is a valid finite label, not non-finite input, so it gets its own clear
293    // error rather than the misleading "non-finite input".
294    if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
295        return Err(SurvivalError::EventCodeInvalid {
296            reason: format!(
297                "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
298            ),
299        });
300    }
301    if block.age_entry.iter().any(|v| !v.is_finite())
302        || block.age_exit.iter().any(|v| !v.is_finite())
303        || block
304            .sampleweight
305            .iter()
306            .any(|v| !v.is_finite() || *v < 0.0)
307        || block.x_entry.iter().any(|v| !v.is_finite())
308        || block.x_exit.iter().any(|v| !v.is_finite())
309        || block.x_derivative.iter().any(|v| !v.is_finite())
310        || block.offset_eta_entry.iter().any(|v| !v.is_finite())
311        || block.offset_eta_exit.iter().any(|v| !v.is_finite())
312        || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
313        || !block.derivative_floor.is_finite()
314        || block.derivative_floor < 0.0
315    {
316        bail_invalid_surv!("non-finite input");
317    }
318    Ok(())
319}
320
321fn evaluate_cause_specific_block(
322    block: &CauseSpecificRoystonParmarBlock,
323    beta: &Array1<f64>,
324) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
325    let n = block.event_target.len();
326    let p = block.x_exit.ncols();
327    if beta.len() != p {
328        return Err(SurvivalError::CauseSpecificDimensionMismatch {
329            reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
330        });
331    }
332    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
333    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
334    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
335    let mut log_likelihood = 0.0;
336    let mut w_exit = Array1::<f64>::zeros(n);
337    let mut w_entry = Array1::<f64>::zeros(n);
338    let mut w_event = Array1::<f64>::zeros(n);
339    let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
340    let mut w_event_outer = Array1::<f64>::zeros(n);
341
342    for i in 0..n {
343        let weight = block.sampleweight[i];
344        if weight <= 0.0 {
345            continue;
346        }
347        if block.age_exit[i] < block.age_entry[i] {
348            bail_invalid_surv!("age_exit < age_entry at row {i}");
349        }
350        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
351        let h_exit = eta_exit[i].exp();
352        let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
353        if !(h_exit.is_finite() && h_entry.is_finite()) {
354            return Err(SurvivalError::NumericalFailure {
355                reason: format!("non-finite cumulative hazard at row {i}"),
356            });
357        }
358        log_likelihood -= weight * (h_exit - h_entry);
359        w_exit[i] = weight * h_exit;
360        w_entry[i] = weight * h_entry;
361        if block.event_target[i] > 0 {
362            let deriv = derivative[i];
363            if !(deriv.is_finite() && deriv > 0.0) {
364                return Err(SurvivalError::NumericalFailure {
365                    reason: format!(
366                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
367                    ),
368                });
369            }
370            log_likelihood += weight * (eta_exit[i] + deriv.ln());
371            w_event[i] = weight;
372            w_event_inv_deriv[i] = weight / deriv;
373            w_event_outer[i] = weight / (deriv * deriv);
374        }
375    }
376
377    let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
378    nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
379    nll_gradient -= &fast_atv(&block.x_exit, &w_event);
380    nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
381    let gradient = -nll_gradient;
382
383    let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
384    hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
385    hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
386    Ok((log_likelihood, gradient, hessian))
387}
388
389impl CustomFamily for CauseSpecificRoystonParmarFamily {
390    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
391    // flat-prior exact-Newton objective carries no Jeffreys term), so families
392    // that historically armed the term by default opt back in explicitly.
393    fn joint_jeffreys_term_required(&self) -> bool {
394        true
395    }
396
397    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
398        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
399            "cause-specific survival",
400            self.blocks.len(),
401            block_states.len(),
402        )?;
403        let mut log_likelihood = 0.0;
404        let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
405        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
406            let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
407            log_likelihood += ll;
408            blockworking_sets.push(BlockWorkingSet::ExactNewton {
409                gradient,
410                hessian: SymmetricMatrix::Dense(hessian),
411            });
412        }
413        Ok(FamilyEvaluation {
414            log_likelihood,
415            blockworking_sets,
416        })
417    }
418
419    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
420        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
421            "cause-specific survival",
422            self.blocks.len(),
423            block_states.len(),
424        )?;
425        let mut log_likelihood = 0.0;
426        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
427            let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
428            log_likelihood += ll;
429        }
430        Ok(log_likelihood)
431    }
432
433    fn likelihood_blocks_uncoupled(&self) -> bool {
434        true
435    }
436
437    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
438        true
439    }
440
441    fn output_channel_assignment(
442        &self,
443        specs: &[crate::custom_family::ParameterBlockSpec],
444    ) -> Option<Vec<usize>> {
445        if specs.len() != self.blocks.len() {
446            return Some((0..self.blocks.len()).collect());
447        }
448        Some((0..specs.len()).collect())
449    }
450
451    fn coefficient_hessian_cost(
452        &self,
453        specs: &[crate::custom_family::ParameterBlockSpec],
454    ) -> u64 {
455        crate::custom_family::default_coefficient_hessian_cost(specs)
456    }
457
458    fn block_linear_constraints(
459        &self,
460        _: &[ParameterBlockState],
461        block_idx: usize,
462        spec: &crate::custom_family::ParameterBlockSpec,
463    ) -> Result<Option<LinearInequalityConstraints>, String> {
464        let block = self.blocks.get(block_idx).ok_or_else(|| {
465            SurvivalError::CauseSpecificDimensionMismatch {
466                reason: format!(
467                    "cause-specific survival expected block index < {}, got {block_idx}",
468                    self.blocks.len()
469                ),
470            }
471            .to_string()
472        })?;
473        if block.x_derivative.ncols() != spec.design.ncols() {
474            return Err(SurvivalError::CauseSpecificDimensionMismatch {
475                reason: format!(
476                    "cause-specific survival derivative design has {} columns but block '{}' has {}",
477                    block.x_derivative.ncols(),
478                    spec.name,
479                    spec.design.ncols()
480                ),
481            }
482            .into());
483        }
484        let rhs = block
485            .offset_derivative_exit
486            .mapv(|offset| block.derivative_floor - offset);
487        Ok(Some(LinearInequalityConstraints {
488            a: block.x_derivative.clone(),
489            b: rhs,
490        }))
491    }
492
493    fn max_feasible_step_size(
494        &self,
495        block_states: &[ParameterBlockState],
496        block_idx: usize,
497        delta: &Array1<f64>,
498    ) -> Result<Option<f64>, String> {
499        let block = self.blocks.get(block_idx).ok_or_else(|| {
500            SurvivalError::CauseSpecificDimensionMismatch {
501                reason: format!(
502                    "cause-specific survival expected block index < {}, got {block_idx}",
503                    self.blocks.len()
504                ),
505            }
506            .to_string()
507        })?;
508        let state = block_states.get(block_idx).ok_or_else(|| {
509            SurvivalError::CauseSpecificDimensionMismatch {
510                reason: format!(
511                    "cause-specific survival expected {} block states, got {}",
512                    self.blocks.len(),
513                    block_states.len()
514                ),
515            }
516            .to_string()
517        })?;
518        if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
519            return Err(SurvivalError::CauseSpecificDimensionMismatch {
520                reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
521            }
522            .into());
523        }
524        let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
525        let derivative_delta = fast_av(&block.x_derivative, delta);
526        let mut alpha_max = 1.0_f64;
527        for i in 0..derivative.len() {
528            if block.sampleweight[i] <= 0.0 {
529                continue;
530            }
531            let current = derivative[i] - block.derivative_floor;
532            let slope = derivative_delta[i];
533            if slope < 0.0 {
534                if current <= 0.0 {
535                    return Ok(Some(0.0));
536                }
537                alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
538            }
539        }
540        Ok(Some(alpha_max.clamp(0.0, 1.0)))
541    }
542
543    fn exact_newton_hessian_directional_derivative(
544        &self,
545        block_states: &[ParameterBlockState],
546        block_idx: usize,
547        d_beta: &Array1<f64>,
548    ) -> Result<Option<Array2<f64>>, String> {
549        let block = self.blocks.get(block_idx).ok_or_else(|| {
550            SurvivalError::CauseSpecificDimensionMismatch {
551                reason: format!(
552                    "cause-specific survival expected block index < {}, got {block_idx}",
553                    self.blocks.len()
554                ),
555            }
556            .to_string()
557        })?;
558        let state = block_states.get(block_idx).ok_or_else(|| {
559            SurvivalError::CauseSpecificDimensionMismatch {
560                reason: format!(
561                    "cause-specific survival expected {} block states, got {}",
562                    self.blocks.len(),
563                    block_states.len()
564                ),
565            }
566            .to_string()
567        })?;
568        Ok(Some(cause_specific_hessian_directional_derivative(
569            block,
570            &state.beta,
571            d_beta,
572        )?))
573    }
574
575    fn exact_newton_hessian_second_directional_derivative(
576        &self,
577        block_states: &[ParameterBlockState],
578        block_idx: usize,
579        d_beta_u: &Array1<f64>,
580        d_beta_v: &Array1<f64>,
581    ) -> Result<Option<Array2<f64>>, String> {
582        let block = self.blocks.get(block_idx).ok_or_else(|| {
583            SurvivalError::CauseSpecificDimensionMismatch {
584                reason: format!(
585                    "cause-specific survival expected block index < {}, got {block_idx}",
586                    self.blocks.len()
587                ),
588            }
589            .to_string()
590        })?;
591        let state = block_states.get(block_idx).ok_or_else(|| {
592            SurvivalError::CauseSpecificDimensionMismatch {
593                reason: format!(
594                    "cause-specific survival expected {} block states, got {}",
595                    self.blocks.len(),
596                    block_states.len()
597                ),
598            }
599            .to_string()
600        })?;
601        Ok(Some(cause_specific_hessian_second_directional_derivative(
602            block,
603            &state.beta,
604            d_beta_u,
605            d_beta_v,
606        )?))
607    }
608}
609
610fn cause_specific_hessian_directional_derivative(
611    block: &CauseSpecificRoystonParmarBlock,
612    beta: &Array1<f64>,
613    d_beta: &Array1<f64>,
614) -> Result<Array2<f64>, SurvivalError> {
615    let p = block.x_exit.ncols();
616    if beta.len() != p || d_beta.len() != p {
617        return Err(SurvivalError::CauseSpecificDimensionMismatch {
618            reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
619        });
620    }
621    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
622    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
623    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
624    let d_eta_entry = fast_av(&block.x_entry, d_beta);
625    let d_eta_exit = fast_av(&block.x_exit, d_beta);
626    let d_derivative = fast_av(&block.x_derivative, d_beta);
627    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
628    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
629    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
630
631    for i in 0..block.event_target.len() {
632        let weight = block.sampleweight[i];
633        if weight <= 0.0 {
634            continue;
635        }
636        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
637        w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
638        if has_entry {
639            w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
640        }
641        if block.event_target[i] > 0 {
642            let deriv = derivative[i];
643            if !(deriv.is_finite() && deriv > 0.0) {
644                return Err(SurvivalError::NumericalFailure {
645                    reason: format!(
646                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
647                    ),
648                });
649            }
650            w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
651        }
652    }
653
654    let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
655    d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
656    d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
657    Ok(d_hessian)
658}
659
660fn cause_specific_hessian_second_directional_derivative(
661    block: &CauseSpecificRoystonParmarBlock,
662    beta: &Array1<f64>,
663    d_beta_u: &Array1<f64>,
664    d_beta_v: &Array1<f64>,
665) -> Result<Array2<f64>, SurvivalError> {
666    let p = block.x_exit.ncols();
667    if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
668        return Err(SurvivalError::CauseSpecificDimensionMismatch {
669            reason: "cause-specific survival second Hessian derivative dimension mismatch"
670                .to_string(),
671        });
672    }
673    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
674    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
675    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
676    let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
677    let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
678    let u_derivative = fast_av(&block.x_derivative, d_beta_u);
679    let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
680    let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
681    let v_derivative = fast_av(&block.x_derivative, d_beta_v);
682    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
683    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
684    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
685
686    for i in 0..block.event_target.len() {
687        let weight = block.sampleweight[i];
688        if weight <= 0.0 {
689            continue;
690        }
691        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
692        w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
693        if has_entry {
694            w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
695        }
696        if block.event_target[i] > 0 {
697            let deriv = derivative[i];
698            if !(deriv.is_finite() && deriv > 0.0) {
699                return Err(SurvivalError::NumericalFailure {
700                    reason: format!(
701                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
702                    ),
703                });
704            }
705            w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
706        }
707    }
708
709    let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
710    d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
711    d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
712    Ok(d2_hessian)
713}
714
715pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
716    const INTEGER_TOL: f64 = 1e-8;
717    const MAX_AUTO_CAUSES: u8 = 32;
718    if !value.is_finite() {
719        return Err(SurvivalError::EventCodeInvalid {
720            reason: format!(
721                "survival event value at row {} is non-finite",
722                row_index + 1
723            ),
724        }
725        .into());
726    }
727    if value < 0.0 {
728        return Err(SurvivalError::EventCodeInvalid {
729            reason: format!(
730                "survival event value at row {} is negative: {value}",
731                row_index + 1
732            ),
733        }
734        .into());
735    }
736    let rounded = value.round();
737    if (value - rounded).abs() > INTEGER_TOL {
738        return Err(SurvivalError::EventCodeInvalid {
739            reason: format!(
740                "survival event value at row {} must be an integer code with 0=censored, got {value}",
741                row_index + 1
742            ),
743        }
744        .into());
745    }
746    if rounded > f64::from(MAX_AUTO_CAUSES) {
747        return Err(SurvivalError::EventCodeInvalid {
748            reason: format!(
749                "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
750                row_index + 1
751            ),
752        }
753        .into());
754    }
755    Ok(rounded as u8)
756}
757
758pub fn cause_count_from_event_codes(
759    event_codes: ArrayView1<'_, u8>,
760) -> Result<usize, SurvivalError> {
761    let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
762    if max_code == 0 {
763        return Ok(1);
764    }
765
766    let mut present = vec![false; max_code + 1];
767    for code in event_codes.iter().copied() {
768        present[usize::from(code)] = true;
769    }
770    if (1..=max_code).any(|code| !present[code]) {
771        let actual = present
772            .iter()
773            .enumerate()
774            .skip(1)
775            .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
776            .collect::<Vec<_>>()
777            .join(", ");
778        return Err(SurvivalError::EventCodeInvalid {
779            reason: format!(
780                "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
781            ),
782        });
783    }
784
785    Ok(max_code)
786}
787
788/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
789/// onto the binary `{0, 1}` *any-event* indicator the pooled single-hazard
790/// baseline engine consumes.
791///
792/// The shared Royston-Parmar baseline working model fits one hazard across all
793/// causes; every observed event (regardless of cause) informs that baseline, so
794/// the indicator is `1` exactly when *any* cause occurred. This is one of the
795/// two cause-aware projections of the raw code vector; the other is
796/// [`cause_specific_event_indicator`]. Centralizing both here keeps the single
797/// source of truth for "how multi-cause labels become a single-hazard binary
798/// contract", so no construction path open-codes a fragile `mapv` and then
799/// trips the binary `event_target > 1` guard on the raw labels.
800pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
801    event_codes.mapv(|label| u8::from(label > 0))
802}
803
804/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
805/// onto the binary `{0, 1}` indicator for the cause-specific Royston-Parmar
806/// block of cause `cause` (1-based).
807///
808/// Within cause `cause`'s block the event of interest is `event == cause`; every
809/// competing cause is treated as censoring (indicator `0`), which is exactly the
810/// cause-specific hazard likelihood. Like [`pooled_any_event_indicator`], this
811/// yields a binary contract that satisfies the single-hazard `event_target` guard
812/// — the raw multi-cause labels are never handed to a binary engine.
813pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
814    let cause_code = cause as u8;
815    event_codes.mapv(|observed| u8::from(observed == cause_code))
816}
817
818fn compress_positive_collinear_constraints(
819    a: &Array2<f64>,
820    b: &Array1<f64>,
821) -> LinearInequalityConstraints {
822    const SCALE_TOL: f64 = 1e-14;
823    const KEY_TOL: f64 = 1e-8;
824
825    let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
826    let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
827
828    for i in 0..a.nrows() {
829        let row = a.row(i);
830        let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
831        if !scale.is_finite() || scale <= SCALE_TOL {
832            if b[i] > 0.0 {
833                fallbackrows.push((row.to_vec(), b[i]));
834            }
835            continue;
836        }
837
838        let normalizedrow: Vec<f64> = row
839            .iter()
840            .map(|&v| {
841                let scaled = v / scale;
842                if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
843            })
844            .collect();
845        let normalized_rhs = b[i] / scale;
846        let key: Vec<i64> = normalizedrow
847            .iter()
848            .map(|&v| (v / KEY_TOL).round() as i64)
849            .collect();
850
851        match grouped.get_mut(&key) {
852            Some((_, rhs_max)) => {
853                if normalized_rhs > *rhs_max {
854                    *rhs_max = normalized_rhs;
855                }
856            }
857            None => {
858                grouped.insert(key, (normalizedrow, normalized_rhs));
859            }
860        }
861    }
862
863    let nrows = grouped.len() + fallbackrows.len();
864    let n_cols = a.ncols();
865    let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
866    let mut b_out = Array1::<f64>::zeros(nrows);
867
868    let mut outrow = 0usize;
869    for (_, (row, rhs)) in grouped {
870        for (j, value) in row.into_iter().enumerate() {
871            a_out[[outrow, j]] = value;
872        }
873        b_out[outrow] = rhs;
874        outrow += 1;
875    }
876    for (row, rhs) in fallbackrows {
877        for (j, value) in row.into_iter().enumerate() {
878            a_out[[outrow, j]] = value;
879        }
880        b_out[outrow] = rhs;
881        outrow += 1;
882    }
883
884    LinearInequalityConstraints { a: a_out, b: b_out }
885}
886
887#[derive(Debug, Clone, Copy, Default)]
888pub struct SurvivalMonotonicityPenalty {
889    pub tolerance: f64,
890}
891
892#[derive(Debug, Clone)]
893enum SurvivalDesign {
894    Flat {
895        x_entry: Array2<f64>,
896        x_exit: Array2<f64>,
897        x_derivative: Array2<f64>,
898    },
899    TimeCovariateShared {
900        time_entry: Array2<f64>,
901        time_exit: Array2<f64>,
902        time_derivative: Array2<f64>,
903        covariates: Array2<f64>,
904    },
905}
906
907impl SurvivalDesign {
908    fn p_total(&self) -> usize {
909        match self {
910            Self::Flat { x_exit, .. } => x_exit.ncols(),
911            Self::TimeCovariateShared {
912                time_exit,
913                covariates,
914                ..
915            } => time_exit.ncols() + covariates.ncols(),
916        }
917    }
918
919    fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
920        match self {
921            Self::Flat { .. } => time_mat.dot(beta),
922            Self::TimeCovariateShared { covariates, .. } => {
923                let p_time = time_mat.ncols();
924                let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
925                if covariates.ncols() > 0 {
926                    out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
927                }
928                out
929            }
930        }
931    }
932
933    fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
934        match self {
935            Self::Flat { .. } => {
936                for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
937                    *dst = src;
938                }
939            }
940            Self::TimeCovariateShared { covariates, .. } => {
941                let p_time = time_mat.ncols();
942                for j in 0..p_time {
943                    out[j] = time_mat[[i, j]];
944                }
945                for j in 0..covariates.ncols() {
946                    out[p_time + j] = covariates[[i, j]];
947                }
948            }
949        }
950    }
951}
952
953/// Pre-allocated workspace buffers for `update_state` to avoid per-iteration allocations.
954#[derive(Debug, Clone)]
955struct SurvivalWorkspace {
956    w_event: Array1<f64>,
957    w_event_inv_deriv: Array1<f64>,
958    w_event_outer: Array1<f64>,
959    w_hess_exit: Array1<f64>,
960    w_hess_entry: Array1<f64>,
961}
962
963impl SurvivalWorkspace {
964    fn new(n: usize) -> Self {
965        Self {
966            w_event: Array1::zeros(n),
967            w_event_inv_deriv: Array1::zeros(n),
968            w_event_outer: Array1::zeros(n),
969            w_hess_exit: Array1::zeros(n),
970            w_hess_entry: Array1::zeros(n),
971        }
972    }
973
974    fn reset(&mut self, n: usize) {
975        if self.w_event.len() != n {
976            *self = Self::new(n);
977        } else {
978            self.w_event.fill(0.0);
979            self.w_event_inv_deriv.fill(0.0);
980            self.w_event_outer.fill(0.0);
981            self.w_hess_exit.fill(0.0);
982            self.w_hess_entry.fill(0.0);
983        }
984    }
985}
986
987/// Per-observation gradients of the unpenalized survival NLL with respect
988/// to each additive offset channel, at a given β. See
989/// [`WorkingModelSurvival::offset_channel_residuals`] for the algebra.
990///
991/// Contract: all four arrays have length `n` = number of observations.
992/// Rows with non-positive sampleweight are 0 in every channel. The
993/// `derivative` channel is 0 in all non-event rows. The `right` channel is
994/// the interval upper-bound (`R`) η-offset sensitivity and is exactly 0 for
995/// every NON-interval-censored model and every non-interval row of the latent
996/// interval model (only the dedicated `SurvInterval(L, R, event)` latent fit
997/// populates it); the baseline-θ chain rule contracts it against the
998/// `age_right`-evaluated η-partial.
999#[derive(Clone, Debug)]
1000pub struct OffsetChannelResiduals {
1001    /// ∂NLL/∂o_X: w·(exp(η_exit) − δ) per row.
1002    pub exit: Array1<f64>,
1003    /// ∂NLL/∂o_E: −w·exp(η_entry) if row has a positive entry interval else 0.
1004    pub entry: Array1<f64>,
1005    /// ∂NLL/∂o_D: −w·δ / s (event-row only).
1006    pub derivative: Array1<f64>,
1007    /// ∂NLL/∂o_R: interval upper-bound (`R`) η-offset sensitivity,
1008    /// `−w·∂(log-lik)/∂q_right`. Nonzero only for interval-censored latent
1009    /// rows; exactly 0 for every other channel/model.
1010    pub right: Array1<f64>,
1011}
1012
1013/// Per-observation Hessians of the unpenalized survival NLL with respect
1014/// to additive offset channels in `(entry, exit, derivative)` order.
1015#[derive(Clone, Debug)]
1016pub struct OffsetChannelCurvatures {
1017    pub rows: Vec<[[f64; 3]; 3]>,
1018}
1019
1020#[derive(Debug)]
1021pub struct WorkingModelSurvival {
1022    age_entry: Array1<f64>,
1023    age_exit: Array1<f64>,
1024    entry_at_origin: Array1<bool>,
1025    event_target: Array1<u8>,
1026    sampleweight: Array1<f64>,
1027    design: SurvivalDesign,
1028    offset_eta_entry: Array1<f64>,
1029    offset_eta_exit: Array1<f64>,
1030    offset_derivative_exit: Array1<f64>,
1031    penalties: PenaltyBlocks,
1032    monotonicity: SurvivalMonotonicityPenalty,
1033    structurally_monotonic: bool,
1034    structural_time_columns: usize,
1035    monotonicity_constraint_rows: Option<Array2<f64>>,
1036    monotonicity_constraint_offsets: Option<Array1<f64>>,
1037    workspace: std::sync::Mutex<SurvivalWorkspace>,
1038}
1039
1040impl Clone for WorkingModelSurvival {
1041    fn clone(&self) -> Self {
1042        let workspace = self.workspace.lock().unwrap().clone();
1043        Self {
1044            age_entry: self.age_entry.clone(),
1045            age_exit: self.age_exit.clone(),
1046            entry_at_origin: self.entry_at_origin.clone(),
1047            event_target: self.event_target.clone(),
1048            sampleweight: self.sampleweight.clone(),
1049            design: self.design.clone(),
1050            offset_eta_entry: self.offset_eta_entry.clone(),
1051            offset_eta_exit: self.offset_eta_exit.clone(),
1052            offset_derivative_exit: self.offset_derivative_exit.clone(),
1053            penalties: self.penalties.clone(),
1054            monotonicity: self.monotonicity,
1055            structurally_monotonic: self.structurally_monotonic,
1056            structural_time_columns: self.structural_time_columns,
1057            monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1058            monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1059            workspace: std::sync::Mutex::new(workspace),
1060        }
1061    }
1062}
1063
1064impl WorkingModelSurvival {
1065    const LOG_F64_MAX: f64 = 709.782712893384;
1066
1067    #[inline]
1068    fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1069        if base == 0.0 {
1070            return Ok(0.0);
1071        }
1072        let log_abs = log_scale + base.abs().ln();
1073        if !log_abs.is_finite() {
1074            crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1075        }
1076        if log_abs > Self::LOG_F64_MAX {
1077            crate::bail_invalid_estim!(
1078                "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1079            );
1080        }
1081        Ok(base.signum() * log_abs.exp())
1082    }
1083
1084    fn coefficient_dim(&self) -> usize {
1085        self.design.p_total()
1086    }
1087
1088    fn nrows(&self) -> usize {
1089        self.sampleweight.len()
1090    }
1091
1092    fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1093        let time_mat = match &self.design {
1094            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1095            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1096        };
1097        self.design.design_dot(time_mat, beta)
1098    }
1099
1100    fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1101        let time_mat = match &self.design {
1102            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1103            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1104        };
1105        self.design.design_dot(time_mat, beta)
1106    }
1107
1108    fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1109        match &self.design {
1110            SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1111            SurvivalDesign::TimeCovariateShared {
1112                time_derivative, ..
1113            } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1114        }
1115    }
1116
1117    fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1118        let time_mat = match &self.design {
1119            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1120            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1121        };
1122        self.design.fill_row(time_mat, i, out);
1123    }
1124
1125    fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1126        let time_mat = match &self.design {
1127            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1128            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1129        };
1130        self.design.fill_row(time_mat, i, out);
1131    }
1132
1133    fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1134        match &self.design {
1135            SurvivalDesign::Flat { x_derivative, .. } => {
1136                for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1137                    *dst = src;
1138                }
1139            }
1140            SurvivalDesign::TimeCovariateShared {
1141                time_derivative, ..
1142            } => {
1143                let p_time = time_derivative.ncols();
1144                for j in 0..p_time {
1145                    out[j] = time_derivative[[i, j]];
1146                }
1147                for dst in out.iter_mut().skip(p_time) {
1148                    *dst = 0.0;
1149                }
1150            }
1151        }
1152    }
1153
1154    fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1155        match &self.design {
1156            SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1157            SurvivalDesign::TimeCovariateShared {
1158                time_derivative,
1159                covariates,
1160                ..
1161            } => {
1162                let p_time = time_derivative.ncols();
1163                let p_cov = covariates.ncols();
1164                let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1165                let time_block = fast_xt_diag_x(time_derivative, weights);
1166                out.slice_mut(ndarray::s![..p_time, ..p_time])
1167                    .assign(&time_block);
1168                out
1169            }
1170        }
1171    }
1172
1173    /// Compute the full p×p Hessian contribution for the interval terms:
1174    ///   H = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1175    /// using faer-accelerated BLAS on the stored design matrix blocks.
1176    fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1177        match &self.design {
1178            SurvivalDesign::Flat {
1179                x_entry, x_exit, ..
1180            } => {
1181                let mut h = fast_xt_diag_x(x_exit, w_exit);
1182                h -= &fast_xt_diag_x(x_entry, w_entry);
1183                h
1184            }
1185            SurvivalDesign::TimeCovariateShared {
1186                time_entry,
1187                time_exit,
1188                covariates,
1189                ..
1190            } => {
1191                let p_time = time_exit.ncols();
1192                let p_cov = covariates.ncols();
1193                let p = p_time + p_cov;
1194                let mut h = Array2::<f64>::zeros((p, p));
1195                // time-time block: T_exit^T W_exit T_exit - T_entry^T W_entry T_entry
1196                let tt = {
1197                    let mut block = fast_xt_diag_x(time_exit, w_exit);
1198                    block -= &fast_xt_diag_x(time_entry, w_entry);
1199                    block
1200                };
1201                h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1202                if p_cov > 0 {
1203                    // time-cov block: T_exit^T W_exit C - T_entry^T W_entry C
1204                    let tc = {
1205                        let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1206                        block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1207                        block
1208                    };
1209                    h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1210                    h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1211                    // cov-cov block: C^T (W_exit - W_entry) C
1212                    let w_diff = w_exit - w_entry;
1213                    let cc = fast_xt_diag_x(covariates, &w_diff);
1214                    h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1215                }
1216                h
1217            }
1218        }
1219    }
1220
1221    fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1222        const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1223        if !self.structurally_monotonic {
1224            return None;
1225        }
1226        if deriv >= 1e-12 {
1227            return Some(deriv);
1228        }
1229        if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1230            return Some(1e-12);
1231        }
1232        None
1233    }
1234
1235    fn validate_penalties(
1236        penalties: &PenaltyBlocks,
1237        coefficient_dim: usize,
1238    ) -> Result<(), SurvivalError> {
1239        for block in &penalties.blocks {
1240            if !block.lambda.is_finite() || block.lambda < 0.0 {
1241                return Err(SurvivalError::NonFiniteInput);
1242            }
1243            if block.range.start > block.range.end || block.range.end > coefficient_dim {
1244                return Err(SurvivalError::DimensionMismatch);
1245            }
1246            let block_dim = block.range.end - block.range.start;
1247            if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1248                return Err(SurvivalError::DimensionMismatch);
1249            }
1250            if block.matrix.iter().any(|v| !v.is_finite()) {
1251                return Err(SurvivalError::NonFiniteInput);
1252            }
1253        }
1254        Ok(())
1255    }
1256
1257    fn derivative_guard(&self) -> f64 {
1258        if self.structurally_monotonic {
1259            // I-spline basis is monotone by construction when coefficients ≥ 0.
1260            // A derivative of zero (flat hazard) is valid, so the guard only
1261            // rejects genuinely negative derivatives from floating-point noise.
1262            return 0.0;
1263        }
1264        self.monotonicity.tolerance.max(0.0)
1265    }
1266
1267    fn derivative_guard_numerical(&self) -> f64 {
1268        let derivative_guard = self.derivative_guard();
1269        if derivative_guard <= 0.0 {
1270            // For structural monotonicity (guard = 0), tiny negative derivs are
1271            // tolerated because `stabilized_structural_derivative` lifts the
1272            // value back to a small positive floor before any `ln`/`1/deriv`
1273            // use. For *non-structural* monotonicity with tolerance == 0 the
1274            // raw derivative flows straight through into the event-row
1275            // `deriv.ln()` and `1.0 / deriv`, so any non-positive value would
1276            // produce NaN / huge negative weights. Keep the slack only when
1277            // the structural stabilizer is active.
1278            if self.structurally_monotonic {
1279                -1e-10
1280            } else {
1281                1e-12
1282            }
1283        } else {
1284            (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1285        }
1286    }
1287
1288    fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1289        let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1290        1e-10 * scale
1291    }
1292
1293    fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1294        if !self.structurally_monotonic {
1295            return None;
1296        }
1297        let p = self.coefficient_dim();
1298        let time_columns = self.structural_time_columns.min(p);
1299        if time_columns == 0 {
1300            return None;
1301        }
1302        const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1303        let mut active_columns = vec![false; time_columns];
1304        let mut derivative_row = vec![0.0_f64; p];
1305        for i in 0..self.nrows() {
1306            if self.sampleweight[i] <= 0.0 {
1307                continue;
1308            }
1309            self.fill_derivative_row(i, &mut derivative_row);
1310            for j in 0..time_columns {
1311                if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1312                    active_columns[j] = true;
1313                }
1314            }
1315        }
1316        if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1317            for i in 0..rows.nrows() {
1318                for j in 0..time_columns {
1319                    if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1320                        active_columns[j] = true;
1321                    }
1322                }
1323            }
1324        }
1325        let active_columns: Vec<usize> = active_columns
1326            .into_iter()
1327            .enumerate()
1328            .filter_map(|(j, active)| active.then_some(j))
1329            .collect();
1330        if active_columns.is_empty() {
1331            return None;
1332        }
1333        let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1334        let b = Array1::<f64>::zeros(active_columns.len());
1335        for (row, &col) in active_columns.iter().enumerate() {
1336            a[[row, col]] = 1.0;
1337        }
1338        Some(LinearInequalityConstraints { a, b })
1339    }
1340
1341    pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1342        let p = self.coefficient_dim();
1343        const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1344        if p == 0 {
1345            return None;
1346        }
1347        if self.structurally_monotonic {
1348            return self.structural_time_coefficient_constraints();
1349        }
1350        if let (Some(rows), Some(offsets)) = (
1351            self.monotonicity_constraint_rows.as_ref(),
1352            self.monotonicity_constraint_offsets.as_ref(),
1353        ) {
1354            let activerows: Vec<usize> = (0..rows.nrows())
1355                .filter(|&i| {
1356                    rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1357                        > DERIVATIVE_ROW_NORM_TOL
1358                })
1359                .collect();
1360            if activerows.is_empty() {
1361                return None;
1362            }
1363            let mut a = Array2::<f64>::zeros((activerows.len(), p));
1364            let mut b = Array1::<f64>::zeros(activerows.len());
1365            for (r, &i) in activerows.iter().enumerate() {
1366                a.row_mut(r).assign(&rows.row(i));
1367                b[r] = self.derivative_guard() - offsets[i];
1368            }
1369            return Some(compress_positive_collinear_constraints(&a, &b));
1370        }
1371        None
1372    }
1373
1374    pub fn from_engine_inputs(
1375        inputs: SurvivalEngineInputs<'_>,
1376        penalties: PenaltyBlocks,
1377        monotonicity: SurvivalMonotonicityPenalty,
1378        spec: SurvivalSpec,
1379    ) -> Result<Self, SurvivalError> {
1380        Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1381    }
1382
1383    fn validate_offsets(
1384        offsets: Option<SurvivalBaselineOffsets<'_>>,
1385        n: usize,
1386    ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1387        if let Some(off) = offsets {
1388            if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1389            {
1390                return Err(SurvivalError::DimensionMismatch);
1391            }
1392            if off.eta_entry.iter().any(|v| !v.is_finite())
1393                || off.eta_exit.iter().any(|v| !v.is_finite())
1394                || off.derivative_exit.iter().any(|v| !v.is_finite())
1395            {
1396                return Err(SurvivalError::NonFiniteInput);
1397            }
1398            Ok((
1399                off.eta_entry.to_owned(),
1400                off.eta_exit.to_owned(),
1401                off.derivative_exit.to_owned(),
1402            ))
1403        } else {
1404            Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1405        }
1406    }
1407
1408    fn validate_common_inputs(
1409        age_entry: &ArrayView1<f64>,
1410        age_exit: &ArrayView1<f64>,
1411        event_target: &ArrayView1<u8>,
1412        event_competing: &ArrayView1<u8>,
1413        sampleweight: &ArrayView1<f64>,
1414    ) -> Result<(), SurvivalError> {
1415        if age_entry.iter().any(|v| !v.is_finite())
1416            || age_exit.iter().any(|v| !v.is_finite())
1417            || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1418        {
1419            return Err(SurvivalError::NonFiniteInput);
1420        }
1421        // The single-hazard engine's `event_target` contract is binary {0, 1}.
1422        // A code > 1 is a *valid finite multi-cause label* that simply must be
1423        // projected first (any-event for the pooled baseline, cause-specific for
1424        // each block); it is NOT a non-finite input. Report it as such so the
1425        // failure is actionable and never surfaces as the misleading "inputs
1426        // contain non-finite values".
1427        if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1428            return Err(SurvivalError::EventCodeInvalid {
1429                reason: format!(
1430                    "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1431                ),
1432            });
1433        }
1434        if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1435            return Err(SurvivalError::EventCodeInvalid {
1436                reason: format!(
1437                    "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1438                ),
1439            });
1440        }
1441        if event_target
1442            .iter()
1443            .zip(event_competing.iter())
1444            .any(|(&target, &competing)| target > 0 && competing > 0)
1445        {
1446            return Err(SurvivalError::EventCodeInvalid {
1447                reason: "a row cannot be simultaneously a target event and a competing event"
1448                    .to_string(),
1449            });
1450        }
1451        // The "must have at least one target event" requirement is a
1452        // *fittability* check, not a structural one: with all rows censored the
1453        // likelihood has no event score, so any subsequent fit cannot identify
1454        // the hazard and the optimizer spins on a flat landscape.  But the
1455        // structural integrity of the engine — its derivative-guard rejection
1456        // of decreasing cumulative hazards, its monotonicity-collocation
1457        // bookkeeping, its update_state numerics — is well-defined on
1458        // all-censored inputs, and unit tests legitimately exercise those
1459        // structural paths on censored fixtures.  Move the fittability check
1460        // out of construction; production fit dispatchers (e.g.
1461        // `solver::fit_orchestration::materialize_survival`) enforce it on the
1462        // single chokepoint that actually starts an optimization, where
1463        // the failure mode it guards against is reachable.
1464        if age_entry
1465            .iter()
1466            .zip(age_exit.iter())
1467            .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1468        {
1469            return Err(SurvivalError::NonFiniteInput);
1470        }
1471        Ok::<(), _>(())
1472    }
1473
1474    fn validate_monotonicity_constraints(
1475        rows: Option<ArrayView2<'_, f64>>,
1476        offsets: Option<ArrayView1<'_, f64>>,
1477        coefficient_dim: usize,
1478    ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1479        match (rows, offsets) {
1480            (None, None) => Ok((None, None)),
1481            (Some(rows), Some(offsets)) => {
1482                if rows.ncols() != coefficient_dim
1483                    || rows.nrows() != offsets.len()
1484                    || rows.iter().any(|v| !v.is_finite())
1485                    || offsets.iter().any(|v| !v.is_finite())
1486                {
1487                    return Err(SurvivalError::DimensionMismatch);
1488                }
1489                Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1490            }
1491            _ => Err(SurvivalError::DimensionMismatch),
1492        }
1493    }
1494
1495    fn finish_construction(
1496        age_entry: ArrayView1<f64>,
1497        age_exit: ArrayView1<f64>,
1498        event_target: ArrayView1<u8>,
1499        sampleweight: ArrayView1<f64>,
1500        design: SurvivalDesign,
1501        offset_eta_entry: Array1<f64>,
1502        offset_eta_exit: Array1<f64>,
1503        offset_derivative_exit: Array1<f64>,
1504        penalties: PenaltyBlocks,
1505        monotonicity: SurvivalMonotonicityPenalty,
1506        monotonicity_constraint_rows: Option<Array2<f64>>,
1507        monotonicity_constraint_offsets: Option<Array1<f64>>,
1508    ) -> Self {
1509        let n = age_entry.len();
1510        Self {
1511            age_entry: age_entry.to_owned(),
1512            age_exit: age_exit.to_owned(),
1513            entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1514            event_target: event_target.to_owned(),
1515            sampleweight: sampleweight.to_owned(),
1516            design,
1517            offset_eta_entry,
1518            offset_eta_exit,
1519            offset_derivative_exit,
1520            penalties,
1521            monotonicity,
1522            structurally_monotonic: false,
1523            structural_time_columns: 0,
1524            monotonicity_constraint_rows,
1525            monotonicity_constraint_offsets,
1526            workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1527        }
1528    }
1529
1530    pub fn from_engine_inputswith_offsets(
1531        inputs: SurvivalEngineInputs<'_>,
1532        offsets: Option<SurvivalBaselineOffsets<'_>>,
1533        penalties: PenaltyBlocks,
1534        monotonicity: SurvivalMonotonicityPenalty,
1535        spec: SurvivalSpec,
1536    ) -> Result<Self, SurvivalError> {
1537        if spec == SurvivalSpec::Crude {
1538            return Err(SurvivalError::UnsupportedSpec("crude"));
1539        }
1540        let n = inputs.age_entry.len();
1541        let p = inputs.x_entry.ncols();
1542        if inputs.age_exit.len() != n
1543            || inputs.event_target.len() != n
1544            || inputs.event_competing.len() != n
1545            || inputs.sampleweight.len() != n
1546            || inputs.x_entry.nrows() != n
1547            || inputs.x_exit.nrows() != n
1548            || inputs.x_derivative.nrows() != n
1549            || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1550            || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1551        {
1552            return Err(SurvivalError::DimensionMismatch);
1553        }
1554        Self::validate_penalties(&penalties, p)?;
1555        Self::validate_common_inputs(
1556            &inputs.age_entry,
1557            &inputs.age_exit,
1558            &inputs.event_target,
1559            &inputs.event_competing,
1560            &inputs.sampleweight,
1561        )?;
1562        if inputs.x_entry.iter().any(|v| !v.is_finite())
1563            || inputs.x_exit.iter().any(|v| !v.is_finite())
1564            || inputs.x_derivative.iter().any(|v| !v.is_finite())
1565        {
1566            return Err(SurvivalError::NonFiniteInput);
1567        }
1568        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1569            Self::validate_offsets(offsets, n)?;
1570        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1571            Self::validate_monotonicity_constraints(
1572                inputs.monotonicity_constraint_rows,
1573                inputs.monotonicity_constraint_offsets,
1574                p,
1575            )?;
1576
1577        Ok(Self::finish_construction(
1578            inputs.age_entry,
1579            inputs.age_exit,
1580            inputs.event_target,
1581            inputs.sampleweight,
1582            SurvivalDesign::Flat {
1583                x_entry: inputs.x_entry.to_owned(),
1584                x_exit: inputs.x_exit.to_owned(),
1585                x_derivative: inputs.x_derivative.to_owned(),
1586            },
1587            offset_eta_entry,
1588            offset_eta_exit,
1589            offset_derivative_exit,
1590            penalties,
1591            monotonicity,
1592            monotonicity_constraint_rows,
1593            monotonicity_constraint_offsets,
1594        ))
1595    }
1596
1597    pub fn from_time_covariate_inputswith_offsets(
1598        inputs: SurvivalTimeCovarInputs<'_>,
1599        offsets: Option<SurvivalBaselineOffsets<'_>>,
1600        penalties: PenaltyBlocks,
1601        monotonicity: SurvivalMonotonicityPenalty,
1602        spec: SurvivalSpec,
1603    ) -> Result<Self, SurvivalError> {
1604        if spec == SurvivalSpec::Crude {
1605            return Err(SurvivalError::UnsupportedSpec("crude"));
1606        }
1607        let n = inputs.age_entry.len();
1608        let p_time = inputs.time_entry.ncols();
1609        let p_cov = inputs.covariates.ncols();
1610        let p = p_time + p_cov;
1611        if inputs.age_exit.len() != n
1612            || inputs.event_target.len() != n
1613            || inputs.event_competing.len() != n
1614            || inputs.sampleweight.len() != n
1615            || inputs.time_entry.nrows() != n
1616            || inputs.time_exit.nrows() != n
1617            || inputs.time_derivative.nrows() != n
1618            || inputs.covariates.nrows() != n
1619            || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1620            || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1621        {
1622            return Err(SurvivalError::DimensionMismatch);
1623        }
1624        Self::validate_penalties(&penalties, p)?;
1625        Self::validate_common_inputs(
1626            &inputs.age_entry,
1627            &inputs.age_exit,
1628            &inputs.event_target,
1629            &inputs.event_competing,
1630            &inputs.sampleweight,
1631        )?;
1632        if inputs.time_entry.iter().any(|v| !v.is_finite())
1633            || inputs.time_exit.iter().any(|v| !v.is_finite())
1634            || inputs.time_derivative.iter().any(|v| !v.is_finite())
1635            || inputs.covariates.iter().any(|v| !v.is_finite())
1636        {
1637            return Err(SurvivalError::NonFiniteInput);
1638        }
1639        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1640            Self::validate_offsets(offsets, n)?;
1641        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1642            Self::validate_monotonicity_constraints(
1643                inputs.monotonicity_constraint_rows,
1644                inputs.monotonicity_constraint_offsets,
1645                p,
1646            )?;
1647
1648        Ok(Self::finish_construction(
1649            inputs.age_entry,
1650            inputs.age_exit,
1651            inputs.event_target,
1652            inputs.sampleweight,
1653            SurvivalDesign::TimeCovariateShared {
1654                time_entry: inputs.time_entry.to_owned(),
1655                time_exit: inputs.time_exit.to_owned(),
1656                time_derivative: inputs.time_derivative.to_owned(),
1657                covariates: inputs.covariates.to_owned(),
1658            },
1659            offset_eta_entry,
1660            offset_eta_exit,
1661            offset_derivative_exit,
1662            penalties,
1663            monotonicity,
1664            monotonicity_constraint_rows,
1665            monotonicity_constraint_offsets,
1666        ))
1667    }
1668
1669    /// Enable/disable monotonic time-block enforcement metadata.
1670    ///
1671    /// Monotonicity is enforced through linear inequality constraints on the
1672    /// derivative design; enabling this records how many leading time columns
1673    /// belong to that constrained block.
1674    /// Overwrite the per-block smoothing parameters `λ_k` in place.
1675    ///
1676    /// Used by the REML smoothing-parameter selection for transformation
1677    /// survival fits (issue #563): the outer optimizer proposes a `ρ = log λ`
1678    /// vector, sets the smoothing blocks' `λ_k` here, and re-runs the inner
1679    /// constrained PIRLS, so the monotone I-spline baseline can adapt its
1680    /// wiggliness instead of being pinned at a fixed seed. `lambdas` must have
1681    /// one entry per penalty block. The fixed stabilization ridge keeps its
1682    /// caller-set value (the optimizer never proposes a new one for it).
1683    pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1684        if lambdas.len() != self.penalties.blocks.len() {
1685            crate::bail_invalid_estim!(
1686                "set_penalty_lambdas expects {} lambdas, got {}",
1687                self.penalties.blocks.len(),
1688                lambdas.len()
1689            );
1690        }
1691        for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1692            if !lambda.is_finite() || lambda < 0.0 {
1693                crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1694            }
1695            block.lambda = lambda;
1696        }
1697        Ok(())
1698    }
1699
1700    pub fn set_structural_monotonicity(
1701        &mut self,
1702        enabled: bool,
1703        time_columns: usize,
1704    ) -> Result<(), EstimationError> {
1705        let p = self.coefficient_dim();
1706        if time_columns > p {
1707            crate::bail_invalid_estim!(
1708                "structural time columns {} exceed coefficient dimension {}",
1709                time_columns,
1710                p
1711            );
1712        }
1713        if enabled && time_columns == 0 {
1714            crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1715        }
1716        if enabled {
1717            const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1718            for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1719                if offset < -STRUCTURAL_DERIV_TOL {
1720                    crate::bail_invalid_estim!(
1721                        "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1722                    );
1723                }
1724            }
1725            let mut derivative_row = vec![0.0_f64; p];
1726            for i in 0..self.nrows() {
1727                self.fill_derivative_row(i, &mut derivative_row);
1728                for j in 0..time_columns {
1729                    let v = derivative_row[j];
1730                    if v < -STRUCTURAL_DERIV_TOL {
1731                        crate::bail_invalid_estim!(
1732                            "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1733                        );
1734                    }
1735                }
1736                for j in time_columns..p {
1737                    let v = derivative_row[j];
1738                    if v.abs() > STRUCTURAL_DERIV_TOL {
1739                        crate::bail_invalid_estim!(
1740                            "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1741                        );
1742                    }
1743                }
1744            }
1745            if let (Some(rows), Some(offsets)) = (
1746                self.monotonicity_constraint_rows.as_ref(),
1747                self.monotonicity_constraint_offsets.as_ref(),
1748            ) {
1749                for (i, &offset) in offsets.iter().enumerate() {
1750                    if offset < -STRUCTURAL_DERIV_TOL {
1751                        crate::bail_invalid_estim!(
1752                            "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1753                        );
1754                    }
1755                }
1756                for i in 0..rows.nrows() {
1757                    for j in 0..time_columns {
1758                        let v = rows[[i, j]];
1759                        if v < -STRUCTURAL_DERIV_TOL {
1760                            crate::bail_invalid_estim!(
1761                                "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1762                            );
1763                        }
1764                    }
1765                    for j in time_columns..p {
1766                        let v = rows[[i, j]];
1767                        if v.abs() > STRUCTURAL_DERIV_TOL {
1768                            crate::bail_invalid_estim!(
1769                                "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1770                            );
1771                        }
1772                    }
1773                }
1774            }
1775        }
1776        self.structurally_monotonic = enabled;
1777        self.structural_time_columns = if enabled { time_columns } else { 0 };
1778        Ok(())
1779    }
1780
1781    pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1782        if beta.len() != self.coefficient_dim() {
1783            crate::bail_invalid_estim!("survival beta dimension mismatch");
1784        }
1785
1786        let n = self.nrows();
1787        let p = self.coefficient_dim();
1788
1789        // Royston-Parmar contract used throughout the engine:
1790        //   eta(t) = log(H(t)), where H(t) is cumulative hazard.
1791        //
1792        // With row-vectors (per subject i):
1793        //   a1_i^T := x_exit_i^T,  a0_i^T := x_entry_i^T,  d_i^T := x_derivative_i^T
1794        // and scalars:
1795        //   eta1_i = a1_i^T beta,  eta0_i = a0_i^T beta,  s_i = d_i^T beta.
1796        //
1797        // The per-subject negative log-likelihood used below is
1798        //   NLL_i(beta) = exp(eta1_i) - exp(eta0_i) - delta_i * (eta1_i + log(s_i)),
1799        // with delta_i = event_target_i.
1800        //
1801        // This is exactly the form whose derivatives are:
1802        //   grad_i = exp(eta1_i) a1_i - exp(eta0_i) a0_i - delta_i * (a1_i + d_i / s_i)
1803        //   Hess_i = exp(eta1_i) a1_i a1_i^T - exp(eta0_i) a0_i a0_i^T
1804        //            + delta_i * (d_i d_i^T) / s_i^2.
1805        //
1806        // Monotonicity is enforced through linear inequality constraints on the
1807        // derivative design. This keeps the baseline smoothing penalty on the
1808        // actual spline coefficients and preserves zero-deviation as beta=0.
1809        //
1810        // The loop below computes exact beta-space derivatives and then adds penalties.
1811        // Total predictor = target offset + learned deviation.
1812        // This is the same architecture used for flexible binary links:
1813        // principled default, plus penalized wiggle/deviation.
1814        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1815        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1816        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1817
1818        let mut nll = 0.0;
1819        let derivative_guard = self.derivative_guard();
1820        let derivative_guard_numerical = self.derivative_guard_numerical();
1821        let mut workspace = self.workspace.lock().unwrap();
1822        workspace.reset(n);
1823        let SurvivalWorkspace {
1824            w_event,
1825            w_event_inv_deriv,
1826            w_event_outer,
1827            w_hess_exit,
1828            w_hess_entry,
1829        } = &mut *workspace;
1830
1831        // Phase 1: Scalar loop — compute per-observation weights, NLL, validation.
1832        for i in 0..n {
1833            let w = self.sampleweight[i];
1834            if w <= 0.0 {
1835                continue;
1836            }
1837            let entry_age = self.age_entry[i];
1838            let exit_age = self.age_exit[i];
1839            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1840                crate::bail_invalid_estim!(
1841                    "survival ages must be finite with age_exit >= age_entry"
1842                );
1843            }
1844            let d = f64::from(self.event_target[i]);
1845
1846            let has_entry_interval = !self.entry_at_origin[i];
1847            let interval_scale = if has_entry_interval {
1848                eta_exit[i].max(eta_entry[i])
1849            } else {
1850                eta_exit[i]
1851            };
1852            let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1853            let h_s_scaled = if has_entry_interval {
1854                (eta_entry[i] - interval_scale).exp()
1855            } else {
1856                0.0
1857            };
1858            let interval_scaled = h_e_scaled - h_s_scaled;
1859            let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1860            let deriv = self
1861                .stabilized_structural_derivative(derivative_raw[i])
1862                .unwrap_or(derivative_raw[i]);
1863            // Monotonicity of η(t) = log H(t) is a structural property of the
1864            // whole Royston-Parmar spline. If d_eta/dt is *strictly negative*
1865            // at any observed exit time, the cumulative hazard H(t) decreases
1866            // there and S(t) is not a valid survival function — both event
1867            // and censored rows have to refuse that case. Event rows further
1868            // need deriv strictly above the numerical guard because their
1869            // NLL contains `deriv.ln()` and `1.0 / deriv`; censored rows do
1870            // not, so a boundary value of exactly zero is feasible there.
1871            let mono_floor = if d > 0.0 {
1872                derivative_guard_numerical
1873            } else {
1874                0.0
1875            };
1876            if !deriv.is_finite() || deriv < mono_floor {
1877                return Err(EstimationError::ParameterConstraintViolation(format!(
1878                    "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1879                    i, deriv, derivative_guard
1880                )));
1881            }
1882            if has_entry_interval {
1883                let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1884                if interval_scaled + increment_guard < 0.0 {
1885                    return Err(EstimationError::ParameterConstraintViolation(format!(
1886                        "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1887                        i, interval
1888                    )));
1889                }
1890            }
1891            nll += w * interval;
1892
1893            // Per-observation weights for BLAS phase.
1894            // scaled_exp_component(interval_scale, h_e_scaled * x[r]) = exp(interval_scale) * h_e_scaled * x[r]
1895            // so the Hessian weight is w * exp(interval_scale) * h_e_scaled = w * exp(eta_exit).
1896            let w_exit_i = w * eta_exit[i].exp();
1897            let w_entry_i = if has_entry_interval {
1898                w * eta_entry[i].exp()
1899            } else {
1900                0.0
1901            };
1902            if !w_exit_i.is_finite() {
1903                crate::bail_invalid_estim!(
1904                    "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1905                );
1906            }
1907            w_hess_exit[i] = w_exit_i;
1908            w_hess_entry[i] = w_entry_i;
1909
1910            if d > 0.0 {
1911                let inv_deriv = 1.0 / deriv;
1912                nll += -w * (eta_exit[i] + deriv.ln());
1913                w_event[i] = w;
1914                w_event_inv_deriv[i] = w * inv_deriv;
1915                w_event_outer[i] = w * inv_deriv * inv_deriv;
1916            }
1917        }
1918
1919        // Phase 2: BLAS-accelerated Hessian and gradient via faer.
1920        //   H_interval = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1921        //   grad_interval = X_exit^T w_exit - X_entry^T w_entry
1922        let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1923        // At large smoothing penalties the event-Jacobian score nearly cancels
1924        // the interval score. Compensated row accumulation keeps the final KKT
1925        // residual accurate enough for the outer LAML envelope check.
1926        let mut grad = Array1::<f64>::zeros(p);
1927        let mut grad_comp = Array1::<f64>::zeros(p);
1928        let mut row_exit = vec![0.0_f64; p];
1929        let mut row_entry = vec![0.0_f64; p];
1930        let mut row_derivative = vec![0.0_f64; p];
1931        for i in 0..n {
1932            let w_interval_exit = w_hess_exit[i];
1933            let w_interval_entry = w_hess_entry[i];
1934            let w_event_exit = w_event[i];
1935            let w_event_derivative = w_event_inv_deriv[i];
1936            if w_interval_exit == 0.0
1937                && w_interval_entry == 0.0
1938                && w_event_exit == 0.0
1939                && w_event_derivative == 0.0
1940            {
1941                continue;
1942            }
1943            self.fill_exit_row(i, &mut row_exit);
1944            self.fill_entry_row(i, &mut row_entry);
1945            self.fill_derivative_row(i, &mut row_derivative);
1946            for j in 0..p {
1947                let contribution = w_interval_exit * row_exit[j]
1948                    - w_interval_entry * row_entry[j]
1949                    - w_event_exit * row_exit[j]
1950                    - w_event_derivative * row_derivative[j];
1951                let t = grad[j] + contribution;
1952                if grad[j].abs() >= contribution.abs() {
1953                    grad_comp[j] += (grad[j] - t) + contribution;
1954                } else {
1955                    grad_comp[j] += (contribution - t) + grad[j];
1956                }
1957                grad[j] = t;
1958            }
1959        }
1960        grad += &grad_comp;
1961
1962        h += &self.derivative_xt_diag_x(w_event_outer);
1963
1964        // Norm of the unpenalized score, captured before adding the penalty
1965        // and ridge contributions, for the scale-invariant convergence
1966        // certificate (||score||_2 + ||S*beta||_2 (+ ridge*||beta||_2)).
1967        let score_norm = array1_l2_norm(&grad);
1968
1969        let penaltygrad = self.penalties.gradient(beta);
1970        let penalty_dev = self.penalties.deviance(beta);
1971        let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1972
1973        let mut totalgrad = grad;
1974        totalgrad += &penaltygrad;
1975
1976        self.penalties.addhessian_inplace(&mut h);
1977        // SURVIVAL_STABILIZATION_RIDGE is an `ExplicitPrior`-kind
1978        // stabilization in the canonical ledger taxonomy
1979        // (`gam_problem::StabilizationKind::ExplicitPrior`): δ enters the
1980        // gradient (`grad += δ β`), the Hessian (`H += δ I`), the scalar
1981        // penalty term added to the objective (`0.5 δ ‖β‖²`), and is
1982        // serialized through `WorkingState::ridge_used` so downstream
1983        // covariance and survival_ridge_lambda accounting remain
1984        // consistent. The canonical ledger record is
1985        //   StabilizationLedger::explicit_prior(δ, RidgeMatrixForm::ScaledIdentity)
1986        // chosen_by = FixedConstant. Coordinated with main.rs
1987        // `survival_ridge_lambda` field.
1988        const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1989        let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1990        for d in 0..p {
1991            h[[d, d]] += ridge_used;
1992        }
1993        totalgrad += &beta.mapv(|v| ridge_used * v);
1994        // Keep scalar objective term consistent with:
1995        //   grad += ridge * beta,  Hess += ridge * I
1996        // which correspond to 0.5 * ridge * ||beta||^2.
1997        let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1998        let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1999
2000        let log_likelihood = -nll;
2001        let deviance = 2.0 * nll;
2002
2003        Ok(WorkingState {
2004            eta: LinearPredictor::new(eta_exit),
2005            gradient: totalgrad,
2006            hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2007            log_likelihood,
2008            deviance,
2009            penalty_term: penalty_dev + ridge_penalty,
2010            firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2011            ridge_used,
2012            hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2013            gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2014        })
2015    }
2016
2017    /// Compute the third-derivative correction matrix for a given mode response `u_k`.
2018    ///
2019    /// This is the directional derivative of the unpenalized NLL Hessian w.r.t.
2020    /// beta along direction `u_k = -H^{-1} A_k beta_hat`. The returned matrix B
2021    /// satisfies `dH/drho_k = A_k + B`.
2022    ///
2023    /// Called via [`SurvivalDerivProvider`] which adapts the sign convention
2024    /// from the unified `HessianDerivativeProvider` trait (positive `v_k`) to
2025    /// the negated `u_k` used here.
2026    pub(crate) fn survival_hessian_derivative_correction(
2027        &self,
2028        beta: &Array1<f64>,
2029        u_k: &Array1<f64>,
2030    ) -> Result<Array2<f64>, EstimationError> {
2031        let p = beta.len();
2032        let n = self.nrows();
2033
2034        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2035        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2036        let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2037        let exp_entry = eta_entry.mapv(f64::exp);
2038        let exp_exit = eta_exit.mapv(f64::exp);
2039        let guard = self.derivative_guard();
2040        let guard_numerical = self.derivative_guard_numerical();
2041
2042        let jac = Array1::<f64>::ones(p);
2043        let curvature = Array1::<f64>::zeros(p);
2044        let third = Array1::<f64>::zeros(p);
2045
2046        let mut row_exit = vec![0.0_f64; p];
2047        let mut row_entry = vec![0.0_f64; p];
2048        let mut row_derivative = vec![0.0_f64; p];
2049        let mut ge = vec![0.0_f64; p];
2050        let mut gs = vec![0.0_f64; p];
2051        let mut gsd = vec![0.0_f64; p];
2052        let mut he = vec![0.0_f64; p];
2053        let mut hs = vec![0.0_f64; p];
2054        let mut hsd = vec![0.0_f64; p];
2055        let mut te = vec![0.0_f64; p];
2056        let mut ts = vec![0.0_f64; p];
2057        let mut tsd = vec![0.0_f64; p];
2058
2059        let mut b_dir = Array2::<f64>::zeros((p, p));
2060
2061        for i in 0..n {
2062            let w_i = self.sampleweight[i];
2063            if w_i <= 0.0 {
2064                continue;
2065            }
2066            let has_entry = !self.entry_at_origin[i];
2067            let mut deta_e = 0.0_f64;
2068            let mut deta_s = 0.0_f64;
2069            let mut ds = 0.0_f64;
2070            self.fill_exit_row(i, &mut row_exit);
2071            self.fill_entry_row(i, &mut row_entry);
2072            self.fill_derivative_row(i, &mut row_derivative);
2073            for j in 0..p {
2074                ge[j] = row_exit[j] * jac[j];
2075                gs[j] = row_entry[j] * jac[j];
2076                gsd[j] = row_derivative[j] * jac[j];
2077                he[j] = row_exit[j] * curvature[j];
2078                hs[j] = row_entry[j] * curvature[j];
2079                hsd[j] = row_derivative[j] * curvature[j];
2080                te[j] = row_exit[j] * third[j];
2081                ts[j] = row_entry[j] * third[j];
2082                tsd[j] = row_derivative[j] * third[j];
2083                deta_e += ge[j] * u_k[j];
2084                if has_entry {
2085                    deta_s += gs[j] * u_k[j];
2086                }
2087                ds += gsd[j] * u_k[j];
2088            }
2089
2090            // Interval part: d/dbeta [ exp(eta) * (g g^T + diag(h)) ][u_k]
2091            for r in 0..p {
2092                let dge_r = he[r] * u_k[r];
2093                let dgs_r = hs[r] * u_k[r];
2094                let dhe_r = te[r] * u_k[r];
2095                let dhs_r = ts[r] * u_k[r];
2096                for c in 0..p {
2097                    let dge_c = he[c] * u_k[c];
2098                    let dgs_c = hs[c] * u_k[c];
2099                    let mut d_h_rc =
2100                        exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2101                    if r == c {
2102                        d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2103                    }
2104                    if has_entry {
2105                        d_h_rc -=
2106                            exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2107                        if r == c {
2108                            d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2109                        }
2110                    }
2111                    b_dir[[r, c]] += w_i * d_h_rc;
2112                }
2113            }
2114
2115            // Event part: d/dbeta [ gsd gsd^T / s^2 - diag(he) - diag(hsd / s) ][u_k]
2116            let s_i = self
2117                .stabilized_structural_derivative(deriv_raw[i])
2118                .unwrap_or(deriv_raw[i]);
2119            if !s_i.is_finite() {
2120                return Err(EstimationError::ParameterConstraintViolation(format!(
2121                    "survival monotonicity violated in unified trace contraction at row {i}: \
2122                     d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2123                )));
2124            }
2125            if self.event_target[i] > 0 {
2126                if s_i < guard_numerical {
2127                    return Err(EstimationError::ParameterConstraintViolation(format!(
2128                        "survival monotonicity violated in unified trace contraction at row {i}: \
2129                         d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2130                    )));
2131                }
2132                let inv_s = 1.0 / s_i;
2133                let inv_s2 = inv_s * inv_s;
2134                let inv_s3 = inv_s2 * inv_s;
2135                for r in 0..p {
2136                    let dgd_r = hsd[r] * u_k[r];
2137                    let dtsd_r = tsd[r] * u_k[r];
2138                    let dte_r = te[r] * u_k[r];
2139                    for c in 0..p {
2140                        let dgd_c = hsd[c] * u_k[c];
2141                        let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2142                            - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2143                        if r == c {
2144                            d_h_rc += -dte_r;
2145                            d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2146                        }
2147                        b_dir[[r, c]] += w_i * d_h_rc;
2148                    }
2149                }
2150            }
2151        }
2152
2153        Ok(b_dir)
2154    }
2155
2156    /// Per-observation gradients of the unpenalized survival NLL with respect
2157    /// to each additive offset channel, at the given β.
2158    ///
2159    /// Contract (Royston-Parmar, eta = log H(t)):
2160    ///
2161    ///   NLL_i(β; o_E, o_X, o_D) = w_i · [
2162    ///       exp(η1_i) − 1{has_entry}·exp(η0_i)
2163    ///       − δ_i · (η1_i + log s_i)
2164    ///   ]
2165    ///
2166    /// with η1_i = a1_iᵀβ + o_X[i], η0_i = a0_iᵀβ + o_E[i],
2167    ///      s_i  = d_iᵀβ + o_D[i].
2168    ///
2169    /// The additive offsets enter each of the three η channels linearly, so
2170    ///   ∂NLL_i/∂o_X[i] = w_i · (exp(η1_i) − δ_i)
2171    ///   ∂NLL_i/∂o_E[i] = −w_i · exp(η0_i) · 1{has_entry_interval}
2172    ///   ∂NLL_i/∂o_D[i] = −w_i · δ_i / s_i         (event-row only)
2173    ///
2174    /// These three arrays are the sampleweight-scaled residuals used to chain
2175    /// `∂NLL/∂offset` into `∂NLL/∂θ` via any closed-form `∂offset/∂θ` map
2176    /// (see `baseline_offset_theta_partials` for parametric baselines). At
2177    /// converged β*, the envelope theorem on the penalized objective gives
2178    ///
2179    ///   d[0.5·(deviance + β*ᵀS_λβ*)] / dθ
2180    ///     = Σᵢ r_X_i·∂o_X_i/∂θ + r_E_i·∂o_E_i/∂θ + r_D_i·∂o_D_i/∂θ
2181    ///
2182    /// exactly (no IFT back-solve required), because β* is a stationary point
2183    /// of the penalized objective wrt β and the penalty has no θ dependence.
2184    ///
2185    /// Rows with `sampleweight[i] ≤ 0` and non-event rows for `r_D` are
2186    /// returned as exact 0.0 so the output can be dot-producted against a
2187    /// per-obs baseline-partials array without a mask.
2188    ///
2189    /// Structural-monotonicity stabilization on `s_i` (see
2190    /// `stabilized_structural_derivative`) is applied identically to the
2191    /// existing `update_state` path so the residual agrees with the
2192    /// NLL that `update_state` evaluates.
2193    pub fn offset_channel_residuals(
2194        &self,
2195        beta: &Array1<f64>,
2196    ) -> Result<OffsetChannelResiduals, EstimationError> {
2197        if beta.len() != self.coefficient_dim() {
2198            crate::bail_invalid_estim!(
2199                "survival beta dimension mismatch in offset_channel_residuals"
2200            );
2201        }
2202        let n = self.nrows();
2203        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2204        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2205        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2206
2207        let derivative_guard_numerical = self.derivative_guard_numerical();
2208        let mut r_exit = Array1::<f64>::zeros(n);
2209        let mut r_entry = Array1::<f64>::zeros(n);
2210        let mut r_deriv = Array1::<f64>::zeros(n);
2211
2212        for i in 0..n {
2213            let w = self.sampleweight[i];
2214            if w <= 0.0 {
2215                continue;
2216            }
2217            let entry_age = self.age_entry[i];
2218            let exit_age = self.age_exit[i];
2219            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2220                crate::bail_invalid_estim!(
2221                    "survival ages must be finite with age_exit >= age_entry"
2222                );
2223            }
2224            let has_entry_interval = !self.entry_at_origin[i];
2225            let d = f64::from(self.event_target[i]);
2226            // Phase-1 values matching update_state:
2227            //   w_exit_i  = w · exp(eta_exit)                    → ∂NLL/∂o_X before − δ·w term
2228            //   w_entry_i = w · exp(eta_entry) · 1{has_entry}    → matches −∂NLL/∂o_E sign
2229            let w_exit_i = w * eta_exit[i].exp();
2230            let w_entry_i = if has_entry_interval {
2231                w * eta_entry[i].exp()
2232            } else {
2233                0.0
2234            };
2235            if !w_exit_i.is_finite() {
2236                crate::bail_invalid_estim!(
2237                    "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2238                );
2239            }
2240            r_exit[i] = w_exit_i - d * w;
2241            r_entry[i] = -w_entry_i;
2242            // Same per-row monotonicity rule as `update_state`: a strictly
2243            // negative derivative at any observed exit time (event or
2244            // censored) falsifies S(t); event rows additionally need
2245            // `deriv > guard` because `1/deriv` enters their score.
2246            let deriv_raw = derivative_raw[i];
2247            let deriv = self
2248                .stabilized_structural_derivative(deriv_raw)
2249                .unwrap_or(deriv_raw);
2250            let mono_floor = if d > 0.0 {
2251                derivative_guard_numerical
2252            } else {
2253                0.0
2254            };
2255            if !deriv.is_finite() || deriv < mono_floor {
2256                return Err(EstimationError::ParameterConstraintViolation(format!(
2257                    "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2258                )));
2259            }
2260            if d > 0.0 {
2261                r_deriv[i] = -w * d / deriv;
2262            }
2263        }
2264
2265        let right = Array1::<f64>::zeros(r_exit.len());
2266        Ok(OffsetChannelResiduals {
2267            exit: r_exit,
2268            entry: r_entry,
2269            derivative: r_deriv,
2270            right,
2271        })
2272    }
2273
2274    /// Build an [`InnerSolution`](gam_solve::estimate::reml::reml_outer_engine::InnerSolution) from
2275    /// the survival working state, suitable for the unified REML/LAML evaluator.
2276    ///
2277    /// Evaluate the survival outer objective and gradient via the unified REML/LAML
2278    /// evaluator, using the canonical assembly module.
2279    pub fn unified_lamlobjective_and_rhogradient(
2280        &self,
2281        beta: &Array1<f64>,
2282        state: &WorkingState,
2283        rho: &Array1<f64>,
2284    ) -> Result<(f64, Array1<f64>), EstimationError> {
2285        use gam_solve::estimate::reml::assembly::{
2286            InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2287        };
2288        use gam_solve::estimate::reml::reml_outer_engine::{
2289            DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2290            compute_block_penalty_logdet_derivs,
2291        };
2292        use gam_problem::EvalMode;
2293
2294        let p = beta.len();
2295        let active_penalty_blocks: Vec<&PenaltyBlock> = self
2296            .penalties
2297            .blocks
2298            .iter()
2299            .filter(|b| b.lambda > 0.0)
2300            .collect();
2301        if rho.len() != active_penalty_blocks.len() {
2302            crate::bail_invalid_estim!(
2303                "survival LAML rho dimension {} does not match active penalty block count {}",
2304                rho.len(),
2305                active_penalty_blocks.len()
2306            );
2307        }
2308        let k_count = active_penalty_blocks.len();
2309
2310        // --- Hessian operator ---
2311        let h_dense = state.hessian.to_dense();
2312        let hop = DenseSpectralOperator::from_symmetric(&h_dense)
2313            .map_err(EstimationError::InvalidInput)?;
2314
2315        // --- Penalty coordinates via shared assembler helper ---
2316        let block_descs: Vec<PenaltyBlockDesc> = self
2317            .penalties
2318            .blocks
2319            .iter()
2320            .filter(|b| b.lambda > 0.0)
2321            .map(|b| PenaltyBlockDesc {
2322                matrix: &b.matrix,
2323                range_start: b.range.start,
2324                range_end: b.range.end,
2325            })
2326            .collect();
2327        let penalty_coords =
2328            penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2329
2330        // --- Penalty logdet derivatives ---
2331        let per_block_rho: Vec<Array1<f64>> =
2332            rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2333        let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2334            .iter()
2335            .map(|b| vec![b.matrix.clone()])
2336            .collect();
2337        let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2338            .iter()
2339            .map(|v| v.as_slice())
2340            .collect();
2341        let penalty_logdet = if k_count > 0 {
2342            compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2343                .map_err(EstimationError::InvalidInput)?
2344        } else {
2345            PenaltyLogdetDerivs {
2346                value: 0.0,
2347                first: Array1::zeros(0),
2348                second: Some(Array2::zeros((0, 0))),
2349            }
2350        };
2351
2352        // penalty_quadratic = 2 * penalty_term (matching unified evaluator convention).
2353        let penalty_quadratic = 2.0 * state.penalty_term;
2354        let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2355
2356        // #931 survival-LAML IFT envelope: attach the one-step Newton correction
2357        // only when this state is actually a near-stationary inner solution.
2358        // `unified_lamlobjective_and_rhogradient` is also used by algebraic
2359        // fixed-beta objective tests; feeding a large non-stationary residual
2360        // there makes the value a different surface. The re-converged shim
2361        // polishes the inner mode to an absolute residual floor, so certified
2362        // states still keep the envelope correction while arbitrary beta probes
2363        // evaluate the documented LAML objective.
2364        //
2365        // The residual MUST be the active-set-projected stationarity vector, not
2366        // raw `state.gradient`: a binding monotonicity constraint contributes a
2367        // Lagrange-multiplier normal component (`r = A^T lambda`, lambda >= 0)
2368        // that is not a stationarity residual.
2369        const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2370        let kkt_residual = {
2371            let raw = state.gradient.clone();
2372            let projected = match self.monotonicity_linear_constraints() {
2373                Some(constraints) => {
2374                    projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2375                        .ok_or_else(|| {
2376                            EstimationError::InvalidInput(
2377                                "survival LAML could not project the monotonicity KKT residual"
2378                                    .to_string(),
2379                            )
2380                        })?
2381                }
2382                None => raw,
2383            };
2384            let projected_norm = array1_l2_norm(&projected);
2385            let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2386            if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2387                Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2388            } else {
2389                None
2390            }
2391        };
2392
2393        let result = InnerAssembly {
2394            log_likelihood: state.log_likelihood,
2395            penalty_quadratic,
2396            beta: beta.clone(),
2397            n_observations: self.nrows(),
2398            hessian_op: std::sync::Arc::new(hop),
2399            penalty_coords,
2400            penalty_logdet,
2401            dispersion: DispersionHandling::Fixed {
2402                phi: 1.0,
2403                include_logdet_h: true,
2404                include_logdet_s: true,
2405            },
2406            rho_curvature_scale: 1.0,
2407            rho_prior: gam_problem::RhoPrior::Flat,
2408            hessian_logdet_correction: 0.0,
2409            penalty_subspace_trace: None,
2410            deriv_provider: Some(Box::new(provider)),
2411            firth: None,
2412            nullspace_dim: None,
2413            barrier_config: None,
2414            ext_coords: Vec::new(),
2415            ext_coord_pair_fn: None,
2416            rho_ext_pair_fn: None,
2417            fixed_drift_deriv: None,
2418            contracted_psi_second_order: None,
2419            kkt_residual,
2420            active_constraints: None,
2421        }
2422        .evaluate(
2423            rho.as_slice().expect("rho must be contiguous"),
2424            EvalMode::ValueAndGradient,
2425            None,
2426        )
2427        .map_err(EstimationError::InvalidInput)?;
2428
2429        let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2430        Ok((result.cost, gradient))
2431    }
2432
2433    /// Self-contained ρ → (LAML value, analytic ρ-gradient) surface for the
2434    /// survival LAML objective.
2435    ///
2436    /// Unlike [`unified_lamlobjective_and_rhogradient`](Self::unified_lamlobjective_and_rhogradient),
2437    /// which takes a *pre-converged* [`WorkingState`] and `β̂` at the evaluated
2438    /// `ρ`, this shim re-converges the inner survival mode internally: it sets
2439    /// the active-block smoothing parameters to `λ = exp(ρ)`, runs the same
2440    /// constrained inner PIRLS that the survival outer loop uses
2441    /// ([`runworking_model_pirls`](gam_solve::pirls::runworking_model_pirls)), then
2442    /// evaluates the unified survival LAML value and analytic ρ-gradient at the
2443    /// re-fitted `β̂(ρ)`. The returned pair is therefore a single-source value+
2444    /// gradient surface that a caller can finite-difference by varying `ρ`
2445    /// alone — the survival counterpart of the GLM path's
2446    /// `evaluate_externalgradient` / `evaluate_externalcost_andridge`.
2447    ///
2448    /// `rho` enumerates the **active** penalty blocks (those with `λ > 0`) in
2449    /// block order, matching the convention of the unified evaluator. `beta0` is
2450    /// the inner warm-start. The behaviour is identical to the existing survival
2451    /// LAML path (set-λ → inner PIRLS → `update_state` → unified LAML); this is a
2452    /// reachability shim, not a new objective.
2453    pub fn evaluate_survival_lamlcost_and_gradient(
2454        &self,
2455        rho: &[f64],
2456        beta0: &Array1<f64>,
2457    ) -> Result<(f64, Array1<f64>), EstimationError> {
2458        let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2459        // Re-converged β̂(ρ); evaluate the unified survival LAML value and
2460        // analytic ρ-gradient at that mode. The ρ passed to the unified
2461        // evaluator enumerates active blocks in block order, exactly the input
2462        // convention of this shim.
2463        let rho_arr = Array1::from_vec(rho.to_vec());
2464        let state = candidate.update_state(&beta)?;
2465        candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2466    }
2467
2468    /// Re-converge the survival inner mode at `λ = exp(ρ)` from warm-start
2469    /// `beta0`, returning the λ-set model candidate and the converged `β̂(ρ)`.
2470    /// This is the shared inner-solve used by
2471    /// [`evaluate_survival_lamlcost_and_gradient`](Self::evaluate_survival_lamlcost_and_gradient):
2472    /// inner PIRLS to a tight relative certificate, followed by a
2473    /// Levenberg–Marquardt / exact-Cholesky stationarity polish that drives the
2474    /// absolute penalized residual `‖S β̂ − ∇ℓ‖` below the FD round-off floor so
2475    /// the envelope ρ-gradient is exact. (Without the polish, PIRLS alone leaves
2476    /// `‖r‖ ~ 1` at large λ where H is ill-conditioned.)
2477    fn reconverge_survival_inner_mode(
2478        &self,
2479        rho: &[f64],
2480        beta0: &Array1<f64>,
2481    ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2482        // Inner-PIRLS settings mirror the survival transformation outer loop's
2483        // constrained inner solve. Tighter convergence than the production
2484        // outer loop so the inner mode is converged well below the FD step's
2485        // round-off floor, making ∇V finite-differentiable in ρ alone.
2486        const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2487        const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2488        const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2489        const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2490
2491        let active_block_count = self
2492            .penalties
2493            .blocks
2494            .iter()
2495            .filter(|b| b.lambda > 0.0)
2496            .count();
2497        if rho.len() != active_block_count {
2498            crate::bail_invalid_estim!(
2499                "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2500                rho.len(),
2501                active_block_count
2502            );
2503        }
2504        if beta0.len() != self.coefficient_dim() {
2505            crate::bail_invalid_estim!(
2506                "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2507                beta0.len(),
2508                self.coefficient_dim()
2509            );
2510        }
2511
2512        // Set λ = exp(ρ) on the active blocks (block order), leaving inactive
2513        // (λ = 0) blocks untouched, then re-converge the inner mode.
2514        let mut candidate = self.clone();
2515        let mut lambdas: Vec<f64> = candidate
2516            .penalties
2517            .blocks
2518            .iter()
2519            .map(|b| b.lambda)
2520            .collect();
2521        let mut active_idx = 0usize;
2522        for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2523            if block.lambda > 0.0 {
2524                *lambda = rho[active_idx].exp();
2525                active_idx += 1;
2526            }
2527        }
2528        candidate.set_penalty_lambdas(&lambdas)?;
2529
2530        let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2531            max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2532            convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2533            adaptive_kkt_tolerance: None,
2534            max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2535            min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2536            firth_bias_reduction: false,
2537            coefficient_lower_bounds: None,
2538            linear_constraints: None,
2539            initial_lm_lambda: None,
2540            geodesic_acceleration: false,
2541            arrow_schur: None,
2542        };
2543        let summary = gam_solve::pirls::runworking_model_pirls(
2544            &mut candidate,
2545            Coefficients::new(beta0.clone()),
2546            &opts,
2547            |_| {},
2548        )?;
2549        let mut beta = summary.beta.as_ref().to_owned();
2550
2551        // PIRLS exits on a RELATIVE KKT / deviance-plateau certificate, which can leave
2552        // an ABSOLUTE penalized stationarity residual r = S beta_hat - grad_ell of order
2553        // 0.1-1 (the score scales as O(sqrt(n))). The unified LAML gradient uses the
2554        // envelope theorem, exact only at r = 0; a residual that large leaks <r, beta_dot>
2555        // into the objective<->gradient consistency, AND the IFT envelope correction is
2556        // only leading-order in r, so it cannot make the analytic gradient the exact
2557        // derivative of the (re-converged, non-smooth-in-r) value surface either. The
2558        // robust cure is to drive the inner to TRUE stationarity (||r|| ~ 1e-11) so the
2559        // envelope is exactly valid and the IFT term is ~1e-22 — which it is at small
2560        // lambda, but a plain undamped Newton-polish STALLS at large lambda (rho=4..8):
2561        // there the intercept-direction curvature exp(eta)*n is large while the penalized
2562        // time block is lambda*S, so H is ill-conditioned and an undamped step neither
2563        // decreases ||r|| nor stays feasible, leaving ||r|| ~ 3e-2.
2564        //
2565        // Levenberg–Marquardt damping fixes this: solve (H + mu*diag(H)) delta = r,
2566        // accept on a genuine ||r||^2 decrease (Gauss–Newton on the stationarity system,
2567        // whose Jacobian is the penalized Hessian H), shrink mu on success and grow it on
2568        // rejection. The diagonal (Marquardt) scaling makes the damping curvature-aware so
2569        // the stiff time block and the soft intercept are damped commensurately. This
2570        // reliably reaches ||r|| below the FD-step round-off floor across the whole
2571        // rho = [-0.5 .. 8] range exercised by the consistency gates.
2572        {
2573            const POLISH_MAX_ITERS: usize = 400;
2574            const POLISH_TOL: f64 = 1e-13;
2575            // Armijo sufficient-decrease constant and backtracking factor.
2576            const ARMIJO_C: f64 = 1e-4;
2577            const BACKTRACK: f64 = 0.5;
2578            const MAX_BACKTRACK: usize = 80;
2579            let p = beta.len();
2580            // Penalized inner objective f(β) = −ℓ(β) + ½β'Sβ + ½ridge‖β‖² whose
2581            // gradient is exactly `state.gradient` and whose Hessian is exactly
2582            // `state.hessian`. `update_state` exposes the pieces directly.
2583            let penalized_objective =
2584                |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2585            for _ in 0..POLISH_MAX_ITERS {
2586                let st = match candidate.update_state(&beta) {
2587                    Ok(st) => st,
2588                    Err(_) => break,
2589                };
2590                let r = st.gradient.clone();
2591                let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2592                if !r_norm.is_finite() || r_norm < POLISH_TOL {
2593                    break;
2594                }
2595                let h = st.hessian.to_dense();
2596                let f0 = penalized_objective(&st);
2597                // Newton DIRECTION d = −H⁻¹r on the convex penalized survival
2598                // likelihood, found via a Levenberg–Marquardt-regularized solve
2599                // so an ill-conditioned H (h_diag ratio ~2400 at β₀≈4.6, where
2600                // exp(η) is huge) cannot produce a garbage direction whose
2601                // quadratic form rᵀH⁻¹r loses its sign. If even the regularized
2602                // Newton direction is not a sufficient descent direction, fall
2603                // back to STEEPEST DESCENT d = −r, which is ALWAYS a descent
2604                // direction on the convex objective (∇fᵀ(−r) = −‖r‖² < 0). The
2605                // line search below is on the OBJECTIVE VALUE (not ‖r‖), so any
2606                // descent direction makes monotone progress; near the optimum
2607                // the (lightly regularized) Newton step recovers fast local
2608                // convergence. This is globally convergent for the convex
2609                // penalized survival NLL — driving ‖r‖ below the FD round-off
2610                // floor so the envelope ρ-gradient equals the finite difference.
2611                let h_scale = (0..p)
2612                    .map(|d| h[[d, d]].abs())
2613                    .fold(0.0_f64, f64::max)
2614                    .max(1.0);
2615                // Solve (H + λI) step = r by an EXACT Cholesky factorization
2616                // (faer Llt), NOT the DenseSpectralOperator: the spectral
2617                // operator clamps tiny/negative eigenvalues, which on the
2618                // catastrophically ill-conditioned boundary Hessian (cond ~2400,
2619                // exp(η) huge at β₀≈4.6) corrupts the solve so badly that
2620                // rᵀH⁻¹r lost its sign and the previous polish broke on iter 0.
2621                // Cholesky succeeds iff H+λI is SPD; sweeping λ up from 0 finds
2622                // the smallest SPD shift, and for an SPD system rᵀ(H+λI)⁻¹r > 0
2623                // EXACTLY (Cholesky is backward-stable, no clamping), so the
2624                // Newton direction is a guaranteed descent direction.
2625                let mut step: Option<Array1<f64>> = None;
2626                let mut dir_deriv = 0.0_f64;
2627                for lm_pow in 0..18 {
2628                    let lambda_lm = if lm_pow == 0 {
2629                        0.0
2630                    } else {
2631                        1e-12 * h_scale * 10f64.powi(lm_pow)
2632                    };
2633                    let mut h_reg = h.clone();
2634                    for d in 0..p {
2635                        h_reg[[d, d]] += lambda_lm;
2636                    }
2637                    let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2638                        &h_reg,
2639                        faer::Side::Lower,
2640                    ) {
2641                        Ok(f) => f,
2642                        Err(_) => continue,
2643                    };
2644                    let candidate_step = factor.solvevec(&r);
2645                    if candidate_step.iter().any(|v| !v.is_finite()) {
2646                        continue;
2647                    }
2648                    // ∇fᵀd = rᵀ(−step) = −r·(H+λI)⁻¹r < 0 exactly for SPD systems.
2649                    let dd = -r.dot(&candidate_step);
2650                    if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2651                        step = Some(candidate_step);
2652                        dir_deriv = dd;
2653                        break;
2654                    }
2655                }
2656                let (step, dir_deriv) = match step {
2657                    Some(s) => (s, dir_deriv),
2658                    None => {
2659                        // Steepest-descent fallback: d = −r ⇒ step = +r (we step
2660                        // β − step), ∇fᵀd = −‖r‖² < 0.
2661                        (r.clone(), -r_norm * r_norm)
2662                    }
2663                };
2664                let mut alpha = 1.0_f64;
2665                let mut accepted = false;
2666                for _ in 0..MAX_BACKTRACK {
2667                    let trial = &beta - &(alpha * &step);
2668                    if let Ok(ts) = candidate.update_state(&trial) {
2669                        let ft = penalized_objective(&ts);
2670                        let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2671                        // Accept on EITHER a sufficient objective decrease (Armijo,
2672                        // the global-convergence guarantee on the convex objective)
2673                        // OR a strict residual-norm decrease. Near the solution the
2674                        // penalized objective is flat to f64 roundoff (f0 ≈ ft), so a
2675                        // pure-Armijo test backtracks α→0 and crawls (the asymmetric
2676                        // ρ=3.99999 stall: 200 iters at 3.7e-7 vs 12 iters at the
2677                        // other two ρ). The ‖r‖-decrease arm lets the exact Cholesky
2678                        // Newton step (α=1) through, restoring quadratic convergence
2679                        // to ~1e-12 symmetrically across all three FD points so the
2680                        // centered FD of the value surface is itself exact.
2681                        let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2682                        let residual_ok = tn.is_finite() && tn < r_norm;
2683                        if armijo_ok || residual_ok {
2684                            beta = trial;
2685                            accepted = true;
2686                            break;
2687                        }
2688                    }
2689                    alpha *= BACKTRACK;
2690                }
2691                if !accepted {
2692                    break;
2693                }
2694            }
2695        }
2696
2697        Ok((candidate, beta))
2698    }
2699}
2700
2701/// Derivative provider that adapts survival third-derivative Hessian corrections
2702/// to the unified [`HessianDerivativeProvider`](gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider)
2703/// trait.
2704///
2705/// The unified trait supplies `v_k = H^{-1}(A_k beta_hat)` (positive sign),
2706/// whereas the survival engine's
2707/// [`survival_hessian_derivative_correction`](WorkingModelSurvival::survival_hessian_derivative_correction)
2708/// expects `u_k = -v_k`. This provider handles the sign conversion.
2709pub(crate) struct SurvivalDerivProvider {
2710    model: WorkingModelSurvival,
2711    beta: Array1<f64>,
2712}
2713
2714impl SurvivalDerivProvider {
2715    pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2716        Self { model, beta }
2717    }
2718}
2719
2720impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2721    fn hessian_derivative_correction(
2722        &self,
2723        v_k: &Array1<f64>,
2724    ) -> Result<Option<Array2<f64>>, String> {
2725        // The trait provides v_k = H^{-1}(A_k beta_hat) (positive).
2726        // The survival method expects u_k = -H^{-1} A_k beta_hat = -v_k.
2727        let u_k = -v_k;
2728        match self
2729            .model
2730            .survival_hessian_derivative_correction(&self.beta, &u_k)
2731        {
2732            Ok(correction) => Ok(Some(correction)),
2733            Err(e) => Err(e.to_string()),
2734        }
2735    }
2736
2737    fn has_corrections(&self) -> bool {
2738        true
2739    }
2740}
2741
2742#[derive(Debug, Clone)]
2743pub struct CrudeRiskResult {
2744    pub risk: f64,
2745    pub diseasegradient: Array1<f64>,
2746    pub mortalitygradient: Array1<f64>,
2747}
2748
2749#[derive(Debug, Clone)]
2750pub struct CompetingRisksCifResult {
2751    /// Cumulative incidence per endpoint. `cif[ep][[row, time_idx]]` is the
2752    /// probability that cause `ep` occurred by `times[time_idx]` for sample `row`.
2753    /// Stored one matrix per endpoint so it is ergonomic to index per-cause and
2754    /// natural to construct from the per-endpoint cumulative-hazard inputs.
2755    pub cif: Vec<Array2<f64>>,
2756    pub overall_survival: Array2<f64>,
2757}
2758
2759/// Subject-count threshold below which competing-risks CIF assembly stays on the
2760/// serial path. The per-row work (a `n_times`-long prefix-sum recurrence with a
2761/// handful of `exp`/`exp_m1` per element) is cheap, so small panels avoid rayon
2762/// fan-out overhead; large panels (the #1082 quality-test sizes) amortize it.
2763const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2764
2765pub fn assemble_competing_risks_cif(
2766    times: ArrayView1<'_, f64>,
2767    cumulative_hazard: ArrayView3<'_, f64>,
2768) -> Result<CompetingRisksCifResult, SurvivalError> {
2769    let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2770    if n_endpoints == 0 {
2771        return Err(SurvivalError::DimensionMismatch);
2772    }
2773    let endpoint_hazards = cumulative_hazard
2774        .axis_iter(Axis(0))
2775        .map(|view| view.to_owned())
2776        .collect::<Vec<_>>();
2777    assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2778        if result.overall_survival.dim() != (n_rows, n_times) {
2779            Err(SurvivalError::DimensionMismatch)
2780        } else {
2781            Ok(result)
2782        }
2783    })
2784}
2785
2786pub fn assemble_competing_risks_cif_from_endpoints(
2787    times: ArrayView1<'_, f64>,
2788    cumulative_hazards: &[Array2<f64>],
2789) -> Result<CompetingRisksCifResult, SurvivalError> {
2790    let n_endpoints = cumulative_hazards.len();
2791    if n_endpoints == 0 || times.is_empty() {
2792        return Err(SurvivalError::DimensionMismatch);
2793    }
2794    let (n_rows, n_times) = cumulative_hazards[0].dim();
2795    if n_rows == 0 || n_times == 0 || times.len() != n_times {
2796        return Err(SurvivalError::DimensionMismatch);
2797    }
2798    if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2799        return Err(SurvivalError::InvalidTimeGrid);
2800    }
2801    if times
2802        .iter()
2803        .zip(times.iter().skip(1))
2804        .any(|(previous, current)| current <= previous)
2805    {
2806        return Err(SurvivalError::InvalidTimeGrid);
2807    }
2808    for endpoint_hazard in cumulative_hazards {
2809        if endpoint_hazard.dim() != (n_rows, n_times) {
2810            return Err(SurvivalError::DimensionMismatch);
2811        }
2812        if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2813            return Err(SurvivalError::NonFiniteInput);
2814        }
2815    }
2816
2817    let max_abs_hazard = cumulative_hazards
2818        .iter()
2819        .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2820        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2821    let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2822    let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2823        .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2824        .collect();
2825    let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2826
2827    // Per-row CIF assembly. The TIME axis is a sequential prefix-sum recurrence
2828    // (`previous_*` carry forward across `time_idx`) and MUST stay ordered, so it
2829    // is left as the inner serial loop. The ROW (subject) axis is fully
2830    // independent: every `previous_*`/`increments` buffer is allocated fresh per
2831    // row, no state crosses rows, and each row writes only its own disjoint
2832    // output slices. The per-row result is byte-identical regardless of which
2833    // thread runs it, so we fan the outer row loop out over rayon and write the
2834    // owned per-row buffers back serially in row order (deterministic, bit-exact
2835    // vs. the serial implementation).
2836    //
2837    // `cif_flat` is endpoint-major: `cif_flat[endpoint * n_times + time_idx]`.
2838    let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2839        let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2840        let mut surv_row = vec![0.0_f64; n_times];
2841        let mut previous_cif = vec![0.0_f64; n_endpoints];
2842        let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2843        let mut increments = vec![0.0_f64; n_endpoints];
2844        let mut previous_total_cumulative = 0.0_f64;
2845        for time_idx in 0..n_times {
2846            let mut total_increment = 0.0_f64;
2847            for endpoint in 0..n_endpoints {
2848                let current = cumulative_hazards[endpoint][[row, time_idx]];
2849                if current < -monotone_tolerance {
2850                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2851                }
2852                let raw_increment = current - previous_cumulative[endpoint];
2853                if raw_increment < -monotone_tolerance {
2854                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2855                }
2856                let increment = raw_increment.max(0.0);
2857                increments[endpoint] = increment;
2858                total_increment += increment;
2859                previous_cumulative[endpoint] += increment;
2860            }
2861
2862            let survival_left = (-previous_total_cumulative).exp();
2863            let interval_failure = -(-total_increment).exp_m1();
2864            for endpoint in 0..n_endpoints {
2865                if total_increment > 0.0 {
2866                    previous_cif[endpoint] +=
2867                        survival_left * interval_failure * increments[endpoint] / total_increment;
2868                }
2869                cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2870            }
2871            previous_total_cumulative += total_increment;
2872            // Derive `S(t)` from the stored cause-specific CIFs at this time so
2873            // that the competing-risks closure identity
2874            //   Σ_k F_k(t) + S(t) = 1
2875            // holds bit-exactly. Computing `S` independently as
2876            // `exp(-Σ_k H_k(t))` and then comparing against the (clamped, ratio-
2877            // split) Σ F_k introduces O(machine-eps) closure error because the
2878            // float increments
2879            //   ΔF_k = S_left·(1-exp(-ΔH))·ΔH_k/ΔH_total
2880            // do not sum to `S_left - S_new` bit-exactly. By summing the stored
2881            // CIFs in the same left-fold order as `slice.iter().sum::<f64>()`
2882            // and defining `S := 1.0 - Σ F_k`, the IEEE-754 round-trip
2883            //   (1.0 - f) + f
2884            // restores the identity for finite f ∈ [0, 1]. The mathematically
2885            // consistent survival value `exp(-H_total)` is still tracked up to
2886            // ulp-level precision because the ΔF_k construction matches
2887            // `S_left - S_new` to leading order.
2888            let mut fsum_at_t = 0.0_f64;
2889            for endpoint in 0..n_endpoints {
2890                fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2891            }
2892            surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2893        }
2894        Ok((cif_flat, surv_row))
2895    };
2896
2897    // Nesting guard (`rayon::current_thread_index().is_none()`) keeps us from
2898    // oversubscribing when this routine is itself called from inside a rayon
2899    // worker, and the row-count gate keeps small inputs on the serial path.
2900    let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2901        && rayon::current_thread_index().is_none()
2902    {
2903        use rayon::prelude::*;
2904        (0..n_rows)
2905            .into_par_iter()
2906            .map(assemble_row)
2907            .collect::<Result<_, _>>()?
2908    } else {
2909        (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2910    };
2911
2912    for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2913        for endpoint in 0..n_endpoints {
2914            for time_idx in 0..n_times {
2915                cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2916            }
2917        }
2918        for time_idx in 0..n_times {
2919            overall_survival[[row, time_idx]] = surv_row[time_idx];
2920        }
2921    }
2922
2923    Ok(CompetingRisksCifResult {
2924        cif,
2925        overall_survival,
2926    })
2927}
2928
2929fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2930    let mut nodesweights = Vec::with_capacity(n);
2931    let m = n.div_ceil(2);
2932
2933    for i in 0..m {
2934        let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2935        let mut pp = 0.0;
2936
2937        for _ in 0..100 {
2938            let mut p1 = 1.0;
2939            let mut p2 = 0.0;
2940            for j in 0..n {
2941                let p3 = p2;
2942                p2 = p1;
2943                p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2944            }
2945            pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2946            let z_prev = z;
2947            z = z_prev - p1 / pp;
2948            if (z - z_prev).abs() < 1e-14 {
2949                break;
2950            }
2951        }
2952
2953        let x = z;
2954        let w = 2.0 / ((1.0 - z * z) * pp * pp);
2955        if !n.is_multiple_of(2) && i == m - 1 {
2956            nodesweights.push((0.0, w));
2957        } else {
2958            nodesweights.push((-x, w));
2959            nodesweights.push((x, w));
2960        }
2961    }
2962
2963    nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2964    nodesweights
2965}
2966
2967fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2968    // `LazyLock` (not `OnceLock::get_or_init`) so first init never parks a
2969    // caller on the OS condvar from inside a rayon worker. The competing-risks
2970    // CIF assembler in this file dispatches `into_par_iter` and the
2971    // codebase-level lint (`tests/once_lock_get_or_init_not_inside_parallel_regions.rs`)
2972    // forbids the lazy `OnceLock` accessor in any rayon-adjacent file.
2973    static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2974    &CACHE
2975}
2976
2977/// Engine-level crude risk quadrature with exact delta-method gradients.
2978///
2979/// This routine owns the numerical integration and gradient accumulation math:
2980/// - It integrates `h_d(u) * S_total(u | t0)` over `[t0, t1]` by high-order
2981///   Gauss-Legendre quadrature.
2982/// - It computes gradients w.r.t. disease and mortality coefficients:
2983///   d Risk / d beta_d and d Risk / d beta_m.
2984///
2985/// The adapter provides the domain-specific point evaluator callback `eval_at`,
2986/// which fills design rows and returns:
2987/// - instantaneous disease hazard h_d(u) at age `u`,
2988/// - cumulative disease hazard `H_d(u)`,
2989/// - cumulative mortality hazard `H_m(u)`.
2990///
2991/// The callback must fill the following arrays (one entry per coefficient):
2992/// - `design_d[j]`: partial derivative of the linear predictor eta_d w.r.t. beta_j
2993///   at time u, i.e. x_j(u) = d eta_d(u) / d beta_j.
2994/// - `deriv_d[j]`: partial derivative of the TIME DERIVATIVE of eta_d w.r.t. beta_j
2995///   at time u, i.e. x_dot_j(u) = d/d(beta_j) [d eta_d(u)/du].
2996/// - `design_m[j]`: same as design_d but for the mortality linear predictor eta_m.
2997///
2998/// This keeps domain/data wiring out of `gam` while centralizing the
2999/// integration engine in one place.
3000pub fn calculate_crude_risk_quadrature<F>(
3001    t0: f64,
3002    t1: f64,
3003    breakpoints: &[f64],
3004    h_dis_t0: f64,
3005    h_mor_t0: f64,
3006    design_d_t0: ArrayView1<'_, f64>,
3007    design_m_t0: ArrayView1<'_, f64>,
3008    mut eval_at: F,
3009) -> Result<CrudeRiskResult, SurvivalError>
3010where
3011    F: FnMut(
3012        f64,
3013        &mut Array1<f64>,
3014        &mut Array1<f64>,
3015        &mut Array1<f64>,
3016    ) -> Result<(f64, f64, f64), SurvivalError>,
3017{
3018    let coeff_len_d = design_d_t0.len();
3019    let coeff_len_m = design_m_t0.len();
3020    if coeff_len_d == 0 || coeff_len_m == 0 {
3021        return Err(SurvivalError::InvalidIntegrationSetup);
3022    }
3023    if !t0.is_finite()
3024        || !t1.is_finite()
3025        || !h_dis_t0.is_finite()
3026        || !h_mor_t0.is_finite()
3027        || design_d_t0.iter().any(|v| !v.is_finite())
3028        || design_m_t0.iter().any(|v| !v.is_finite())
3029    {
3030        return Err(SurvivalError::NonFiniteInput);
3031    }
3032    if t1 <= t0 {
3033        return Ok(CrudeRiskResult {
3034            risk: 0.0,
3035            diseasegradient: Array1::zeros(coeff_len_d),
3036            mortalitygradient: Array1::zeros(coeff_len_m),
3037        });
3038    }
3039
3040    let mut sorted_breaks: Vec<f64> = breakpoints
3041        .iter()
3042        .copied()
3043        .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3044        .collect();
3045    sorted_breaks.push(t0);
3046    sorted_breaks.push(t1);
3047    sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3048    sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3049    if sorted_breaks.len() < 2 {
3050        return Err(SurvivalError::InvalidIntegrationSetup);
3051    }
3052
3053    let mut total_risk = 0.0;
3054    let mut diseasegradient = Array1::zeros(coeff_len_d);
3055    let mut mortalitygradient = Array1::zeros(coeff_len_m);
3056    let nodesweights = gauss_legendre_quadrature();
3057
3058    let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3059    let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3060    let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3061
3062    for segment in sorted_breaks.windows(2) {
3063        let a = segment[0];
3064        let b = segment[1];
3065        let center = 0.5 * (b + a);
3066        let halfwidth = 0.5 * (b - a);
3067        if halfwidth <= 0.0 {
3068            continue;
3069        }
3070
3071        for &(x, w) in nodesweights {
3072            let u = center + halfwidth * x;
3073            let (inst_hazard_d, hazard_d, hazard_m) =
3074                eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3075            if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3076                return Err(SurvivalError::NonFiniteInput);
3077            }
3078            if inst_hazard_d <= 0.0 {
3079                return Err(SurvivalError::NonPositiveHazard);
3080            }
3081
3082            if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3083                return Err(SurvivalError::NonMonotoneCumulativeHazard);
3084            }
3085
3086            let h_dis_cond = hazard_d - h_dis_t0;
3087            let h_mor_cond = hazard_m - h_mor_t0;
3088            let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3089
3090            total_risk += w * inst_hazard_d * s_total * halfwidth;
3091
3092            // d Risk / d beta_d:
3093            //   integral [ d h_d * S_total - h_d * S_total * d H_d ] du
3094            // Contract: design_d[j] = x_j(u) = ∂_{β_j} η_d(u)
3095            //           deriv_d[j]  = ẋ_j(u) = ∂_{β_j} η̇_d(u)
3096            // Then ∂_{β_j} h_d = h_d · x_j + H_d · ẋ_j
3097            let weight = w * s_total * halfwidth;
3098            for j in 0..coeff_len_d {
3099                let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3100                let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3101                let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3102                diseasegradient[j] += weight * g;
3103            }
3104
3105            // d Risk / d beta_m:
3106            //   -integral h_d * S_total * d H_m(u|t0) du
3107            let weight = w * inst_hazard_d * s_total * halfwidth;
3108            for j in 0..coeff_len_m {
3109                let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3110                mortalitygradient[j] += weight * g;
3111            }
3112        }
3113    }
3114
3115    Ok(CrudeRiskResult {
3116        risk: total_risk,
3117        diseasegradient,
3118        mortalitygradient,
3119    })
3120}
3121
3122impl PirlsWorkingModel for WorkingModelSurvival {
3123    fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3124        self.update_state(beta)
3125    }
3126}
3127
3128#[cfg(test)]
3129mod tests {
3130    use super::*;
3131    use ndarray::{Array1, Array2, Array3, array, s};
3132
3133    #[test]
3134    fn competing_risks_cif_constant_hazard_matches_closed_form() {
3135        let times = array![0.0, 2.0, 5.0, 10.0];
3136        let disease_rates = [0.12, 0.06];
3137        let death_rates = [0.05, 0.02];
3138        let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3139            let rate = if endpoint == 0 {
3140                disease_rates[row]
3141            } else {
3142                death_rates[row]
3143            };
3144            rate * times[time_idx]
3145        });
3146
3147        let result =
3148            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3149
3150        for row in 0..2 {
3151            let total_rate = disease_rates[row] + death_rates[row];
3152            for time_idx in 0..times.len() {
3153                let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3154                let expected_disease = disease_rates[row] / total_rate * failure;
3155                let expected_death = death_rates[row] / total_rate * failure;
3156                assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3157                assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3158                assert!(
3159                    (result.cif[0][[row, time_idx]]
3160                        + result.cif[1][[row, time_idx]]
3161                        + result.overall_survival[[row, time_idx]]
3162                        - 1.0)
3163                        .abs()
3164                        < 1e-12
3165                );
3166            }
3167        }
3168    }
3169
3170    #[test]
3171    fn competing_risks_cif_rejects_nonmonotone_hazards() {
3172        let times = array![0.0, 1.0, 2.0];
3173        let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3174        let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3175            .expect_err("nonmonotone cumulative hazard should be rejected");
3176        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3177    }
3178
3179    #[test]
3180    fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3181        let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3182        let cumulative = Array3::from_shape_vec(
3183            (3, 2, 5),
3184            vec![
3185                // cause 1
3186                0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, // cause 2
3187                0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, // cause 3
3188                0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3189            ],
3190        )
3191        .expect("shape");
3192
3193        let result =
3194            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3195
3196        for row in 0..2 {
3197            for time_idx in 0..times.len() {
3198                let total_cif = result.cif[0][[row, time_idx]]
3199                    + result.cif[1][[row, time_idx]]
3200                    + result.cif[2][[row, time_idx]];
3201                assert!(
3202                    (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3203                    "probability mass mismatch at row={row}, time_idx={time_idx}"
3204                );
3205                assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3206                for cause in 0..3 {
3207                    assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3208                    if time_idx > 0 {
3209                        assert!(
3210                            result.cif[cause][[row, time_idx]] + 1e-12
3211                                >= result.cif[cause][[row, time_idx - 1]],
3212                            "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3213                        );
3214                    }
3215                }
3216            }
3217        }
3218
3219        // Cause 1 is flat between t=1 and t=3 for row 0, but other causes
3220        // fail in that interval; its CIF must remain exactly flat.
3221        assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3222        // All causes are flat between t=3 and t=7 for row 1 except cause 2;
3223        // causes 1 and 3 must not move.
3224        assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3225        assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3226    }
3227
3228    #[test]
3229    fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3230        let cumulative = Array3::zeros((2, 1, 2));
3231
3232        for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3233            let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3234                .expect_err("bad time grid should be rejected");
3235            assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3236        }
3237
3238        let times = array![0.0, 1.0];
3239        let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3240        let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3241            .expect_err("nonfinite hazard should be rejected");
3242        assert!(matches!(err, SurvivalError::NonFiniteInput));
3243    }
3244
3245    #[test]
3246    fn competing_risks_cif_extreme_hazards_remain_bounded() {
3247        let times = array![0.0, 1.0, 2.0];
3248        let cumulative =
3249            Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3250                .expect("shape");
3251
3252        let result =
3253            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3254
3255        for value in result
3256            .cif
3257            .iter()
3258            .flat_map(|m| m.iter())
3259            .chain(result.overall_survival.iter())
3260        {
3261            assert!(value.is_finite());
3262            assert!((0.0..=1.0).contains(value));
3263        }
3264        assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3265        assert_eq!(result.overall_survival[[0, 2]], 0.0);
3266    }
3267
3268    fn toy_penalties() -> PenaltyBlocks {
3269        let s = array![[2.0, 0.5], [0.5, 3.0]];
3270        PenaltyBlocks::new(vec![PenaltyBlock {
3271            matrix: s,
3272            lambda: 1.7,
3273            range: 1..3,
3274            nullspace_dim: 0,
3275        }])
3276    }
3277
3278    fn survival_inputs<'a>(
3279        age_entry: &'a Array1<f64>,
3280        age_exit: &'a Array1<f64>,
3281        event_target: &'a Array1<u8>,
3282        event_competing: &'a Array1<u8>,
3283        sampleweight: &'a Array1<f64>,
3284        x_entry: &'a Array2<f64>,
3285        x_exit: &'a Array2<f64>,
3286        x_derivative: &'a Array2<f64>,
3287    ) -> SurvivalEngineInputs<'a> {
3288        SurvivalEngineInputs {
3289            age_entry: age_entry.view(),
3290            age_exit: age_exit.view(),
3291            event_target: event_target.view(),
3292            event_competing: event_competing.view(),
3293            sampleweight: sampleweight.view(),
3294            x_entry: x_entry.view(),
3295            x_exit: x_exit.view(),
3296            x_derivative: x_derivative.view(),
3297            monotonicity_constraint_rows: None,
3298            monotonicity_constraint_offsets: None,
3299        }
3300    }
3301
3302    fn survival_model(
3303        inputs: SurvivalEngineInputs<'_>,
3304        penalties: PenaltyBlocks,
3305        monotonicity: SurvivalMonotonicityPenalty,
3306        spec: SurvivalSpec,
3307    ) -> Result<WorkingModelSurvival, SurvivalError> {
3308        WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3309    }
3310
3311    fn survival_model_with_offsets(
3312        inputs: SurvivalEngineInputs<'_>,
3313        offsets: Option<SurvivalBaselineOffsets<'_>>,
3314        penalties: PenaltyBlocks,
3315        monotonicity: SurvivalMonotonicityPenalty,
3316        spec: SurvivalSpec,
3317    ) -> Result<WorkingModelSurvival, SurvivalError> {
3318        WorkingModelSurvival::from_engine_inputswith_offsets(
3319            inputs,
3320            offsets,
3321            penalties,
3322            monotonicity,
3323            spec,
3324        )
3325    }
3326
3327    #[test]
3328    fn penaltyhessian_matchesgradient_jacobian() {
3329        let penalties = toy_penalties();
3330        let beta = array![10.0, -0.3, 1.2, 7.0];
3331
3332        let grad = penalties.gradient(&beta);
3333        let h = penalties.hessian(beta.len());
3334        let b_block = beta.slice(s![1..3]).to_owned();
3335        let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3336
3337        assert!((grad[1] - expected[0]).abs() < 1e-12);
3338        assert!((grad[2] - expected[1]).abs() < 1e-12);
3339        assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3340        assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3341        assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3342        assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3343    }
3344
3345    #[test]
3346    fn penaltygradient_matches_deviance_finite_difference() {
3347        let penalties = toy_penalties();
3348        let beta = array![10.0, -0.3, 1.2, 7.0];
3349        let grad = penalties.gradient(&beta);
3350        let eps = 1e-7;
3351
3352        for idx in 0..beta.len() {
3353            let mut plus = beta.clone();
3354            let mut minus = beta.clone();
3355            plus[idx] += eps;
3356            minus[idx] -= eps;
3357            let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3358            assert_eq!(
3359                grad[idx].signum(),
3360                fd.signum(),
3361                "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3362                grad[idx]
3363            );
3364            assert!(
3365                (grad[idx] - fd).abs() < 1e-6,
3366                "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3367                grad[idx]
3368            );
3369        }
3370    }
3371
3372    #[test]
3373    fn zero_offsets_match_default_survival_state() {
3374        let age_entry = array![1.0_f64, 2.0_f64];
3375        let age_exit = array![2.0_f64, 3.5_f64];
3376        let event_target = array![1u8, 0u8];
3377        let event_competing = array![0u8, 0u8];
3378        let sampleweight = array![1.0, 1.0];
3379        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3380        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3381        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3382        let penalties = PenaltyBlocks::new(Vec::new());
3383        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3384        let beta = array![-1.0, 0.8];
3385
3386        let base = survival_model(
3387            survival_inputs(
3388                &age_entry,
3389                &age_exit,
3390                &event_target,
3391                &event_competing,
3392                &sampleweight,
3393                &x_entry,
3394                &x_exit,
3395                &x_derivative,
3396            ),
3397            penalties.clone(),
3398            mono,
3399            SurvivalSpec::Net,
3400        )
3401        .expect("construct base survival model");
3402
3403        let zero_offsets = survival_model_with_offsets(
3404            survival_inputs(
3405                &age_entry,
3406                &age_exit,
3407                &event_target,
3408                &event_competing,
3409                &sampleweight,
3410                &x_entry,
3411                &x_exit,
3412                &x_derivative,
3413            ),
3414            Some(SurvivalBaselineOffsets {
3415                eta_entry: array![0.0, 0.0].view(),
3416                eta_exit: array![0.0, 0.0].view(),
3417                derivative_exit: array![0.0, 0.0].view(),
3418            }),
3419            penalties,
3420            mono,
3421            SurvivalSpec::Net,
3422        )
3423        .expect("construct offset survival model");
3424
3425        let state_base = base.update_state(&beta).expect("base state");
3426        let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3427        assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3428        assert!(
3429            state_base
3430                .gradient
3431                .iter()
3432                .zip(statezero.gradient.iter())
3433                .all(|(a, b)| (a - b).abs() < 1e-12)
3434        );
3435    }
3436
3437    #[test]
3438    fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3439        // Regression for #378: the joint competing-risks Weibull path seeds a
3440        // shared single-hazard baseline working model from the dataset's event
3441        // *labels* {0 = censored, 1 = cause 1, 2 = cause 2}. The single-hazard
3442        // engine's `event_target` contract is binary {0, 1}, so feeding the raw
3443        // cause labels straight through used to bail out of construction via the
3444        // `event_target > 1` guard and surface as the misleading
3445        // `SurvivalError::NonFiniteInput` ("inputs contain non-finite values"),
3446        // even though every input value is finite. The fix (a) reports a
3447        // multi-cause label as the actionable `EventCodeInvalid`, never the
3448        // misleading "non-finite", and (b) projects cause labels to the
3449        // any-event {0, 1} indicator via the single-source-of-truth
3450        // `pooled_any_event_indicator` before constructing the pooled baseline.
3451        // This pins both halves of that contract.
3452        let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3453        let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3454        // Competing-risks cause labels: censored, cause 1, cause 2, censored.
3455        let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3456        let event_competing = Array1::<u8>::zeros(cause_labels.len());
3457        let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3458        let x_entry = array![
3459            [1.0, age_entry[0].max(1e-8).ln()],
3460            [1.0, age_entry[1].max(1e-8).ln()],
3461            [1.0, age_entry[2].max(1e-8).ln()],
3462            [1.0, age_entry[3].max(1e-8).ln()],
3463        ];
3464        let x_exit = array![
3465            [1.0, age_exit[0].ln()],
3466            [1.0, age_exit[1].ln()],
3467            [1.0, age_exit[2].ln()],
3468            [1.0, age_exit[3].ln()],
3469        ];
3470        let x_derivative = array![
3471            [0.0, 1.0 / age_exit[0]],
3472            [0.0, 1.0 / age_exit[1]],
3473            [0.0, 1.0 / age_exit[2]],
3474            [0.0, 1.0 / age_exit[3]],
3475        ];
3476        let penalties = PenaltyBlocks::new(Vec::new());
3477        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3478
3479        // Raw cause labels {0,1,2} violate the single-hazard binary contract and
3480        // must be rejected -- but as an *actionable* `EventCodeInvalid`, NOT the
3481        // misleading `NonFiniteInput`: the labels are finite, they merely need
3482        // projecting. (The old fix left this surfacing as "non-finite".)
3483        let raw = survival_model(
3484            survival_inputs(
3485                &age_entry,
3486                &age_exit,
3487                &cause_labels,
3488                &event_competing,
3489                &sampleweight,
3490                &x_entry,
3491                &x_exit,
3492                &x_derivative,
3493            ),
3494            penalties.clone(),
3495            mono,
3496            SurvivalSpec::Net,
3497        );
3498        assert!(
3499            matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3500            "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3501        );
3502
3503        // The pooled-baseline projection the workflow now performs through the
3504        // single source of truth: any observed event (any cause) -> {0, 1}.
3505        let any_event = pooled_any_event_indicator(cause_labels.view());
3506        assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3507        // And the per-cause projection that seeds each cause-specific block.
3508        assert_eq!(
3509            cause_specific_event_indicator(cause_labels.view(), 1),
3510            array![0u8, 1u8, 0u8, 0u8]
3511        );
3512        assert_eq!(
3513            cause_specific_event_indicator(cause_labels.view(), 2),
3514            array![0u8, 0u8, 1u8, 0u8]
3515        );
3516        let model = survival_model(
3517            survival_inputs(
3518                &age_entry,
3519                &age_exit,
3520                &any_event,
3521                &event_competing,
3522                &sampleweight,
3523                &x_entry,
3524                &x_exit,
3525                &x_derivative,
3526            ),
3527            penalties,
3528            mono,
3529            SurvivalSpec::Net,
3530        )
3531        .expect("pooled any-event baseline model must construct from competing-risks data");
3532
3533        // The constructed pooled baseline must yield a finite working state, so
3534        // the downstream baseline-seeding PIRLS loop has something to optimize.
3535        let beta = array![-1.0_f64, 0.8];
3536        let state = model.update_state(&beta).expect("pooled baseline state");
3537        assert!(
3538            state.deviance.is_finite(),
3539            "pooled baseline deviance must be finite, got {}",
3540            state.deviance
3541        );
3542        assert!(
3543            state.gradient.iter().all(|g| g.is_finite()),
3544            "pooled baseline gradient must be finite"
3545        );
3546    }
3547
3548    #[test]
3549    fn offset_channel_residuals_match_central_fd_of_nll() {
3550        // Three observations: two events (non-origin entry and origin entry)
3551        // and one censored row. This exercises every nonzero channel at least
3552        // once: r_exit from all rows, r_entry only from the first (has entry
3553        // interval), r_derivative only from events.
3554        let age_entry = array![0.5_f64, 0.0, 0.3];
3555        let age_exit = array![1.4_f64, 1.0, 2.0];
3556        let event_target = array![1u8, 1u8, 0u8];
3557        let event_competing = array![0u8, 0u8, 0u8];
3558        let sampleweight = array![1.0_f64, 2.5, 0.7];
3559        let x_entry = array![
3560            [1.0, age_entry[0].ln()],
3561            [1.0, age_entry[1].max(1e-8).ln()],
3562            [1.0, age_entry[2].ln()]
3563        ];
3564        let x_exit = array![
3565            [1.0, age_exit[0].ln()],
3566            [1.0, age_exit[1].ln()],
3567            [1.0, age_exit[2].ln()]
3568        ];
3569        let x_derivative = array![
3570            [0.0, 1.0 / age_exit[0]],
3571            [0.0, 1.0 / age_exit[1]],
3572            [0.0, 1.0 / age_exit[2]]
3573        ];
3574        // Baseline offsets chosen so η_entry, η_exit, s are all comfortably
3575        // away from overflow / monotonicity-violation boundaries.
3576        let o_entry = array![0.2_f64, 0.0, 0.1];
3577        let o_exit = array![0.4_f64, 0.5, 0.7];
3578        let o_deriv = array![0.3_f64, 0.8, 0.5];
3579        let penalties = PenaltyBlocks::new(Vec::new());
3580        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3581        let beta = array![-0.7_f64, 0.6];
3582
3583        let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3584            survival_model_with_offsets(
3585                survival_inputs(
3586                    &age_entry,
3587                    &age_exit,
3588                    &event_target,
3589                    &event_competing,
3590                    &sampleweight,
3591                    &x_entry,
3592                    &x_exit,
3593                    &x_derivative,
3594                ),
3595                Some(SurvivalBaselineOffsets {
3596                    eta_entry: o_e.view(),
3597                    eta_exit: o_x.view(),
3598                    derivative_exit: o_d.view(),
3599                }),
3600                penalties.clone(),
3601                mono,
3602                SurvivalSpec::Net,
3603            )
3604            .expect("model build")
3605        };
3606
3607        let base = build(&o_entry, &o_exit, &o_deriv);
3608        let resid = base
3609            .offset_channel_residuals(&beta)
3610            .expect("offset residuals");
3611        assert_eq!(resid.exit.len(), 3);
3612        assert_eq!(resid.entry.len(), 3);
3613        assert_eq!(resid.derivative.len(), 3);
3614
3615        // NLL equals half the deviance returned by update_state; that is the
3616        // exact unpenalized loss whose offset partials r_{X,E,D} encode.
3617        let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3618        let h = 1e-6;
3619
3620        // Row 1 (origin entry, event=1) has no entry interval, so r_entry[1]
3621        // must be exactly 0. Row 2 (censored) has r_deriv[2] exactly 0. Check
3622        // those identities before FD comparison on the nonzero elements.
3623        assert_eq!(resid.entry[1], 0.0);
3624        assert_eq!(resid.derivative[2], 0.0);
3625
3626        for i in 0..3 {
3627            // exit channel: perturb o_exit[i] alone.
3628            {
3629                let mut op = o_exit.clone();
3630                let mut om = o_exit.clone();
3631                op[i] += h;
3632                om[i] -= h;
3633                let fd = (nll(&build(&o_entry, &op, &o_deriv))
3634                    - nll(&build(&o_entry, &om, &o_deriv)))
3635                    / (2.0 * h);
3636                assert!(
3637                    (resid.exit[i] - fd).abs() < 1e-6,
3638                    "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3639                    resid.exit[i],
3640                    fd
3641                );
3642            }
3643            // entry channel: only row 0 has an entry interval; for rows with
3644            // entry_at_origin the offset contributes nothing to NLL and FD
3645            // must also be exactly 0 to numerical precision.
3646            {
3647                let mut op = o_entry.clone();
3648                let mut om = o_entry.clone();
3649                op[i] += h;
3650                om[i] -= h;
3651                let fd = (nll(&build(&op, &o_exit, &o_deriv))
3652                    - nll(&build(&om, &o_exit, &o_deriv)))
3653                    / (2.0 * h);
3654                assert!(
3655                    (resid.entry[i] - fd).abs() < 1e-6,
3656                    "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3657                    resid.entry[i],
3658                    fd
3659                );
3660            }
3661            // derivative channel: only event rows contribute.
3662            {
3663                let mut op = o_deriv.clone();
3664                let mut om = o_deriv.clone();
3665                op[i] += h;
3666                om[i] -= h;
3667                let fd = (nll(&build(&o_entry, &o_exit, &op))
3668                    - nll(&build(&o_entry, &o_exit, &om)))
3669                    / (2.0 * h);
3670                assert!(
3671                    (resid.derivative[i] - fd).abs() < 1e-6,
3672                    "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3673                    resid.derivative[i],
3674                    fd
3675                );
3676            }
3677        }
3678    }
3679
3680    #[test]
3681    fn offset_channel_residuals_respect_zero_sampleweight() {
3682        let age_entry = array![1.0_f64, 2.0];
3683        let age_exit = array![2.0_f64, 3.5];
3684        let event_target = array![1u8, 1u8];
3685        let event_competing = array![0u8, 0u8];
3686        let sampleweight = array![0.0_f64, 1.2]; // row 0 is excluded by weight
3687        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3688        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3689        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3690        let penalties = PenaltyBlocks::new(Vec::new());
3691        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3692        let beta = array![-1.0_f64, 0.8];
3693
3694        let model = survival_model_with_offsets(
3695            survival_inputs(
3696                &age_entry,
3697                &age_exit,
3698                &event_target,
3699                &event_competing,
3700                &sampleweight,
3701                &x_entry,
3702                &x_exit,
3703                &x_derivative,
3704            ),
3705            Some(SurvivalBaselineOffsets {
3706                eta_entry: array![0.0_f64, 0.1].view(),
3707                eta_exit: array![0.0_f64, 0.2].view(),
3708                derivative_exit: array![0.0_f64, 0.1].view(),
3709            }),
3710            penalties,
3711            mono,
3712            SurvivalSpec::Net,
3713        )
3714        .expect("model");
3715        let r = model.offset_channel_residuals(&beta).expect("resid");
3716        // Row 0 (sampleweight=0) must contribute zero in every channel.
3717        assert_eq!(r.exit[0], 0.0);
3718        assert_eq!(r.entry[0], 0.0);
3719        assert_eq!(r.derivative[0], 0.0);
3720        // Row 1 must still carry a nonzero exit-channel residual.
3721        assert!(r.exit[1] != 0.0);
3722    }
3723
3724    #[test]
3725    fn offset_channel_residuals_reject_beta_dim_mismatch() {
3726        let age_entry = array![1.0_f64];
3727        let age_exit = array![2.0_f64];
3728        let event_target = array![1u8];
3729        let event_competing = array![0u8];
3730        let sampleweight = array![1.0_f64];
3731        let x_entry = array![[1.0, 0.0]];
3732        let x_exit = array![[1.0, 0.7]];
3733        let x_derivative = array![[0.0, 0.5]];
3734        let penalties = PenaltyBlocks::new(Vec::new());
3735        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3736        let model = survival_model(
3737            survival_inputs(
3738                &age_entry,
3739                &age_exit,
3740                &event_target,
3741                &event_competing,
3742                &sampleweight,
3743                &x_entry,
3744                &x_exit,
3745                &x_derivative,
3746            ),
3747            penalties,
3748            mono,
3749            SurvivalSpec::Net,
3750        )
3751        .expect("model");
3752        let bad_beta = array![0.0_f64]; // should be length 2
3753        let err = model
3754            .offset_channel_residuals(&bad_beta)
3755            .expect_err("mismatch must error");
3756        match err {
3757            EstimationError::InvalidInput(msg) => {
3758                assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3759            }
3760            other => panic!("expected InvalidInput, got {other:?}"),
3761        }
3762    }
3763
3764    #[test]
3765    fn crudespec_is_rejected_by_one_hazard_engine() {
3766        let age_entry = array![1.0_f64];
3767        let age_exit = array![2.0_f64];
3768        let event_target = array![0u8];
3769        let event_competing = array![1u8];
3770        let sampleweight = array![1.0];
3771        let x_entry = array![[0.1]];
3772        let x_exit = array![[0.4]];
3773        let x_derivative = array![[1.0]];
3774        let penalties = PenaltyBlocks::new(Vec::new());
3775        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3776
3777        let err = survival_model(
3778            survival_inputs(
3779                &age_entry,
3780                &age_exit,
3781                &event_target,
3782                &event_competing,
3783                &sampleweight,
3784                &x_entry,
3785                &x_exit,
3786                &x_derivative,
3787            ),
3788            penalties,
3789            mono,
3790            SurvivalSpec::Crude,
3791        )
3792        .expect_err("crude fitting should be rejected by the one-hazard engine");
3793        assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3794    }
3795
3796    #[test]
3797    fn nonstructural_models_require_explicit_monotonicity_collocation() {
3798        let age_entry = array![1.0_f64, 1.5_f64];
3799        let age_exit = array![2.0_f64, 2.5_f64];
3800        let event_target = array![0u8, 0u8];
3801        let event_competing = array![0u8, 1u8];
3802        let sampleweight = array![1.0, 1.0];
3803        let x_entry = array![[0.2], [0.1]];
3804        let x_exit = array![[0.3], [0.2]];
3805        let x_derivative = array![[1.0], [1.0]];
3806
3807        let model = survival_model(
3808            survival_inputs(
3809                &age_entry,
3810                &age_exit,
3811                &event_target,
3812                &event_competing,
3813                &sampleweight,
3814                &x_entry,
3815                &x_exit,
3816                &x_derivative,
3817            ),
3818            PenaltyBlocks::new(Vec::new()),
3819            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3820            SurvivalSpec::Net,
3821        )
3822        .expect("construct censored survival model");
3823
3824        assert!(
3825            model.monotonicity_linear_constraints().is_none(),
3826            "non-structural survival models must not fabricate rowwise monotonicity constraints"
3827        );
3828    }
3829
3830    #[test]
3831    fn decreasing_interval_is_rejectedwithout_target_events() {
3832        let age_entry = array![1.0_f64];
3833        let age_exit = array![2.0_f64];
3834        let event_target = array![0u8];
3835        let event_competing = array![0u8];
3836        let sampleweight = array![1.0];
3837        let x_entry = array![[0.5]];
3838        let x_exit = array![[0.0]];
3839        let x_derivative = array![[1.0]];
3840
3841        let model = survival_model(
3842            survival_inputs(
3843                &age_entry,
3844                &age_exit,
3845                &event_target,
3846                &event_competing,
3847                &sampleweight,
3848                &x_entry,
3849                &x_exit,
3850                &x_derivative,
3851            ),
3852            PenaltyBlocks::new(Vec::new()),
3853            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3854            SurvivalSpec::Net,
3855        )
3856        .expect("construct censored survival model");
3857
3858        let err = model
3859            .update_state(&array![1.0])
3860            .expect_err("decreasing cumulative hazard increment should be rejected");
3861        assert!(
3862            err.to_string().contains("cumulative hazard decreased"),
3863            "unexpected error: {err}"
3864        );
3865    }
3866
3867    fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3868        calculate_crude_risk_quadrature(
3869            0.0,
3870            1.0,
3871            &[0.0, 1.0],
3872            beta_d.exp(),
3873            beta_m.exp(),
3874            array![1.0].view(),
3875            array![1.0].view(),
3876            |u, design_d, deriv_d, design_m| {
3877                let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3878                let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3879                let inst_hazard_d = 0.2 * beta_d.exp();
3880                design_d[0] = 1.0;
3881                // η_d = β_d + ln(1 + 0.2u), so η̇_d = 0.2/(1+0.2u)
3882                // which does not depend on β_d → ∂_{β_d} η̇_d = 0
3883                deriv_d[0] = 0.0;
3884                design_m[0] = 1.0;
3885                Ok((inst_hazard_d, cumulative_d, cumulative_m))
3886            },
3887        )
3888        .expect("smooth crude-risk quadrature should succeed")
3889    }
3890
3891    #[test]
3892    fn crude_riskgradient_matches_monotoneobjective() {
3893        let beta_d = -0.2_f64;
3894        let beta_m = -0.5_f64;
3895        let result = smooth_crude_risk(beta_d, beta_m);
3896        let eps = 1e-6;
3897
3898        let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3899            - smooth_crude_risk(beta_d - eps, beta_m).risk)
3900            / (2.0 * eps);
3901        let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3902            - smooth_crude_risk(beta_d, beta_m - eps).risk)
3903            / (2.0 * eps);
3904
3905        assert!(
3906            (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3907            "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3908            result.diseasegradient[0]
3909        );
3910        assert!(
3911            (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3912            "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3913            result.mortalitygradient[0]
3914        );
3915    }
3916
3917    #[test]
3918    fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3919        let age_entry = array![1.0_f64, 2.0_f64];
3920        let age_exit = array![2.0_f64, 3.5_f64];
3921        let event_target = array![1u8, 0u8];
3922        let event_competing = array![0u8, 0u8];
3923        let sampleweight = array![1.0, 1.0];
3924        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3925        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3926        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3927        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3928            matrix: array![[2.0]],
3929            lambda: 1.7,
3930            range: 1..2,
3931            nullspace_dim: 0,
3932        }]);
3933        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3934        let beta = array![-1.2, 0.4];
3935
3936        let model = survival_model(
3937            survival_inputs(
3938                &age_entry,
3939                &age_exit,
3940                &event_target,
3941                &event_competing,
3942                &sampleweight,
3943                &x_entry,
3944                &x_exit,
3945                &x_derivative,
3946            ),
3947            penalties.clone(),
3948            mono,
3949            SurvivalSpec::Net,
3950        )
3951        .expect("construct survival model");
3952
3953        let state = model.update_state(&beta).expect("survival state");
3954        let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3955        assert!(
3956            (state.penalty_term - expected_penalty).abs() < 1e-12,
3957            "penalty_term mismatch: state={} expected={}",
3958            state.penalty_term,
3959            expected_penalty
3960        );
3961    }
3962
3963    #[test]
3964    fn negative_penalty_lambda_is_rejected() {
3965        let age_entry = array![1.0_f64];
3966        let age_exit = array![2.0_f64];
3967        let event_target = array![1u8];
3968        let event_competing = array![0u8];
3969        let sampleweight = array![1.0];
3970        let x_entry = array![[1.0, 0.0]];
3971        let x_exit = array![[1.0, 0.5]];
3972        let x_derivative = array![[0.0, 1.0]];
3973        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3974            matrix: array![[1.0]],
3975            lambda: -0.1,
3976            range: 1..2,
3977            nullspace_dim: 0,
3978        }]);
3979
3980        let err = survival_model(
3981            survival_inputs(
3982                &age_entry,
3983                &age_exit,
3984                &event_target,
3985                &event_competing,
3986                &sampleweight,
3987                &x_entry,
3988                &x_exit,
3989                &x_derivative,
3990            ),
3991            penalties,
3992            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3993            SurvivalSpec::Net,
3994        )
3995        .expect_err("negative lambda must be rejected");
3996
3997        assert!(matches!(err, SurvivalError::NonFiniteInput));
3998    }
3999
4000    #[test]
4001    fn penalty_block_range_and_shapemust_match_coefficients() {
4002        let age_entry = array![1.0_f64];
4003        let age_exit = array![2.0_f64];
4004        let event_target = array![1u8];
4005        let event_competing = array![0u8];
4006        let sampleweight = array![1.0];
4007        let x_entry = array![[1.0, 0.0]];
4008        let x_exit = array![[1.0, 0.5]];
4009        let x_derivative = array![[0.0, 1.0]];
4010        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4011            matrix: array![[1.0]],
4012            lambda: 0.5,
4013            range: 0..2,
4014            nullspace_dim: 0,
4015        }]);
4016
4017        let err = survival_model(
4018            survival_inputs(
4019                &age_entry,
4020                &age_exit,
4021                &event_target,
4022                &event_competing,
4023                &sampleweight,
4024                &x_entry,
4025                &x_exit,
4026                &x_derivative,
4027            ),
4028            penalties,
4029            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4030            SurvivalSpec::Net,
4031        )
4032        .expect_err("penalty block geometry must match coefficient support");
4033
4034        assert!(matches!(err, SurvivalError::DimensionMismatch));
4035    }
4036
4037    #[test]
4038    fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4039        let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4040        let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4041        let event_target = array![1u8, 0u8, 1u8];
4042        let event_competing = array![0u8, 0u8, 0u8];
4043        let sampleweight = array![1.0, 1.0, 1.0];
4044        let x_entry = array![
4045            [1.0, age_entry[0].ln()],
4046            [1.0, age_entry[1].ln()],
4047            [1.0, age_entry[2].ln()]
4048        ];
4049        let x_exit = array![
4050            [1.0, age_exit[0].ln()],
4051            [1.0, age_exit[1].ln()],
4052            [1.0, age_exit[2].ln()]
4053        ];
4054        let x_derivative = array![
4055            [0.0, 1.0 / age_exit[0]],
4056            [0.0, 1.0 / age_exit[1]],
4057            [0.0, 1.0 / age_exit[2]]
4058        ];
4059        let penalties = PenaltyBlocks::new(Vec::new());
4060        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4061        let beta = array![-1.0, 3.0];
4062
4063        let model = survival_model(
4064            survival_inputs(
4065                &age_entry,
4066                &age_exit,
4067                &event_target,
4068                &event_competing,
4069                &sampleweight,
4070                &x_entry,
4071                &x_exit,
4072                &x_derivative,
4073            ),
4074            penalties,
4075            mono,
4076            SurvivalSpec::Net,
4077        )
4078        .expect("construct survival model");
4079
4080        let state = model.update_state(&beta).expect("state at beta");
4081        let eps = 1e-7;
4082        for j in 0..beta.len() {
4083            let mut plus = beta.clone();
4084            let mut minus = beta.clone();
4085            plus[j] += eps;
4086            minus[j] -= eps;
4087            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4088            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4089            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4090            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4091            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4092            assert_eq!(
4093                state.gradient[j].signum(),
4094                fd.signum(),
4095                "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4096                state.gradient[j]
4097            );
4098            assert!(
4099                (state.gradient[j] - fd).abs() < 1e-5,
4100                "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4101                state.gradient[j]
4102            );
4103        }
4104    }
4105
4106    fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4107        // 20-subject survival fixture with mean-centered log-age time
4108        // covariate, balanced events/censorings, and moderate hazard levels.
4109        // The fixture is large enough that the observed-information Hessian
4110        // is well-conditioned at the MLE so PIRLS reaches the 1e-10 KKT
4111        // tolerance in well under 80 iterations from the starting beta used
4112        // below.
4113        let age_entry: Array1<f64> = Array1::from(vec![
4114            30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4115            34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4116        ]);
4117        let age_exit: Array1<f64> = Array1::from(vec![
4118            45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4119            48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4120        ]);
4121        let event_target = Array1::from(vec![
4122            1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4123        ]);
4124        let event_competing = Array1::<u8>::zeros(age_entry.len());
4125        let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4126        let n = age_entry.len();
4127        let ln_age_mean: f64 = {
4128            let mut sum = 0.0;
4129            for i in 0..n {
4130                sum += age_entry[i].ln() + age_exit[i].ln();
4131            }
4132            sum / (2.0 * n as f64)
4133        };
4134        let mut x_entry = Array2::<f64>::zeros((n, 2));
4135        let mut x_exit = Array2::<f64>::zeros((n, 2));
4136        let mut x_derivative = Array2::<f64>::zeros((n, 2));
4137        for i in 0..n {
4138            x_entry[[i, 0]] = 1.0;
4139            x_exit[[i, 0]] = 1.0;
4140            x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4141            x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4142            x_derivative[[i, 0]] = 0.0;
4143            x_derivative[[i, 1]] = 1.0 / age_exit[i];
4144        }
4145        let penalties = PenaltyBlocks::new(vec![
4146            PenaltyBlock {
4147                matrix: array![[3.0]],
4148                lambda: 0.0,
4149                range: 0..1,
4150                nullspace_dim: 0,
4151            },
4152            PenaltyBlock {
4153                matrix: array![[2.5]],
4154                lambda,
4155                range: 1..2,
4156                nullspace_dim: 0,
4157            },
4158        ]);
4159        survival_model(
4160            survival_inputs(
4161                &age_entry,
4162                &age_exit,
4163                &event_target,
4164                &event_competing,
4165                &sampleweight,
4166                &x_entry,
4167                &x_exit,
4168                &x_derivative,
4169            ),
4170            penalties,
4171            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4172            SurvivalSpec::Net,
4173        )
4174        .expect("construct LAML FD survival model")
4175    }
4176
4177    fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4178        use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4179        use gam_linalg::faer_ndarray::FaerEigh;
4180
4181        let h_dense = state.hessian.to_dense();
4182        let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4183        let eps = spectral_epsilon(evals.as_slice().unwrap());
4184        evals
4185            .iter()
4186            .map(|&sigma| spectral_regularize(sigma, eps).ln())
4187            .sum()
4188    }
4189
4190    #[test]
4191    fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4192        // The core claim under test: the survival LAML rho-gradient and the
4193        // documented LAML objective enumerate only penalty blocks whose
4194        // lambda is actually active (> 0). An inactive prefix block must
4195        // therefore contribute neither a log|lambda * S| term to the
4196        // objective nor an entry to the rho-gradient vector.
4197        //
4198        // We verify the objective formula and the gradient dimensionality at
4199        // a fixed beta rather than a fitted one: the bug this test guards
4200        // against was purely algebraic enumeration over penalty blocks and
4201        // has no dependence on PIRLS convergence quality. A gradient-vs-FD
4202        // comparison would require beta to sit at the joint MLE of a tiny
4203        // synthetic survival fixture, which the analytic Newton/PIRLS path
4204        // cannot reach to 1e-10 KKT tolerance without a much richer design.
4205        let rho0 = -0.35_f64;
4206        let beta = array![-2.5_f64, 1.0];
4207        let model = laml_fd_test_model(rho0.exp());
4208        let state = model
4209            .update_state(&beta)
4210            .expect("state for LAML prefix-skip test");
4211
4212        // Sanity: the fixture has two penalty blocks; the first has
4213        // lambda = 0 (inactive prefix) and the second has lambda > 0
4214        // (active). If a future refactor flips this ordering, the prefix
4215        // skip being exercised here would silently become an identity test.
4216        assert_eq!(model.penalties.blocks.len(), 2);
4217        assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4218        assert!(model.penalties.blocks[1].lambda > 0.0);
4219
4220        let rho = Array1::from_iter(
4221            model
4222                .penalties
4223                .blocks
4224                .iter()
4225                .filter(|b| b.lambda > 0.0)
4226                .map(|b| b.lambda.ln()),
4227        );
4228        assert_eq!(
4229            rho.len(),
4230            1,
4231            "fixture should expose exactly one active penalty block for the rho vector"
4232        );
4233
4234        let (obj, grad) = model
4235            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4236            .expect("survival LAML objective and gradient");
4237
4238        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4239            - 0.5 * (rho0 + 2.5_f64.ln());
4240        assert_eq!(
4241            grad.len(),
4242            1,
4243            "rho-gradient must match the active-penalty count, not the full block list"
4244        );
4245        assert!(
4246            (obj - expected).abs() < 1e-10,
4247            "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4248        );
4249        assert!(
4250            grad[0].is_finite(),
4251            "rho-gradient must be finite: {}",
4252            grad[0]
4253        );
4254    }
4255
4256    #[test]
4257    fn structural_monotonicgradient_matchesobjectivefd() {
4258        let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4259        let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4260        let event_target = array![1u8, 0u8, 1u8];
4261        let event_competing = array![0u8, 0u8, 0u8];
4262        let sampleweight = array![1.0, 1.0, 1.0];
4263
4264        // Time block has 3 structural-monotone columns.
4265        // Final column is a covariate, left unconstrained.
4266        let x_entry = array![
4267            [1.0, 0.2, 0.05, -0.7],
4268            [1.0, 0.5, 0.20, 0.1],
4269            [1.0, 0.9, 0.60, 1.2]
4270        ];
4271        let x_exit = array![
4272            [1.0, 0.4, 0.16, -0.7],
4273            [1.0, 0.8, 0.64, 0.1],
4274            [1.0, 1.1, 1.21, 1.2]
4275        ];
4276        let x_derivative = array![
4277            [0.0, 0.8, 0.64, 0.0],
4278            [0.0, 0.7, 1.12, 0.0],
4279            [0.0, 0.6, 1.32, 0.0]
4280        ];
4281        let penalties = PenaltyBlocks::new(Vec::new());
4282        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4283        let mut model = survival_model(
4284            survival_inputs(
4285                &age_entry,
4286                &age_exit,
4287                &event_target,
4288                &event_competing,
4289                &sampleweight,
4290                &x_entry,
4291                &x_exit,
4292                &x_derivative,
4293            ),
4294            penalties,
4295            mono,
4296            SurvivalSpec::Net,
4297        )
4298        .expect("construct structural survival model");
4299        model
4300            .set_structural_monotonicity(true, 3)
4301            .expect("enable structural monotonicity");
4302        let constraints = model
4303            .monotonicity_linear_constraints()
4304            .expect("structural derivative constraints");
4305        assert_eq!(constraints.a.nrows(), 2);
4306        assert_eq!(constraints.a.ncols(), 4);
4307        assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4308        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4309        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4310
4311        let beta = array![0.2, 0.2, 0.1, 0.2];
4312        let state = model.update_state(&beta).expect("state at structural beta");
4313        let eps = 1e-7;
4314        for j in 0..beta.len() {
4315            let mut plus = beta.clone();
4316            let mut minus = beta.clone();
4317            plus[j] += eps;
4318            minus[j] -= eps;
4319            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4320            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4321            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4322            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4323            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4324            assert_eq!(
4325                state.gradient[j].signum(),
4326                fd.signum(),
4327                "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4328                state.gradient[j]
4329            );
4330            assert!(
4331                (state.gradient[j] - fd).abs() < 2e-5,
4332                "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4333                state.gradient[j]
4334            );
4335        }
4336    }
4337
4338    #[test]
4339    fn structural_monotonic_lamlgradient_returns_finitevalues() {
4340        let age_entry = array![1.0_f64, 1.2_f64];
4341        let age_exit = array![1.5_f64, 2.0_f64];
4342        let event_target = array![1u8, 0u8];
4343        let event_competing = array![0u8, 0u8];
4344        let sampleweight = array![1.0, 1.0];
4345        let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4346        let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4347        let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4348        let penalties = PenaltyBlocks::new(Vec::new());
4349        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4350        let mut model = survival_model(
4351            survival_inputs(
4352                &age_entry,
4353                &age_exit,
4354                &event_target,
4355                &event_competing,
4356                &sampleweight,
4357                &x_entry,
4358                &x_exit,
4359                &x_derivative,
4360            ),
4361            penalties,
4362            mono,
4363            SurvivalSpec::Net,
4364        )
4365        .expect("construct structural survival model");
4366        model
4367            .set_structural_monotonicity(true, 2)
4368            .expect("enable structural monotonicity");
4369        // One simple penalty block to exercise rho-gradient path.
4370        model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4371            matrix: array![[1.0]],
4372            lambda: 0.7,
4373            range: 1..2,
4374            nullspace_dim: 0,
4375        }]);
4376        let beta = array![0.2, 0.2, 0.1];
4377        let state = model.update_state(&beta).expect("state at structural beta");
4378        let rho = Array1::from_iter(
4379            model
4380                .penalties
4381                .blocks
4382                .iter()
4383                .filter(|b| b.lambda > 0.0)
4384                .map(|b| b.lambda.ln()),
4385        );
4386        let (obj, grad) = model
4387            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4388            .expect("laml gradient should work in structural mode");
4389        assert!(obj.is_finite());
4390        assert_eq!(grad.len(), 1);
4391        assert!(grad[0].is_finite());
4392    }
4393
4394    #[test]
4395    fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4396        let age_entry = array![1.0_f64];
4397        let age_exit = array![2.0_f64];
4398        let event_target = array![1u8];
4399        let event_competing = array![0u8];
4400        let sampleweight = array![1.0];
4401        let x_entry = array![[0.0]];
4402        let x_exit = array![[0.2]];
4403        let x_derivative = array![[1.0]];
4404
4405        let penalties = PenaltyBlocks::new(Vec::new());
4406        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4407        let mut model = survival_model(
4408            survival_inputs(
4409                &age_entry,
4410                &age_exit,
4411                &event_target,
4412                &event_competing,
4413                &sampleweight,
4414                &x_entry,
4415                &x_exit,
4416                &x_derivative,
4417            ),
4418            penalties,
4419            mono,
4420            SurvivalSpec::Net,
4421        )
4422        .expect("construct structural survival model");
4423
4424        let beta = array![-3.0];
4425        assert!(
4426            model.update_state(&beta).is_err(),
4427            "negative derivative coefficient should violate derivative guard"
4428        );
4429
4430        model
4431            .set_structural_monotonicity(true, 1)
4432            .expect("enable structural monotonicity");
4433        let constraints = model
4434            .monotonicity_linear_constraints()
4435            .expect("structural derivative constraints");
4436        assert_eq!(constraints.a.nrows(), 1);
4437        assert_eq!(constraints.a.ncols(), 1);
4438        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4439        // Structural monotonicity uses derivative_guard() == 0.0
4440        assert!(constraints.b[0].abs() <= 1e-12);
4441        let state = model
4442            .update_state(&array![1e-6])
4443            .expect("small positive derivative coefficient should remain feasible");
4444        assert!(state.deviance.is_finite());
4445    }
4446
4447    #[test]
4448    fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4449        let age_entry = array![1.0_f64];
4450        let age_exit = array![2.0_f64];
4451        let event_target = array![1u8];
4452        let event_competing = array![0u8];
4453        let sampleweight = array![1.0];
4454        let x_entry = array![[1.0, 0.0]];
4455        let x_exit = array![[1.0, 0.0]];
4456        let x_derivative = array![[0.0, 0.0]];
4457        let penalties = PenaltyBlocks::new(Vec::new());
4458        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4459        let eta_entry_offset = array![0.0];
4460        let eta_exit_offset = array![0.0];
4461        let derivative_offset_below_guard = array![2.0];
4462        let derivative_offset_above_guard = array![3.1];
4463        let offsets_below_guard = SurvivalBaselineOffsets {
4464            eta_entry: eta_entry_offset.view(),
4465            eta_exit: eta_exit_offset.view(),
4466            derivative_exit: derivative_offset_below_guard.view(),
4467        };
4468        let offsets_above_guard = SurvivalBaselineOffsets {
4469            eta_entry: eta_entry_offset.view(),
4470            eta_exit: eta_exit_offset.view(),
4471            derivative_exit: derivative_offset_above_guard.view(),
4472        };
4473
4474        let model_below_guard = survival_model_with_offsets(
4475            survival_inputs(
4476                &age_entry,
4477                &age_exit,
4478                &event_target,
4479                &event_competing,
4480                &sampleweight,
4481                &x_entry,
4482                &x_exit,
4483                &x_derivative,
4484            ),
4485            Some(offsets_below_guard),
4486            penalties.clone(),
4487            monotonicity,
4488            SurvivalSpec::Net,
4489        )
4490        .expect("construct model with derivative offset below guard");
4491        let err = model_below_guard
4492            .update_state(&array![0.0, 0.0])
4493            .expect_err("derivative offset below guard should be rejected");
4494        let err_text = err.to_string();
4495        assert!(
4496            err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4497            "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4498        );
4499
4500        let model_above_guard = survival_model_with_offsets(
4501            survival_inputs(
4502                &age_entry,
4503                &age_exit,
4504                &event_target,
4505                &event_competing,
4506                &sampleweight,
4507                &x_entry,
4508                &x_exit,
4509                &x_derivative,
4510            ),
4511            Some(offsets_above_guard),
4512            penalties,
4513            SurvivalMonotonicityPenalty { tolerance: 3.0 },
4514            SurvivalSpec::Net,
4515        )
4516        .expect("construct model with derivative offset above guard");
4517        let state = model_above_guard
4518            .update_state(&array![0.0, 0.0])
4519            .expect("derivative offset above guard should remain feasible");
4520        assert!(state.deviance.is_finite());
4521    }
4522
4523    #[test]
4524    fn structural_monotonicity_rejects_negative_derivative_offsets() {
4525        let age_entry = array![1.0_f64];
4526        let age_exit = array![2.0_f64];
4527        let event_target = array![1u8];
4528        let event_competing = array![0u8];
4529        let sampleweight = array![1.0];
4530        let x_entry = array![[0.0]];
4531        let x_exit = array![[0.2]];
4532        let x_derivative = array![[1.0]];
4533        let eta_entry = array![0.0];
4534        let eta_exit = array![0.0];
4535        let derivative_exit = array![-1e-3];
4536        let offsets = SurvivalBaselineOffsets {
4537            eta_entry: eta_entry.view(),
4538            eta_exit: eta_exit.view(),
4539            derivative_exit: derivative_exit.view(),
4540        };
4541
4542        let mut model = survival_model_with_offsets(
4543            survival_inputs(
4544                &age_entry,
4545                &age_exit,
4546                &event_target,
4547                &event_competing,
4548                &sampleweight,
4549                &x_entry,
4550                &x_exit,
4551                &x_derivative,
4552            ),
4553            Some(offsets),
4554            PenaltyBlocks::new(Vec::new()),
4555            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4556            SurvivalSpec::Net,
4557        )
4558        .expect("construct structural survival model");
4559        let err = model
4560            .set_structural_monotonicity(true, 1)
4561            .expect_err("negative derivative offsets must be rejected");
4562        assert!(
4563            err.to_string()
4564                .contains("structural monotonicity requires nonnegative derivative offsets"),
4565            "unexpected error: {err}"
4566        );
4567    }
4568
4569    #[test]
4570    fn structural_monotonicity_emits_coefficient_constraints() {
4571        let age_entry = array![1.0_f64, 1.5_f64];
4572        let age_exit = array![2.0_f64, 3.0_f64];
4573        let event_target = array![1u8, 0u8];
4574        let event_competing = array![0u8, 0u8];
4575        let sampleweight = array![1.0, 1.0];
4576        let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4577        let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4578        let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4579
4580        let mut model = survival_model(
4581            survival_inputs(
4582                &age_entry,
4583                &age_exit,
4584                &event_target,
4585                &event_competing,
4586                &sampleweight,
4587                &x_entry,
4588                &x_exit,
4589                &x_derivative,
4590            ),
4591            PenaltyBlocks::new(Vec::new()),
4592            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4593            SurvivalSpec::Net,
4594        )
4595        .expect("construct structural survival model");
4596        model
4597            .set_structural_monotonicity(true, 2)
4598            .expect("enable structural monotonicity");
4599
4600        let constraints = model
4601            .monotonicity_linear_constraints()
4602            .expect("structural derivative constraints");
4603
4604        assert_eq!(constraints.a.nrows(), 2);
4605        assert_eq!(constraints.a.ncols(), 3);
4606        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4607        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4608        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4609    }
4610
4611    #[test]
4612    fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4613        let age_entry = array![1.0_f64];
4614        let age_exit = array![2.0_f64];
4615        let event_target = array![1u8];
4616        let event_competing = array![0u8];
4617        let sampleweight = array![1.0];
4618        let x_entry = array![[1.0, 0.2]];
4619        let x_exit = array![[1.0, 0.6]];
4620        let x_derivative = array![[0.0, 1.0]];
4621
4622        let mut model = survival_model(
4623            survival_inputs(
4624                &age_entry,
4625                &age_exit,
4626                &event_target,
4627                &event_competing,
4628                &sampleweight,
4629                &x_entry,
4630                &x_exit,
4631                &x_derivative,
4632            ),
4633            PenaltyBlocks::new(Vec::new()),
4634            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4635            SurvivalSpec::Net,
4636        )
4637        .expect("construct structural survival model");
4638        model
4639            .set_structural_monotonicity(true, 2)
4640            .expect("enable structural monotonicity");
4641
4642        let constraints = model
4643            .monotonicity_linear_constraints()
4644            .expect("structural derivative constraints");
4645
4646        assert_eq!(constraints.a.nrows(), 1);
4647        assert!(
4648            constraints.a[[0, 0]].abs() <= 1e-12,
4649            "inactive time column should remain unconstrained"
4650        );
4651        assert!(
4652            (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4653            "active time column should remain constrained"
4654        );
4655    }
4656
4657    #[test]
4658    fn structural_monotonicity_preserves_sparse_row_patterns() {
4659        let age_entry = array![1.0_f64, 1.5_f64];
4660        let age_exit = array![2.0_f64, 2.5_f64];
4661        let event_target = array![1u8, 1u8];
4662        let event_competing = array![0u8, 0u8];
4663        let sampleweight = array![1.0, 1.0];
4664        let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4665        let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4666        let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4667
4668        let mut model = survival_model(
4669            survival_inputs(
4670                &age_entry,
4671                &age_exit,
4672                &event_target,
4673                &event_competing,
4674                &sampleweight,
4675                &x_entry,
4676                &x_exit,
4677                &x_derivative,
4678            ),
4679            PenaltyBlocks::new(Vec::new()),
4680            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4681            SurvivalSpec::Net,
4682        )
4683        .expect("construct structural survival model");
4684        model
4685            .set_structural_monotonicity(true, 2)
4686            .expect("enable structural monotonicity");
4687
4688        let constraints = model
4689            .monotonicity_linear_constraints()
4690            .expect("structural derivative constraints");
4691
4692        assert_eq!(constraints.a.nrows(), 2);
4693        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4694        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4695    }
4696
4697    #[test]
4698    fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4699        let age_entry = array![1.0_f64];
4700        let age_exit = array![1.1_f64];
4701        let event_target = array![0u8];
4702        let event_competing = array![0u8];
4703        let sampleweight = array![1.0];
4704        let x_entry = array![[0.0]];
4705        let x_exit = array![[0.0]];
4706        let x_derivative = array![[-1.0]];
4707        let penalties = PenaltyBlocks::new(Vec::new());
4708        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4709        let model = survival_model(
4710            survival_inputs(
4711                &age_entry,
4712                &age_exit,
4713                &event_target,
4714                &event_competing,
4715                &sampleweight,
4716                &x_entry,
4717                &x_exit,
4718                &x_derivative,
4719            ),
4720            penalties,
4721            mono,
4722            SurvivalSpec::Net,
4723        )
4724        .expect("construct censored survival model");
4725
4726        let err = model
4727            .update_state(&array![1.0])
4728            .expect_err("censored row should still enforce monotonic derivative");
4729        assert!(
4730            matches!(err, EstimationError::ParameterConstraintViolation(_)),
4731            "unexpected error: {err:?}"
4732        );
4733    }
4734
4735    fn crude_risk_quadrature_error(
4736        cumulative_entry: f64,
4737        cumulative_exit: f64,
4738        hazard_exit: f64,
4739    ) -> SurvivalError {
4740        calculate_crude_risk_quadrature(
4741            1.0,
4742            2.0,
4743            &[],
4744            0.4,
4745            0.2,
4746            array![1.0].view(),
4747            array![1.0].view(),
4748            |_, design_d, deriv_d, design_m| {
4749                design_d[0] = 1.0;
4750                deriv_d[0] = 0.0;
4751                design_m[0] = 1.0;
4752                Ok((cumulative_entry, cumulative_exit, hazard_exit))
4753            },
4754        )
4755        .expect_err("invalid hazards should fail")
4756    }
4757
4758    #[test]
4759    fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4760        let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4761        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4762    }
4763
4764    #[test]
4765    fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4766        let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4767        assert!(matches!(err, SurvivalError::NonPositiveHazard));
4768    }
4769
4770    #[test]
4771    fn laml_no_penalties_matches_documentedobjective() {
4772        let age_entry = array![40.0, 45.0, 50.0, 55.0];
4773        let age_exit = array![44.0, 49.0, 54.0, 59.0];
4774        let event_target = array![1u8, 0u8, 1u8, 0u8];
4775        let event_competing = Array1::<u8>::zeros(4);
4776        let sampleweight = Array1::ones(4);
4777        let x_entry = array![
4778            [1.0, -0.2, 0.04],
4779            [1.0, -0.1, 0.01],
4780            [1.0, 0.0, 0.0],
4781            [1.0, 0.1, 0.01]
4782        ];
4783        let x_exit = array![
4784            [1.0, -0.12, 0.0144],
4785            [1.0, -0.02, 0.0004],
4786            [1.0, 0.08, 0.0064],
4787            [1.0, 0.18, 0.0324]
4788        ];
4789        let x_derivative = array![
4790            [0.0, 0.02, 0.001],
4791            [0.0, 0.02, 0.001],
4792            [0.0, 0.02, 0.001],
4793            [0.0, 0.02, 0.001]
4794        ];
4795        let penalties = PenaltyBlocks::new(Vec::new());
4796        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4797        let beta = array![-2.0, 0.7, 0.2];
4798
4799        let model = survival_model(
4800            survival_inputs(
4801                &age_entry,
4802                &age_exit,
4803                &event_target,
4804                &event_competing,
4805                &sampleweight,
4806                &x_entry,
4807                &x_exit,
4808                &x_derivative,
4809            ),
4810            penalties,
4811            mono,
4812            SurvivalSpec::Net,
4813        )
4814        .expect("construct survival model");
4815
4816        let state = model.update_state(&beta).expect("state at beta");
4817        let rho = Array1::from_iter(
4818            model
4819                .penalties
4820                .blocks
4821                .iter()
4822                .filter(|b| b.lambda > 0.0)
4823                .map(|b| b.lambda.ln()),
4824        );
4825        let (obj, grad) = model
4826            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4827            .expect("laml objective for no-penalty model");
4828
4829        let h_dense = state.hessian.to_dense();
4830        let logdet_h: f64 = {
4831            use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4832            use gam_linalg::faer_ndarray::FaerEigh;
4833            let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4834            let eps = spectral_epsilon(evals.as_slice().unwrap());
4835            evals
4836                .iter()
4837                .map(|&sigma| spectral_regularize(sigma, eps).ln())
4838                .sum()
4839        };
4840        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4841
4842        assert_eq!(grad.len(), 0);
4843        assert!(
4844            (obj - expected).abs() < 1e-10,
4845            "no-penalty LAML objective mismatch: obj={} expected={}",
4846            obj,
4847            expected
4848        );
4849    }
4850
4851    #[test]
4852    fn monotonicity_constraints_collapse_positive_collinearrows() {
4853        let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4854        let b = array![1e-8, 1e-8, 1e-8];
4855
4856        let compressed = compress_positive_collinear_constraints(&a, &b);
4857
4858        assert_eq!(compressed.a.nrows(), 1);
4859        assert_eq!(compressed.a.ncols(), 3);
4860        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4861        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4862        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4863        assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4864    }
4865
4866    #[test]
4867    fn monotonicity_constraints_preserve_distinct_directions() {
4868        let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4869        let b = array![0.2, 0.3, 0.1];
4870
4871        let compressed = compress_positive_collinear_constraints(&a, &b);
4872
4873        assert_eq!(compressed.a.nrows(), 2);
4874        let mut saw_x = false;
4875        let mut saw_y = false;
4876        for i in 0..compressed.a.nrows() {
4877            if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4878                saw_x = true;
4879                assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4880            }
4881            if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4882                saw_y = true;
4883                assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4884            }
4885        }
4886        assert!(saw_x);
4887        assert!(saw_y);
4888    }
4889
4890    #[test]
4891    fn monotonicity_constraints_cluster_near_collinearrows() {
4892        let a = array![
4893            [0.0, 0.5, 0.0],
4894            [0.0, 0.50000000003, 0.0],
4895            [0.0, 0.49999999997, 0.0]
4896        ];
4897        let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4898
4899        let compressed = compress_positive_collinear_constraints(&a, &b);
4900
4901        assert_eq!(compressed.a.nrows(), 1);
4902        assert_eq!(compressed.a.ncols(), 3);
4903        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4904        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4905        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4906        assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4907    }
4908
4909    #[test]
4910    fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4911        let a = array![
4912            [0.0, 0.401, 0.302, 0.197],
4913            [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4914            [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4915            [0.0, 0.125, 0.500, 0.375]
4916        ];
4917        let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4918
4919        let compressed = compress_positive_collinear_constraints(&a, &b);
4920
4921        assert_eq!(compressed.a.nrows(), 2);
4922        let mut clustered_face = false;
4923        let mut distinct_face = false;
4924        for i in 0..compressed.a.nrows() {
4925            let row = compressed.a.row(i);
4926            if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4927                clustered_face = true;
4928                assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4929            } else {
4930                distinct_face = true;
4931                assert!((row[1] - 0.25).abs() <= 1e-12);
4932                assert!((row[2] - 1.0).abs() <= 1e-12);
4933                assert!((row[3] - 0.75).abs() <= 1e-12);
4934                assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4935            }
4936        }
4937        assert!(clustered_face);
4938        assert!(distinct_face);
4939    }
4940
4941    #[test]
4942    fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4943        let age_entry = array![1.0_f64, 1.0, 1.0];
4944        let age_exit = array![2.0_f64, 4.0, 8.0];
4945        let event_target = array![0u8, 1u8, 0u8];
4946        let event_competing = array![0u8, 0u8, 0u8];
4947        let sampleweight = array![1.0, 1.0, 1.0];
4948        let x_entry = array![
4949            [1.0, age_entry[0].ln()],
4950            [1.0, age_entry[1].ln()],
4951            [1.0, age_entry[2].ln()]
4952        ];
4953        let x_exit = array![
4954            [1.0, age_exit[0].ln()],
4955            [1.0, age_exit[1].ln()],
4956            [1.0, age_exit[2].ln()]
4957        ];
4958        let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4959        let penalties = PenaltyBlocks::new(Vec::new());
4960        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4961
4962        let collocation_offsets = Array1::zeros(x_derivative.nrows());
4963        let mut inputs = survival_inputs(
4964            &age_entry,
4965            &age_exit,
4966            &event_target,
4967            &event_competing,
4968            &sampleweight,
4969            &x_entry,
4970            &x_exit,
4971            &x_derivative,
4972        );
4973        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4974        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4975
4976        let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4977            .expect("construct linear survival model");
4978
4979        let constraints = model
4980            .monotonicity_linear_constraints()
4981            .expect("monotonicity constraints");
4982        assert_eq!(constraints.a.nrows(), 1);
4983        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
4984        assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
4985    }
4986
4987    #[test]
4988    fn monotonicity_constraints_skip_numericallyzerorows() {
4989        let age_entry = array![1.0_f64, 1.0, 1.0];
4990        let age_exit = array![2.0_f64, 3.0, 4.0];
4991        let event_target = array![0u8, 0u8, 0u8];
4992        let event_competing = array![0u8, 0u8, 0u8];
4993        let sampleweight = array![1.0, 1.0, 1.0];
4994        let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
4995        let x_exit = x_entry.clone();
4996        let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
4997
4998        let collocation_offsets = Array1::zeros(x_derivative.nrows());
4999        let mut inputs = survival_inputs(
5000            &age_entry,
5001            &age_exit,
5002            &event_target,
5003            &event_competing,
5004            &sampleweight,
5005            &x_entry,
5006            &x_exit,
5007            &x_derivative,
5008        );
5009        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5010        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5011
5012        let model = survival_model(
5013            inputs,
5014            PenaltyBlocks::new(Vec::new()),
5015            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5016            SurvivalSpec::Net,
5017        )
5018        .expect("construct survival model");
5019
5020        let constraints = model
5021            .monotonicity_linear_constraints()
5022            .expect("nonzero derivative row should remain");
5023        assert_eq!(constraints.a.nrows(), 1);
5024        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5025        assert!(constraints.b[0].abs() <= 1e-18);
5026    }
5027
5028    #[test]
5029    fn censoredrows_allowzero_boundary_derivative() {
5030        let age_entry = array![1.0_f64];
5031        let age_exit = array![2.0_f64];
5032        let event_target = array![0u8];
5033        let event_competing = array![0u8];
5034        let sampleweight = array![1.0];
5035        let x_entry = array![[0.0]];
5036        let x_exit = array![[0.0]];
5037        let x_derivative = array![[1.0]];
5038
5039        let model = survival_model(
5040            survival_inputs(
5041                &age_entry,
5042                &age_exit,
5043                &event_target,
5044                &event_competing,
5045                &sampleweight,
5046                &x_entry,
5047                &x_exit,
5048                &x_derivative,
5049            ),
5050            PenaltyBlocks::new(Vec::new()),
5051            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5052            SurvivalSpec::Net,
5053        )
5054        .expect("construct censored survival model");
5055
5056        let state = model
5057            .update_state(&array![0.0])
5058            .expect("censored boundary derivative should remain feasible with zero tolerance");
5059        assert!(state.deviance.is_finite());
5060    }
5061
5062    #[test]
5063    fn eventrows_keep_positive_derivative_constraint() {
5064        let age_entry = array![1.0_f64, 1.0];
5065        let age_exit = array![2.0_f64, 4.0];
5066        let event_target = array![0u8, 1u8];
5067        let event_competing = array![0u8, 0u8];
5068        let sampleweight = array![1.0, 1.0];
5069        let x_entry = array![[0.0], [0.0]];
5070        let x_exit = array![[0.0], [0.0]];
5071        let x_derivative = array![[0.5], [0.25]];
5072
5073        let collocation_offsets = Array1::zeros(x_derivative.nrows());
5074        let mut inputs = survival_inputs(
5075            &age_entry,
5076            &age_exit,
5077            &event_target,
5078            &event_competing,
5079            &sampleweight,
5080            &x_entry,
5081            &x_exit,
5082            &x_derivative,
5083        );
5084        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5085        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5086
5087        let model = survival_model(
5088            inputs,
5089            PenaltyBlocks::new(Vec::new()),
5090            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5091            SurvivalSpec::Net,
5092        )
5093        .expect("construct mixed survival model");
5094
5095        let constraints = model
5096            .monotonicity_linear_constraints()
5097            .expect("event row should induce positive lower bound");
5098        assert_eq!(constraints.a.nrows(), 1);
5099        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5100        assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5101    }
5102
5103    #[test]
5104    fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5105        let age_entry = array![1.0_f64];
5106        let age_exit = array![2.0_f64];
5107        let event_target = array![1u8];
5108        let event_competing = array![0u8];
5109        let sampleweight = array![1.0];
5110        let x_entry = array![[0.0]];
5111        let x_exit = array![[0.0]];
5112        let x_derivative = array![[1.0]];
5113        let mut model = survival_model(
5114            survival_inputs(
5115                &age_entry,
5116                &age_exit,
5117                &event_target,
5118                &event_competing,
5119                &sampleweight,
5120                &x_entry,
5121                &x_exit,
5122                &x_derivative,
5123            ),
5124            PenaltyBlocks::new(Vec::new()),
5125            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5126            SurvivalSpec::Net,
5127        )
5128        .expect("construct survival model");
5129        model
5130            .set_structural_monotonicity(true, 1)
5131            .expect("enable structural monotonicity");
5132
5133        let state = model
5134            .update_state(&array![-1e-8])
5135            .expect("tiny structural roundoff should be clamped");
5136        assert!(state.deviance.is_finite());
5137    }
5138
5139    #[test]
5140    fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5141        let uncompressed_constraints = LinearInequalityConstraints {
5142            a: array![
5143                [0.0, 0.5, 0.0],
5144                [0.0, 1.0 / 3.0, 0.0],
5145                [0.0, 0.2, 0.0],
5146                [0.0, 0.125, 0.0]
5147            ],
5148            b: Array1::from_elem(4, 1e-8),
5149        };
5150        let compressed_constraints = compress_positive_collinear_constraints(
5151            &uncompressed_constraints.a,
5152            &uncompressed_constraints.b,
5153        );
5154
5155        let candidates = [
5156            array![0.0, 1e-9, 0.0],
5157            array![0.0, 4e-8, 0.0],
5158            array![0.0, 8e-8, 0.0],
5159            array![0.0, 2e-7, 1.5],
5160        ];
5161        for beta in candidates {
5162            let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5163                uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5164            });
5165            let compressed_ok = (0..compressed_constraints.a.nrows())
5166                .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5167            assert_eq!(compressed_ok, uncompressed_ok);
5168        }
5169    }
5170
5171    #[test]
5172    fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5173        let age_entry = array![10.0_f64, 20.0, 25.0];
5174        let age_exit = array![15.0_f64, 30.0, 40.0];
5175        let event_target = array![1u8, 0u8, 1u8];
5176        let event_competing = array![0u8, 0u8, 0u8];
5177        let sampleweight = array![1.0, 2.0, 0.5];
5178        let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5179        let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5180        let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5181        let beta = array![0.8, 1.1, -0.2];
5182
5183        let base_model = survival_model(
5184            survival_inputs(
5185                &age_entry,
5186                &age_exit,
5187                &event_target,
5188                &event_competing,
5189                &sampleweight,
5190                &x_entry,
5191                &x_exit,
5192                &x_derivative,
5193            ),
5194            PenaltyBlocks::new(Vec::new()),
5195            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5196            SurvivalSpec::Net,
5197        )
5198        .expect("construct base survival model");
5199        let base_state = base_model
5200            .update_state(&beta)
5201            .expect("evaluate base survival state");
5202
5203        let time_scale = 365.25;
5204        let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5205        let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5206        let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5207        let scaled_model = survival_model(
5208            survival_inputs(
5209                &scaled_age_entry,
5210                &scaled_age_exit,
5211                &event_target,
5212                &event_competing,
5213                &sampleweight,
5214                &x_entry,
5215                &x_exit,
5216                &scaled_x_derivative,
5217            ),
5218            PenaltyBlocks::new(Vec::new()),
5219            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5220            SurvivalSpec::Net,
5221        )
5222        .expect("construct scaled survival model");
5223        let scaled_state = scaled_model
5224            .update_state(&beta)
5225            .expect("evaluate scaled survival state");
5226
5227        let weighted_events = sampleweight
5228            .iter()
5229            .zip(event_target.iter())
5230            .map(|(w, d)| *w * f64::from(*d))
5231            .sum::<f64>();
5232        let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5233        assert!(
5234            (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5235            "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5236            scaled_state.deviance,
5237            base_state.deviance
5238        );
5239
5240        for j in 0..beta.len() {
5241            assert!(
5242                (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5243                "gradient mismatch at j={j}: scaled={} base={}",
5244                scaled_state.gradient[j],
5245                base_state.gradient[j]
5246            );
5247        }
5248
5249        let base_hessian = base_state.hessian.to_dense();
5250        let scaled_hessian = scaled_state.hessian.to_dense();
5251        for r in 0..beta.len() {
5252            for c in 0..beta.len() {
5253                assert!(
5254                    (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5255                    "hessian mismatch at ({r},{c}): scaled={} base={}",
5256                    scaled_hessian[[r, c]],
5257                    base_hessian[[r, c]]
5258                );
5259            }
5260        }
5261    }
5262}