Skip to main content

gam_models/survival/
base.rs

1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3    BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4    projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9    LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21    #[error("input dimensions are inconsistent")]
22    DimensionMismatch,
23    #[error("inputs contain non-finite values")]
24    NonFiniteInput,
25    #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26    UnsupportedSpec(&'static str),
27    #[error("crude risk integration setup is invalid")]
28    InvalidIntegrationSetup,
29    #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30    InvalidTimeGrid,
31    #[error("cumulative hazard must be nondecreasing")]
32    NonMonotoneCumulativeHazard,
33    #[error("instantaneous hazard must stay strictly positive during integration")]
34    NonPositiveHazard,
35    #[error("{reason}")]
36    InvalidInput { reason: String },
37    #[error("{reason}")]
38    CauseSpecificDimensionMismatch { reason: String },
39    #[error("{reason}")]
40    NumericalFailure { reason: String },
41    #[error("{reason}")]
42    EventCodeInvalid { reason: String },
43    #[error("{reason}")]
44    EventDegenerate { reason: String },
45    #[error("cause-specific survival block {block}: {source}")]
46    CauseSpecificBlock {
47        block: usize,
48        #[source]
49        source: Box<SurvivalError>,
50    },
51}
52
53impl From<SurvivalError> for String {
54    fn from(err: SurvivalError) -> Self {
55        err.to_string()
56    }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60    fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61        SurvivalError::CauseSpecificDimensionMismatch {
62            reason: err.message(),
63        }
64    }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69    #[default]
70    Net,
71    Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76    pub age_entry: ArrayView1<'a, f64>,
77    pub age_exit: ArrayView1<'a, f64>,
78    pub event_target: ArrayView1<'a, u8>,
79    pub event_competing: ArrayView1<'a, u8>,
80    pub sampleweight: ArrayView1<'a, f64>,
81    pub x_entry: ArrayView2<'a, f64>,
82    pub x_exit: ArrayView2<'a, f64>,
83    pub x_derivative: ArrayView2<'a, f64>,
84    /// Optional global monotonicity collocation rows for the full coefficient vector.
85    /// Non-structural survival models should pass these explicitly instead of
86    /// relying on observed derivative rows.
87    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
89    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94    pub age_entry: ArrayView1<'a, f64>,
95    pub age_exit: ArrayView1<'a, f64>,
96    pub event_target: ArrayView1<'a, u8>,
97    pub event_competing: ArrayView1<'a, u8>,
98    pub sampleweight: ArrayView1<'a, f64>,
99    pub time_entry: ArrayView2<'a, f64>,
100    pub time_exit: ArrayView2<'a, f64>,
101    pub time_derivative: ArrayView2<'a, f64>,
102    pub covariates: ArrayView2<'a, f64>,
103    /// Optional global monotonicity collocation rows for the full coefficient vector.
104    /// Non-structural survival models should pass these explicitly instead of
105    /// relying on observed derivative rows.
106    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
108    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113    /// Baseline target contribution to eta at entry time: eta_target(t_entry).
114    pub eta_entry: ArrayView1<'a, f64>,
115    /// Baseline target contribution to eta at exit time: eta_target(t_exit).
116    pub eta_exit: ArrayView1<'a, f64>,
117    /// Baseline target contribution to d eta / d t at exit: eta_target'(t_exit).
118    ///
119    /// This is used in event terms where log-hazard requires
120    /// log(d eta / d t). By threading this as an explicit offset, we get
121    /// "parametric default + spline deviation" behavior:
122    /// - strong penalty => deviation ~ 0 => model collapses to baseline target,
123    /// - weak penalty   => deviation can bend away where data supports it.
124    pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129    pub matrix: Array2<f64>,
130    pub lambda: f64,
131    pub range: Range<usize>,
132    /// Structural nullspace dimension of this penalty matrix.
133    /// Used for exact pseudo-logdet computation. 0 means full rank.
134    pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139    pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143    pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144        Self { blocks }
145    }
146
147    pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148        let mut grad = Array1::zeros(beta.len());
149        for block in &self.blocks {
150            if block.lambda == 0.0 {
151                continue;
152            }
153            let b = beta.slice(ndarray::s![block.range.clone()]);
154            let g = block.matrix.dot(&b);
155            let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156            dst += &(block.lambda * g);
157        }
158        grad
159    }
160
161    pub fn hessian(&self, dim: usize) -> Array2<f64> {
162        let mut h = Array2::zeros((dim, dim));
163        self.addhessian_inplace(&mut h);
164        h
165    }
166
167    pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168        let mut value = 0.0;
169        for block in &self.blocks {
170            if block.lambda == 0.0 {
171                continue;
172            }
173            let b = beta.slice(ndarray::s![block.range.clone()]);
174            value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175        }
176        value
177    }
178
179    pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180        for block in &self.blocks {
181            if block.lambda == 0.0 {
182                continue;
183            }
184            let start = block.range.start;
185            let end = block.range.end;
186            h.slice_mut(ndarray::s![start..end, start..end])
187                .scaled_add(block.lambda, &block.matrix);
188        }
189    }
190}
191
192/// Entry ages at or below this value are treated as left-truncation at the time
193/// origin, i.e. "no delayed-entry interval" — the cumulative-hazard term
194/// `exp(η_entry)` is dropped because `H(0) = 0`. The Royston-Parmar baseline is
195/// `η(t) = log H(t)` with `H(t) → 0` as `t → 0`, so `log H` diverges at the
196/// origin; this small positive floor lets a row that genuinely enters at time
197/// zero skip the entry contribution instead of evaluating `log H` at a
198/// degenerate point. Shared so every entry-detection site stays in lockstep.
199///
200/// Public so the fit-orchestration layer can classify a dataset as genuinely
201/// left-truncated (`entry > threshold`) with the SAME origin convention the
202/// likelihood engines use, and pick the left-truncation-robust time anchor
203/// accordingly (issue #1790).
204pub const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
205
206/// Fraction-to-the-boundary factor for the cause-specific feasible-step search.
207/// When a Newton direction would drive a row's derivative down to the
208/// monotonicity floor, the step is capped at this fraction of the distance to
209/// the boundary rather than landing exactly on it. Staying strictly inside the
210/// feasible region (the standard interior-point fraction-to-boundary rule)
211/// keeps the next `1/deriv` / `deriv.ln()` evaluation away from the singular
212/// boundary where curvature blows up.
213const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
214
215#[derive(Debug, Clone)]
216pub struct CauseSpecificRoystonParmarBlock {
217    pub age_entry: Array1<f64>,
218    pub age_exit: Array1<f64>,
219    pub event_target: Array1<u8>,
220    pub sampleweight: Array1<f64>,
221    pub x_entry: Array2<f64>,
222    pub x_exit: Array2<f64>,
223    pub x_derivative: Array2<f64>,
224    pub offset_eta_entry: Array1<f64>,
225    pub offset_eta_exit: Array1<f64>,
226    pub offset_derivative_exit: Array1<f64>,
227    pub derivative_floor: f64,
228}
229
230/// Cause-specific competing-risks survival as a blockwise custom family.
231///
232/// Each cause is represented by one `ParameterBlockState`, so endpoint-specific
233/// coefficients, shared smoothing labels, and user-defined coefficient groups
234/// stay on the existing `CustomFamily` / `BlockwiseFitOptions` joint-fit path.
235#[derive(Debug, Clone)]
236pub struct CauseSpecificRoystonParmarFamily {
237    blocks: Vec<CauseSpecificRoystonParmarBlock>,
238}
239
240impl CauseSpecificRoystonParmarFamily {
241    pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
242        if blocks.is_empty() {
243            return Err(SurvivalError::InvalidInput {
244                reason: "cause-specific survival family requires at least one endpoint".to_string(),
245            }
246            .into());
247        }
248        for (idx, block) in blocks.iter().enumerate() {
249            validate_cause_specific_block(block).map_err(|err| {
250                SurvivalError::CauseSpecificBlock {
251                    block: idx + 1,
252                    source: Box::new(err),
253                }
254                .to_string()
255            })?;
256        }
257        Ok(Self { blocks })
258    }
259
260    pub fn cause_count(&self) -> usize {
261        self.blocks.len()
262    }
263}
264
265fn validate_cause_specific_block(
266    block: &CauseSpecificRoystonParmarBlock,
267) -> Result<(), SurvivalError> {
268    let n = block.event_target.len();
269    let p = block.x_exit.ncols();
270    if n == 0 || p == 0 {
271        bail_invalid_surv!("empty event vector or coefficient block");
272    }
273    if block.age_entry.len() != n
274        || block.age_exit.len() != n
275        || block.sampleweight.len() != n
276        || block.x_entry.nrows() != n
277        || block.x_exit.nrows() != n
278        || block.x_derivative.nrows() != n
279        || block.x_entry.ncols() != p
280        || block.x_derivative.ncols() != p
281        || block.offset_eta_entry.len() != n
282        || block.offset_eta_exit.len() != n
283        || block.offset_derivative_exit.len() != n
284    {
285        return Err(SurvivalError::CauseSpecificDimensionMismatch {
286            reason: "dimension mismatch".to_string(),
287        });
288    }
289    // A cause-specific block's `event_target` is the binary cause-k indicator
290    // produced by `cause_specific_event_indicator`; a label > 1 here means the
291    // caller passed raw multi-cause codes instead of projecting per cause. That
292    // is a valid finite label, not non-finite input, so it gets its own clear
293    // error rather than the misleading "non-finite input".
294    if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
295        return Err(SurvivalError::EventCodeInvalid {
296            reason: format!(
297                "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
298            ),
299        });
300    }
301    if block.age_entry.iter().any(|v| !v.is_finite())
302        || block.age_exit.iter().any(|v| !v.is_finite())
303        || block
304            .sampleweight
305            .iter()
306            .any(|v| !v.is_finite() || *v < 0.0)
307        || block.x_entry.iter().any(|v| !v.is_finite())
308        || block.x_exit.iter().any(|v| !v.is_finite())
309        || block.x_derivative.iter().any(|v| !v.is_finite())
310        || block.offset_eta_entry.iter().any(|v| !v.is_finite())
311        || block.offset_eta_exit.iter().any(|v| !v.is_finite())
312        || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
313        || !block.derivative_floor.is_finite()
314        || block.derivative_floor < 0.0
315    {
316        bail_invalid_surv!("non-finite input");
317    }
318    Ok(())
319}
320
321fn evaluate_cause_specific_block(
322    block: &CauseSpecificRoystonParmarBlock,
323    beta: &Array1<f64>,
324) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
325    let n = block.event_target.len();
326    let p = block.x_exit.ncols();
327    if beta.len() != p {
328        return Err(SurvivalError::CauseSpecificDimensionMismatch {
329            reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
330        });
331    }
332    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
333    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
334    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
335    let mut log_likelihood = 0.0;
336    let mut w_exit = Array1::<f64>::zeros(n);
337    let mut w_entry = Array1::<f64>::zeros(n);
338    let mut w_event = Array1::<f64>::zeros(n);
339    let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
340    let mut w_event_outer = Array1::<f64>::zeros(n);
341
342    for i in 0..n {
343        let weight = block.sampleweight[i];
344        if weight <= 0.0 {
345            continue;
346        }
347        if block.age_exit[i] < block.age_entry[i] {
348            bail_invalid_surv!("age_exit < age_entry at row {i}");
349        }
350        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
351        let h_exit = eta_exit[i].exp();
352        let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
353        if !(h_exit.is_finite() && h_entry.is_finite()) {
354            return Err(SurvivalError::NumericalFailure {
355                reason: format!("non-finite cumulative hazard at row {i}"),
356            });
357        }
358        log_likelihood -= weight * (h_exit - h_entry);
359        w_exit[i] = weight * h_exit;
360        w_entry[i] = weight * h_entry;
361        if block.event_target[i] > 0 {
362            let deriv = derivative[i];
363            if !(deriv.is_finite() && deriv > 0.0) {
364                return Err(SurvivalError::NumericalFailure {
365                    reason: format!(
366                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
367                    ),
368                });
369            }
370            log_likelihood += weight * (eta_exit[i] + deriv.ln());
371            w_event[i] = weight;
372            w_event_inv_deriv[i] = weight / deriv;
373            w_event_outer[i] = weight / (deriv * deriv);
374        }
375    }
376
377    let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
378    nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
379    nll_gradient -= &fast_atv(&block.x_exit, &w_event);
380    nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
381    let gradient = -nll_gradient;
382
383    let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
384    hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
385    hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
386    Ok((log_likelihood, gradient, hessian))
387}
388
389impl CustomFamily for CauseSpecificRoystonParmarFamily {
390    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
391    // flat-prior exact-Newton objective carries no Jeffreys term), so families
392    // that historically armed the term by default opt back in explicitly.
393    fn joint_jeffreys_term_required(&self) -> bool {
394        true
395    }
396
397    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
398        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
399            "cause-specific survival",
400            self.blocks.len(),
401            block_states.len(),
402        )?;
403        let mut log_likelihood = 0.0;
404        let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
405        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
406            let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
407            log_likelihood += ll;
408            blockworking_sets.push(BlockWorkingSet::ExactNewton {
409                gradient,
410                hessian: SymmetricMatrix::Dense(hessian),
411            });
412        }
413        Ok(FamilyEvaluation {
414            log_likelihood,
415            blockworking_sets,
416        })
417    }
418
419    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
420        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
421            "cause-specific survival",
422            self.blocks.len(),
423            block_states.len(),
424        )?;
425        let mut log_likelihood = 0.0;
426        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
427            let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
428            log_likelihood += ll;
429        }
430        Ok(log_likelihood)
431    }
432
433    fn likelihood_blocks_uncoupled(&self) -> bool {
434        true
435    }
436
437    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
438        true
439    }
440
441    fn output_channel_assignment(
442        &self,
443        specs: &[crate::custom_family::ParameterBlockSpec],
444    ) -> Option<Vec<usize>> {
445        if specs.len() != self.blocks.len() {
446            return Some((0..self.blocks.len()).collect());
447        }
448        Some((0..specs.len()).collect())
449    }
450
451    fn coefficient_hessian_cost(
452        &self,
453        specs: &[crate::custom_family::ParameterBlockSpec],
454    ) -> u64 {
455        crate::custom_family::default_coefficient_hessian_cost(specs)
456    }
457
458    fn block_linear_constraints(
459        &self,
460        _: &[ParameterBlockState],
461        block_idx: usize,
462        spec: &crate::custom_family::ParameterBlockSpec,
463    ) -> Result<Option<LinearInequalityConstraints>, String> {
464        let block = self.blocks.get(block_idx).ok_or_else(|| {
465            SurvivalError::CauseSpecificDimensionMismatch {
466                reason: format!(
467                    "cause-specific survival expected block index < {}, got {block_idx}",
468                    self.blocks.len()
469                ),
470            }
471            .to_string()
472        })?;
473        if block.x_derivative.ncols() != spec.design.ncols() {
474            return Err(SurvivalError::CauseSpecificDimensionMismatch {
475                reason: format!(
476                    "cause-specific survival derivative design has {} columns but block '{}' has {}",
477                    block.x_derivative.ncols(),
478                    spec.name,
479                    spec.design.ncols()
480                ),
481            }
482            .into());
483        }
484        let rhs = block
485            .offset_derivative_exit
486            .mapv(|offset| block.derivative_floor - offset);
487        Ok(Some(LinearInequalityConstraints {
488            a: block.x_derivative.clone(),
489            b: rhs,
490        }))
491    }
492
493    fn max_feasible_step_size(
494        &self,
495        block_states: &[ParameterBlockState],
496        block_idx: usize,
497        delta: &Array1<f64>,
498    ) -> Result<Option<f64>, String> {
499        let block = self.blocks.get(block_idx).ok_or_else(|| {
500            SurvivalError::CauseSpecificDimensionMismatch {
501                reason: format!(
502                    "cause-specific survival expected block index < {}, got {block_idx}",
503                    self.blocks.len()
504                ),
505            }
506            .to_string()
507        })?;
508        let state = block_states.get(block_idx).ok_or_else(|| {
509            SurvivalError::CauseSpecificDimensionMismatch {
510                reason: format!(
511                    "cause-specific survival expected {} block states, got {}",
512                    self.blocks.len(),
513                    block_states.len()
514                ),
515            }
516            .to_string()
517        })?;
518        if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
519            return Err(SurvivalError::CauseSpecificDimensionMismatch {
520                reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
521            }
522            .into());
523        }
524        let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
525        let derivative_delta = fast_av(&block.x_derivative, delta);
526        let mut alpha_max = 1.0_f64;
527        for i in 0..derivative.len() {
528            if block.sampleweight[i] <= 0.0 {
529                continue;
530            }
531            let current = derivative[i] - block.derivative_floor;
532            let slope = derivative_delta[i];
533            if slope < 0.0 {
534                if current <= 0.0 {
535                    return Ok(Some(0.0));
536                }
537                alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
538            }
539        }
540        Ok(Some(alpha_max.clamp(0.0, 1.0)))
541    }
542
543    fn exact_newton_hessian_directional_derivative(
544        &self,
545        block_states: &[ParameterBlockState],
546        block_idx: usize,
547        d_beta: &Array1<f64>,
548    ) -> Result<Option<Array2<f64>>, String> {
549        let block = self.blocks.get(block_idx).ok_or_else(|| {
550            SurvivalError::CauseSpecificDimensionMismatch {
551                reason: format!(
552                    "cause-specific survival expected block index < {}, got {block_idx}",
553                    self.blocks.len()
554                ),
555            }
556            .to_string()
557        })?;
558        let state = block_states.get(block_idx).ok_or_else(|| {
559            SurvivalError::CauseSpecificDimensionMismatch {
560                reason: format!(
561                    "cause-specific survival expected {} block states, got {}",
562                    self.blocks.len(),
563                    block_states.len()
564                ),
565            }
566            .to_string()
567        })?;
568        Ok(Some(cause_specific_hessian_directional_derivative(
569            block,
570            &state.beta,
571            d_beta,
572        )?))
573    }
574
575    fn exact_newton_hessian_second_directional_derivative(
576        &self,
577        block_states: &[ParameterBlockState],
578        block_idx: usize,
579        d_beta_u: &Array1<f64>,
580        d_beta_v: &Array1<f64>,
581    ) -> Result<Option<Array2<f64>>, String> {
582        let block = self.blocks.get(block_idx).ok_or_else(|| {
583            SurvivalError::CauseSpecificDimensionMismatch {
584                reason: format!(
585                    "cause-specific survival expected block index < {}, got {block_idx}",
586                    self.blocks.len()
587                ),
588            }
589            .to_string()
590        })?;
591        let state = block_states.get(block_idx).ok_or_else(|| {
592            SurvivalError::CauseSpecificDimensionMismatch {
593                reason: format!(
594                    "cause-specific survival expected {} block states, got {}",
595                    self.blocks.len(),
596                    block_states.len()
597                ),
598            }
599            .to_string()
600        })?;
601        Ok(Some(cause_specific_hessian_second_directional_derivative(
602            block,
603            &state.beta,
604            d_beta_u,
605            d_beta_v,
606        )?))
607    }
608}
609
610fn cause_specific_hessian_directional_derivative(
611    block: &CauseSpecificRoystonParmarBlock,
612    beta: &Array1<f64>,
613    d_beta: &Array1<f64>,
614) -> Result<Array2<f64>, SurvivalError> {
615    let p = block.x_exit.ncols();
616    if beta.len() != p || d_beta.len() != p {
617        return Err(SurvivalError::CauseSpecificDimensionMismatch {
618            reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
619        });
620    }
621    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
622    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
623    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
624    let d_eta_entry = fast_av(&block.x_entry, d_beta);
625    let d_eta_exit = fast_av(&block.x_exit, d_beta);
626    let d_derivative = fast_av(&block.x_derivative, d_beta);
627    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
628    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
629    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
630
631    for i in 0..block.event_target.len() {
632        let weight = block.sampleweight[i];
633        if weight <= 0.0 {
634            continue;
635        }
636        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
637        w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
638        if has_entry {
639            w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
640        }
641        if block.event_target[i] > 0 {
642            let deriv = derivative[i];
643            if !(deriv.is_finite() && deriv > 0.0) {
644                return Err(SurvivalError::NumericalFailure {
645                    reason: format!(
646                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
647                    ),
648                });
649            }
650            w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
651        }
652    }
653
654    let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
655    d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
656    d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
657    Ok(d_hessian)
658}
659
660fn cause_specific_hessian_second_directional_derivative(
661    block: &CauseSpecificRoystonParmarBlock,
662    beta: &Array1<f64>,
663    d_beta_u: &Array1<f64>,
664    d_beta_v: &Array1<f64>,
665) -> Result<Array2<f64>, SurvivalError> {
666    let p = block.x_exit.ncols();
667    if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
668        return Err(SurvivalError::CauseSpecificDimensionMismatch {
669            reason: "cause-specific survival second Hessian derivative dimension mismatch"
670                .to_string(),
671        });
672    }
673    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
674    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
675    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
676    let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
677    let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
678    let u_derivative = fast_av(&block.x_derivative, d_beta_u);
679    let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
680    let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
681    let v_derivative = fast_av(&block.x_derivative, d_beta_v);
682    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
683    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
684    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
685
686    for i in 0..block.event_target.len() {
687        let weight = block.sampleweight[i];
688        if weight <= 0.0 {
689            continue;
690        }
691        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
692        w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
693        if has_entry {
694            w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
695        }
696        if block.event_target[i] > 0 {
697            let deriv = derivative[i];
698            if !(deriv.is_finite() && deriv > 0.0) {
699                return Err(SurvivalError::NumericalFailure {
700                    reason: format!(
701                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
702                    ),
703                });
704            }
705            w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
706        }
707    }
708
709    let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
710    d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
711    d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
712    Ok(d2_hessian)
713}
714
715pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
716    const INTEGER_TOL: f64 = 1e-8;
717    const MAX_AUTO_CAUSES: u8 = 32;
718    if !value.is_finite() {
719        return Err(SurvivalError::EventCodeInvalid {
720            reason: format!(
721                "survival event value at row {} is non-finite",
722                row_index + 1
723            ),
724        }
725        .into());
726    }
727    if value < 0.0 {
728        return Err(SurvivalError::EventCodeInvalid {
729            reason: format!(
730                "survival event value at row {} is negative: {value}",
731                row_index + 1
732            ),
733        }
734        .into());
735    }
736    let rounded = value.round();
737    if (value - rounded).abs() > INTEGER_TOL {
738        return Err(SurvivalError::EventCodeInvalid {
739            reason: format!(
740                "survival event value at row {} must be an integer code with 0=censored, got {value}",
741                row_index + 1
742            ),
743        }
744        .into());
745    }
746    if rounded > f64::from(MAX_AUTO_CAUSES) {
747        return Err(SurvivalError::EventCodeInvalid {
748            reason: format!(
749                "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
750                row_index + 1
751            ),
752        }
753        .into());
754    }
755    Ok(rounded as u8)
756}
757
758pub fn cause_count_from_event_codes(
759    event_codes: ArrayView1<'_, u8>,
760) -> Result<usize, SurvivalError> {
761    let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
762    if max_code == 0 {
763        return Ok(1);
764    }
765
766    let mut present = vec![false; max_code + 1];
767    for code in event_codes.iter().copied() {
768        present[usize::from(code)] = true;
769    }
770    if (1..=max_code).any(|code| !present[code]) {
771        let actual = present
772            .iter()
773            .enumerate()
774            .skip(1)
775            .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
776            .collect::<Vec<_>>()
777            .join(", ");
778        return Err(SurvivalError::EventCodeInvalid {
779            reason: format!(
780                "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
781            ),
782        });
783    }
784
785    Ok(max_code)
786}
787
788/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
789/// onto the binary `{0, 1}` *any-event* indicator the pooled single-hazard
790/// baseline engine consumes.
791///
792/// The shared Royston-Parmar baseline working model fits one hazard across all
793/// causes; every observed event (regardless of cause) informs that baseline, so
794/// the indicator is `1` exactly when *any* cause occurred. This is one of the
795/// two cause-aware projections of the raw code vector; the other is
796/// [`cause_specific_event_indicator`]. Centralizing both here keeps the single
797/// source of truth for "how multi-cause labels become a single-hazard binary
798/// contract", so no construction path open-codes a fragile `mapv` and then
799/// trips the binary `event_target > 1` guard on the raw labels.
800pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
801    event_codes.mapv(|label| u8::from(label > 0))
802}
803
804/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
805/// onto the binary `{0, 1}` indicator for the cause-specific Royston-Parmar
806/// block of cause `cause` (1-based).
807///
808/// Within cause `cause`'s block the event of interest is `event == cause`; every
809/// competing cause is treated as censoring (indicator `0`), which is exactly the
810/// cause-specific hazard likelihood. Like [`pooled_any_event_indicator`], this
811/// yields a binary contract that satisfies the single-hazard `event_target` guard
812/// — the raw multi-cause labels are never handed to a binary engine.
813pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
814    let cause_code = cause as u8;
815    event_codes.mapv(|observed| u8::from(observed == cause_code))
816}
817
818fn compress_positive_collinear_constraints(
819    a: &Array2<f64>,
820    b: &Array1<f64>,
821) -> LinearInequalityConstraints {
822    const SCALE_TOL: f64 = 1e-14;
823    const KEY_TOL: f64 = 1e-8;
824
825    let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
826    let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
827
828    for i in 0..a.nrows() {
829        let row = a.row(i);
830        let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
831        if !scale.is_finite() || scale <= SCALE_TOL {
832            if b[i] > 0.0 {
833                fallbackrows.push((row.to_vec(), b[i]));
834            }
835            continue;
836        }
837
838        let normalizedrow: Vec<f64> = row
839            .iter()
840            .map(|&v| {
841                let scaled = v / scale;
842                if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
843            })
844            .collect();
845        let normalized_rhs = b[i] / scale;
846        let key: Vec<i64> = normalizedrow
847            .iter()
848            .map(|&v| (v / KEY_TOL).round() as i64)
849            .collect();
850
851        match grouped.get_mut(&key) {
852            Some((_, rhs_max)) => {
853                if normalized_rhs > *rhs_max {
854                    *rhs_max = normalized_rhs;
855                }
856            }
857            None => {
858                grouped.insert(key, (normalizedrow, normalized_rhs));
859            }
860        }
861    }
862
863    let nrows = grouped.len() + fallbackrows.len();
864    let n_cols = a.ncols();
865    let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
866    let mut b_out = Array1::<f64>::zeros(nrows);
867
868    let mut outrow = 0usize;
869    for (_, (row, rhs)) in grouped {
870        for (j, value) in row.into_iter().enumerate() {
871            a_out[[outrow, j]] = value;
872        }
873        b_out[outrow] = rhs;
874        outrow += 1;
875    }
876    for (row, rhs) in fallbackrows {
877        for (j, value) in row.into_iter().enumerate() {
878            a_out[[outrow, j]] = value;
879        }
880        b_out[outrow] = rhs;
881        outrow += 1;
882    }
883
884    LinearInequalityConstraints { a: a_out, b: b_out }
885}
886
887#[derive(Debug, Clone, Copy, Default)]
888pub struct SurvivalMonotonicityPenalty {
889    pub tolerance: f64,
890}
891
892#[derive(Debug, Clone)]
893enum SurvivalDesign {
894    Flat {
895        x_entry: Array2<f64>,
896        x_exit: Array2<f64>,
897        x_derivative: Array2<f64>,
898    },
899    TimeCovariateShared {
900        time_entry: Array2<f64>,
901        time_exit: Array2<f64>,
902        time_derivative: Array2<f64>,
903        covariates: Array2<f64>,
904    },
905}
906
907impl SurvivalDesign {
908    fn p_total(&self) -> usize {
909        match self {
910            Self::Flat { x_exit, .. } => x_exit.ncols(),
911            Self::TimeCovariateShared {
912                time_exit,
913                covariates,
914                ..
915            } => time_exit.ncols() + covariates.ncols(),
916        }
917    }
918
919    fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
920        match self {
921            Self::Flat { .. } => time_mat.dot(beta),
922            Self::TimeCovariateShared { covariates, .. } => {
923                let p_time = time_mat.ncols();
924                let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
925                if covariates.ncols() > 0 {
926                    out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
927                }
928                out
929            }
930        }
931    }
932
933    fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
934        match self {
935            Self::Flat { .. } => {
936                for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
937                    *dst = src;
938                }
939            }
940            Self::TimeCovariateShared { covariates, .. } => {
941                let p_time = time_mat.ncols();
942                for j in 0..p_time {
943                    out[j] = time_mat[[i, j]];
944                }
945                for j in 0..covariates.ncols() {
946                    out[p_time + j] = covariates[[i, j]];
947                }
948            }
949        }
950    }
951}
952
953/// Pre-allocated workspace buffers for `update_state` to avoid per-iteration allocations.
954#[derive(Debug, Clone)]
955struct SurvivalWorkspace {
956    w_event: Array1<f64>,
957    w_event_inv_deriv: Array1<f64>,
958    w_event_outer: Array1<f64>,
959    w_hess_exit: Array1<f64>,
960    w_hess_entry: Array1<f64>,
961}
962
963impl SurvivalWorkspace {
964    fn new(n: usize) -> Self {
965        Self {
966            w_event: Array1::zeros(n),
967            w_event_inv_deriv: Array1::zeros(n),
968            w_event_outer: Array1::zeros(n),
969            w_hess_exit: Array1::zeros(n),
970            w_hess_entry: Array1::zeros(n),
971        }
972    }
973
974    fn reset(&mut self, n: usize) {
975        if self.w_event.len() != n {
976            *self = Self::new(n);
977        } else {
978            self.w_event.fill(0.0);
979            self.w_event_inv_deriv.fill(0.0);
980            self.w_event_outer.fill(0.0);
981            self.w_hess_exit.fill(0.0);
982            self.w_hess_entry.fill(0.0);
983        }
984    }
985}
986
987/// Per-observation gradients of the unpenalized survival NLL with respect
988/// to each additive offset channel, at a given β. See
989/// [`WorkingModelSurvival::offset_channel_residuals`] for the algebra.
990///
991/// Contract: all four arrays have length `n` = number of observations.
992/// Rows with non-positive sampleweight are 0 in every channel. The
993/// `derivative` channel is 0 in all non-event rows. The `right` channel is
994/// the interval upper-bound (`R`) η-offset sensitivity and is exactly 0 for
995/// every NON-interval-censored model and every non-interval row of the latent
996/// interval model (only the dedicated `SurvInterval(L, R, event)` latent fit
997/// populates it); the baseline-θ chain rule contracts it against the
998/// `age_right`-evaluated η-partial.
999#[derive(Clone, Debug)]
1000pub struct OffsetChannelResiduals {
1001    /// ∂NLL/∂o_X: w·(exp(η_exit) − δ) per row.
1002    pub exit: Array1<f64>,
1003    /// ∂NLL/∂o_E: −w·exp(η_entry) if row has a positive entry interval else 0.
1004    pub entry: Array1<f64>,
1005    /// ∂NLL/∂o_D: −w·δ / s (event-row only).
1006    pub derivative: Array1<f64>,
1007    /// ∂NLL/∂o_R: interval upper-bound (`R`) η-offset sensitivity,
1008    /// `−w·∂(log-lik)/∂q_right`. Nonzero only for interval-censored latent
1009    /// rows; exactly 0 for every other channel/model.
1010    pub right: Array1<f64>,
1011}
1012
1013/// Per-observation Hessians of the unpenalized survival NLL with respect
1014/// to additive offset channels in `(entry, exit, derivative)` order.
1015#[derive(Clone, Debug)]
1016pub struct OffsetChannelCurvatures {
1017    pub rows: Vec<[[f64; 3]; 3]>,
1018}
1019
1020#[derive(Debug)]
1021pub struct WorkingModelSurvival {
1022    age_entry: Array1<f64>,
1023    age_exit: Array1<f64>,
1024    entry_at_origin: Array1<bool>,
1025    event_target: Array1<u8>,
1026    sampleweight: Array1<f64>,
1027    design: SurvivalDesign,
1028    offset_eta_entry: Array1<f64>,
1029    offset_eta_exit: Array1<f64>,
1030    offset_derivative_exit: Array1<f64>,
1031    penalties: PenaltyBlocks,
1032    monotonicity: SurvivalMonotonicityPenalty,
1033    structurally_monotonic: bool,
1034    structural_time_columns: usize,
1035    monotonicity_constraint_rows: Option<Array2<f64>>,
1036    monotonicity_constraint_offsets: Option<Array1<f64>>,
1037    workspace: std::sync::Mutex<SurvivalWorkspace>,
1038}
1039
1040impl Clone for WorkingModelSurvival {
1041    fn clone(&self) -> Self {
1042        let workspace = self.workspace.lock().unwrap().clone();
1043        Self {
1044            age_entry: self.age_entry.clone(),
1045            age_exit: self.age_exit.clone(),
1046            entry_at_origin: self.entry_at_origin.clone(),
1047            event_target: self.event_target.clone(),
1048            sampleweight: self.sampleweight.clone(),
1049            design: self.design.clone(),
1050            offset_eta_entry: self.offset_eta_entry.clone(),
1051            offset_eta_exit: self.offset_eta_exit.clone(),
1052            offset_derivative_exit: self.offset_derivative_exit.clone(),
1053            penalties: self.penalties.clone(),
1054            monotonicity: self.monotonicity,
1055            structurally_monotonic: self.structurally_monotonic,
1056            structural_time_columns: self.structural_time_columns,
1057            monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1058            monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1059            workspace: std::sync::Mutex::new(workspace),
1060        }
1061    }
1062}
1063
1064impl WorkingModelSurvival {
1065    const LOG_F64_MAX: f64 = 709.782712893384;
1066
1067    #[inline]
1068    fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1069        if base == 0.0 {
1070            return Ok(0.0);
1071        }
1072        let log_abs = log_scale + base.abs().ln();
1073        if !log_abs.is_finite() {
1074            crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1075        }
1076        if log_abs > Self::LOG_F64_MAX {
1077            crate::bail_invalid_estim!(
1078                "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1079            );
1080        }
1081        Ok(base.signum() * log_abs.exp())
1082    }
1083
1084    fn coefficient_dim(&self) -> usize {
1085        self.design.p_total()
1086    }
1087
1088    fn nrows(&self) -> usize {
1089        self.sampleweight.len()
1090    }
1091
1092    fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1093        let time_mat = match &self.design {
1094            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1095            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1096        };
1097        self.design.design_dot(time_mat, beta)
1098    }
1099
1100    fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1101        let time_mat = match &self.design {
1102            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1103            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1104        };
1105        self.design.design_dot(time_mat, beta)
1106    }
1107
1108    fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1109        match &self.design {
1110            SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1111            SurvivalDesign::TimeCovariateShared {
1112                time_derivative, ..
1113            } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1114        }
1115    }
1116
1117    fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1118        let time_mat = match &self.design {
1119            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1120            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1121        };
1122        self.design.fill_row(time_mat, i, out);
1123    }
1124
1125    fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1126        let time_mat = match &self.design {
1127            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1128            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1129        };
1130        self.design.fill_row(time_mat, i, out);
1131    }
1132
1133    fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1134        match &self.design {
1135            SurvivalDesign::Flat { x_derivative, .. } => {
1136                for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1137                    *dst = src;
1138                }
1139            }
1140            SurvivalDesign::TimeCovariateShared {
1141                time_derivative, ..
1142            } => {
1143                let p_time = time_derivative.ncols();
1144                for j in 0..p_time {
1145                    out[j] = time_derivative[[i, j]];
1146                }
1147                for dst in out.iter_mut().skip(p_time) {
1148                    *dst = 0.0;
1149                }
1150            }
1151        }
1152    }
1153
1154    fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1155        match &self.design {
1156            SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1157            SurvivalDesign::TimeCovariateShared {
1158                time_derivative,
1159                covariates,
1160                ..
1161            } => {
1162                let p_time = time_derivative.ncols();
1163                let p_cov = covariates.ncols();
1164                let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1165                let time_block = fast_xt_diag_x(time_derivative, weights);
1166                out.slice_mut(ndarray::s![..p_time, ..p_time])
1167                    .assign(&time_block);
1168                out
1169            }
1170        }
1171    }
1172
1173    /// Compute the full p×p Hessian contribution for the interval terms:
1174    ///   H = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1175    /// using faer-accelerated BLAS on the stored design matrix blocks.
1176    fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1177        match &self.design {
1178            SurvivalDesign::Flat {
1179                x_entry, x_exit, ..
1180            } => {
1181                let mut h = fast_xt_diag_x(x_exit, w_exit);
1182                h -= &fast_xt_diag_x(x_entry, w_entry);
1183                h
1184            }
1185            SurvivalDesign::TimeCovariateShared {
1186                time_entry,
1187                time_exit,
1188                covariates,
1189                ..
1190            } => {
1191                let p_time = time_exit.ncols();
1192                let p_cov = covariates.ncols();
1193                let p = p_time + p_cov;
1194                let mut h = Array2::<f64>::zeros((p, p));
1195                // time-time block: T_exit^T W_exit T_exit - T_entry^T W_entry T_entry
1196                let tt = {
1197                    let mut block = fast_xt_diag_x(time_exit, w_exit);
1198                    block -= &fast_xt_diag_x(time_entry, w_entry);
1199                    block
1200                };
1201                h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1202                if p_cov > 0 {
1203                    // time-cov block: T_exit^T W_exit C - T_entry^T W_entry C
1204                    let tc = {
1205                        let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1206                        block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1207                        block
1208                    };
1209                    h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1210                    h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1211                    // cov-cov block: C^T (W_exit - W_entry) C
1212                    let w_diff = w_exit - w_entry;
1213                    let cc = fast_xt_diag_x(covariates, &w_diff);
1214                    h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1215                }
1216                h
1217            }
1218        }
1219    }
1220
1221    fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1222        const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1223        if !self.structurally_monotonic {
1224            return None;
1225        }
1226        if deriv >= 1e-12 {
1227            return Some(deriv);
1228        }
1229        if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1230            return Some(1e-12);
1231        }
1232        None
1233    }
1234
1235    fn validate_penalties(
1236        penalties: &PenaltyBlocks,
1237        coefficient_dim: usize,
1238    ) -> Result<(), SurvivalError> {
1239        for block in &penalties.blocks {
1240            if !block.lambda.is_finite() || block.lambda < 0.0 {
1241                return Err(SurvivalError::NonFiniteInput);
1242            }
1243            if block.range.start > block.range.end || block.range.end > coefficient_dim {
1244                return Err(SurvivalError::DimensionMismatch);
1245            }
1246            let block_dim = block.range.end - block.range.start;
1247            if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1248                return Err(SurvivalError::DimensionMismatch);
1249            }
1250            if block.matrix.iter().any(|v| !v.is_finite()) {
1251                return Err(SurvivalError::NonFiniteInput);
1252            }
1253        }
1254        Ok(())
1255    }
1256
1257    fn derivative_guard(&self) -> f64 {
1258        if self.structurally_monotonic {
1259            // I-spline basis is monotone by construction when coefficients ≥ 0.
1260            // A derivative of zero (flat hazard) is valid, so the guard only
1261            // rejects genuinely negative derivatives from floating-point noise.
1262            return 0.0;
1263        }
1264        self.monotonicity.tolerance.max(0.0)
1265    }
1266
1267    fn derivative_guard_numerical(&self) -> f64 {
1268        let derivative_guard = self.derivative_guard();
1269        if derivative_guard <= 0.0 {
1270            // For structural monotonicity (guard = 0), tiny negative derivs are
1271            // tolerated because `stabilized_structural_derivative` lifts the
1272            // value back to a small positive floor before any `ln`/`1/deriv`
1273            // use. For *non-structural* monotonicity with tolerance == 0 the
1274            // raw derivative flows straight through into the event-row
1275            // `deriv.ln()` and `1.0 / deriv`, so any non-positive value would
1276            // produce NaN / huge negative weights. Keep the slack only when
1277            // the structural stabilizer is active.
1278            if self.structurally_monotonic {
1279                -1e-10
1280            } else {
1281                1e-12
1282            }
1283        } else {
1284            (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1285        }
1286    }
1287
1288    fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1289        let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1290        1e-10 * scale
1291    }
1292
1293    fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1294        if !self.structurally_monotonic {
1295            return None;
1296        }
1297        let p = self.coefficient_dim();
1298        let time_columns = self.structural_time_columns.min(p);
1299        if time_columns == 0 {
1300            return None;
1301        }
1302        const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1303        let mut active_columns = vec![false; time_columns];
1304        let mut derivative_row = vec![0.0_f64; p];
1305        for i in 0..self.nrows() {
1306            if self.sampleweight[i] <= 0.0 {
1307                continue;
1308            }
1309            self.fill_derivative_row(i, &mut derivative_row);
1310            for j in 0..time_columns {
1311                if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1312                    active_columns[j] = true;
1313                }
1314            }
1315        }
1316        if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1317            for i in 0..rows.nrows() {
1318                for j in 0..time_columns {
1319                    if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1320                        active_columns[j] = true;
1321                    }
1322                }
1323            }
1324        }
1325        let active_columns: Vec<usize> = active_columns
1326            .into_iter()
1327            .enumerate()
1328            .filter_map(|(j, active)| active.then_some(j))
1329            .collect();
1330        if active_columns.is_empty() {
1331            return None;
1332        }
1333        let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1334        let b = Array1::<f64>::zeros(active_columns.len());
1335        for (row, &col) in active_columns.iter().enumerate() {
1336            a[[row, col]] = 1.0;
1337        }
1338        Some(LinearInequalityConstraints { a, b })
1339    }
1340
1341    pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1342        let p = self.coefficient_dim();
1343        const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1344        if p == 0 {
1345            return None;
1346        }
1347        if self.structurally_monotonic {
1348            return self.structural_time_coefficient_constraints();
1349        }
1350        if let (Some(rows), Some(offsets)) = (
1351            self.monotonicity_constraint_rows.as_ref(),
1352            self.monotonicity_constraint_offsets.as_ref(),
1353        ) {
1354            let activerows: Vec<usize> = (0..rows.nrows())
1355                .filter(|&i| {
1356                    rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1357                        > DERIVATIVE_ROW_NORM_TOL
1358                })
1359                .collect();
1360            if activerows.is_empty() {
1361                return None;
1362            }
1363            let mut a = Array2::<f64>::zeros((activerows.len(), p));
1364            let mut b = Array1::<f64>::zeros(activerows.len());
1365            for (r, &i) in activerows.iter().enumerate() {
1366                a.row_mut(r).assign(&rows.row(i));
1367                b[r] = self.derivative_guard() - offsets[i];
1368            }
1369            return Some(compress_positive_collinear_constraints(&a, &b));
1370        }
1371        None
1372    }
1373
1374    pub fn from_engine_inputs(
1375        inputs: SurvivalEngineInputs<'_>,
1376        penalties: PenaltyBlocks,
1377        monotonicity: SurvivalMonotonicityPenalty,
1378        spec: SurvivalSpec,
1379    ) -> Result<Self, SurvivalError> {
1380        Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1381    }
1382
1383    fn validate_offsets(
1384        offsets: Option<SurvivalBaselineOffsets<'_>>,
1385        n: usize,
1386    ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1387        if let Some(off) = offsets {
1388            if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1389            {
1390                return Err(SurvivalError::DimensionMismatch);
1391            }
1392            if off.eta_entry.iter().any(|v| !v.is_finite())
1393                || off.eta_exit.iter().any(|v| !v.is_finite())
1394                || off.derivative_exit.iter().any(|v| !v.is_finite())
1395            {
1396                return Err(SurvivalError::NonFiniteInput);
1397            }
1398            Ok((
1399                off.eta_entry.to_owned(),
1400                off.eta_exit.to_owned(),
1401                off.derivative_exit.to_owned(),
1402            ))
1403        } else {
1404            Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1405        }
1406    }
1407
1408    fn validate_common_inputs(
1409        age_entry: &ArrayView1<f64>,
1410        age_exit: &ArrayView1<f64>,
1411        event_target: &ArrayView1<u8>,
1412        event_competing: &ArrayView1<u8>,
1413        sampleweight: &ArrayView1<f64>,
1414    ) -> Result<(), SurvivalError> {
1415        if age_entry.iter().any(|v| !v.is_finite())
1416            || age_exit.iter().any(|v| !v.is_finite())
1417            || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1418        {
1419            return Err(SurvivalError::NonFiniteInput);
1420        }
1421        // The single-hazard engine's `event_target` contract is binary {0, 1}.
1422        // A code > 1 is a *valid finite multi-cause label* that simply must be
1423        // projected first (any-event for the pooled baseline, cause-specific for
1424        // each block); it is NOT a non-finite input. Report it as such so the
1425        // failure is actionable and never surfaces as the misleading "inputs
1426        // contain non-finite values".
1427        if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1428            return Err(SurvivalError::EventCodeInvalid {
1429                reason: format!(
1430                    "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1431                ),
1432            });
1433        }
1434        if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1435            return Err(SurvivalError::EventCodeInvalid {
1436                reason: format!(
1437                    "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1438                ),
1439            });
1440        }
1441        if event_target
1442            .iter()
1443            .zip(event_competing.iter())
1444            .any(|(&target, &competing)| target > 0 && competing > 0)
1445        {
1446            return Err(SurvivalError::EventCodeInvalid {
1447                reason: "a row cannot be simultaneously a target event and a competing event"
1448                    .to_string(),
1449            });
1450        }
1451        // The "must have at least one target event" requirement is a
1452        // *fittability* check, not a structural one: with all rows censored the
1453        // likelihood has no event score, so any subsequent fit cannot identify
1454        // the hazard and the optimizer spins on a flat landscape.  But the
1455        // structural integrity of the engine — its derivative-guard rejection
1456        // of decreasing cumulative hazards, its monotonicity-collocation
1457        // bookkeeping, its update_state numerics — is well-defined on
1458        // all-censored inputs, and unit tests legitimately exercise those
1459        // structural paths on censored fixtures.  Move the fittability check
1460        // out of construction; production fit dispatchers (e.g.
1461        // `solver::fit_orchestration::materialize_survival`) enforce it on the
1462        // single chokepoint that actually starts an optimization, where
1463        // the failure mode it guards against is reachable.
1464        if age_entry
1465            .iter()
1466            .zip(age_exit.iter())
1467            .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1468        {
1469            return Err(SurvivalError::NonFiniteInput);
1470        }
1471        Ok::<(), _>(())
1472    }
1473
1474    fn validate_monotonicity_constraints(
1475        rows: Option<ArrayView2<'_, f64>>,
1476        offsets: Option<ArrayView1<'_, f64>>,
1477        coefficient_dim: usize,
1478    ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1479        match (rows, offsets) {
1480            (None, None) => Ok((None, None)),
1481            (Some(rows), Some(offsets)) => {
1482                if rows.ncols() != coefficient_dim
1483                    || rows.nrows() != offsets.len()
1484                    || rows.iter().any(|v| !v.is_finite())
1485                    || offsets.iter().any(|v| !v.is_finite())
1486                {
1487                    return Err(SurvivalError::DimensionMismatch);
1488                }
1489                Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1490            }
1491            _ => Err(SurvivalError::DimensionMismatch),
1492        }
1493    }
1494
1495    fn finish_construction(
1496        age_entry: ArrayView1<f64>,
1497        age_exit: ArrayView1<f64>,
1498        event_target: ArrayView1<u8>,
1499        sampleweight: ArrayView1<f64>,
1500        design: SurvivalDesign,
1501        offset_eta_entry: Array1<f64>,
1502        offset_eta_exit: Array1<f64>,
1503        offset_derivative_exit: Array1<f64>,
1504        penalties: PenaltyBlocks,
1505        monotonicity: SurvivalMonotonicityPenalty,
1506        monotonicity_constraint_rows: Option<Array2<f64>>,
1507        monotonicity_constraint_offsets: Option<Array1<f64>>,
1508    ) -> Self {
1509        let n = age_entry.len();
1510        Self {
1511            age_entry: age_entry.to_owned(),
1512            age_exit: age_exit.to_owned(),
1513            entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1514            event_target: event_target.to_owned(),
1515            sampleweight: sampleweight.to_owned(),
1516            design,
1517            offset_eta_entry,
1518            offset_eta_exit,
1519            offset_derivative_exit,
1520            penalties,
1521            monotonicity,
1522            structurally_monotonic: false,
1523            structural_time_columns: 0,
1524            monotonicity_constraint_rows,
1525            monotonicity_constraint_offsets,
1526            workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1527        }
1528    }
1529
1530    pub fn from_engine_inputswith_offsets(
1531        inputs: SurvivalEngineInputs<'_>,
1532        offsets: Option<SurvivalBaselineOffsets<'_>>,
1533        penalties: PenaltyBlocks,
1534        monotonicity: SurvivalMonotonicityPenalty,
1535        spec: SurvivalSpec,
1536    ) -> Result<Self, SurvivalError> {
1537        if spec == SurvivalSpec::Crude {
1538            return Err(SurvivalError::UnsupportedSpec("crude"));
1539        }
1540        let n = inputs.age_entry.len();
1541        let p = inputs.x_entry.ncols();
1542        if inputs.age_exit.len() != n
1543            || inputs.event_target.len() != n
1544            || inputs.event_competing.len() != n
1545            || inputs.sampleweight.len() != n
1546            || inputs.x_entry.nrows() != n
1547            || inputs.x_exit.nrows() != n
1548            || inputs.x_derivative.nrows() != n
1549            || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1550            || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1551        {
1552            return Err(SurvivalError::DimensionMismatch);
1553        }
1554        Self::validate_penalties(&penalties, p)?;
1555        Self::validate_common_inputs(
1556            &inputs.age_entry,
1557            &inputs.age_exit,
1558            &inputs.event_target,
1559            &inputs.event_competing,
1560            &inputs.sampleweight,
1561        )?;
1562        if inputs.x_entry.iter().any(|v| !v.is_finite())
1563            || inputs.x_exit.iter().any(|v| !v.is_finite())
1564            || inputs.x_derivative.iter().any(|v| !v.is_finite())
1565        {
1566            return Err(SurvivalError::NonFiniteInput);
1567        }
1568        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1569            Self::validate_offsets(offsets, n)?;
1570        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1571            Self::validate_monotonicity_constraints(
1572                inputs.monotonicity_constraint_rows,
1573                inputs.monotonicity_constraint_offsets,
1574                p,
1575            )?;
1576
1577        Ok(Self::finish_construction(
1578            inputs.age_entry,
1579            inputs.age_exit,
1580            inputs.event_target,
1581            inputs.sampleweight,
1582            SurvivalDesign::Flat {
1583                x_entry: inputs.x_entry.to_owned(),
1584                x_exit: inputs.x_exit.to_owned(),
1585                x_derivative: inputs.x_derivative.to_owned(),
1586            },
1587            offset_eta_entry,
1588            offset_eta_exit,
1589            offset_derivative_exit,
1590            penalties,
1591            monotonicity,
1592            monotonicity_constraint_rows,
1593            monotonicity_constraint_offsets,
1594        ))
1595    }
1596
1597    pub fn from_time_covariate_inputswith_offsets(
1598        inputs: SurvivalTimeCovarInputs<'_>,
1599        offsets: Option<SurvivalBaselineOffsets<'_>>,
1600        penalties: PenaltyBlocks,
1601        monotonicity: SurvivalMonotonicityPenalty,
1602        spec: SurvivalSpec,
1603    ) -> Result<Self, SurvivalError> {
1604        if spec == SurvivalSpec::Crude {
1605            return Err(SurvivalError::UnsupportedSpec("crude"));
1606        }
1607        let n = inputs.age_entry.len();
1608        let p_time = inputs.time_entry.ncols();
1609        let p_cov = inputs.covariates.ncols();
1610        let p = p_time + p_cov;
1611        if inputs.age_exit.len() != n
1612            || inputs.event_target.len() != n
1613            || inputs.event_competing.len() != n
1614            || inputs.sampleweight.len() != n
1615            || inputs.time_entry.nrows() != n
1616            || inputs.time_exit.nrows() != n
1617            || inputs.time_derivative.nrows() != n
1618            || inputs.covariates.nrows() != n
1619            || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1620            || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1621        {
1622            return Err(SurvivalError::DimensionMismatch);
1623        }
1624        Self::validate_penalties(&penalties, p)?;
1625        Self::validate_common_inputs(
1626            &inputs.age_entry,
1627            &inputs.age_exit,
1628            &inputs.event_target,
1629            &inputs.event_competing,
1630            &inputs.sampleweight,
1631        )?;
1632        if inputs.time_entry.iter().any(|v| !v.is_finite())
1633            || inputs.time_exit.iter().any(|v| !v.is_finite())
1634            || inputs.time_derivative.iter().any(|v| !v.is_finite())
1635            || inputs.covariates.iter().any(|v| !v.is_finite())
1636        {
1637            return Err(SurvivalError::NonFiniteInput);
1638        }
1639        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1640            Self::validate_offsets(offsets, n)?;
1641        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1642            Self::validate_monotonicity_constraints(
1643                inputs.monotonicity_constraint_rows,
1644                inputs.monotonicity_constraint_offsets,
1645                p,
1646            )?;
1647
1648        Ok(Self::finish_construction(
1649            inputs.age_entry,
1650            inputs.age_exit,
1651            inputs.event_target,
1652            inputs.sampleweight,
1653            SurvivalDesign::TimeCovariateShared {
1654                time_entry: inputs.time_entry.to_owned(),
1655                time_exit: inputs.time_exit.to_owned(),
1656                time_derivative: inputs.time_derivative.to_owned(),
1657                covariates: inputs.covariates.to_owned(),
1658            },
1659            offset_eta_entry,
1660            offset_eta_exit,
1661            offset_derivative_exit,
1662            penalties,
1663            monotonicity,
1664            monotonicity_constraint_rows,
1665            monotonicity_constraint_offsets,
1666        ))
1667    }
1668
1669    /// Enable/disable monotonic time-block enforcement metadata.
1670    ///
1671    /// Monotonicity is enforced through linear inequality constraints on the
1672    /// derivative design; enabling this records how many leading time columns
1673    /// belong to that constrained block.
1674    /// Overwrite the per-block smoothing parameters `λ_k` in place.
1675    ///
1676    /// Used by the REML smoothing-parameter selection for transformation
1677    /// survival fits (issue #563): the outer optimizer proposes a `ρ = log λ`
1678    /// vector, sets the smoothing blocks' `λ_k` here, and re-runs the inner
1679    /// constrained PIRLS, so the monotone I-spline baseline can adapt its
1680    /// wiggliness instead of being pinned at a fixed seed. `lambdas` must have
1681    /// one entry per penalty block. The fixed stabilization ridge keeps its
1682    /// caller-set value (the optimizer never proposes a new one for it).
1683    pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1684        if lambdas.len() != self.penalties.blocks.len() {
1685            crate::bail_invalid_estim!(
1686                "set_penalty_lambdas expects {} lambdas, got {}",
1687                self.penalties.blocks.len(),
1688                lambdas.len()
1689            );
1690        }
1691        for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1692            if !lambda.is_finite() || lambda < 0.0 {
1693                crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1694            }
1695            block.lambda = lambda;
1696        }
1697        Ok(())
1698    }
1699
1700    pub fn set_structural_monotonicity(
1701        &mut self,
1702        enabled: bool,
1703        time_columns: usize,
1704    ) -> Result<(), EstimationError> {
1705        let p = self.coefficient_dim();
1706        if time_columns > p {
1707            crate::bail_invalid_estim!(
1708                "structural time columns {} exceed coefficient dimension {}",
1709                time_columns,
1710                p
1711            );
1712        }
1713        if enabled && time_columns == 0 {
1714            crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1715        }
1716        if enabled {
1717            const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1718            for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1719                if offset < -STRUCTURAL_DERIV_TOL {
1720                    crate::bail_invalid_estim!(
1721                        "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1722                    );
1723                }
1724            }
1725            let mut derivative_row = vec![0.0_f64; p];
1726            for i in 0..self.nrows() {
1727                self.fill_derivative_row(i, &mut derivative_row);
1728                for j in 0..time_columns {
1729                    let v = derivative_row[j];
1730                    if v < -STRUCTURAL_DERIV_TOL {
1731                        crate::bail_invalid_estim!(
1732                            "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1733                        );
1734                    }
1735                }
1736                for j in time_columns..p {
1737                    let v = derivative_row[j];
1738                    if v.abs() > STRUCTURAL_DERIV_TOL {
1739                        crate::bail_invalid_estim!(
1740                            "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1741                        );
1742                    }
1743                }
1744            }
1745            if let (Some(rows), Some(offsets)) = (
1746                self.monotonicity_constraint_rows.as_ref(),
1747                self.monotonicity_constraint_offsets.as_ref(),
1748            ) {
1749                for (i, &offset) in offsets.iter().enumerate() {
1750                    if offset < -STRUCTURAL_DERIV_TOL {
1751                        crate::bail_invalid_estim!(
1752                            "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1753                        );
1754                    }
1755                }
1756                for i in 0..rows.nrows() {
1757                    for j in 0..time_columns {
1758                        let v = rows[[i, j]];
1759                        if v < -STRUCTURAL_DERIV_TOL {
1760                            crate::bail_invalid_estim!(
1761                                "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1762                            );
1763                        }
1764                    }
1765                    for j in time_columns..p {
1766                        let v = rows[[i, j]];
1767                        if v.abs() > STRUCTURAL_DERIV_TOL {
1768                            crate::bail_invalid_estim!(
1769                                "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1770                            );
1771                        }
1772                    }
1773                }
1774            }
1775        }
1776        self.structurally_monotonic = enabled;
1777        self.structural_time_columns = if enabled { time_columns } else { 0 };
1778        Ok(())
1779    }
1780
1781    pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1782        if beta.len() != self.coefficient_dim() {
1783            crate::bail_invalid_estim!("survival beta dimension mismatch");
1784        }
1785
1786        let n = self.nrows();
1787        let p = self.coefficient_dim();
1788
1789        // Royston-Parmar contract used throughout the engine:
1790        //   eta(t) = log(H(t)), where H(t) is cumulative hazard.
1791        //
1792        // With row-vectors (per subject i):
1793        //   a1_i^T := x_exit_i^T,  a0_i^T := x_entry_i^T,  d_i^T := x_derivative_i^T
1794        // and scalars:
1795        //   eta1_i = a1_i^T beta,  eta0_i = a0_i^T beta,  s_i = d_i^T beta.
1796        //
1797        // The per-subject negative log-likelihood used below is
1798        //   NLL_i(beta) = exp(eta1_i) - exp(eta0_i) - delta_i * (eta1_i + log(s_i)),
1799        // with delta_i = event_target_i.
1800        //
1801        // This is exactly the form whose derivatives are:
1802        //   grad_i = exp(eta1_i) a1_i - exp(eta0_i) a0_i - delta_i * (a1_i + d_i / s_i)
1803        //   Hess_i = exp(eta1_i) a1_i a1_i^T - exp(eta0_i) a0_i a0_i^T
1804        //            + delta_i * (d_i d_i^T) / s_i^2.
1805        //
1806        // Monotonicity is enforced through linear inequality constraints on the
1807        // derivative design. This keeps the baseline smoothing penalty on the
1808        // actual spline coefficients and preserves zero-deviation as beta=0.
1809        //
1810        // The loop below computes exact beta-space derivatives and then adds penalties.
1811        // Total predictor = target offset + learned deviation.
1812        // This is the same architecture used for flexible binary links:
1813        // principled default, plus penalized wiggle/deviation.
1814        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1815        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1816        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1817
1818        let mut nll = 0.0;
1819        let derivative_guard = self.derivative_guard();
1820        let derivative_guard_numerical = self.derivative_guard_numerical();
1821        let mut workspace = self.workspace.lock().unwrap();
1822        workspace.reset(n);
1823        let SurvivalWorkspace {
1824            w_event,
1825            w_event_inv_deriv,
1826            w_event_outer,
1827            w_hess_exit,
1828            w_hess_entry,
1829        } = &mut *workspace;
1830
1831        // Phase 1: Scalar loop — compute per-observation weights, NLL, validation.
1832        for i in 0..n {
1833            let w = self.sampleweight[i];
1834            if w <= 0.0 {
1835                continue;
1836            }
1837            let entry_age = self.age_entry[i];
1838            let exit_age = self.age_exit[i];
1839            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1840                crate::bail_invalid_estim!(
1841                    "survival ages must be finite with age_exit >= age_entry"
1842                );
1843            }
1844            let d = f64::from(self.event_target[i]);
1845
1846            let has_entry_interval = !self.entry_at_origin[i];
1847            let interval_scale = if has_entry_interval {
1848                eta_exit[i].max(eta_entry[i])
1849            } else {
1850                eta_exit[i]
1851            };
1852            let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1853            let h_s_scaled = if has_entry_interval {
1854                (eta_entry[i] - interval_scale).exp()
1855            } else {
1856                0.0
1857            };
1858            let interval_scaled = h_e_scaled - h_s_scaled;
1859            let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1860            let deriv = self
1861                .stabilized_structural_derivative(derivative_raw[i])
1862                .unwrap_or(derivative_raw[i]);
1863            // Monotonicity of η(t) = log H(t) is a structural property of the
1864            // whole Royston-Parmar spline. If d_eta/dt is *strictly negative*
1865            // at any observed exit time, the cumulative hazard H(t) decreases
1866            // there and S(t) is not a valid survival function — both event
1867            // and censored rows have to refuse that case. Event rows further
1868            // need deriv strictly above the numerical guard because their
1869            // NLL contains `deriv.ln()` and `1.0 / deriv`; censored rows do
1870            // not, so a boundary value of exactly zero is feasible there.
1871            let mono_floor = if d > 0.0 {
1872                derivative_guard_numerical
1873            } else {
1874                0.0
1875            };
1876            if !deriv.is_finite() || deriv < mono_floor {
1877                return Err(EstimationError::ParameterConstraintViolation(format!(
1878                    "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1879                    i, deriv, derivative_guard
1880                )));
1881            }
1882            if has_entry_interval {
1883                let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1884                if interval_scaled + increment_guard < 0.0 {
1885                    return Err(EstimationError::ParameterConstraintViolation(format!(
1886                        "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1887                        i, interval
1888                    )));
1889                }
1890            }
1891            nll += w * interval;
1892
1893            // Per-observation weights for BLAS phase.
1894            // scaled_exp_component(interval_scale, h_e_scaled * x[r]) = exp(interval_scale) * h_e_scaled * x[r]
1895            // so the Hessian weight is w * exp(interval_scale) * h_e_scaled = w * exp(eta_exit).
1896            let w_exit_i = w * eta_exit[i].exp();
1897            let w_entry_i = if has_entry_interval {
1898                w * eta_entry[i].exp()
1899            } else {
1900                0.0
1901            };
1902            if !w_exit_i.is_finite() {
1903                crate::bail_invalid_estim!(
1904                    "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1905                );
1906            }
1907            w_hess_exit[i] = w_exit_i;
1908            w_hess_entry[i] = w_entry_i;
1909
1910            if d > 0.0 {
1911                let inv_deriv = 1.0 / deriv;
1912                nll += -w * (eta_exit[i] + deriv.ln());
1913                w_event[i] = w;
1914                w_event_inv_deriv[i] = w * inv_deriv;
1915                w_event_outer[i] = w * inv_deriv * inv_deriv;
1916            }
1917        }
1918
1919        // Phase 2: BLAS-accelerated Hessian and gradient via faer.
1920        //   H_interval = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1921        //   grad_interval = X_exit^T w_exit - X_entry^T w_entry
1922        let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1923        // At large smoothing penalties the event-Jacobian score nearly cancels
1924        // the interval score. Compensated row accumulation keeps the final KKT
1925        // residual accurate enough for the outer LAML envelope check.
1926        let mut grad = Array1::<f64>::zeros(p);
1927        let mut grad_comp = Array1::<f64>::zeros(p);
1928        let mut row_exit = vec![0.0_f64; p];
1929        let mut row_entry = vec![0.0_f64; p];
1930        let mut row_derivative = vec![0.0_f64; p];
1931        for i in 0..n {
1932            let w_interval_exit = w_hess_exit[i];
1933            let w_interval_entry = w_hess_entry[i];
1934            let w_event_exit = w_event[i];
1935            let w_event_derivative = w_event_inv_deriv[i];
1936            if w_interval_exit == 0.0
1937                && w_interval_entry == 0.0
1938                && w_event_exit == 0.0
1939                && w_event_derivative == 0.0
1940            {
1941                continue;
1942            }
1943            self.fill_exit_row(i, &mut row_exit);
1944            self.fill_entry_row(i, &mut row_entry);
1945            self.fill_derivative_row(i, &mut row_derivative);
1946            for j in 0..p {
1947                let contribution = w_interval_exit * row_exit[j]
1948                    - w_interval_entry * row_entry[j]
1949                    - w_event_exit * row_exit[j]
1950                    - w_event_derivative * row_derivative[j];
1951                let t = grad[j] + contribution;
1952                if grad[j].abs() >= contribution.abs() {
1953                    grad_comp[j] += (grad[j] - t) + contribution;
1954                } else {
1955                    grad_comp[j] += (contribution - t) + grad[j];
1956                }
1957                grad[j] = t;
1958            }
1959        }
1960        grad += &grad_comp;
1961
1962        h += &self.derivative_xt_diag_x(w_event_outer);
1963
1964        // Norm of the unpenalized score, captured before adding the penalty
1965        // and ridge contributions, for the scale-invariant convergence
1966        // certificate (||score||_2 + ||S*beta||_2 (+ ridge*||beta||_2)).
1967        let score_norm = array1_l2_norm(&grad);
1968
1969        let penaltygrad = self.penalties.gradient(beta);
1970        let penalty_dev = self.penalties.deviance(beta);
1971        let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1972
1973        let mut totalgrad = grad;
1974        totalgrad += &penaltygrad;
1975
1976        self.penalties.addhessian_inplace(&mut h);
1977        // SURVIVAL_STABILIZATION_RIDGE is an `ExplicitPrior`-kind
1978        // stabilization in the canonical ledger taxonomy
1979        // (`gam_problem::StabilizationKind::ExplicitPrior`): δ enters the
1980        // gradient (`grad += δ β`), the Hessian (`H += δ I`), the scalar
1981        // penalty term added to the objective (`0.5 δ ‖β‖²`), and is
1982        // serialized through `WorkingState::ridge_used` so downstream
1983        // covariance and survival_ridge_lambda accounting remain
1984        // consistent. The canonical ledger record is
1985        //   StabilizationLedger::explicit_prior(δ, RidgeMatrixForm::ScaledIdentity)
1986        // chosen_by = FixedConstant. Coordinated with main.rs
1987        // `survival_ridge_lambda` field.
1988        const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1989        let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1990        for d in 0..p {
1991            h[[d, d]] += ridge_used;
1992        }
1993        totalgrad += &beta.mapv(|v| ridge_used * v);
1994        // Keep scalar objective term consistent with:
1995        //   grad += ridge * beta,  Hess += ridge * I
1996        // which correspond to 0.5 * ridge * ||beta||^2.
1997        let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1998        let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1999
2000        let log_likelihood = -nll;
2001        let deviance = 2.0 * nll;
2002
2003        Ok(WorkingState {
2004            eta: LinearPredictor::new(eta_exit),
2005            gradient: totalgrad,
2006            hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2007            log_likelihood,
2008            deviance,
2009            penalty_term: penalty_dev + ridge_penalty,
2010            firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2011            ridge_used,
2012            hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2013            gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2014        })
2015    }
2016
2017    /// Compute the third-derivative correction matrix for a given mode response `u_k`.
2018    ///
2019    /// This is the directional derivative of the unpenalized NLL Hessian w.r.t.
2020    /// beta along direction `u_k = -H^{-1} A_k beta_hat`. The returned matrix B
2021    /// satisfies `dH/drho_k = A_k + B`.
2022    ///
2023    /// Called via [`SurvivalDerivProvider`] which adapts the sign convention
2024    /// from the unified `HessianDerivativeProvider` trait (positive `v_k`) to
2025    /// the negated `u_k` used here.
2026    pub(crate) fn survival_hessian_derivative_correction(
2027        &self,
2028        beta: &Array1<f64>,
2029        u_k: &Array1<f64>,
2030    ) -> Result<Array2<f64>, EstimationError> {
2031        let p = beta.len();
2032        let n = self.nrows();
2033
2034        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2035        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2036        let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2037        let exp_entry = eta_entry.mapv(f64::exp);
2038        let exp_exit = eta_exit.mapv(f64::exp);
2039        let guard = self.derivative_guard();
2040        let guard_numerical = self.derivative_guard_numerical();
2041
2042        let jac = Array1::<f64>::ones(p);
2043        let curvature = Array1::<f64>::zeros(p);
2044        let third = Array1::<f64>::zeros(p);
2045
2046        let mut row_exit = vec![0.0_f64; p];
2047        let mut row_entry = vec![0.0_f64; p];
2048        let mut row_derivative = vec![0.0_f64; p];
2049        let mut ge = vec![0.0_f64; p];
2050        let mut gs = vec![0.0_f64; p];
2051        let mut gsd = vec![0.0_f64; p];
2052        let mut he = vec![0.0_f64; p];
2053        let mut hs = vec![0.0_f64; p];
2054        let mut hsd = vec![0.0_f64; p];
2055        let mut te = vec![0.0_f64; p];
2056        let mut ts = vec![0.0_f64; p];
2057        let mut tsd = vec![0.0_f64; p];
2058
2059        let mut b_dir = Array2::<f64>::zeros((p, p));
2060
2061        for i in 0..n {
2062            let w_i = self.sampleweight[i];
2063            if w_i <= 0.0 {
2064                continue;
2065            }
2066            let has_entry = !self.entry_at_origin[i];
2067            let mut deta_e = 0.0_f64;
2068            let mut deta_s = 0.0_f64;
2069            let mut ds = 0.0_f64;
2070            self.fill_exit_row(i, &mut row_exit);
2071            self.fill_entry_row(i, &mut row_entry);
2072            self.fill_derivative_row(i, &mut row_derivative);
2073            for j in 0..p {
2074                ge[j] = row_exit[j] * jac[j];
2075                gs[j] = row_entry[j] * jac[j];
2076                gsd[j] = row_derivative[j] * jac[j];
2077                he[j] = row_exit[j] * curvature[j];
2078                hs[j] = row_entry[j] * curvature[j];
2079                hsd[j] = row_derivative[j] * curvature[j];
2080                te[j] = row_exit[j] * third[j];
2081                ts[j] = row_entry[j] * third[j];
2082                tsd[j] = row_derivative[j] * third[j];
2083                deta_e += ge[j] * u_k[j];
2084                if has_entry {
2085                    deta_s += gs[j] * u_k[j];
2086                }
2087                ds += gsd[j] * u_k[j];
2088            }
2089
2090            // Interval part: d/dbeta [ exp(eta) * (g g^T + diag(h)) ][u_k]
2091            for r in 0..p {
2092                let dge_r = he[r] * u_k[r];
2093                let dgs_r = hs[r] * u_k[r];
2094                let dhe_r = te[r] * u_k[r];
2095                let dhs_r = ts[r] * u_k[r];
2096                for c in 0..p {
2097                    let dge_c = he[c] * u_k[c];
2098                    let dgs_c = hs[c] * u_k[c];
2099                    let mut d_h_rc =
2100                        exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2101                    if r == c {
2102                        d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2103                    }
2104                    if has_entry {
2105                        d_h_rc -=
2106                            exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2107                        if r == c {
2108                            d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2109                        }
2110                    }
2111                    b_dir[[r, c]] += w_i * d_h_rc;
2112                }
2113            }
2114
2115            // Event part: d/dbeta [ gsd gsd^T / s^2 - diag(he) - diag(hsd / s) ][u_k]
2116            let s_i = self
2117                .stabilized_structural_derivative(deriv_raw[i])
2118                .unwrap_or(deriv_raw[i]);
2119            if !s_i.is_finite() {
2120                return Err(EstimationError::ParameterConstraintViolation(format!(
2121                    "survival monotonicity violated in unified trace contraction at row {i}: \
2122                     d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2123                )));
2124            }
2125            if self.event_target[i] > 0 {
2126                if s_i < guard_numerical {
2127                    return Err(EstimationError::ParameterConstraintViolation(format!(
2128                        "survival monotonicity violated in unified trace contraction at row {i}: \
2129                         d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2130                    )));
2131                }
2132                let inv_s = 1.0 / s_i;
2133                let inv_s2 = inv_s * inv_s;
2134                let inv_s3 = inv_s2 * inv_s;
2135                for r in 0..p {
2136                    let dgd_r = hsd[r] * u_k[r];
2137                    let dtsd_r = tsd[r] * u_k[r];
2138                    let dte_r = te[r] * u_k[r];
2139                    for c in 0..p {
2140                        let dgd_c = hsd[c] * u_k[c];
2141                        let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2142                            - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2143                        if r == c {
2144                            d_h_rc += -dte_r;
2145                            d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2146                        }
2147                        b_dir[[r, c]] += w_i * d_h_rc;
2148                    }
2149                }
2150            }
2151        }
2152
2153        Ok(b_dir)
2154    }
2155
2156    /// Per-observation gradients of the unpenalized survival NLL with respect
2157    /// to each additive offset channel, at the given β.
2158    ///
2159    /// Contract (Royston-Parmar, eta = log H(t)):
2160    ///
2161    ///   NLL_i(β; o_E, o_X, o_D) = w_i · [
2162    ///       exp(η1_i) − 1{has_entry}·exp(η0_i)
2163    ///       − δ_i · (η1_i + log s_i)
2164    ///   ]
2165    ///
2166    /// with η1_i = a1_iᵀβ + o_X[i], η0_i = a0_iᵀβ + o_E[i],
2167    ///      s_i  = d_iᵀβ + o_D[i].
2168    ///
2169    /// The additive offsets enter each of the three η channels linearly, so
2170    ///   ∂NLL_i/∂o_X[i] = w_i · (exp(η1_i) − δ_i)
2171    ///   ∂NLL_i/∂o_E[i] = −w_i · exp(η0_i) · 1{has_entry_interval}
2172    ///   ∂NLL_i/∂o_D[i] = −w_i · δ_i / s_i         (event-row only)
2173    ///
2174    /// These three arrays are the sampleweight-scaled residuals used to chain
2175    /// `∂NLL/∂offset` into `∂NLL/∂θ` via any closed-form `∂offset/∂θ` map
2176    /// (see `baseline_offset_theta_partials` for parametric baselines). At
2177    /// converged β*, the envelope theorem on the penalized objective gives
2178    ///
2179    ///   d[0.5·(deviance + β*ᵀS_λβ*)] / dθ
2180    ///     = Σᵢ r_X_i·∂o_X_i/∂θ + r_E_i·∂o_E_i/∂θ + r_D_i·∂o_D_i/∂θ
2181    ///
2182    /// exactly (no IFT back-solve required), because β* is a stationary point
2183    /// of the penalized objective wrt β and the penalty has no θ dependence.
2184    ///
2185    /// Rows with `sampleweight[i] ≤ 0` and non-event rows for `r_D` are
2186    /// returned as exact 0.0 so the output can be dot-producted against a
2187    /// per-obs baseline-partials array without a mask.
2188    ///
2189    /// Structural-monotonicity stabilization on `s_i` (see
2190    /// `stabilized_structural_derivative`) is applied identically to the
2191    /// existing `update_state` path so the residual agrees with the
2192    /// NLL that `update_state` evaluates.
2193    pub fn offset_channel_residuals(
2194        &self,
2195        beta: &Array1<f64>,
2196    ) -> Result<OffsetChannelResiduals, EstimationError> {
2197        if beta.len() != self.coefficient_dim() {
2198            crate::bail_invalid_estim!(
2199                "survival beta dimension mismatch in offset_channel_residuals"
2200            );
2201        }
2202        let n = self.nrows();
2203        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2204        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2205        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2206
2207        let derivative_guard_numerical = self.derivative_guard_numerical();
2208        let mut r_exit = Array1::<f64>::zeros(n);
2209        let mut r_entry = Array1::<f64>::zeros(n);
2210        let mut r_deriv = Array1::<f64>::zeros(n);
2211
2212        for i in 0..n {
2213            let w = self.sampleweight[i];
2214            if w <= 0.0 {
2215                continue;
2216            }
2217            let entry_age = self.age_entry[i];
2218            let exit_age = self.age_exit[i];
2219            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2220                crate::bail_invalid_estim!(
2221                    "survival ages must be finite with age_exit >= age_entry"
2222                );
2223            }
2224            let has_entry_interval = !self.entry_at_origin[i];
2225            let d = f64::from(self.event_target[i]);
2226            // Phase-1 values matching update_state:
2227            //   w_exit_i  = w · exp(eta_exit)                    → ∂NLL/∂o_X before − δ·w term
2228            //   w_entry_i = w · exp(eta_entry) · 1{has_entry}    → matches −∂NLL/∂o_E sign
2229            let w_exit_i = w * eta_exit[i].exp();
2230            let w_entry_i = if has_entry_interval {
2231                w * eta_entry[i].exp()
2232            } else {
2233                0.0
2234            };
2235            if !w_exit_i.is_finite() {
2236                crate::bail_invalid_estim!(
2237                    "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2238                );
2239            }
2240            r_exit[i] = w_exit_i - d * w;
2241            r_entry[i] = -w_entry_i;
2242            // Same per-row monotonicity rule as `update_state`: a strictly
2243            // negative derivative at any observed exit time (event or
2244            // censored) falsifies S(t); event rows additionally need
2245            // `deriv > guard` because `1/deriv` enters their score.
2246            let deriv_raw = derivative_raw[i];
2247            let deriv = self
2248                .stabilized_structural_derivative(deriv_raw)
2249                .unwrap_or(deriv_raw);
2250            let mono_floor = if d > 0.0 {
2251                derivative_guard_numerical
2252            } else {
2253                0.0
2254            };
2255            if !deriv.is_finite() || deriv < mono_floor {
2256                return Err(EstimationError::ParameterConstraintViolation(format!(
2257                    "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2258                )));
2259            }
2260            if d > 0.0 {
2261                r_deriv[i] = -w * d / deriv;
2262            }
2263        }
2264
2265        let right = Array1::<f64>::zeros(r_exit.len());
2266        Ok(OffsetChannelResiduals {
2267            exit: r_exit,
2268            entry: r_entry,
2269            derivative: r_deriv,
2270            right,
2271        })
2272    }
2273
2274    /// Build an [`InnerSolution`](gam_solve::estimate::reml::reml_outer_engine::InnerSolution) from
2275    /// the survival working state, suitable for the unified REML/LAML evaluator.
2276    ///
2277    /// Evaluate the survival outer objective and gradient via the unified REML/LAML
2278    /// evaluator, using the canonical assembly module.
2279    pub fn unified_lamlobjective_and_rhogradient(
2280        &self,
2281        beta: &Array1<f64>,
2282        state: &WorkingState,
2283        rho: &Array1<f64>,
2284    ) -> Result<(f64, Array1<f64>), EstimationError> {
2285        use gam_solve::estimate::reml::assembly::{
2286            InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2287        };
2288        use gam_solve::estimate::reml::reml_outer_engine::{
2289            DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2290            compute_block_penalty_logdet_derivs,
2291        };
2292        use gam_problem::{EvalMode, PseudoLogdetMode};
2293
2294        let p = beta.len();
2295        let active_penalty_blocks: Vec<&PenaltyBlock> = self
2296            .penalties
2297            .blocks
2298            .iter()
2299            .filter(|b| b.lambda > 0.0)
2300            .collect();
2301        if rho.len() != active_penalty_blocks.len() {
2302            crate::bail_invalid_estim!(
2303                "survival LAML rho dimension {} does not match active penalty block count {}",
2304                rho.len(),
2305                active_penalty_blocks.len()
2306            );
2307        }
2308        let k_count = active_penalty_blocks.len();
2309
2310        // --- Hessian operator ---
2311        let h_dense = state.hessian.to_dense();
2312        let has_left_truncation = self
2313            .age_entry
2314            .iter()
2315            .any(|&t| t > ENTRY_AT_ORIGIN_THRESHOLD);
2316        // Transformation-survival uses observed information in the LAML logdet.
2317        // With delayed entry the likelihood contains +H(entry), so the observed
2318        // NLL curvature includes a genuine negative
2319        // -X_entry' diag(exp(eta_entry)) X_entry block. The shared smooth
2320        // pseudo-logdet is a PSD-contract regularizer, not a licence to reward
2321        // negative observed-curvature directions: a negative eigenvalue maps to
2322        // a tiny positive regularized value and can make the outer smoothing
2323        // objective prefer under-smoothed, nearly singular baselines. For the
2324        // delayed-entry observed-information path, use the identified positive
2325        // subspace logdet/pseudoinverse instead; right-censored fits keep the
2326        // historical smooth full-spectrum convention.
2327        let hessian_logdet_mode = if has_left_truncation {
2328            PseudoLogdetMode::HardPseudo
2329        } else {
2330            PseudoLogdetMode::Smooth
2331        };
2332        let hop = DenseSpectralOperator::from_symmetric_with_mode(&h_dense, hessian_logdet_mode)
2333            .map_err(EstimationError::InvalidInput)?;
2334
2335        // --- Penalty coordinates via shared assembler helper ---
2336        let block_descs: Vec<PenaltyBlockDesc> = self
2337            .penalties
2338            .blocks
2339            .iter()
2340            .filter(|b| b.lambda > 0.0)
2341            .map(|b| PenaltyBlockDesc {
2342                matrix: &b.matrix,
2343                range_start: b.range.start,
2344                range_end: b.range.end,
2345            })
2346            .collect();
2347        let penalty_coords =
2348            penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2349
2350        // --- Penalty logdet derivatives ---
2351        let per_block_rho: Vec<Array1<f64>> =
2352            rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2353        let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2354            .iter()
2355            .map(|b| vec![b.matrix.clone()])
2356            .collect();
2357        let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2358            .iter()
2359            .map(|v| v.as_slice())
2360            .collect();
2361        let penalty_logdet = if k_count > 0 {
2362            compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2363                .map_err(EstimationError::InvalidInput)?
2364        } else {
2365            PenaltyLogdetDerivs {
2366                value: 0.0,
2367                first: Array1::zeros(0),
2368                second: Some(Array2::zeros((0, 0))),
2369            }
2370        };
2371
2372        // penalty_quadratic = 2 * penalty_term (matching unified evaluator convention).
2373        let penalty_quadratic = 2.0 * state.penalty_term;
2374        let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2375
2376        // #931 survival-LAML IFT envelope: attach the one-step Newton correction
2377        // only when this state is actually a near-stationary inner solution.
2378        // `unified_lamlobjective_and_rhogradient` is also used by algebraic
2379        // fixed-beta objective tests; feeding a large non-stationary residual
2380        // there makes the value a different surface. The re-converged shim
2381        // polishes the inner mode to an absolute residual floor, so certified
2382        // states still keep the envelope correction while arbitrary beta probes
2383        // evaluate the documented LAML objective.
2384        //
2385        // The residual MUST be the active-set-projected stationarity vector, not
2386        // raw `state.gradient`: a binding monotonicity constraint contributes a
2387        // Lagrange-multiplier normal component (`r = A^T lambda`, lambda >= 0)
2388        // that is not a stationarity residual.
2389        const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2390        let kkt_residual = {
2391            let raw = state.gradient.clone();
2392            let projected = match self.monotonicity_linear_constraints() {
2393                Some(constraints) => {
2394                    projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2395                        .ok_or_else(|| {
2396                            EstimationError::InvalidInput(
2397                                "survival LAML could not project the monotonicity KKT residual"
2398                                    .to_string(),
2399                            )
2400                        })?
2401                }
2402                None => raw,
2403            };
2404            let projected_norm = array1_l2_norm(&projected);
2405            let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2406            if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2407                Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2408            } else {
2409                None
2410            }
2411        };
2412
2413        let result = InnerAssembly {
2414            log_likelihood: state.log_likelihood,
2415            penalty_quadratic,
2416            beta: beta.clone(),
2417            n_observations: self.nrows(),
2418            hessian_op: std::sync::Arc::new(hop),
2419            penalty_coords,
2420            penalty_logdet,
2421            dispersion: DispersionHandling::Fixed {
2422                phi: 1.0,
2423                include_logdet_h: true,
2424                include_logdet_s: true,
2425            },
2426            rho_curvature_scale: 1.0,
2427            rho_prior: gam_problem::RhoPrior::Flat,
2428            hessian_logdet_correction: 0.0,
2429            penalty_subspace_trace: None,
2430            deriv_provider: Some(Box::new(provider)),
2431            firth: None,
2432            nullspace_dim: None,
2433            barrier_config: None,
2434            ext_coords: Vec::new(),
2435            ext_coord_pair_fn: None,
2436            rho_ext_pair_fn: None,
2437            fixed_drift_deriv: None,
2438            contracted_psi_second_order: None,
2439            kkt_residual,
2440            active_constraints: None,
2441        }
2442        .evaluate(
2443            rho.as_slice().expect("rho must be contiguous"),
2444            EvalMode::ValueAndGradient,
2445            None,
2446        )
2447        .map_err(EstimationError::InvalidInput)?;
2448
2449        let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2450        Ok((result.cost, gradient))
2451    }
2452
2453    /// Self-contained ρ → (LAML value, analytic ρ-gradient) surface for the
2454    /// survival LAML objective.
2455    ///
2456    /// Unlike [`unified_lamlobjective_and_rhogradient`](Self::unified_lamlobjective_and_rhogradient),
2457    /// which takes a *pre-converged* [`WorkingState`] and `β̂` at the evaluated
2458    /// `ρ`, this shim re-converges the inner survival mode internally: it sets
2459    /// the active-block smoothing parameters to `λ = exp(ρ)`, runs the same
2460    /// constrained inner PIRLS that the survival outer loop uses
2461    /// ([`runworking_model_pirls`](gam_solve::pirls::runworking_model_pirls)), then
2462    /// evaluates the unified survival LAML value and analytic ρ-gradient at the
2463    /// re-fitted `β̂(ρ)`. The returned pair is therefore a single-source value+
2464    /// gradient surface that a caller can finite-difference by varying `ρ`
2465    /// alone — the survival counterpart of the GLM path's
2466    /// `evaluate_externalgradient` / `evaluate_externalcost_andridge`.
2467    ///
2468    /// `rho` enumerates the **active** penalty blocks (those with `λ > 0`) in
2469    /// block order, matching the convention of the unified evaluator. `beta0` is
2470    /// the inner warm-start. The behaviour is identical to the existing survival
2471    /// LAML path (set-λ → inner PIRLS → `update_state` → unified LAML); this is a
2472    /// reachability shim, not a new objective.
2473    pub fn evaluate_survival_lamlcost_and_gradient(
2474        &self,
2475        rho: &[f64],
2476        beta0: &Array1<f64>,
2477    ) -> Result<(f64, Array1<f64>), EstimationError> {
2478        let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2479        // Re-converged β̂(ρ); evaluate the unified survival LAML value and
2480        // analytic ρ-gradient at that mode. The ρ passed to the unified
2481        // evaluator enumerates active blocks in block order, exactly the input
2482        // convention of this shim.
2483        let rho_arr = Array1::from_vec(rho.to_vec());
2484        let state = candidate.update_state(&beta)?;
2485        candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2486    }
2487
2488    /// Re-converge the survival inner mode at `λ = exp(ρ)` from warm-start
2489    /// `beta0`, returning the λ-set model candidate and the converged `β̂(ρ)`.
2490    /// This is the shared inner-solve used by
2491    /// [`evaluate_survival_lamlcost_and_gradient`](Self::evaluate_survival_lamlcost_and_gradient):
2492    /// inner PIRLS to a tight relative certificate, followed by a
2493    /// Levenberg–Marquardt / exact-Cholesky stationarity polish that drives the
2494    /// absolute penalized residual `‖S β̂ − ∇ℓ‖` below the FD round-off floor so
2495    /// the envelope ρ-gradient is exact. (Without the polish, PIRLS alone leaves
2496    /// `‖r‖ ~ 1` at large λ where H is ill-conditioned.)
2497    fn reconverge_survival_inner_mode(
2498        &self,
2499        rho: &[f64],
2500        beta0: &Array1<f64>,
2501    ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2502        // Inner-PIRLS settings mirror the survival transformation outer loop's
2503        // constrained inner solve. Tighter convergence than the production
2504        // outer loop so the inner mode is converged well below the FD step's
2505        // round-off floor, making ∇V finite-differentiable in ρ alone.
2506        const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2507        const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2508        const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2509        const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2510
2511        let active_block_count = self
2512            .penalties
2513            .blocks
2514            .iter()
2515            .filter(|b| b.lambda > 0.0)
2516            .count();
2517        if rho.len() != active_block_count {
2518            crate::bail_invalid_estim!(
2519                "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2520                rho.len(),
2521                active_block_count
2522            );
2523        }
2524        if beta0.len() != self.coefficient_dim() {
2525            crate::bail_invalid_estim!(
2526                "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2527                beta0.len(),
2528                self.coefficient_dim()
2529            );
2530        }
2531
2532        // Set λ = exp(ρ) on the active blocks (block order), leaving inactive
2533        // (λ = 0) blocks untouched, then re-converge the inner mode.
2534        let mut candidate = self.clone();
2535        let mut lambdas: Vec<f64> = candidate
2536            .penalties
2537            .blocks
2538            .iter()
2539            .map(|b| b.lambda)
2540            .collect();
2541        let mut active_idx = 0usize;
2542        for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2543            if block.lambda > 0.0 {
2544                *lambda = rho[active_idx].exp();
2545                active_idx += 1;
2546            }
2547        }
2548        candidate.set_penalty_lambdas(&lambdas)?;
2549
2550        let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2551            max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2552            convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2553            adaptive_kkt_tolerance: None,
2554            max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2555            min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2556            firth_bias_reduction: false,
2557            coefficient_lower_bounds: None,
2558            linear_constraints: None,
2559            initial_lm_lambda: None,
2560            geodesic_acceleration: false,
2561            arrow_schur: None,
2562        };
2563        let summary = gam_solve::pirls::runworking_model_pirls(
2564            &mut candidate,
2565            Coefficients::new(beta0.clone()),
2566            &opts,
2567            |_| {},
2568        )?;
2569        let mut beta = summary.beta.as_ref().to_owned();
2570
2571        // PIRLS exits on a RELATIVE KKT / deviance-plateau certificate, which can leave
2572        // an ABSOLUTE penalized stationarity residual r = S beta_hat - grad_ell of order
2573        // 0.1-1 (the score scales as O(sqrt(n))). The unified LAML gradient uses the
2574        // envelope theorem, exact only at r = 0; a residual that large leaks <r, beta_dot>
2575        // into the objective<->gradient consistency, AND the IFT envelope correction is
2576        // only leading-order in r, so it cannot make the analytic gradient the exact
2577        // derivative of the (re-converged, non-smooth-in-r) value surface either. The
2578        // robust cure is to drive the inner to TRUE stationarity (||r|| ~ 1e-11) so the
2579        // envelope is exactly valid and the IFT term is ~1e-22 — which it is at small
2580        // lambda, but a plain undamped Newton-polish STALLS at large lambda (rho=4..8):
2581        // there the intercept-direction curvature exp(eta)*n is large while the penalized
2582        // time block is lambda*S, so H is ill-conditioned and an undamped step neither
2583        // decreases ||r|| nor stays feasible, leaving ||r|| ~ 3e-2.
2584        //
2585        // Levenberg–Marquardt damping fixes this: solve (H + mu*diag(H)) delta = r,
2586        // accept on a genuine ||r||^2 decrease (Gauss–Newton on the stationarity system,
2587        // whose Jacobian is the penalized Hessian H), shrink mu on success and grow it on
2588        // rejection. The diagonal (Marquardt) scaling makes the damping curvature-aware so
2589        // the stiff time block and the soft intercept are damped commensurately. This
2590        // reliably reaches ||r|| below the FD-step round-off floor across the whole
2591        // rho = [-0.5 .. 8] range exercised by the consistency gates.
2592        {
2593            const POLISH_MAX_ITERS: usize = 400;
2594            const POLISH_TOL: f64 = 1e-13;
2595            // Armijo sufficient-decrease constant and backtracking factor.
2596            const ARMIJO_C: f64 = 1e-4;
2597            const BACKTRACK: f64 = 0.5;
2598            const MAX_BACKTRACK: usize = 80;
2599            let p = beta.len();
2600            // Penalized inner objective f(β) = −ℓ(β) + ½β'Sβ + ½ridge‖β‖² whose
2601            // gradient is exactly `state.gradient` and whose Hessian is exactly
2602            // `state.hessian`. `update_state` exposes the pieces directly.
2603            let penalized_objective =
2604                |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2605            for _ in 0..POLISH_MAX_ITERS {
2606                let st = match candidate.update_state(&beta) {
2607                    Ok(st) => st,
2608                    Err(_) => break,
2609                };
2610                let r = st.gradient.clone();
2611                let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2612                if !r_norm.is_finite() || r_norm < POLISH_TOL {
2613                    break;
2614                }
2615                let h = st.hessian.to_dense();
2616                let f0 = penalized_objective(&st);
2617                // Newton DIRECTION d = −H⁻¹r on the convex penalized survival
2618                // likelihood, found via a Levenberg–Marquardt-regularized solve
2619                // so an ill-conditioned H (h_diag ratio ~2400 at β₀≈4.6, where
2620                // exp(η) is huge) cannot produce a garbage direction whose
2621                // quadratic form rᵀH⁻¹r loses its sign. If even the regularized
2622                // Newton direction is not a sufficient descent direction, fall
2623                // back to STEEPEST DESCENT d = −r, which is ALWAYS a descent
2624                // direction on the convex objective (∇fᵀ(−r) = −‖r‖² < 0). The
2625                // line search below is on the OBJECTIVE VALUE (not ‖r‖), so any
2626                // descent direction makes monotone progress; near the optimum
2627                // the (lightly regularized) Newton step recovers fast local
2628                // convergence. This is globally convergent for the convex
2629                // penalized survival NLL — driving ‖r‖ below the FD round-off
2630                // floor so the envelope ρ-gradient equals the finite difference.
2631                let h_scale = (0..p)
2632                    .map(|d| h[[d, d]].abs())
2633                    .fold(0.0_f64, f64::max)
2634                    .max(1.0);
2635                // Solve (H + λI) step = r by an EXACT Cholesky factorization
2636                // (faer Llt), NOT the DenseSpectralOperator: the spectral
2637                // operator clamps tiny/negative eigenvalues, which on the
2638                // catastrophically ill-conditioned boundary Hessian (cond ~2400,
2639                // exp(η) huge at β₀≈4.6) corrupts the solve so badly that
2640                // rᵀH⁻¹r lost its sign and the previous polish broke on iter 0.
2641                // Cholesky succeeds iff H+λI is SPD; sweeping λ up from 0 finds
2642                // the smallest SPD shift, and for an SPD system rᵀ(H+λI)⁻¹r > 0
2643                // EXACTLY (Cholesky is backward-stable, no clamping), so the
2644                // Newton direction is a guaranteed descent direction.
2645                let mut step: Option<Array1<f64>> = None;
2646                let mut dir_deriv = 0.0_f64;
2647                for lm_pow in 0..18 {
2648                    let lambda_lm = if lm_pow == 0 {
2649                        0.0
2650                    } else {
2651                        1e-12 * h_scale * 10f64.powi(lm_pow)
2652                    };
2653                    let mut h_reg = h.clone();
2654                    for d in 0..p {
2655                        h_reg[[d, d]] += lambda_lm;
2656                    }
2657                    let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2658                        &h_reg,
2659                        faer::Side::Lower,
2660                    ) {
2661                        Ok(f) => f,
2662                        Err(_) => continue,
2663                    };
2664                    let candidate_step = factor.solvevec(&r);
2665                    if candidate_step.iter().any(|v| !v.is_finite()) {
2666                        continue;
2667                    }
2668                    // ∇fᵀd = rᵀ(−step) = −r·(H+λI)⁻¹r < 0 exactly for SPD systems.
2669                    let dd = -r.dot(&candidate_step);
2670                    if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2671                        step = Some(candidate_step);
2672                        dir_deriv = dd;
2673                        break;
2674                    }
2675                }
2676                let (step, dir_deriv) = match step {
2677                    Some(s) => (s, dir_deriv),
2678                    None => {
2679                        // Steepest-descent fallback: d = −r ⇒ step = +r (we step
2680                        // β − step), ∇fᵀd = −‖r‖² < 0.
2681                        (r.clone(), -r_norm * r_norm)
2682                    }
2683                };
2684                let mut alpha = 1.0_f64;
2685                let mut accepted = false;
2686                for _ in 0..MAX_BACKTRACK {
2687                    let trial = &beta - &(alpha * &step);
2688                    if let Ok(ts) = candidate.update_state(&trial) {
2689                        let ft = penalized_objective(&ts);
2690                        let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2691                        // Accept on EITHER a sufficient objective decrease (Armijo,
2692                        // the global-convergence guarantee on the convex objective)
2693                        // OR a strict residual-norm decrease. Near the solution the
2694                        // penalized objective is flat to f64 roundoff (f0 ≈ ft), so a
2695                        // pure-Armijo test backtracks α→0 and crawls (the asymmetric
2696                        // ρ=3.99999 stall: 200 iters at 3.7e-7 vs 12 iters at the
2697                        // other two ρ). The ‖r‖-decrease arm lets the exact Cholesky
2698                        // Newton step (α=1) through, restoring quadratic convergence
2699                        // to ~1e-12 symmetrically across all three FD points so the
2700                        // centered FD of the value surface is itself exact.
2701                        let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2702                        let residual_ok = tn.is_finite() && tn < r_norm;
2703                        if armijo_ok || residual_ok {
2704                            beta = trial;
2705                            accepted = true;
2706                            break;
2707                        }
2708                    }
2709                    alpha *= BACKTRACK;
2710                }
2711                if !accepted {
2712                    break;
2713                }
2714            }
2715        }
2716
2717        Ok((candidate, beta))
2718    }
2719}
2720
2721/// Derivative provider that adapts survival third-derivative Hessian corrections
2722/// to the unified [`HessianDerivativeProvider`](gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider)
2723/// trait.
2724///
2725/// The unified trait supplies `v_k = H^{-1}(A_k beta_hat)` (positive sign),
2726/// whereas the survival engine's
2727/// [`survival_hessian_derivative_correction`](WorkingModelSurvival::survival_hessian_derivative_correction)
2728/// expects `u_k = -v_k`. This provider handles the sign conversion.
2729pub(crate) struct SurvivalDerivProvider {
2730    model: WorkingModelSurvival,
2731    beta: Array1<f64>,
2732}
2733
2734impl SurvivalDerivProvider {
2735    pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2736        Self { model, beta }
2737    }
2738}
2739
2740impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2741    fn hessian_derivative_correction(
2742        &self,
2743        v_k: &Array1<f64>,
2744    ) -> Result<Option<Array2<f64>>, String> {
2745        // The trait provides v_k = H^{-1}(A_k beta_hat) (positive).
2746        // The survival method expects u_k = -H^{-1} A_k beta_hat = -v_k.
2747        let u_k = -v_k;
2748        match self
2749            .model
2750            .survival_hessian_derivative_correction(&self.beta, &u_k)
2751        {
2752            Ok(correction) => Ok(Some(correction)),
2753            Err(e) => Err(e.to_string()),
2754        }
2755    }
2756
2757    fn has_corrections(&self) -> bool {
2758        true
2759    }
2760}
2761
2762#[derive(Debug, Clone)]
2763pub struct CrudeRiskResult {
2764    pub risk: f64,
2765    pub diseasegradient: Array1<f64>,
2766    pub mortalitygradient: Array1<f64>,
2767}
2768
2769#[derive(Debug, Clone)]
2770pub struct CompetingRisksCifResult {
2771    /// Cumulative incidence per endpoint. `cif[ep][[row, time_idx]]` is the
2772    /// probability that cause `ep` occurred by `times[time_idx]` for sample `row`.
2773    /// Stored one matrix per endpoint so it is ergonomic to index per-cause and
2774    /// natural to construct from the per-endpoint cumulative-hazard inputs.
2775    pub cif: Vec<Array2<f64>>,
2776    pub overall_survival: Array2<f64>,
2777}
2778
2779/// Subject-count threshold below which competing-risks CIF assembly stays on the
2780/// serial path. The per-row work (a `n_times`-long prefix-sum recurrence with a
2781/// handful of `exp`/`exp_m1` per element) is cheap, so small panels avoid rayon
2782/// fan-out overhead; large panels (the #1082 quality-test sizes) amortize it.
2783const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2784
2785pub fn assemble_competing_risks_cif(
2786    times: ArrayView1<'_, f64>,
2787    cumulative_hazard: ArrayView3<'_, f64>,
2788) -> Result<CompetingRisksCifResult, SurvivalError> {
2789    let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2790    if n_endpoints == 0 {
2791        return Err(SurvivalError::DimensionMismatch);
2792    }
2793    let endpoint_hazards = cumulative_hazard
2794        .axis_iter(Axis(0))
2795        .map(|view| view.to_owned())
2796        .collect::<Vec<_>>();
2797    assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2798        if result.overall_survival.dim() != (n_rows, n_times) {
2799            Err(SurvivalError::DimensionMismatch)
2800        } else {
2801            Ok(result)
2802        }
2803    })
2804}
2805
2806pub fn assemble_competing_risks_cif_from_endpoints(
2807    times: ArrayView1<'_, f64>,
2808    cumulative_hazards: &[Array2<f64>],
2809) -> Result<CompetingRisksCifResult, SurvivalError> {
2810    let n_endpoints = cumulative_hazards.len();
2811    if n_endpoints == 0 || times.is_empty() {
2812        return Err(SurvivalError::DimensionMismatch);
2813    }
2814    let (n_rows, n_times) = cumulative_hazards[0].dim();
2815    if n_rows == 0 || n_times == 0 || times.len() != n_times {
2816        return Err(SurvivalError::DimensionMismatch);
2817    }
2818    if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2819        return Err(SurvivalError::InvalidTimeGrid);
2820    }
2821    if times
2822        .iter()
2823        .zip(times.iter().skip(1))
2824        .any(|(previous, current)| current <= previous)
2825    {
2826        return Err(SurvivalError::InvalidTimeGrid);
2827    }
2828    for endpoint_hazard in cumulative_hazards {
2829        if endpoint_hazard.dim() != (n_rows, n_times) {
2830            return Err(SurvivalError::DimensionMismatch);
2831        }
2832        if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2833            return Err(SurvivalError::NonFiniteInput);
2834        }
2835    }
2836
2837    let max_abs_hazard = cumulative_hazards
2838        .iter()
2839        .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2840        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2841    let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2842    let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2843        .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2844        .collect();
2845    let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2846
2847    // Per-row CIF assembly. The TIME axis is a sequential prefix-sum recurrence
2848    // (`previous_*` carry forward across `time_idx`) and MUST stay ordered, so it
2849    // is left as the inner serial loop. The ROW (subject) axis is fully
2850    // independent: every `previous_*`/`increments` buffer is allocated fresh per
2851    // row, no state crosses rows, and each row writes only its own disjoint
2852    // output slices. The per-row result is byte-identical regardless of which
2853    // thread runs it, so we fan the outer row loop out over rayon and write the
2854    // owned per-row buffers back serially in row order (deterministic, bit-exact
2855    // vs. the serial implementation).
2856    //
2857    // `cif_flat` is endpoint-major: `cif_flat[endpoint * n_times + time_idx]`.
2858    let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2859        let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2860        let mut surv_row = vec![0.0_f64; n_times];
2861        let mut previous_cif = vec![0.0_f64; n_endpoints];
2862        let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2863        let mut increments = vec![0.0_f64; n_endpoints];
2864        let mut previous_total_cumulative = 0.0_f64;
2865        for time_idx in 0..n_times {
2866            let mut total_increment = 0.0_f64;
2867            for endpoint in 0..n_endpoints {
2868                let current = cumulative_hazards[endpoint][[row, time_idx]];
2869                if current < -monotone_tolerance {
2870                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2871                }
2872                let raw_increment = current - previous_cumulative[endpoint];
2873                if raw_increment < -monotone_tolerance {
2874                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2875                }
2876                let increment = raw_increment.max(0.0);
2877                increments[endpoint] = increment;
2878                total_increment += increment;
2879                previous_cumulative[endpoint] += increment;
2880            }
2881
2882            let survival_left = (-previous_total_cumulative).exp();
2883            let interval_failure = -(-total_increment).exp_m1();
2884            for endpoint in 0..n_endpoints {
2885                if total_increment > 0.0 {
2886                    previous_cif[endpoint] +=
2887                        survival_left * interval_failure * increments[endpoint] / total_increment;
2888                }
2889                cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2890            }
2891            previous_total_cumulative += total_increment;
2892            // Derive `S(t)` from the stored cause-specific CIFs at this time so
2893            // that the competing-risks closure identity
2894            //   Σ_k F_k(t) + S(t) = 1
2895            // holds bit-exactly. Computing `S` independently as
2896            // `exp(-Σ_k H_k(t))` and then comparing against the (clamped, ratio-
2897            // split) Σ F_k introduces O(machine-eps) closure error because the
2898            // float increments
2899            //   ΔF_k = S_left·(1-exp(-ΔH))·ΔH_k/ΔH_total
2900            // do not sum to `S_left - S_new` bit-exactly. By summing the stored
2901            // CIFs in the same left-fold order as `slice.iter().sum::<f64>()`
2902            // and defining `S := 1.0 - Σ F_k`, the IEEE-754 round-trip
2903            //   (1.0 - f) + f
2904            // restores the identity for finite f ∈ [0, 1]. The mathematically
2905            // consistent survival value `exp(-H_total)` is still tracked up to
2906            // ulp-level precision because the ΔF_k construction matches
2907            // `S_left - S_new` to leading order.
2908            let mut fsum_at_t = 0.0_f64;
2909            for endpoint in 0..n_endpoints {
2910                fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2911            }
2912            surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2913        }
2914        Ok((cif_flat, surv_row))
2915    };
2916
2917    // Nesting guard (`rayon::current_thread_index().is_none()`) keeps us from
2918    // oversubscribing when this routine is itself called from inside a rayon
2919    // worker, and the row-count gate keeps small inputs on the serial path.
2920    let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2921        && rayon::current_thread_index().is_none()
2922    {
2923        use rayon::prelude::*;
2924        (0..n_rows)
2925            .into_par_iter()
2926            .map(assemble_row)
2927            .collect::<Result<_, _>>()?
2928    } else {
2929        (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2930    };
2931
2932    for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2933        for endpoint in 0..n_endpoints {
2934            for time_idx in 0..n_times {
2935                cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2936            }
2937        }
2938        for time_idx in 0..n_times {
2939            overall_survival[[row, time_idx]] = surv_row[time_idx];
2940        }
2941    }
2942
2943    Ok(CompetingRisksCifResult {
2944        cif,
2945        overall_survival,
2946    })
2947}
2948
2949fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2950    let mut nodesweights = Vec::with_capacity(n);
2951    let m = n.div_ceil(2);
2952
2953    for i in 0..m {
2954        let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2955        let mut pp = 0.0;
2956
2957        for _ in 0..100 {
2958            let mut p1 = 1.0;
2959            let mut p2 = 0.0;
2960            for j in 0..n {
2961                let p3 = p2;
2962                p2 = p1;
2963                p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2964            }
2965            pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2966            let z_prev = z;
2967            z = z_prev - p1 / pp;
2968            if (z - z_prev).abs() < 1e-14 {
2969                break;
2970            }
2971        }
2972
2973        let x = z;
2974        let w = 2.0 / ((1.0 - z * z) * pp * pp);
2975        if !n.is_multiple_of(2) && i == m - 1 {
2976            nodesweights.push((0.0, w));
2977        } else {
2978            nodesweights.push((-x, w));
2979            nodesweights.push((x, w));
2980        }
2981    }
2982
2983    nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2984    nodesweights
2985}
2986
2987fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2988    // `LazyLock` (not `OnceLock::get_or_init`) so first init never parks a
2989    // caller on the OS condvar from inside a rayon worker. The competing-risks
2990    // CIF assembler in this file dispatches `into_par_iter` and the
2991    // codebase-level lint (`tests/once_lock_get_or_init_not_inside_parallel_regions.rs`)
2992    // forbids the lazy `OnceLock` accessor in any rayon-adjacent file.
2993    static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2994    &CACHE
2995}
2996
2997/// Engine-level crude risk quadrature with exact delta-method gradients.
2998///
2999/// This routine owns the numerical integration and gradient accumulation math:
3000/// - It integrates `h_d(u) * S_total(u | t0)` over `[t0, t1]` by high-order
3001///   Gauss-Legendre quadrature.
3002/// - It computes gradients w.r.t. disease and mortality coefficients:
3003///   d Risk / d beta_d and d Risk / d beta_m.
3004///
3005/// The adapter provides the domain-specific point evaluator callback `eval_at`,
3006/// which fills design rows and returns:
3007/// - instantaneous disease hazard h_d(u) at age `u`,
3008/// - cumulative disease hazard `H_d(u)`,
3009/// - cumulative mortality hazard `H_m(u)`.
3010///
3011/// The callback must fill the following arrays (one entry per coefficient):
3012/// - `design_d[j]`: partial derivative of the linear predictor eta_d w.r.t. beta_j
3013///   at time u, i.e. x_j(u) = d eta_d(u) / d beta_j.
3014/// - `deriv_d[j]`: partial derivative of the TIME DERIVATIVE of eta_d w.r.t. beta_j
3015///   at time u, i.e. x_dot_j(u) = d/d(beta_j) [d eta_d(u)/du].
3016/// - `design_m[j]`: same as design_d but for the mortality linear predictor eta_m.
3017///
3018/// This keeps domain/data wiring out of `gam` while centralizing the
3019/// integration engine in one place.
3020pub fn calculate_crude_risk_quadrature<F>(
3021    t0: f64,
3022    t1: f64,
3023    breakpoints: &[f64],
3024    h_dis_t0: f64,
3025    h_mor_t0: f64,
3026    design_d_t0: ArrayView1<'_, f64>,
3027    design_m_t0: ArrayView1<'_, f64>,
3028    mut eval_at: F,
3029) -> Result<CrudeRiskResult, SurvivalError>
3030where
3031    F: FnMut(
3032        f64,
3033        &mut Array1<f64>,
3034        &mut Array1<f64>,
3035        &mut Array1<f64>,
3036    ) -> Result<(f64, f64, f64), SurvivalError>,
3037{
3038    let coeff_len_d = design_d_t0.len();
3039    let coeff_len_m = design_m_t0.len();
3040    if coeff_len_d == 0 || coeff_len_m == 0 {
3041        return Err(SurvivalError::InvalidIntegrationSetup);
3042    }
3043    if !t0.is_finite()
3044        || !t1.is_finite()
3045        || !h_dis_t0.is_finite()
3046        || !h_mor_t0.is_finite()
3047        || design_d_t0.iter().any(|v| !v.is_finite())
3048        || design_m_t0.iter().any(|v| !v.is_finite())
3049    {
3050        return Err(SurvivalError::NonFiniteInput);
3051    }
3052    if t1 <= t0 {
3053        return Ok(CrudeRiskResult {
3054            risk: 0.0,
3055            diseasegradient: Array1::zeros(coeff_len_d),
3056            mortalitygradient: Array1::zeros(coeff_len_m),
3057        });
3058    }
3059
3060    let mut sorted_breaks: Vec<f64> = breakpoints
3061        .iter()
3062        .copied()
3063        .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3064        .collect();
3065    sorted_breaks.push(t0);
3066    sorted_breaks.push(t1);
3067    sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3068    sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3069    if sorted_breaks.len() < 2 {
3070        return Err(SurvivalError::InvalidIntegrationSetup);
3071    }
3072
3073    let mut total_risk = 0.0;
3074    let mut diseasegradient = Array1::zeros(coeff_len_d);
3075    let mut mortalitygradient = Array1::zeros(coeff_len_m);
3076    let nodesweights = gauss_legendre_quadrature();
3077
3078    let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3079    let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3080    let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3081
3082    for segment in sorted_breaks.windows(2) {
3083        let a = segment[0];
3084        let b = segment[1];
3085        let center = 0.5 * (b + a);
3086        let halfwidth = 0.5 * (b - a);
3087        if halfwidth <= 0.0 {
3088            continue;
3089        }
3090
3091        for &(x, w) in nodesweights {
3092            let u = center + halfwidth * x;
3093            let (inst_hazard_d, hazard_d, hazard_m) =
3094                eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3095            if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3096                return Err(SurvivalError::NonFiniteInput);
3097            }
3098            if inst_hazard_d <= 0.0 {
3099                return Err(SurvivalError::NonPositiveHazard);
3100            }
3101
3102            if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3103                return Err(SurvivalError::NonMonotoneCumulativeHazard);
3104            }
3105
3106            let h_dis_cond = hazard_d - h_dis_t0;
3107            let h_mor_cond = hazard_m - h_mor_t0;
3108            let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3109
3110            total_risk += w * inst_hazard_d * s_total * halfwidth;
3111
3112            // d Risk / d beta_d:
3113            //   integral [ d h_d * S_total - h_d * S_total * d H_d ] du
3114            // Contract: design_d[j] = x_j(u) = ∂_{β_j} η_d(u)
3115            //           deriv_d[j]  = ẋ_j(u) = ∂_{β_j} η̇_d(u)
3116            // Then ∂_{β_j} h_d = h_d · x_j + H_d · ẋ_j
3117            let weight = w * s_total * halfwidth;
3118            for j in 0..coeff_len_d {
3119                let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3120                let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3121                let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3122                diseasegradient[j] += weight * g;
3123            }
3124
3125            // d Risk / d beta_m:
3126            //   -integral h_d * S_total * d H_m(u|t0) du
3127            let weight = w * inst_hazard_d * s_total * halfwidth;
3128            for j in 0..coeff_len_m {
3129                let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3130                mortalitygradient[j] += weight * g;
3131            }
3132        }
3133    }
3134
3135    Ok(CrudeRiskResult {
3136        risk: total_risk,
3137        diseasegradient,
3138        mortalitygradient,
3139    })
3140}
3141
3142impl PirlsWorkingModel for WorkingModelSurvival {
3143    fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3144        self.update_state(beta)
3145    }
3146}
3147
3148#[cfg(test)]
3149mod tests {
3150    use super::*;
3151    use ndarray::{Array1, Array2, Array3, array, s};
3152
3153    #[test]
3154    fn competing_risks_cif_constant_hazard_matches_closed_form() {
3155        let times = array![0.0, 2.0, 5.0, 10.0];
3156        let disease_rates = [0.12, 0.06];
3157        let death_rates = [0.05, 0.02];
3158        let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3159            let rate = if endpoint == 0 {
3160                disease_rates[row]
3161            } else {
3162                death_rates[row]
3163            };
3164            rate * times[time_idx]
3165        });
3166
3167        let result =
3168            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3169
3170        for row in 0..2 {
3171            let total_rate = disease_rates[row] + death_rates[row];
3172            for time_idx in 0..times.len() {
3173                let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3174                let expected_disease = disease_rates[row] / total_rate * failure;
3175                let expected_death = death_rates[row] / total_rate * failure;
3176                assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3177                assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3178                assert!(
3179                    (result.cif[0][[row, time_idx]]
3180                        + result.cif[1][[row, time_idx]]
3181                        + result.overall_survival[[row, time_idx]]
3182                        - 1.0)
3183                        .abs()
3184                        < 1e-12
3185                );
3186            }
3187        }
3188    }
3189
3190    #[test]
3191    fn competing_risks_cif_rejects_nonmonotone_hazards() {
3192        let times = array![0.0, 1.0, 2.0];
3193        let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3194        let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3195            .expect_err("nonmonotone cumulative hazard should be rejected");
3196        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3197    }
3198
3199    #[test]
3200    fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3201        let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3202        let cumulative = Array3::from_shape_vec(
3203            (3, 2, 5),
3204            vec![
3205                // cause 1
3206                0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, // cause 2
3207                0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, // cause 3
3208                0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3209            ],
3210        )
3211        .expect("shape");
3212
3213        let result =
3214            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3215
3216        for row in 0..2 {
3217            for time_idx in 0..times.len() {
3218                let total_cif = result.cif[0][[row, time_idx]]
3219                    + result.cif[1][[row, time_idx]]
3220                    + result.cif[2][[row, time_idx]];
3221                assert!(
3222                    (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3223                    "probability mass mismatch at row={row}, time_idx={time_idx}"
3224                );
3225                assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3226                for cause in 0..3 {
3227                    assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3228                    if time_idx > 0 {
3229                        assert!(
3230                            result.cif[cause][[row, time_idx]] + 1e-12
3231                                >= result.cif[cause][[row, time_idx - 1]],
3232                            "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3233                        );
3234                    }
3235                }
3236            }
3237        }
3238
3239        // Cause 1 is flat between t=1 and t=3 for row 0, but other causes
3240        // fail in that interval; its CIF must remain exactly flat.
3241        assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3242        // All causes are flat between t=3 and t=7 for row 1 except cause 2;
3243        // causes 1 and 3 must not move.
3244        assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3245        assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3246    }
3247
3248    #[test]
3249    fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3250        let cumulative = Array3::zeros((2, 1, 2));
3251
3252        for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3253            let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3254                .expect_err("bad time grid should be rejected");
3255            assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3256        }
3257
3258        let times = array![0.0, 1.0];
3259        let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3260        let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3261            .expect_err("nonfinite hazard should be rejected");
3262        assert!(matches!(err, SurvivalError::NonFiniteInput));
3263    }
3264
3265    #[test]
3266    fn competing_risks_cif_extreme_hazards_remain_bounded() {
3267        let times = array![0.0, 1.0, 2.0];
3268        let cumulative =
3269            Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3270                .expect("shape");
3271
3272        let result =
3273            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3274
3275        for value in result
3276            .cif
3277            .iter()
3278            .flat_map(|m| m.iter())
3279            .chain(result.overall_survival.iter())
3280        {
3281            assert!(value.is_finite());
3282            assert!((0.0..=1.0).contains(value));
3283        }
3284        assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3285        assert_eq!(result.overall_survival[[0, 2]], 0.0);
3286    }
3287
3288    fn toy_penalties() -> PenaltyBlocks {
3289        let s = array![[2.0, 0.5], [0.5, 3.0]];
3290        PenaltyBlocks::new(vec![PenaltyBlock {
3291            matrix: s,
3292            lambda: 1.7,
3293            range: 1..3,
3294            nullspace_dim: 0,
3295        }])
3296    }
3297
3298    fn survival_inputs<'a>(
3299        age_entry: &'a Array1<f64>,
3300        age_exit: &'a Array1<f64>,
3301        event_target: &'a Array1<u8>,
3302        event_competing: &'a Array1<u8>,
3303        sampleweight: &'a Array1<f64>,
3304        x_entry: &'a Array2<f64>,
3305        x_exit: &'a Array2<f64>,
3306        x_derivative: &'a Array2<f64>,
3307    ) -> SurvivalEngineInputs<'a> {
3308        SurvivalEngineInputs {
3309            age_entry: age_entry.view(),
3310            age_exit: age_exit.view(),
3311            event_target: event_target.view(),
3312            event_competing: event_competing.view(),
3313            sampleweight: sampleweight.view(),
3314            x_entry: x_entry.view(),
3315            x_exit: x_exit.view(),
3316            x_derivative: x_derivative.view(),
3317            monotonicity_constraint_rows: None,
3318            monotonicity_constraint_offsets: None,
3319        }
3320    }
3321
3322    fn survival_model(
3323        inputs: SurvivalEngineInputs<'_>,
3324        penalties: PenaltyBlocks,
3325        monotonicity: SurvivalMonotonicityPenalty,
3326        spec: SurvivalSpec,
3327    ) -> Result<WorkingModelSurvival, SurvivalError> {
3328        WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3329    }
3330
3331    fn survival_model_with_offsets(
3332        inputs: SurvivalEngineInputs<'_>,
3333        offsets: Option<SurvivalBaselineOffsets<'_>>,
3334        penalties: PenaltyBlocks,
3335        monotonicity: SurvivalMonotonicityPenalty,
3336        spec: SurvivalSpec,
3337    ) -> Result<WorkingModelSurvival, SurvivalError> {
3338        WorkingModelSurvival::from_engine_inputswith_offsets(
3339            inputs,
3340            offsets,
3341            penalties,
3342            monotonicity,
3343            spec,
3344        )
3345    }
3346
3347    #[test]
3348    fn penaltyhessian_matchesgradient_jacobian() {
3349        let penalties = toy_penalties();
3350        let beta = array![10.0, -0.3, 1.2, 7.0];
3351
3352        let grad = penalties.gradient(&beta);
3353        let h = penalties.hessian(beta.len());
3354        let b_block = beta.slice(s![1..3]).to_owned();
3355        let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3356
3357        assert!((grad[1] - expected[0]).abs() < 1e-12);
3358        assert!((grad[2] - expected[1]).abs() < 1e-12);
3359        assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3360        assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3361        assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3362        assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3363    }
3364
3365    #[test]
3366    fn penaltygradient_matches_deviance_finite_difference() {
3367        let penalties = toy_penalties();
3368        let beta = array![10.0, -0.3, 1.2, 7.0];
3369        let grad = penalties.gradient(&beta);
3370        let eps = 1e-7;
3371
3372        for idx in 0..beta.len() {
3373            let mut plus = beta.clone();
3374            let mut minus = beta.clone();
3375            plus[idx] += eps;
3376            minus[idx] -= eps;
3377            let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3378            assert_eq!(
3379                grad[idx].signum(),
3380                fd.signum(),
3381                "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3382                grad[idx]
3383            );
3384            assert!(
3385                (grad[idx] - fd).abs() < 1e-6,
3386                "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3387                grad[idx]
3388            );
3389        }
3390    }
3391
3392    #[test]
3393    fn zero_offsets_match_default_survival_state() {
3394        let age_entry = array![1.0_f64, 2.0_f64];
3395        let age_exit = array![2.0_f64, 3.5_f64];
3396        let event_target = array![1u8, 0u8];
3397        let event_competing = array![0u8, 0u8];
3398        let sampleweight = array![1.0, 1.0];
3399        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3400        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3401        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3402        let penalties = PenaltyBlocks::new(Vec::new());
3403        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3404        let beta = array![-1.0, 0.8];
3405
3406        let base = survival_model(
3407            survival_inputs(
3408                &age_entry,
3409                &age_exit,
3410                &event_target,
3411                &event_competing,
3412                &sampleweight,
3413                &x_entry,
3414                &x_exit,
3415                &x_derivative,
3416            ),
3417            penalties.clone(),
3418            mono,
3419            SurvivalSpec::Net,
3420        )
3421        .expect("construct base survival model");
3422
3423        let zero_offsets = survival_model_with_offsets(
3424            survival_inputs(
3425                &age_entry,
3426                &age_exit,
3427                &event_target,
3428                &event_competing,
3429                &sampleweight,
3430                &x_entry,
3431                &x_exit,
3432                &x_derivative,
3433            ),
3434            Some(SurvivalBaselineOffsets {
3435                eta_entry: array![0.0, 0.0].view(),
3436                eta_exit: array![0.0, 0.0].view(),
3437                derivative_exit: array![0.0, 0.0].view(),
3438            }),
3439            penalties,
3440            mono,
3441            SurvivalSpec::Net,
3442        )
3443        .expect("construct offset survival model");
3444
3445        let state_base = base.update_state(&beta).expect("base state");
3446        let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3447        assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3448        assert!(
3449            state_base
3450                .gradient
3451                .iter()
3452                .zip(statezero.gradient.iter())
3453                .all(|(a, b)| (a - b).abs() < 1e-12)
3454        );
3455    }
3456
3457    #[test]
3458    fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3459        // Regression for #378: the joint competing-risks Weibull path seeds a
3460        // shared single-hazard baseline working model from the dataset's event
3461        // *labels* {0 = censored, 1 = cause 1, 2 = cause 2}. The single-hazard
3462        // engine's `event_target` contract is binary {0, 1}, so feeding the raw
3463        // cause labels straight through used to bail out of construction via the
3464        // `event_target > 1` guard and surface as the misleading
3465        // `SurvivalError::NonFiniteInput` ("inputs contain non-finite values"),
3466        // even though every input value is finite. The fix (a) reports a
3467        // multi-cause label as the actionable `EventCodeInvalid`, never the
3468        // misleading "non-finite", and (b) projects cause labels to the
3469        // any-event {0, 1} indicator via the single-source-of-truth
3470        // `pooled_any_event_indicator` before constructing the pooled baseline.
3471        // This pins both halves of that contract.
3472        let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3473        let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3474        // Competing-risks cause labels: censored, cause 1, cause 2, censored.
3475        let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3476        let event_competing = Array1::<u8>::zeros(cause_labels.len());
3477        let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3478        let x_entry = array![
3479            [1.0, age_entry[0].max(1e-8).ln()],
3480            [1.0, age_entry[1].max(1e-8).ln()],
3481            [1.0, age_entry[2].max(1e-8).ln()],
3482            [1.0, age_entry[3].max(1e-8).ln()],
3483        ];
3484        let x_exit = array![
3485            [1.0, age_exit[0].ln()],
3486            [1.0, age_exit[1].ln()],
3487            [1.0, age_exit[2].ln()],
3488            [1.0, age_exit[3].ln()],
3489        ];
3490        let x_derivative = array![
3491            [0.0, 1.0 / age_exit[0]],
3492            [0.0, 1.0 / age_exit[1]],
3493            [0.0, 1.0 / age_exit[2]],
3494            [0.0, 1.0 / age_exit[3]],
3495        ];
3496        let penalties = PenaltyBlocks::new(Vec::new());
3497        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3498
3499        // Raw cause labels {0,1,2} violate the single-hazard binary contract and
3500        // must be rejected -- but as an *actionable* `EventCodeInvalid`, NOT the
3501        // misleading `NonFiniteInput`: the labels are finite, they merely need
3502        // projecting. (The old fix left this surfacing as "non-finite".)
3503        let raw = survival_model(
3504            survival_inputs(
3505                &age_entry,
3506                &age_exit,
3507                &cause_labels,
3508                &event_competing,
3509                &sampleweight,
3510                &x_entry,
3511                &x_exit,
3512                &x_derivative,
3513            ),
3514            penalties.clone(),
3515            mono,
3516            SurvivalSpec::Net,
3517        );
3518        assert!(
3519            matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3520            "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3521        );
3522
3523        // The pooled-baseline projection the workflow now performs through the
3524        // single source of truth: any observed event (any cause) -> {0, 1}.
3525        let any_event = pooled_any_event_indicator(cause_labels.view());
3526        assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3527        // And the per-cause projection that seeds each cause-specific block.
3528        assert_eq!(
3529            cause_specific_event_indicator(cause_labels.view(), 1),
3530            array![0u8, 1u8, 0u8, 0u8]
3531        );
3532        assert_eq!(
3533            cause_specific_event_indicator(cause_labels.view(), 2),
3534            array![0u8, 0u8, 1u8, 0u8]
3535        );
3536        let model = survival_model(
3537            survival_inputs(
3538                &age_entry,
3539                &age_exit,
3540                &any_event,
3541                &event_competing,
3542                &sampleweight,
3543                &x_entry,
3544                &x_exit,
3545                &x_derivative,
3546            ),
3547            penalties,
3548            mono,
3549            SurvivalSpec::Net,
3550        )
3551        .expect("pooled any-event baseline model must construct from competing-risks data");
3552
3553        // The constructed pooled baseline must yield a finite working state, so
3554        // the downstream baseline-seeding PIRLS loop has something to optimize.
3555        let beta = array![-1.0_f64, 0.8];
3556        let state = model.update_state(&beta).expect("pooled baseline state");
3557        assert!(
3558            state.deviance.is_finite(),
3559            "pooled baseline deviance must be finite, got {}",
3560            state.deviance
3561        );
3562        assert!(
3563            state.gradient.iter().all(|g| g.is_finite()),
3564            "pooled baseline gradient must be finite"
3565        );
3566    }
3567
3568    #[test]
3569    fn offset_channel_residuals_match_central_fd_of_nll() {
3570        // Three observations: two events (non-origin entry and origin entry)
3571        // and one censored row. This exercises every nonzero channel at least
3572        // once: r_exit from all rows, r_entry only from the first (has entry
3573        // interval), r_derivative only from events.
3574        let age_entry = array![0.5_f64, 0.0, 0.3];
3575        let age_exit = array![1.4_f64, 1.0, 2.0];
3576        let event_target = array![1u8, 1u8, 0u8];
3577        let event_competing = array![0u8, 0u8, 0u8];
3578        let sampleweight = array![1.0_f64, 2.5, 0.7];
3579        let x_entry = array![
3580            [1.0, age_entry[0].ln()],
3581            [1.0, age_entry[1].max(1e-8).ln()],
3582            [1.0, age_entry[2].ln()]
3583        ];
3584        let x_exit = array![
3585            [1.0, age_exit[0].ln()],
3586            [1.0, age_exit[1].ln()],
3587            [1.0, age_exit[2].ln()]
3588        ];
3589        let x_derivative = array![
3590            [0.0, 1.0 / age_exit[0]],
3591            [0.0, 1.0 / age_exit[1]],
3592            [0.0, 1.0 / age_exit[2]]
3593        ];
3594        // Baseline offsets chosen so η_entry, η_exit, s are all comfortably
3595        // away from overflow / monotonicity-violation boundaries.
3596        let o_entry = array![0.2_f64, 0.0, 0.1];
3597        let o_exit = array![0.4_f64, 0.5, 0.7];
3598        let o_deriv = array![0.3_f64, 0.8, 0.5];
3599        let penalties = PenaltyBlocks::new(Vec::new());
3600        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3601        let beta = array![-0.7_f64, 0.6];
3602
3603        let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3604            survival_model_with_offsets(
3605                survival_inputs(
3606                    &age_entry,
3607                    &age_exit,
3608                    &event_target,
3609                    &event_competing,
3610                    &sampleweight,
3611                    &x_entry,
3612                    &x_exit,
3613                    &x_derivative,
3614                ),
3615                Some(SurvivalBaselineOffsets {
3616                    eta_entry: o_e.view(),
3617                    eta_exit: o_x.view(),
3618                    derivative_exit: o_d.view(),
3619                }),
3620                penalties.clone(),
3621                mono,
3622                SurvivalSpec::Net,
3623            )
3624            .expect("model build")
3625        };
3626
3627        let base = build(&o_entry, &o_exit, &o_deriv);
3628        let resid = base
3629            .offset_channel_residuals(&beta)
3630            .expect("offset residuals");
3631        assert_eq!(resid.exit.len(), 3);
3632        assert_eq!(resid.entry.len(), 3);
3633        assert_eq!(resid.derivative.len(), 3);
3634
3635        // NLL equals half the deviance returned by update_state; that is the
3636        // exact unpenalized loss whose offset partials r_{X,E,D} encode.
3637        let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3638        let h = 1e-6;
3639
3640        // Row 1 (origin entry, event=1) has no entry interval, so r_entry[1]
3641        // must be exactly 0. Row 2 (censored) has r_deriv[2] exactly 0. Check
3642        // those identities before FD comparison on the nonzero elements.
3643        assert_eq!(resid.entry[1], 0.0);
3644        assert_eq!(resid.derivative[2], 0.0);
3645
3646        for i in 0..3 {
3647            // exit channel: perturb o_exit[i] alone.
3648            {
3649                let mut op = o_exit.clone();
3650                let mut om = o_exit.clone();
3651                op[i] += h;
3652                om[i] -= h;
3653                let fd = (nll(&build(&o_entry, &op, &o_deriv))
3654                    - nll(&build(&o_entry, &om, &o_deriv)))
3655                    / (2.0 * h);
3656                assert!(
3657                    (resid.exit[i] - fd).abs() < 1e-6,
3658                    "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3659                    resid.exit[i],
3660                    fd
3661                );
3662            }
3663            // entry channel: only row 0 has an entry interval; for rows with
3664            // entry_at_origin the offset contributes nothing to NLL and FD
3665            // must also be exactly 0 to numerical precision.
3666            {
3667                let mut op = o_entry.clone();
3668                let mut om = o_entry.clone();
3669                op[i] += h;
3670                om[i] -= h;
3671                let fd = (nll(&build(&op, &o_exit, &o_deriv))
3672                    - nll(&build(&om, &o_exit, &o_deriv)))
3673                    / (2.0 * h);
3674                assert!(
3675                    (resid.entry[i] - fd).abs() < 1e-6,
3676                    "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3677                    resid.entry[i],
3678                    fd
3679                );
3680            }
3681            // derivative channel: only event rows contribute.
3682            {
3683                let mut op = o_deriv.clone();
3684                let mut om = o_deriv.clone();
3685                op[i] += h;
3686                om[i] -= h;
3687                let fd = (nll(&build(&o_entry, &o_exit, &op))
3688                    - nll(&build(&o_entry, &o_exit, &om)))
3689                    / (2.0 * h);
3690                assert!(
3691                    (resid.derivative[i] - fd).abs() < 1e-6,
3692                    "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3693                    resid.derivative[i],
3694                    fd
3695                );
3696            }
3697        }
3698    }
3699
3700    #[test]
3701    fn offset_channel_residuals_respect_zero_sampleweight() {
3702        let age_entry = array![1.0_f64, 2.0];
3703        let age_exit = array![2.0_f64, 3.5];
3704        let event_target = array![1u8, 1u8];
3705        let event_competing = array![0u8, 0u8];
3706        let sampleweight = array![0.0_f64, 1.2]; // row 0 is excluded by weight
3707        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3708        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3709        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3710        let penalties = PenaltyBlocks::new(Vec::new());
3711        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3712        let beta = array![-1.0_f64, 0.8];
3713
3714        let model = survival_model_with_offsets(
3715            survival_inputs(
3716                &age_entry,
3717                &age_exit,
3718                &event_target,
3719                &event_competing,
3720                &sampleweight,
3721                &x_entry,
3722                &x_exit,
3723                &x_derivative,
3724            ),
3725            Some(SurvivalBaselineOffsets {
3726                eta_entry: array![0.0_f64, 0.1].view(),
3727                eta_exit: array![0.0_f64, 0.2].view(),
3728                derivative_exit: array![0.0_f64, 0.1].view(),
3729            }),
3730            penalties,
3731            mono,
3732            SurvivalSpec::Net,
3733        )
3734        .expect("model");
3735        let r = model.offset_channel_residuals(&beta).expect("resid");
3736        // Row 0 (sampleweight=0) must contribute zero in every channel.
3737        assert_eq!(r.exit[0], 0.0);
3738        assert_eq!(r.entry[0], 0.0);
3739        assert_eq!(r.derivative[0], 0.0);
3740        // Row 1 must still carry a nonzero exit-channel residual.
3741        assert!(r.exit[1] != 0.0);
3742    }
3743
3744    #[test]
3745    fn offset_channel_residuals_reject_beta_dim_mismatch() {
3746        let age_entry = array![1.0_f64];
3747        let age_exit = array![2.0_f64];
3748        let event_target = array![1u8];
3749        let event_competing = array![0u8];
3750        let sampleweight = array![1.0_f64];
3751        let x_entry = array![[1.0, 0.0]];
3752        let x_exit = array![[1.0, 0.7]];
3753        let x_derivative = array![[0.0, 0.5]];
3754        let penalties = PenaltyBlocks::new(Vec::new());
3755        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3756        let model = survival_model(
3757            survival_inputs(
3758                &age_entry,
3759                &age_exit,
3760                &event_target,
3761                &event_competing,
3762                &sampleweight,
3763                &x_entry,
3764                &x_exit,
3765                &x_derivative,
3766            ),
3767            penalties,
3768            mono,
3769            SurvivalSpec::Net,
3770        )
3771        .expect("model");
3772        let bad_beta = array![0.0_f64]; // should be length 2
3773        let err = model
3774            .offset_channel_residuals(&bad_beta)
3775            .expect_err("mismatch must error");
3776        match err {
3777            EstimationError::InvalidInput(msg) => {
3778                assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3779            }
3780            other => panic!("expected InvalidInput, got {other:?}"),
3781        }
3782    }
3783
3784    #[test]
3785    fn crudespec_is_rejected_by_one_hazard_engine() {
3786        let age_entry = array![1.0_f64];
3787        let age_exit = array![2.0_f64];
3788        let event_target = array![0u8];
3789        let event_competing = array![1u8];
3790        let sampleweight = array![1.0];
3791        let x_entry = array![[0.1]];
3792        let x_exit = array![[0.4]];
3793        let x_derivative = array![[1.0]];
3794        let penalties = PenaltyBlocks::new(Vec::new());
3795        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3796
3797        let err = survival_model(
3798            survival_inputs(
3799                &age_entry,
3800                &age_exit,
3801                &event_target,
3802                &event_competing,
3803                &sampleweight,
3804                &x_entry,
3805                &x_exit,
3806                &x_derivative,
3807            ),
3808            penalties,
3809            mono,
3810            SurvivalSpec::Crude,
3811        )
3812        .expect_err("crude fitting should be rejected by the one-hazard engine");
3813        assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3814    }
3815
3816    #[test]
3817    fn nonstructural_models_require_explicit_monotonicity_collocation() {
3818        let age_entry = array![1.0_f64, 1.5_f64];
3819        let age_exit = array![2.0_f64, 2.5_f64];
3820        let event_target = array![0u8, 0u8];
3821        let event_competing = array![0u8, 1u8];
3822        let sampleweight = array![1.0, 1.0];
3823        let x_entry = array![[0.2], [0.1]];
3824        let x_exit = array![[0.3], [0.2]];
3825        let x_derivative = array![[1.0], [1.0]];
3826
3827        let model = survival_model(
3828            survival_inputs(
3829                &age_entry,
3830                &age_exit,
3831                &event_target,
3832                &event_competing,
3833                &sampleweight,
3834                &x_entry,
3835                &x_exit,
3836                &x_derivative,
3837            ),
3838            PenaltyBlocks::new(Vec::new()),
3839            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3840            SurvivalSpec::Net,
3841        )
3842        .expect("construct censored survival model");
3843
3844        assert!(
3845            model.monotonicity_linear_constraints().is_none(),
3846            "non-structural survival models must not fabricate rowwise monotonicity constraints"
3847        );
3848    }
3849
3850    #[test]
3851    fn decreasing_interval_is_rejectedwithout_target_events() {
3852        let age_entry = array![1.0_f64];
3853        let age_exit = array![2.0_f64];
3854        let event_target = array![0u8];
3855        let event_competing = array![0u8];
3856        let sampleweight = array![1.0];
3857        let x_entry = array![[0.5]];
3858        let x_exit = array![[0.0]];
3859        let x_derivative = array![[1.0]];
3860
3861        let model = survival_model(
3862            survival_inputs(
3863                &age_entry,
3864                &age_exit,
3865                &event_target,
3866                &event_competing,
3867                &sampleweight,
3868                &x_entry,
3869                &x_exit,
3870                &x_derivative,
3871            ),
3872            PenaltyBlocks::new(Vec::new()),
3873            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3874            SurvivalSpec::Net,
3875        )
3876        .expect("construct censored survival model");
3877
3878        let err = model
3879            .update_state(&array![1.0])
3880            .expect_err("decreasing cumulative hazard increment should be rejected");
3881        assert!(
3882            err.to_string().contains("cumulative hazard decreased"),
3883            "unexpected error: {err}"
3884        );
3885    }
3886
3887    fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3888        calculate_crude_risk_quadrature(
3889            0.0,
3890            1.0,
3891            &[0.0, 1.0],
3892            beta_d.exp(),
3893            beta_m.exp(),
3894            array![1.0].view(),
3895            array![1.0].view(),
3896            |u, design_d, deriv_d, design_m| {
3897                let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3898                let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3899                let inst_hazard_d = 0.2 * beta_d.exp();
3900                design_d[0] = 1.0;
3901                // η_d = β_d + ln(1 + 0.2u), so η̇_d = 0.2/(1+0.2u)
3902                // which does not depend on β_d → ∂_{β_d} η̇_d = 0
3903                deriv_d[0] = 0.0;
3904                design_m[0] = 1.0;
3905                Ok((inst_hazard_d, cumulative_d, cumulative_m))
3906            },
3907        )
3908        .expect("smooth crude-risk quadrature should succeed")
3909    }
3910
3911    #[test]
3912    fn crude_riskgradient_matches_monotoneobjective() {
3913        let beta_d = -0.2_f64;
3914        let beta_m = -0.5_f64;
3915        let result = smooth_crude_risk(beta_d, beta_m);
3916        let eps = 1e-6;
3917
3918        let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3919            - smooth_crude_risk(beta_d - eps, beta_m).risk)
3920            / (2.0 * eps);
3921        let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3922            - smooth_crude_risk(beta_d, beta_m - eps).risk)
3923            / (2.0 * eps);
3924
3925        assert!(
3926            (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3927            "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3928            result.diseasegradient[0]
3929        );
3930        assert!(
3931            (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3932            "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3933            result.mortalitygradient[0]
3934        );
3935    }
3936
3937    #[test]
3938    fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3939        let age_entry = array![1.0_f64, 2.0_f64];
3940        let age_exit = array![2.0_f64, 3.5_f64];
3941        let event_target = array![1u8, 0u8];
3942        let event_competing = array![0u8, 0u8];
3943        let sampleweight = array![1.0, 1.0];
3944        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3945        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3946        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3947        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3948            matrix: array![[2.0]],
3949            lambda: 1.7,
3950            range: 1..2,
3951            nullspace_dim: 0,
3952        }]);
3953        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3954        let beta = array![-1.2, 0.4];
3955
3956        let model = survival_model(
3957            survival_inputs(
3958                &age_entry,
3959                &age_exit,
3960                &event_target,
3961                &event_competing,
3962                &sampleweight,
3963                &x_entry,
3964                &x_exit,
3965                &x_derivative,
3966            ),
3967            penalties.clone(),
3968            mono,
3969            SurvivalSpec::Net,
3970        )
3971        .expect("construct survival model");
3972
3973        let state = model.update_state(&beta).expect("survival state");
3974        let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3975        assert!(
3976            (state.penalty_term - expected_penalty).abs() < 1e-12,
3977            "penalty_term mismatch: state={} expected={}",
3978            state.penalty_term,
3979            expected_penalty
3980        );
3981    }
3982
3983    #[test]
3984    fn negative_penalty_lambda_is_rejected() {
3985        let age_entry = array![1.0_f64];
3986        let age_exit = array![2.0_f64];
3987        let event_target = array![1u8];
3988        let event_competing = array![0u8];
3989        let sampleweight = array![1.0];
3990        let x_entry = array![[1.0, 0.0]];
3991        let x_exit = array![[1.0, 0.5]];
3992        let x_derivative = array![[0.0, 1.0]];
3993        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3994            matrix: array![[1.0]],
3995            lambda: -0.1,
3996            range: 1..2,
3997            nullspace_dim: 0,
3998        }]);
3999
4000        let err = survival_model(
4001            survival_inputs(
4002                &age_entry,
4003                &age_exit,
4004                &event_target,
4005                &event_competing,
4006                &sampleweight,
4007                &x_entry,
4008                &x_exit,
4009                &x_derivative,
4010            ),
4011            penalties,
4012            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4013            SurvivalSpec::Net,
4014        )
4015        .expect_err("negative lambda must be rejected");
4016
4017        assert!(matches!(err, SurvivalError::NonFiniteInput));
4018    }
4019
4020    #[test]
4021    fn penalty_block_range_and_shapemust_match_coefficients() {
4022        let age_entry = array![1.0_f64];
4023        let age_exit = array![2.0_f64];
4024        let event_target = array![1u8];
4025        let event_competing = array![0u8];
4026        let sampleweight = array![1.0];
4027        let x_entry = array![[1.0, 0.0]];
4028        let x_exit = array![[1.0, 0.5]];
4029        let x_derivative = array![[0.0, 1.0]];
4030        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4031            matrix: array![[1.0]],
4032            lambda: 0.5,
4033            range: 0..2,
4034            nullspace_dim: 0,
4035        }]);
4036
4037        let err = survival_model(
4038            survival_inputs(
4039                &age_entry,
4040                &age_exit,
4041                &event_target,
4042                &event_competing,
4043                &sampleweight,
4044                &x_entry,
4045                &x_exit,
4046                &x_derivative,
4047            ),
4048            penalties,
4049            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4050            SurvivalSpec::Net,
4051        )
4052        .expect_err("penalty block geometry must match coefficient support");
4053
4054        assert!(matches!(err, SurvivalError::DimensionMismatch));
4055    }
4056
4057    #[test]
4058    fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4059        let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4060        let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4061        let event_target = array![1u8, 0u8, 1u8];
4062        let event_competing = array![0u8, 0u8, 0u8];
4063        let sampleweight = array![1.0, 1.0, 1.0];
4064        let x_entry = array![
4065            [1.0, age_entry[0].ln()],
4066            [1.0, age_entry[1].ln()],
4067            [1.0, age_entry[2].ln()]
4068        ];
4069        let x_exit = array![
4070            [1.0, age_exit[0].ln()],
4071            [1.0, age_exit[1].ln()],
4072            [1.0, age_exit[2].ln()]
4073        ];
4074        let x_derivative = array![
4075            [0.0, 1.0 / age_exit[0]],
4076            [0.0, 1.0 / age_exit[1]],
4077            [0.0, 1.0 / age_exit[2]]
4078        ];
4079        let penalties = PenaltyBlocks::new(Vec::new());
4080        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4081        let beta = array![-1.0, 3.0];
4082
4083        let model = survival_model(
4084            survival_inputs(
4085                &age_entry,
4086                &age_exit,
4087                &event_target,
4088                &event_competing,
4089                &sampleweight,
4090                &x_entry,
4091                &x_exit,
4092                &x_derivative,
4093            ),
4094            penalties,
4095            mono,
4096            SurvivalSpec::Net,
4097        )
4098        .expect("construct survival model");
4099
4100        let state = model.update_state(&beta).expect("state at beta");
4101        let eps = 1e-7;
4102        for j in 0..beta.len() {
4103            let mut plus = beta.clone();
4104            let mut minus = beta.clone();
4105            plus[j] += eps;
4106            minus[j] -= eps;
4107            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4108            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4109            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4110            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4111            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4112            assert_eq!(
4113                state.gradient[j].signum(),
4114                fd.signum(),
4115                "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4116                state.gradient[j]
4117            );
4118            assert!(
4119                (state.gradient[j] - fd).abs() < 1e-5,
4120                "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4121                state.gradient[j]
4122            );
4123        }
4124    }
4125
4126    fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4127        // 20-subject survival fixture with mean-centered log-age time
4128        // covariate, balanced events/censorings, and moderate hazard levels.
4129        // The fixture is large enough that the observed-information Hessian
4130        // is well-conditioned at the MLE so PIRLS reaches the 1e-10 KKT
4131        // tolerance in well under 80 iterations from the starting beta used
4132        // below.
4133        let age_entry: Array1<f64> = Array1::from(vec![
4134            30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4135            34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4136        ]);
4137        let age_exit: Array1<f64> = Array1::from(vec![
4138            45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4139            48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4140        ]);
4141        let event_target = Array1::from(vec![
4142            1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4143        ]);
4144        let event_competing = Array1::<u8>::zeros(age_entry.len());
4145        let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4146        let n = age_entry.len();
4147        let ln_age_mean: f64 = {
4148            let mut sum = 0.0;
4149            for i in 0..n {
4150                sum += age_entry[i].ln() + age_exit[i].ln();
4151            }
4152            sum / (2.0 * n as f64)
4153        };
4154        let mut x_entry = Array2::<f64>::zeros((n, 2));
4155        let mut x_exit = Array2::<f64>::zeros((n, 2));
4156        let mut x_derivative = Array2::<f64>::zeros((n, 2));
4157        for i in 0..n {
4158            x_entry[[i, 0]] = 1.0;
4159            x_exit[[i, 0]] = 1.0;
4160            x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4161            x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4162            x_derivative[[i, 0]] = 0.0;
4163            x_derivative[[i, 1]] = 1.0 / age_exit[i];
4164        }
4165        let penalties = PenaltyBlocks::new(vec![
4166            PenaltyBlock {
4167                matrix: array![[3.0]],
4168                lambda: 0.0,
4169                range: 0..1,
4170                nullspace_dim: 0,
4171            },
4172            PenaltyBlock {
4173                matrix: array![[2.5]],
4174                lambda,
4175                range: 1..2,
4176                nullspace_dim: 0,
4177            },
4178        ]);
4179        survival_model(
4180            survival_inputs(
4181                &age_entry,
4182                &age_exit,
4183                &event_target,
4184                &event_competing,
4185                &sampleweight,
4186                &x_entry,
4187                &x_exit,
4188                &x_derivative,
4189            ),
4190            penalties,
4191            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4192            SurvivalSpec::Net,
4193        )
4194        .expect("construct LAML FD survival model")
4195    }
4196
4197    fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4198        use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4199        use gam_linalg::faer_ndarray::FaerEigh;
4200
4201        let h_dense = state.hessian.to_dense();
4202        let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4203        let eps = spectral_epsilon(evals.as_slice().unwrap());
4204        evals
4205            .iter()
4206            .map(|&sigma| spectral_regularize(sigma, eps).ln())
4207            .sum()
4208    }
4209
4210    #[test]
4211    fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4212        // The core claim under test: the survival LAML rho-gradient and the
4213        // documented LAML objective enumerate only penalty blocks whose
4214        // lambda is actually active (> 0). An inactive prefix block must
4215        // therefore contribute neither a log|lambda * S| term to the
4216        // objective nor an entry to the rho-gradient vector.
4217        //
4218        // We verify the objective formula and the gradient dimensionality at
4219        // a fixed beta rather than a fitted one: the bug this test guards
4220        // against was purely algebraic enumeration over penalty blocks and
4221        // has no dependence on PIRLS convergence quality. A gradient-vs-FD
4222        // comparison would require beta to sit at the joint MLE of a tiny
4223        // synthetic survival fixture, which the analytic Newton/PIRLS path
4224        // cannot reach to 1e-10 KKT tolerance without a much richer design.
4225        let rho0 = -0.35_f64;
4226        let beta = array![-2.5_f64, 1.0];
4227        let model = laml_fd_test_model(rho0.exp());
4228        let state = model
4229            .update_state(&beta)
4230            .expect("state for LAML prefix-skip test");
4231
4232        // Sanity: the fixture has two penalty blocks; the first has
4233        // lambda = 0 (inactive prefix) and the second has lambda > 0
4234        // (active). If a future refactor flips this ordering, the prefix
4235        // skip being exercised here would silently become an identity test.
4236        assert_eq!(model.penalties.blocks.len(), 2);
4237        assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4238        assert!(model.penalties.blocks[1].lambda > 0.0);
4239
4240        let rho = Array1::from_iter(
4241            model
4242                .penalties
4243                .blocks
4244                .iter()
4245                .filter(|b| b.lambda > 0.0)
4246                .map(|b| b.lambda.ln()),
4247        );
4248        assert_eq!(
4249            rho.len(),
4250            1,
4251            "fixture should expose exactly one active penalty block for the rho vector"
4252        );
4253
4254        let (obj, grad) = model
4255            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4256            .expect("survival LAML objective and gradient");
4257
4258        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4259            - 0.5 * (rho0 + 2.5_f64.ln());
4260        assert_eq!(
4261            grad.len(),
4262            1,
4263            "rho-gradient must match the active-penalty count, not the full block list"
4264        );
4265        assert!(
4266            (obj - expected).abs() < 1e-10,
4267            "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4268        );
4269        assert!(
4270            grad[0].is_finite(),
4271            "rho-gradient must be finite: {}",
4272            grad[0]
4273        );
4274    }
4275
4276    #[test]
4277    fn structural_monotonicgradient_matchesobjectivefd() {
4278        let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4279        let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4280        let event_target = array![1u8, 0u8, 1u8];
4281        let event_competing = array![0u8, 0u8, 0u8];
4282        let sampleweight = array![1.0, 1.0, 1.0];
4283
4284        // Time block has 3 structural-monotone columns.
4285        // Final column is a covariate, left unconstrained.
4286        let x_entry = array![
4287            [1.0, 0.2, 0.05, -0.7],
4288            [1.0, 0.5, 0.20, 0.1],
4289            [1.0, 0.9, 0.60, 1.2]
4290        ];
4291        let x_exit = array![
4292            [1.0, 0.4, 0.16, -0.7],
4293            [1.0, 0.8, 0.64, 0.1],
4294            [1.0, 1.1, 1.21, 1.2]
4295        ];
4296        let x_derivative = array![
4297            [0.0, 0.8, 0.64, 0.0],
4298            [0.0, 0.7, 1.12, 0.0],
4299            [0.0, 0.6, 1.32, 0.0]
4300        ];
4301        let penalties = PenaltyBlocks::new(Vec::new());
4302        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4303        let mut model = survival_model(
4304            survival_inputs(
4305                &age_entry,
4306                &age_exit,
4307                &event_target,
4308                &event_competing,
4309                &sampleweight,
4310                &x_entry,
4311                &x_exit,
4312                &x_derivative,
4313            ),
4314            penalties,
4315            mono,
4316            SurvivalSpec::Net,
4317        )
4318        .expect("construct structural survival model");
4319        model
4320            .set_structural_monotonicity(true, 3)
4321            .expect("enable structural monotonicity");
4322        let constraints = model
4323            .monotonicity_linear_constraints()
4324            .expect("structural derivative constraints");
4325        assert_eq!(constraints.a.nrows(), 2);
4326        assert_eq!(constraints.a.ncols(), 4);
4327        assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4328        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4329        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4330
4331        let beta = array![0.2, 0.2, 0.1, 0.2];
4332        let state = model.update_state(&beta).expect("state at structural beta");
4333        let eps = 1e-7;
4334        for j in 0..beta.len() {
4335            let mut plus = beta.clone();
4336            let mut minus = beta.clone();
4337            plus[j] += eps;
4338            minus[j] -= eps;
4339            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4340            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4341            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4342            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4343            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4344            assert_eq!(
4345                state.gradient[j].signum(),
4346                fd.signum(),
4347                "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4348                state.gradient[j]
4349            );
4350            assert!(
4351                (state.gradient[j] - fd).abs() < 2e-5,
4352                "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4353                state.gradient[j]
4354            );
4355        }
4356    }
4357
4358    #[test]
4359    fn structural_monotonic_lamlgradient_returns_finitevalues() {
4360        let age_entry = array![1.0_f64, 1.2_f64];
4361        let age_exit = array![1.5_f64, 2.0_f64];
4362        let event_target = array![1u8, 0u8];
4363        let event_competing = array![0u8, 0u8];
4364        let sampleweight = array![1.0, 1.0];
4365        let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4366        let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4367        let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4368        let penalties = PenaltyBlocks::new(Vec::new());
4369        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4370        let mut model = survival_model(
4371            survival_inputs(
4372                &age_entry,
4373                &age_exit,
4374                &event_target,
4375                &event_competing,
4376                &sampleweight,
4377                &x_entry,
4378                &x_exit,
4379                &x_derivative,
4380            ),
4381            penalties,
4382            mono,
4383            SurvivalSpec::Net,
4384        )
4385        .expect("construct structural survival model");
4386        model
4387            .set_structural_monotonicity(true, 2)
4388            .expect("enable structural monotonicity");
4389        // One simple penalty block to exercise rho-gradient path.
4390        model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4391            matrix: array![[1.0]],
4392            lambda: 0.7,
4393            range: 1..2,
4394            nullspace_dim: 0,
4395        }]);
4396        let beta = array![0.2, 0.2, 0.1];
4397        let state = model.update_state(&beta).expect("state at structural beta");
4398        let rho = Array1::from_iter(
4399            model
4400                .penalties
4401                .blocks
4402                .iter()
4403                .filter(|b| b.lambda > 0.0)
4404                .map(|b| b.lambda.ln()),
4405        );
4406        let (obj, grad) = model
4407            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4408            .expect("laml gradient should work in structural mode");
4409        assert!(obj.is_finite());
4410        assert_eq!(grad.len(), 1);
4411        assert!(grad[0].is_finite());
4412    }
4413
4414    #[test]
4415    fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4416        let age_entry = array![1.0_f64];
4417        let age_exit = array![2.0_f64];
4418        let event_target = array![1u8];
4419        let event_competing = array![0u8];
4420        let sampleweight = array![1.0];
4421        let x_entry = array![[0.0]];
4422        let x_exit = array![[0.2]];
4423        let x_derivative = array![[1.0]];
4424
4425        let penalties = PenaltyBlocks::new(Vec::new());
4426        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4427        let mut model = survival_model(
4428            survival_inputs(
4429                &age_entry,
4430                &age_exit,
4431                &event_target,
4432                &event_competing,
4433                &sampleweight,
4434                &x_entry,
4435                &x_exit,
4436                &x_derivative,
4437            ),
4438            penalties,
4439            mono,
4440            SurvivalSpec::Net,
4441        )
4442        .expect("construct structural survival model");
4443
4444        let beta = array![-3.0];
4445        assert!(
4446            model.update_state(&beta).is_err(),
4447            "negative derivative coefficient should violate derivative guard"
4448        );
4449
4450        model
4451            .set_structural_monotonicity(true, 1)
4452            .expect("enable structural monotonicity");
4453        let constraints = model
4454            .monotonicity_linear_constraints()
4455            .expect("structural derivative constraints");
4456        assert_eq!(constraints.a.nrows(), 1);
4457        assert_eq!(constraints.a.ncols(), 1);
4458        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4459        // Structural monotonicity uses derivative_guard() == 0.0
4460        assert!(constraints.b[0].abs() <= 1e-12);
4461        let state = model
4462            .update_state(&array![1e-6])
4463            .expect("small positive derivative coefficient should remain feasible");
4464        assert!(state.deviance.is_finite());
4465    }
4466
4467    #[test]
4468    fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4469        let age_entry = array![1.0_f64];
4470        let age_exit = array![2.0_f64];
4471        let event_target = array![1u8];
4472        let event_competing = array![0u8];
4473        let sampleweight = array![1.0];
4474        let x_entry = array![[1.0, 0.0]];
4475        let x_exit = array![[1.0, 0.0]];
4476        let x_derivative = array![[0.0, 0.0]];
4477        let penalties = PenaltyBlocks::new(Vec::new());
4478        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4479        let eta_entry_offset = array![0.0];
4480        let eta_exit_offset = array![0.0];
4481        let derivative_offset_below_guard = array![2.0];
4482        let derivative_offset_above_guard = array![3.1];
4483        let offsets_below_guard = SurvivalBaselineOffsets {
4484            eta_entry: eta_entry_offset.view(),
4485            eta_exit: eta_exit_offset.view(),
4486            derivative_exit: derivative_offset_below_guard.view(),
4487        };
4488        let offsets_above_guard = SurvivalBaselineOffsets {
4489            eta_entry: eta_entry_offset.view(),
4490            eta_exit: eta_exit_offset.view(),
4491            derivative_exit: derivative_offset_above_guard.view(),
4492        };
4493
4494        let model_below_guard = survival_model_with_offsets(
4495            survival_inputs(
4496                &age_entry,
4497                &age_exit,
4498                &event_target,
4499                &event_competing,
4500                &sampleweight,
4501                &x_entry,
4502                &x_exit,
4503                &x_derivative,
4504            ),
4505            Some(offsets_below_guard),
4506            penalties.clone(),
4507            monotonicity,
4508            SurvivalSpec::Net,
4509        )
4510        .expect("construct model with derivative offset below guard");
4511        let err = model_below_guard
4512            .update_state(&array![0.0, 0.0])
4513            .expect_err("derivative offset below guard should be rejected");
4514        let err_text = err.to_string();
4515        assert!(
4516            err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4517            "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4518        );
4519
4520        let model_above_guard = survival_model_with_offsets(
4521            survival_inputs(
4522                &age_entry,
4523                &age_exit,
4524                &event_target,
4525                &event_competing,
4526                &sampleweight,
4527                &x_entry,
4528                &x_exit,
4529                &x_derivative,
4530            ),
4531            Some(offsets_above_guard),
4532            penalties,
4533            SurvivalMonotonicityPenalty { tolerance: 3.0 },
4534            SurvivalSpec::Net,
4535        )
4536        .expect("construct model with derivative offset above guard");
4537        let state = model_above_guard
4538            .update_state(&array![0.0, 0.0])
4539            .expect("derivative offset above guard should remain feasible");
4540        assert!(state.deviance.is_finite());
4541    }
4542
4543    #[test]
4544    fn structural_monotonicity_rejects_negative_derivative_offsets() {
4545        let age_entry = array![1.0_f64];
4546        let age_exit = array![2.0_f64];
4547        let event_target = array![1u8];
4548        let event_competing = array![0u8];
4549        let sampleweight = array![1.0];
4550        let x_entry = array![[0.0]];
4551        let x_exit = array![[0.2]];
4552        let x_derivative = array![[1.0]];
4553        let eta_entry = array![0.0];
4554        let eta_exit = array![0.0];
4555        let derivative_exit = array![-1e-3];
4556        let offsets = SurvivalBaselineOffsets {
4557            eta_entry: eta_entry.view(),
4558            eta_exit: eta_exit.view(),
4559            derivative_exit: derivative_exit.view(),
4560        };
4561
4562        let mut model = survival_model_with_offsets(
4563            survival_inputs(
4564                &age_entry,
4565                &age_exit,
4566                &event_target,
4567                &event_competing,
4568                &sampleweight,
4569                &x_entry,
4570                &x_exit,
4571                &x_derivative,
4572            ),
4573            Some(offsets),
4574            PenaltyBlocks::new(Vec::new()),
4575            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4576            SurvivalSpec::Net,
4577        )
4578        .expect("construct structural survival model");
4579        let err = model
4580            .set_structural_monotonicity(true, 1)
4581            .expect_err("negative derivative offsets must be rejected");
4582        assert!(
4583            err.to_string()
4584                .contains("structural monotonicity requires nonnegative derivative offsets"),
4585            "unexpected error: {err}"
4586        );
4587    }
4588
4589    #[test]
4590    fn structural_monotonicity_emits_coefficient_constraints() {
4591        let age_entry = array![1.0_f64, 1.5_f64];
4592        let age_exit = array![2.0_f64, 3.0_f64];
4593        let event_target = array![1u8, 0u8];
4594        let event_competing = array![0u8, 0u8];
4595        let sampleweight = array![1.0, 1.0];
4596        let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4597        let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4598        let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4599
4600        let mut model = survival_model(
4601            survival_inputs(
4602                &age_entry,
4603                &age_exit,
4604                &event_target,
4605                &event_competing,
4606                &sampleweight,
4607                &x_entry,
4608                &x_exit,
4609                &x_derivative,
4610            ),
4611            PenaltyBlocks::new(Vec::new()),
4612            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4613            SurvivalSpec::Net,
4614        )
4615        .expect("construct structural survival model");
4616        model
4617            .set_structural_monotonicity(true, 2)
4618            .expect("enable structural monotonicity");
4619
4620        let constraints = model
4621            .monotonicity_linear_constraints()
4622            .expect("structural derivative constraints");
4623
4624        assert_eq!(constraints.a.nrows(), 2);
4625        assert_eq!(constraints.a.ncols(), 3);
4626        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4627        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4628        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4629    }
4630
4631    #[test]
4632    fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4633        let age_entry = array![1.0_f64];
4634        let age_exit = array![2.0_f64];
4635        let event_target = array![1u8];
4636        let event_competing = array![0u8];
4637        let sampleweight = array![1.0];
4638        let x_entry = array![[1.0, 0.2]];
4639        let x_exit = array![[1.0, 0.6]];
4640        let x_derivative = array![[0.0, 1.0]];
4641
4642        let mut model = survival_model(
4643            survival_inputs(
4644                &age_entry,
4645                &age_exit,
4646                &event_target,
4647                &event_competing,
4648                &sampleweight,
4649                &x_entry,
4650                &x_exit,
4651                &x_derivative,
4652            ),
4653            PenaltyBlocks::new(Vec::new()),
4654            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4655            SurvivalSpec::Net,
4656        )
4657        .expect("construct structural survival model");
4658        model
4659            .set_structural_monotonicity(true, 2)
4660            .expect("enable structural monotonicity");
4661
4662        let constraints = model
4663            .monotonicity_linear_constraints()
4664            .expect("structural derivative constraints");
4665
4666        assert_eq!(constraints.a.nrows(), 1);
4667        assert!(
4668            constraints.a[[0, 0]].abs() <= 1e-12,
4669            "inactive time column should remain unconstrained"
4670        );
4671        assert!(
4672            (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4673            "active time column should remain constrained"
4674        );
4675    }
4676
4677    #[test]
4678    fn structural_monotonicity_preserves_sparse_row_patterns() {
4679        let age_entry = array![1.0_f64, 1.5_f64];
4680        let age_exit = array![2.0_f64, 2.5_f64];
4681        let event_target = array![1u8, 1u8];
4682        let event_competing = array![0u8, 0u8];
4683        let sampleweight = array![1.0, 1.0];
4684        let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4685        let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4686        let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4687
4688        let mut model = survival_model(
4689            survival_inputs(
4690                &age_entry,
4691                &age_exit,
4692                &event_target,
4693                &event_competing,
4694                &sampleweight,
4695                &x_entry,
4696                &x_exit,
4697                &x_derivative,
4698            ),
4699            PenaltyBlocks::new(Vec::new()),
4700            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4701            SurvivalSpec::Net,
4702        )
4703        .expect("construct structural survival model");
4704        model
4705            .set_structural_monotonicity(true, 2)
4706            .expect("enable structural monotonicity");
4707
4708        let constraints = model
4709            .monotonicity_linear_constraints()
4710            .expect("structural derivative constraints");
4711
4712        assert_eq!(constraints.a.nrows(), 2);
4713        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4714        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4715    }
4716
4717    #[test]
4718    fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4719        let age_entry = array![1.0_f64];
4720        let age_exit = array![1.1_f64];
4721        let event_target = array![0u8];
4722        let event_competing = array![0u8];
4723        let sampleweight = array![1.0];
4724        let x_entry = array![[0.0]];
4725        let x_exit = array![[0.0]];
4726        let x_derivative = array![[-1.0]];
4727        let penalties = PenaltyBlocks::new(Vec::new());
4728        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4729        let model = survival_model(
4730            survival_inputs(
4731                &age_entry,
4732                &age_exit,
4733                &event_target,
4734                &event_competing,
4735                &sampleweight,
4736                &x_entry,
4737                &x_exit,
4738                &x_derivative,
4739            ),
4740            penalties,
4741            mono,
4742            SurvivalSpec::Net,
4743        )
4744        .expect("construct censored survival model");
4745
4746        let err = model
4747            .update_state(&array![1.0])
4748            .expect_err("censored row should still enforce monotonic derivative");
4749        assert!(
4750            matches!(err, EstimationError::ParameterConstraintViolation(_)),
4751            "unexpected error: {err:?}"
4752        );
4753    }
4754
4755    fn crude_risk_quadrature_error(
4756        cumulative_entry: f64,
4757        cumulative_exit: f64,
4758        hazard_exit: f64,
4759    ) -> SurvivalError {
4760        calculate_crude_risk_quadrature(
4761            1.0,
4762            2.0,
4763            &[],
4764            0.4,
4765            0.2,
4766            array![1.0].view(),
4767            array![1.0].view(),
4768            |_, design_d, deriv_d, design_m| {
4769                design_d[0] = 1.0;
4770                deriv_d[0] = 0.0;
4771                design_m[0] = 1.0;
4772                Ok((cumulative_entry, cumulative_exit, hazard_exit))
4773            },
4774        )
4775        .expect_err("invalid hazards should fail")
4776    }
4777
4778    #[test]
4779    fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4780        let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4781        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4782    }
4783
4784    #[test]
4785    fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4786        let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4787        assert!(matches!(err, SurvivalError::NonPositiveHazard));
4788    }
4789
4790    #[test]
4791    fn laml_no_penalties_matches_documentedobjective() {
4792        let age_entry = array![40.0, 45.0, 50.0, 55.0];
4793        let age_exit = array![44.0, 49.0, 54.0, 59.0];
4794        let event_target = array![1u8, 0u8, 1u8, 0u8];
4795        let event_competing = Array1::<u8>::zeros(4);
4796        let sampleweight = Array1::ones(4);
4797        let x_entry = array![
4798            [1.0, -0.2, 0.04],
4799            [1.0, -0.1, 0.01],
4800            [1.0, 0.0, 0.0],
4801            [1.0, 0.1, 0.01]
4802        ];
4803        let x_exit = array![
4804            [1.0, -0.12, 0.0144],
4805            [1.0, -0.02, 0.0004],
4806            [1.0, 0.08, 0.0064],
4807            [1.0, 0.18, 0.0324]
4808        ];
4809        let x_derivative = array![
4810            [0.0, 0.02, 0.001],
4811            [0.0, 0.02, 0.001],
4812            [0.0, 0.02, 0.001],
4813            [0.0, 0.02, 0.001]
4814        ];
4815        let penalties = PenaltyBlocks::new(Vec::new());
4816        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4817        let beta = array![-2.0, 0.7, 0.2];
4818
4819        let model = survival_model(
4820            survival_inputs(
4821                &age_entry,
4822                &age_exit,
4823                &event_target,
4824                &event_competing,
4825                &sampleweight,
4826                &x_entry,
4827                &x_exit,
4828                &x_derivative,
4829            ),
4830            penalties,
4831            mono,
4832            SurvivalSpec::Net,
4833        )
4834        .expect("construct survival model");
4835
4836        let state = model.update_state(&beta).expect("state at beta");
4837        let rho = Array1::from_iter(
4838            model
4839                .penalties
4840                .blocks
4841                .iter()
4842                .filter(|b| b.lambda > 0.0)
4843                .map(|b| b.lambda.ln()),
4844        );
4845        let (obj, grad) = model
4846            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4847            .expect("laml objective for no-penalty model");
4848
4849        let h_dense = state.hessian.to_dense();
4850        let logdet_h: f64 = {
4851            use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4852            use gam_linalg::faer_ndarray::FaerEigh;
4853            let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4854            let eps = spectral_epsilon(evals.as_slice().unwrap());
4855            evals
4856                .iter()
4857                .map(|&sigma| spectral_regularize(sigma, eps).ln())
4858                .sum()
4859        };
4860        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4861
4862        assert_eq!(grad.len(), 0);
4863        assert!(
4864            (obj - expected).abs() < 1e-10,
4865            "no-penalty LAML objective mismatch: obj={} expected={}",
4866            obj,
4867            expected
4868        );
4869    }
4870
4871    #[test]
4872    fn monotonicity_constraints_collapse_positive_collinearrows() {
4873        let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4874        let b = array![1e-8, 1e-8, 1e-8];
4875
4876        let compressed = compress_positive_collinear_constraints(&a, &b);
4877
4878        assert_eq!(compressed.a.nrows(), 1);
4879        assert_eq!(compressed.a.ncols(), 3);
4880        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4881        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4882        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4883        assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4884    }
4885
4886    #[test]
4887    fn monotonicity_constraints_preserve_distinct_directions() {
4888        let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4889        let b = array![0.2, 0.3, 0.1];
4890
4891        let compressed = compress_positive_collinear_constraints(&a, &b);
4892
4893        assert_eq!(compressed.a.nrows(), 2);
4894        let mut saw_x = false;
4895        let mut saw_y = false;
4896        for i in 0..compressed.a.nrows() {
4897            if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4898                saw_x = true;
4899                assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4900            }
4901            if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4902                saw_y = true;
4903                assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4904            }
4905        }
4906        assert!(saw_x);
4907        assert!(saw_y);
4908    }
4909
4910    #[test]
4911    fn monotonicity_constraints_cluster_near_collinearrows() {
4912        let a = array![
4913            [0.0, 0.5, 0.0],
4914            [0.0, 0.50000000003, 0.0],
4915            [0.0, 0.49999999997, 0.0]
4916        ];
4917        let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4918
4919        let compressed = compress_positive_collinear_constraints(&a, &b);
4920
4921        assert_eq!(compressed.a.nrows(), 1);
4922        assert_eq!(compressed.a.ncols(), 3);
4923        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4924        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4925        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4926        assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4927    }
4928
4929    #[test]
4930    fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4931        let a = array![
4932            [0.0, 0.401, 0.302, 0.197],
4933            [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4934            [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4935            [0.0, 0.125, 0.500, 0.375]
4936        ];
4937        let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4938
4939        let compressed = compress_positive_collinear_constraints(&a, &b);
4940
4941        assert_eq!(compressed.a.nrows(), 2);
4942        let mut clustered_face = false;
4943        let mut distinct_face = false;
4944        for i in 0..compressed.a.nrows() {
4945            let row = compressed.a.row(i);
4946            if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4947                clustered_face = true;
4948                assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4949            } else {
4950                distinct_face = true;
4951                assert!((row[1] - 0.25).abs() <= 1e-12);
4952                assert!((row[2] - 1.0).abs() <= 1e-12);
4953                assert!((row[3] - 0.75).abs() <= 1e-12);
4954                assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4955            }
4956        }
4957        assert!(clustered_face);
4958        assert!(distinct_face);
4959    }
4960
4961    #[test]
4962    fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4963        let age_entry = array![1.0_f64, 1.0, 1.0];
4964        let age_exit = array![2.0_f64, 4.0, 8.0];
4965        let event_target = array![0u8, 1u8, 0u8];
4966        let event_competing = array![0u8, 0u8, 0u8];
4967        let sampleweight = array![1.0, 1.0, 1.0];
4968        let x_entry = array![
4969            [1.0, age_entry[0].ln()],
4970            [1.0, age_entry[1].ln()],
4971            [1.0, age_entry[2].ln()]
4972        ];
4973        let x_exit = array![
4974            [1.0, age_exit[0].ln()],
4975            [1.0, age_exit[1].ln()],
4976            [1.0, age_exit[2].ln()]
4977        ];
4978        let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4979        let penalties = PenaltyBlocks::new(Vec::new());
4980        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4981
4982        let collocation_offsets = Array1::zeros(x_derivative.nrows());
4983        let mut inputs = survival_inputs(
4984            &age_entry,
4985            &age_exit,
4986            &event_target,
4987            &event_competing,
4988            &sampleweight,
4989            &x_entry,
4990            &x_exit,
4991            &x_derivative,
4992        );
4993        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4994        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4995
4996        let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4997            .expect("construct linear survival model");
4998
4999        let constraints = model
5000            .monotonicity_linear_constraints()
5001            .expect("monotonicity constraints");
5002        assert_eq!(constraints.a.nrows(), 1);
5003        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5004        assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
5005    }
5006
5007    #[test]
5008    fn monotonicity_constraints_skip_numericallyzerorows() {
5009        let age_entry = array![1.0_f64, 1.0, 1.0];
5010        let age_exit = array![2.0_f64, 3.0, 4.0];
5011        let event_target = array![0u8, 0u8, 0u8];
5012        let event_competing = array![0u8, 0u8, 0u8];
5013        let sampleweight = array![1.0, 1.0, 1.0];
5014        let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
5015        let x_exit = x_entry.clone();
5016        let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
5017
5018        let collocation_offsets = Array1::zeros(x_derivative.nrows());
5019        let mut inputs = survival_inputs(
5020            &age_entry,
5021            &age_exit,
5022            &event_target,
5023            &event_competing,
5024            &sampleweight,
5025            &x_entry,
5026            &x_exit,
5027            &x_derivative,
5028        );
5029        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5030        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5031
5032        let model = survival_model(
5033            inputs,
5034            PenaltyBlocks::new(Vec::new()),
5035            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5036            SurvivalSpec::Net,
5037        )
5038        .expect("construct survival model");
5039
5040        let constraints = model
5041            .monotonicity_linear_constraints()
5042            .expect("nonzero derivative row should remain");
5043        assert_eq!(constraints.a.nrows(), 1);
5044        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5045        assert!(constraints.b[0].abs() <= 1e-18);
5046    }
5047
5048    #[test]
5049    fn censoredrows_allowzero_boundary_derivative() {
5050        let age_entry = array![1.0_f64];
5051        let age_exit = array![2.0_f64];
5052        let event_target = array![0u8];
5053        let event_competing = array![0u8];
5054        let sampleweight = array![1.0];
5055        let x_entry = array![[0.0]];
5056        let x_exit = array![[0.0]];
5057        let x_derivative = array![[1.0]];
5058
5059        let model = survival_model(
5060            survival_inputs(
5061                &age_entry,
5062                &age_exit,
5063                &event_target,
5064                &event_competing,
5065                &sampleweight,
5066                &x_entry,
5067                &x_exit,
5068                &x_derivative,
5069            ),
5070            PenaltyBlocks::new(Vec::new()),
5071            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5072            SurvivalSpec::Net,
5073        )
5074        .expect("construct censored survival model");
5075
5076        let state = model
5077            .update_state(&array![0.0])
5078            .expect("censored boundary derivative should remain feasible with zero tolerance");
5079        assert!(state.deviance.is_finite());
5080    }
5081
5082    #[test]
5083    fn eventrows_keep_positive_derivative_constraint() {
5084        let age_entry = array![1.0_f64, 1.0];
5085        let age_exit = array![2.0_f64, 4.0];
5086        let event_target = array![0u8, 1u8];
5087        let event_competing = array![0u8, 0u8];
5088        let sampleweight = array![1.0, 1.0];
5089        let x_entry = array![[0.0], [0.0]];
5090        let x_exit = array![[0.0], [0.0]];
5091        let x_derivative = array![[0.5], [0.25]];
5092
5093        let collocation_offsets = Array1::zeros(x_derivative.nrows());
5094        let mut inputs = survival_inputs(
5095            &age_entry,
5096            &age_exit,
5097            &event_target,
5098            &event_competing,
5099            &sampleweight,
5100            &x_entry,
5101            &x_exit,
5102            &x_derivative,
5103        );
5104        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5105        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5106
5107        let model = survival_model(
5108            inputs,
5109            PenaltyBlocks::new(Vec::new()),
5110            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5111            SurvivalSpec::Net,
5112        )
5113        .expect("construct mixed survival model");
5114
5115        let constraints = model
5116            .monotonicity_linear_constraints()
5117            .expect("event row should induce positive lower bound");
5118        assert_eq!(constraints.a.nrows(), 1);
5119        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5120        assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5121    }
5122
5123    #[test]
5124    fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5125        let age_entry = array![1.0_f64];
5126        let age_exit = array![2.0_f64];
5127        let event_target = array![1u8];
5128        let event_competing = array![0u8];
5129        let sampleweight = array![1.0];
5130        let x_entry = array![[0.0]];
5131        let x_exit = array![[0.0]];
5132        let x_derivative = array![[1.0]];
5133        let mut model = survival_model(
5134            survival_inputs(
5135                &age_entry,
5136                &age_exit,
5137                &event_target,
5138                &event_competing,
5139                &sampleweight,
5140                &x_entry,
5141                &x_exit,
5142                &x_derivative,
5143            ),
5144            PenaltyBlocks::new(Vec::new()),
5145            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5146            SurvivalSpec::Net,
5147        )
5148        .expect("construct survival model");
5149        model
5150            .set_structural_monotonicity(true, 1)
5151            .expect("enable structural monotonicity");
5152
5153        let state = model
5154            .update_state(&array![-1e-8])
5155            .expect("tiny structural roundoff should be clamped");
5156        assert!(state.deviance.is_finite());
5157    }
5158
5159    #[test]
5160    fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5161        let uncompressed_constraints = LinearInequalityConstraints {
5162            a: array![
5163                [0.0, 0.5, 0.0],
5164                [0.0, 1.0 / 3.0, 0.0],
5165                [0.0, 0.2, 0.0],
5166                [0.0, 0.125, 0.0]
5167            ],
5168            b: Array1::from_elem(4, 1e-8),
5169        };
5170        let compressed_constraints = compress_positive_collinear_constraints(
5171            &uncompressed_constraints.a,
5172            &uncompressed_constraints.b,
5173        );
5174
5175        let candidates = [
5176            array![0.0, 1e-9, 0.0],
5177            array![0.0, 4e-8, 0.0],
5178            array![0.0, 8e-8, 0.0],
5179            array![0.0, 2e-7, 1.5],
5180        ];
5181        for beta in candidates {
5182            let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5183                uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5184            });
5185            let compressed_ok = (0..compressed_constraints.a.nrows())
5186                .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5187            assert_eq!(compressed_ok, uncompressed_ok);
5188        }
5189    }
5190
5191    #[test]
5192    fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5193        let age_entry = array![10.0_f64, 20.0, 25.0];
5194        let age_exit = array![15.0_f64, 30.0, 40.0];
5195        let event_target = array![1u8, 0u8, 1u8];
5196        let event_competing = array![0u8, 0u8, 0u8];
5197        let sampleweight = array![1.0, 2.0, 0.5];
5198        let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5199        let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5200        let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5201        let beta = array![0.8, 1.1, -0.2];
5202
5203        let base_model = survival_model(
5204            survival_inputs(
5205                &age_entry,
5206                &age_exit,
5207                &event_target,
5208                &event_competing,
5209                &sampleweight,
5210                &x_entry,
5211                &x_exit,
5212                &x_derivative,
5213            ),
5214            PenaltyBlocks::new(Vec::new()),
5215            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5216            SurvivalSpec::Net,
5217        )
5218        .expect("construct base survival model");
5219        let base_state = base_model
5220            .update_state(&beta)
5221            .expect("evaluate base survival state");
5222
5223        let time_scale = 365.25;
5224        let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5225        let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5226        let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5227        let scaled_model = survival_model(
5228            survival_inputs(
5229                &scaled_age_entry,
5230                &scaled_age_exit,
5231                &event_target,
5232                &event_competing,
5233                &sampleweight,
5234                &x_entry,
5235                &x_exit,
5236                &scaled_x_derivative,
5237            ),
5238            PenaltyBlocks::new(Vec::new()),
5239            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5240            SurvivalSpec::Net,
5241        )
5242        .expect("construct scaled survival model");
5243        let scaled_state = scaled_model
5244            .update_state(&beta)
5245            .expect("evaluate scaled survival state");
5246
5247        let weighted_events = sampleweight
5248            .iter()
5249            .zip(event_target.iter())
5250            .map(|(w, d)| *w * f64::from(*d))
5251            .sum::<f64>();
5252        let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5253        assert!(
5254            (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5255            "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5256            scaled_state.deviance,
5257            base_state.deviance
5258        );
5259
5260        for j in 0..beta.len() {
5261            assert!(
5262                (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5263                "gradient mismatch at j={j}: scaled={} base={}",
5264                scaled_state.gradient[j],
5265                base_state.gradient[j]
5266            );
5267        }
5268
5269        let base_hessian = base_state.hessian.to_dense();
5270        let scaled_hessian = scaled_state.hessian.to_dense();
5271        for r in 0..beta.len() {
5272            for c in 0..beta.len() {
5273                assert!(
5274                    (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5275                    "hessian mismatch at ({r},{c}): scaled={} base={}",
5276                    scaled_hessian[[r, c]],
5277                    base_hessian[[r, c]]
5278                );
5279            }
5280        }
5281    }
5282}