Skip to main content

gam_models/survival/
base.rs

1use gam_linalg::faer_ndarray::{fast_atv, fast_av, fast_xt_diag_x, fast_xt_diag_y};
2use crate::custom_family::{
3    BlockWorkingSet, CustomFamily, FamilyEvaluation, ParameterBlockState,
4    projected_linear_constraint_stationarity_vector,
5};
6use gam_linalg::matrix::SymmetricMatrix;
7use crate::model_types::EstimationError;
8use gam_solve::pirls::{
9    LinearInequalityConstraints, WorkingModel as PirlsWorkingModel, WorkingState, array1_l2_norm,
10};
11use gam_problem::{Coefficients, LinearPredictor};
12use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
13use serde::{Deserialize, Serialize};
14use std::collections::BTreeMap;
15use std::ops::Range;
16use std::sync::LazyLock;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum SurvivalError {
21    #[error("input dimensions are inconsistent")]
22    DimensionMismatch,
23    #[error("inputs contain non-finite values")]
24    NonFiniteInput,
25    #[error("survival spec '{0}' is not supported by the one-hazard survival engine")]
26    UnsupportedSpec(&'static str),
27    #[error("crude risk integration setup is invalid")]
28    InvalidIntegrationSetup,
29    #[error("survival time grid must be finite, non-negative, and strictly increasing")]
30    InvalidTimeGrid,
31    #[error("cumulative hazard must be nondecreasing")]
32    NonMonotoneCumulativeHazard,
33    #[error("instantaneous hazard must stay strictly positive during integration")]
34    NonPositiveHazard,
35    #[error("{reason}")]
36    InvalidInput { reason: String },
37    #[error("{reason}")]
38    CauseSpecificDimensionMismatch { reason: String },
39    #[error("{reason}")]
40    NumericalFailure { reason: String },
41    #[error("{reason}")]
42    EventCodeInvalid { reason: String },
43    #[error("{reason}")]
44    EventDegenerate { reason: String },
45    #[error("cause-specific survival block {block}: {source}")]
46    CauseSpecificBlock {
47        block: usize,
48        #[source]
49        source: Box<SurvivalError>,
50    },
51}
52
53impl From<SurvivalError> for String {
54    fn from(err: SurvivalError) -> Self {
55        err.to_string()
56    }
57}
58
59impl From<crate::block_layout::block_count::BlockCountMismatch> for SurvivalError {
60    fn from(err: crate::block_layout::block_count::BlockCountMismatch) -> SurvivalError {
61        SurvivalError::CauseSpecificDimensionMismatch {
62            reason: err.message(),
63        }
64    }
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
68pub enum SurvivalSpec {
69    #[default]
70    Net,
71    Crude,
72}
73
74#[derive(Debug, Clone)]
75pub struct SurvivalEngineInputs<'a> {
76    pub age_entry: ArrayView1<'a, f64>,
77    pub age_exit: ArrayView1<'a, f64>,
78    pub event_target: ArrayView1<'a, u8>,
79    pub event_competing: ArrayView1<'a, u8>,
80    pub sampleweight: ArrayView1<'a, f64>,
81    pub x_entry: ArrayView2<'a, f64>,
82    pub x_exit: ArrayView2<'a, f64>,
83    pub x_derivative: ArrayView2<'a, f64>,
84    /// Optional global monotonicity collocation rows for the full coefficient vector.
85    /// Non-structural survival models should pass these explicitly instead of
86    /// relying on observed derivative rows.
87    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
88    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
89    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
90}
91
92#[derive(Debug, Clone)]
93pub struct SurvivalTimeCovarInputs<'a> {
94    pub age_entry: ArrayView1<'a, f64>,
95    pub age_exit: ArrayView1<'a, f64>,
96    pub event_target: ArrayView1<'a, u8>,
97    pub event_competing: ArrayView1<'a, u8>,
98    pub sampleweight: ArrayView1<'a, f64>,
99    pub time_entry: ArrayView2<'a, f64>,
100    pub time_exit: ArrayView2<'a, f64>,
101    pub time_derivative: ArrayView2<'a, f64>,
102    pub covariates: ArrayView2<'a, f64>,
103    /// Optional global monotonicity collocation rows for the full coefficient vector.
104    /// Non-structural survival models should pass these explicitly instead of
105    /// relying on observed derivative rows.
106    pub monotonicity_constraint_rows: Option<ArrayView2<'a, f64>>,
107    /// Baseline offsets corresponding to `monotonicity_constraint_rows`.
108    pub monotonicity_constraint_offsets: Option<ArrayView1<'a, f64>>,
109}
110
111#[derive(Debug, Clone)]
112pub struct SurvivalBaselineOffsets<'a> {
113    /// Baseline target contribution to eta at entry time: eta_target(t_entry).
114    pub eta_entry: ArrayView1<'a, f64>,
115    /// Baseline target contribution to eta at exit time: eta_target(t_exit).
116    pub eta_exit: ArrayView1<'a, f64>,
117    /// Baseline target contribution to d eta / d t at exit: eta_target'(t_exit).
118    ///
119    /// This is used in event terms where log-hazard requires
120    /// log(d eta / d t). By threading this as an explicit offset, we get
121    /// "parametric default + spline deviation" behavior:
122    /// - strong penalty => deviation ~ 0 => model collapses to baseline target,
123    /// - weak penalty   => deviation can bend away where data supports it.
124    pub derivative_exit: ArrayView1<'a, f64>,
125}
126
127#[derive(Debug, Clone)]
128pub struct PenaltyBlock {
129    pub matrix: Array2<f64>,
130    pub lambda: f64,
131    pub range: Range<usize>,
132    /// Structural nullspace dimension of this penalty matrix.
133    /// Used for exact pseudo-logdet computation. 0 means full rank.
134    pub nullspace_dim: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct PenaltyBlocks {
139    pub blocks: Vec<PenaltyBlock>,
140}
141
142impl PenaltyBlocks {
143    pub fn new(blocks: Vec<PenaltyBlock>) -> Self {
144        Self { blocks }
145    }
146
147    pub fn gradient(&self, beta: &Array1<f64>) -> Array1<f64> {
148        let mut grad = Array1::zeros(beta.len());
149        for block in &self.blocks {
150            if block.lambda == 0.0 {
151                continue;
152            }
153            let b = beta.slice(ndarray::s![block.range.clone()]);
154            let g = block.matrix.dot(&b);
155            let mut dst = grad.slice_mut(ndarray::s![block.range.clone()]);
156            dst += &(block.lambda * g);
157        }
158        grad
159    }
160
161    pub fn hessian(&self, dim: usize) -> Array2<f64> {
162        let mut h = Array2::zeros((dim, dim));
163        self.addhessian_inplace(&mut h);
164        h
165    }
166
167    pub fn deviance(&self, beta: &Array1<f64>) -> f64 {
168        let mut value = 0.0;
169        for block in &self.blocks {
170            if block.lambda == 0.0 {
171                continue;
172            }
173            let b = beta.slice(ndarray::s![block.range.clone()]);
174            value += 0.5 * block.lambda * b.dot(&block.matrix.dot(&b));
175        }
176        value
177    }
178
179    pub fn addhessian_inplace(&self, h: &mut Array2<f64>) {
180        for block in &self.blocks {
181            if block.lambda == 0.0 {
182                continue;
183            }
184            let start = block.range.start;
185            let end = block.range.end;
186            h.slice_mut(ndarray::s![start..end, start..end])
187                .scaled_add(block.lambda, &block.matrix);
188        }
189    }
190}
191
192/// Entry ages at or below this value are treated as left-truncation at the time
193/// origin, i.e. "no delayed-entry interval" — the cumulative-hazard term
194/// `exp(η_entry)` is dropped because `H(0) = 0`. The Royston-Parmar baseline is
195/// `η(t) = log H(t)` with `H(t) → 0` as `t → 0`, so `log H` diverges at the
196/// origin; this small positive floor lets a row that genuinely enters at time
197/// zero skip the entry contribution instead of evaluating `log H` at a
198/// degenerate point. Shared so every entry-detection site stays in lockstep.
199const ENTRY_AT_ORIGIN_THRESHOLD: f64 = 1e-8;
200
201/// Fraction-to-the-boundary factor for the cause-specific feasible-step search.
202/// When a Newton direction would drive a row's derivative down to the
203/// monotonicity floor, the step is capped at this fraction of the distance to
204/// the boundary rather than landing exactly on it. Staying strictly inside the
205/// feasible region (the standard interior-point fraction-to-boundary rule)
206/// keeps the next `1/deriv` / `deriv.ln()` evaluation away from the singular
207/// boundary where curvature blows up.
208const DERIVATIVE_FRACTION_TO_BOUNDARY: f64 = 0.995;
209
210#[derive(Debug, Clone)]
211pub struct CauseSpecificRoystonParmarBlock {
212    pub age_entry: Array1<f64>,
213    pub age_exit: Array1<f64>,
214    pub event_target: Array1<u8>,
215    pub sampleweight: Array1<f64>,
216    pub x_entry: Array2<f64>,
217    pub x_exit: Array2<f64>,
218    pub x_derivative: Array2<f64>,
219    pub offset_eta_entry: Array1<f64>,
220    pub offset_eta_exit: Array1<f64>,
221    pub offset_derivative_exit: Array1<f64>,
222    pub derivative_floor: f64,
223}
224
225/// Cause-specific competing-risks survival as a blockwise custom family.
226///
227/// Each cause is represented by one `ParameterBlockState`, so endpoint-specific
228/// coefficients, shared smoothing labels, and user-defined coefficient groups
229/// stay on the existing `CustomFamily` / `BlockwiseFitOptions` joint-fit path.
230#[derive(Debug, Clone)]
231pub struct CauseSpecificRoystonParmarFamily {
232    blocks: Vec<CauseSpecificRoystonParmarBlock>,
233}
234
235impl CauseSpecificRoystonParmarFamily {
236    pub fn new(blocks: Vec<CauseSpecificRoystonParmarBlock>) -> Result<Self, String> {
237        if blocks.is_empty() {
238            return Err(SurvivalError::InvalidInput {
239                reason: "cause-specific survival family requires at least one endpoint".to_string(),
240            }
241            .into());
242        }
243        for (idx, block) in blocks.iter().enumerate() {
244            validate_cause_specific_block(block).map_err(|err| {
245                SurvivalError::CauseSpecificBlock {
246                    block: idx + 1,
247                    source: Box::new(err),
248                }
249                .to_string()
250            })?;
251        }
252        Ok(Self { blocks })
253    }
254
255    pub fn cause_count(&self) -> usize {
256        self.blocks.len()
257    }
258}
259
260fn validate_cause_specific_block(
261    block: &CauseSpecificRoystonParmarBlock,
262) -> Result<(), SurvivalError> {
263    let n = block.event_target.len();
264    let p = block.x_exit.ncols();
265    if n == 0 || p == 0 {
266        bail_invalid_surv!("empty event vector or coefficient block");
267    }
268    if block.age_entry.len() != n
269        || block.age_exit.len() != n
270        || block.sampleweight.len() != n
271        || block.x_entry.nrows() != n
272        || block.x_exit.nrows() != n
273        || block.x_derivative.nrows() != n
274        || block.x_entry.ncols() != p
275        || block.x_derivative.ncols() != p
276        || block.offset_eta_entry.len() != n
277        || block.offset_eta_exit.len() != n
278        || block.offset_derivative_exit.len() != n
279    {
280        return Err(SurvivalError::CauseSpecificDimensionMismatch {
281            reason: "dimension mismatch".to_string(),
282        });
283    }
284    // A cause-specific block's `event_target` is the binary cause-k indicator
285    // produced by `cause_specific_event_indicator`; a label > 1 here means the
286    // caller passed raw multi-cause codes instead of projecting per cause. That
287    // is a valid finite label, not non-finite input, so it gets its own clear
288    // error rather than the misleading "non-finite input".
289    if let Some(&label) = block.event_target.iter().find(|&&v| v > 1) {
290        return Err(SurvivalError::EventCodeInvalid {
291            reason: format!(
292                "cause-specific block event_target must be the binary cause indicator {{0, 1}}, got multi-cause label {label}; project raw codes per cause via cause_specific_event_indicator"
293            ),
294        });
295    }
296    if block.age_entry.iter().any(|v| !v.is_finite())
297        || block.age_exit.iter().any(|v| !v.is_finite())
298        || block
299            .sampleweight
300            .iter()
301            .any(|v| !v.is_finite() || *v < 0.0)
302        || block.x_entry.iter().any(|v| !v.is_finite())
303        || block.x_exit.iter().any(|v| !v.is_finite())
304        || block.x_derivative.iter().any(|v| !v.is_finite())
305        || block.offset_eta_entry.iter().any(|v| !v.is_finite())
306        || block.offset_eta_exit.iter().any(|v| !v.is_finite())
307        || block.offset_derivative_exit.iter().any(|v| !v.is_finite())
308        || !block.derivative_floor.is_finite()
309        || block.derivative_floor < 0.0
310    {
311        bail_invalid_surv!("non-finite input");
312    }
313    Ok(())
314}
315
316fn evaluate_cause_specific_block(
317    block: &CauseSpecificRoystonParmarBlock,
318    beta: &Array1<f64>,
319) -> Result<(f64, Array1<f64>, Array2<f64>), SurvivalError> {
320    let n = block.event_target.len();
321    let p = block.x_exit.ncols();
322    if beta.len() != p {
323        return Err(SurvivalError::CauseSpecificDimensionMismatch {
324            reason: format!("beta length mismatch: got {}, expected {p}", beta.len()),
325        });
326    }
327    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
328    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
329    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
330    let mut log_likelihood = 0.0;
331    let mut w_exit = Array1::<f64>::zeros(n);
332    let mut w_entry = Array1::<f64>::zeros(n);
333    let mut w_event = Array1::<f64>::zeros(n);
334    let mut w_event_inv_deriv = Array1::<f64>::zeros(n);
335    let mut w_event_outer = Array1::<f64>::zeros(n);
336
337    for i in 0..n {
338        let weight = block.sampleweight[i];
339        if weight <= 0.0 {
340            continue;
341        }
342        if block.age_exit[i] < block.age_entry[i] {
343            bail_invalid_surv!("age_exit < age_entry at row {i}");
344        }
345        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
346        let h_exit = eta_exit[i].exp();
347        let h_entry = if has_entry { eta_entry[i].exp() } else { 0.0 };
348        if !(h_exit.is_finite() && h_entry.is_finite()) {
349            return Err(SurvivalError::NumericalFailure {
350                reason: format!("non-finite cumulative hazard at row {i}"),
351            });
352        }
353        log_likelihood -= weight * (h_exit - h_entry);
354        w_exit[i] = weight * h_exit;
355        w_entry[i] = weight * h_entry;
356        if block.event_target[i] > 0 {
357            let deriv = derivative[i];
358            if !(deriv.is_finite() && deriv > 0.0) {
359                return Err(SurvivalError::NumericalFailure {
360                    reason: format!(
361                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
362                    ),
363                });
364            }
365            log_likelihood += weight * (eta_exit[i] + deriv.ln());
366            w_event[i] = weight;
367            w_event_inv_deriv[i] = weight / deriv;
368            w_event_outer[i] = weight / (deriv * deriv);
369        }
370    }
371
372    let mut nll_gradient = fast_atv(&block.x_exit, &w_exit);
373    nll_gradient -= &fast_atv(&block.x_entry, &w_entry);
374    nll_gradient -= &fast_atv(&block.x_exit, &w_event);
375    nll_gradient -= &fast_atv(&block.x_derivative, &w_event_inv_deriv);
376    let gradient = -nll_gradient;
377
378    let mut hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
379    hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
380    hessian += &fast_xt_diag_x(&block.x_derivative, &w_event_outer);
381    Ok((log_likelihood, gradient, hessian))
382}
383
384impl CustomFamily for CauseSpecificRoystonParmarFamily {
385    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
386    // flat-prior exact-Newton objective carries no Jeffreys term), so families
387    // that historically armed the term by default opt back in explicitly.
388    fn joint_jeffreys_term_required(&self) -> bool {
389        true
390    }
391
392    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
393        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
394            "cause-specific survival",
395            self.blocks.len(),
396            block_states.len(),
397        )?;
398        let mut log_likelihood = 0.0;
399        let mut blockworking_sets = Vec::with_capacity(self.blocks.len());
400        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
401            let (ll, gradient, hessian) = evaluate_cause_specific_block(block, &state.beta)?;
402            log_likelihood += ll;
403            blockworking_sets.push(BlockWorkingSet::ExactNewton {
404                gradient,
405                hessian: SymmetricMatrix::Dense(hessian),
406            });
407        }
408        Ok(FamilyEvaluation {
409            log_likelihood,
410            blockworking_sets,
411        })
412    }
413
414    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
415        crate::block_layout::block_count::validate_block_count::<SurvivalError>(
416            "cause-specific survival",
417            self.blocks.len(),
418            block_states.len(),
419        )?;
420        let mut log_likelihood = 0.0;
421        for (block, state) in self.blocks.iter().zip(block_states.iter()) {
422            let (ll, _, _) = evaluate_cause_specific_block(block, &state.beta)?;
423            log_likelihood += ll;
424        }
425        Ok(log_likelihood)
426    }
427
428    fn likelihood_blocks_uncoupled(&self) -> bool {
429        true
430    }
431
432    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
433        true
434    }
435
436    fn output_channel_assignment(
437        &self,
438        specs: &[crate::custom_family::ParameterBlockSpec],
439    ) -> Option<Vec<usize>> {
440        if specs.len() != self.blocks.len() {
441            return Some((0..self.blocks.len()).collect());
442        }
443        Some((0..specs.len()).collect())
444    }
445
446    fn coefficient_hessian_cost(
447        &self,
448        specs: &[crate::custom_family::ParameterBlockSpec],
449    ) -> u64 {
450        crate::custom_family::default_coefficient_hessian_cost(specs)
451    }
452
453    fn block_linear_constraints(
454        &self,
455        _: &[ParameterBlockState],
456        block_idx: usize,
457        spec: &crate::custom_family::ParameterBlockSpec,
458    ) -> Result<Option<LinearInequalityConstraints>, String> {
459        let block = self.blocks.get(block_idx).ok_or_else(|| {
460            SurvivalError::CauseSpecificDimensionMismatch {
461                reason: format!(
462                    "cause-specific survival expected block index < {}, got {block_idx}",
463                    self.blocks.len()
464                ),
465            }
466            .to_string()
467        })?;
468        if block.x_derivative.ncols() != spec.design.ncols() {
469            return Err(SurvivalError::CauseSpecificDimensionMismatch {
470                reason: format!(
471                    "cause-specific survival derivative design has {} columns but block '{}' has {}",
472                    block.x_derivative.ncols(),
473                    spec.name,
474                    spec.design.ncols()
475                ),
476            }
477            .into());
478        }
479        let rhs = block
480            .offset_derivative_exit
481            .mapv(|offset| block.derivative_floor - offset);
482        Ok(Some(LinearInequalityConstraints {
483            a: block.x_derivative.clone(),
484            b: rhs,
485        }))
486    }
487
488    fn max_feasible_step_size(
489        &self,
490        block_states: &[ParameterBlockState],
491        block_idx: usize,
492        delta: &Array1<f64>,
493    ) -> Result<Option<f64>, String> {
494        let block = self.blocks.get(block_idx).ok_or_else(|| {
495            SurvivalError::CauseSpecificDimensionMismatch {
496                reason: format!(
497                    "cause-specific survival expected block index < {}, got {block_idx}",
498                    self.blocks.len()
499                ),
500            }
501            .to_string()
502        })?;
503        let state = block_states.get(block_idx).ok_or_else(|| {
504            SurvivalError::CauseSpecificDimensionMismatch {
505                reason: format!(
506                    "cause-specific survival expected {} block states, got {}",
507                    self.blocks.len(),
508                    block_states.len()
509                ),
510            }
511            .to_string()
512        })?;
513        if delta.len() != state.beta.len() || block.x_derivative.ncols() != delta.len() {
514            return Err(SurvivalError::CauseSpecificDimensionMismatch {
515                reason: "cause-specific survival feasible-step dimension mismatch".to_string(),
516            }
517            .into());
518        }
519        let derivative = fast_av(&block.x_derivative, &state.beta) + &block.offset_derivative_exit;
520        let derivative_delta = fast_av(&block.x_derivative, delta);
521        let mut alpha_max = 1.0_f64;
522        for i in 0..derivative.len() {
523            if block.sampleweight[i] <= 0.0 {
524                continue;
525            }
526            let current = derivative[i] - block.derivative_floor;
527            let slope = derivative_delta[i];
528            if slope < 0.0 {
529                if current <= 0.0 {
530                    return Ok(Some(0.0));
531                }
532                alpha_max = alpha_max.min(DERIVATIVE_FRACTION_TO_BOUNDARY * current / -slope);
533            }
534        }
535        Ok(Some(alpha_max.clamp(0.0, 1.0)))
536    }
537
538    fn exact_newton_hessian_directional_derivative(
539        &self,
540        block_states: &[ParameterBlockState],
541        block_idx: usize,
542        d_beta: &Array1<f64>,
543    ) -> Result<Option<Array2<f64>>, String> {
544        let block = self.blocks.get(block_idx).ok_or_else(|| {
545            SurvivalError::CauseSpecificDimensionMismatch {
546                reason: format!(
547                    "cause-specific survival expected block index < {}, got {block_idx}",
548                    self.blocks.len()
549                ),
550            }
551            .to_string()
552        })?;
553        let state = block_states.get(block_idx).ok_or_else(|| {
554            SurvivalError::CauseSpecificDimensionMismatch {
555                reason: format!(
556                    "cause-specific survival expected {} block states, got {}",
557                    self.blocks.len(),
558                    block_states.len()
559                ),
560            }
561            .to_string()
562        })?;
563        Ok(Some(cause_specific_hessian_directional_derivative(
564            block,
565            &state.beta,
566            d_beta,
567        )?))
568    }
569
570    fn exact_newton_hessian_second_directional_derivative(
571        &self,
572        block_states: &[ParameterBlockState],
573        block_idx: usize,
574        d_beta_u: &Array1<f64>,
575        d_beta_v: &Array1<f64>,
576    ) -> Result<Option<Array2<f64>>, String> {
577        let block = self.blocks.get(block_idx).ok_or_else(|| {
578            SurvivalError::CauseSpecificDimensionMismatch {
579                reason: format!(
580                    "cause-specific survival expected block index < {}, got {block_idx}",
581                    self.blocks.len()
582                ),
583            }
584            .to_string()
585        })?;
586        let state = block_states.get(block_idx).ok_or_else(|| {
587            SurvivalError::CauseSpecificDimensionMismatch {
588                reason: format!(
589                    "cause-specific survival expected {} block states, got {}",
590                    self.blocks.len(),
591                    block_states.len()
592                ),
593            }
594            .to_string()
595        })?;
596        Ok(Some(cause_specific_hessian_second_directional_derivative(
597            block,
598            &state.beta,
599            d_beta_u,
600            d_beta_v,
601        )?))
602    }
603}
604
605fn cause_specific_hessian_directional_derivative(
606    block: &CauseSpecificRoystonParmarBlock,
607    beta: &Array1<f64>,
608    d_beta: &Array1<f64>,
609) -> Result<Array2<f64>, SurvivalError> {
610    let p = block.x_exit.ncols();
611    if beta.len() != p || d_beta.len() != p {
612        return Err(SurvivalError::CauseSpecificDimensionMismatch {
613            reason: "cause-specific survival Hessian derivative dimension mismatch".to_string(),
614        });
615    }
616    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
617    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
618    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
619    let d_eta_entry = fast_av(&block.x_entry, d_beta);
620    let d_eta_exit = fast_av(&block.x_exit, d_beta);
621    let d_derivative = fast_av(&block.x_derivative, d_beta);
622    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
623    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
624    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
625
626    for i in 0..block.event_target.len() {
627        let weight = block.sampleweight[i];
628        if weight <= 0.0 {
629            continue;
630        }
631        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
632        w_exit[i] = weight * eta_exit[i].exp() * d_eta_exit[i];
633        if has_entry {
634            w_entry[i] = weight * eta_entry[i].exp() * d_eta_entry[i];
635        }
636        if block.event_target[i] > 0 {
637            let deriv = derivative[i];
638            if !(deriv.is_finite() && deriv > 0.0) {
639                return Err(SurvivalError::NumericalFailure {
640                    reason: format!(
641                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
642                    ),
643                });
644            }
645            w_derivative[i] = -2.0 * weight * d_derivative[i] / (deriv * deriv * deriv);
646        }
647    }
648
649    let mut d_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
650    d_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
651    d_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
652    Ok(d_hessian)
653}
654
655fn cause_specific_hessian_second_directional_derivative(
656    block: &CauseSpecificRoystonParmarBlock,
657    beta: &Array1<f64>,
658    d_beta_u: &Array1<f64>,
659    d_beta_v: &Array1<f64>,
660) -> Result<Array2<f64>, SurvivalError> {
661    let p = block.x_exit.ncols();
662    if beta.len() != p || d_beta_u.len() != p || d_beta_v.len() != p {
663        return Err(SurvivalError::CauseSpecificDimensionMismatch {
664            reason: "cause-specific survival second Hessian derivative dimension mismatch"
665                .to_string(),
666        });
667    }
668    let eta_entry = fast_av(&block.x_entry, beta) + &block.offset_eta_entry;
669    let eta_exit = fast_av(&block.x_exit, beta) + &block.offset_eta_exit;
670    let derivative = fast_av(&block.x_derivative, beta) + &block.offset_derivative_exit;
671    let u_eta_entry = fast_av(&block.x_entry, d_beta_u);
672    let u_eta_exit = fast_av(&block.x_exit, d_beta_u);
673    let u_derivative = fast_av(&block.x_derivative, d_beta_u);
674    let v_eta_entry = fast_av(&block.x_entry, d_beta_v);
675    let v_eta_exit = fast_av(&block.x_exit, d_beta_v);
676    let v_derivative = fast_av(&block.x_derivative, d_beta_v);
677    let mut w_exit = Array1::<f64>::zeros(block.event_target.len());
678    let mut w_entry = Array1::<f64>::zeros(block.event_target.len());
679    let mut w_derivative = Array1::<f64>::zeros(block.event_target.len());
680
681    for i in 0..block.event_target.len() {
682        let weight = block.sampleweight[i];
683        if weight <= 0.0 {
684            continue;
685        }
686        let has_entry = block.age_entry[i] > ENTRY_AT_ORIGIN_THRESHOLD;
687        w_exit[i] = weight * eta_exit[i].exp() * u_eta_exit[i] * v_eta_exit[i];
688        if has_entry {
689            w_entry[i] = weight * eta_entry[i].exp() * u_eta_entry[i] * v_eta_entry[i];
690        }
691        if block.event_target[i] > 0 {
692            let deriv = derivative[i];
693            if !(deriv.is_finite() && deriv > 0.0) {
694                return Err(SurvivalError::NumericalFailure {
695                    reason: format!(
696                        "cause-specific survival derivative must be positive at row {i}, got {deriv}"
697                    ),
698                });
699            }
700            w_derivative[i] = 6.0 * weight * u_derivative[i] * v_derivative[i] / deriv.powi(4);
701        }
702    }
703
704    let mut d2_hessian = fast_xt_diag_x(&block.x_exit, &w_exit);
705    d2_hessian -= &fast_xt_diag_x(&block.x_entry, &w_entry);
706    d2_hessian += &fast_xt_diag_x(&block.x_derivative, &w_derivative);
707    Ok(d2_hessian)
708}
709
710pub fn survival_event_code_from_value(value: f64, row_index: usize) -> Result<u8, String> {
711    const INTEGER_TOL: f64 = 1e-8;
712    const MAX_AUTO_CAUSES: u8 = 32;
713    if !value.is_finite() {
714        return Err(SurvivalError::EventCodeInvalid {
715            reason: format!(
716                "survival event value at row {} is non-finite",
717                row_index + 1
718            ),
719        }
720        .into());
721    }
722    if value < 0.0 {
723        return Err(SurvivalError::EventCodeInvalid {
724            reason: format!(
725                "survival event value at row {} is negative: {value}",
726                row_index + 1
727            ),
728        }
729        .into());
730    }
731    let rounded = value.round();
732    if (value - rounded).abs() > INTEGER_TOL {
733        return Err(SurvivalError::EventCodeInvalid {
734            reason: format!(
735                "survival event value at row {} must be an integer code with 0=censored, got {value}",
736                row_index + 1
737            ),
738        }
739        .into());
740    }
741    if rounded > f64::from(MAX_AUTO_CAUSES) {
742        return Err(SurvivalError::EventCodeInvalid {
743            reason: format!(
744                "survival event value at row {} has code {rounded}; automatic competing-risks detection supports codes 0..={MAX_AUTO_CAUSES}",
745                row_index + 1
746            ),
747        }
748        .into());
749    }
750    Ok(rounded as u8)
751}
752
753pub fn cause_count_from_event_codes(
754    event_codes: ArrayView1<'_, u8>,
755) -> Result<usize, SurvivalError> {
756    let max_code = event_codes.iter().copied().max().map_or(0, usize::from);
757    if max_code == 0 {
758        return Ok(1);
759    }
760
761    let mut present = vec![false; max_code + 1];
762    for code in event_codes.iter().copied() {
763        present[usize::from(code)] = true;
764    }
765    if (1..=max_code).any(|code| !present[code]) {
766        let actual = present
767            .iter()
768            .enumerate()
769            .skip(1)
770            .filter_map(|(code, &seen)| seen.then_some(code.to_string()))
771            .collect::<Vec<_>>()
772            .join(", ");
773        return Err(SurvivalError::EventCodeInvalid {
774            reason: format!(
775                "survival competing-risks event codes must use contiguous positive codes; observed nonzero codes are {{{actual}}}. Remap event codes contiguously (for example, {{0,1,3}} -> {{0,1,2}}), otherwise a phantom cause is fit with no events and pollutes CIF assembly."
776            ),
777        });
778    }
779
780    Ok(max_code)
781}
782
783/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
784/// onto the binary `{0, 1}` *any-event* indicator the pooled single-hazard
785/// baseline engine consumes.
786///
787/// The shared Royston-Parmar baseline working model fits one hazard across all
788/// causes; every observed event (regardless of cause) informs that baseline, so
789/// the indicator is `1` exactly when *any* cause occurred. This is one of the
790/// two cause-aware projections of the raw code vector; the other is
791/// [`cause_specific_event_indicator`]. Centralizing both here keeps the single
792/// source of truth for "how multi-cause labels become a single-hazard binary
793/// contract", so no construction path open-codes a fragile `mapv` and then
794/// trips the binary `event_target > 1` guard on the raw labels.
795pub fn pooled_any_event_indicator(event_codes: ArrayView1<'_, u8>) -> Array1<u8> {
796    event_codes.mapv(|label| u8::from(label > 0))
797}
798
799/// Project multi-cause competing-risks event codes `{0 = censored, k = cause k}`
800/// onto the binary `{0, 1}` indicator for the cause-specific Royston-Parmar
801/// block of cause `cause` (1-based).
802///
803/// Within cause `cause`'s block the event of interest is `event == cause`; every
804/// competing cause is treated as censoring (indicator `0`), which is exactly the
805/// cause-specific hazard likelihood. Like [`pooled_any_event_indicator`], this
806/// yields a binary contract that satisfies the single-hazard `event_target` guard
807/// — the raw multi-cause labels are never handed to a binary engine.
808pub fn cause_specific_event_indicator(event_codes: ArrayView1<'_, u8>, cause: usize) -> Array1<u8> {
809    let cause_code = cause as u8;
810    event_codes.mapv(|observed| u8::from(observed == cause_code))
811}
812
813fn compress_positive_collinear_constraints(
814    a: &Array2<f64>,
815    b: &Array1<f64>,
816) -> LinearInequalityConstraints {
817    const SCALE_TOL: f64 = 1e-14;
818    const KEY_TOL: f64 = 1e-8;
819
820    let mut grouped: BTreeMap<Vec<i64>, (Vec<f64>, f64)> = BTreeMap::new();
821    let mut fallbackrows: Vec<(Vec<f64>, f64)> = Vec::new();
822
823    for i in 0..a.nrows() {
824        let row = a.row(i);
825        let scale = row.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
826        if !scale.is_finite() || scale <= SCALE_TOL {
827            if b[i] > 0.0 {
828                fallbackrows.push((row.to_vec(), b[i]));
829            }
830            continue;
831        }
832
833        let normalizedrow: Vec<f64> = row
834            .iter()
835            .map(|&v| {
836                let scaled = v / scale;
837                if scaled.abs() <= KEY_TOL { 0.0 } else { scaled }
838            })
839            .collect();
840        let normalized_rhs = b[i] / scale;
841        let key: Vec<i64> = normalizedrow
842            .iter()
843            .map(|&v| (v / KEY_TOL).round() as i64)
844            .collect();
845
846        match grouped.get_mut(&key) {
847            Some((_, rhs_max)) => {
848                if normalized_rhs > *rhs_max {
849                    *rhs_max = normalized_rhs;
850                }
851            }
852            None => {
853                grouped.insert(key, (normalizedrow, normalized_rhs));
854            }
855        }
856    }
857
858    let nrows = grouped.len() + fallbackrows.len();
859    let n_cols = a.ncols();
860    let mut a_out = Array2::<f64>::zeros((nrows, n_cols));
861    let mut b_out = Array1::<f64>::zeros(nrows);
862
863    let mut outrow = 0usize;
864    for (_, (row, rhs)) in grouped {
865        for (j, value) in row.into_iter().enumerate() {
866            a_out[[outrow, j]] = value;
867        }
868        b_out[outrow] = rhs;
869        outrow += 1;
870    }
871    for (row, rhs) in fallbackrows {
872        for (j, value) in row.into_iter().enumerate() {
873            a_out[[outrow, j]] = value;
874        }
875        b_out[outrow] = rhs;
876        outrow += 1;
877    }
878
879    LinearInequalityConstraints { a: a_out, b: b_out }
880}
881
882#[derive(Debug, Clone, Copy, Default)]
883pub struct SurvivalMonotonicityPenalty {
884    pub tolerance: f64,
885}
886
887#[derive(Debug, Clone)]
888enum SurvivalDesign {
889    Flat {
890        x_entry: Array2<f64>,
891        x_exit: Array2<f64>,
892        x_derivative: Array2<f64>,
893    },
894    TimeCovariateShared {
895        time_entry: Array2<f64>,
896        time_exit: Array2<f64>,
897        time_derivative: Array2<f64>,
898        covariates: Array2<f64>,
899    },
900}
901
902impl SurvivalDesign {
903    fn p_total(&self) -> usize {
904        match self {
905            Self::Flat { x_exit, .. } => x_exit.ncols(),
906            Self::TimeCovariateShared {
907                time_exit,
908                covariates,
909                ..
910            } => time_exit.ncols() + covariates.ncols(),
911        }
912    }
913
914    fn design_dot(&self, time_mat: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
915        match self {
916            Self::Flat { .. } => time_mat.dot(beta),
917            Self::TimeCovariateShared { covariates, .. } => {
918                let p_time = time_mat.ncols();
919                let mut out = time_mat.dot(&beta.slice(ndarray::s![..p_time]));
920                if covariates.ncols() > 0 {
921                    out += &covariates.dot(&beta.slice(ndarray::s![p_time..]));
922                }
923                out
924            }
925        }
926    }
927
928    fn fill_row(&self, time_mat: &Array2<f64>, i: usize, out: &mut [f64]) {
929        match self {
930            Self::Flat { .. } => {
931                for (dst, &src) in out.iter_mut().zip(time_mat.row(i).iter()) {
932                    *dst = src;
933                }
934            }
935            Self::TimeCovariateShared { covariates, .. } => {
936                let p_time = time_mat.ncols();
937                for j in 0..p_time {
938                    out[j] = time_mat[[i, j]];
939                }
940                for j in 0..covariates.ncols() {
941                    out[p_time + j] = covariates[[i, j]];
942                }
943            }
944        }
945    }
946}
947
948/// Pre-allocated workspace buffers for `update_state` to avoid per-iteration allocations.
949#[derive(Debug, Clone)]
950struct SurvivalWorkspace {
951    w_event: Array1<f64>,
952    w_event_inv_deriv: Array1<f64>,
953    w_event_outer: Array1<f64>,
954    w_hess_exit: Array1<f64>,
955    w_hess_entry: Array1<f64>,
956}
957
958impl SurvivalWorkspace {
959    fn new(n: usize) -> Self {
960        Self {
961            w_event: Array1::zeros(n),
962            w_event_inv_deriv: Array1::zeros(n),
963            w_event_outer: Array1::zeros(n),
964            w_hess_exit: Array1::zeros(n),
965            w_hess_entry: Array1::zeros(n),
966        }
967    }
968
969    fn reset(&mut self, n: usize) {
970        if self.w_event.len() != n {
971            *self = Self::new(n);
972        } else {
973            self.w_event.fill(0.0);
974            self.w_event_inv_deriv.fill(0.0);
975            self.w_event_outer.fill(0.0);
976            self.w_hess_exit.fill(0.0);
977            self.w_hess_entry.fill(0.0);
978        }
979    }
980}
981
982/// Per-observation gradients of the unpenalized survival NLL with respect
983/// to each additive offset channel, at a given β. See
984/// [`WorkingModelSurvival::offset_channel_residuals`] for the algebra.
985///
986/// Contract: all four arrays have length `n` = number of observations.
987/// Rows with non-positive sampleweight are 0 in every channel. The
988/// `derivative` channel is 0 in all non-event rows. The `right` channel is
989/// the interval upper-bound (`R`) η-offset sensitivity and is exactly 0 for
990/// every NON-interval-censored model and every non-interval row of the latent
991/// interval model (only the dedicated `SurvInterval(L, R, event)` latent fit
992/// populates it); the baseline-θ chain rule contracts it against the
993/// `age_right`-evaluated η-partial.
994#[derive(Clone, Debug)]
995pub struct OffsetChannelResiduals {
996    /// ∂NLL/∂o_X: w·(exp(η_exit) − δ) per row.
997    pub exit: Array1<f64>,
998    /// ∂NLL/∂o_E: −w·exp(η_entry) if row has a positive entry interval else 0.
999    pub entry: Array1<f64>,
1000    /// ∂NLL/∂o_D: −w·δ / s (event-row only).
1001    pub derivative: Array1<f64>,
1002    /// ∂NLL/∂o_R: interval upper-bound (`R`) η-offset sensitivity,
1003    /// `−w·∂(log-lik)/∂q_right`. Nonzero only for interval-censored latent
1004    /// rows; exactly 0 for every other channel/model.
1005    pub right: Array1<f64>,
1006}
1007
1008/// Per-observation Hessians of the unpenalized survival NLL with respect
1009/// to additive offset channels in `(entry, exit, derivative)` order.
1010#[derive(Clone, Debug)]
1011pub struct OffsetChannelCurvatures {
1012    pub rows: Vec<[[f64; 3]; 3]>,
1013}
1014
1015#[derive(Debug)]
1016pub struct WorkingModelSurvival {
1017    age_entry: Array1<f64>,
1018    age_exit: Array1<f64>,
1019    entry_at_origin: Array1<bool>,
1020    event_target: Array1<u8>,
1021    sampleweight: Array1<f64>,
1022    design: SurvivalDesign,
1023    offset_eta_entry: Array1<f64>,
1024    offset_eta_exit: Array1<f64>,
1025    offset_derivative_exit: Array1<f64>,
1026    penalties: PenaltyBlocks,
1027    monotonicity: SurvivalMonotonicityPenalty,
1028    structurally_monotonic: bool,
1029    structural_time_columns: usize,
1030    monotonicity_constraint_rows: Option<Array2<f64>>,
1031    monotonicity_constraint_offsets: Option<Array1<f64>>,
1032    workspace: std::sync::Mutex<SurvivalWorkspace>,
1033}
1034
1035impl Clone for WorkingModelSurvival {
1036    fn clone(&self) -> Self {
1037        let workspace = self.workspace.lock().unwrap().clone();
1038        Self {
1039            age_entry: self.age_entry.clone(),
1040            age_exit: self.age_exit.clone(),
1041            entry_at_origin: self.entry_at_origin.clone(),
1042            event_target: self.event_target.clone(),
1043            sampleweight: self.sampleweight.clone(),
1044            design: self.design.clone(),
1045            offset_eta_entry: self.offset_eta_entry.clone(),
1046            offset_eta_exit: self.offset_eta_exit.clone(),
1047            offset_derivative_exit: self.offset_derivative_exit.clone(),
1048            penalties: self.penalties.clone(),
1049            monotonicity: self.monotonicity,
1050            structurally_monotonic: self.structurally_monotonic,
1051            structural_time_columns: self.structural_time_columns,
1052            monotonicity_constraint_rows: self.monotonicity_constraint_rows.clone(),
1053            monotonicity_constraint_offsets: self.monotonicity_constraint_offsets.clone(),
1054            workspace: std::sync::Mutex::new(workspace),
1055        }
1056    }
1057}
1058
1059impl WorkingModelSurvival {
1060    const LOG_F64_MAX: f64 = 709.782712893384;
1061
1062    #[inline]
1063    fn scaled_exp_component(log_scale: f64, base: f64) -> Result<f64, EstimationError> {
1064        if base == 0.0 {
1065            return Ok(0.0);
1066        }
1067        let log_abs = log_scale + base.abs().ln();
1068        if !log_abs.is_finite() {
1069            crate::bail_invalid_estim!("survival interval term produced non-finite log-magnitude");
1070        }
1071        if log_abs > Self::LOG_F64_MAX {
1072            crate::bail_invalid_estim!(
1073                "survival interval term exceeds f64 range (log-magnitude={log_abs:.3e})"
1074            );
1075        }
1076        Ok(base.signum() * log_abs.exp())
1077    }
1078
1079    fn coefficient_dim(&self) -> usize {
1080        self.design.p_total()
1081    }
1082
1083    fn nrows(&self) -> usize {
1084        self.sampleweight.len()
1085    }
1086
1087    fn entry_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1088        let time_mat = match &self.design {
1089            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1090            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1091        };
1092        self.design.design_dot(time_mat, beta)
1093    }
1094
1095    fn exit_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1096        let time_mat = match &self.design {
1097            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1098            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1099        };
1100        self.design.design_dot(time_mat, beta)
1101    }
1102
1103    fn derivative_dot(&self, beta: &Array1<f64>) -> Array1<f64> {
1104        match &self.design {
1105            SurvivalDesign::Flat { x_derivative, .. } => x_derivative.dot(beta),
1106            SurvivalDesign::TimeCovariateShared {
1107                time_derivative, ..
1108            } => time_derivative.dot(&beta.slice(ndarray::s![..time_derivative.ncols()])),
1109        }
1110    }
1111
1112    fn fill_entry_row(&self, i: usize, out: &mut [f64]) {
1113        let time_mat = match &self.design {
1114            SurvivalDesign::Flat { x_entry, .. } => x_entry,
1115            SurvivalDesign::TimeCovariateShared { time_entry, .. } => time_entry,
1116        };
1117        self.design.fill_row(time_mat, i, out);
1118    }
1119
1120    fn fill_exit_row(&self, i: usize, out: &mut [f64]) {
1121        let time_mat = match &self.design {
1122            SurvivalDesign::Flat { x_exit, .. } => x_exit,
1123            SurvivalDesign::TimeCovariateShared { time_exit, .. } => time_exit,
1124        };
1125        self.design.fill_row(time_mat, i, out);
1126    }
1127
1128    fn fill_derivative_row(&self, i: usize, out: &mut [f64]) {
1129        match &self.design {
1130            SurvivalDesign::Flat { x_derivative, .. } => {
1131                for (dst, &src) in out.iter_mut().zip(x_derivative.row(i).iter()) {
1132                    *dst = src;
1133                }
1134            }
1135            SurvivalDesign::TimeCovariateShared {
1136                time_derivative, ..
1137            } => {
1138                let p_time = time_derivative.ncols();
1139                for j in 0..p_time {
1140                    out[j] = time_derivative[[i, j]];
1141                }
1142                for dst in out.iter_mut().skip(p_time) {
1143                    *dst = 0.0;
1144                }
1145            }
1146        }
1147    }
1148
1149    fn derivative_xt_diag_x(&self, weights: &Array1<f64>) -> Array2<f64> {
1150        match &self.design {
1151            SurvivalDesign::Flat { x_derivative, .. } => fast_xt_diag_x(x_derivative, weights),
1152            SurvivalDesign::TimeCovariateShared {
1153                time_derivative,
1154                covariates,
1155                ..
1156            } => {
1157                let p_time = time_derivative.ncols();
1158                let p_cov = covariates.ncols();
1159                let mut out = Array2::<f64>::zeros((p_time + p_cov, p_time + p_cov));
1160                let time_block = fast_xt_diag_x(time_derivative, weights);
1161                out.slice_mut(ndarray::s![..p_time, ..p_time])
1162                    .assign(&time_block);
1163                out
1164            }
1165        }
1166    }
1167
1168    /// Compute the full p×p Hessian contribution for the interval terms:
1169    ///   H = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1170    /// using faer-accelerated BLAS on the stored design matrix blocks.
1171    fn interval_hessian_blas(&self, w_exit: &Array1<f64>, w_entry: &Array1<f64>) -> Array2<f64> {
1172        match &self.design {
1173            SurvivalDesign::Flat {
1174                x_entry, x_exit, ..
1175            } => {
1176                let mut h = fast_xt_diag_x(x_exit, w_exit);
1177                h -= &fast_xt_diag_x(x_entry, w_entry);
1178                h
1179            }
1180            SurvivalDesign::TimeCovariateShared {
1181                time_entry,
1182                time_exit,
1183                covariates,
1184                ..
1185            } => {
1186                let p_time = time_exit.ncols();
1187                let p_cov = covariates.ncols();
1188                let p = p_time + p_cov;
1189                let mut h = Array2::<f64>::zeros((p, p));
1190                // time-time block: T_exit^T W_exit T_exit - T_entry^T W_entry T_entry
1191                let tt = {
1192                    let mut block = fast_xt_diag_x(time_exit, w_exit);
1193                    block -= &fast_xt_diag_x(time_entry, w_entry);
1194                    block
1195                };
1196                h.slice_mut(ndarray::s![..p_time, ..p_time]).assign(&tt);
1197                if p_cov > 0 {
1198                    // time-cov block: T_exit^T W_exit C - T_entry^T W_entry C
1199                    let tc = {
1200                        let mut block = fast_xt_diag_y(time_exit, w_exit, covariates);
1201                        block -= &fast_xt_diag_y(time_entry, w_entry, covariates);
1202                        block
1203                    };
1204                    h.slice_mut(ndarray::s![..p_time, p_time..]).assign(&tc);
1205                    h.slice_mut(ndarray::s![p_time.., ..p_time]).assign(&tc.t());
1206                    // cov-cov block: C^T (W_exit - W_entry) C
1207                    let w_diff = w_exit - w_entry;
1208                    let cc = fast_xt_diag_x(covariates, &w_diff);
1209                    h.slice_mut(ndarray::s![p_time.., p_time..]).assign(&cc);
1210                }
1211                h
1212            }
1213        }
1214    }
1215
1216    fn stabilized_structural_derivative(&self, deriv: f64) -> Option<f64> {
1217        const STRUCTURAL_MONO_ROUNDOFF_TOL: f64 = 1e-7;
1218        if !self.structurally_monotonic {
1219            return None;
1220        }
1221        if deriv >= 1e-12 {
1222            return Some(deriv);
1223        }
1224        if deriv >= -STRUCTURAL_MONO_ROUNDOFF_TOL {
1225            return Some(1e-12);
1226        }
1227        None
1228    }
1229
1230    fn validate_penalties(
1231        penalties: &PenaltyBlocks,
1232        coefficient_dim: usize,
1233    ) -> Result<(), SurvivalError> {
1234        for block in &penalties.blocks {
1235            if !block.lambda.is_finite() || block.lambda < 0.0 {
1236                return Err(SurvivalError::NonFiniteInput);
1237            }
1238            if block.range.start > block.range.end || block.range.end > coefficient_dim {
1239                return Err(SurvivalError::DimensionMismatch);
1240            }
1241            let block_dim = block.range.end - block.range.start;
1242            if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
1243                return Err(SurvivalError::DimensionMismatch);
1244            }
1245            if block.matrix.iter().any(|v| !v.is_finite()) {
1246                return Err(SurvivalError::NonFiniteInput);
1247            }
1248        }
1249        Ok(())
1250    }
1251
1252    fn derivative_guard(&self) -> f64 {
1253        if self.structurally_monotonic {
1254            // I-spline basis is monotone by construction when coefficients ≥ 0.
1255            // A derivative of zero (flat hazard) is valid, so the guard only
1256            // rejects genuinely negative derivatives from floating-point noise.
1257            return 0.0;
1258        }
1259        self.monotonicity.tolerance.max(0.0)
1260    }
1261
1262    fn derivative_guard_numerical(&self) -> f64 {
1263        let derivative_guard = self.derivative_guard();
1264        if derivative_guard <= 0.0 {
1265            // For structural monotonicity (guard = 0), tiny negative derivs are
1266            // tolerated because `stabilized_structural_derivative` lifts the
1267            // value back to a small positive floor before any `ln`/`1/deriv`
1268            // use. For *non-structural* monotonicity with tolerance == 0 the
1269            // raw derivative flows straight through into the event-row
1270            // `deriv.ln()` and `1.0 / deriv`, so any non-positive value would
1271            // produce NaN / huge negative weights. Keep the slack only when
1272            // the structural stabilizer is active.
1273            if self.structurally_monotonic {
1274                -1e-10
1275            } else {
1276                1e-12
1277            }
1278        } else {
1279            (derivative_guard - (1e-10_f64).min(0.01 * derivative_guard)).max(1e-12)
1280        }
1281    }
1282
1283    fn interval_increment_guard(&self, h_entry: f64, h_exit: f64) -> f64 {
1284        let scale = h_entry.abs().max(h_exit.abs()).max(1.0);
1285        1e-10 * scale
1286    }
1287
1288    fn structural_time_coefficient_constraints(&self) -> Option<LinearInequalityConstraints> {
1289        if !self.structurally_monotonic {
1290            return None;
1291        }
1292        let p = self.coefficient_dim();
1293        let time_columns = self.structural_time_columns.min(p);
1294        if time_columns == 0 {
1295            return None;
1296        }
1297        const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1298        let mut active_columns = vec![false; time_columns];
1299        let mut derivative_row = vec![0.0_f64; p];
1300        for i in 0..self.nrows() {
1301            if self.sampleweight[i] <= 0.0 {
1302                continue;
1303            }
1304            self.fill_derivative_row(i, &mut derivative_row);
1305            for j in 0..time_columns {
1306                if derivative_row[j] > STRUCTURAL_DERIV_TOL {
1307                    active_columns[j] = true;
1308                }
1309            }
1310        }
1311        if let Some(rows) = self.monotonicity_constraint_rows.as_ref() {
1312            for i in 0..rows.nrows() {
1313                for j in 0..time_columns {
1314                    if rows[[i, j]] > STRUCTURAL_DERIV_TOL {
1315                        active_columns[j] = true;
1316                    }
1317                }
1318            }
1319        }
1320        let active_columns: Vec<usize> = active_columns
1321            .into_iter()
1322            .enumerate()
1323            .filter_map(|(j, active)| active.then_some(j))
1324            .collect();
1325        if active_columns.is_empty() {
1326            return None;
1327        }
1328        let mut a = Array2::<f64>::zeros((active_columns.len(), p));
1329        let b = Array1::<f64>::zeros(active_columns.len());
1330        for (row, &col) in active_columns.iter().enumerate() {
1331            a[[row, col]] = 1.0;
1332        }
1333        Some(LinearInequalityConstraints { a, b })
1334    }
1335
1336    pub fn monotonicity_linear_constraints(&self) -> Option<LinearInequalityConstraints> {
1337        let p = self.coefficient_dim();
1338        const DERIVATIVE_ROW_NORM_TOL: f64 = 1e-12;
1339        if p == 0 {
1340            return None;
1341        }
1342        if self.structurally_monotonic {
1343            return self.structural_time_coefficient_constraints();
1344        }
1345        if let (Some(rows), Some(offsets)) = (
1346            self.monotonicity_constraint_rows.as_ref(),
1347            self.monotonicity_constraint_offsets.as_ref(),
1348        ) {
1349            let activerows: Vec<usize> = (0..rows.nrows())
1350                .filter(|&i| {
1351                    rows.row(i).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()))
1352                        > DERIVATIVE_ROW_NORM_TOL
1353                })
1354                .collect();
1355            if activerows.is_empty() {
1356                return None;
1357            }
1358            let mut a = Array2::<f64>::zeros((activerows.len(), p));
1359            let mut b = Array1::<f64>::zeros(activerows.len());
1360            for (r, &i) in activerows.iter().enumerate() {
1361                a.row_mut(r).assign(&rows.row(i));
1362                b[r] = self.derivative_guard() - offsets[i];
1363            }
1364            return Some(compress_positive_collinear_constraints(&a, &b));
1365        }
1366        None
1367    }
1368
1369    pub fn from_engine_inputs(
1370        inputs: SurvivalEngineInputs<'_>,
1371        penalties: PenaltyBlocks,
1372        monotonicity: SurvivalMonotonicityPenalty,
1373        spec: SurvivalSpec,
1374    ) -> Result<Self, SurvivalError> {
1375        Self::from_engine_inputswith_offsets(inputs, None, penalties, monotonicity, spec)
1376    }
1377
1378    fn validate_offsets(
1379        offsets: Option<SurvivalBaselineOffsets<'_>>,
1380        n: usize,
1381    ) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), SurvivalError> {
1382        if let Some(off) = offsets {
1383            if off.eta_entry.len() != n || off.eta_exit.len() != n || off.derivative_exit.len() != n
1384            {
1385                return Err(SurvivalError::DimensionMismatch);
1386            }
1387            if off.eta_entry.iter().any(|v| !v.is_finite())
1388                || off.eta_exit.iter().any(|v| !v.is_finite())
1389                || off.derivative_exit.iter().any(|v| !v.is_finite())
1390            {
1391                return Err(SurvivalError::NonFiniteInput);
1392            }
1393            Ok((
1394                off.eta_entry.to_owned(),
1395                off.eta_exit.to_owned(),
1396                off.derivative_exit.to_owned(),
1397            ))
1398        } else {
1399            Ok((Array1::zeros(n), Array1::zeros(n), Array1::zeros(n)))
1400        }
1401    }
1402
1403    fn validate_common_inputs(
1404        age_entry: &ArrayView1<f64>,
1405        age_exit: &ArrayView1<f64>,
1406        event_target: &ArrayView1<u8>,
1407        event_competing: &ArrayView1<u8>,
1408        sampleweight: &ArrayView1<f64>,
1409    ) -> Result<(), SurvivalError> {
1410        if age_entry.iter().any(|v| !v.is_finite())
1411            || age_exit.iter().any(|v| !v.is_finite())
1412            || sampleweight.iter().any(|v| !v.is_finite() || *v < 0.0)
1413        {
1414            return Err(SurvivalError::NonFiniteInput);
1415        }
1416        // The single-hazard engine's `event_target` contract is binary {0, 1}.
1417        // A code > 1 is a *valid finite multi-cause label* that simply must be
1418        // projected first (any-event for the pooled baseline, cause-specific for
1419        // each block); it is NOT a non-finite input. Report it as such so the
1420        // failure is actionable and never surfaces as the misleading "inputs
1421        // contain non-finite values".
1422        if let Some(&label) = event_target.iter().find(|&&v| v > 1) {
1423            return Err(SurvivalError::EventCodeInvalid {
1424                reason: format!(
1425                    "single-hazard survival engine requires a binary {{0, 1}} event_target, got multi-cause label {label}; competing-risks codes must be projected via pooled_any_event_indicator / cause_specific_event_indicator before construction"
1426                ),
1427            });
1428        }
1429        if let Some(&label) = event_competing.iter().find(|&&v| v > 1) {
1430            return Err(SurvivalError::EventCodeInvalid {
1431                reason: format!(
1432                    "single-hazard survival engine requires a binary {{0, 1}} event_competing, got multi-cause label {label}"
1433                ),
1434            });
1435        }
1436        if event_target
1437            .iter()
1438            .zip(event_competing.iter())
1439            .any(|(&target, &competing)| target > 0 && competing > 0)
1440        {
1441            return Err(SurvivalError::EventCodeInvalid {
1442                reason: "a row cannot be simultaneously a target event and a competing event"
1443                    .to_string(),
1444            });
1445        }
1446        // The "must have at least one target event" requirement is a
1447        // *fittability* check, not a structural one: with all rows censored the
1448        // likelihood has no event score, so any subsequent fit cannot identify
1449        // the hazard and the optimizer spins on a flat landscape.  But the
1450        // structural integrity of the engine — its derivative-guard rejection
1451        // of decreasing cumulative hazards, its monotonicity-collocation
1452        // bookkeeping, its update_state numerics — is well-defined on
1453        // all-censored inputs, and unit tests legitimately exercise those
1454        // structural paths on censored fixtures.  Move the fittability check
1455        // out of construction; production fit dispatchers (e.g.
1456        // `solver::fit_orchestration::materialize_survival`) enforce it on the
1457        // single chokepoint that actually starts an optimization, where
1458        // the failure mode it guards against is reachable.
1459        if age_entry
1460            .iter()
1461            .zip(age_exit.iter())
1462            .any(|(&entry, &exit)| entry < 0.0 || exit <= 0.0)
1463        {
1464            return Err(SurvivalError::NonFiniteInput);
1465        }
1466        Ok::<(), _>(())
1467    }
1468
1469    fn validate_monotonicity_constraints(
1470        rows: Option<ArrayView2<'_, f64>>,
1471        offsets: Option<ArrayView1<'_, f64>>,
1472        coefficient_dim: usize,
1473    ) -> Result<(Option<Array2<f64>>, Option<Array1<f64>>), SurvivalError> {
1474        match (rows, offsets) {
1475            (None, None) => Ok((None, None)),
1476            (Some(rows), Some(offsets)) => {
1477                if rows.ncols() != coefficient_dim
1478                    || rows.nrows() != offsets.len()
1479                    || rows.iter().any(|v| !v.is_finite())
1480                    || offsets.iter().any(|v| !v.is_finite())
1481                {
1482                    return Err(SurvivalError::DimensionMismatch);
1483                }
1484                Ok((Some(rows.to_owned()), Some(offsets.to_owned())))
1485            }
1486            _ => Err(SurvivalError::DimensionMismatch),
1487        }
1488    }
1489
1490    fn finish_construction(
1491        age_entry: ArrayView1<f64>,
1492        age_exit: ArrayView1<f64>,
1493        event_target: ArrayView1<u8>,
1494        sampleweight: ArrayView1<f64>,
1495        design: SurvivalDesign,
1496        offset_eta_entry: Array1<f64>,
1497        offset_eta_exit: Array1<f64>,
1498        offset_derivative_exit: Array1<f64>,
1499        penalties: PenaltyBlocks,
1500        monotonicity: SurvivalMonotonicityPenalty,
1501        monotonicity_constraint_rows: Option<Array2<f64>>,
1502        monotonicity_constraint_offsets: Option<Array1<f64>>,
1503    ) -> Self {
1504        let n = age_entry.len();
1505        Self {
1506            age_entry: age_entry.to_owned(),
1507            age_exit: age_exit.to_owned(),
1508            entry_at_origin: age_entry.mapv(|t| t <= ENTRY_AT_ORIGIN_THRESHOLD),
1509            event_target: event_target.to_owned(),
1510            sampleweight: sampleweight.to_owned(),
1511            design,
1512            offset_eta_entry,
1513            offset_eta_exit,
1514            offset_derivative_exit,
1515            penalties,
1516            monotonicity,
1517            structurally_monotonic: false,
1518            structural_time_columns: 0,
1519            monotonicity_constraint_rows,
1520            monotonicity_constraint_offsets,
1521            workspace: std::sync::Mutex::new(SurvivalWorkspace::new(n)),
1522        }
1523    }
1524
1525    pub fn from_engine_inputswith_offsets(
1526        inputs: SurvivalEngineInputs<'_>,
1527        offsets: Option<SurvivalBaselineOffsets<'_>>,
1528        penalties: PenaltyBlocks,
1529        monotonicity: SurvivalMonotonicityPenalty,
1530        spec: SurvivalSpec,
1531    ) -> Result<Self, SurvivalError> {
1532        if spec == SurvivalSpec::Crude {
1533            return Err(SurvivalError::UnsupportedSpec("crude"));
1534        }
1535        let n = inputs.age_entry.len();
1536        let p = inputs.x_entry.ncols();
1537        if inputs.age_exit.len() != n
1538            || inputs.event_target.len() != n
1539            || inputs.event_competing.len() != n
1540            || inputs.sampleweight.len() != n
1541            || inputs.x_entry.nrows() != n
1542            || inputs.x_exit.nrows() != n
1543            || inputs.x_derivative.nrows() != n
1544            || inputs.x_entry.ncols() != inputs.x_exit.ncols()
1545            || inputs.x_entry.ncols() != inputs.x_derivative.ncols()
1546        {
1547            return Err(SurvivalError::DimensionMismatch);
1548        }
1549        Self::validate_penalties(&penalties, p)?;
1550        Self::validate_common_inputs(
1551            &inputs.age_entry,
1552            &inputs.age_exit,
1553            &inputs.event_target,
1554            &inputs.event_competing,
1555            &inputs.sampleweight,
1556        )?;
1557        if inputs.x_entry.iter().any(|v| !v.is_finite())
1558            || inputs.x_exit.iter().any(|v| !v.is_finite())
1559            || inputs.x_derivative.iter().any(|v| !v.is_finite())
1560        {
1561            return Err(SurvivalError::NonFiniteInput);
1562        }
1563        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1564            Self::validate_offsets(offsets, n)?;
1565        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1566            Self::validate_monotonicity_constraints(
1567                inputs.monotonicity_constraint_rows,
1568                inputs.monotonicity_constraint_offsets,
1569                p,
1570            )?;
1571
1572        Ok(Self::finish_construction(
1573            inputs.age_entry,
1574            inputs.age_exit,
1575            inputs.event_target,
1576            inputs.sampleweight,
1577            SurvivalDesign::Flat {
1578                x_entry: inputs.x_entry.to_owned(),
1579                x_exit: inputs.x_exit.to_owned(),
1580                x_derivative: inputs.x_derivative.to_owned(),
1581            },
1582            offset_eta_entry,
1583            offset_eta_exit,
1584            offset_derivative_exit,
1585            penalties,
1586            monotonicity,
1587            monotonicity_constraint_rows,
1588            monotonicity_constraint_offsets,
1589        ))
1590    }
1591
1592    pub fn from_time_covariate_inputswith_offsets(
1593        inputs: SurvivalTimeCovarInputs<'_>,
1594        offsets: Option<SurvivalBaselineOffsets<'_>>,
1595        penalties: PenaltyBlocks,
1596        monotonicity: SurvivalMonotonicityPenalty,
1597        spec: SurvivalSpec,
1598    ) -> Result<Self, SurvivalError> {
1599        if spec == SurvivalSpec::Crude {
1600            return Err(SurvivalError::UnsupportedSpec("crude"));
1601        }
1602        let n = inputs.age_entry.len();
1603        let p_time = inputs.time_entry.ncols();
1604        let p_cov = inputs.covariates.ncols();
1605        let p = p_time + p_cov;
1606        if inputs.age_exit.len() != n
1607            || inputs.event_target.len() != n
1608            || inputs.event_competing.len() != n
1609            || inputs.sampleweight.len() != n
1610            || inputs.time_entry.nrows() != n
1611            || inputs.time_exit.nrows() != n
1612            || inputs.time_derivative.nrows() != n
1613            || inputs.covariates.nrows() != n
1614            || inputs.time_entry.ncols() != inputs.time_exit.ncols()
1615            || inputs.time_entry.ncols() != inputs.time_derivative.ncols()
1616        {
1617            return Err(SurvivalError::DimensionMismatch);
1618        }
1619        Self::validate_penalties(&penalties, p)?;
1620        Self::validate_common_inputs(
1621            &inputs.age_entry,
1622            &inputs.age_exit,
1623            &inputs.event_target,
1624            &inputs.event_competing,
1625            &inputs.sampleweight,
1626        )?;
1627        if inputs.time_entry.iter().any(|v| !v.is_finite())
1628            || inputs.time_exit.iter().any(|v| !v.is_finite())
1629            || inputs.time_derivative.iter().any(|v| !v.is_finite())
1630            || inputs.covariates.iter().any(|v| !v.is_finite())
1631        {
1632            return Err(SurvivalError::NonFiniteInput);
1633        }
1634        let (offset_eta_entry, offset_eta_exit, offset_derivative_exit) =
1635            Self::validate_offsets(offsets, n)?;
1636        let (monotonicity_constraint_rows, monotonicity_constraint_offsets) =
1637            Self::validate_monotonicity_constraints(
1638                inputs.monotonicity_constraint_rows,
1639                inputs.monotonicity_constraint_offsets,
1640                p,
1641            )?;
1642
1643        Ok(Self::finish_construction(
1644            inputs.age_entry,
1645            inputs.age_exit,
1646            inputs.event_target,
1647            inputs.sampleweight,
1648            SurvivalDesign::TimeCovariateShared {
1649                time_entry: inputs.time_entry.to_owned(),
1650                time_exit: inputs.time_exit.to_owned(),
1651                time_derivative: inputs.time_derivative.to_owned(),
1652                covariates: inputs.covariates.to_owned(),
1653            },
1654            offset_eta_entry,
1655            offset_eta_exit,
1656            offset_derivative_exit,
1657            penalties,
1658            monotonicity,
1659            monotonicity_constraint_rows,
1660            monotonicity_constraint_offsets,
1661        ))
1662    }
1663
1664    /// Enable/disable monotonic time-block enforcement metadata.
1665    ///
1666    /// Monotonicity is enforced through linear inequality constraints on the
1667    /// derivative design; enabling this records how many leading time columns
1668    /// belong to that constrained block.
1669    /// Overwrite the per-block smoothing parameters `λ_k` in place.
1670    ///
1671    /// Used by the REML smoothing-parameter selection for transformation
1672    /// survival fits (issue #563): the outer optimizer proposes a `ρ = log λ`
1673    /// vector, sets the smoothing blocks' `λ_k` here, and re-runs the inner
1674    /// constrained PIRLS, so the monotone I-spline baseline can adapt its
1675    /// wiggliness instead of being pinned at a fixed seed. `lambdas` must have
1676    /// one entry per penalty block. The fixed stabilization ridge keeps its
1677    /// caller-set value (the optimizer never proposes a new one for it).
1678    pub fn set_penalty_lambdas(&mut self, lambdas: &[f64]) -> Result<(), EstimationError> {
1679        if lambdas.len() != self.penalties.blocks.len() {
1680            crate::bail_invalid_estim!(
1681                "set_penalty_lambdas expects {} lambdas, got {}",
1682                self.penalties.blocks.len(),
1683                lambdas.len()
1684            );
1685        }
1686        for (block, &lambda) in self.penalties.blocks.iter_mut().zip(lambdas.iter()) {
1687            if !lambda.is_finite() || lambda < 0.0 {
1688                crate::bail_invalid_estim!("penalty lambda must be finite and >= 0, got {lambda}");
1689            }
1690            block.lambda = lambda;
1691        }
1692        Ok(())
1693    }
1694
1695    pub fn set_structural_monotonicity(
1696        &mut self,
1697        enabled: bool,
1698        time_columns: usize,
1699    ) -> Result<(), EstimationError> {
1700        let p = self.coefficient_dim();
1701        if time_columns > p {
1702            crate::bail_invalid_estim!(
1703                "structural time columns {} exceed coefficient dimension {}",
1704                time_columns,
1705                p
1706            );
1707        }
1708        if enabled && time_columns == 0 {
1709            crate::bail_invalid_estim!("structural monotonicity requires at least one time column");
1710        }
1711        if enabled {
1712            const STRUCTURAL_DERIV_TOL: f64 = 1e-12;
1713            for (i, &offset) in self.offset_derivative_exit.iter().enumerate() {
1714                if offset < -STRUCTURAL_DERIV_TOL {
1715                    crate::bail_invalid_estim!(
1716                        "structural monotonicity requires nonnegative derivative offsets; found offset_derivative_exit[{i}]={offset:.3e}"
1717                    );
1718                }
1719            }
1720            let mut derivative_row = vec![0.0_f64; p];
1721            for i in 0..self.nrows() {
1722                self.fill_derivative_row(i, &mut derivative_row);
1723                for j in 0..time_columns {
1724                    let v = derivative_row[j];
1725                    if v < -STRUCTURAL_DERIV_TOL {
1726                        crate::bail_invalid_estim!(
1727                            "structural monotonicity requires nonnegative time-derivative basis entries; found x_derivative[{i},{j}]={v:.3e}"
1728                        );
1729                    }
1730                }
1731                for j in time_columns..p {
1732                    let v = derivative_row[j];
1733                    if v.abs() > STRUCTURAL_DERIV_TOL {
1734                        crate::bail_invalid_estim!(
1735                            "structural monotonicity requires zero derivative contribution outside the time block; found x_derivative[{i},{j}]={v:.3e}"
1736                        );
1737                    }
1738                }
1739            }
1740            if let (Some(rows), Some(offsets)) = (
1741                self.monotonicity_constraint_rows.as_ref(),
1742                self.monotonicity_constraint_offsets.as_ref(),
1743            ) {
1744                for (i, &offset) in offsets.iter().enumerate() {
1745                    if offset < -STRUCTURAL_DERIV_TOL {
1746                        crate::bail_invalid_estim!(
1747                            "structural monotonicity requires nonnegative collocation derivative offsets; found monotonicity_constraint_offsets[{i}]={offset:.3e}"
1748                        );
1749                    }
1750                }
1751                for i in 0..rows.nrows() {
1752                    for j in 0..time_columns {
1753                        let v = rows[[i, j]];
1754                        if v < -STRUCTURAL_DERIV_TOL {
1755                            crate::bail_invalid_estim!(
1756                                "structural monotonicity requires nonnegative collocation derivative basis entries; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1757                            );
1758                        }
1759                    }
1760                    for j in time_columns..p {
1761                        let v = rows[[i, j]];
1762                        if v.abs() > STRUCTURAL_DERIV_TOL {
1763                            crate::bail_invalid_estim!(
1764                                "structural monotonicity requires zero collocation derivative contribution outside the time block; found monotonicity_constraint_rows[{i},{j}]={v:.3e}"
1765                            );
1766                        }
1767                    }
1768                }
1769            }
1770        }
1771        self.structurally_monotonic = enabled;
1772        self.structural_time_columns = if enabled { time_columns } else { 0 };
1773        Ok(())
1774    }
1775
1776    pub fn update_state(&self, beta: &Array1<f64>) -> Result<WorkingState, EstimationError> {
1777        if beta.len() != self.coefficient_dim() {
1778            crate::bail_invalid_estim!("survival beta dimension mismatch");
1779        }
1780
1781        let n = self.nrows();
1782        let p = self.coefficient_dim();
1783
1784        // Royston-Parmar contract used throughout the engine:
1785        //   eta(t) = log(H(t)), where H(t) is cumulative hazard.
1786        //
1787        // With row-vectors (per subject i):
1788        //   a1_i^T := x_exit_i^T,  a0_i^T := x_entry_i^T,  d_i^T := x_derivative_i^T
1789        // and scalars:
1790        //   eta1_i = a1_i^T beta,  eta0_i = a0_i^T beta,  s_i = d_i^T beta.
1791        //
1792        // The per-subject negative log-likelihood used below is
1793        //   NLL_i(beta) = exp(eta1_i) - exp(eta0_i) - delta_i * (eta1_i + log(s_i)),
1794        // with delta_i = event_target_i.
1795        //
1796        // This is exactly the form whose derivatives are:
1797        //   grad_i = exp(eta1_i) a1_i - exp(eta0_i) a0_i - delta_i * (a1_i + d_i / s_i)
1798        //   Hess_i = exp(eta1_i) a1_i a1_i^T - exp(eta0_i) a0_i a0_i^T
1799        //            + delta_i * (d_i d_i^T) / s_i^2.
1800        //
1801        // Monotonicity is enforced through linear inequality constraints on the
1802        // derivative design. This keeps the baseline smoothing penalty on the
1803        // actual spline coefficients and preserves zero-deviation as beta=0.
1804        //
1805        // The loop below computes exact beta-space derivatives and then adds penalties.
1806        // Total predictor = target offset + learned deviation.
1807        // This is the same architecture used for flexible binary links:
1808        // principled default, plus penalized wiggle/deviation.
1809        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
1810        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
1811        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
1812
1813        let mut nll = 0.0;
1814        let derivative_guard = self.derivative_guard();
1815        let derivative_guard_numerical = self.derivative_guard_numerical();
1816        let mut workspace = self.workspace.lock().unwrap();
1817        workspace.reset(n);
1818        let SurvivalWorkspace {
1819            w_event,
1820            w_event_inv_deriv,
1821            w_event_outer,
1822            w_hess_exit,
1823            w_hess_entry,
1824        } = &mut *workspace;
1825
1826        // Phase 1: Scalar loop — compute per-observation weights, NLL, validation.
1827        for i in 0..n {
1828            let w = self.sampleweight[i];
1829            if w <= 0.0 {
1830                continue;
1831            }
1832            let entry_age = self.age_entry[i];
1833            let exit_age = self.age_exit[i];
1834            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
1835                crate::bail_invalid_estim!(
1836                    "survival ages must be finite with age_exit >= age_entry"
1837                );
1838            }
1839            let d = f64::from(self.event_target[i]);
1840
1841            let has_entry_interval = !self.entry_at_origin[i];
1842            let interval_scale = if has_entry_interval {
1843                eta_exit[i].max(eta_entry[i])
1844            } else {
1845                eta_exit[i]
1846            };
1847            let h_e_scaled = (eta_exit[i] - interval_scale).exp();
1848            let h_s_scaled = if has_entry_interval {
1849                (eta_entry[i] - interval_scale).exp()
1850            } else {
1851                0.0
1852            };
1853            let interval_scaled = h_e_scaled - h_s_scaled;
1854            let interval = Self::scaled_exp_component(interval_scale, interval_scaled)?;
1855            let deriv = self
1856                .stabilized_structural_derivative(derivative_raw[i])
1857                .unwrap_or(derivative_raw[i]);
1858            // Monotonicity of η(t) = log H(t) is a structural property of the
1859            // whole Royston-Parmar spline. If d_eta/dt is *strictly negative*
1860            // at any observed exit time, the cumulative hazard H(t) decreases
1861            // there and S(t) is not a valid survival function — both event
1862            // and censored rows have to refuse that case. Event rows further
1863            // need deriv strictly above the numerical guard because their
1864            // NLL contains `deriv.ln()` and `1.0 / deriv`; censored rows do
1865            // not, so a boundary value of exactly zero is feasible there.
1866            let mono_floor = if d > 0.0 {
1867                derivative_guard_numerical
1868            } else {
1869                0.0
1870            };
1871            if !deriv.is_finite() || deriv < mono_floor {
1872                return Err(EstimationError::ParameterConstraintViolation(format!(
1873                    "survival monotonicity violated at row {}: d_eta/dt={:.3e} <= tolerance={:.3e}",
1874                    i, deriv, derivative_guard
1875                )));
1876            }
1877            if has_entry_interval {
1878                let increment_guard = self.interval_increment_guard(h_s_scaled, h_e_scaled);
1879                if interval_scaled + increment_guard < 0.0 {
1880                    return Err(EstimationError::ParameterConstraintViolation(format!(
1881                        "survival cumulative hazard decreased over row {}: H(exit)-H(entry)={:.6e}",
1882                        i, interval
1883                    )));
1884                }
1885            }
1886            nll += w * interval;
1887
1888            // Per-observation weights for BLAS phase.
1889            // scaled_exp_component(interval_scale, h_e_scaled * x[r]) = exp(interval_scale) * h_e_scaled * x[r]
1890            // so the Hessian weight is w * exp(interval_scale) * h_e_scaled = w * exp(eta_exit).
1891            let w_exit_i = w * eta_exit[i].exp();
1892            let w_entry_i = if has_entry_interval {
1893                w * eta_entry[i].exp()
1894            } else {
1895                0.0
1896            };
1897            if !w_exit_i.is_finite() {
1898                crate::bail_invalid_estim!(
1899                    "survival interval term exceeds f64 range at row {i} (w*exp(eta_exit)={w_exit_i:.3e})"
1900                );
1901            }
1902            w_hess_exit[i] = w_exit_i;
1903            w_hess_entry[i] = w_entry_i;
1904
1905            if d > 0.0 {
1906                let inv_deriv = 1.0 / deriv;
1907                nll += -w * (eta_exit[i] + deriv.ln());
1908                w_event[i] = w;
1909                w_event_inv_deriv[i] = w * inv_deriv;
1910                w_event_outer[i] = w * inv_deriv * inv_deriv;
1911            }
1912        }
1913
1914        // Phase 2: BLAS-accelerated Hessian and gradient via faer.
1915        //   H_interval = X_exit^T diag(w_exit) X_exit - X_entry^T diag(w_entry) X_entry
1916        //   grad_interval = X_exit^T w_exit - X_entry^T w_entry
1917        let mut h = self.interval_hessian_blas(w_hess_exit, w_hess_entry);
1918        // At large smoothing penalties the event-Jacobian score nearly cancels
1919        // the interval score. Compensated row accumulation keeps the final KKT
1920        // residual accurate enough for the outer LAML envelope check.
1921        let mut grad = Array1::<f64>::zeros(p);
1922        let mut grad_comp = Array1::<f64>::zeros(p);
1923        let mut row_exit = vec![0.0_f64; p];
1924        let mut row_entry = vec![0.0_f64; p];
1925        let mut row_derivative = vec![0.0_f64; p];
1926        for i in 0..n {
1927            let w_interval_exit = w_hess_exit[i];
1928            let w_interval_entry = w_hess_entry[i];
1929            let w_event_exit = w_event[i];
1930            let w_event_derivative = w_event_inv_deriv[i];
1931            if w_interval_exit == 0.0
1932                && w_interval_entry == 0.0
1933                && w_event_exit == 0.0
1934                && w_event_derivative == 0.0
1935            {
1936                continue;
1937            }
1938            self.fill_exit_row(i, &mut row_exit);
1939            self.fill_entry_row(i, &mut row_entry);
1940            self.fill_derivative_row(i, &mut row_derivative);
1941            for j in 0..p {
1942                let contribution = w_interval_exit * row_exit[j]
1943                    - w_interval_entry * row_entry[j]
1944                    - w_event_exit * row_exit[j]
1945                    - w_event_derivative * row_derivative[j];
1946                let t = grad[j] + contribution;
1947                if grad[j].abs() >= contribution.abs() {
1948                    grad_comp[j] += (grad[j] - t) + contribution;
1949                } else {
1950                    grad_comp[j] += (contribution - t) + grad[j];
1951                }
1952                grad[j] = t;
1953            }
1954        }
1955        grad += &grad_comp;
1956
1957        h += &self.derivative_xt_diag_x(w_event_outer);
1958
1959        // Norm of the unpenalized score, captured before adding the penalty
1960        // and ridge contributions, for the scale-invariant convergence
1961        // certificate (||score||_2 + ||S*beta||_2 (+ ridge*||beta||_2)).
1962        let score_norm = array1_l2_norm(&grad);
1963
1964        let penaltygrad = self.penalties.gradient(beta);
1965        let penalty_dev = self.penalties.deviance(beta);
1966        let penaltygrad_norm = array1_l2_norm(&penaltygrad);
1967
1968        let mut totalgrad = grad;
1969        totalgrad += &penaltygrad;
1970
1971        self.penalties.addhessian_inplace(&mut h);
1972        // SURVIVAL_STABILIZATION_RIDGE is an `ExplicitPrior`-kind
1973        // stabilization in the canonical ledger taxonomy
1974        // (`gam_problem::StabilizationKind::ExplicitPrior`): δ enters the
1975        // gradient (`grad += δ β`), the Hessian (`H += δ I`), the scalar
1976        // penalty term added to the objective (`0.5 δ ‖β‖²`), and is
1977        // serialized through `WorkingState::ridge_used` so downstream
1978        // covariance and survival_ridge_lambda accounting remain
1979        // consistent. The canonical ledger record is
1980        //   StabilizationLedger::explicit_prior(δ, RidgeMatrixForm::ScaledIdentity)
1981        // chosen_by = FixedConstant. Coordinated with main.rs
1982        // `survival_ridge_lambda` field.
1983        const SURVIVAL_STABILIZATION_RIDGE: f64 = 1e-8;
1984        let ridge_used = SURVIVAL_STABILIZATION_RIDGE;
1985        for d in 0..p {
1986            h[[d, d]] += ridge_used;
1987        }
1988        totalgrad += &beta.mapv(|v| ridge_used * v);
1989        // Keep scalar objective term consistent with:
1990        //   grad += ridge * beta,  Hess += ridge * I
1991        // which correspond to 0.5 * ridge * ||beta||^2.
1992        let ridge_penalty = 0.5 * ridge_used * beta.dot(beta);
1993        let ridge_grad_norm = ridge_used * array1_l2_norm(beta);
1994
1995        let log_likelihood = -nll;
1996        let deviance = 2.0 * nll;
1997
1998        Ok(WorkingState {
1999            eta: LinearPredictor::new(eta_exit),
2000            gradient: totalgrad,
2001            hessian: gam_linalg::matrix::SymmetricMatrix::Dense(h),
2002            log_likelihood,
2003            deviance,
2004            penalty_term: penalty_dev + ridge_penalty,
2005            firth: gam_solve::pirls::FirthDiagnostics::Inactive,
2006            ridge_used,
2007            hessian_curvature: gam_solve::pirls::HessianCurvatureKind::Observed,
2008            gradient_natural_scale: score_norm + penaltygrad_norm + ridge_grad_norm,
2009        })
2010    }
2011
2012    /// Compute the third-derivative correction matrix for a given mode response `u_k`.
2013    ///
2014    /// This is the directional derivative of the unpenalized NLL Hessian w.r.t.
2015    /// beta along direction `u_k = -H^{-1} A_k beta_hat`. The returned matrix B
2016    /// satisfies `dH/drho_k = A_k + B`.
2017    ///
2018    /// Called via [`SurvivalDerivProvider`] which adapts the sign convention
2019    /// from the unified `HessianDerivativeProvider` trait (positive `v_k`) to
2020    /// the negated `u_k` used here.
2021    pub(crate) fn survival_hessian_derivative_correction(
2022        &self,
2023        beta: &Array1<f64>,
2024        u_k: &Array1<f64>,
2025    ) -> Result<Array2<f64>, EstimationError> {
2026        let p = beta.len();
2027        let n = self.nrows();
2028
2029        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2030        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2031        let deriv_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2032        let exp_entry = eta_entry.mapv(f64::exp);
2033        let exp_exit = eta_exit.mapv(f64::exp);
2034        let guard = self.derivative_guard();
2035        let guard_numerical = self.derivative_guard_numerical();
2036
2037        let jac = Array1::<f64>::ones(p);
2038        let curvature = Array1::<f64>::zeros(p);
2039        let third = Array1::<f64>::zeros(p);
2040
2041        let mut row_exit = vec![0.0_f64; p];
2042        let mut row_entry = vec![0.0_f64; p];
2043        let mut row_derivative = vec![0.0_f64; p];
2044        let mut ge = vec![0.0_f64; p];
2045        let mut gs = vec![0.0_f64; p];
2046        let mut gsd = vec![0.0_f64; p];
2047        let mut he = vec![0.0_f64; p];
2048        let mut hs = vec![0.0_f64; p];
2049        let mut hsd = vec![0.0_f64; p];
2050        let mut te = vec![0.0_f64; p];
2051        let mut ts = vec![0.0_f64; p];
2052        let mut tsd = vec![0.0_f64; p];
2053
2054        let mut b_dir = Array2::<f64>::zeros((p, p));
2055
2056        for i in 0..n {
2057            let w_i = self.sampleweight[i];
2058            if w_i <= 0.0 {
2059                continue;
2060            }
2061            let has_entry = !self.entry_at_origin[i];
2062            let mut deta_e = 0.0_f64;
2063            let mut deta_s = 0.0_f64;
2064            let mut ds = 0.0_f64;
2065            self.fill_exit_row(i, &mut row_exit);
2066            self.fill_entry_row(i, &mut row_entry);
2067            self.fill_derivative_row(i, &mut row_derivative);
2068            for j in 0..p {
2069                ge[j] = row_exit[j] * jac[j];
2070                gs[j] = row_entry[j] * jac[j];
2071                gsd[j] = row_derivative[j] * jac[j];
2072                he[j] = row_exit[j] * curvature[j];
2073                hs[j] = row_entry[j] * curvature[j];
2074                hsd[j] = row_derivative[j] * curvature[j];
2075                te[j] = row_exit[j] * third[j];
2076                ts[j] = row_entry[j] * third[j];
2077                tsd[j] = row_derivative[j] * third[j];
2078                deta_e += ge[j] * u_k[j];
2079                if has_entry {
2080                    deta_s += gs[j] * u_k[j];
2081                }
2082                ds += gsd[j] * u_k[j];
2083            }
2084
2085            // Interval part: d/dbeta [ exp(eta) * (g g^T + diag(h)) ][u_k]
2086            for r in 0..p {
2087                let dge_r = he[r] * u_k[r];
2088                let dgs_r = hs[r] * u_k[r];
2089                let dhe_r = te[r] * u_k[r];
2090                let dhs_r = ts[r] * u_k[r];
2091                for c in 0..p {
2092                    let dge_c = he[c] * u_k[c];
2093                    let dgs_c = hs[c] * u_k[c];
2094                    let mut d_h_rc =
2095                        exp_exit[i] * (deta_e * ge[r] * ge[c] + dge_r * ge[c] + ge[r] * dge_c);
2096                    if r == c {
2097                        d_h_rc += exp_exit[i] * (deta_e * he[r] + dhe_r);
2098                    }
2099                    if has_entry {
2100                        d_h_rc -=
2101                            exp_entry[i] * (deta_s * gs[r] * gs[c] + dgs_r * gs[c] + gs[r] * dgs_c);
2102                        if r == c {
2103                            d_h_rc -= exp_entry[i] * (deta_s * hs[r] + dhs_r);
2104                        }
2105                    }
2106                    b_dir[[r, c]] += w_i * d_h_rc;
2107                }
2108            }
2109
2110            // Event part: d/dbeta [ gsd gsd^T / s^2 - diag(he) - diag(hsd / s) ][u_k]
2111            let s_i = self
2112                .stabilized_structural_derivative(deriv_raw[i])
2113                .unwrap_or(deriv_raw[i]);
2114            if !s_i.is_finite() {
2115                return Err(EstimationError::ParameterConstraintViolation(format!(
2116                    "survival monotonicity violated in unified trace contraction at row {i}: \
2117                     d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2118                )));
2119            }
2120            if self.event_target[i] > 0 {
2121                if s_i < guard_numerical {
2122                    return Err(EstimationError::ParameterConstraintViolation(format!(
2123                        "survival monotonicity violated in unified trace contraction at row {i}: \
2124                         d_eta/dt={s_i:.3e} <= tolerance={guard:.3e}",
2125                    )));
2126                }
2127                let inv_s = 1.0 / s_i;
2128                let inv_s2 = inv_s * inv_s;
2129                let inv_s3 = inv_s2 * inv_s;
2130                for r in 0..p {
2131                    let dgd_r = hsd[r] * u_k[r];
2132                    let dtsd_r = tsd[r] * u_k[r];
2133                    let dte_r = te[r] * u_k[r];
2134                    for c in 0..p {
2135                        let dgd_c = hsd[c] * u_k[c];
2136                        let mut d_h_rc = (dgd_r * gsd[c] + gsd[r] * dgd_c) * inv_s2
2137                            - 2.0 * gsd[r] * gsd[c] * ds * inv_s3;
2138                        if r == c {
2139                            d_h_rc += -dte_r;
2140                            d_h_rc += -(dtsd_r * inv_s - hsd[r] * ds * inv_s2);
2141                        }
2142                        b_dir[[r, c]] += w_i * d_h_rc;
2143                    }
2144                }
2145            }
2146        }
2147
2148        Ok(b_dir)
2149    }
2150
2151    /// Per-observation gradients of the unpenalized survival NLL with respect
2152    /// to each additive offset channel, at the given β.
2153    ///
2154    /// Contract (Royston-Parmar, eta = log H(t)):
2155    ///
2156    ///   NLL_i(β; o_E, o_X, o_D) = w_i · [
2157    ///       exp(η1_i) − 1{has_entry}·exp(η0_i)
2158    ///       − δ_i · (η1_i + log s_i)
2159    ///   ]
2160    ///
2161    /// with η1_i = a1_iᵀβ + o_X[i], η0_i = a0_iᵀβ + o_E[i],
2162    ///      s_i  = d_iᵀβ + o_D[i].
2163    ///
2164    /// The additive offsets enter each of the three η channels linearly, so
2165    ///   ∂NLL_i/∂o_X[i] = w_i · (exp(η1_i) − δ_i)
2166    ///   ∂NLL_i/∂o_E[i] = −w_i · exp(η0_i) · 1{has_entry_interval}
2167    ///   ∂NLL_i/∂o_D[i] = −w_i · δ_i / s_i         (event-row only)
2168    ///
2169    /// These three arrays are the sampleweight-scaled residuals used to chain
2170    /// `∂NLL/∂offset` into `∂NLL/∂θ` via any closed-form `∂offset/∂θ` map
2171    /// (see `baseline_offset_theta_partials` for parametric baselines). At
2172    /// converged β*, the envelope theorem on the penalized objective gives
2173    ///
2174    ///   d[0.5·(deviance + β*ᵀS_λβ*)] / dθ
2175    ///     = Σᵢ r_X_i·∂o_X_i/∂θ + r_E_i·∂o_E_i/∂θ + r_D_i·∂o_D_i/∂θ
2176    ///
2177    /// exactly (no IFT back-solve required), because β* is a stationary point
2178    /// of the penalized objective wrt β and the penalty has no θ dependence.
2179    ///
2180    /// Rows with `sampleweight[i] ≤ 0` and non-event rows for `r_D` are
2181    /// returned as exact 0.0 so the output can be dot-producted against a
2182    /// per-obs baseline-partials array without a mask.
2183    ///
2184    /// Structural-monotonicity stabilization on `s_i` (see
2185    /// `stabilized_structural_derivative`) is applied identically to the
2186    /// existing `update_state` path so the residual agrees with the
2187    /// NLL that `update_state` evaluates.
2188    pub fn offset_channel_residuals(
2189        &self,
2190        beta: &Array1<f64>,
2191    ) -> Result<OffsetChannelResiduals, EstimationError> {
2192        if beta.len() != self.coefficient_dim() {
2193            crate::bail_invalid_estim!(
2194                "survival beta dimension mismatch in offset_channel_residuals"
2195            );
2196        }
2197        let n = self.nrows();
2198        let eta_entry = self.entry_dot(beta) + &self.offset_eta_entry;
2199        let eta_exit = self.exit_dot(beta) + &self.offset_eta_exit;
2200        let derivative_raw = self.derivative_dot(beta) + &self.offset_derivative_exit;
2201
2202        let derivative_guard_numerical = self.derivative_guard_numerical();
2203        let mut r_exit = Array1::<f64>::zeros(n);
2204        let mut r_entry = Array1::<f64>::zeros(n);
2205        let mut r_deriv = Array1::<f64>::zeros(n);
2206
2207        for i in 0..n {
2208            let w = self.sampleweight[i];
2209            if w <= 0.0 {
2210                continue;
2211            }
2212            let entry_age = self.age_entry[i];
2213            let exit_age = self.age_exit[i];
2214            if !entry_age.is_finite() || !exit_age.is_finite() || exit_age < entry_age {
2215                crate::bail_invalid_estim!(
2216                    "survival ages must be finite with age_exit >= age_entry"
2217                );
2218            }
2219            let has_entry_interval = !self.entry_at_origin[i];
2220            let d = f64::from(self.event_target[i]);
2221            // Phase-1 values matching update_state:
2222            //   w_exit_i  = w · exp(eta_exit)                    → ∂NLL/∂o_X before − δ·w term
2223            //   w_entry_i = w · exp(eta_entry) · 1{has_entry}    → matches −∂NLL/∂o_E sign
2224            let w_exit_i = w * eta_exit[i].exp();
2225            let w_entry_i = if has_entry_interval {
2226                w * eta_entry[i].exp()
2227            } else {
2228                0.0
2229            };
2230            if !w_exit_i.is_finite() {
2231                crate::bail_invalid_estim!(
2232                    "offset_channel_residuals: w*exp(eta_exit)={w_exit_i:.3e} non-finite at row {i}"
2233                );
2234            }
2235            r_exit[i] = w_exit_i - d * w;
2236            r_entry[i] = -w_entry_i;
2237            // Same per-row monotonicity rule as `update_state`: a strictly
2238            // negative derivative at any observed exit time (event or
2239            // censored) falsifies S(t); event rows additionally need
2240            // `deriv > guard` because `1/deriv` enters their score.
2241            let deriv_raw = derivative_raw[i];
2242            let deriv = self
2243                .stabilized_structural_derivative(deriv_raw)
2244                .unwrap_or(deriv_raw);
2245            let mono_floor = if d > 0.0 {
2246                derivative_guard_numerical
2247            } else {
2248                0.0
2249            };
2250            if !deriv.is_finite() || deriv < mono_floor {
2251                return Err(EstimationError::ParameterConstraintViolation(format!(
2252                    "offset_channel_residuals: derivative ≤ numerical guard at row {i}: {deriv:.3e}"
2253                )));
2254            }
2255            if d > 0.0 {
2256                r_deriv[i] = -w * d / deriv;
2257            }
2258        }
2259
2260        let right = Array1::<f64>::zeros(r_exit.len());
2261        Ok(OffsetChannelResiduals {
2262            exit: r_exit,
2263            entry: r_entry,
2264            derivative: r_deriv,
2265            right,
2266        })
2267    }
2268
2269    /// Build an [`InnerSolution`](gam_solve::estimate::reml::reml_outer_engine::InnerSolution) from
2270    /// the survival working state, suitable for the unified REML/LAML evaluator.
2271    ///
2272    /// Evaluate the survival outer objective and gradient via the unified REML/LAML
2273    /// evaluator, using the canonical assembly module.
2274    pub fn unified_lamlobjective_and_rhogradient(
2275        &self,
2276        beta: &Array1<f64>,
2277        state: &WorkingState,
2278        rho: &Array1<f64>,
2279    ) -> Result<(f64, Array1<f64>), EstimationError> {
2280        use gam_solve::estimate::reml::assembly::{
2281            InnerAssembly, PenaltyBlockDesc, penalty_coords_from_blocks,
2282        };
2283        use gam_solve::estimate::reml::reml_outer_engine::{
2284            DenseSpectralOperator, DispersionHandling, PenaltyLogdetDerivs,
2285            compute_block_penalty_logdet_derivs,
2286        };
2287        use gam_problem::EvalMode;
2288
2289        let p = beta.len();
2290        let active_penalty_blocks: Vec<&PenaltyBlock> = self
2291            .penalties
2292            .blocks
2293            .iter()
2294            .filter(|b| b.lambda > 0.0)
2295            .collect();
2296        if rho.len() != active_penalty_blocks.len() {
2297            crate::bail_invalid_estim!(
2298                "survival LAML rho dimension {} does not match active penalty block count {}",
2299                rho.len(),
2300                active_penalty_blocks.len()
2301            );
2302        }
2303        let k_count = active_penalty_blocks.len();
2304
2305        // --- Hessian operator ---
2306        let h_dense = state.hessian.to_dense();
2307        let hop = DenseSpectralOperator::from_symmetric(&h_dense)
2308            .map_err(EstimationError::InvalidInput)?;
2309
2310        // --- Penalty coordinates via shared assembler helper ---
2311        let block_descs: Vec<PenaltyBlockDesc> = self
2312            .penalties
2313            .blocks
2314            .iter()
2315            .filter(|b| b.lambda > 0.0)
2316            .map(|b| PenaltyBlockDesc {
2317                matrix: &b.matrix,
2318                range_start: b.range.start,
2319                range_end: b.range.end,
2320            })
2321            .collect();
2322        let penalty_coords =
2323            penalty_coords_from_blocks(&block_descs, p).map_err(EstimationError::InvalidInput)?;
2324
2325        // --- Penalty logdet derivatives ---
2326        let per_block_rho: Vec<Array1<f64>> =
2327            rho.iter().map(|&r| Array1::from_vec(vec![r])).collect();
2328        let per_block_penalty_matrices: Vec<Vec<Array2<f64>>> = active_penalty_blocks
2329            .iter()
2330            .map(|b| vec![b.matrix.clone()])
2331            .collect();
2332        let per_block_penalty_refs: Vec<&[Array2<f64>]> = per_block_penalty_matrices
2333            .iter()
2334            .map(|v| v.as_slice())
2335            .collect();
2336        let penalty_logdet = if k_count > 0 {
2337            compute_block_penalty_logdet_derivs(&per_block_rho, &per_block_penalty_refs, 0.0)
2338                .map_err(EstimationError::InvalidInput)?
2339        } else {
2340            PenaltyLogdetDerivs {
2341                value: 0.0,
2342                first: Array1::zeros(0),
2343                second: Some(Array2::zeros((0, 0))),
2344            }
2345        };
2346
2347        // penalty_quadratic = 2 * penalty_term (matching unified evaluator convention).
2348        let penalty_quadratic = 2.0 * state.penalty_term;
2349        let provider = SurvivalDerivProvider::new(self.clone(), beta.clone());
2350
2351        // #931 survival-LAML IFT envelope: attach the one-step Newton correction
2352        // only when this state is actually a near-stationary inner solution.
2353        // `unified_lamlobjective_and_rhogradient` is also used by algebraic
2354        // fixed-beta objective tests; feeding a large non-stationary residual
2355        // there makes the value a different surface. The re-converged shim
2356        // polishes the inner mode to an absolute residual floor, so certified
2357        // states still keep the envelope correction while arbitrary beta probes
2358        // evaluate the documented LAML objective.
2359        //
2360        // The residual MUST be the active-set-projected stationarity vector, not
2361        // raw `state.gradient`: a binding monotonicity constraint contributes a
2362        // Lagrange-multiplier normal component (`r = A^T lambda`, lambda >= 0)
2363        // that is not a stationarity residual.
2364        const SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE: f64 = 1.0e-8;
2365        let kkt_residual = {
2366            let raw = state.gradient.clone();
2367            let projected = match self.monotonicity_linear_constraints() {
2368                Some(constraints) => {
2369                    projected_linear_constraint_stationarity_vector(&raw, beta, &constraints, None)
2370                        .ok_or_else(|| {
2371                            EstimationError::InvalidInput(
2372                                "survival LAML could not project the monotonicity KKT residual"
2373                                    .to_string(),
2374                            )
2375                        })?
2376                }
2377                None => raw,
2378            };
2379            let projected_norm = array1_l2_norm(&projected);
2380            let relative_projected_norm = state.relative_gradient_norm(projected_norm);
2381            if relative_projected_norm <= SURVIVAL_LAML_IFT_RELATIVE_KKT_GATE {
2382                Some(crate::model_types::ProjectedKktResidual::from_active_projected(projected))
2383            } else {
2384                None
2385            }
2386        };
2387
2388        let result = InnerAssembly {
2389            log_likelihood: state.log_likelihood,
2390            penalty_quadratic,
2391            beta: beta.clone(),
2392            n_observations: self.nrows(),
2393            hessian_op: std::sync::Arc::new(hop),
2394            penalty_coords,
2395            penalty_logdet,
2396            dispersion: DispersionHandling::Fixed {
2397                phi: 1.0,
2398                include_logdet_h: true,
2399                include_logdet_s: true,
2400            },
2401            rho_curvature_scale: 1.0,
2402            rho_prior: gam_problem::RhoPrior::Flat,
2403            hessian_logdet_correction: 0.0,
2404            penalty_subspace_trace: None,
2405            deriv_provider: Some(Box::new(provider)),
2406            firth: None,
2407            nullspace_dim: None,
2408            barrier_config: None,
2409            ext_coords: Vec::new(),
2410            ext_coord_pair_fn: None,
2411            rho_ext_pair_fn: None,
2412            fixed_drift_deriv: None,
2413            contracted_psi_second_order: None,
2414            kkt_residual,
2415            active_constraints: None,
2416        }
2417        .evaluate(
2418            rho.as_slice().expect("rho must be contiguous"),
2419            EvalMode::ValueAndGradient,
2420            None,
2421        )
2422        .map_err(EstimationError::InvalidInput)?;
2423
2424        let gradient = result.gradient.unwrap_or_else(|| Array1::zeros(rho.len()));
2425        Ok((result.cost, gradient))
2426    }
2427
2428    /// Self-contained ρ → (LAML value, analytic ρ-gradient) surface for the
2429    /// survival LAML objective.
2430    ///
2431    /// Unlike [`unified_lamlobjective_and_rhogradient`](Self::unified_lamlobjective_and_rhogradient),
2432    /// which takes a *pre-converged* [`WorkingState`] and `β̂` at the evaluated
2433    /// `ρ`, this shim re-converges the inner survival mode internally: it sets
2434    /// the active-block smoothing parameters to `λ = exp(ρ)`, runs the same
2435    /// constrained inner PIRLS that the survival outer loop uses
2436    /// ([`runworking_model_pirls`](gam_solve::pirls::runworking_model_pirls)), then
2437    /// evaluates the unified survival LAML value and analytic ρ-gradient at the
2438    /// re-fitted `β̂(ρ)`. The returned pair is therefore a single-source value+
2439    /// gradient surface that a caller can finite-difference by varying `ρ`
2440    /// alone — the survival counterpart of the GLM path's
2441    /// `evaluate_externalgradient` / `evaluate_externalcost_andridge`.
2442    ///
2443    /// `rho` enumerates the **active** penalty blocks (those with `λ > 0`) in
2444    /// block order, matching the convention of the unified evaluator. `beta0` is
2445    /// the inner warm-start. The behaviour is identical to the existing survival
2446    /// LAML path (set-λ → inner PIRLS → `update_state` → unified LAML); this is a
2447    /// reachability shim, not a new objective.
2448    pub fn evaluate_survival_lamlcost_and_gradient(
2449        &self,
2450        rho: &[f64],
2451        beta0: &Array1<f64>,
2452    ) -> Result<(f64, Array1<f64>), EstimationError> {
2453        let (candidate, beta) = self.reconverge_survival_inner_mode(rho, beta0)?;
2454        // Re-converged β̂(ρ); evaluate the unified survival LAML value and
2455        // analytic ρ-gradient at that mode. The ρ passed to the unified
2456        // evaluator enumerates active blocks in block order, exactly the input
2457        // convention of this shim.
2458        let rho_arr = Array1::from_vec(rho.to_vec());
2459        let state = candidate.update_state(&beta)?;
2460        candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &rho_arr)
2461    }
2462
2463    /// Re-converge the survival inner mode at `λ = exp(ρ)` from warm-start
2464    /// `beta0`, returning the λ-set model candidate and the converged `β̂(ρ)`.
2465    /// This is the shared inner-solve used by
2466    /// [`evaluate_survival_lamlcost_and_gradient`](Self::evaluate_survival_lamlcost_and_gradient):
2467    /// inner PIRLS to a tight relative certificate, followed by a
2468    /// Levenberg–Marquardt / exact-Cholesky stationarity polish that drives the
2469    /// absolute penalized residual `‖S β̂ − ∇ℓ‖` below the FD round-off floor so
2470    /// the envelope ρ-gradient is exact. (Without the polish, PIRLS alone leaves
2471    /// `‖r‖ ~ 1` at large λ where H is ill-conditioned.)
2472    fn reconverge_survival_inner_mode(
2473        &self,
2474        rho: &[f64],
2475        beta0: &Array1<f64>,
2476    ) -> Result<(WorkingModelSurvival, Array1<f64>), EstimationError> {
2477        // Inner-PIRLS settings mirror the survival transformation outer loop's
2478        // constrained inner solve. Tighter convergence than the production
2479        // outer loop so the inner mode is converged well below the FD step's
2480        // round-off floor, making ∇V finite-differentiable in ρ alone.
2481        const SHIM_PIRLS_MAX_ITERATIONS: usize = 600;
2482        const SHIM_PIRLS_CONVERGENCE_TOL: f64 = 1e-12;
2483        const SHIM_PIRLS_MAX_STEP_HALVING: usize = 40;
2484        const SHIM_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
2485
2486        let active_block_count = self
2487            .penalties
2488            .blocks
2489            .iter()
2490            .filter(|b| b.lambda > 0.0)
2491            .count();
2492        if rho.len() != active_block_count {
2493            crate::bail_invalid_estim!(
2494                "reconverge_survival_inner_mode: rho dimension {} does not match active penalty block count {}",
2495                rho.len(),
2496                active_block_count
2497            );
2498        }
2499        if beta0.len() != self.coefficient_dim() {
2500            crate::bail_invalid_estim!(
2501                "reconverge_survival_inner_mode: beta0 dimension {} does not match coefficient dimension {}",
2502                beta0.len(),
2503                self.coefficient_dim()
2504            );
2505        }
2506
2507        // Set λ = exp(ρ) on the active blocks (block order), leaving inactive
2508        // (λ = 0) blocks untouched, then re-converge the inner mode.
2509        let mut candidate = self.clone();
2510        let mut lambdas: Vec<f64> = candidate
2511            .penalties
2512            .blocks
2513            .iter()
2514            .map(|b| b.lambda)
2515            .collect();
2516        let mut active_idx = 0usize;
2517        for (block, lambda) in candidate.penalties.blocks.iter().zip(lambdas.iter_mut()) {
2518            if block.lambda > 0.0 {
2519                *lambda = rho[active_idx].exp();
2520                active_idx += 1;
2521            }
2522        }
2523        candidate.set_penalty_lambdas(&lambdas)?;
2524
2525        let opts = gam_solve::pirls::WorkingModelPirlsOptions {
2526            max_iterations: SHIM_PIRLS_MAX_ITERATIONS,
2527            convergence_tolerance: SHIM_PIRLS_CONVERGENCE_TOL,
2528            adaptive_kkt_tolerance: None,
2529            max_step_halving: SHIM_PIRLS_MAX_STEP_HALVING,
2530            min_step_size: SHIM_PIRLS_MIN_STEP_SIZE,
2531            firth_bias_reduction: false,
2532            coefficient_lower_bounds: None,
2533            linear_constraints: None,
2534            initial_lm_lambda: None,
2535            geodesic_acceleration: false,
2536            arrow_schur: None,
2537        };
2538        let summary = gam_solve::pirls::runworking_model_pirls(
2539            &mut candidate,
2540            Coefficients::new(beta0.clone()),
2541            &opts,
2542            |_| {},
2543        )?;
2544        let mut beta = summary.beta.as_ref().to_owned();
2545
2546        // PIRLS exits on a RELATIVE KKT / deviance-plateau certificate, which can leave
2547        // an ABSOLUTE penalized stationarity residual r = S beta_hat - grad_ell of order
2548        // 0.1-1 (the score scales as O(sqrt(n))). The unified LAML gradient uses the
2549        // envelope theorem, exact only at r = 0; a residual that large leaks <r, beta_dot>
2550        // into the objective<->gradient consistency, AND the IFT envelope correction is
2551        // only leading-order in r, so it cannot make the analytic gradient the exact
2552        // derivative of the (re-converged, non-smooth-in-r) value surface either. The
2553        // robust cure is to drive the inner to TRUE stationarity (||r|| ~ 1e-11) so the
2554        // envelope is exactly valid and the IFT term is ~1e-22 — which it is at small
2555        // lambda, but a plain undamped Newton-polish STALLS at large lambda (rho=4..8):
2556        // there the intercept-direction curvature exp(eta)*n is large while the penalized
2557        // time block is lambda*S, so H is ill-conditioned and an undamped step neither
2558        // decreases ||r|| nor stays feasible, leaving ||r|| ~ 3e-2.
2559        //
2560        // Levenberg–Marquardt damping fixes this: solve (H + mu*diag(H)) delta = r,
2561        // accept on a genuine ||r||^2 decrease (Gauss–Newton on the stationarity system,
2562        // whose Jacobian is the penalized Hessian H), shrink mu on success and grow it on
2563        // rejection. The diagonal (Marquardt) scaling makes the damping curvature-aware so
2564        // the stiff time block and the soft intercept are damped commensurately. This
2565        // reliably reaches ||r|| below the FD-step round-off floor across the whole
2566        // rho = [-0.5 .. 8] range exercised by the consistency gates.
2567        {
2568            const POLISH_MAX_ITERS: usize = 400;
2569            const POLISH_TOL: f64 = 1e-13;
2570            // Armijo sufficient-decrease constant and backtracking factor.
2571            const ARMIJO_C: f64 = 1e-4;
2572            const BACKTRACK: f64 = 0.5;
2573            const MAX_BACKTRACK: usize = 80;
2574            let p = beta.len();
2575            // Penalized inner objective f(β) = −ℓ(β) + ½β'Sβ + ½ridge‖β‖² whose
2576            // gradient is exactly `state.gradient` and whose Hessian is exactly
2577            // `state.hessian`. `update_state` exposes the pieces directly.
2578            let penalized_objective =
2579                |st: &WorkingState| -> f64 { -st.log_likelihood + st.penalty_term };
2580            for _ in 0..POLISH_MAX_ITERS {
2581                let st = match candidate.update_state(&beta) {
2582                    Ok(st) => st,
2583                    Err(_) => break,
2584                };
2585                let r = st.gradient.clone();
2586                let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
2587                if !r_norm.is_finite() || r_norm < POLISH_TOL {
2588                    break;
2589                }
2590                let h = st.hessian.to_dense();
2591                let f0 = penalized_objective(&st);
2592                // Newton DIRECTION d = −H⁻¹r on the convex penalized survival
2593                // likelihood, found via a Levenberg–Marquardt-regularized solve
2594                // so an ill-conditioned H (h_diag ratio ~2400 at β₀≈4.6, where
2595                // exp(η) is huge) cannot produce a garbage direction whose
2596                // quadratic form rᵀH⁻¹r loses its sign. If even the regularized
2597                // Newton direction is not a sufficient descent direction, fall
2598                // back to STEEPEST DESCENT d = −r, which is ALWAYS a descent
2599                // direction on the convex objective (∇fᵀ(−r) = −‖r‖² < 0). The
2600                // line search below is on the OBJECTIVE VALUE (not ‖r‖), so any
2601                // descent direction makes monotone progress; near the optimum
2602                // the (lightly regularized) Newton step recovers fast local
2603                // convergence. This is globally convergent for the convex
2604                // penalized survival NLL — driving ‖r‖ below the FD round-off
2605                // floor so the envelope ρ-gradient equals the finite difference.
2606                let h_scale = (0..p)
2607                    .map(|d| h[[d, d]].abs())
2608                    .fold(0.0_f64, f64::max)
2609                    .max(1.0);
2610                // Solve (H + λI) step = r by an EXACT Cholesky factorization
2611                // (faer Llt), NOT the DenseSpectralOperator: the spectral
2612                // operator clamps tiny/negative eigenvalues, which on the
2613                // catastrophically ill-conditioned boundary Hessian (cond ~2400,
2614                // exp(η) huge at β₀≈4.6) corrupts the solve so badly that
2615                // rᵀH⁻¹r lost its sign and the previous polish broke on iter 0.
2616                // Cholesky succeeds iff H+λI is SPD; sweeping λ up from 0 finds
2617                // the smallest SPD shift, and for an SPD system rᵀ(H+λI)⁻¹r > 0
2618                // EXACTLY (Cholesky is backward-stable, no clamping), so the
2619                // Newton direction is a guaranteed descent direction.
2620                let mut step: Option<Array1<f64>> = None;
2621                let mut dir_deriv = 0.0_f64;
2622                for lm_pow in 0..18 {
2623                    let lambda_lm = if lm_pow == 0 {
2624                        0.0
2625                    } else {
2626                        1e-12 * h_scale * 10f64.powi(lm_pow)
2627                    };
2628                    let mut h_reg = h.clone();
2629                    for d in 0..p {
2630                        h_reg[[d, d]] += lambda_lm;
2631                    }
2632                    let factor = match gam_linalg::faer_ndarray::FaerCholesky::cholesky(
2633                        &h_reg,
2634                        faer::Side::Lower,
2635                    ) {
2636                        Ok(f) => f,
2637                        Err(_) => continue,
2638                    };
2639                    let candidate_step = factor.solvevec(&r);
2640                    if candidate_step.iter().any(|v| !v.is_finite()) {
2641                        continue;
2642                    }
2643                    // ∇fᵀd = rᵀ(−step) = −r·(H+λI)⁻¹r < 0 exactly for SPD systems.
2644                    let dd = -r.dot(&candidate_step);
2645                    if dd.is_finite() && dd < -1e-14 * r_norm * r_norm {
2646                        step = Some(candidate_step);
2647                        dir_deriv = dd;
2648                        break;
2649                    }
2650                }
2651                let (step, dir_deriv) = match step {
2652                    Some(s) => (s, dir_deriv),
2653                    None => {
2654                        // Steepest-descent fallback: d = −r ⇒ step = +r (we step
2655                        // β − step), ∇fᵀd = −‖r‖² < 0.
2656                        (r.clone(), -r_norm * r_norm)
2657                    }
2658                };
2659                let mut alpha = 1.0_f64;
2660                let mut accepted = false;
2661                for _ in 0..MAX_BACKTRACK {
2662                    let trial = &beta - &(alpha * &step);
2663                    if let Ok(ts) = candidate.update_state(&trial) {
2664                        let ft = penalized_objective(&ts);
2665                        let tn = ts.gradient.iter().map(|v| v * v).sum::<f64>().sqrt();
2666                        // Accept on EITHER a sufficient objective decrease (Armijo,
2667                        // the global-convergence guarantee on the convex objective)
2668                        // OR a strict residual-norm decrease. Near the solution the
2669                        // penalized objective is flat to f64 roundoff (f0 ≈ ft), so a
2670                        // pure-Armijo test backtracks α→0 and crawls (the asymmetric
2671                        // ρ=3.99999 stall: 200 iters at 3.7e-7 vs 12 iters at the
2672                        // other two ρ). The ‖r‖-decrease arm lets the exact Cholesky
2673                        // Newton step (α=1) through, restoring quadratic convergence
2674                        // to ~1e-12 symmetrically across all three FD points so the
2675                        // centered FD of the value surface is itself exact.
2676                        let armijo_ok = ft.is_finite() && ft <= f0 + ARMIJO_C * alpha * dir_deriv;
2677                        let residual_ok = tn.is_finite() && tn < r_norm;
2678                        if armijo_ok || residual_ok {
2679                            beta = trial;
2680                            accepted = true;
2681                            break;
2682                        }
2683                    }
2684                    alpha *= BACKTRACK;
2685                }
2686                if !accepted {
2687                    break;
2688                }
2689            }
2690        }
2691
2692        Ok((candidate, beta))
2693    }
2694}
2695
2696/// Derivative provider that adapts survival third-derivative Hessian corrections
2697/// to the unified [`HessianDerivativeProvider`](gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider)
2698/// trait.
2699///
2700/// The unified trait supplies `v_k = H^{-1}(A_k beta_hat)` (positive sign),
2701/// whereas the survival engine's
2702/// [`survival_hessian_derivative_correction`](WorkingModelSurvival::survival_hessian_derivative_correction)
2703/// expects `u_k = -v_k`. This provider handles the sign conversion.
2704pub(crate) struct SurvivalDerivProvider {
2705    model: WorkingModelSurvival,
2706    beta: Array1<f64>,
2707}
2708
2709impl SurvivalDerivProvider {
2710    pub(crate) fn new(model: WorkingModelSurvival, beta: Array1<f64>) -> Self {
2711        Self { model, beta }
2712    }
2713}
2714
2715impl gam_solve::estimate::reml::reml_outer_engine::HessianDerivativeProvider for SurvivalDerivProvider {
2716    fn hessian_derivative_correction(
2717        &self,
2718        v_k: &Array1<f64>,
2719    ) -> Result<Option<Array2<f64>>, String> {
2720        // The trait provides v_k = H^{-1}(A_k beta_hat) (positive).
2721        // The survival method expects u_k = -H^{-1} A_k beta_hat = -v_k.
2722        let u_k = -v_k;
2723        match self
2724            .model
2725            .survival_hessian_derivative_correction(&self.beta, &u_k)
2726        {
2727            Ok(correction) => Ok(Some(correction)),
2728            Err(e) => Err(e.to_string()),
2729        }
2730    }
2731
2732    fn has_corrections(&self) -> bool {
2733        true
2734    }
2735}
2736
2737#[derive(Debug, Clone)]
2738pub struct CrudeRiskResult {
2739    pub risk: f64,
2740    pub diseasegradient: Array1<f64>,
2741    pub mortalitygradient: Array1<f64>,
2742}
2743
2744#[derive(Debug, Clone)]
2745pub struct CompetingRisksCifResult {
2746    /// Cumulative incidence per endpoint. `cif[ep][[row, time_idx]]` is the
2747    /// probability that cause `ep` occurred by `times[time_idx]` for sample `row`.
2748    /// Stored one matrix per endpoint so it is ergonomic to index per-cause and
2749    /// natural to construct from the per-endpoint cumulative-hazard inputs.
2750    pub cif: Vec<Array2<f64>>,
2751    pub overall_survival: Array2<f64>,
2752}
2753
2754/// Subject-count threshold below which competing-risks CIF assembly stays on the
2755/// serial path. The per-row work (a `n_times`-long prefix-sum recurrence with a
2756/// handful of `exp`/`exp_m1` per element) is cheap, so small panels avoid rayon
2757/// fan-out overhead; large panels (the #1082 quality-test sizes) amortize it.
2758const COMPETING_RISKS_CIF_PARALLEL_ROW_MIN: usize = 256;
2759
2760pub fn assemble_competing_risks_cif(
2761    times: ArrayView1<'_, f64>,
2762    cumulative_hazard: ArrayView3<'_, f64>,
2763) -> Result<CompetingRisksCifResult, SurvivalError> {
2764    let (n_endpoints, n_rows, n_times) = cumulative_hazard.dim();
2765    if n_endpoints == 0 {
2766        return Err(SurvivalError::DimensionMismatch);
2767    }
2768    let endpoint_hazards = cumulative_hazard
2769        .axis_iter(Axis(0))
2770        .map(|view| view.to_owned())
2771        .collect::<Vec<_>>();
2772    assemble_competing_risks_cif_from_endpoints(times, &endpoint_hazards).and_then(|result| {
2773        if result.overall_survival.dim() != (n_rows, n_times) {
2774            Err(SurvivalError::DimensionMismatch)
2775        } else {
2776            Ok(result)
2777        }
2778    })
2779}
2780
2781pub fn assemble_competing_risks_cif_from_endpoints(
2782    times: ArrayView1<'_, f64>,
2783    cumulative_hazards: &[Array2<f64>],
2784) -> Result<CompetingRisksCifResult, SurvivalError> {
2785    let n_endpoints = cumulative_hazards.len();
2786    if n_endpoints == 0 || times.is_empty() {
2787        return Err(SurvivalError::DimensionMismatch);
2788    }
2789    let (n_rows, n_times) = cumulative_hazards[0].dim();
2790    if n_rows == 0 || n_times == 0 || times.len() != n_times {
2791        return Err(SurvivalError::DimensionMismatch);
2792    }
2793    if times.iter().any(|time| !time.is_finite() || *time < 0.0) {
2794        return Err(SurvivalError::InvalidTimeGrid);
2795    }
2796    if times
2797        .iter()
2798        .zip(times.iter().skip(1))
2799        .any(|(previous, current)| current <= previous)
2800    {
2801        return Err(SurvivalError::InvalidTimeGrid);
2802    }
2803    for endpoint_hazard in cumulative_hazards {
2804        if endpoint_hazard.dim() != (n_rows, n_times) {
2805            return Err(SurvivalError::DimensionMismatch);
2806        }
2807        if endpoint_hazard.iter().any(|value| !value.is_finite()) {
2808            return Err(SurvivalError::NonFiniteInput);
2809        }
2810    }
2811
2812    let max_abs_hazard = cumulative_hazards
2813        .iter()
2814        .flat_map(|endpoint_hazard| endpoint_hazard.iter())
2815        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2816    let monotone_tolerance = 1.0e-10_f64 * max_abs_hazard.max(1.0);
2817    let mut cif: Vec<Array2<f64>> = (0..n_endpoints)
2818        .map(|_| Array2::<f64>::zeros((n_rows, n_times)))
2819        .collect();
2820    let mut overall_survival = Array2::<f64>::zeros((n_rows, n_times));
2821
2822    // Per-row CIF assembly. The TIME axis is a sequential prefix-sum recurrence
2823    // (`previous_*` carry forward across `time_idx`) and MUST stay ordered, so it
2824    // is left as the inner serial loop. The ROW (subject) axis is fully
2825    // independent: every `previous_*`/`increments` buffer is allocated fresh per
2826    // row, no state crosses rows, and each row writes only its own disjoint
2827    // output slices. The per-row result is byte-identical regardless of which
2828    // thread runs it, so we fan the outer row loop out over rayon and write the
2829    // owned per-row buffers back serially in row order (deterministic, bit-exact
2830    // vs. the serial implementation).
2831    //
2832    // `cif_flat` is endpoint-major: `cif_flat[endpoint * n_times + time_idx]`.
2833    let assemble_row = |row: usize| -> Result<(Vec<f64>, Vec<f64>), SurvivalError> {
2834        let mut cif_flat = vec![0.0_f64; n_endpoints * n_times];
2835        let mut surv_row = vec![0.0_f64; n_times];
2836        let mut previous_cif = vec![0.0_f64; n_endpoints];
2837        let mut previous_cumulative = vec![0.0_f64; n_endpoints];
2838        let mut increments = vec![0.0_f64; n_endpoints];
2839        let mut previous_total_cumulative = 0.0_f64;
2840        for time_idx in 0..n_times {
2841            let mut total_increment = 0.0_f64;
2842            for endpoint in 0..n_endpoints {
2843                let current = cumulative_hazards[endpoint][[row, time_idx]];
2844                if current < -monotone_tolerance {
2845                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2846                }
2847                let raw_increment = current - previous_cumulative[endpoint];
2848                if raw_increment < -monotone_tolerance {
2849                    return Err(SurvivalError::NonMonotoneCumulativeHazard);
2850                }
2851                let increment = raw_increment.max(0.0);
2852                increments[endpoint] = increment;
2853                total_increment += increment;
2854                previous_cumulative[endpoint] += increment;
2855            }
2856
2857            let survival_left = (-previous_total_cumulative).exp();
2858            let interval_failure = -(-total_increment).exp_m1();
2859            for endpoint in 0..n_endpoints {
2860                if total_increment > 0.0 {
2861                    previous_cif[endpoint] +=
2862                        survival_left * interval_failure * increments[endpoint] / total_increment;
2863                }
2864                cif_flat[endpoint * n_times + time_idx] = previous_cif[endpoint].clamp(0.0, 1.0);
2865            }
2866            previous_total_cumulative += total_increment;
2867            // Derive `S(t)` from the stored cause-specific CIFs at this time so
2868            // that the competing-risks closure identity
2869            //   Σ_k F_k(t) + S(t) = 1
2870            // holds bit-exactly. Computing `S` independently as
2871            // `exp(-Σ_k H_k(t))` and then comparing against the (clamped, ratio-
2872            // split) Σ F_k introduces O(machine-eps) closure error because the
2873            // float increments
2874            //   ΔF_k = S_left·(1-exp(-ΔH))·ΔH_k/ΔH_total
2875            // do not sum to `S_left - S_new` bit-exactly. By summing the stored
2876            // CIFs in the same left-fold order as `slice.iter().sum::<f64>()`
2877            // and defining `S := 1.0 - Σ F_k`, the IEEE-754 round-trip
2878            //   (1.0 - f) + f
2879            // restores the identity for finite f ∈ [0, 1]. The mathematically
2880            // consistent survival value `exp(-H_total)` is still tracked up to
2881            // ulp-level precision because the ΔF_k construction matches
2882            // `S_left - S_new` to leading order.
2883            let mut fsum_at_t = 0.0_f64;
2884            for endpoint in 0..n_endpoints {
2885                fsum_at_t += cif_flat[endpoint * n_times + time_idx];
2886            }
2887            surv_row[time_idx] = (1.0_f64 - fsum_at_t).clamp(0.0, 1.0);
2888        }
2889        Ok((cif_flat, surv_row))
2890    };
2891
2892    // Nesting guard (`rayon::current_thread_index().is_none()`) keeps us from
2893    // oversubscribing when this routine is itself called from inside a rayon
2894    // worker, and the row-count gate keeps small inputs on the serial path.
2895    let rows: Vec<(Vec<f64>, Vec<f64>)> = if n_rows >= COMPETING_RISKS_CIF_PARALLEL_ROW_MIN
2896        && rayon::current_thread_index().is_none()
2897    {
2898        use rayon::prelude::*;
2899        (0..n_rows)
2900            .into_par_iter()
2901            .map(assemble_row)
2902            .collect::<Result<_, _>>()?
2903    } else {
2904        (0..n_rows).map(assemble_row).collect::<Result<_, _>>()?
2905    };
2906
2907    for (row, (cif_flat, surv_row)) in rows.into_iter().enumerate() {
2908        for endpoint in 0..n_endpoints {
2909            for time_idx in 0..n_times {
2910                cif[endpoint][[row, time_idx]] = cif_flat[endpoint * n_times + time_idx];
2911            }
2912        }
2913        for time_idx in 0..n_times {
2914            overall_survival[[row, time_idx]] = surv_row[time_idx];
2915        }
2916    }
2917
2918    Ok(CompetingRisksCifResult {
2919        cif,
2920        overall_survival,
2921    })
2922}
2923
2924fn compute_gauss_legendre_nodes(n: usize) -> Vec<(f64, f64)> {
2925    let mut nodesweights = Vec::with_capacity(n);
2926    let m = n.div_ceil(2);
2927
2928    for i in 0..m {
2929        let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
2930        let mut pp = 0.0;
2931
2932        for _ in 0..100 {
2933            let mut p1 = 1.0;
2934            let mut p2 = 0.0;
2935            for j in 0..n {
2936                let p3 = p2;
2937                p2 = p1;
2938                p1 = ((2.0 * j as f64 + 1.0) * z * p2 - j as f64 * p3) / (j as f64 + 1.0);
2939            }
2940            pp = n as f64 * (z * p1 - p2) / (z * z - 1.0);
2941            let z_prev = z;
2942            z = z_prev - p1 / pp;
2943            if (z - z_prev).abs() < 1e-14 {
2944                break;
2945            }
2946        }
2947
2948        let x = z;
2949        let w = 2.0 / ((1.0 - z * z) * pp * pp);
2950        if !n.is_multiple_of(2) && i == m - 1 {
2951            nodesweights.push((0.0, w));
2952        } else {
2953            nodesweights.push((-x, w));
2954            nodesweights.push((x, w));
2955        }
2956    }
2957
2958    nodesweights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
2959    nodesweights
2960}
2961
2962fn gauss_legendre_quadrature() -> &'static [(f64, f64)] {
2963    // `LazyLock` (not `OnceLock::get_or_init`) so first init never parks a
2964    // caller on the OS condvar from inside a rayon worker. The competing-risks
2965    // CIF assembler in this file dispatches `into_par_iter` and the
2966    // codebase-level lint (`tests/once_lock_get_or_init_not_inside_parallel_regions.rs`)
2967    // forbids the lazy `OnceLock` accessor in any rayon-adjacent file.
2968    static CACHE: LazyLock<Vec<(f64, f64)>> = LazyLock::new(|| compute_gauss_legendre_nodes(40));
2969    &CACHE
2970}
2971
2972/// Engine-level crude risk quadrature with exact delta-method gradients.
2973///
2974/// This routine owns the numerical integration and gradient accumulation math:
2975/// - It integrates `h_d(u) * S_total(u | t0)` over `[t0, t1]` by high-order
2976///   Gauss-Legendre quadrature.
2977/// - It computes gradients w.r.t. disease and mortality coefficients:
2978///   d Risk / d beta_d and d Risk / d beta_m.
2979///
2980/// The adapter provides the domain-specific point evaluator callback `eval_at`,
2981/// which fills design rows and returns:
2982/// - instantaneous disease hazard h_d(u) at age `u`,
2983/// - cumulative disease hazard `H_d(u)`,
2984/// - cumulative mortality hazard `H_m(u)`.
2985///
2986/// The callback must fill the following arrays (one entry per coefficient):
2987/// - `design_d[j]`: partial derivative of the linear predictor eta_d w.r.t. beta_j
2988///   at time u, i.e. x_j(u) = d eta_d(u) / d beta_j.
2989/// - `deriv_d[j]`: partial derivative of the TIME DERIVATIVE of eta_d w.r.t. beta_j
2990///   at time u, i.e. x_dot_j(u) = d/d(beta_j) [d eta_d(u)/du].
2991/// - `design_m[j]`: same as design_d but for the mortality linear predictor eta_m.
2992///
2993/// This keeps domain/data wiring out of `gam` while centralizing the
2994/// integration engine in one place.
2995pub fn calculate_crude_risk_quadrature<F>(
2996    t0: f64,
2997    t1: f64,
2998    breakpoints: &[f64],
2999    h_dis_t0: f64,
3000    h_mor_t0: f64,
3001    design_d_t0: ArrayView1<'_, f64>,
3002    design_m_t0: ArrayView1<'_, f64>,
3003    mut eval_at: F,
3004) -> Result<CrudeRiskResult, SurvivalError>
3005where
3006    F: FnMut(
3007        f64,
3008        &mut Array1<f64>,
3009        &mut Array1<f64>,
3010        &mut Array1<f64>,
3011    ) -> Result<(f64, f64, f64), SurvivalError>,
3012{
3013    let coeff_len_d = design_d_t0.len();
3014    let coeff_len_m = design_m_t0.len();
3015    if coeff_len_d == 0 || coeff_len_m == 0 {
3016        return Err(SurvivalError::InvalidIntegrationSetup);
3017    }
3018    if !t0.is_finite()
3019        || !t1.is_finite()
3020        || !h_dis_t0.is_finite()
3021        || !h_mor_t0.is_finite()
3022        || design_d_t0.iter().any(|v| !v.is_finite())
3023        || design_m_t0.iter().any(|v| !v.is_finite())
3024    {
3025        return Err(SurvivalError::NonFiniteInput);
3026    }
3027    if t1 <= t0 {
3028        return Ok(CrudeRiskResult {
3029            risk: 0.0,
3030            diseasegradient: Array1::zeros(coeff_len_d),
3031            mortalitygradient: Array1::zeros(coeff_len_m),
3032        });
3033    }
3034
3035    let mut sorted_breaks: Vec<f64> = breakpoints
3036        .iter()
3037        .copied()
3038        .filter(|x| x.is_finite() && *x >= t0 && *x <= t1)
3039        .collect();
3040    sorted_breaks.push(t0);
3041    sorted_breaks.push(t1);
3042    sorted_breaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3043    sorted_breaks.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
3044    if sorted_breaks.len() < 2 {
3045        return Err(SurvivalError::InvalidIntegrationSetup);
3046    }
3047
3048    let mut total_risk = 0.0;
3049    let mut diseasegradient = Array1::zeros(coeff_len_d);
3050    let mut mortalitygradient = Array1::zeros(coeff_len_m);
3051    let nodesweights = gauss_legendre_quadrature();
3052
3053    let mut design_d = Array1::<f64>::zeros(coeff_len_d);
3054    let mut deriv_d = Array1::<f64>::zeros(coeff_len_d);
3055    let mut design_m = Array1::<f64>::zeros(coeff_len_m);
3056
3057    for segment in sorted_breaks.windows(2) {
3058        let a = segment[0];
3059        let b = segment[1];
3060        let center = 0.5 * (b + a);
3061        let halfwidth = 0.5 * (b - a);
3062        if halfwidth <= 0.0 {
3063            continue;
3064        }
3065
3066        for &(x, w) in nodesweights {
3067            let u = center + halfwidth * x;
3068            let (inst_hazard_d, hazard_d, hazard_m) =
3069                eval_at(u, &mut design_d, &mut deriv_d, &mut design_m)?;
3070            if !inst_hazard_d.is_finite() || !hazard_d.is_finite() || !hazard_m.is_finite() {
3071                return Err(SurvivalError::NonFiniteInput);
3072            }
3073            if inst_hazard_d <= 0.0 {
3074                return Err(SurvivalError::NonPositiveHazard);
3075            }
3076
3077            if hazard_d < h_dis_t0 || hazard_m < h_mor_t0 {
3078                return Err(SurvivalError::NonMonotoneCumulativeHazard);
3079            }
3080
3081            let h_dis_cond = hazard_d - h_dis_t0;
3082            let h_mor_cond = hazard_m - h_mor_t0;
3083            let s_total = (-(h_dis_cond + h_mor_cond)).exp();
3084
3085            total_risk += w * inst_hazard_d * s_total * halfwidth;
3086
3087            // d Risk / d beta_d:
3088            //   integral [ d h_d * S_total - h_d * S_total * d H_d ] du
3089            // Contract: design_d[j] = x_j(u) = ∂_{β_j} η_d(u)
3090            //           deriv_d[j]  = ẋ_j(u) = ∂_{β_j} η̇_d(u)
3091            // Then ∂_{β_j} h_d = h_d · x_j + H_d · ẋ_j
3092            let weight = w * s_total * halfwidth;
3093            for j in 0..coeff_len_d {
3094                let d_inst_hazard = inst_hazard_d * design_d[j] + hazard_d * deriv_d[j];
3095                let d_hazard_cond = hazard_d * design_d[j] - h_dis_t0 * design_d_t0[j];
3096                let g = d_inst_hazard - inst_hazard_d * d_hazard_cond;
3097                diseasegradient[j] += weight * g;
3098            }
3099
3100            // d Risk / d beta_m:
3101            //   -integral h_d * S_total * d H_m(u|t0) du
3102            let weight = w * inst_hazard_d * s_total * halfwidth;
3103            for j in 0..coeff_len_m {
3104                let g = -hazard_m * design_m[j] + h_mor_t0 * design_m_t0[j];
3105                mortalitygradient[j] += weight * g;
3106            }
3107        }
3108    }
3109
3110    Ok(CrudeRiskResult {
3111        risk: total_risk,
3112        diseasegradient,
3113        mortalitygradient,
3114    })
3115}
3116
3117impl PirlsWorkingModel for WorkingModelSurvival {
3118    fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError> {
3119        self.update_state(beta)
3120    }
3121}
3122
3123#[cfg(test)]
3124mod tests {
3125    use super::*;
3126    use ndarray::{Array1, Array2, Array3, array, s};
3127
3128    #[test]
3129    fn competing_risks_cif_constant_hazard_matches_closed_form() {
3130        let times = array![0.0, 2.0, 5.0, 10.0];
3131        let disease_rates = [0.12, 0.06];
3132        let death_rates = [0.05, 0.02];
3133        let cumulative = Array3::from_shape_fn((2, 2, times.len()), |(endpoint, row, time_idx)| {
3134            let rate = if endpoint == 0 {
3135                disease_rates[row]
3136            } else {
3137                death_rates[row]
3138            };
3139            rate * times[time_idx]
3140        });
3141
3142        let result =
3143            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3144
3145        for row in 0..2 {
3146            let total_rate = disease_rates[row] + death_rates[row];
3147            for time_idx in 0..times.len() {
3148                let failure = 1.0 - (-total_rate * times[time_idx]).exp();
3149                let expected_disease = disease_rates[row] / total_rate * failure;
3150                let expected_death = death_rates[row] / total_rate * failure;
3151                assert!((result.cif[0][[row, time_idx]] - expected_disease).abs() < 1e-12);
3152                assert!((result.cif[1][[row, time_idx]] - expected_death).abs() < 1e-12);
3153                assert!(
3154                    (result.cif[0][[row, time_idx]]
3155                        + result.cif[1][[row, time_idx]]
3156                        + result.overall_survival[[row, time_idx]]
3157                        - 1.0)
3158                        .abs()
3159                        < 1e-12
3160                );
3161            }
3162        }
3163    }
3164
3165    #[test]
3166    fn competing_risks_cif_rejects_nonmonotone_hazards() {
3167        let times = array![0.0, 1.0, 2.0];
3168        let cumulative = Array3::from_shape_vec((1, 1, 3), vec![0.0, 0.2, 0.1]).expect("shape");
3169        let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3170            .expect_err("nonmonotone cumulative hazard should be rejected");
3171        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
3172    }
3173
3174    #[test]
3175    fn competing_risks_cif_plateaus_and_three_causes_conserve_probability() {
3176        let times = array![0.0, 1.0, 3.0, 7.0, 12.0];
3177        let cumulative = Array3::from_shape_vec(
3178            (3, 2, 5),
3179            vec![
3180                // cause 1
3181                0.0, 0.2, 0.2, 0.5, 1.1, 0.0, 0.0, 0.4, 0.4, 0.9, // cause 2
3182                0.0, 0.1, 0.3, 0.3, 0.7, 0.0, 0.2, 0.2, 0.8, 0.8, // cause 3
3183                0.0, 0.0, 0.2, 0.6, 0.6, 0.0, 0.1, 0.5, 0.5, 1.5,
3184            ],
3185        )
3186        .expect("shape");
3187
3188        let result =
3189            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3190
3191        for row in 0..2 {
3192            for time_idx in 0..times.len() {
3193                let total_cif = result.cif[0][[row, time_idx]]
3194                    + result.cif[1][[row, time_idx]]
3195                    + result.cif[2][[row, time_idx]];
3196                assert!(
3197                    (total_cif + result.overall_survival[[row, time_idx]] - 1.0).abs() < 1e-12,
3198                    "probability mass mismatch at row={row}, time_idx={time_idx}"
3199                );
3200                assert!((0.0..=1.0).contains(&result.overall_survival[[row, time_idx]]));
3201                for cause in 0..3 {
3202                    assert!((0.0..=1.0).contains(&result.cif[cause][[row, time_idx]]));
3203                    if time_idx > 0 {
3204                        assert!(
3205                            result.cif[cause][[row, time_idx]] + 1e-12
3206                                >= result.cif[cause][[row, time_idx - 1]],
3207                            "CIF decreased for cause={cause}, row={row}, time_idx={time_idx}"
3208                        );
3209                    }
3210                }
3211            }
3212        }
3213
3214        // Cause 1 is flat between t=1 and t=3 for row 0, but other causes
3215        // fail in that interval; its CIF must remain exactly flat.
3216        assert_eq!(result.cif[0][[0, 1]], result.cif[0][[0, 2]]);
3217        // All causes are flat between t=3 and t=7 for row 1 except cause 2;
3218        // causes 1 and 3 must not move.
3219        assert_eq!(result.cif[0][[1, 2]], result.cif[0][[1, 3]]);
3220        assert_eq!(result.cif[2][[1, 2]], result.cif[2][[1, 3]]);
3221    }
3222
3223    #[test]
3224    fn competing_risks_cif_rejects_bad_time_grids_and_nonfinite_hazards() {
3225        let cumulative = Array3::zeros((2, 1, 2));
3226
3227        for times in [array![0.0, 0.0], array![1.0, 0.5], array![-1.0, 1.0]] {
3228            let err = assemble_competing_risks_cif(times.view(), cumulative.view())
3229                .expect_err("bad time grid should be rejected");
3230            assert!(matches!(err, SurvivalError::InvalidTimeGrid));
3231        }
3232
3233        let times = array![0.0, 1.0];
3234        let nonfinite = Array3::from_shape_vec((1, 1, 2), vec![0.0, f64::NAN]).expect("shape");
3235        let err = assemble_competing_risks_cif(times.view(), nonfinite.view())
3236            .expect_err("nonfinite hazard should be rejected");
3237        assert!(matches!(err, SurvivalError::NonFiniteInput));
3238    }
3239
3240    #[test]
3241    fn competing_risks_cif_extreme_hazards_remain_bounded() {
3242        let times = array![0.0, 1.0, 2.0];
3243        let cumulative =
3244            Array3::from_shape_vec((2, 1, 3), vec![0.0, 500.0, 1000.0, 0.0, 250.0, 1000.0])
3245                .expect("shape");
3246
3247        let result =
3248            assemble_competing_risks_cif(times.view(), cumulative.view()).expect("assemble CIF");
3249
3250        for value in result
3251            .cif
3252            .iter()
3253            .flat_map(|m| m.iter())
3254            .chain(result.overall_survival.iter())
3255        {
3256            assert!(value.is_finite());
3257            assert!((0.0..=1.0).contains(value));
3258        }
3259        assert!((result.cif[0][[0, 2]] + result.cif[1][[0, 2]] - 1.0).abs() < 1e-12);
3260        assert_eq!(result.overall_survival[[0, 2]], 0.0);
3261    }
3262
3263    fn toy_penalties() -> PenaltyBlocks {
3264        let s = array![[2.0, 0.5], [0.5, 3.0]];
3265        PenaltyBlocks::new(vec![PenaltyBlock {
3266            matrix: s,
3267            lambda: 1.7,
3268            range: 1..3,
3269            nullspace_dim: 0,
3270        }])
3271    }
3272
3273    fn survival_inputs<'a>(
3274        age_entry: &'a Array1<f64>,
3275        age_exit: &'a Array1<f64>,
3276        event_target: &'a Array1<u8>,
3277        event_competing: &'a Array1<u8>,
3278        sampleweight: &'a Array1<f64>,
3279        x_entry: &'a Array2<f64>,
3280        x_exit: &'a Array2<f64>,
3281        x_derivative: &'a Array2<f64>,
3282    ) -> SurvivalEngineInputs<'a> {
3283        SurvivalEngineInputs {
3284            age_entry: age_entry.view(),
3285            age_exit: age_exit.view(),
3286            event_target: event_target.view(),
3287            event_competing: event_competing.view(),
3288            sampleweight: sampleweight.view(),
3289            x_entry: x_entry.view(),
3290            x_exit: x_exit.view(),
3291            x_derivative: x_derivative.view(),
3292            monotonicity_constraint_rows: None,
3293            monotonicity_constraint_offsets: None,
3294        }
3295    }
3296
3297    fn survival_model(
3298        inputs: SurvivalEngineInputs<'_>,
3299        penalties: PenaltyBlocks,
3300        monotonicity: SurvivalMonotonicityPenalty,
3301        spec: SurvivalSpec,
3302    ) -> Result<WorkingModelSurvival, SurvivalError> {
3303        WorkingModelSurvival::from_engine_inputs(inputs, penalties, monotonicity, spec)
3304    }
3305
3306    fn survival_model_with_offsets(
3307        inputs: SurvivalEngineInputs<'_>,
3308        offsets: Option<SurvivalBaselineOffsets<'_>>,
3309        penalties: PenaltyBlocks,
3310        monotonicity: SurvivalMonotonicityPenalty,
3311        spec: SurvivalSpec,
3312    ) -> Result<WorkingModelSurvival, SurvivalError> {
3313        WorkingModelSurvival::from_engine_inputswith_offsets(
3314            inputs,
3315            offsets,
3316            penalties,
3317            monotonicity,
3318            spec,
3319        )
3320    }
3321
3322    #[test]
3323    fn penaltyhessian_matchesgradient_jacobian() {
3324        let penalties = toy_penalties();
3325        let beta = array![10.0, -0.3, 1.2, 7.0];
3326
3327        let grad = penalties.gradient(&beta);
3328        let h = penalties.hessian(beta.len());
3329        let b_block = beta.slice(s![1..3]).to_owned();
3330        let expected = 1.7 * array![[2.0, 0.5], [0.5, 3.0]].dot(&b_block);
3331
3332        assert!((grad[1] - expected[0]).abs() < 1e-12);
3333        assert!((grad[2] - expected[1]).abs() < 1e-12);
3334        assert!((h[[1, 1]] - 1.7 * 2.0).abs() < 1e-12);
3335        assert!((h[[1, 2]] - 1.7 * 0.5).abs() < 1e-12);
3336        assert!((h[[2, 1]] - 1.7 * 0.5).abs() < 1e-12);
3337        assert!((h[[2, 2]] - 1.7 * 3.0).abs() < 1e-12);
3338    }
3339
3340    #[test]
3341    fn penaltygradient_matches_deviance_finite_difference() {
3342        let penalties = toy_penalties();
3343        let beta = array![10.0, -0.3, 1.2, 7.0];
3344        let grad = penalties.gradient(&beta);
3345        let eps = 1e-7;
3346
3347        for idx in 0..beta.len() {
3348            let mut plus = beta.clone();
3349            let mut minus = beta.clone();
3350            plus[idx] += eps;
3351            minus[idx] -= eps;
3352            let fd = (penalties.deviance(&plus) - penalties.deviance(&minus)) / (2.0 * eps);
3353            assert_eq!(
3354                grad[idx].signum(),
3355                fd.signum(),
3356                "gradient/deviance sign mismatch at idx={idx}: grad={} fd={fd}",
3357                grad[idx]
3358            );
3359            assert!(
3360                (grad[idx] - fd).abs() < 1e-6,
3361                "gradient/deviance mismatch at idx={idx}: grad={} fd={fd}",
3362                grad[idx]
3363            );
3364        }
3365    }
3366
3367    #[test]
3368    fn zero_offsets_match_default_survival_state() {
3369        let age_entry = array![1.0_f64, 2.0_f64];
3370        let age_exit = array![2.0_f64, 3.5_f64];
3371        let event_target = array![1u8, 0u8];
3372        let event_competing = array![0u8, 0u8];
3373        let sampleweight = array![1.0, 1.0];
3374        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3375        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3376        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3377        let penalties = PenaltyBlocks::new(Vec::new());
3378        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3379        let beta = array![-1.0, 0.8];
3380
3381        let base = survival_model(
3382            survival_inputs(
3383                &age_entry,
3384                &age_exit,
3385                &event_target,
3386                &event_competing,
3387                &sampleweight,
3388                &x_entry,
3389                &x_exit,
3390                &x_derivative,
3391            ),
3392            penalties.clone(),
3393            mono,
3394            SurvivalSpec::Net,
3395        )
3396        .expect("construct base survival model");
3397
3398        let zero_offsets = survival_model_with_offsets(
3399            survival_inputs(
3400                &age_entry,
3401                &age_exit,
3402                &event_target,
3403                &event_competing,
3404                &sampleweight,
3405                &x_entry,
3406                &x_exit,
3407                &x_derivative,
3408            ),
3409            Some(SurvivalBaselineOffsets {
3410                eta_entry: array![0.0, 0.0].view(),
3411                eta_exit: array![0.0, 0.0].view(),
3412                derivative_exit: array![0.0, 0.0].view(),
3413            }),
3414            penalties,
3415            mono,
3416            SurvivalSpec::Net,
3417        )
3418        .expect("construct offset survival model");
3419
3420        let state_base = base.update_state(&beta).expect("base state");
3421        let statezero = zero_offsets.update_state(&beta).expect("zero-offset state");
3422        assert!((state_base.deviance - statezero.deviance).abs() < 1e-12);
3423        assert!(
3424            state_base
3425                .gradient
3426                .iter()
3427                .zip(statezero.gradient.iter())
3428                .all(|(a, b)| (a - b).abs() < 1e-12)
3429        );
3430    }
3431
3432    #[test]
3433    fn competing_risk_cause_labels_collapse_to_pooled_baseline_indicator() {
3434        // Regression for #378: the joint competing-risks Weibull path seeds a
3435        // shared single-hazard baseline working model from the dataset's event
3436        // *labels* {0 = censored, 1 = cause 1, 2 = cause 2}. The single-hazard
3437        // engine's `event_target` contract is binary {0, 1}, so feeding the raw
3438        // cause labels straight through used to bail out of construction via the
3439        // `event_target > 1` guard and surface as the misleading
3440        // `SurvivalError::NonFiniteInput` ("inputs contain non-finite values"),
3441        // even though every input value is finite. The fix (a) reports a
3442        // multi-cause label as the actionable `EventCodeInvalid`, never the
3443        // misleading "non-finite", and (b) projects cause labels to the
3444        // any-event {0, 1} indicator via the single-source-of-truth
3445        // `pooled_any_event_indicator` before constructing the pooled baseline.
3446        // This pins both halves of that contract.
3447        let age_entry = array![0.0_f64, 0.0, 0.0, 0.0];
3448        let age_exit = array![1.2_f64, 0.8, 2.1, 1.5];
3449        // Competing-risks cause labels: censored, cause 1, cause 2, censored.
3450        let cause_labels = array![0u8, 1u8, 2u8, 0u8];
3451        let event_competing = Array1::<u8>::zeros(cause_labels.len());
3452        let sampleweight = array![1.0_f64, 1.0, 1.0, 1.0];
3453        let x_entry = array![
3454            [1.0, age_entry[0].max(1e-8).ln()],
3455            [1.0, age_entry[1].max(1e-8).ln()],
3456            [1.0, age_entry[2].max(1e-8).ln()],
3457            [1.0, age_entry[3].max(1e-8).ln()],
3458        ];
3459        let x_exit = array![
3460            [1.0, age_exit[0].ln()],
3461            [1.0, age_exit[1].ln()],
3462            [1.0, age_exit[2].ln()],
3463            [1.0, age_exit[3].ln()],
3464        ];
3465        let x_derivative = array![
3466            [0.0, 1.0 / age_exit[0]],
3467            [0.0, 1.0 / age_exit[1]],
3468            [0.0, 1.0 / age_exit[2]],
3469            [0.0, 1.0 / age_exit[3]],
3470        ];
3471        let penalties = PenaltyBlocks::new(Vec::new());
3472        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3473
3474        // Raw cause labels {0,1,2} violate the single-hazard binary contract and
3475        // must be rejected -- but as an *actionable* `EventCodeInvalid`, NOT the
3476        // misleading `NonFiniteInput`: the labels are finite, they merely need
3477        // projecting. (The old fix left this surfacing as "non-finite".)
3478        let raw = survival_model(
3479            survival_inputs(
3480                &age_entry,
3481                &age_exit,
3482                &cause_labels,
3483                &event_competing,
3484                &sampleweight,
3485                &x_entry,
3486                &x_exit,
3487                &x_derivative,
3488            ),
3489            penalties.clone(),
3490            mono,
3491            SurvivalSpec::Net,
3492        );
3493        assert!(
3494            matches!(raw, Err(SurvivalError::EventCodeInvalid { .. })),
3495            "raw competing-risks cause labels must be rejected as EventCodeInvalid (not NonFiniteInput), got {raw:?}"
3496        );
3497
3498        // The pooled-baseline projection the workflow now performs through the
3499        // single source of truth: any observed event (any cause) -> {0, 1}.
3500        let any_event = pooled_any_event_indicator(cause_labels.view());
3501        assert_eq!(any_event, array![0u8, 1u8, 1u8, 0u8]);
3502        // And the per-cause projection that seeds each cause-specific block.
3503        assert_eq!(
3504            cause_specific_event_indicator(cause_labels.view(), 1),
3505            array![0u8, 1u8, 0u8, 0u8]
3506        );
3507        assert_eq!(
3508            cause_specific_event_indicator(cause_labels.view(), 2),
3509            array![0u8, 0u8, 1u8, 0u8]
3510        );
3511        let model = survival_model(
3512            survival_inputs(
3513                &age_entry,
3514                &age_exit,
3515                &any_event,
3516                &event_competing,
3517                &sampleweight,
3518                &x_entry,
3519                &x_exit,
3520                &x_derivative,
3521            ),
3522            penalties,
3523            mono,
3524            SurvivalSpec::Net,
3525        )
3526        .expect("pooled any-event baseline model must construct from competing-risks data");
3527
3528        // The constructed pooled baseline must yield a finite working state, so
3529        // the downstream baseline-seeding PIRLS loop has something to optimize.
3530        let beta = array![-1.0_f64, 0.8];
3531        let state = model.update_state(&beta).expect("pooled baseline state");
3532        assert!(
3533            state.deviance.is_finite(),
3534            "pooled baseline deviance must be finite, got {}",
3535            state.deviance
3536        );
3537        assert!(
3538            state.gradient.iter().all(|g| g.is_finite()),
3539            "pooled baseline gradient must be finite"
3540        );
3541    }
3542
3543    #[test]
3544    fn offset_channel_residuals_match_central_fd_of_nll() {
3545        // Three observations: two events (non-origin entry and origin entry)
3546        // and one censored row. This exercises every nonzero channel at least
3547        // once: r_exit from all rows, r_entry only from the first (has entry
3548        // interval), r_derivative only from events.
3549        let age_entry = array![0.5_f64, 0.0, 0.3];
3550        let age_exit = array![1.4_f64, 1.0, 2.0];
3551        let event_target = array![1u8, 1u8, 0u8];
3552        let event_competing = array![0u8, 0u8, 0u8];
3553        let sampleweight = array![1.0_f64, 2.5, 0.7];
3554        let x_entry = array![
3555            [1.0, age_entry[0].ln()],
3556            [1.0, age_entry[1].max(1e-8).ln()],
3557            [1.0, age_entry[2].ln()]
3558        ];
3559        let x_exit = array![
3560            [1.0, age_exit[0].ln()],
3561            [1.0, age_exit[1].ln()],
3562            [1.0, age_exit[2].ln()]
3563        ];
3564        let x_derivative = array![
3565            [0.0, 1.0 / age_exit[0]],
3566            [0.0, 1.0 / age_exit[1]],
3567            [0.0, 1.0 / age_exit[2]]
3568        ];
3569        // Baseline offsets chosen so η_entry, η_exit, s are all comfortably
3570        // away from overflow / monotonicity-violation boundaries.
3571        let o_entry = array![0.2_f64, 0.0, 0.1];
3572        let o_exit = array![0.4_f64, 0.5, 0.7];
3573        let o_deriv = array![0.3_f64, 0.8, 0.5];
3574        let penalties = PenaltyBlocks::new(Vec::new());
3575        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3576        let beta = array![-0.7_f64, 0.6];
3577
3578        let build = |o_e: &Array1<f64>, o_x: &Array1<f64>, o_d: &Array1<f64>| {
3579            survival_model_with_offsets(
3580                survival_inputs(
3581                    &age_entry,
3582                    &age_exit,
3583                    &event_target,
3584                    &event_competing,
3585                    &sampleweight,
3586                    &x_entry,
3587                    &x_exit,
3588                    &x_derivative,
3589                ),
3590                Some(SurvivalBaselineOffsets {
3591                    eta_entry: o_e.view(),
3592                    eta_exit: o_x.view(),
3593                    derivative_exit: o_d.view(),
3594                }),
3595                penalties.clone(),
3596                mono,
3597                SurvivalSpec::Net,
3598            )
3599            .expect("model build")
3600        };
3601
3602        let base = build(&o_entry, &o_exit, &o_deriv);
3603        let resid = base
3604            .offset_channel_residuals(&beta)
3605            .expect("offset residuals");
3606        assert_eq!(resid.exit.len(), 3);
3607        assert_eq!(resid.entry.len(), 3);
3608        assert_eq!(resid.derivative.len(), 3);
3609
3610        // NLL equals half the deviance returned by update_state; that is the
3611        // exact unpenalized loss whose offset partials r_{X,E,D} encode.
3612        let nll = |m: &WorkingModelSurvival| 0.5 * m.update_state(&beta).expect("state").deviance;
3613        let h = 1e-6;
3614
3615        // Row 1 (origin entry, event=1) has no entry interval, so r_entry[1]
3616        // must be exactly 0. Row 2 (censored) has r_deriv[2] exactly 0. Check
3617        // those identities before FD comparison on the nonzero elements.
3618        assert_eq!(resid.entry[1], 0.0);
3619        assert_eq!(resid.derivative[2], 0.0);
3620
3621        for i in 0..3 {
3622            // exit channel: perturb o_exit[i] alone.
3623            {
3624                let mut op = o_exit.clone();
3625                let mut om = o_exit.clone();
3626                op[i] += h;
3627                om[i] -= h;
3628                let fd = (nll(&build(&o_entry, &op, &o_deriv))
3629                    - nll(&build(&o_entry, &om, &o_deriv)))
3630                    / (2.0 * h);
3631                assert!(
3632                    (resid.exit[i] - fd).abs() < 1e-6,
3633                    "∂NLL/∂o_X[{i}]: analytic={:.6e} fd={:.6e}",
3634                    resid.exit[i],
3635                    fd
3636                );
3637            }
3638            // entry channel: only row 0 has an entry interval; for rows with
3639            // entry_at_origin the offset contributes nothing to NLL and FD
3640            // must also be exactly 0 to numerical precision.
3641            {
3642                let mut op = o_entry.clone();
3643                let mut om = o_entry.clone();
3644                op[i] += h;
3645                om[i] -= h;
3646                let fd = (nll(&build(&op, &o_exit, &o_deriv))
3647                    - nll(&build(&om, &o_exit, &o_deriv)))
3648                    / (2.0 * h);
3649                assert!(
3650                    (resid.entry[i] - fd).abs() < 1e-6,
3651                    "∂NLL/∂o_E[{i}]: analytic={:.6e} fd={:.6e}",
3652                    resid.entry[i],
3653                    fd
3654                );
3655            }
3656            // derivative channel: only event rows contribute.
3657            {
3658                let mut op = o_deriv.clone();
3659                let mut om = o_deriv.clone();
3660                op[i] += h;
3661                om[i] -= h;
3662                let fd = (nll(&build(&o_entry, &o_exit, &op))
3663                    - nll(&build(&o_entry, &o_exit, &om)))
3664                    / (2.0 * h);
3665                assert!(
3666                    (resid.derivative[i] - fd).abs() < 1e-6,
3667                    "∂NLL/∂o_D[{i}]: analytic={:.6e} fd={:.6e}",
3668                    resid.derivative[i],
3669                    fd
3670                );
3671            }
3672        }
3673    }
3674
3675    #[test]
3676    fn offset_channel_residuals_respect_zero_sampleweight() {
3677        let age_entry = array![1.0_f64, 2.0];
3678        let age_exit = array![2.0_f64, 3.5];
3679        let event_target = array![1u8, 1u8];
3680        let event_competing = array![0u8, 0u8];
3681        let sampleweight = array![0.0_f64, 1.2]; // row 0 is excluded by weight
3682        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3683        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3684        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3685        let penalties = PenaltyBlocks::new(Vec::new());
3686        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3687        let beta = array![-1.0_f64, 0.8];
3688
3689        let model = survival_model_with_offsets(
3690            survival_inputs(
3691                &age_entry,
3692                &age_exit,
3693                &event_target,
3694                &event_competing,
3695                &sampleweight,
3696                &x_entry,
3697                &x_exit,
3698                &x_derivative,
3699            ),
3700            Some(SurvivalBaselineOffsets {
3701                eta_entry: array![0.0_f64, 0.1].view(),
3702                eta_exit: array![0.0_f64, 0.2].view(),
3703                derivative_exit: array![0.0_f64, 0.1].view(),
3704            }),
3705            penalties,
3706            mono,
3707            SurvivalSpec::Net,
3708        )
3709        .expect("model");
3710        let r = model.offset_channel_residuals(&beta).expect("resid");
3711        // Row 0 (sampleweight=0) must contribute zero in every channel.
3712        assert_eq!(r.exit[0], 0.0);
3713        assert_eq!(r.entry[0], 0.0);
3714        assert_eq!(r.derivative[0], 0.0);
3715        // Row 1 must still carry a nonzero exit-channel residual.
3716        assert!(r.exit[1] != 0.0);
3717    }
3718
3719    #[test]
3720    fn offset_channel_residuals_reject_beta_dim_mismatch() {
3721        let age_entry = array![1.0_f64];
3722        let age_exit = array![2.0_f64];
3723        let event_target = array![1u8];
3724        let event_competing = array![0u8];
3725        let sampleweight = array![1.0_f64];
3726        let x_entry = array![[1.0, 0.0]];
3727        let x_exit = array![[1.0, 0.7]];
3728        let x_derivative = array![[0.0, 0.5]];
3729        let penalties = PenaltyBlocks::new(Vec::new());
3730        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3731        let model = survival_model(
3732            survival_inputs(
3733                &age_entry,
3734                &age_exit,
3735                &event_target,
3736                &event_competing,
3737                &sampleweight,
3738                &x_entry,
3739                &x_exit,
3740                &x_derivative,
3741            ),
3742            penalties,
3743            mono,
3744            SurvivalSpec::Net,
3745        )
3746        .expect("model");
3747        let bad_beta = array![0.0_f64]; // should be length 2
3748        let err = model
3749            .offset_channel_residuals(&bad_beta)
3750            .expect_err("mismatch must error");
3751        match err {
3752            EstimationError::InvalidInput(msg) => {
3753                assert!(msg.contains("beta dimension mismatch"), "msg={msg}")
3754            }
3755            other => panic!("expected InvalidInput, got {other:?}"),
3756        }
3757    }
3758
3759    #[test]
3760    fn crudespec_is_rejected_by_one_hazard_engine() {
3761        let age_entry = array![1.0_f64];
3762        let age_exit = array![2.0_f64];
3763        let event_target = array![0u8];
3764        let event_competing = array![1u8];
3765        let sampleweight = array![1.0];
3766        let x_entry = array![[0.1]];
3767        let x_exit = array![[0.4]];
3768        let x_derivative = array![[1.0]];
3769        let penalties = PenaltyBlocks::new(Vec::new());
3770        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3771
3772        let err = survival_model(
3773            survival_inputs(
3774                &age_entry,
3775                &age_exit,
3776                &event_target,
3777                &event_competing,
3778                &sampleweight,
3779                &x_entry,
3780                &x_exit,
3781                &x_derivative,
3782            ),
3783            penalties,
3784            mono,
3785            SurvivalSpec::Crude,
3786        )
3787        .expect_err("crude fitting should be rejected by the one-hazard engine");
3788        assert!(matches!(err, SurvivalError::UnsupportedSpec("crude")));
3789    }
3790
3791    #[test]
3792    fn nonstructural_models_require_explicit_monotonicity_collocation() {
3793        let age_entry = array![1.0_f64, 1.5_f64];
3794        let age_exit = array![2.0_f64, 2.5_f64];
3795        let event_target = array![0u8, 0u8];
3796        let event_competing = array![0u8, 1u8];
3797        let sampleweight = array![1.0, 1.0];
3798        let x_entry = array![[0.2], [0.1]];
3799        let x_exit = array![[0.3], [0.2]];
3800        let x_derivative = array![[1.0], [1.0]];
3801
3802        let model = survival_model(
3803            survival_inputs(
3804                &age_entry,
3805                &age_exit,
3806                &event_target,
3807                &event_competing,
3808                &sampleweight,
3809                &x_entry,
3810                &x_exit,
3811                &x_derivative,
3812            ),
3813            PenaltyBlocks::new(Vec::new()),
3814            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3815            SurvivalSpec::Net,
3816        )
3817        .expect("construct censored survival model");
3818
3819        assert!(
3820            model.monotonicity_linear_constraints().is_none(),
3821            "non-structural survival models must not fabricate rowwise monotonicity constraints"
3822        );
3823    }
3824
3825    #[test]
3826    fn decreasing_interval_is_rejectedwithout_target_events() {
3827        let age_entry = array![1.0_f64];
3828        let age_exit = array![2.0_f64];
3829        let event_target = array![0u8];
3830        let event_competing = array![0u8];
3831        let sampleweight = array![1.0];
3832        let x_entry = array![[0.5]];
3833        let x_exit = array![[0.0]];
3834        let x_derivative = array![[1.0]];
3835
3836        let model = survival_model(
3837            survival_inputs(
3838                &age_entry,
3839                &age_exit,
3840                &event_target,
3841                &event_competing,
3842                &sampleweight,
3843                &x_entry,
3844                &x_exit,
3845                &x_derivative,
3846            ),
3847            PenaltyBlocks::new(Vec::new()),
3848            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3849            SurvivalSpec::Net,
3850        )
3851        .expect("construct censored survival model");
3852
3853        let err = model
3854            .update_state(&array![1.0])
3855            .expect_err("decreasing cumulative hazard increment should be rejected");
3856        assert!(
3857            err.to_string().contains("cumulative hazard decreased"),
3858            "unexpected error: {err}"
3859        );
3860    }
3861
3862    fn smooth_crude_risk(beta_d: f64, beta_m: f64) -> CrudeRiskResult {
3863        calculate_crude_risk_quadrature(
3864            0.0,
3865            1.0,
3866            &[0.0, 1.0],
3867            beta_d.exp(),
3868            beta_m.exp(),
3869            array![1.0].view(),
3870            array![1.0].view(),
3871            |u, design_d, deriv_d, design_m| {
3872                let cumulative_d = beta_d.exp() * (1.0 + 0.2 * u);
3873                let cumulative_m = beta_m.exp() * (1.0 + 0.1 * u);
3874                let inst_hazard_d = 0.2 * beta_d.exp();
3875                design_d[0] = 1.0;
3876                // η_d = β_d + ln(1 + 0.2u), so η̇_d = 0.2/(1+0.2u)
3877                // which does not depend on β_d → ∂_{β_d} η̇_d = 0
3878                deriv_d[0] = 0.0;
3879                design_m[0] = 1.0;
3880                Ok((inst_hazard_d, cumulative_d, cumulative_m))
3881            },
3882        )
3883        .expect("smooth crude-risk quadrature should succeed")
3884    }
3885
3886    #[test]
3887    fn crude_riskgradient_matches_monotoneobjective() {
3888        let beta_d = -0.2_f64;
3889        let beta_m = -0.5_f64;
3890        let result = smooth_crude_risk(beta_d, beta_m);
3891        let eps = 1e-6;
3892
3893        let fd_d = (smooth_crude_risk(beta_d + eps, beta_m).risk
3894            - smooth_crude_risk(beta_d - eps, beta_m).risk)
3895            / (2.0 * eps);
3896        let fd_m = (smooth_crude_risk(beta_d, beta_m + eps).risk
3897            - smooth_crude_risk(beta_d, beta_m - eps).risk)
3898            / (2.0 * eps);
3899
3900        assert!(
3901            (result.diseasegradient[0] - fd_d).abs() < 1e-5,
3902            "disease gradient mismatch for monotone crude risk: analytic={} fd={fd_d}",
3903            result.diseasegradient[0]
3904        );
3905        assert!(
3906            (result.mortalitygradient[0] - fd_m).abs() < 1e-5,
3907            "mortality gradient mismatch for monotone crude risk: analytic={} fd={fd_m}",
3908            result.mortalitygradient[0]
3909        );
3910    }
3911
3912    #[test]
3913    fn survivalridge_penalty_scalar_matchesgradienthessian_scaling() {
3914        let age_entry = array![1.0_f64, 2.0_f64];
3915        let age_exit = array![2.0_f64, 3.5_f64];
3916        let event_target = array![1u8, 0u8];
3917        let event_competing = array![0u8, 0u8];
3918        let sampleweight = array![1.0, 1.0];
3919        let x_entry = array![[1.0, age_entry[0].ln()], [1.0, age_entry[1].ln()]];
3920        let x_exit = array![[1.0, age_exit[0].ln()], [1.0, age_exit[1].ln()]];
3921        let x_derivative = array![[0.0, 1.0 / age_exit[0]], [0.0, 1.0 / age_exit[1]]];
3922        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3923            matrix: array![[2.0]],
3924            lambda: 1.7,
3925            range: 1..2,
3926            nullspace_dim: 0,
3927        }]);
3928        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
3929        let beta = array![-1.2, 0.4];
3930
3931        let model = survival_model(
3932            survival_inputs(
3933                &age_entry,
3934                &age_exit,
3935                &event_target,
3936                &event_competing,
3937                &sampleweight,
3938                &x_entry,
3939                &x_exit,
3940                &x_derivative,
3941            ),
3942            penalties.clone(),
3943            mono,
3944            SurvivalSpec::Net,
3945        )
3946        .expect("construct survival model");
3947
3948        let state = model.update_state(&beta).expect("survival state");
3949        let expected_penalty = penalties.deviance(&beta) + 0.5 * state.ridge_used * beta.dot(&beta);
3950        assert!(
3951            (state.penalty_term - expected_penalty).abs() < 1e-12,
3952            "penalty_term mismatch: state={} expected={}",
3953            state.penalty_term,
3954            expected_penalty
3955        );
3956    }
3957
3958    #[test]
3959    fn negative_penalty_lambda_is_rejected() {
3960        let age_entry = array![1.0_f64];
3961        let age_exit = array![2.0_f64];
3962        let event_target = array![1u8];
3963        let event_competing = array![0u8];
3964        let sampleweight = array![1.0];
3965        let x_entry = array![[1.0, 0.0]];
3966        let x_exit = array![[1.0, 0.5]];
3967        let x_derivative = array![[0.0, 1.0]];
3968        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
3969            matrix: array![[1.0]],
3970            lambda: -0.1,
3971            range: 1..2,
3972            nullspace_dim: 0,
3973        }]);
3974
3975        let err = survival_model(
3976            survival_inputs(
3977                &age_entry,
3978                &age_exit,
3979                &event_target,
3980                &event_competing,
3981                &sampleweight,
3982                &x_entry,
3983                &x_exit,
3984                &x_derivative,
3985            ),
3986            penalties,
3987            SurvivalMonotonicityPenalty { tolerance: 0.0 },
3988            SurvivalSpec::Net,
3989        )
3990        .expect_err("negative lambda must be rejected");
3991
3992        assert!(matches!(err, SurvivalError::NonFiniteInput));
3993    }
3994
3995    #[test]
3996    fn penalty_block_range_and_shapemust_match_coefficients() {
3997        let age_entry = array![1.0_f64];
3998        let age_exit = array![2.0_f64];
3999        let event_target = array![1u8];
4000        let event_competing = array![0u8];
4001        let sampleweight = array![1.0];
4002        let x_entry = array![[1.0, 0.0]];
4003        let x_exit = array![[1.0, 0.5]];
4004        let x_derivative = array![[0.0, 1.0]];
4005        let penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4006            matrix: array![[1.0]],
4007            lambda: 0.5,
4008            range: 0..2,
4009            nullspace_dim: 0,
4010        }]);
4011
4012        let err = survival_model(
4013            survival_inputs(
4014                &age_entry,
4015                &age_exit,
4016                &event_target,
4017                &event_competing,
4018                &sampleweight,
4019                &x_entry,
4020                &x_exit,
4021                &x_derivative,
4022            ),
4023            penalties,
4024            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4025            SurvivalSpec::Net,
4026        )
4027        .expect_err("penalty block geometry must match coefficient support");
4028
4029        assert!(matches!(err, SurvivalError::DimensionMismatch));
4030    }
4031
4032    #[test]
4033    fn survivalgradient_matchesobjectivefdwithridge_scaling() {
4034        let age_entry = array![1.0_f64, 2.0_f64, 3.0_f64];
4035        let age_exit = array![2.0_f64, 3.5_f64, 4.0_f64];
4036        let event_target = array![1u8, 0u8, 1u8];
4037        let event_competing = array![0u8, 0u8, 0u8];
4038        let sampleweight = array![1.0, 1.0, 1.0];
4039        let x_entry = array![
4040            [1.0, age_entry[0].ln()],
4041            [1.0, age_entry[1].ln()],
4042            [1.0, age_entry[2].ln()]
4043        ];
4044        let x_exit = array![
4045            [1.0, age_exit[0].ln()],
4046            [1.0, age_exit[1].ln()],
4047            [1.0, age_exit[2].ln()]
4048        ];
4049        let x_derivative = array![
4050            [0.0, 1.0 / age_exit[0]],
4051            [0.0, 1.0 / age_exit[1]],
4052            [0.0, 1.0 / age_exit[2]]
4053        ];
4054        let penalties = PenaltyBlocks::new(Vec::new());
4055        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4056        let beta = array![-1.0, 3.0];
4057
4058        let model = survival_model(
4059            survival_inputs(
4060                &age_entry,
4061                &age_exit,
4062                &event_target,
4063                &event_competing,
4064                &sampleweight,
4065                &x_entry,
4066                &x_exit,
4067                &x_derivative,
4068            ),
4069            penalties,
4070            mono,
4071            SurvivalSpec::Net,
4072        )
4073        .expect("construct survival model");
4074
4075        let state = model.update_state(&beta).expect("state at beta");
4076        let eps = 1e-7;
4077        for j in 0..beta.len() {
4078            let mut plus = beta.clone();
4079            let mut minus = beta.clone();
4080            plus[j] += eps;
4081            minus[j] -= eps;
4082            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4083            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4084            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4085            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4086            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4087            assert_eq!(
4088                state.gradient[j].signum(),
4089                fd.signum(),
4090                "objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4091                state.gradient[j]
4092            );
4093            assert!(
4094                (state.gradient[j] - fd).abs() < 1e-5,
4095                "objective/gradient mismatch at j={j}: grad={} fd={fd}",
4096                state.gradient[j]
4097            );
4098        }
4099    }
4100
4101    fn laml_fd_test_model(lambda: f64) -> WorkingModelSurvival {
4102        // 20-subject survival fixture with mean-centered log-age time
4103        // covariate, balanced events/censorings, and moderate hazard levels.
4104        // The fixture is large enough that the observed-information Hessian
4105        // is well-conditioned at the MLE so PIRLS reaches the 1e-10 KKT
4106        // tolerance in well under 80 iterations from the starting beta used
4107        // below.
4108        let age_entry: Array1<f64> = Array1::from(vec![
4109            30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 62.0,
4110            34.0, 39.0, 44.0, 49.0, 54.0, 59.0,
4111        ]);
4112        let age_exit: Array1<f64> = Array1::from(vec![
4113            45.0, 48.0, 55.0, 58.0, 62.0, 66.0, 68.0, 47.0, 52.0, 53.0, 55.0, 60.0, 63.0, 70.0,
4114            48.0, 51.0, 58.0, 62.0, 66.0, 69.0,
4115        ]);
4116        let event_target = Array1::from(vec![
4117            1u8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
4118        ]);
4119        let event_competing = Array1::<u8>::zeros(age_entry.len());
4120        let sampleweight = Array1::from_elem(age_entry.len(), 1.0_f64);
4121        let n = age_entry.len();
4122        let ln_age_mean: f64 = {
4123            let mut sum = 0.0;
4124            for i in 0..n {
4125                sum += age_entry[i].ln() + age_exit[i].ln();
4126            }
4127            sum / (2.0 * n as f64)
4128        };
4129        let mut x_entry = Array2::<f64>::zeros((n, 2));
4130        let mut x_exit = Array2::<f64>::zeros((n, 2));
4131        let mut x_derivative = Array2::<f64>::zeros((n, 2));
4132        for i in 0..n {
4133            x_entry[[i, 0]] = 1.0;
4134            x_exit[[i, 0]] = 1.0;
4135            x_entry[[i, 1]] = age_entry[i].ln() - ln_age_mean;
4136            x_exit[[i, 1]] = age_exit[i].ln() - ln_age_mean;
4137            x_derivative[[i, 0]] = 0.0;
4138            x_derivative[[i, 1]] = 1.0 / age_exit[i];
4139        }
4140        let penalties = PenaltyBlocks::new(vec![
4141            PenaltyBlock {
4142                matrix: array![[3.0]],
4143                lambda: 0.0,
4144                range: 0..1,
4145                nullspace_dim: 0,
4146            },
4147            PenaltyBlock {
4148                matrix: array![[2.5]],
4149                lambda,
4150                range: 1..2,
4151                nullspace_dim: 0,
4152            },
4153        ]);
4154        survival_model(
4155            survival_inputs(
4156                &age_entry,
4157                &age_exit,
4158                &event_target,
4159                &event_competing,
4160                &sampleweight,
4161                &x_entry,
4162                &x_exit,
4163                &x_derivative,
4164            ),
4165            penalties,
4166            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
4167            SurvivalSpec::Net,
4168        )
4169        .expect("construct LAML FD survival model")
4170    }
4171
4172    fn laml_test_logdet_h(state: &WorkingState) -> f64 {
4173        use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4174        use gam_linalg::faer_ndarray::FaerEigh;
4175
4176        let h_dense = state.hessian.to_dense();
4177        let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4178        let eps = spectral_epsilon(evals.as_slice().unwrap());
4179        evals
4180            .iter()
4181            .map(|&sigma| spectral_regularize(sigma, eps).ln())
4182            .sum()
4183    }
4184
4185    #[test]
4186    fn laml_gradient_and_objective_ignore_inactive_penalty_prefix_blocks() {
4187        // The core claim under test: the survival LAML rho-gradient and the
4188        // documented LAML objective enumerate only penalty blocks whose
4189        // lambda is actually active (> 0). An inactive prefix block must
4190        // therefore contribute neither a log|lambda * S| term to the
4191        // objective nor an entry to the rho-gradient vector.
4192        //
4193        // We verify the objective formula and the gradient dimensionality at
4194        // a fixed beta rather than a fitted one: the bug this test guards
4195        // against was purely algebraic enumeration over penalty blocks and
4196        // has no dependence on PIRLS convergence quality. A gradient-vs-FD
4197        // comparison would require beta to sit at the joint MLE of a tiny
4198        // synthetic survival fixture, which the analytic Newton/PIRLS path
4199        // cannot reach to 1e-10 KKT tolerance without a much richer design.
4200        let rho0 = -0.35_f64;
4201        let beta = array![-2.5_f64, 1.0];
4202        let model = laml_fd_test_model(rho0.exp());
4203        let state = model
4204            .update_state(&beta)
4205            .expect("state for LAML prefix-skip test");
4206
4207        // Sanity: the fixture has two penalty blocks; the first has
4208        // lambda = 0 (inactive prefix) and the second has lambda > 0
4209        // (active). If a future refactor flips this ordering, the prefix
4210        // skip being exercised here would silently become an identity test.
4211        assert_eq!(model.penalties.blocks.len(), 2);
4212        assert_eq!(model.penalties.blocks[0].lambda, 0.0);
4213        assert!(model.penalties.blocks[1].lambda > 0.0);
4214
4215        let rho = Array1::from_iter(
4216            model
4217                .penalties
4218                .blocks
4219                .iter()
4220                .filter(|b| b.lambda > 0.0)
4221                .map(|b| b.lambda.ln()),
4222        );
4223        assert_eq!(
4224            rho.len(),
4225            1,
4226            "fixture should expose exactly one active penalty block for the rho vector"
4227        );
4228
4229        let (obj, grad) = model
4230            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4231            .expect("survival LAML objective and gradient");
4232
4233        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * laml_test_logdet_h(&state)
4234            - 0.5 * (rho0 + 2.5_f64.ln());
4235        assert_eq!(
4236            grad.len(),
4237            1,
4238            "rho-gradient must match the active-penalty count, not the full block list"
4239        );
4240        assert!(
4241            (obj - expected).abs() < 1e-10,
4242            "survival LAML objective mismatch with inactive prefix block: obj={obj} expected={expected}",
4243        );
4244        assert!(
4245            grad[0].is_finite(),
4246            "rho-gradient must be finite: {}",
4247            grad[0]
4248        );
4249    }
4250
4251    #[test]
4252    fn structural_monotonicgradient_matchesobjectivefd() {
4253        let age_entry = array![1.0_f64, 1.3_f64, 1.8_f64];
4254        let age_exit = array![1.6_f64, 2.1_f64, 2.7_f64];
4255        let event_target = array![1u8, 0u8, 1u8];
4256        let event_competing = array![0u8, 0u8, 0u8];
4257        let sampleweight = array![1.0, 1.0, 1.0];
4258
4259        // Time block has 3 structural-monotone columns.
4260        // Final column is a covariate, left unconstrained.
4261        let x_entry = array![
4262            [1.0, 0.2, 0.05, -0.7],
4263            [1.0, 0.5, 0.20, 0.1],
4264            [1.0, 0.9, 0.60, 1.2]
4265        ];
4266        let x_exit = array![
4267            [1.0, 0.4, 0.16, -0.7],
4268            [1.0, 0.8, 0.64, 0.1],
4269            [1.0, 1.1, 1.21, 1.2]
4270        ];
4271        let x_derivative = array![
4272            [0.0, 0.8, 0.64, 0.0],
4273            [0.0, 0.7, 1.12, 0.0],
4274            [0.0, 0.6, 1.32, 0.0]
4275        ];
4276        let penalties = PenaltyBlocks::new(Vec::new());
4277        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4278        let mut model = survival_model(
4279            survival_inputs(
4280                &age_entry,
4281                &age_exit,
4282                &event_target,
4283                &event_competing,
4284                &sampleweight,
4285                &x_entry,
4286                &x_exit,
4287                &x_derivative,
4288            ),
4289            penalties,
4290            mono,
4291            SurvivalSpec::Net,
4292        )
4293        .expect("construct structural survival model");
4294        model
4295            .set_structural_monotonicity(true, 3)
4296            .expect("enable structural monotonicity");
4297        let constraints = model
4298            .monotonicity_linear_constraints()
4299            .expect("structural derivative constraints");
4300        assert_eq!(constraints.a.nrows(), 2);
4301        assert_eq!(constraints.a.ncols(), 4);
4302        assert_eq!(constraints.a.row(0).to_vec(), vec![0.0, 1.0, 0.0, 0.0]);
4303        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 0.0, 1.0, 0.0]);
4304        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4305
4306        let beta = array![0.2, 0.2, 0.1, 0.2];
4307        let state = model.update_state(&beta).expect("state at structural beta");
4308        let eps = 1e-7;
4309        for j in 0..beta.len() {
4310            let mut plus = beta.clone();
4311            let mut minus = beta.clone();
4312            plus[j] += eps;
4313            minus[j] -= eps;
4314            let state_plus = model.update_state(&plus).expect("state at beta + eps");
4315            let state_minus = model.update_state(&minus).expect("state at beta - eps");
4316            let obj_plus = 0.5 * state_plus.deviance + state_plus.penalty_term;
4317            let obj_minus = 0.5 * state_minus.deviance + state_minus.penalty_term;
4318            let fd = (obj_plus - obj_minus) / (2.0 * eps);
4319            assert_eq!(
4320                state.gradient[j].signum(),
4321                fd.signum(),
4322                "structural objective/gradient sign mismatch at j={j}: grad={} fd={fd}",
4323                state.gradient[j]
4324            );
4325            assert!(
4326                (state.gradient[j] - fd).abs() < 2e-5,
4327                "structural objective/gradient mismatch at j={j}: grad={} fd={fd}",
4328                state.gradient[j]
4329            );
4330        }
4331    }
4332
4333    #[test]
4334    fn structural_monotonic_lamlgradient_returns_finitevalues() {
4335        let age_entry = array![1.0_f64, 1.2_f64];
4336        let age_exit = array![1.5_f64, 2.0_f64];
4337        let event_target = array![1u8, 0u8];
4338        let event_competing = array![0u8, 0u8];
4339        let sampleweight = array![1.0, 1.0];
4340        let x_entry = array![[1.0, 0.2, -0.5], [1.0, 0.4, 0.2]];
4341        let x_exit = array![[1.0, 0.5, -0.5], [1.0, 0.8, 0.2]];
4342        let x_derivative = array![[0.0, 0.9, 0.0], [0.0, 0.7, 0.0]];
4343        let penalties = PenaltyBlocks::new(Vec::new());
4344        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4345        let mut model = survival_model(
4346            survival_inputs(
4347                &age_entry,
4348                &age_exit,
4349                &event_target,
4350                &event_competing,
4351                &sampleweight,
4352                &x_entry,
4353                &x_exit,
4354                &x_derivative,
4355            ),
4356            penalties,
4357            mono,
4358            SurvivalSpec::Net,
4359        )
4360        .expect("construct structural survival model");
4361        model
4362            .set_structural_monotonicity(true, 2)
4363            .expect("enable structural monotonicity");
4364        // One simple penalty block to exercise rho-gradient path.
4365        model.penalties = PenaltyBlocks::new(vec![PenaltyBlock {
4366            matrix: array![[1.0]],
4367            lambda: 0.7,
4368            range: 1..2,
4369            nullspace_dim: 0,
4370        }]);
4371        let beta = array![0.2, 0.2, 0.1];
4372        let state = model.update_state(&beta).expect("state at structural beta");
4373        let rho = Array1::from_iter(
4374            model
4375                .penalties
4376                .blocks
4377                .iter()
4378                .filter(|b| b.lambda > 0.0)
4379                .map(|b| b.lambda.ln()),
4380        );
4381        let (obj, grad) = model
4382            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4383            .expect("laml gradient should work in structural mode");
4384        assert!(obj.is_finite());
4385        assert_eq!(grad.len(), 1);
4386        assert!(grad[0].is_finite());
4387    }
4388
4389    #[test]
4390    fn structural_monotonicity_switches_to_tiny_derivative_guard_constraints() {
4391        let age_entry = array![1.0_f64];
4392        let age_exit = array![2.0_f64];
4393        let event_target = array![1u8];
4394        let event_competing = array![0u8];
4395        let sampleweight = array![1.0];
4396        let x_entry = array![[0.0]];
4397        let x_exit = array![[0.2]];
4398        let x_derivative = array![[1.0]];
4399
4400        let penalties = PenaltyBlocks::new(Vec::new());
4401        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4402        let mut model = survival_model(
4403            survival_inputs(
4404                &age_entry,
4405                &age_exit,
4406                &event_target,
4407                &event_competing,
4408                &sampleweight,
4409                &x_entry,
4410                &x_exit,
4411                &x_derivative,
4412            ),
4413            penalties,
4414            mono,
4415            SurvivalSpec::Net,
4416        )
4417        .expect("construct structural survival model");
4418
4419        let beta = array![-3.0];
4420        assert!(
4421            model.update_state(&beta).is_err(),
4422            "negative derivative coefficient should violate derivative guard"
4423        );
4424
4425        model
4426            .set_structural_monotonicity(true, 1)
4427            .expect("enable structural monotonicity");
4428        let constraints = model
4429            .monotonicity_linear_constraints()
4430            .expect("structural derivative constraints");
4431        assert_eq!(constraints.a.nrows(), 1);
4432        assert_eq!(constraints.a.ncols(), 1);
4433        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
4434        // Structural monotonicity uses derivative_guard() == 0.0
4435        assert!(constraints.b[0].abs() <= 1e-12);
4436        let state = model
4437            .update_state(&array![1e-6])
4438            .expect("small positive derivative coefficient should remain feasible");
4439        assert!(state.deviance.is_finite());
4440    }
4441
4442    #[test]
4443    fn derivative_offset_must_clear_nonstructural_monotonicity_threshold() {
4444        let age_entry = array![1.0_f64];
4445        let age_exit = array![2.0_f64];
4446        let event_target = array![1u8];
4447        let event_competing = array![0u8];
4448        let sampleweight = array![1.0];
4449        let x_entry = array![[1.0, 0.0]];
4450        let x_exit = array![[1.0, 0.0]];
4451        let x_derivative = array![[0.0, 0.0]];
4452        let penalties = PenaltyBlocks::new(Vec::new());
4453        let monotonicity = SurvivalMonotonicityPenalty { tolerance: 3.0 };
4454        let eta_entry_offset = array![0.0];
4455        let eta_exit_offset = array![0.0];
4456        let derivative_offset_below_guard = array![2.0];
4457        let derivative_offset_above_guard = array![3.1];
4458        let offsets_below_guard = SurvivalBaselineOffsets {
4459            eta_entry: eta_entry_offset.view(),
4460            eta_exit: eta_exit_offset.view(),
4461            derivative_exit: derivative_offset_below_guard.view(),
4462        };
4463        let offsets_above_guard = SurvivalBaselineOffsets {
4464            eta_entry: eta_entry_offset.view(),
4465            eta_exit: eta_exit_offset.view(),
4466            derivative_exit: derivative_offset_above_guard.view(),
4467        };
4468
4469        let model_below_guard = survival_model_with_offsets(
4470            survival_inputs(
4471                &age_entry,
4472                &age_exit,
4473                &event_target,
4474                &event_competing,
4475                &sampleweight,
4476                &x_entry,
4477                &x_exit,
4478                &x_derivative,
4479            ),
4480            Some(offsets_below_guard),
4481            penalties.clone(),
4482            monotonicity,
4483            SurvivalSpec::Net,
4484        )
4485        .expect("construct model with derivative offset below guard");
4486        let err = model_below_guard
4487            .update_state(&array![0.0, 0.0])
4488            .expect_err("derivative offset below guard should be rejected");
4489        let err_text = err.to_string();
4490        assert!(
4491            err_text.contains("d_eta/dt=2.000e0") && err_text.contains("tolerance=3.000e0"),
4492            "expected derivative guard rejection to report the offset-driven derivative: {err_text}"
4493        );
4494
4495        let model_above_guard = survival_model_with_offsets(
4496            survival_inputs(
4497                &age_entry,
4498                &age_exit,
4499                &event_target,
4500                &event_competing,
4501                &sampleweight,
4502                &x_entry,
4503                &x_exit,
4504                &x_derivative,
4505            ),
4506            Some(offsets_above_guard),
4507            penalties,
4508            SurvivalMonotonicityPenalty { tolerance: 3.0 },
4509            SurvivalSpec::Net,
4510        )
4511        .expect("construct model with derivative offset above guard");
4512        let state = model_above_guard
4513            .update_state(&array![0.0, 0.0])
4514            .expect("derivative offset above guard should remain feasible");
4515        assert!(state.deviance.is_finite());
4516    }
4517
4518    #[test]
4519    fn structural_monotonicity_rejects_negative_derivative_offsets() {
4520        let age_entry = array![1.0_f64];
4521        let age_exit = array![2.0_f64];
4522        let event_target = array![1u8];
4523        let event_competing = array![0u8];
4524        let sampleweight = array![1.0];
4525        let x_entry = array![[0.0]];
4526        let x_exit = array![[0.2]];
4527        let x_derivative = array![[1.0]];
4528        let eta_entry = array![0.0];
4529        let eta_exit = array![0.0];
4530        let derivative_exit = array![-1e-3];
4531        let offsets = SurvivalBaselineOffsets {
4532            eta_entry: eta_entry.view(),
4533            eta_exit: eta_exit.view(),
4534            derivative_exit: derivative_exit.view(),
4535        };
4536
4537        let mut model = survival_model_with_offsets(
4538            survival_inputs(
4539                &age_entry,
4540                &age_exit,
4541                &event_target,
4542                &event_competing,
4543                &sampleweight,
4544                &x_entry,
4545                &x_exit,
4546                &x_derivative,
4547            ),
4548            Some(offsets),
4549            PenaltyBlocks::new(Vec::new()),
4550            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4551            SurvivalSpec::Net,
4552        )
4553        .expect("construct structural survival model");
4554        let err = model
4555            .set_structural_monotonicity(true, 1)
4556            .expect_err("negative derivative offsets must be rejected");
4557        assert!(
4558            err.to_string()
4559                .contains("structural monotonicity requires nonnegative derivative offsets"),
4560            "unexpected error: {err}"
4561        );
4562    }
4563
4564    #[test]
4565    fn structural_monotonicity_emits_coefficient_constraints() {
4566        let age_entry = array![1.0_f64, 1.5_f64];
4567        let age_exit = array![2.0_f64, 3.0_f64];
4568        let event_target = array![1u8, 0u8];
4569        let event_competing = array![0u8, 0u8];
4570        let sampleweight = array![1.0, 1.0];
4571        let x_entry = array![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]];
4572        let x_exit = array![[0.2, 0.4, 1.0], [0.3, 0.5, 1.0]];
4573        let x_derivative = array![[0.3, 0.2, 0.0], [0.4, 0.1, 0.0]];
4574
4575        let mut model = survival_model(
4576            survival_inputs(
4577                &age_entry,
4578                &age_exit,
4579                &event_target,
4580                &event_competing,
4581                &sampleweight,
4582                &x_entry,
4583                &x_exit,
4584                &x_derivative,
4585            ),
4586            PenaltyBlocks::new(Vec::new()),
4587            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4588            SurvivalSpec::Net,
4589        )
4590        .expect("construct structural survival model");
4591        model
4592            .set_structural_monotonicity(true, 2)
4593            .expect("enable structural monotonicity");
4594
4595        let constraints = model
4596            .monotonicity_linear_constraints()
4597            .expect("structural derivative constraints");
4598
4599        assert_eq!(constraints.a.nrows(), 2);
4600        assert_eq!(constraints.a.ncols(), 3);
4601        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
4602        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
4603        assert!(constraints.b.iter().all(|&v| v.abs() <= 1e-12));
4604    }
4605
4606    #[test]
4607    fn structural_monotonicity_preserves_inactive_time_columns_in_constraints() {
4608        let age_entry = array![1.0_f64];
4609        let age_exit = array![2.0_f64];
4610        let event_target = array![1u8];
4611        let event_competing = array![0u8];
4612        let sampleweight = array![1.0];
4613        let x_entry = array![[1.0, 0.2]];
4614        let x_exit = array![[1.0, 0.6]];
4615        let x_derivative = array![[0.0, 1.0]];
4616
4617        let mut model = survival_model(
4618            survival_inputs(
4619                &age_entry,
4620                &age_exit,
4621                &event_target,
4622                &event_competing,
4623                &sampleweight,
4624                &x_entry,
4625                &x_exit,
4626                &x_derivative,
4627            ),
4628            PenaltyBlocks::new(Vec::new()),
4629            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4630            SurvivalSpec::Net,
4631        )
4632        .expect("construct structural survival model");
4633        model
4634            .set_structural_monotonicity(true, 2)
4635            .expect("enable structural monotonicity");
4636
4637        let constraints = model
4638            .monotonicity_linear_constraints()
4639            .expect("structural derivative constraints");
4640
4641        assert_eq!(constraints.a.nrows(), 1);
4642        assert!(
4643            constraints.a[[0, 0]].abs() <= 1e-12,
4644            "inactive time column should remain unconstrained"
4645        );
4646        assert!(
4647            (constraints.a[[0, 1]] - 1.0).abs() <= 1e-12,
4648            "active time column should remain constrained"
4649        );
4650    }
4651
4652    #[test]
4653    fn structural_monotonicity_preserves_sparse_row_patterns() {
4654        let age_entry = array![1.0_f64, 1.5_f64];
4655        let age_exit = array![2.0_f64, 2.5_f64];
4656        let event_target = array![1u8, 1u8];
4657        let event_competing = array![0u8, 0u8];
4658        let sampleweight = array![1.0, 1.0];
4659        let x_entry = array![[0.0, 0.0], [0.0, 0.0]];
4660        let x_exit = array![[0.4, 0.2], [0.6, 0.3]];
4661        let x_derivative = array![[1.0, 0.0], [1.0, 0.5]];
4662
4663        let mut model = survival_model(
4664            survival_inputs(
4665                &age_entry,
4666                &age_exit,
4667                &event_target,
4668                &event_competing,
4669                &sampleweight,
4670                &x_entry,
4671                &x_exit,
4672                &x_derivative,
4673            ),
4674            PenaltyBlocks::new(Vec::new()),
4675            SurvivalMonotonicityPenalty { tolerance: 0.0 },
4676            SurvivalSpec::Net,
4677        )
4678        .expect("construct structural survival model");
4679        model
4680            .set_structural_monotonicity(true, 2)
4681            .expect("enable structural monotonicity");
4682
4683        let constraints = model
4684            .monotonicity_linear_constraints()
4685            .expect("structural derivative constraints");
4686
4687        assert_eq!(constraints.a.nrows(), 2);
4688        assert_eq!(constraints.a.row(0).to_vec(), vec![1.0, 0.0]);
4689        assert_eq!(constraints.a.row(1).to_vec(), vec![0.0, 1.0]);
4690    }
4691
4692    #[test]
4693    fn update_state_rejects_negative_exit_derivative_for_censoredrows() {
4694        let age_entry = array![1.0_f64];
4695        let age_exit = array![1.1_f64];
4696        let event_target = array![0u8];
4697        let event_competing = array![0u8];
4698        let sampleweight = array![1.0];
4699        let x_entry = array![[0.0]];
4700        let x_exit = array![[0.0]];
4701        let x_derivative = array![[-1.0]];
4702        let penalties = PenaltyBlocks::new(Vec::new());
4703        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4704        let model = survival_model(
4705            survival_inputs(
4706                &age_entry,
4707                &age_exit,
4708                &event_target,
4709                &event_competing,
4710                &sampleweight,
4711                &x_entry,
4712                &x_exit,
4713                &x_derivative,
4714            ),
4715            penalties,
4716            mono,
4717            SurvivalSpec::Net,
4718        )
4719        .expect("construct censored survival model");
4720
4721        let err = model
4722            .update_state(&array![1.0])
4723            .expect_err("censored row should still enforce monotonic derivative");
4724        assert!(
4725            matches!(err, EstimationError::ParameterConstraintViolation(_)),
4726            "unexpected error: {err:?}"
4727        );
4728    }
4729
4730    fn crude_risk_quadrature_error(
4731        cumulative_entry: f64,
4732        cumulative_exit: f64,
4733        hazard_exit: f64,
4734    ) -> SurvivalError {
4735        calculate_crude_risk_quadrature(
4736            1.0,
4737            2.0,
4738            &[],
4739            0.4,
4740            0.2,
4741            array![1.0].view(),
4742            array![1.0].view(),
4743            |_, design_d, deriv_d, design_m| {
4744                design_d[0] = 1.0;
4745                deriv_d[0] = 0.0;
4746                design_m[0] = 1.0;
4747                Ok((cumulative_entry, cumulative_exit, hazard_exit))
4748            },
4749        )
4750        .expect_err("invalid hazards should fail")
4751    }
4752
4753    #[test]
4754    fn crude_risk_quadrature_rejects_decreasing_cumulative_hazard() {
4755        let err = crude_risk_quadrature_error(0.1, 0.3, 0.25);
4756        assert!(matches!(err, SurvivalError::NonMonotoneCumulativeHazard));
4757    }
4758
4759    #[test]
4760    fn crude_risk_quadrature_rejects_nonpositive_instantaneous_hazard() {
4761        let err = crude_risk_quadrature_error(0.0, 0.4, 0.25);
4762        assert!(matches!(err, SurvivalError::NonPositiveHazard));
4763    }
4764
4765    #[test]
4766    fn laml_no_penalties_matches_documentedobjective() {
4767        let age_entry = array![40.0, 45.0, 50.0, 55.0];
4768        let age_exit = array![44.0, 49.0, 54.0, 59.0];
4769        let event_target = array![1u8, 0u8, 1u8, 0u8];
4770        let event_competing = Array1::<u8>::zeros(4);
4771        let sampleweight = Array1::ones(4);
4772        let x_entry = array![
4773            [1.0, -0.2, 0.04],
4774            [1.0, -0.1, 0.01],
4775            [1.0, 0.0, 0.0],
4776            [1.0, 0.1, 0.01]
4777        ];
4778        let x_exit = array![
4779            [1.0, -0.12, 0.0144],
4780            [1.0, -0.02, 0.0004],
4781            [1.0, 0.08, 0.0064],
4782            [1.0, 0.18, 0.0324]
4783        ];
4784        let x_derivative = array![
4785            [0.0, 0.02, 0.001],
4786            [0.0, 0.02, 0.001],
4787            [0.0, 0.02, 0.001],
4788            [0.0, 0.02, 0.001]
4789        ];
4790        let penalties = PenaltyBlocks::new(Vec::new());
4791        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4792        let beta = array![-2.0, 0.7, 0.2];
4793
4794        let model = survival_model(
4795            survival_inputs(
4796                &age_entry,
4797                &age_exit,
4798                &event_target,
4799                &event_competing,
4800                &sampleweight,
4801                &x_entry,
4802                &x_exit,
4803                &x_derivative,
4804            ),
4805            penalties,
4806            mono,
4807            SurvivalSpec::Net,
4808        )
4809        .expect("construct survival model");
4810
4811        let state = model.update_state(&beta).expect("state at beta");
4812        let rho = Array1::from_iter(
4813            model
4814                .penalties
4815                .blocks
4816                .iter()
4817                .filter(|b| b.lambda > 0.0)
4818                .map(|b| b.lambda.ln()),
4819        );
4820        let (obj, grad) = model
4821            .unified_lamlobjective_and_rhogradient(&beta, &state, &rho)
4822            .expect("laml objective for no-penalty model");
4823
4824        let h_dense = state.hessian.to_dense();
4825        let logdet_h: f64 = {
4826            use gam_solve::estimate::reml::reml_outer_engine::{spectral_epsilon, spectral_regularize};
4827            use gam_linalg::faer_ndarray::FaerEigh;
4828            let (evals, _) = h_dense.eigh(faer::Side::Lower).expect("eigh");
4829            let eps = spectral_epsilon(evals.as_slice().unwrap());
4830            evals
4831                .iter()
4832                .map(|&sigma| spectral_regularize(sigma, eps).ln())
4833                .sum()
4834        };
4835        let expected = 0.5 * state.deviance + state.penalty_term + 0.5 * logdet_h;
4836
4837        assert_eq!(grad.len(), 0);
4838        assert!(
4839            (obj - expected).abs() < 1e-10,
4840            "no-penalty LAML objective mismatch: obj={} expected={}",
4841            obj,
4842            expected
4843        );
4844    }
4845
4846    #[test]
4847    fn monotonicity_constraints_collapse_positive_collinearrows() {
4848        let a = array![[0.0, 0.5, 0.0], [0.0, 0.25, 0.0], [0.0, 0.125, 0.0]];
4849        let b = array![1e-8, 1e-8, 1e-8];
4850
4851        let compressed = compress_positive_collinear_constraints(&a, &b);
4852
4853        assert_eq!(compressed.a.nrows(), 1);
4854        assert_eq!(compressed.a.ncols(), 3);
4855        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4856        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4857        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4858        assert!((compressed.b[0] - 8e-8).abs() <= 1e-18);
4859    }
4860
4861    #[test]
4862    fn monotonicity_constraints_preserve_distinct_directions() {
4863        let a = array![[1.0, 0.0], [0.0, 1.0], [2.0, 0.0]];
4864        let b = array![0.2, 0.3, 0.1];
4865
4866        let compressed = compress_positive_collinear_constraints(&a, &b);
4867
4868        assert_eq!(compressed.a.nrows(), 2);
4869        let mut saw_x = false;
4870        let mut saw_y = false;
4871        for i in 0..compressed.a.nrows() {
4872            if (compressed.a[[i, 0]] - 1.0).abs() <= 1e-12 && compressed.a[[i, 1]].abs() <= 1e-12 {
4873                saw_x = true;
4874                assert!((compressed.b[i] - 0.2).abs() <= 1e-12);
4875            }
4876            if compressed.a[[i, 0]].abs() <= 1e-12 && (compressed.a[[i, 1]] - 1.0).abs() <= 1e-12 {
4877                saw_y = true;
4878                assert!((compressed.b[i] - 0.3).abs() <= 1e-12);
4879            }
4880        }
4881        assert!(saw_x);
4882        assert!(saw_y);
4883    }
4884
4885    #[test]
4886    fn monotonicity_constraints_cluster_near_collinearrows() {
4887        let a = array![
4888            [0.0, 0.5, 0.0],
4889            [0.0, 0.50000000003, 0.0],
4890            [0.0, 0.49999999997, 0.0]
4891        ];
4892        let b = array![1e-8, 1.00000000005e-8, 0.99999999995e-8];
4893
4894        let compressed = compress_positive_collinear_constraints(&a, &b);
4895
4896        assert_eq!(compressed.a.nrows(), 1);
4897        assert_eq!(compressed.a.ncols(), 3);
4898        assert!(compressed.a[[0, 0]].abs() <= 1e-12);
4899        assert!((compressed.a[[0, 1]] - 1.0).abs() <= 1e-12);
4900        assert!(compressed.a[[0, 2]].abs() <= 1e-12);
4901        assert!((compressed.b[0] - 2.0e-8).abs() <= 1e-18);
4902    }
4903
4904    #[test]
4905    fn monotonicity_constraints_cluster_spline_like_near_duplicates() {
4906        let a = array![
4907            [0.0, 0.401, 0.302, 0.197],
4908            [0.0, 0.40100000003, 0.30199999998, 0.19700000001],
4909            [0.0, 0.40099999997, 0.30200000002, 0.19699999999],
4910            [0.0, 0.125, 0.500, 0.375]
4911        ];
4912        let b = array![2.0e-8, 2.00000000004e-8, 1.99999999996e-8, 3.0e-8];
4913
4914        let compressed = compress_positive_collinear_constraints(&a, &b);
4915
4916        assert_eq!(compressed.a.nrows(), 2);
4917        let mut clustered_face = false;
4918        let mut distinct_face = false;
4919        for i in 0..compressed.a.nrows() {
4920            let row = compressed.a.row(i);
4921            if row[1] > 0.99 && row[2] > 0.7 && row[3] > 0.49 {
4922                clustered_face = true;
4923                assert!((compressed.b[i] - (2.0e-8 / 0.401)).abs() <= 1e-12);
4924            } else {
4925                distinct_face = true;
4926                assert!((row[1] - 0.25).abs() <= 1e-12);
4927                assert!((row[2] - 1.0).abs() <= 1e-12);
4928                assert!((row[3] - 0.75).abs() <= 1e-12);
4929                assert!((compressed.b[i] - 6.0e-8).abs() <= 1e-18);
4930            }
4931        }
4932        assert!(clustered_face);
4933        assert!(distinct_face);
4934    }
4935
4936    #[test]
4937    fn linear_time_monotonicity_constraints_reduce_to_single_halfspace() {
4938        let age_entry = array![1.0_f64, 1.0, 1.0];
4939        let age_exit = array![2.0_f64, 4.0, 8.0];
4940        let event_target = array![0u8, 1u8, 0u8];
4941        let event_competing = array![0u8, 0u8, 0u8];
4942        let sampleweight = array![1.0, 1.0, 1.0];
4943        let x_entry = array![
4944            [1.0, age_entry[0].ln()],
4945            [1.0, age_entry[1].ln()],
4946            [1.0, age_entry[2].ln()]
4947        ];
4948        let x_exit = array![
4949            [1.0, age_exit[0].ln()],
4950            [1.0, age_exit[1].ln()],
4951            [1.0, age_exit[2].ln()]
4952        ];
4953        let x_derivative = array![[0.0, 0.5], [0.0, 0.25], [0.0, 0.125]];
4954        let penalties = PenaltyBlocks::new(Vec::new());
4955        let mono = SurvivalMonotonicityPenalty { tolerance: 1e-8 };
4956
4957        let collocation_offsets = Array1::zeros(x_derivative.nrows());
4958        let mut inputs = survival_inputs(
4959            &age_entry,
4960            &age_exit,
4961            &event_target,
4962            &event_competing,
4963            &sampleweight,
4964            &x_entry,
4965            &x_exit,
4966            &x_derivative,
4967        );
4968        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
4969        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
4970
4971        let model = survival_model(inputs, penalties, mono, SurvivalSpec::Net)
4972            .expect("construct linear survival model");
4973
4974        let constraints = model
4975            .monotonicity_linear_constraints()
4976            .expect("monotonicity constraints");
4977        assert_eq!(constraints.a.nrows(), 1);
4978        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
4979        assert!((constraints.b[0] - 8e-8).abs() <= 1e-12);
4980    }
4981
4982    #[test]
4983    fn monotonicity_constraints_skip_numericallyzerorows() {
4984        let age_entry = array![1.0_f64, 1.0, 1.0];
4985        let age_exit = array![2.0_f64, 3.0, 4.0];
4986        let event_target = array![0u8, 0u8, 0u8];
4987        let event_competing = array![0u8, 0u8, 0u8];
4988        let sampleweight = array![1.0, 1.0, 1.0];
4989        let x_entry = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
4990        let x_exit = x_entry.clone();
4991        let x_derivative = array![[0.0, 0.0], [0.0, 1e-16], [0.0, 0.25]];
4992
4993        let collocation_offsets = Array1::zeros(x_derivative.nrows());
4994        let mut inputs = survival_inputs(
4995            &age_entry,
4996            &age_exit,
4997            &event_target,
4998            &event_competing,
4999            &sampleweight,
5000            &x_entry,
5001            &x_exit,
5002            &x_derivative,
5003        );
5004        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5005        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5006
5007        let model = survival_model(
5008            inputs,
5009            PenaltyBlocks::new(Vec::new()),
5010            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5011            SurvivalSpec::Net,
5012        )
5013        .expect("construct survival model");
5014
5015        let constraints = model
5016            .monotonicity_linear_constraints()
5017            .expect("nonzero derivative row should remain");
5018        assert_eq!(constraints.a.nrows(), 1);
5019        assert!((constraints.a[[0, 1]] - 1.0).abs() <= 1e-12);
5020        assert!(constraints.b[0].abs() <= 1e-18);
5021    }
5022
5023    #[test]
5024    fn censoredrows_allowzero_boundary_derivative() {
5025        let age_entry = array![1.0_f64];
5026        let age_exit = array![2.0_f64];
5027        let event_target = array![0u8];
5028        let event_competing = array![0u8];
5029        let sampleweight = array![1.0];
5030        let x_entry = array![[0.0]];
5031        let x_exit = array![[0.0]];
5032        let x_derivative = array![[1.0]];
5033
5034        let model = survival_model(
5035            survival_inputs(
5036                &age_entry,
5037                &age_exit,
5038                &event_target,
5039                &event_competing,
5040                &sampleweight,
5041                &x_entry,
5042                &x_exit,
5043                &x_derivative,
5044            ),
5045            PenaltyBlocks::new(Vec::new()),
5046            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5047            SurvivalSpec::Net,
5048        )
5049        .expect("construct censored survival model");
5050
5051        let state = model
5052            .update_state(&array![0.0])
5053            .expect("censored boundary derivative should remain feasible with zero tolerance");
5054        assert!(state.deviance.is_finite());
5055    }
5056
5057    #[test]
5058    fn eventrows_keep_positive_derivative_constraint() {
5059        let age_entry = array![1.0_f64, 1.0];
5060        let age_exit = array![2.0_f64, 4.0];
5061        let event_target = array![0u8, 1u8];
5062        let event_competing = array![0u8, 0u8];
5063        let sampleweight = array![1.0, 1.0];
5064        let x_entry = array![[0.0], [0.0]];
5065        let x_exit = array![[0.0], [0.0]];
5066        let x_derivative = array![[0.5], [0.25]];
5067
5068        let collocation_offsets = Array1::zeros(x_derivative.nrows());
5069        let mut inputs = survival_inputs(
5070            &age_entry,
5071            &age_exit,
5072            &event_target,
5073            &event_competing,
5074            &sampleweight,
5075            &x_entry,
5076            &x_exit,
5077            &x_derivative,
5078        );
5079        inputs.monotonicity_constraint_rows = Some(x_derivative.view());
5080        inputs.monotonicity_constraint_offsets = Some(collocation_offsets.view());
5081
5082        let model = survival_model(
5083            inputs,
5084            PenaltyBlocks::new(Vec::new()),
5085            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5086            SurvivalSpec::Net,
5087        )
5088        .expect("construct mixed survival model");
5089
5090        let constraints = model
5091            .monotonicity_linear_constraints()
5092            .expect("event row should induce positive lower bound");
5093        assert_eq!(constraints.a.nrows(), 1);
5094        assert!((constraints.a[[0, 0]] - 1.0).abs() <= 1e-12);
5095        assert!((constraints.b[0] - 4e-8).abs() <= 1e-18);
5096    }
5097
5098    #[test]
5099    fn structural_monotonicity_clamps_tiny_negative_roundoff() {
5100        let age_entry = array![1.0_f64];
5101        let age_exit = array![2.0_f64];
5102        let event_target = array![1u8];
5103        let event_competing = array![0u8];
5104        let sampleweight = array![1.0];
5105        let x_entry = array![[0.0]];
5106        let x_exit = array![[0.0]];
5107        let x_derivative = array![[1.0]];
5108        let mut model = survival_model(
5109            survival_inputs(
5110                &age_entry,
5111                &age_exit,
5112                &event_target,
5113                &event_competing,
5114                &sampleweight,
5115                &x_entry,
5116                &x_exit,
5117                &x_derivative,
5118            ),
5119            PenaltyBlocks::new(Vec::new()),
5120            SurvivalMonotonicityPenalty { tolerance: 1e-8 },
5121            SurvivalSpec::Net,
5122        )
5123        .expect("construct survival model");
5124        model
5125            .set_structural_monotonicity(true, 1)
5126            .expect("enable structural monotonicity");
5127
5128        let state = model
5129            .update_state(&array![-1e-8])
5130            .expect("tiny structural roundoff should be clamped");
5131        assert!(state.deviance.is_finite());
5132    }
5133
5134    #[test]
5135    fn compressed_monotonicity_constraints_preserve_uncompressed_feasible_region() {
5136        let uncompressed_constraints = LinearInequalityConstraints {
5137            a: array![
5138                [0.0, 0.5, 0.0],
5139                [0.0, 1.0 / 3.0, 0.0],
5140                [0.0, 0.2, 0.0],
5141                [0.0, 0.125, 0.0]
5142            ],
5143            b: Array1::from_elem(4, 1e-8),
5144        };
5145        let compressed_constraints = compress_positive_collinear_constraints(
5146            &uncompressed_constraints.a,
5147            &uncompressed_constraints.b,
5148        );
5149
5150        let candidates = [
5151            array![0.0, 1e-9, 0.0],
5152            array![0.0, 4e-8, 0.0],
5153            array![0.0, 8e-8, 0.0],
5154            array![0.0, 2e-7, 1.5],
5155        ];
5156        for beta in candidates {
5157            let uncompressed_ok = (0..uncompressed_constraints.a.nrows()).all(|i| {
5158                uncompressed_constraints.a.row(i).dot(&beta) >= uncompressed_constraints.b[i]
5159            });
5160            let compressed_ok = (0..compressed_constraints.a.nrows())
5161                .all(|i| compressed_constraints.a.row(i).dot(&beta) >= compressed_constraints.b[i]);
5162            assert_eq!(compressed_ok, uncompressed_ok);
5163        }
5164    }
5165
5166    #[test]
5167    fn exact_survival_derivatives_are_time_unit_invariant_up_to_constant_shift() {
5168        let age_entry = array![10.0_f64, 20.0, 25.0];
5169        let age_exit = array![15.0_f64, 30.0, 40.0];
5170        let event_target = array![1u8, 0u8, 1u8];
5171        let event_competing = array![0u8, 0u8, 0u8];
5172        let sampleweight = array![1.0, 2.0, 0.5];
5173        let x_entry = array![[0.1, 0.2, 1.0], [0.3, 0.4, 1.0], [0.2, 0.6, 1.0]];
5174        let x_exit = array![[0.2, 0.3, 1.0], [0.5, 0.7, 1.0], [0.4, 0.8, 1.0]];
5175        let x_derivative = array![[0.04, 0.02, 0.0], [0.03, 0.01, 0.0], [0.02, 0.03, 0.0]];
5176        let beta = array![0.8, 1.1, -0.2];
5177
5178        let base_model = survival_model(
5179            survival_inputs(
5180                &age_entry,
5181                &age_exit,
5182                &event_target,
5183                &event_competing,
5184                &sampleweight,
5185                &x_entry,
5186                &x_exit,
5187                &x_derivative,
5188            ),
5189            PenaltyBlocks::new(Vec::new()),
5190            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5191            SurvivalSpec::Net,
5192        )
5193        .expect("construct base survival model");
5194        let base_state = base_model
5195            .update_state(&beta)
5196            .expect("evaluate base survival state");
5197
5198        let time_scale = 365.25;
5199        let scaled_age_entry = age_entry.mapv(|v| v * time_scale);
5200        let scaled_age_exit = age_exit.mapv(|v| v * time_scale);
5201        let scaled_x_derivative = x_derivative.mapv(|v| v / time_scale);
5202        let scaled_model = survival_model(
5203            survival_inputs(
5204                &scaled_age_entry,
5205                &scaled_age_exit,
5206                &event_target,
5207                &event_competing,
5208                &sampleweight,
5209                &x_entry,
5210                &x_exit,
5211                &scaled_x_derivative,
5212            ),
5213            PenaltyBlocks::new(Vec::new()),
5214            SurvivalMonotonicityPenalty { tolerance: 0.0 },
5215            SurvivalSpec::Net,
5216        )
5217        .expect("construct scaled survival model");
5218        let scaled_state = scaled_model
5219            .update_state(&beta)
5220            .expect("evaluate scaled survival state");
5221
5222        let weighted_events = sampleweight
5223            .iter()
5224            .zip(event_target.iter())
5225            .map(|(w, d)| *w * f64::from(*d))
5226            .sum::<f64>();
5227        let expected_deviance_shift = 2.0 * weighted_events * time_scale.ln();
5228        assert!(
5229            (scaled_state.deviance - base_state.deviance - expected_deviance_shift).abs() <= 1e-10,
5230            "deviance shift mismatch: scaled={} base={} expected_shift={expected_deviance_shift}",
5231            scaled_state.deviance,
5232            base_state.deviance
5233        );
5234
5235        for j in 0..beta.len() {
5236            assert!(
5237                (scaled_state.gradient[j] - base_state.gradient[j]).abs() <= 1e-12,
5238                "gradient mismatch at j={j}: scaled={} base={}",
5239                scaled_state.gradient[j],
5240                base_state.gradient[j]
5241            );
5242        }
5243
5244        let base_hessian = base_state.hessian.to_dense();
5245        let scaled_hessian = scaled_state.hessian.to_dense();
5246        for r in 0..beta.len() {
5247            for c in 0..beta.len() {
5248                assert!(
5249                    (scaled_hessian[[r, c]] - base_hessian[[r, c]]).abs() <= 1e-12,
5250                    "hessian mismatch at ({r},{c}): scaled={} base={}",
5251                    scaled_hessian[[r, c]],
5252                    base_hessian[[r, c]]
5253                );
5254            }
5255        }
5256    }
5257}