Skip to main content

gam_models/survival/
construction.rs

1//! Survival model construction helpers.
2//!
3//! Types and functions for building survival model components:
4//! - Baseline hazard targets (Weibull, Gompertz, Gompertz-Makeham)
5//! - Time basis construction (I-spline on log-time)
6//! - Baseline offset computation
7//! - Time wiggle construction
8//!
9//! These are the building blocks a library consumer needs to construct
10//! a `FitRequest::SurvivalLocationScale` without going through the CLI.
11
12use gam_terms::basis::{
13    BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14    BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15    create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18    DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19    SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24    WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25    monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::outer_subsample::RowSet;
31use gam_problem::{InverseLink, StandardLink};
32use ndarray::{Array1, Array2, array, s};
33use rayon::prelude::*;
34
35// ---------------------------------------------------------------------------
36// Typed error
37// ---------------------------------------------------------------------------
38
39/// Structured failure surface for survival-model construction helpers
40/// (`parse_*`, baseline-config builders, time-basis construction). Every
41/// variant carries a free-form `reason: String` payload; `Display` emits
42/// that payload verbatim, so converting to `String` via the `From` impl
43/// produces text byte-equivalent to the pre-refactor `Err(format!(...))`
44/// call sites that were the only producers in this module.
45///
46/// The public CLI-input parsers (`parse_survival_distribution`,
47/// `parse_survival_likelihood_mode`, `parse_survival_baseline_config`)
48/// keep their `Result<_, String>` signatures — string is the natural
49/// failure type for free-form user input — and route through this enum
50/// internally via `From<SurvivalConstructionError> for String`.
51#[derive(Clone, Debug)]
52pub enum SurvivalConstructionError {
53    /// User-supplied configuration is malformed or out of range (knot
54    /// counts, anchor offsets, derivative guards, ranks).
55    InvalidConfig { reason: String },
56    /// A required column or block of metadata is absent (e.g. saved
57    /// survival ispline keep_cols, baseline target on a saved fit).
58    MissingColumn { reason: String },
59    /// Per-row / per-column shape disagreement (entry/exit lengths,
60    /// penalty rank vs basis width, basis vs coefficient counts).
61    IncompatibleDimensions { reason: String },
62    /// Numeric / domain rejection: non-finite ratios, non-positive
63    /// survival times, monotonicity violations, ispline-derivative
64    /// underflow.
65    DataValidationFailed { reason: String },
66    /// Underlying basis / penalty builder rejected the construction
67    /// request (invalid spline order, ispline keep_cols out of range,
68    /// internal empty ispline time basis).
69    BasisConstructionFailed { reason: String },
70    /// User-named distribution / likelihood-mode / baseline target /
71    /// time-basis kind is not one we recognise.
72    UnsupportedDistribution { reason: String },
73}
74
75impl_reason_error_boilerplate! {
76    SurvivalConstructionError {
77        InvalidConfig,
78        MissingColumn,
79        IncompatibleDimensions,
80        DataValidationFailed,
81        BasisConstructionFailed,
82        UnsupportedDistribution,
83    }
84}
85
86// ---------------------------------------------------------------------------
87// Types
88// ---------------------------------------------------------------------------
89
90#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub enum SurvivalBaselineTarget {
92    /// No additional parametric target:
93    /// eta_target(t) = 0, so regularized model defaults to linear log-cumulative
94    /// hazard from the existing time basis.
95    Linear,
96    /// Parametric target: Weibull baseline.
97    ///
98    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
99    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
100    Weibull,
101    /// Parametric target: Gompertz baseline.
102    ///
103    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
104    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
105    Gompertz,
106    /// Parametric target: Gompertz-Makeham baseline.
107    ///
108    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
109    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
110    GompertzMakeham,
111}
112
113#[derive(Clone, Debug)]
114pub struct SurvivalBaselineConfig {
115    pub target: SurvivalBaselineTarget,
116    pub scale: Option<f64>,
117    pub shape: Option<f64>,
118    pub rate: Option<f64>,
119    pub makeham: Option<f64>,
120}
121
122#[derive(Clone, Debug)]
123pub enum SurvivalTimeBasisConfig {
124    None,
125    Linear,
126    BSpline {
127        degree: usize,
128        knots: Array1<f64>,
129        smooth_lambda: f64,
130    },
131    /// I-spline value rows on the `log(t)` axis with non-negative
132    /// coefficients (`γ ≥ 0`) enforcing structural monotonicity of
133    /// `q(t) = I_basis(log t) · γ`. This replaces the row-wise
134    /// `D β + o ≥ guard` derivative-guard constraints the marginal-slope
135    /// family previously relied on.
136    ///
137    /// The design builder lives below at `_build_time_block`'s
138    /// `SurvivalTimeBasisConfig::ISpline` arm and exposes:
139    ///
140    /// * `x_entry_time` / `x_exit_time` — I-spline value rows on the
141    ///   `log(t)` axis. Non-negative entries plus `γ ≥ 0` give a
142    ///   monotone-non-decreasing `q(t)`, the structural property the
143    ///   marginal-slope family needs.
144    /// * `x_derivative_time` — right-cumulative B-spline-derivative on
145    ///   `log(t)` scaled by `1/t`, again non-negative with `γ ≥ 0`, so
146    ///   `q'(t) ≥ 0` pointwise. The `derivative_guard` constant is added
147    ///   externally by [`add_survival_time_derivative_guard_offset`],
148    ///   leaving the derivative guarantee `q'(t) ≥ guard` exact.
149    /// * 2nd-difference penalty on the underlying degree-`(k+1)` B-spline
150    ///   coefficients, filtered through `keep_cols` for identifiability.
151    ///
152    /// `TimeBlockInput::time_monotonicity` declares to the consuming
153    /// family how monotonicity is enforced. The marginal-slope
154    /// construction site sets it to
155    /// [`crate::survival::location_scale::TimeBlockMonotonicity::StructuralISpline`]
156    /// so the family skips row-wise `D β + o ≥ guard` constraint
157    /// generation and treats `γ ≥ 0` as the sole derivative-guard
158    /// mechanism. The universal `validate_time_qd1_feasible` safety net
159    /// runs regardless.
160    ///
161    /// An earlier iteration proposed a separate C-spline antiderivative
162    /// parameterization that put `q'(t)` in the I-spline space and `q(t)`
163    /// in the integral-of-I-spline space. That was mathematically
164    /// equivalent but a strictly worse fit for the codebase (extra basis
165    /// degree, an extra antiderivative builder, an extra identifiability
166    /// path, an extra penalty); it was removed in favor of the canonical
167    /// I-spline-value path here.
168    ISpline {
169        degree: usize,
170        knots: Array1<f64>,
171        keep_cols: Vec<usize>,
172        smooth_lambda: f64,
173    },
174}
175
176/// Persistable snapshot of the time-basis state used by a survival fit.
177///
178/// Every survival family routes through [`SurvivalTimeBuildOutput`] during
179/// the fit, but the FFI save path needs only the metadata — not the full
180/// design matrices. This struct is the single source of truth that flows
181/// from the workflow-level basis construction, through the family-specific
182/// fit result, into the saved-model payload via
183/// [`crate::inference::model::FittedModelPayload::apply_survival_time_basis`].
184///
185/// Threading this snapshot end-to-end eliminates the prior bug pattern
186/// where each FFI builder had to reconstruct the metadata from
187/// `fit_config` + the formula (silent drift risk; one builder forgetting
188/// to do so caused the marginal-slope save→load break).
189#[derive(Clone, Debug, PartialEq)]
190pub struct SavedSurvivalTimeBasis {
191    pub basisname: String,
192    pub degree: Option<usize>,
193    pub knots: Option<Vec<f64>>,
194    pub keep_cols: Option<Vec<usize>>,
195    pub smooth_lambda: Option<f64>,
196    pub anchor: f64,
197}
198
199impl SavedSurvivalTimeBasis {
200    /// Build a snapshot from the realised time-basis state and the entry
201    /// anchor that was used during the fit.
202    pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
203        Self {
204            basisname: build.basisname.clone(),
205            degree: build.degree,
206            knots: build.knots.clone(),
207            keep_cols: build.keep_cols.clone(),
208            smooth_lambda: build.smooth_lambda,
209            anchor,
210        }
211    }
212}
213
214#[derive(Clone)]
215pub struct SurvivalTimeBuildOutput {
216    pub x_entry_time: DesignMatrix,
217    pub x_exit_time: DesignMatrix,
218    pub x_derivative_time: DesignMatrix,
219    pub penalties: Vec<Array2<f64>>,
220    /// Structural nullspace dimension of each penalty matrix.
221    pub nullspace_dims: Vec<usize>,
222    pub basisname: String,
223    pub degree: Option<usize>,
224    pub knots: Option<Vec<f64>>,
225    pub keep_cols: Option<Vec<usize>>,
226    pub smooth_lambda: Option<f64>,
227}
228
229pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
230
231/// Entry ages above this value mark genuine left truncation (delayed entry): the
232/// row's cumulative-hazard interval starts at a positive left-tail time rather
233/// than the time origin. Kept in lockstep with the working-model's
234/// `ENTRY_AT_ORIGIN_THRESHOLD` so the "this row has an entry interval" and "the
235/// data is left-truncated" decisions agree.
236pub const SURVIVAL_DELAYED_ENTRY_THRESHOLD: f64 = 1e-8;
237
238/// Seed smoothing penalty `λ` used when a survival time basis is reconstructed
239/// from a build (or saved model) that did not carry an explicit `smooth_lambda`.
240/// This is only an initial value for the REML smoothing search, not a fixed
241/// policy: a small positive seed keeps the baseline spline lightly regularized
242/// at the start so the outer optimizer begins from a well-conditioned point and
243/// then adapts `λ` to the data. Kept in one place so the b-spline and i-spline
244/// reconstruction paths cannot drift apart.
245const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
246
247/// Default initial Gompertz / Gompertz-Makeham shape parameter when the user
248/// does not supply `--baseline-shape`. The Gompertz hazard is
249/// `h(t) = rate · exp(shape · t)`; a near-zero shape seeds the baseline at an
250/// almost-flat (exponential-like) hazard, letting the fit grow the
251/// age-acceleration term from the data rather than committing to a strong
252/// curvature up front. Shared by the parse and fit-seed paths so both start
253/// from the same neutral shape.
254const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
255
256#[derive(Clone, Copy, Debug, PartialEq, Eq)]
257pub enum SurvivalLikelihoodMode {
258    Transformation,
259    Weibull,
260    LocationScale,
261    MarginalSlope,
262    Latent,
263    LatentBinary,
264}
265
266pub struct SurvivalTimeWiggleBuild {
267    pub penalties: Vec<Array2<f64>>,
268    pub nullspace_dims: Vec<usize>,
269    pub knots: Array1<f64>,
270    pub degree: usize,
271    pub ncols: usize,
272}
273
274// ---------------------------------------------------------------------------
275// Time normalization
276// ---------------------------------------------------------------------------
277
278pub fn normalize_survival_time_pair(
279    entry_raw: f64,
280    exit_raw: f64,
281    row_index: usize,
282) -> Result<(f64, f64), String> {
283    if !entry_raw.is_finite() || !exit_raw.is_finite() {
284        return Err(SurvivalConstructionError::DataValidationFailed {
285            reason: format!("non-finite survival times at row {}", row_index + 1),
286        }
287        .into());
288    }
289    if entry_raw < 0.0 || exit_raw < 0.0 {
290        return Err(SurvivalConstructionError::DataValidationFailed {
291            reason: format!("negative survival times at row {}", row_index + 1),
292        }
293        .into());
294    }
295
296    let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
297    let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
298    Ok((entry, exit))
299}
300
301// ---------------------------------------------------------------------------
302// Basis monotonicity helpers
303// ---------------------------------------------------------------------------
304
305pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
306    basisname.eq_ignore_ascii_case("ispline")
307}
308
309pub fn require_structural_survival_time_basis(
310    basisname: &str,
311    context: &str,
312) -> Result<(), String> {
313    if survival_basis_supports_structural_monotonicity(basisname) {
314        return Ok(());
315    }
316    Err(SurvivalConstructionError::UnsupportedDistribution {
317        reason: format!(
318            "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
319Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
320`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
321Re-run with `--time-basis ispline`."
322        ),
323    }
324    .into())
325}
326
327// ---------------------------------------------------------------------------
328// Baseline config parsing
329// ---------------------------------------------------------------------------
330
331pub fn parse_survival_baseline_config(
332    target_raw: &str,
333    scale: Option<f64>,
334    shape: Option<f64>,
335    rate: Option<f64>,
336    makeham: Option<f64>,
337) -> Result<SurvivalBaselineConfig, String> {
338    let target = match target_raw.to_ascii_lowercase().as_str() {
339        "linear" => SurvivalBaselineTarget::Linear,
340        "weibull" => SurvivalBaselineTarget::Weibull,
341        "gompertz" => SurvivalBaselineTarget::Gompertz,
342        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
343        other => {
344            return Err(SurvivalConstructionError::UnsupportedDistribution {
345                reason: format!(
346                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
347                ),
348            }
349            .into());
350        }
351    };
352
353    match target {
354        SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
355            target,
356            scale: None,
357            shape: None,
358            rate: None,
359            makeham: None,
360        }),
361        SurvivalBaselineTarget::Weibull => {
362            let scale = scale.ok_or_else(|| {
363                "--baseline-target weibull requires --baseline-scale > 0".to_string()
364            })?;
365            let shape = shape.ok_or_else(|| {
366                "--baseline-target weibull requires --baseline-shape > 0".to_string()
367            })?;
368            if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
369                return Err(
370                    "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
371                        .to_string(),
372                );
373            }
374            Ok(SurvivalBaselineConfig {
375                target,
376                scale: Some(scale),
377                shape: Some(shape),
378                rate: None,
379                makeham: None,
380            })
381        }
382        SurvivalBaselineTarget::Gompertz => {
383            let rate = rate.unwrap_or(1.0);
384            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
385            if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
386                return Err(
387                    "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
388                        .to_string(),
389                );
390            }
391            Ok(SurvivalBaselineConfig {
392                target,
393                scale: None,
394                shape: Some(shape),
395                rate: Some(rate),
396                makeham: None,
397            })
398        }
399        SurvivalBaselineTarget::GompertzMakeham => {
400            let rate = rate.unwrap_or(0.5);
401            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
402            let makeham = makeham.unwrap_or(0.5);
403            if !rate.is_finite()
404                || rate <= 0.0
405                || !shape.is_finite()
406                || !makeham.is_finite()
407                || makeham <= 0.0
408            {
409                return Err(
410                    "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
411                        .to_string(),
412                );
413            }
414            Ok(SurvivalBaselineConfig {
415                target,
416                scale: None,
417                shape: Some(shape),
418                rate: Some(rate),
419                makeham: Some(makeham),
420            })
421        }
422    }
423}
424
425// ---------------------------------------------------------------------------
426// Likelihood mode / distribution parsing
427// ---------------------------------------------------------------------------
428
429pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
430    match raw.to_ascii_lowercase().as_str() {
431        "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
432        "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
433        "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
434        "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
435        "latent" => Ok(SurvivalLikelihoodMode::Latent),
436        "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
437        other => Err(SurvivalConstructionError::UnsupportedDistribution {
438            reason: format!(
439                "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
440            ),
441        }
442        .into()),
443    }
444}
445
446pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
447    match mode {
448        SurvivalLikelihoodMode::Transformation => "transformation",
449        SurvivalLikelihoodMode::Weibull => "weibull",
450        SurvivalLikelihoodMode::LocationScale => "location-scale",
451        SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
452        SurvivalLikelihoodMode::Latent => "latent",
453        SurvivalLikelihoodMode::LatentBinary => "latent-binary",
454    }
455}
456
457pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
458    match raw.to_ascii_lowercase().as_str() {
459        "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
460        "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
461        "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
462        other => Err(SurvivalConstructionError::UnsupportedDistribution {
463            reason: format!(
464                "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
465            ),
466        }
467        .into()),
468    }
469}
470
471pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
472    match target {
473        SurvivalBaselineTarget::Linear => "linear",
474        SurvivalBaselineTarget::Weibull => "weibull",
475        SurvivalBaselineTarget::Gompertz => "gompertz",
476        SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
477    }
478}
479
480pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
481    let sum = age_exit
482        .iter()
483        .copied()
484        .filter(|value| value.is_finite() && *value > 0.0)
485        .sum::<f64>();
486    let count = age_exit
487        .iter()
488        .filter(|value| value.is_finite() && **value > 0.0)
489        .count()
490        .max(1);
491    (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
492}
493
494pub fn initial_survival_baseline_config_for_fit(
495    target_raw: &str,
496    scale: Option<f64>,
497    shape: Option<f64>,
498    rate: Option<f64>,
499    makeham: Option<f64>,
500    age_exit: &Array1<f64>,
501) -> Result<SurvivalBaselineConfig, String> {
502    let target = match target_raw.trim().to_ascii_lowercase().as_str() {
503        "linear" => SurvivalBaselineTarget::Linear,
504        "weibull" => SurvivalBaselineTarget::Weibull,
505        "gompertz" => SurvivalBaselineTarget::Gompertz,
506        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
507        other => {
508            return Err(SurvivalConstructionError::UnsupportedDistribution {
509                reason: format!(
510                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
511                ),
512            }
513            .into());
514        }
515    };
516    let time_scale_seed = positive_survival_time_seed(age_exit);
517    let cfg = match target {
518        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
519            target,
520            scale: None,
521            shape: None,
522            rate: None,
523            makeham: None,
524        },
525        SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
526            target,
527            scale: Some(scale.unwrap_or(time_scale_seed)),
528            shape: Some(shape.unwrap_or(1.0)),
529            rate: None,
530            makeham: None,
531        },
532        SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
533            target,
534            scale: None,
535            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
536            rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
537            makeham: None,
538        },
539        SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
540            target,
541            scale: None,
542            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
543            rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
544            makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
545        },
546    };
547    parse_survival_baseline_config(
548        survival_baseline_targetname(cfg.target),
549        cfg.scale,
550        cfg.shape,
551        cfg.rate,
552        cfg.makeham,
553    )
554}
555
556fn survival_baseline_theta_from_config(
557    cfg: &SurvivalBaselineConfig,
558) -> Result<Option<Array1<f64>>, String> {
559    Ok(match cfg.target {
560        SurvivalBaselineTarget::Linear => None,
561        SurvivalBaselineTarget::Weibull => Some(array![
562            cfg.scale
563                .ok_or_else(|| "missing weibull baseline scale".to_string())?
564                .ln(),
565            cfg.shape
566                .ok_or_else(|| "missing weibull baseline shape".to_string())?
567                .ln(),
568        ]),
569        SurvivalBaselineTarget::Gompertz => Some(array![
570            cfg.rate
571                .ok_or_else(|| "missing gompertz baseline rate".to_string())?
572                .ln(),
573            cfg.shape
574                .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
575        ]),
576        SurvivalBaselineTarget::GompertzMakeham => Some(array![
577            cfg.rate
578                .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
579                .ln(),
580            cfg.shape
581                .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
582            cfg.makeham
583                .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
584                .ln(),
585        ]),
586    })
587}
588
589fn survival_baseline_config_from_theta(
590    target: SurvivalBaselineTarget,
591    theta: &Array1<f64>,
592) -> Result<SurvivalBaselineConfig, String> {
593    let cfg = match target {
594        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
595            target,
596            scale: None,
597            shape: None,
598            rate: None,
599            makeham: None,
600        },
601        SurvivalBaselineTarget::Weibull => {
602            if theta.len() != 2 {
603                return Err(SurvivalConstructionError::IncompatibleDimensions {
604                    reason: format!(
605                        "weibull baseline parameter dimension mismatch: expected 2, got {}",
606                        theta.len()
607                    ),
608                }
609                .into());
610            }
611            SurvivalBaselineConfig {
612                target,
613                scale: Some(theta[0].exp()),
614                shape: Some(theta[1].exp()),
615                rate: None,
616                makeham: None,
617            }
618        }
619        SurvivalBaselineTarget::Gompertz => {
620            if theta.len() != 2 {
621                return Err(SurvivalConstructionError::IncompatibleDimensions {
622                    reason: format!(
623                        "gompertz baseline parameter dimension mismatch: expected 2, got {}",
624                        theta.len()
625                    ),
626                }
627                .into());
628            }
629            SurvivalBaselineConfig {
630                target,
631                scale: None,
632                shape: Some(theta[1]),
633                rate: Some(theta[0].exp()),
634                makeham: None,
635            }
636        }
637        SurvivalBaselineTarget::GompertzMakeham => {
638            if theta.len() != 3 {
639                return Err(SurvivalConstructionError::IncompatibleDimensions {
640                    reason: format!(
641                        "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
642                        theta.len()
643                    ),
644                }
645                .into());
646            }
647            SurvivalBaselineConfig {
648                target,
649                scale: None,
650                shape: Some(theta[1]),
651                rate: Some(theta[0].exp()),
652                makeham: Some(theta[2].exp()),
653            }
654        }
655    };
656    parse_survival_baseline_config(
657        survival_baseline_targetname(cfg.target),
658        cfg.scale,
659        cfg.shape,
660        cfg.rate,
661        cfg.makeham,
662    )
663}
664
665/// Derivative contract for the shared baseline-θ outer optimizer.
666///
667/// The two public baseline optimizers (`…_with_gradient_only`,
668/// `…_with_gradient`) differ in exactly one axis: how much derivative
669/// information the objective closure supplies, and therefore which curvature
670/// declaration the `OuterProblem` must advertise. Every baseline-θ path now
671/// supplies an exact analytic gradient (profile-NLL envelope gradient), so both
672/// contracts route to a gradient-based solver. Everything else — θ↔config
673/// conversion, the ±6 log-space box,
674/// the single-seed config, the `run`/convergence/error-formatting boilerplate
675/// — is identical, so it lives once in [`run_baseline_theta_optimizer`] and
676/// this enum selects the per-contract `OuterProblem` configuration.
677#[derive(Clone, Copy, Debug, PartialEq, Eq)]
678enum BaselineDerivativeContract {
679    /// Cost + analytic gradient, no analytic Hessian. Routes to BFGS, which
680    /// builds its own quasi-Newton curvature from successive gradients.
681    GradientOnly,
682    /// Cost + analytic gradient + analytic Hessian. Routes to the primary
683    /// second-order outer solver, which may use either the analytic Hessian or
684    /// a BFGS approximation depending on the planner.
685    GradientHessian,
686}
687
688impl BaselineDerivativeContract {
689    /// Apply this contract's derivative declaration, solver class, tolerance,
690    /// and iteration budget to a freshly-constructed `OuterProblem`. The
691    /// bounds, initial ρ, and seed config are contract-independent and applied
692    /// by [`run_baseline_theta_optimizer`].
693    fn configure(
694        self,
695        problem: gam_solve::rho_optimizer::OuterProblem,
696    ) -> gam_solve::rho_optimizer::OuterProblem {
697        use gam_problem::{DeclaredHessianForm, Derivative};
698        match self {
699            // BFGS on a 2–3 dim problem with an exact gradient typically
700            // converges in 5–10 outer evaluations.
701            BaselineDerivativeContract::GradientOnly => problem
702                .with_gradient(Derivative::Analytic)
703                .with_hessian(DeclaredHessianForm::Unavailable)
704                .with_tolerance(1e-4)
705                .with_max_iter(240),
706            BaselineDerivativeContract::GradientHessian => problem
707                .with_gradient(Derivative::Analytic)
708                .with_hessian(DeclaredHessianForm::Either)
709                .with_tolerance(1e-4)
710                .with_max_iter(240),
711        }
712    }
713}
714
715/// Shared engine behind the three public baseline-config optimizers.
716///
717/// Owns every step that is identical across the cost-only, gradient-only, and
718/// gradient+Hessian contracts: config→θ seeding (with the linear/no-parameter
719/// early return), the ±6 log-space box, the single-seed `OuterProblem`
720/// skeleton, derivative-contract configuration, `build_objective` wiring,
721/// `run`, the convergence check + error formatting, and θ→config. The only
722/// contract-specific inputs are the already-wired `cost_fn`/`eval_fn` closures
723/// (which embed the derivative shape and dimension validation) and the
724/// `contract` selecting the `OuterProblem` derivative declaration.
725fn run_baseline_theta_optimizer<Fc, Fe>(
726    initial: &SurvivalBaselineConfig,
727    context: &str,
728    contract: BaselineDerivativeContract,
729    cost_fn: Fc,
730    eval_fn: Fe,
731) -> Result<SurvivalBaselineConfig, String>
732where
733    Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
734    Fe: FnMut(
735        &mut (),
736        &Array1<f64>,
737    ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
738{
739    use gam_solve::rho_optimizer::OuterProblem;
740    let Some(seed) = survival_baseline_theta_from_config(initial)? else {
741        return Ok(initial.clone());
742    };
743    let dim = seed.len();
744    let target = initial.target;
745    let lower = seed.mapv(|v| v - 6.0);
746    let upper = seed.mapv(|v| v + 6.0);
747    let problem = contract
748        .configure(OuterProblem::new(dim))
749        .with_bounds(lower, upper)
750        .with_initial_rho(seed.clone())
751        .with_seed_config(crate::seeding::SeedConfig {
752            max_seeds: 1,
753            seed_budget: 1,
754            num_auxiliary_trailing: dim,
755            ..Default::default()
756        });
757    let mut obj = problem.build_objective(
758        (),
759        cost_fn,
760        eval_fn,
761        None::<fn(&mut ())>,
762        None::<
763            fn(
764                &mut (),
765                &Array1<f64>,
766            ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
767        >,
768    );
769    let result = problem
770        .run(&mut obj, context)
771        .map_err(|e| format!("{context} failed: {e}"))?;
772    if !result.converged {
773        return Err(SurvivalConstructionError::InvalidConfig {
774            reason: format!(
775                "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
776                result.iterations,
777                result.final_value,
778                result.final_grad_norm_report(),
779            ),
780        }
781        .into());
782    }
783    survival_baseline_config_from_theta(target, &result.rho)
784}
785
786/// Shared engine for the two derivative-carrying baseline-config optimizers.
787///
788/// Both `…_with_gradient_only` and `…_with_gradient` route an objective that
789/// returns a fully-populated [`OuterEval`](gam_problem::OuterEval)
790/// (cost + analytic gradient, optionally + analytic Hessian) for a given
791/// config. Everything downstream of that — the `Rc<RefCell>` sharing that lets
792/// the same user closure back both the `cost_fn` and `eval_fn`, the θ→config
793/// conversion, and deriving the scalar `cost_fn` from the eval result — is
794/// identical, so it lives here once. The contract-specific axis is only which
795/// `HessianResult` the objective embeds, which the wrapper has already encoded
796/// in the returned `OuterEval`, so this helper is contract-agnostic beyond the
797/// `contract` it forwards to [`run_baseline_theta_optimizer`].
798fn run_baseline_theta_optimizer_with_eval<F>(
799    initial: &SurvivalBaselineConfig,
800    context: &str,
801    contract: BaselineDerivativeContract,
802    objective: F,
803) -> Result<SurvivalBaselineConfig, String>
804where
805    F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
806{
807    let target = initial.target;
808    let engine_context = context.to_string();
809    let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
810    let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
811                        theta: &Array1<f64>|
812          -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
813        let cfg = survival_baseline_config_from_theta(target, theta)
814            .map_err(crate::model_types::EstimationError::InvalidInput)?;
815        let eval =
816            obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
817        if eval.gradient.len() != theta.len() {
818            return Err(crate::model_types::EstimationError::InvalidInput(format!(
819                "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
820                eval.gradient.len(),
821                theta.len()
822            )));
823        }
824        if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
825            if h.nrows() != theta.len() || h.ncols() != theta.len() {
826                return Err(crate::model_types::EstimationError::InvalidInput(format!(
827                    "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
828                    h.nrows(),
829                    h.ncols(),
830                    theta.len(),
831                    theta.len()
832                )));
833            }
834        }
835        Ok(eval)
836    };
837    let cost_objective = std::rc::Rc::clone(&objective);
838    let cost_eval = eval_at.clone();
839    let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
840        cost_eval(&cost_objective, theta).map(|eval| eval.cost)
841    };
842    let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
843    run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
844}
845
846/// Gradient-only outer baseline-config optimizer. Thin adapter over
847/// [`run_baseline_theta_optimizer`] under the
848/// [`BaselineDerivativeContract::GradientOnly`] contract, which advertises
849/// `DeclaredHessianForm::Unavailable`, so the planner routes to BFGS and
850/// builds its own quasi-Newton curvature from successive gradient
851/// evaluations. Used by the survival location-scale path which has a
852/// closed-form θ-gradient (`baseline_chain_rule_gradient` /
853/// `marginal_slope_baseline_chain_rule_gradient`) but no native analytic
854/// θ-Hessian; BFGS on a 2–3 dim problem with an exact gradient typically
855/// converges in 5–10 outer evaluations.
856pub fn optimize_survival_baseline_config_with_gradient_only<F>(
857    initial: &SurvivalBaselineConfig,
858    context: &str,
859    mut objective: F,
860) -> Result<SurvivalBaselineConfig, String>
861where
862    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
863{
864    use gam_problem::{HessianResult, OuterEval};
865    run_baseline_theta_optimizer_with_eval(
866        initial,
867        context,
868        BaselineDerivativeContract::GradientOnly,
869        move |cfg| {
870            let (cost, gradient) = objective(cfg)?;
871            Ok(OuterEval {
872                cost,
873                gradient,
874                hessian: HessianResult::Unavailable,
875                inner_beta_hint: None,
876            })
877        },
878    )
879}
880
881/// Gradient + Hessian outer baseline-config optimizer. Thin adapter over
882/// [`run_baseline_theta_optimizer`] under the
883/// [`BaselineDerivativeContract::GradientHessian`] contract, which advertises
884/// an analytic θ-Hessian so the primary second-order outer solver can use it.
885pub fn optimize_survival_baseline_config_with_gradient<F>(
886    initial: &SurvivalBaselineConfig,
887    context: &str,
888    mut objective: F,
889) -> Result<SurvivalBaselineConfig, String>
890where
891    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
892{
893    use gam_problem::{HessianResult, OuterEval};
894    run_baseline_theta_optimizer_with_eval(
895        initial,
896        context,
897        BaselineDerivativeContract::GradientHessian,
898        move |cfg| {
899            let (cost, gradient, hessian) = objective(cfg)?;
900            Ok(OuterEval {
901                cost,
902                gradient,
903                hessian: HessianResult::Analytic(hessian),
904                inner_beta_hint: None,
905            })
906        },
907    )
908}
909
910// ---------------------------------------------------------------------------
911// Time basis config (library-friendly: takes primitives, not CLI args)
912// ---------------------------------------------------------------------------
913
914pub fn parse_survival_time_basis_config(
915    time_basis: &str,
916    time_degree: usize,
917    time_num_internal_knots: usize,
918    time_smooth_lambda: f64,
919) -> Result<SurvivalTimeBasisConfig, String> {
920    match time_basis.to_ascii_lowercase().as_str() {
921        "none" => Ok(SurvivalTimeBasisConfig::None),
922        "ispline" => {
923            if time_degree < 1 {
924                return Err(
925                    "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
926                        .to_string(),
927                );
928            }
929            if time_num_internal_knots == 0 {
930                return Err(
931                    "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
932                        .to_string(),
933                );
934            }
935            if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
936                return Err(
937                    "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
938                        .to_string(),
939                );
940            }
941            Ok(SurvivalTimeBasisConfig::ISpline {
942                degree: time_degree,
943                knots: Array1::zeros(0),
944                keep_cols: Vec::new(),
945                smooth_lambda: time_smooth_lambda,
946            })
947        }
948        "linear" | "bspline" => {
949            // Forward to the shared structural-basis check so error text
950            // stays consistent with every other call site. `linear` /
951            // `bspline` are not structural, so this always returns Err;
952            // we map a (currently impossible) `Ok` to an explicit error
953            // string instead of `unreachable!`, keeping the match total
954            // without relying on a never-executes claim.
955            match require_structural_survival_time_basis(time_basis, "survival model configuration")
956            {
957                Err(e) => Err(e),
958                Ok(()) => Err(format!(
959                    "internal: structural-basis check accepted non-structural \
960                     survival time basis '{time_basis}'"
961                )),
962            }
963        }
964        other => Err(format!(
965            "unsupported --time-basis '{other}'; accepted values: ispline, none"
966        )),
967    }
968}
969
970// ---------------------------------------------------------------------------
971// Time basis construction
972// ---------------------------------------------------------------------------
973
974pub fn build_survival_time_basis(
975    age_entry: &Array1<f64>,
976    age_exit: &Array1<f64>,
977    cfg: SurvivalTimeBasisConfig,
978    infer_knots_if_needed: Option<(usize, f64)>,
979) -> Result<SurvivalTimeBuildOutput, String> {
980    fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
981        if let Some(row) = times.iter().position(|t| !t.is_finite()) {
982            return Err(SurvivalConstructionError::DataValidationFailed {
983                reason: format!(
984                    "survival time basis requires finite {label} times (row {})",
985                    row + 1
986                ),
987            }
988            .into());
989        }
990        if let Some(row) = times.iter().position(|t| *t < 0.0) {
991            return Err(SurvivalConstructionError::DataValidationFailed {
992                reason: format!(
993                    "survival time basis requires non-negative {label} times (row {})",
994                    row + 1
995                ),
996            }
997            .into());
998        }
999        Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
1000    }
1001
1002    let n = age_entry.len();
1003    if n != age_exit.len() {
1004        return Err(SurvivalConstructionError::IncompatibleDimensions {
1005            reason: "survival time basis requires matching entry/exit lengths".to_string(),
1006        }
1007        .into());
1008    }
1009    for i in 0..n {
1010        if age_exit[i] < age_entry[i] {
1011            return Err(format!(
1012                "survival time basis requires exit times >= entry times (row {})",
1013                i + 1
1014            ));
1015        }
1016    }
1017    let log_entry = checked_log_survival_times(age_entry, "entry")?;
1018    let log_exit = checked_log_survival_times(age_exit, "exit")?;
1019
1020    fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1021        let n = log_entry.len();
1022        let entry_range = log_entry
1023            .iter()
1024            .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1025                (lo.min(v), hi.max(v))
1026            });
1027        let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1028        if entry_degenerate {
1029            log_exit.clone()
1030        } else {
1031            let mut combined = Array1::<f64>::zeros(2 * n);
1032            for i in 0..n {
1033                combined[i] = log_entry[i];
1034                combined[n + i] = log_exit[i];
1035            }
1036            combined
1037        }
1038    }
1039
1040    /// Cap the requested monotone-baseline internal-knot count to what the
1041    /// observed time resolution can actually support.
1042    ///
1043    /// The survival location-scale baseline is a degree-`d` I-spline with
1044    /// `num_internal_knots + d` shape-varying columns. Its smoothing parameter
1045    /// is informed *only* by the distinct interior log-time points: with fewer
1046    /// distinct interior times than requested knots the baseline is
1047    /// rank-deficient, and the REML/LAML profile in the time smoothing
1048    /// parameter becomes a flat ridge — the exact-joint outer search then
1049    /// probes that ridge indefinitely (each inner constrained Newton burns its
1050    /// whole cycle budget without certifying convergence) and the fit never
1051    /// terminates. This is the survival analogue of the standard
1052    /// "df must not exceed the data resolution" guard (`mgcv` caps `k` at the
1053    /// number of unique covariate values; `flexsurv`/`rstpm2` use a handful of
1054    /// baseline knots): we never place more interior knots than there are
1055    /// distinct interior points, and we keep the total baseline dimension a
1056    /// bounded fraction of the sample so the smoothing profile stays curved.
1057    ///
1058    /// This clamp lives in the shared knot-inference routine so the fit and any
1059    /// independent rebuild of the time basis (e.g. a predictor reconstructing
1060    /// `design · β` at fresh covariates) resolve to the *same* knot vector from
1061    /// the same data — there is no raw/active dimension drift.
1062    fn data_capped_internal_knots(
1063        combined: &Array1<f64>,
1064        degree: usize,
1065        requested_internal_knots: usize,
1066    ) -> usize {
1067        if requested_internal_knots == 0 {
1068            return 0;
1069        }
1070        let mut sorted: Vec<f64> = combined.iter().copied().collect();
1071        sorted.sort_by(f64::total_cmp);
1072        let minval = sorted.first().copied().unwrap_or(0.0);
1073        let maxval = sorted.last().copied().unwrap_or(minval);
1074        if minval == maxval {
1075            // Degenerate (single distinct time): no interior structure to fit.
1076            return 1.min(requested_internal_knots);
1077        }
1078        let scale = (maxval - minval).abs().max(1.0);
1079        let tol = 1e-12 * scale;
1080        // Count distinct strictly-interior points (knots can only live strictly
1081        // between the data extremes).
1082        let mut distinct_interior = 0usize;
1083        let mut last: Option<f64> = None;
1084        for &x in &sorted {
1085            if x <= minval + tol || x >= maxval - tol {
1086                continue;
1087            }
1088            if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1089                continue;
1090            }
1091            distinct_interior += 1;
1092            last = Some(x);
1093        }
1094        // Distinct-point ceiling: cannot place more interior knots than there
1095        // are distinct interior values.
1096        let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1097        // Dimension-vs-resolution ceiling: keep the total baseline column count
1098        // `cap + degree` below ~1/4 of the distinct sample points so the
1099        // smoothing-parameter profile retains curvature (the data must be able
1100        // to identify the baseline shape, not just interpolate it). `n_distinct`
1101        // counts all distinct points (interior + the two extremes).
1102        let n_distinct = {
1103            let mut count = 0usize;
1104            let mut last: Option<f64> = None;
1105            for &x in &sorted {
1106                if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1107                    continue;
1108                }
1109                count += 1;
1110                last = Some(x);
1111            }
1112            count
1113        };
1114        let dim_budget = n_distinct / 4;
1115        let dim_cap = dim_budget.saturating_sub(degree);
1116        cap = cap.min(dim_cap.max(1));
1117        cap.max(1)
1118    }
1119
1120    fn infer_survival_time_knots(
1121        combined: &Array1<f64>,
1122        knot_degree: usize,
1123        validation_degree: usize,
1124        num_internal_knots: usize,
1125        basis_options: BasisOptions,
1126    ) -> Result<Array1<f64>, String> {
1127        // Identifiability/termination guard: never request more baseline
1128        // internal knots than the observed time resolution supports. See
1129        // `data_capped_internal_knots` for the full rationale (a flat smoothing
1130        // ridge on an over-parameterized baseline is what makes the survival
1131        // location-scale exact-joint outer search fail to terminate).
1132        let num_internal_knots =
1133            data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1134
1135        fn quantile_knot_inference_needs_uniform_fallback(
1136            combined: &Array1<f64>,
1137            num_internal_knots: usize,
1138        ) -> bool {
1139            if num_internal_knots == 0 || combined.is_empty() {
1140                return false;
1141            }
1142
1143            let mut sorted: Vec<f64> = combined.iter().copied().collect();
1144            sorted.sort_by(f64::total_cmp);
1145            let minval = sorted[0];
1146            let maxval = *sorted.last().unwrap_or(&minval);
1147            if minval == maxval {
1148                return false;
1149            }
1150
1151            let scale = (maxval - minval).abs().max(1.0);
1152            let tol = 1e-12 * scale;
1153            let mut support = Vec::with_capacity(sorted.len());
1154            let mut last: Option<f64> = None;
1155            for &x in &sorted {
1156                if x <= minval + tol || x >= maxval - tol {
1157                    continue;
1158                }
1159                if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1160                    continue;
1161                }
1162                support.push(x);
1163                last = Some(x);
1164            }
1165            if support.is_empty() {
1166                return true;
1167            }
1168
1169            let n = support.len();
1170            let mut prev_q = minval;
1171            for j in 1..=num_internal_knots {
1172                let p = j as f64 / (num_internal_knots + 1) as f64;
1173                let pos = p * (n.saturating_sub(1) as f64);
1174                let lo = pos.floor() as usize;
1175                let hi = pos.ceil() as usize;
1176                let frac = pos - lo as f64;
1177                let q = if lo == hi {
1178                    support[lo]
1179                } else {
1180                    support[lo] * (1.0 - frac) + support[hi] * frac
1181                }
1182                .clamp(minval, maxval);
1183                if q <= prev_q + tol || q >= maxval - tol {
1184                    return true;
1185                }
1186                prev_q = q;
1187            }
1188
1189            false
1190        }
1191
1192        let inferwith =
1193            |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1194                let built = build_bspline_basis_1d(
1195                    combined.view(),
1196                    &BSplineBasisSpec {
1197                        degree: knot_degree,
1198                        penalty_order: 2,
1199                        knotspec: BSplineKnotSpec::Automatic {
1200                            num_internal_knots: Some(num_internal_knots),
1201                            placement,
1202                        },
1203                        double_penalty: false,
1204                        identifiability: BSplineIdentifiability::None,
1205                        boundary: OneDimensionalBoundary::Open,
1206                        boundary_conditions: BSplineBoundaryConditions::default(),
1207                    },
1208                )
1209                .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1210                let knots = match built.metadata {
1211                    BasisMetadata::BSpline1D { knots, .. } => knots,
1212                    _ => {
1213                        return Err(
1214                            "internal error: expected BSpline1D metadata for survival time basis"
1215                                .to_string(),
1216                        );
1217                    }
1218                };
1219                // `knot_degree` is the clamped B-spline degree used to size
1220                // the knot vector. `validation_degree` is the public basis
1221                // degree passed to the final evaluator. They differ for
1222                // I-splines because `create_basis(..., BasisOptions::i_spline())`
1223                // internally raises the public degree by one to its working
1224                // B-spline antiderivative degree. Validating with
1225                // `knot_degree` here would raise a second time and reject the
1226                // coherent knot vector we just inferred.
1227                create_basis::<Dense>(
1228                    combined.view(),
1229                    KnotSource::Provided(knots.view()),
1230                    validation_degree,
1231                    basis_options,
1232                )
1233                .map_err(|e| e.to_string())?;
1234                Ok(knots)
1235            };
1236
1237        if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1238            inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1239        } else {
1240            inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1241        }
1242    }
1243
1244    match cfg {
1245        SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1246            x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1247            x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1248            x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1249            penalties: Vec::new(),
1250            nullspace_dims: Vec::new(),
1251            basisname: "none".to_string(),
1252            degree: None,
1253            knots: None,
1254            keep_cols: None,
1255            smooth_lambda: None,
1256        }),
1257        SurvivalTimeBasisConfig::Linear => {
1258            let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1259            let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1260            let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1261            for i in 0..n {
1262                x_entry_time[[i, 0]] = 1.0;
1263                x_exit_time[[i, 0]] = 1.0;
1264                x_entry_time[[i, 1]] = log_entry[i];
1265                x_exit_time[[i, 1]] = log_exit[i];
1266                x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1267            }
1268            Ok(SurvivalTimeBuildOutput {
1269                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1270                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1271                x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1272                penalties: Vec::new(),
1273                nullspace_dims: Vec::new(),
1274                basisname: "linear".to_string(),
1275                degree: None,
1276                knots: None,
1277                keep_cols: None,
1278                smooth_lambda: None,
1279            })
1280        }
1281        SurvivalTimeBasisConfig::BSpline {
1282            degree,
1283            knots,
1284            smooth_lambda,
1285        } => {
1286            let knotvec = if knots.is_empty() {
1287                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1288                    "internal error: bspline time basis requested without knot source".to_string()
1289                })?;
1290                let combined = survival_time_knot_input(&log_entry, &log_exit);
1291                infer_survival_time_knots(
1292                    &combined,
1293                    degree,
1294                    degree,
1295                    num_internal_knots,
1296                    BasisOptions::value(),
1297                )?
1298            } else {
1299                knots
1300            };
1301
1302            let entry_basis = build_bspline_basis_1d(
1303                log_entry.view(),
1304                &BSplineBasisSpec {
1305                    degree,
1306                    penalty_order: 2,
1307                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1308                    double_penalty: false,
1309                    identifiability: BSplineIdentifiability::None,
1310                    boundary: OneDimensionalBoundary::Open,
1311                    boundary_conditions: BSplineBoundaryConditions::default(),
1312                },
1313            )
1314            .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1315            let exit_basis = build_bspline_basis_1d(
1316                log_exit.view(),
1317                &BSplineBasisSpec {
1318                    degree,
1319                    penalty_order: 2,
1320                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1321                    double_penalty: false,
1322                    identifiability: BSplineIdentifiability::None,
1323                    boundary: OneDimensionalBoundary::Open,
1324                    boundary_conditions: BSplineBoundaryConditions::default(),
1325                },
1326            )
1327            .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1328
1329            let p_time = exit_basis.design.ncols();
1330            // Build derivative basis as sparse triplets — B-spline derivatives
1331            // have the same local support as the basis itself (at most degree+1
1332            // nonzeros per row), so building dense first wastes memory.
1333            let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1334            let mut deriv_buf = vec![0.0_f64; p_time];
1335            for i in 0..n {
1336                deriv_buf.fill(0.0);
1337                evaluate_bspline_derivative_scalar(
1338                    log_exit[i],
1339                    knotvec.view(),
1340                    degree,
1341                    &mut deriv_buf,
1342                )
1343                .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1344                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1345                for j in 0..p_time {
1346                    let v = deriv_buf[j] * chain;
1347                    if v.abs() > 1e-15 {
1348                        deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1349                    }
1350                }
1351            }
1352            let x_derivative_time =
1353                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1354                {
1355                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1356                    Err(_) => {
1357                        // Fallback: build dense
1358                        let mut dense = Array2::<f64>::zeros((n, p_time));
1359                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1360                            dense[[row, col]] = val;
1361                        }
1362                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1363                    }
1364                };
1365
1366            Ok(SurvivalTimeBuildOutput {
1367                x_entry_time: entry_basis.design,
1368                x_exit_time: exit_basis.design,
1369                x_derivative_time,
1370                nullspace_dims: entry_basis.nullspace_dims,
1371                penalties: entry_basis.penalties,
1372                basisname: "bspline".to_string(),
1373                degree: Some(degree),
1374                knots: Some(knotvec.to_vec()),
1375                keep_cols: None,
1376                smooth_lambda: Some(smooth_lambda),
1377            })
1378        }
1379        SurvivalTimeBasisConfig::ISpline {
1380            degree,
1381            knots,
1382            keep_cols,
1383            smooth_lambda,
1384        } => {
1385            let bspline_degree = degree
1386                .checked_add(1)
1387                .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1388            let knotvec = if knots.is_empty() {
1389                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1390                    "internal error: ispline time basis requested without knot source".to_string()
1391                })?;
1392                let combined = survival_time_knot_input(&log_entry, &log_exit);
1393                infer_survival_time_knots(
1394                    &combined,
1395                    bspline_degree,
1396                    degree,
1397                    num_internal_knots,
1398                    BasisOptions::i_spline(),
1399                )?
1400            } else {
1401                knots
1402            };
1403
1404            let (db_exit_arc, _) = create_basis::<Dense>(
1405                log_exit.view(),
1406                KnotSource::Provided(knotvec.view()),
1407                bspline_degree,
1408                BasisOptions::first_derivative(),
1409            )
1410            .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1411
1412            // Build full-width I-spline bases inside a block scope so the
1413            // large Arc allocations are freed when the block ends.
1414            let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1415                let (entry_arc, _) = create_basis::<Dense>(
1416                    log_entry.view(),
1417                    KnotSource::Provided(knotvec.view()),
1418                    degree,
1419                    BasisOptions::i_spline(),
1420                )
1421                .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1422                let (exit_arc, _) = create_basis::<Dense>(
1423                    log_exit.view(),
1424                    KnotSource::Provided(knotvec.view()),
1425                    degree,
1426                    BasisOptions::i_spline(),
1427                )
1428                .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1429
1430                let x_entry_full = entry_arc.as_ref();
1431                let x_exit_full = exit_arc.as_ref();
1432                let p_time_full = x_exit_full.ncols();
1433                if p_time_full == 0 {
1434                    return Err(SurvivalConstructionError::BasisConstructionFailed {
1435                        reason: "internal error: empty ispline time basis".to_string(),
1436                    }
1437                    .into());
1438                }
1439                let db_exit = db_exit_arc.as_ref();
1440                if db_exit.ncols() != p_time_full + 1 {
1441                    return Err(
1442                        "internal error: ispline derivative basis width must exceed basis width by one"
1443                            .to_string(),
1444                    );
1445                }
1446
1447                let keep_cols = if keep_cols.is_empty() {
1448                    let constant_tol = 1e-12_f64;
1449                    let mut inferred_keep_cols: Vec<usize> = Vec::new();
1450                    for j in 0..p_time_full {
1451                        let mut minv = f64::INFINITY;
1452                        let mut maxv = f64::NEG_INFINITY;
1453                        for i in 0..n {
1454                            let ve = x_exit_full[[i, j]];
1455                            let vs = x_entry_full[[i, j]];
1456                            minv = minv.min(ve.min(vs));
1457                            maxv = maxv.max(ve.max(vs));
1458                        }
1459                        if (maxv - minv) > constant_tol {
1460                            inferred_keep_cols.push(j);
1461                        }
1462                    }
1463                    inferred_keep_cols
1464                } else {
1465                    keep_cols
1466                };
1467                if keep_cols.is_empty() {
1468                    return Err(
1469                        "internal error: ispline basis has no shape-varying time columns"
1470                            .to_string(),
1471                    );
1472                }
1473                if keep_cols.iter().any(|&j| j >= p_time_full) {
1474                    return Err(SurvivalConstructionError::MissingColumn {
1475                        reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1476                    }
1477                    .into());
1478                }
1479
1480                let p_time = keep_cols.len();
1481                let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1482                let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1483                // entry_arc and exit_arc go out of scope here, freeing the
1484                // full-width bases before derivative computation below.
1485                (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1486            };
1487            let db_exit = db_exit_arc.as_ref();
1488
1489            // Build I-spline derivative as sparse triplets.  The derivative
1490            // is a cumulative sum of B-spline derivatives and typically has
1491            // more nonzeros per row than a plain B-spline, but still much
1492            // fewer than p_time for modest bases.
1493            let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1494            let mut found_nonfinite: Option<(usize, usize)> = None;
1495            for i in 0..n {
1496                let mut running = 0.0_f64;
1497                let mut d_i_log_full = vec![0.0_f64; p_time_full];
1498                for j in (1..db_exit.ncols()).rev() {
1499                    let term = db_exit[[i, j]];
1500                    if term.is_finite() {
1501                        running += term;
1502                    }
1503                    d_i_log_full[j - 1] = running;
1504                }
1505                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1506                for (j_new, &j_old) in keep_cols.iter().enumerate() {
1507                    let raw_v = d_i_log_full[j_old] * chain;
1508                    let v = if (-1e-12..0.0).contains(&raw_v) {
1509                        0.0
1510                    } else {
1511                        raw_v
1512                    };
1513                    if !v.is_finite() {
1514                        found_nonfinite = Some((i, j_new));
1515                    }
1516                    if v < -1e-12 {
1517                        return Err(format!(
1518                            "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1519                            i + 1,
1520                            j_new + 1,
1521                            v
1522                        ));
1523                    }
1524                    if v.abs() > 1e-15 {
1525                        deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1526                    }
1527                }
1528            }
1529            if let Some((row, col)) = found_nonfinite {
1530                return Err(format!(
1531                    "survival ispline derivative basis produced non-finite value at row {}, column {}",
1532                    row + 1,
1533                    col + 1
1534                ));
1535            }
1536            let x_derivative_time =
1537                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1538                {
1539                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1540                    Err(_) => {
1541                        let mut dense = Array2::<f64>::zeros((n, p_time));
1542                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1543                            dense[[row, col]] = val;
1544                        }
1545                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1546                    }
1547                };
1548
1549            let penalty_basis = build_bspline_basis_1d(
1550                log_exit.view(),
1551                &BSplineBasisSpec {
1552                    degree: bspline_degree,
1553                    penalty_order: 2,
1554                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1555                    double_penalty: false,
1556                    identifiability: BSplineIdentifiability::None,
1557                    boundary: OneDimensionalBoundary::Open,
1558                    boundary_conditions: BSplineBoundaryConditions::default(),
1559                },
1560            )
1561            .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1562            if penalty_basis.design.ncols() != p_time_full + 1 {
1563                return Err("internal error: ispline penalty dimension mismatch".to_string());
1564            }
1565            // I-spline curvature penalty in the *value* space of the baseline
1566            // log-cumulative-hazard, restricted to the retained (non-dropped)
1567            // coefficient block.
1568            //
1569            // The I-spline coefficient γ is the consecutive increment of the B-spline
1570            // value coefficients `c`: `c_0 = 0`, `c_k = Σ_{j<k} γ_j = (L γ)_k`, where
1571            // `L` is the `p_time × p_time` lower-triangular cumsum matrix. The
1572            // second-difference penalty on the B-spline values is `S_B = D₂ᵀD₂`
1573            // (the `penalty_basis.penalties` block). The correct curvature penalty
1574            // on γ is the **value-space congruence transform**
1575            //
1576            //   `S_I = Lᵀ S_B[1:,1:] L`,
1577            //
1578            // which satisfies `γᵀ S_I γ = (Lγ)ᵀ S_B[1:,1:] (Lγ)`.
1579            //
1580            // A constant γ (γ_k = γ₀ ∀k) maps to the linear value sequence
1581            // `c_k = k·γ₀`, which is annihilated by D₂: `D₂c = 0`. Therefore
1582            // `γᵀ S_I γ = 0` for constant γ, i.e. the **affine trend lies in the
1583            // penalty null space**. REML does not penalize the baseline slope
1584            // `d(log Λ)/d(log t)` or the overall level, so it correctly lets the
1585            // data determine these quantities without bias. The previous increment-
1586            // space form `S_B[1:,1:]` (applied directly to γ instead of Lγ) did NOT
1587            // have constant γ in its null space and therefore over-penalized affine
1588            // baselines, causing the fitted log-cumulative-hazard to lose its tail
1589            // slope to the penalty and fail quality tests (#1076).
1590            //
1591            // The value-space form has a 1-dimensional null space (span{(1,…,1)}),
1592            // declared via `nullspace_dims` so the REML generalized-logdet picks it
1593            // up. The penalized inner PIRLS is well-conditioned because the
1594            // likelihood Hessian H_lik has O(n_events) curvature along the affine
1595            // direction (the overall baseline level is identified by the data), and
1596            // the global stabilization ridge (ridge_lambda) provides an absolute
1597            // positive-definite floor.
1598            let mut penalties = Vec::<Array2<f64>>::new();
1599            for s_mat in &penalty_basis.penalties {
1600                if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1601                    continue;
1602                }
1603                // I-spline value-space penalty, computed in the CORRECT order
1604                // (gam#979). The B-spline value coefficients are the cumulative
1605                // sum of the I-spline increment coefficients, `c = L γ_full`, where
1606                // `L` is the FULL `p_time_full × p_time_full` LOWER-triangular
1607                // all-ones cumsum matrix (`L[i,j] = 1 iff j ≤ i`, so
1608                // `c_i = Σ_{j≤i} γ_j`). The value-space curvature penalty on the
1609                // full increment vector is the symmetric congruence
1610                //
1611                //   `S_I_full = Lᵀ · S_B[1:,1:] · L`,
1612                //
1613                // which is PSD because `S_B[1:,1:]` is a principal submatrix of the
1614                // PSD `S_B = D₂ᵀD₂` and congruence by any matrix preserves PSD.
1615                //
1616                // CRITICAL ORDERING (the gam#979 indefiniteness bug): the retained
1617                // columns `keep_cols` must be selected as a PRINCIPAL SUBMATRIX of
1618                // the FULL congruence `S_I_full` — i.e. congruence FIRST, selection
1619                // SECOND. The previous code selected `keep_cols` from `S_B[1:,1:]`
1620                // first and then applied a `p_time × p_time` cumsum to that
1621                // already-reduced block. Because the cumsum `L` couples every
1622                // increment, restricting the increment index set BEFORE the cumsum
1623                // does NOT commute with it: the reduced operator is a different,
1624                // generally INDEFINITE matrix (measured `s0_min_eval = −9.8e7`),
1625                // which makes `½γᵀS_Iγ` unbounded below and the penalized survival
1626                // NLL diverge (β drifts up the negative-eigenvalue mode, the inner
1627                // joint-Newton follows the unbounded objective, the outer REML never
1628                // terminates — the #979 hang). Doing the congruence on the full γ
1629                // and then taking the `keep_cols` principal submatrix restores the
1630                // PSD guarantee (a principal submatrix of a PSD matrix is PSD).
1631                let s_increment = s_mat.slice(s![1.., 1..]);
1632                if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1633                    return Err(format!(
1634                        "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1635                        s_increment.nrows(),
1636                        s_increment.ncols(),
1637                    ));
1638                }
1639                // Symmetrize the (already-symmetric) source with the shared
1640                // matrix utility. The survival builder's value-space
1641                // congruence is domain-specific; only the low-level symmetric
1642                // cleanup is common with the generic and SAE construction code.
1643                let mut s_full = s_increment.to_owned();
1644                symmetrize_in_place(&mut s_full);
1645                // S_mid = S_B[1:,1:] · L  (right-multiply by lower-triangular
1646                // cumsum): (S·L)[i,j] = Σ_k S[i,k]·L[k,j] = Σ_{k≥j} S[i,k]
1647                // because L[k,j] = 1 iff j ≤ k.
1648                let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1649                for i in 0..p_time_full {
1650                    for j in 0..p_time_full {
1651                        let mut v = 0.0;
1652                        for k in j..p_time_full {
1653                            v += s_full[[i, k]];
1654                        }
1655                        s_mid_full[[i, j]] = v;
1656                    }
1657                }
1658                // S_I_full = Lᵀ · S_mid = Lᵀ · S · L:
1659                // (Lᵀ·S_mid)[i,j] = Σ_k Lᵀ[i,k]·S_mid[k,j] = Σ_{k≥i} S_mid[k,j]
1660                // because Lᵀ[i,k] = L[k,i] = 1 iff i ≤ k.
1661                let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1662                for i in 0..p_time_full {
1663                    for j in 0..p_time_full {
1664                        let mut v = 0.0;
1665                        for k in i..p_time_full {
1666                            v += s_mid_full[[k, j]];
1667                        }
1668                        s_full_congruent[[i, j]] = v;
1669                    }
1670                }
1671                // Principal submatrix on the retained (shape-varying) columns.
1672                let mut local = Array2::<f64>::zeros((p_time, p_time));
1673                for (i_new, &i_old) in keep_cols.iter().enumerate() {
1674                    for (j_new, &j_old) in keep_cols.iter().enumerate() {
1675                        // Symmetrize on the way out to absorb residual
1676                        // floating-point asymmetry.
1677                        local[[i_new, j_new]] = 0.5
1678                            * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1679                    }
1680                }
1681                penalties.push(local);
1682            }
1683
1684            // PSD contract (gam#979). The value-space congruence Lᵀ S_B[1:,1:] L,
1685            // restricted to a principal submatrix, is positive semidefinite by
1686            // construction. A negative eigenvalue here means the construction has
1687            // regressed to the increment-space / wrong-ordering form that made the
1688            // penalized survival NLL unbounded below (the #979 divergence). Verify
1689            // it here, at construction, so the defect can never silently reach the
1690            // inner solver again. The tolerance is the same relative scale the
1691            // nullspace detection below uses; a numerically tiny negative (round-off
1692            // on the genuine 1-D null direction) is allowed, a structural one is not.
1693            for (idx, s_mat) in penalties.iter().enumerate() {
1694                let p = s_mat.nrows();
1695                if p == 0 {
1696                    continue;
1697                }
1698                if let Ok((evals, _)) =
1699                    gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1700                {
1701                    let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1702                        "internal error: ispline penalty eigenvalues not contiguous".to_string()
1703                    })?;
1704                    let max_ev = evals_slice
1705                        .iter()
1706                        .copied()
1707                        .fold(0.0_f64, |a, b| a.max(b.abs()))
1708                        .max(1.0);
1709                    let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1710                    let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1711                    if min_ev < neg_tol {
1712                        return Err(format!(
1713                            "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1714                             indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1715                             {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1716                        ));
1717                    }
1718                }
1719            }
1720
1721            // The value-space penalty S_I = L^T S_B[1:,1:] L has a 1-dimensional
1722            // null space (constant γ ↦ affine c ↦ D₂c = 0). Detect it spectrally
1723            // so the REML uses the generalized logdet over the penalized subspace.
1724            let nullspace_dims: Vec<usize> = penalties
1725                .iter()
1726                .map(|s_mat| {
1727                    let p = s_mat.nrows();
1728                    if p == 0 {
1729                        return 0;
1730                    }
1731                    match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1732                        Ok((evals, _)) => {
1733                            let evals_slice: &[f64] = evals.as_slice().unwrap();
1734                            let max_ev = evals_slice
1735                                .iter()
1736                                .copied()
1737                                .fold(0.0_f64, |a, b| a.max(b.abs()))
1738                                .max(1.0);
1739                            let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1740                            evals_slice.iter().filter(|&&e| e <= threshold).count()
1741                        }
1742                        Err(_) => 0,
1743                    }
1744                })
1745                .collect();
1746            Ok(SurvivalTimeBuildOutput {
1747                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1748                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1749                x_derivative_time,
1750                penalties,
1751                nullspace_dims,
1752                basisname: "ispline".to_string(),
1753                degree: Some(degree),
1754                knots: Some(knotvec.to_vec()),
1755                keep_cols: Some(keep_cols),
1756                smooth_lambda: Some(smooth_lambda),
1757            })
1758        }
1759    }
1760}
1761
1762pub fn resolved_survival_time_basis_config_from_build(
1763    basisname: &str,
1764    degree: Option<usize>,
1765    knots: Option<&Vec<f64>>,
1766    keep_cols: Option<&Vec<usize>>,
1767    smooth_lambda: Option<f64>,
1768) -> Result<SurvivalTimeBasisConfig, String> {
1769    match basisname {
1770        "none" => Ok(SurvivalTimeBasisConfig::None),
1771        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1772        "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1773            degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1774            knots: Array1::from_vec(
1775                knots
1776                    .cloned()
1777                    .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1778            ),
1779            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1780        }),
1781        "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1782            degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1783            knots: Array1::from_vec(
1784                knots
1785                    .cloned()
1786                    .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1787            ),
1788            keep_cols: keep_cols
1789                .cloned()
1790                .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1791            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1792        }),
1793        other => Err(format!("unsupported survival time basis '{other}'")),
1794    }
1795}
1796
1797pub fn resolve_survival_time_anchor_value(
1798    age_entry: &Array1<f64>,
1799    time_anchor: Option<f64>,
1800) -> Result<f64, String> {
1801    if age_entry.is_empty() {
1802        return Err("survival time anchor requires non-empty entry times".to_string());
1803    }
1804    let anchor = match time_anchor {
1805        Some(t_anchor) => {
1806            if !t_anchor.is_finite() || t_anchor < 0.0 {
1807                return Err(format!(
1808                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1809                ));
1810            }
1811            t_anchor
1812        }
1813        None => age_entry
1814            .iter()
1815            .copied()
1816            .min_by(f64::total_cmp)
1817            .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1818    };
1819    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1820}
1821
1822/// Marginal-slope centering anchor: a robust *interior* time on the **exit**
1823/// scale rather than the earliest entry age.
1824///
1825/// `center_survival_time_designs_at_anchor` subtracts the time-basis row at the
1826/// anchor from every entry/exit design row, so the anchor sets the origin of
1827/// the baseline-hazard I-spline's affine reparameterization. The
1828/// location-scale path anchors at the minimum entry age
1829/// ([`resolve_survival_time_anchor_value`]); for right-censored-only data that
1830/// minimum is ≈ the time origin, so centering is nearly a no-op.
1831///
1832/// Under **left truncation** the minimum entry age is a genuine positive
1833/// *left-tail* point, and centering there leaves the centered linear-trend
1834/// column `X(exit) − X(anchor)` large and one-signed across all rows (exit
1835/// times sit far to the right of the earliest entry). That column is the
1836/// unpenalized polynomial null space of the 2nd-difference time penalty, so the
1837/// inflated, one-signed column multiplies the marginal-slope time-block score
1838/// at the `γ = 0` monotone-cone seed up by hundreds — the constrained joint
1839/// Newton cannot certify KKT on it and REML rejects every seed (issue #751).
1840///
1841/// Centering instead at a robust interior location on the *exit* scale — the
1842/// **median exit age**, where the at-risk mass concentrates — keeps the
1843/// centered column small and two-signed (some exits below the median, some
1844/// above), so the exit-event likelihood pins the linear trend and the seed
1845/// score stays bounded. Re-centering is an exact affine reparameterization of
1846/// the baseline offset: the fitted `q(t)` and the REML objective are unchanged,
1847/// only the seed conditioning improves. The median is chosen (over the mean)
1848/// for robustness to the heavy right tail of survival times.
1849///
1850/// An explicit `--survival-time-anchor` is honored verbatim (same validation as
1851/// the location-scale path) so the user retains full control; the saved
1852/// `survival_time_anchor` scalar round-trips to predict unchanged.
1853pub fn resolve_survival_marginal_slope_time_anchor_value(
1854    age_entry: &Array1<f64>,
1855    age_exit: &Array1<f64>,
1856    time_anchor: Option<f64>,
1857) -> Result<f64, String> {
1858    if age_entry.is_empty() || age_exit.is_empty() {
1859        return Err(
1860            "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1861        );
1862    }
1863    let anchor = match time_anchor {
1864        Some(t_anchor) => {
1865            if !t_anchor.is_finite() || t_anchor < 0.0 {
1866                return Err(format!(
1867                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1868                ));
1869            }
1870            t_anchor
1871        }
1872        None => robust_interior_exit_anchor(age_exit),
1873    };
1874    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1875}
1876
1877/// Median exit age — a robust interior time on the exit scale, where the
1878/// at-risk mass concentrates. Used as the survival time-basis centering anchor
1879/// whenever the earliest entry is a positive left-tail point (delayed entry):
1880/// centering there keeps the reparameterized linear-trend column small and
1881/// two-signed instead of large and one-signed. The median is chosen over the
1882/// mean for robustness to the heavy right tail of survival times.
1883fn robust_interior_exit_anchor(age_exit: &Array1<f64>) -> f64 {
1884    let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1885    sorted.sort_by(f64::total_cmp);
1886    let m = sorted.len();
1887    if m == 0 {
1888        return SURVIVAL_TIME_FLOOR;
1889    }
1890    if m % 2 == 1 {
1891        sorted[m / 2]
1892    } else {
1893        0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1894    }
1895}
1896
1897/// Centering anchor for the default transformation (Royston-Parmar) survival
1898/// baseline.
1899///
1900/// For right-censored-only data the earliest entry age is ≈ the time origin, so
1901/// [`resolve_survival_time_anchor_value`] (min entry) is nearly a no-op and is
1902/// used unchanged. Under **left truncation** (every row enters at a positive
1903/// delayed-entry time) that minimum is a genuine left-tail point far below the
1904/// exit mass, and centering the I-spline time basis there leaves the
1905/// unpenalized linear-trend column `X(exit) − X(anchor)` large and one-signed
1906/// across all rows. That column is the null space of the 2nd-difference time
1907/// penalty, so the inflated one-signed column blows up the transformation
1908/// smoothing-parameter selection: it rails a penalty direction and collapses the
1909/// baseline to a covariate-independent, cumulative-hazard-inflated degenerate
1910/// fit (issue #1790 — the transformation-model analogue of the marginal-slope
1911/// #751 defect). Anchoring instead at the robust interior **median exit age**
1912/// keeps the centered column small and two-signed so the exit-event likelihood
1913/// pins the linear trend. Re-centering is an exact affine reparameterization of
1914/// the baseline offset — the fitted `q(t)` and REML objective are unchanged,
1915/// only the seed conditioning improves. An explicit `time_anchor` is honored
1916/// verbatim.
1917pub fn resolve_survival_transformation_time_anchor_value(
1918    age_entry: &Array1<f64>,
1919    age_exit: &Array1<f64>,
1920    time_anchor: Option<f64>,
1921) -> Result<f64, String> {
1922    if time_anchor.is_some() {
1923        return resolve_survival_time_anchor_value(age_entry, time_anchor);
1924    }
1925    if age_exit.is_empty() {
1926        return Err(
1927            "survival transformation time anchor requires non-empty exit times".to_string(),
1928        );
1929    }
1930    let min_entry = age_entry
1931        .iter()
1932        .copied()
1933        .fold(f64::INFINITY, f64::min);
1934    if min_entry > SURVIVAL_DELAYED_ENTRY_THRESHOLD {
1935        Ok(robust_interior_exit_anchor(age_exit).max(SURVIVAL_TIME_FLOOR))
1936    } else {
1937        resolve_survival_time_anchor_value(age_entry, None)
1938    }
1939}
1940
1941pub fn evaluate_survival_time_basis_row(
1942    age: f64,
1943    cfg: &SurvivalTimeBasisConfig,
1944) -> Result<Array1<f64>, String> {
1945    if !age.is_finite() || age < 0.0 {
1946        return Err(format!(
1947            "survival time basis row requires finite non-negative age, got {age}"
1948        ));
1949    }
1950    let age = age.max(SURVIVAL_TIME_FLOOR);
1951    let log_age = array![age.ln()];
1952    match cfg {
1953        SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1954        SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1955        SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1956            if knots.is_empty() {
1957                return Err(
1958                    "survival BSpline anchor evaluation requires resolved knot metadata"
1959                        .to_string(),
1960                );
1961            }
1962            let built = build_bspline_basis_1d(
1963                log_age.view(),
1964                &BSplineBasisSpec {
1965                    degree: *degree,
1966                    penalty_order: 2,
1967                    knotspec: BSplineKnotSpec::Provided(knots.clone()),
1968                    double_penalty: false,
1969                    identifiability: BSplineIdentifiability::None,
1970                    boundary: OneDimensionalBoundary::Open,
1971                    boundary_conditions: BSplineBoundaryConditions::default(),
1972                },
1973            )
1974            .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1975            Ok(built.design.to_dense().row(0).to_owned())
1976        }
1977        SurvivalTimeBasisConfig::ISpline {
1978            degree,
1979            knots,
1980            keep_cols,
1981            ..
1982        } => {
1983            if knots.is_empty() {
1984                return Err(
1985                    "survival ISpline anchor evaluation requires resolved knot metadata"
1986                        .to_string(),
1987                );
1988            }
1989            let (basis_arc, _) = create_basis::<Dense>(
1990                log_age.view(),
1991                KnotSource::Provided(knots.view()),
1992                *degree,
1993                BasisOptions::i_spline(),
1994            )
1995            .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1996            let basis = basis_arc.as_ref();
1997            let row = basis.row(0);
1998            if keep_cols.is_empty() {
1999                return Ok(row.to_owned());
2000            }
2001            if keep_cols.iter().any(|&j| j >= row.len()) {
2002                return Err(SurvivalConstructionError::MissingColumn {
2003                    reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
2004                }
2005                .into());
2006            }
2007            Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
2008        }
2009    }
2010}
2011
2012pub fn center_survival_time_designs_at_anchor(
2013    design_entry: &mut DesignMatrix,
2014    design_exit: &mut DesignMatrix,
2015    anchor_row: &Array1<f64>,
2016) -> Result<(), String> {
2017    if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
2018        return Err(format!(
2019            "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
2020            design_entry.ncols(),
2021            design_exit.ncols(),
2022            anchor_row.len()
2023        ));
2024    }
2025    // Centering destroys sparsity (every row gets a dense offset), so
2026    // materialize to dense.  This only runs once at construction time.
2027    fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
2028        let mut dense = dm.to_dense();
2029        for mut row in dense.rows_mut() {
2030            row -= &anchor.view();
2031        }
2032        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
2033    }
2034    center_dense(design_entry, anchor_row);
2035    center_dense(design_exit, anchor_row);
2036    Ok(())
2037}
2038
2039// ---------------------------------------------------------------------------
2040// Baseline evaluation (Gompertz, Weibull, Gompertz-Makeham)
2041// ---------------------------------------------------------------------------
2042
2043/// Partial derivatives of the baseline offsets `(eta_target, d_eta_target/dt)`
2044/// with respect to the θ-parameters in the same parameterization that
2045/// [`survival_baseline_theta_from_config`] / [`survival_baseline_config_from_theta`]
2046/// use:
2047///
2048/// - **Weibull**: θ = (log_scale, log_shape).  `eta = shape·(log t − log scale)`,
2049///   `o_D = shape/t`.
2050/// - **Gompertz**: θ = (log_rate, shape).  `eta = log H_G(t)` with
2051///   `H_G(t) = (rate/shape)·(exp(shape·t) − 1)`, `o_D = h_G(t)/H_G(t) =
2052///   shape·E/(E−1)` where `E = exp(shape·t)`.
2053/// - **Gompertz–Makeham**: θ = (log_rate, shape, log_makeham).
2054///   `eta = log H(t)` with `H(t) = makeham·t + H_G(t)`,
2055///   `o_D = (makeham + h_G(t)) / H(t)`.
2056///
2057/// Returns a flat `(d_eta/dθ_k, d_oD/dθ_k)` pair for each component of θ,
2058/// in the same order as `survival_baseline_theta_from_config`.  Linear has
2059/// no θ-parameters so returns `Ok(None)`.
2060///
2061/// The `eta`-channel derivatives are closed-form for every branch.  The
2062/// `o_D`-channel derivatives use the log-derivative identity
2063/// `∂o_D/∂θ = o_D · ∂log(o_D)/∂θ` which is more numerically stable near
2064/// the small-shape limit (shape·t → 0).  Near shape = 0 we fall back to
2065/// a third-order Taylor expansion with the same 1e-10 pivot that
2066/// `evaluate_survival_baseline` uses, keeping the value/derivative pair
2067/// continuous and agreement with the linear-hazard limit exact at shape=0.
2068pub fn baseline_offset_theta_partials(
2069    age: f64,
2070    cfg: &SurvivalBaselineConfig,
2071) -> Result<Option<Vec<(f64, f64)>>, String> {
2072    let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2073    else {
2074        return Ok(None);
2075    };
2076
2077    match params {
2078        ValidatedBaselineTarget::Weibull { scale, shape } => {
2079            // eta = shape·(log t − log scale)
2080            //     = shape·log t − shape·log scale
2081            // o_D = shape / t
2082            //
2083            // θ = (log_scale, log_shape):
2084            //   ∂eta/∂log_scale  = −shape          ∂o_D/∂log_scale = 0
2085            //   ∂eta/∂log_shape  = shape·(log t − log scale) = eta
2086            //   ∂o_D/∂log_shape  = shape / t = o_D
2087            let eta = shape * (age.ln() - scale.ln());
2088            let o_d = shape / age;
2089            let d_eta_d_log_scale = -shape;
2090            let d_od_d_log_scale = 0.0;
2091            let d_eta_d_log_shape = eta;
2092            let d_od_d_log_shape = o_d;
2093            Ok(Some(vec![
2094                (d_eta_d_log_scale, d_od_d_log_scale),
2095                (d_eta_d_log_shape, d_od_d_log_shape),
2096            ]))
2097        }
2098        ValidatedBaselineTarget::Gompertz { shape, .. } => {
2099            // θ = (log_rate, shape):
2100            //   Rate cancels in o_D = h/H for Gompertz, so ∂o_D/∂log_rate = 0
2101            //   and ∂eta/∂log_rate = 1. The shape channel uses
2102            //     ∂eta/∂shape   = −1/shape + t·E/(E−1)
2103            //     ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2104            //     ∂o_D/∂shape  = o_D · ∂log(o_D)/∂shape
2105            //   Near shape=0 both numerators are 1/shape cancellations. Use
2106            //   Taylor expansions with the same 1e-10 pivot that
2107            //   gompertz_components uses in evaluate_survival_baseline.
2108            let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2109            Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2110        }
2111        ValidatedBaselineTarget::GompertzMakeham {
2112            rate,
2113            shape,
2114            makeham,
2115        } => {
2116            // H(t) = M·t + H_G(t),   H_G(t) = (rate/shape)·(E−1),  E = exp(shape·t)
2117            // h(t) = M + h_G(t),     h_G(t) = rate·E
2118            // o_D  = h/H
2119            //
2120            // θ = (log_rate, shape, log_makeham):
2121            //   ∂H/∂log_rate    = rate · ∂H/∂rate = H_G               (scales with rate)
2122            //   ∂H/∂shape       = H_G_shape                            (closed form below)
2123            //   ∂H/∂log_makeham = makeham · t                          (linear in makeham)
2124            //   ∂h/∂log_rate    = rate · ∂h/∂rate = h_G
2125            //   ∂h/∂shape       = h_G_shape = rate·t·E + 0              (= rate·t·E)
2126            //   ∂h/∂log_makeham = makeham
2127            //   ∂eta/∂θ = (∂H/∂θ) / H
2128            //   ∂o_D/∂θ = (∂h/∂θ − o_D · ∂H/∂θ) / H
2129            //           = (∂h/∂θ)/H − o_D · (∂H/∂θ)/H
2130            let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2131            let cum_total = makeham * age + cum_g;
2132            if cum_total <= 0.0 || !cum_total.is_finite() {
2133                return Err(SurvivalConstructionError::DataValidationFailed {
2134                    reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2135                }
2136                .into());
2137            }
2138            let inst_total = makeham + inst_g;
2139            let o_d = inst_total / cum_total;
2140            let inv_cum = 1.0 / cum_total;
2141            // Each channel: ∂cum/∂θ and ∂inst/∂θ → ∂eta/∂θ = ∂cum/∂θ / cum
2142            //                                       ∂o_D/∂θ = (∂inst/∂θ − o_D·∂cum/∂θ) / cum
2143            // log_rate channel: cum is linear in rate through H_G; ∂cum/∂rate = H_G/rate,
2144            //   so ∂cum/∂log_rate = H_G (= cum_g here). Similarly ∂inst/∂log_rate = h_G (= inst_g).
2145            let d_cum_dlr = cum_g;
2146            let d_inst_dlr = inst_g;
2147            let d_eta_dlr = d_cum_dlr * inv_cum;
2148            let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2149            // shape channel: only H_G and h_G have shape dependence.
2150            let (d_cum_dshape, d_inst_dshape) =
2151                gompertz_cumulative_shape_derivative(age, rate, shape);
2152            let d_eta_dshape = d_cum_dshape * inv_cum;
2153            let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2154            // log_makeham channel: cum contributes M·t, inst contributes M.
2155            //   ∂cum/∂log_makeham = makeham·t,  ∂inst/∂log_makeham = makeham.
2156            let d_cum_dlm = makeham * age;
2157            let d_inst_dlm = makeham;
2158            let d_eta_dlm = d_cum_dlm * inv_cum;
2159            let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2160            Ok(Some(vec![
2161                (d_eta_dlr, d_od_dlr),
2162                (d_eta_dshape, d_od_dshape),
2163                (d_eta_dlm, d_od_dlm),
2164            ]))
2165        }
2166    }
2167}
2168
2169/// Shared chain-rule θ-gradient contraction for baseline offsets.
2170///
2171/// Both [`baseline_chain_rule_gradient`] (RP eta offsets) and
2172/// [`marginal_slope_baseline_chain_rule_gradient`] (probit q-offsets) reduce to
2173/// the same contraction of [`OffsetChannelResiduals`] against per-age baseline
2174/// θ-partials; only the `partials` provider differs. This engine owns the length
2175/// checks, the θ-dim probe, the parallel per-row reduction, the entry gating, and
2176/// the error handling. Each provider returns, per age, a length-`theta_dim` vector
2177/// of `(∂eta/∂θ_k, ∂(d eta/dt)/∂θ_k)` pairs (or `(∂q/∂θ_k, ∂(dq/dt)/∂θ_k)` for the
2178/// probit channel), and `None` when `cfg` has no θ-parameters (`Linear` target).
2179///
2180/// Contract (envelope theorem at converged β; the penalty has no θ dependence):
2181///
2182///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2183///     = Σᵢ r_X[i]·(∂o_X_i/∂θ_k) + r_D[i]·(∂o_D_i/∂θ_k) + r_E[i]·(∂o_E_i/∂θ_k)
2184///       + r_R[i]·(∂o_R_i/∂θ_k)
2185///
2186/// where `r_X = residuals.exit`, `r_D = residuals.derivative`, `r_E =
2187/// residuals.entry`, `r_R = residuals.right` (all sampleweight-scaled already).
2188/// Exit and derivative partials both come from the `age_exit[i]` evaluation;
2189/// the entry partial from `age_entry[i]`; the interval upper-bound (`R`)
2190/// η-partial from `age_right[i]`. Origin-entry rows have `r_E[i] == 0` exactly
2191/// and non-interval rows have `r_R[i] == 0` exactly, so those partials are
2192/// skipped for those rows (avoiding the `age > 0` precondition failure when an
2193/// inactive boundary age is 0 / a placeholder).
2194///
2195/// Returns `Ok(None)` when the provider reports no θ-parameters.
2196fn baseline_chain_rule_gradient_with_partials<F>(
2197    label: &'static str,
2198    age_entry: ndarray::ArrayView1<'_, f64>,
2199    age_exit: ndarray::ArrayView1<'_, f64>,
2200    age_right: ndarray::ArrayView1<'_, f64>,
2201    cfg: &SurvivalBaselineConfig,
2202    residuals: &crate::survival::OffsetChannelResiduals,
2203    partials: F,
2204) -> Result<Option<Array1<f64>>, String>
2205where
2206    F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2207{
2208    let n = age_exit.len();
2209    if age_entry.len() != n
2210        || age_right.len() != n
2211        || residuals.exit.len() != n
2212        || residuals.entry.len() != n
2213        || residuals.derivative.len() != n
2214        || residuals.right.len() != n
2215    {
2216        return Err(format!(
2217            "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2218            age_entry.len(),
2219            n,
2220            age_right.len(),
2221            residuals.exit.len(),
2222            residuals.entry.len(),
2223            residuals.derivative.len(),
2224            residuals.right.len(),
2225        ));
2226    }
2227    // Probe θ-dim via any valid positive age. If the provider returns None the
2228    // config carries no θ-parameters (Linear target) and there is no θ-gradient.
2229    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2230    let theta_dim = match probe_age {
2231        Some(t) => match partials(t, cfg)? {
2232            None => return Ok(None),
2233            Some(v) => v.len(),
2234        },
2235        None => {
2236            return Err(format!("{label}: no valid positive age for dim probe"));
2237        }
2238    };
2239    // Per-row partial contractions are independent, but each row's
2240    // contribution is a `theta_dim`-vector of `O(theta_dim · partial_cost)`
2241    // flops — small enough that the rayon parallel reduction's split
2242    // overhead dominates for any plausible `theta_dim`, *and* the
2243    // non-associative IEEE-754 sum order across thread chunks made the
2244    // engine drift in the low-order bits from row to row. The serial
2245    // accumulator below mirrors the inline reference exactly (and remains
2246    // ~memory-bandwidth-bound at large-scale `n`), so the engine is now a
2247    // bit-for-bit replacement for the legacy path, not just a
2248    // floating-point-noise-equivalent one.
2249    let mut grad = Array1::<f64>::zeros(theta_dim);
2250    for i in 0..n {
2251        // Exit + derivative partials both come from the age_exit evaluation.
2252        let partials_exit = partials(age_exit[i], cfg)?
2253            .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2254        if partials_exit.len() != theta_dim {
2255            return Err(format!(
2256                "{label}: theta_dim drifted ({} != {})",
2257                partials_exit.len(),
2258                theta_dim
2259            ));
2260        }
2261        let r_x = residuals.exit[i];
2262        let r_d = residuals.derivative[i];
2263        for k in 0..theta_dim {
2264            let (d_eta_dk, d_od_dk) = partials_exit[k];
2265            grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2266        }
2267        // Entry channel is nonzero only for rows with a positive entry
2268        // interval; for origin-entry rows age_entry may be 0 and calling
2269        // the provider would error. Gate on residual==0.
2270        let r_e = residuals.entry[i];
2271        if r_e != 0.0 {
2272            let partials_entry = partials(age_entry[i], cfg)?
2273                .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2274            for k in 0..theta_dim {
2275                grad[k] += r_e * partials_entry[k].0;
2276            }
2277        }
2278        // Interval upper-bound (`R`) channel: `q_right = X_time(R)·β + o_R(θ)`
2279        // carries its own baseline-θ η-offset evaluated at `age_right[i]`. It is
2280        // an η-level offset with NO time-derivative channel (the interval
2281        // likelihood `log[S(L) − S(R)]` has no hazard-derivative term), so it
2282        // contracts against the η-partial `.0` only. Nonzero only for
2283        // interval-censored latent rows; for every other channel/model
2284        // `r_right[i] == 0` exactly, so the (possibly placeholder) `age_right[i]`
2285        // partial is never consulted.
2286        let r_r = residuals.right[i];
2287        if r_r != 0.0 {
2288            let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2289                format!("{label}: unexpected None from partials at right boundary")
2290            })?;
2291            if partials_right.len() != theta_dim {
2292                return Err(format!(
2293                    "{label}: theta_dim drifted at right boundary ({} != {})",
2294                    partials_right.len(),
2295                    theta_dim
2296                ));
2297            }
2298            for k in 0..theta_dim {
2299                grad[k] += r_r * partials_right[k].0;
2300            }
2301        }
2302    }
2303    Ok(Some(grad))
2304}
2305
2306/// Contract `OffsetChannelResiduals` against `baseline_offset_theta_partials`
2307/// to produce the closed-form θ-gradient of the unpenalized NLL at converged β.
2308///
2309/// Derivation (envelope theorem on the penalized objective, β* minimizes the
2310/// same cost wrt β and the penalty has no θ dependence):
2311///
2312///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2313///     = d[NLL(β*; o(θ))] / dθ_k
2314///     = Σᵢ (∂NLL_i/∂o_X[i])·(∂o_X_i/∂θ_k)
2315///       + (∂NLL_i/∂o_E[i])·(∂o_E_i/∂θ_k)
2316///       + (∂NLL_i/∂o_D[i])·(∂o_D_i/∂θ_k)
2317///       + (∂NLL_i/∂o_R[i])·(∂o_R_i/∂θ_k)
2318///
2319/// The four `∂NLL_i/∂o_channel` terms are the `exit`, `entry`, `derivative`,
2320/// `right` fields of [`OffsetChannelResiduals`] (sampleweight-scaled already).
2321/// The `∂o/∂θ_k` terms come from [`baseline_offset_theta_partials`] per obs at
2322/// the appropriate age.
2323///
2324/// Per the RP offset convention:
2325///   o_E[i] = eta_target(age_entry[i])
2326///   o_X[i] = eta_target(age_exit[i])
2327///   o_D[i] = d/dt eta_target(t) |_{t=age_exit[i]}
2328///   o_R[i] = eta_target(age_right[i])   (interval upper bound `R`; η-level only)
2329///
2330/// so the exit and derivative partials are both evaluated at `age_exit[i]`,
2331/// the entry partial at `age_entry[i]`, and the interval-right η-partial at
2332/// `age_right[i]`. The origin-entry case (`entry_at_origin[i]`) has
2333/// `r_entry[i] = 0` exactly and every non-interval row has `r_right[i] = 0`
2334/// exactly, so we skip the `baseline_offset_theta_partials(age, ..)` call for
2335/// those rows (avoiding the `age > 0` precondition failure when an inactive
2336/// boundary age is 0 / a placeholder).
2337///
2338/// Returns `Ok(None)` when `cfg.target == Linear` (no θ-parameters).
2339pub fn baseline_chain_rule_gradient(
2340    age_entry: ndarray::ArrayView1<'_, f64>,
2341    age_exit: ndarray::ArrayView1<'_, f64>,
2342    age_right: ndarray::ArrayView1<'_, f64>,
2343    cfg: &SurvivalBaselineConfig,
2344    residuals: &crate::survival::OffsetChannelResiduals,
2345) -> Result<Option<Array1<f64>>, String> {
2346    baseline_chain_rule_gradient_with_partials(
2347        "baseline_chain_rule_gradient",
2348        age_entry,
2349        age_exit,
2350        age_right,
2351        cfg,
2352        residuals,
2353        baseline_offset_theta_partials,
2354    )
2355}
2356
2357/// Chain-rule θ-gradient for marginal-slope probit baseline offsets.
2358///
2359/// This is the probit-survival counterpart of [`baseline_chain_rule_gradient`].
2360/// It contracts residuals against
2361/// [`marginal_slope_baseline_offset_theta_partials`], so the offset channels
2362/// are `(q_entry, q_exit, dq_exit/dt)` with `Phi(-q(t)) = exp(-H0(t))`.
2363pub fn marginal_slope_baseline_chain_rule_gradient(
2364    age_entry: ndarray::ArrayView1<'_, f64>,
2365    age_exit: ndarray::ArrayView1<'_, f64>,
2366    cfg: &SurvivalBaselineConfig,
2367    residuals: &crate::survival::OffsetChannelResiduals,
2368) -> Result<Option<Array1<f64>>, String> {
2369    // Marginal-slope has no interval upper-bound channel; `residuals.right` is
2370    // all-zero, so the right channel never contracts and `age_exit` serves as an
2371    // unconsulted placeholder for the (unused) `age_right` argument.
2372    baseline_chain_rule_gradient_with_partials(
2373        "marginal_slope_baseline_chain_rule_gradient",
2374        age_entry,
2375        age_exit,
2376        age_exit,
2377        cfg,
2378        residuals,
2379        marginal_slope_baseline_offset_theta_partials,
2380    )
2381}
2382
2383/// Shared Gompertz hazard components `(H_G(t), h_G(t))`.
2384/// Mirrors the private helper in `evaluate_survival_baseline` with the
2385/// same 1e-10 small-shape pivot.
2386#[inline]
2387fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2388    if shape.abs() < 1e-10 {
2389        // Taylor at shape=0: H_G(t) = rate·t·(1 + shape·t/2 + (shape·t)²/6),
2390        // h_G(t) = rate·(1 + shape·t + (shape·t)²/2).
2391        let x = shape * age;
2392        (
2393            rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2394            rate * (1.0 + x + 0.5 * x * x),
2395        )
2396    } else {
2397        let shape_age = shape * age;
2398        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2399        let instant_hazard = rate * shape_age.exp();
2400        (cumulative_hazard, instant_hazard)
2401    }
2402}
2403
2404/// Partials of `(H_G(t), h_G(t))` with respect to the shape parameter.
2405///
2406/// H_G(t) = (rate/shape)·(E−1),  h_G(t) = rate·E,  E = exp(shape·t)
2407///
2408/// ∂H_G/∂shape  = −(rate/shape²)·(E−1) + (rate/shape)·t·E
2409///              = rate·[t·E/shape − (E−1)/shape²]
2410///              = rate·[t·E·shape − (E−1)] / shape²
2411/// ∂h_G/∂shape  = rate·t·E
2412///
2413/// Near shape=0 the first expression has a 1/shape² singularity that
2414/// cancels analytically. Using the series E−1 = Σₖ≥₁ (shape·t)ᵏ/k!:
2415///   t·E·shape − (E−1) = Σₖ≥₁ (shape·t)ᵏ·(k−1)/k!·shape⁰  [after simplification]
2416///                     = (shape·t)²/2 + 2(shape·t)³/6 + 3(shape·t)⁴/24 + ...
2417/// so ∂H_G/∂shape at shape→0 = rate·[t²/2 + shape·t³/3 + shape²·t⁴/8 + ...].
2418/// We use that Taylor expansion in the small-shape branch.
2419#[inline]
2420fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2421    let x = shape * age;
2422    let dinstg_dshape = rate * age * x.exp();
2423    // The exact form rate·[t·E·shape − (E−1)]/shape² is a difference of two
2424    // O(1/shape) quantities whose leading terms cancel, so its accuracy is
2425    // governed by the dimensionless product x = shape·age, NOT by `shape`
2426    // alone. Pivoting on `shape < 1e-10` ignored `age`: for large ages a small
2427    // shape still yields a small x where the catastrophic cancellation has
2428    // already corrupted the difference. Pivot on x instead; the 3-term Taylor
2429    // (through O(x²)) is accurate to <1e-9 for |x| < 1e-4, and the exact branch
2430    // is clean above it.
2431    let dhg_dshape = if x.abs() < 1e-4 {
2432        let t = age;
2433        // Truncated to O(x³): t²/2 + x·t²/3 + x²·t²/8
2434        rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2435    } else {
2436        // t·E·shape − (E−1) = t·e^x·shape − expm1(x)
2437        let e = x.exp();
2438        let em1 = x.exp_m1();
2439        let numerator = age * e * shape - em1;
2440        rate * numerator / (shape * shape)
2441    };
2442    (dhg_dshape, dinstg_dshape)
2443}
2444
2445/// Partials `(∂eta/∂shape, ∂o_D/∂shape)` for the pure Gompertz baseline.
2446/// Pure Gompertz has rate cancelling in o_D, so there is no log_rate
2447/// contribution in o_D. The rate channel for eta is trivially 1; this
2448/// helper only covers the shape channel.
2449#[inline]
2450fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2451    if shape.abs() < 1e-10 {
2452        // Closed-form limits from the series t·E/(E−1) = 1/x + 1/2 + x/12 + ...
2453        // with E = e^x, x = shape·t:
2454        //   ∂eta/∂shape  = −1/shape + t·E/(E−1)
2455        //                = t/2 + shape·t²/12 + O(shape²)
2456        //   o_D         = shape·E/(E−1)
2457        //                = 1/t + shape/2 + shape²·t/12 + O(shape³)
2458        //   ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2459        //                = t/2 − shape·t²/12 + O(shape²)
2460        //   ∂o_D/∂shape = o_D · ∂log(o_D)/∂shape
2461        let t = age;
2462        let d_eta = 0.5 * t + shape * t * t / 12.0;
2463        let dlog_od = 0.5 * t - shape * t * t / 12.0;
2464        let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2465        (d_eta, o_d * dlog_od)
2466    } else {
2467        let x = shape * age;
2468        let e = x.exp();
2469        let em1 = x.exp_m1(); // E − 1 via expm1 for accuracy at small x
2470        let d_eta = -1.0 / shape + age * e / em1;
2471        // o_D = shape · E/(E−1); ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2472        let o_d = shape * e / em1;
2473        let dlog_od = 1.0 / shape - age / em1;
2474        (d_eta, o_d * dlog_od)
2475    }
2476}
2477
2478/// Per-target baseline parameters after the shared age guard and the per-target
2479/// required-field extraction + finiteness/positivity validation have passed.
2480///
2481/// This is the single source of truth for *which* config fields each baseline
2482/// target requires and *what* domain each must satisfy. Both the hazard-value
2483/// evaluator (`survival_cumulative_and_instant_hazard`) and the θ-partials
2484/// evaluator (`survival_hazard_theta_partials`) consume it and only differ in how
2485/// they assemble their (value vs derivative) outputs from these checked scalars.
2486#[derive(Clone, Copy, Debug)]
2487enum ValidatedBaselineTarget {
2488    Weibull { scale: f64, shape: f64 },
2489    Gompertz { rate: f64, shape: f64 },
2490    GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2491}
2492
2493/// Shared prologue for the survival baseline hazard evaluators: validate the age,
2494/// then extract and domain-check the per-target parameters from `cfg`.
2495///
2496/// `Ok(None)` is the `Linear` target (no parametric baseline). `context` is woven
2497/// into the age-guard error so each caller keeps its specific phrasing.
2498fn validated_baseline_params(
2499    age: f64,
2500    cfg: &SurvivalBaselineConfig,
2501    context: &str,
2502) -> Result<Option<ValidatedBaselineTarget>, String> {
2503    if !age.is_finite() || age <= 0.0 {
2504        return Err(format!(
2505            "survival ages must be finite and positive for {context}"
2506        ));
2507    }
2508
2509    match cfg.target {
2510        SurvivalBaselineTarget::Linear => Ok(None),
2511        SurvivalBaselineTarget::Weibull => {
2512            let scale = cfg
2513                .scale
2514                .ok_or_else(|| "weibull missing scale".to_string())?;
2515            let shape = cfg
2516                .shape
2517                .ok_or_else(|| "weibull missing shape".to_string())?;
2518            if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2519                return Err(SurvivalConstructionError::InvalidConfig {
2520                    reason: "weibull baseline requires finite positive scale and shape".to_string(),
2521                }
2522                .into());
2523            }
2524            Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2525        }
2526        SurvivalBaselineTarget::Gompertz => {
2527            let rate = cfg
2528                .rate
2529                .ok_or_else(|| "gompertz missing rate".to_string())?;
2530            let shape = cfg
2531                .shape
2532                .ok_or_else(|| "gompertz missing shape".to_string())?;
2533            if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2534                return Err(
2535                    "gompertz baseline requires finite positive rate and finite shape".to_string(),
2536                );
2537            }
2538            Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2539        }
2540        SurvivalBaselineTarget::GompertzMakeham => {
2541            let rate = cfg
2542                .rate
2543                .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2544            let shape = cfg
2545                .shape
2546                .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2547            let makeham = cfg
2548                .makeham
2549                .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2550            if !(rate.is_finite()
2551                && shape.is_finite()
2552                && makeham.is_finite()
2553                && rate > 0.0
2554                && makeham > 0.0)
2555            {
2556                return Err(
2557                    "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2558                        .to_string(),
2559                );
2560            }
2561            Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2562                rate,
2563                shape,
2564                makeham,
2565            }))
2566        }
2567    }
2568}
2569
2570fn survival_hazard_theta_partials(
2571    age: f64,
2572    cfg: &SurvivalBaselineConfig,
2573) -> Result<Option<Vec<(f64, f64)>>, String> {
2574    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2575        return Ok(None);
2576    };
2577
2578    match params {
2579        ValidatedBaselineTarget::Weibull { scale, shape } => {
2580            let log_time_ratio = age.ln() - scale.ln();
2581            let cumulative_hazard = (age / scale).powf(shape);
2582            let instant_hazard = shape * cumulative_hazard / age;
2583            let eta = shape * log_time_ratio;
2584            Ok(Some(vec![
2585                (-shape * cumulative_hazard, -shape * instant_hazard),
2586                (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2587            ]))
2588        }
2589        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2590            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2591            let (d_cum_dshape, d_inst_dshape) =
2592                gompertz_cumulative_shape_derivative(age, rate, shape);
2593            Ok(Some(vec![
2594                (cumulative_hazard, instant_hazard),
2595                (d_cum_dshape, d_inst_dshape),
2596            ]))
2597        }
2598        ValidatedBaselineTarget::GompertzMakeham {
2599            rate,
2600            shape,
2601            makeham,
2602        } => {
2603            let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2604            let (d_cum_dshape, d_inst_dshape) =
2605                gompertz_cumulative_shape_derivative(age, rate, shape);
2606            Ok(Some(vec![
2607                (cum_gompertz, inst_gompertz),
2608                (d_cum_dshape, d_inst_dshape),
2609                (makeham * age, makeham),
2610            ]))
2611        }
2612    }
2613}
2614
2615fn survival_cumulative_and_instant_hazard(
2616    age: f64,
2617    cfg: &SurvivalBaselineConfig,
2618) -> Result<Option<(f64, f64)>, String> {
2619    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2620        return Ok(None);
2621    };
2622
2623    match params {
2624        ValidatedBaselineTarget::Weibull { scale, shape } => {
2625            let cumulative_hazard = (age / scale).powf(shape);
2626            let instant_hazard = shape * cumulative_hazard / age;
2627            Ok(Some((cumulative_hazard, instant_hazard)))
2628        }
2629        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2630            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2631            Ok(Some((cumulative_hazard, instant_hazard)))
2632        }
2633        ValidatedBaselineTarget::GompertzMakeham {
2634            rate,
2635            shape,
2636            makeham,
2637        } => {
2638            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2639            Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2640        }
2641    }
2642}
2643
2644#[derive(Clone, Copy, Debug)]
2645struct MarginalSlopeBaselinePoint {
2646    instant_hazard: f64,
2647    q: f64,
2648    q_t: f64,
2649}
2650
2651fn evaluate_marginal_slope_baseline_point(
2652    age: f64,
2653    cfg: &SurvivalBaselineConfig,
2654) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2655    let Some((cumulative_hazard, instant_hazard)) =
2656        survival_cumulative_and_instant_hazard(age, cfg)?
2657    else {
2658        return Ok(None);
2659    };
2660    if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2661        return Err(format!(
2662            "{} marginal-slope baseline produced non-positive cumulative hazard",
2663            survival_baseline_targetname(cfg.target)
2664        ));
2665    }
2666    if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2667        return Err(format!(
2668            "{} marginal-slope baseline produced non-positive instant hazard",
2669            survival_baseline_targetname(cfg.target)
2670        ));
2671    }
2672    let survival = (-cumulative_hazard).exp();
2673    if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2674        return Err(format!(
2675            "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2676            survival_baseline_targetname(cfg.target)
2677        ));
2678    }
2679    let q = -standard_normal_quantile(survival).map_err(|e| {
2680        format!(
2681            "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2682            survival_baseline_targetname(cfg.target)
2683        )
2684    })?;
2685    let phi_q = normal_pdf(q);
2686    if !(phi_q.is_finite() && phi_q > 0.0) {
2687        return Err(format!(
2688            "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2689            survival_baseline_targetname(cfg.target)
2690        ));
2691    }
2692    Ok(Some(MarginalSlopeBaselinePoint {
2693        instant_hazard,
2694        q,
2695        q_t: instant_hazard * survival / phi_q,
2696    }))
2697}
2698
2699/// Evaluate the parametric baseline target at a given age.
2700/// Returns `(eta_target(age), d eta_target / d age)` on the log-cumulative-hazard scale.
2701pub fn evaluate_survival_baseline(
2702    age: f64,
2703    cfg: &SurvivalBaselineConfig,
2704) -> Result<(f64, f64), String> {
2705    if !age.is_finite() || age < 0.0 {
2706        return Err(
2707            "survival ages must be finite and non-negative for baseline target evaluation"
2708                .to_string(),
2709        );
2710    }
2711
2712    // At t = 0 every parametric cumulative-hazard target satisfies H(0) = 0
2713    // exactly (this is the defining property of a cumulative hazard:
2714    // S(0) = 1 ⇒ H(0) = -log S(0) = 0). The log-cumulative-hazard offset is
2715    // therefore eta(0) = log H(0) = -inf, and we report a zero log-derivative
2716    // since `exp(eta(0)) = H(0) = 0` is the only physically valid value.
2717    // Returning `Ok((-inf, 0.0))` keeps the baseline cumulative hazard exactly
2718    // zero at the origin; downstream callers that need to multiply this offset
2719    // into a linear predictor are responsible for handling the origin row via
2720    // the `entry_at_origin` / `exit_at_origin` gating already wired through the
2721    // engine.
2722    if age == 0.0 {
2723        return match cfg.target {
2724            SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2725            SurvivalBaselineTarget::Weibull
2726            | SurvivalBaselineTarget::Gompertz
2727            | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2728        };
2729    }
2730
2731    let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2732        return Ok((0.0, 0.0));
2733    };
2734
2735    match params {
2736        ValidatedBaselineTarget::Weibull { scale, shape } => {
2737            let eta = shape * (age.ln() - scale.ln());
2738            let derivative = shape / age;
2739            Ok((eta, derivative))
2740        }
2741        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2742            let (h, inst) = gompertz_hazard_components(age, rate, shape);
2743            if h <= 0.0 || !h.is_finite() {
2744                return Err(if shape.abs() < 1e-10 {
2745                    "invalid gompertz baseline at near-zero shape".to_string()
2746                } else {
2747                    "gompertz baseline produced non-positive cumulative hazard".to_string()
2748                });
2749            }
2750            let derivative = inst / h;
2751            Ok((h.ln(), derivative))
2752        }
2753        ValidatedBaselineTarget::GompertzMakeham {
2754            rate,
2755            shape,
2756            makeham,
2757        } => {
2758            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2759            let h = makeham * age + h_gompertz;
2760            if h <= 0.0 || !h.is_finite() {
2761                return Err(
2762                    "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2763                );
2764            }
2765            let inst = makeham + inst_gompertz;
2766            let derivative = inst / h;
2767            Ok((h.ln(), derivative))
2768        }
2769    }
2770}
2771
2772/// Evaluate the parametric baseline as the probit index whose marginal
2773/// survival is the true hazard survival `exp(-H0(t))`.
2774///
2775/// Returns `(q(age), dq / d age)` such that `Phi(-q(age)) = exp(-H0(age))`.
2776/// The derivative is `h0(t) * exp(-H0(t)) / phi(q(t))`.
2777pub fn evaluate_survival_marginal_slope_baseline(
2778    age: f64,
2779    cfg: &SurvivalBaselineConfig,
2780) -> Result<(f64, f64), String> {
2781    // Survival-curve origin. Every cumulative-hazard baseline satisfies
2782    // `H0(0) = 0` (`S0(0) = exp(-H0(0)) = 1`), so the probit index
2783    // `q(0) = -Phi^{-1}(S0(0)) = -Phi^{-1}(1) = -inf`: there is no *finite*
2784    // probit-survival offset at the origin. The survival surface anchors
2785    // `S(0) = 1` directly (see the `t <= 0` origin handling in the survival
2786    // predict paths), so the baseline contributes nothing here — report the
2787    // zero offset rather than aborting in the `age <= 0` hazard guard. This
2788    // mirrors `evaluate_survival_baseline`'s explicit `age == 0` branch on the
2789    // log-cumulative-hazard channel; without it the probit/marginal-slope
2790    // baseline path (location-scale + marginal-slope likelihoods) could not be
2791    // evaluated on a prediction grid whose first node is the origin (#1024).
2792    if age == 0.0 {
2793        return Ok((0.0, 0.0));
2794    }
2795    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2796        return Ok((0.0, 0.0));
2797    };
2798    Ok((point.q, point.q_t))
2799}
2800
2801/// Partial derivatives of the true survival marginal-slope probit offsets
2802/// `(q(t), dq(t)/dt)` with respect to the baseline θ-parameters.
2803///
2804/// The returned channels match `survival_baseline_theta_from_config`.  For
2805/// Gompertz-Makeham, θ is `(log_rate, shape, log_makeham)`.  If
2806/// `S(t)=exp(-H(t))`, `q(t)=-Phi^-1(S(t))`, `A(t)=S(t)/phi(q(t))`, and
2807/// `h(t)=dH/dt`, then
2808///
2809///   dq/dθ      = A * dH/dθ
2810///   d(q')/dθ   = A * (dh/dθ + h * (q*A - 1) * dH/dθ)
2811///
2812/// which keeps the probit transform and the hazard baseline analytically tied.
2813pub fn marginal_slope_baseline_offset_theta_partials(
2814    age: f64,
2815    cfg: &SurvivalBaselineConfig,
2816) -> Result<Option<Vec<(f64, f64)>>, String> {
2817    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2818        return Ok(None);
2819    };
2820    let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2821        .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2822    let a = point.q_t / point.instant_hazard;
2823    let a_log_derivative_factor = point.q * a - 1.0;
2824    Ok(Some(
2825        hazard_partials
2826            .into_iter()
2827            .map(|(d_h_cum, d_h_inst)| {
2828                (
2829                    a * d_h_cum,
2830                    a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2831                )
2832            })
2833            .collect(),
2834    ))
2835}
2836
2837/// Contract marginal-slope offset residuals and channel curvatures into the
2838/// exact Hessian with respect to baseline θ-parameters.
2839pub fn marginal_slope_baseline_chain_rule_hessian(
2840    age_entry: ndarray::ArrayView1<'_, f64>,
2841    age_exit: ndarray::ArrayView1<'_, f64>,
2842    cfg: &SurvivalBaselineConfig,
2843    residuals: &crate::survival::OffsetChannelResiduals,
2844    curvatures: &crate::survival::OffsetChannelCurvatures,
2845) -> Result<Option<Array2<f64>>, String> {
2846    let n = age_exit.len();
2847    if age_entry.len() != n
2848        || residuals.exit.len() != n
2849        || residuals.entry.len() != n
2850        || residuals.derivative.len() != n
2851        || curvatures.rows.len() != n
2852    {
2853        return Err(format!(
2854            "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2855            age_entry.len(),
2856            n,
2857            residuals.exit.len(),
2858            residuals.entry.len(),
2859            residuals.derivative.len(),
2860            curvatures.rows.len(),
2861        ));
2862    }
2863    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2864    let dim = match probe_age {
2865        Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2866            None => return Ok(None),
2867            Some(parts) => parts.first.len(),
2868        },
2869        None => {
2870            return Err(
2871                "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2872                    .to_string(),
2873            );
2874        }
2875    };
2876    // Per-row Hessian contractions are independent. Each row contributes a
2877    // dim×dim increment combining second partials (exit/entry channels) with
2878    // the curvature-weighted outer product of the (entry, exit, derivative)
2879    // first-partial Jacobians. Fixed row chunks are combined in chunk-index
2880    // order so floating-point addition stays deterministic across Rayon
2881    // scheduling decisions.
2882    let hessian = RowSet::All.par_try_reduce_fold(
2883        n,
2884        || Array2::<f64>::zeros((dim, dim)),
2885        |mut acc, i, _row_weight| -> Result<Array2<f64>, String> {
2886            let exit_parts =
2887                marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2888                    .ok_or_else(|| {
2889                        "unexpected None from marginal-slope second partials at exit".to_string()
2890                    })?;
2891            if exit_parts.first.len() != dim {
2892                return Err(
2893                    "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2894                );
2895            }
2896            let mut entry_parts = None;
2897            if residuals.entry[i] != 0.0 {
2898                entry_parts = Some(
2899                    marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2900                        .ok_or_else(|| {
2901                            "unexpected None from marginal-slope second partials at entry"
2902                                .to_string()
2903                        })?,
2904                );
2905            }
2906            for a in 0..dim {
2907                for b in 0..dim {
2908                    let j_exit_a = exit_parts.first[a].0;
2909                    let j_exit_b = exit_parts.first[b].0;
2910                    let j_deriv_a = exit_parts.first[a].1;
2911                    let j_deriv_b = exit_parts.first[b].1;
2912                    let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2913                        + residuals.derivative[i] * exit_parts.second[a][b].1;
2914                    if let Some(parts) = entry_parts.as_ref() {
2915                        value += residuals.entry[i] * parts.second[a][b].0;
2916                    }
2917                    let curv = curvatures.rows[i];
2918                    let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2919                    let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2920                    let ja = [j_entry_a, j_exit_a, j_deriv_a];
2921                    let jb = [j_entry_b, j_exit_b, j_deriv_b];
2922                    for u in 0..3 {
2923                        for v in 0..3 {
2924                            value += ja[u] * curv[u][v] * jb[v];
2925                        }
2926                    }
2927                    acc[[a, b]] += value;
2928                }
2929            }
2930            Ok(acc)
2931        },
2932        |a, b| Ok(a + b),
2933    )?;
2934    Ok(Some(hessian))
2935}
2936
2937struct MarginalSlopeThetaSecondPartials {
2938    first: Vec<(f64, f64)>,
2939    second: Vec<Vec<(f64, f64)>>,
2940}
2941
2942fn marginal_slope_baseline_offset_theta_second_partials(
2943    age: f64,
2944    cfg: &SurvivalBaselineConfig,
2945) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2946    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2947        return Ok(None);
2948    };
2949    let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2950        return Ok(None);
2951    };
2952    let (cum_hazard, instant_hazard) = hazard;
2953    let survival = (-cum_hazard).exp();
2954    let a = survival / normal_pdf(point.q);
2955    let b = point.q * a - 1.0;
2956    let b_factor = a + point.q * b;
2957    let dim = first.len();
2958    let mut first_out = Vec::with_capacity(dim);
2959    let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2960    for i in 0..dim {
2961        let (h_i, inst_i) = first[i];
2962        first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2963    }
2964    for i in 0..dim {
2965        for j in 0..dim {
2966            let (h_i, inst_i) = first[i];
2967            let (h_j, inst_j) = first[j];
2968            let (h_ij, inst_ij) = second[i][j];
2969            let a_j = a * b * h_j;
2970            let b_j = a * h_j * b_factor;
2971            let q_ij = a * h_ij + a * b * h_i * h_j;
2972            let qt_inner_i = inst_i + instant_hazard * b * h_i;
2973            let qt_ij = a_j * qt_inner_i
2974                + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2975            second_out[i][j] = (q_ij, qt_ij);
2976        }
2977    }
2978    Ok(Some(MarginalSlopeThetaSecondPartials {
2979        first: first_out,
2980        second: second_out,
2981    }))
2982}
2983
2984type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2985
2986fn survival_hazard_theta_first_second(
2987    age: f64,
2988    cfg: &SurvivalBaselineConfig,
2989) -> Result<Option<HazardFirstSecond>, String> {
2990    let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2991        return Ok(None);
2992    };
2993    let first = survival_hazard_theta_partials(age, cfg)?
2994        .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2995    let dim = first.len();
2996    let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2997    match cfg.target {
2998        SurvivalBaselineTarget::Linear => return Ok(None),
2999        SurvivalBaselineTarget::Weibull => {
3000            let scale = cfg
3001                .scale
3002                .ok_or_else(|| "weibull missing scale".to_string())?;
3003            let shape = cfg
3004                .shape
3005                .ok_or_else(|| "weibull missing shape".to_string())?;
3006            let log_time_ratio = age.ln() - scale.ln();
3007            let cumulative_hazard = hazard.0;
3008            let instant_hazard = hazard.1;
3009            let eta = shape * log_time_ratio;
3010            second[0][0] = (
3011                shape * shape * cumulative_hazard,
3012                shape * shape * instant_hazard,
3013            );
3014            second[0][1] = (
3015                -shape * cumulative_hazard * (1.0 + eta),
3016                -shape * instant_hazard * (2.0 + eta),
3017            );
3018            second[1][0] = second[0][1];
3019            second[1][1] = (
3020                eta * cumulative_hazard * (1.0 + eta),
3021                (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
3022            );
3023        }
3024        SurvivalBaselineTarget::Gompertz => {
3025            let rate = cfg
3026                .rate
3027                .ok_or_else(|| "gompertz missing rate".to_string())?;
3028            let shape = cfg
3029                .shape
3030                .ok_or_else(|| "gompertz missing shape".to_string())?;
3031            second[0][0] = first[0];
3032            second[0][1] = first[1];
3033            second[1][0] = first[1];
3034            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3035        }
3036        SurvivalBaselineTarget::GompertzMakeham => {
3037            let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
3038            let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
3039            second[0][0] = first[0];
3040            second[0][1] = first[1];
3041            second[1][0] = first[1];
3042            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3043            second[2][2] = first[2];
3044        }
3045    }
3046    Ok(Some((hazard, first, second)))
3047}
3048
3049#[inline]
3050fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3051    let x = shape * age;
3052    // ∂²H_G/∂shape² = rate·[t²·E/shape − 2·(shape·t·E − (E−1))/shape³]. This is
3053    // a difference of O(1/shape³) terms whose leading parts cancel, so its
3054    // floating-point accuracy is governed by x = shape·age — and the
3055    // cancellation is FAR worse than the first derivative's 1/shape² form.
3056    // Empirically the exact branch is already garbage for |x| < ~1e-4 (e.g.
3057    // x=1e-9 gives a ~98% relative error; x=1e-10 a ~9700% error). The old
3058    // `shape < 1e-10` pivot ignored `age` and so routed those small-x cases
3059    // through the cancelling exact form, corrupting the marginal-slope baseline
3060    // Hessian near small shape. Pivot on x with a wider threshold than the
3061    // first derivative: the 3-term Taylor (through O(x²)) holds to <1e-8 for
3062    // |x| < 1e-3, and the exact branch is clean above it.
3063    if x.abs() < 1e-3 {
3064        let t = age;
3065        (
3066            rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3067            rate * t * t * (1.0 + x + 0.5 * x * x),
3068        )
3069    } else {
3070        let e = x.exp();
3071        let em1 = x.exp_m1();
3072        let n = shape * age * e - em1;
3073        (
3074            rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3075            rate * age * age * e,
3076        )
3077    }
3078}
3079
3080// ---------------------------------------------------------------------------
3081// Baseline offsets
3082// ---------------------------------------------------------------------------
3083
3084#[derive(Clone, Copy)]
3085enum BaselineOffsetEvaluator {
3086    LogCumulativeHazard,
3087    ProbitSurvival,
3088}
3089
3090impl BaselineOffsetEvaluator {
3091    fn length_error(self) -> String {
3092        match self {
3093            Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3094                reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3095            }
3096            .into(),
3097            Self::ProbitSurvival => {
3098                "survival probit baseline offsets require matching entry/exit lengths".to_string()
3099            }
3100        }
3101    }
3102
3103    fn finite_error(self) -> &'static str {
3104        match self {
3105            Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3106            Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3107        }
3108    }
3109
3110    fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3111        match self {
3112            Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3113            Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3114        }
3115    }
3116
3117    fn exit_is_finite(self, value: f64, age: f64) -> bool {
3118        match self {
3119            Self::LogCumulativeHazard => {
3120                value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3121            }
3122            Self::ProbitSurvival => value.is_finite(),
3123        }
3124    }
3125}
3126
3127fn build_survival_offsets_with_evaluator(
3128    age_entry: &Array1<f64>,
3129    age_exit: &Array1<f64>,
3130    cfg: &SurvivalBaselineConfig,
3131    evaluator: BaselineOffsetEvaluator,
3132) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3133    if age_entry.len() != age_exit.len() {
3134        return Err(evaluator.length_error());
3135    }
3136    let n = age_entry.len();
3137    // Each row's three offsets are independent across i. Compute the triplets
3138    // in parallel, then unpack into three Array1 outputs preserving order.
3139    let triples: Vec<(f64, f64, f64)> = (0..n)
3140        .into_par_iter()
3141        .map(|i| -> Result<(f64, f64, f64), String> {
3142            // Origin-entry rows are multiplied out by the survival engines, so
3143            // keep their entry channel finite even when the evaluator's natural
3144            // value at t=0 is undefined or -inf.
3145            let entry_age = age_entry[i];
3146            let e0 = if !entry_age.is_finite() {
3147                return Err(SurvivalConstructionError::DataValidationFailed {
3148                    reason: format!("non-finite entry age at row {i}"),
3149                }
3150                .into());
3151            } else if entry_age <= 0.0 {
3152                0.0
3153            } else {
3154                evaluator.evaluate(entry_age, cfg)?.0
3155            };
3156            let exit_age = age_exit[i];
3157            let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3158            if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3159                return Err(SurvivalConstructionError::DataValidationFailed {
3160                    reason: evaluator.finite_error().to_string(),
3161                }
3162                .into());
3163            }
3164            Ok((e0, e1, d1))
3165        })
3166        .collect::<Result<Vec<_>, String>>()?;
3167    let mut eta_entry = Array1::<f64>::zeros(n);
3168    let mut eta_exit = Array1::<f64>::zeros(n);
3169    let mut derivative_exit = Array1::<f64>::zeros(n);
3170    for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3171        eta_entry[i] = e0;
3172        eta_exit[i] = e1;
3173        derivative_exit[i] = d1;
3174    }
3175    Ok((eta_entry, eta_exit, derivative_exit))
3176}
3177
3178/// Compute baseline target offsets for all observations.
3179/// Returns `(eta_entry, eta_exit, derivative_exit)`.
3180pub fn build_survival_baseline_offsets(
3181    age_entry: &Array1<f64>,
3182    age_exit: &Array1<f64>,
3183    cfg: &SurvivalBaselineConfig,
3184) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3185    build_survival_offsets_with_evaluator(
3186        age_entry,
3187        age_exit,
3188        cfg,
3189        BaselineOffsetEvaluator::LogCumulativeHazard,
3190    )
3191}
3192
3193/// Compute probit-survival baseline target offsets for all observations.
3194/// Returns `(q_entry, q_exit, q_derivative_exit)` where `Phi(-q(t)) = exp(-H0(t))`.
3195pub fn build_survival_marginal_slope_baseline_offsets(
3196    age_entry: &Array1<f64>,
3197    age_exit: &Array1<f64>,
3198    cfg: &SurvivalBaselineConfig,
3199) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3200    build_survival_offsets_with_evaluator(
3201        age_entry,
3202        age_exit,
3203        cfg,
3204        BaselineOffsetEvaluator::ProbitSurvival,
3205    )
3206}
3207
3208pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3209    matches!(
3210        inverse_link,
3211        Some(
3212            InverseLink::Standard(StandardLink::Probit)
3213                | InverseLink::LatentCLogLog(_)
3214                | InverseLink::Sas(_)
3215                | InverseLink::BetaLogistic(_)
3216                | InverseLink::Mixture(_)
3217        )
3218    )
3219}
3220
3221pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3222    match likelihood_mode {
3223        SurvivalLikelihoodMode::LocationScale
3224        | SurvivalLikelihoodMode::Latent
3225        | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3226        SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3227        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3228    }
3229}
3230
3231pub fn build_survival_time_offsets_for_likelihood(
3232    age_entry: &Array1<f64>,
3233    age_exit: &Array1<f64>,
3234    baseline_cfg: &SurvivalBaselineConfig,
3235    likelihood_mode: SurvivalLikelihoodMode,
3236    inverse_link: Option<&InverseLink>,
3237) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3238    if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3239        || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3240            && location_scale_uses_probit_survival_baseline(inverse_link))
3241    {
3242        build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3243    } else {
3244        build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3245    }
3246}
3247
3248pub fn add_survival_time_derivative_guard_offset(
3249    age_entry: &Array1<f64>,
3250    age_exit: &Array1<f64>,
3251    anchor_time: f64,
3252    derivative_guard: f64,
3253    eta_offset_entry: &mut Array1<f64>,
3254    eta_offset_exit: &mut Array1<f64>,
3255    derivative_offset_exit: &mut Array1<f64>,
3256) -> Result<(), String> {
3257    if derivative_guard <= 0.0 {
3258        return Ok(());
3259    }
3260    let n = age_entry.len();
3261    if age_exit.len() != n
3262        || eta_offset_entry.len() != n
3263        || eta_offset_exit.len() != n
3264        || derivative_offset_exit.len() != n
3265    {
3266        return Err(SurvivalConstructionError::IncompatibleDimensions {
3267            reason: "survival derivative-guard offset lengths must match".to_string(),
3268        }
3269        .into());
3270    }
3271    for i in 0..n {
3272        eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3273        eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3274        derivative_offset_exit[i] += derivative_guard;
3275    }
3276    Ok(())
3277}
3278
3279#[derive(Clone, Debug)]
3280pub struct LatentSurvivalBaselineOffsets {
3281    pub loaded_eta_entry: Array1<f64>,
3282    pub loaded_eta_exit: Array1<f64>,
3283    pub loaded_derivative_exit: Array1<f64>,
3284    pub unloaded_mass_entry: Array1<f64>,
3285    pub unloaded_mass_exit: Array1<f64>,
3286    pub unloaded_hazard_exit: Array1<f64>,
3287}
3288
3289pub fn build_latent_survival_baseline_offsets(
3290    age_entry: &Array1<f64>,
3291    age_exit: &Array1<f64>,
3292    cfg: &SurvivalBaselineConfig,
3293    loading: HazardLoading,
3294) -> Result<LatentSurvivalBaselineOffsets, String> {
3295    if age_entry.len() != age_exit.len() {
3296        return Err(
3297            "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3298        );
3299    }
3300
3301    fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3302        if shape.abs() < 1e-10 {
3303            // Taylor at shape=0 matching `gompertz_hazard_components`:
3304            //   H_G(t) = rate·t·(1 + (shape·t)/2 + (shape·t)²/6)
3305            //   h_G(t) = rate·(1 + shape·t + (shape·t)²/2)
3306            // Dropping the higher-order `shape*t` corrections silently
3307            // diverges this helper from its sibling for non-zero shape near
3308            // the cutoff and gives inconsistent loaded-vs-unloaded offsets.
3309            let x = shape * age;
3310            return (
3311                rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3312                rate * (1.0 + x + 0.5 * x * x),
3313            );
3314        }
3315        let shape_age = shape * age;
3316        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3317        let instant_hazard = rate * shape_age.exp();
3318        (cumulative_hazard, instant_hazard)
3319    }
3320
3321    let n = age_entry.len();
3322
3323    // Per-row 6-tuple is independent. Evaluate in parallel into a Vec and then
3324    // unpack into the six Array1 outputs in original order.
3325    let rows: Vec<[f64; 6]> = (0..n)
3326        .into_par_iter()
3327        .map(|i| -> Result<[f64; 6], String> {
3328            let entry = age_entry[i];
3329            let exit = age_exit[i];
3330            if !entry.is_finite()
3331                || !exit.is_finite()
3332                || entry <= 0.0
3333                || exit <= 0.0
3334                || exit < entry
3335            {
3336                return Err(format!(
3337                    "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3338                    i + 1
3339                ));
3340            }
3341            match loading {
3342                HazardLoading::Full => {
3343                    let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3344                    let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3345                    Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3346                }
3347                HazardLoading::LoadedVsUnloaded => {
3348                    if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3349                        return Err(format!(
3350                            "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3351                            survival_baseline_targetname(cfg.target)
3352                        ));
3353                    }
3354                    let rate = cfg.rate.ok_or_else(|| {
3355                        "gompertz-makeham latent survival is missing baseline rate".to_string()
3356                    })?;
3357                    let shape = cfg.shape.ok_or_else(|| {
3358                        "gompertz-makeham latent survival is missing baseline shape".to_string()
3359                    })?;
3360                    let makeham = cfg.makeham.ok_or_else(|| {
3361                        "gompertz-makeham latent survival is missing baseline makeham".to_string()
3362                    })?;
3363                    let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3364                    let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3365                    if !(loaded_entry.is_finite()
3366                        && loaded_entry > 0.0
3367                        && loaded_exit.is_finite()
3368                        && loaded_exit > 0.0
3369                        && loaded_hazard.is_finite()
3370                        && loaded_hazard > 0.0)
3371                    {
3372                        return Err(format!(
3373                            "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3374                            i + 1
3375                        ));
3376                    }
3377                    Ok([
3378                        loaded_entry.ln(),
3379                        loaded_exit.ln(),
3380                        loaded_hazard / loaded_exit,
3381                        makeham * entry,
3382                        makeham * exit,
3383                        makeham,
3384                    ])
3385                }
3386            }
3387        })
3388        .collect::<Result<Vec<_>, String>>()?;
3389
3390    let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3391    let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3392    let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3393    let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3394    let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3395    let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3396    for (i, row) in rows.into_iter().enumerate() {
3397        loaded_eta_entry[i] = row[0];
3398        loaded_eta_exit[i] = row[1];
3399        loaded_derivative_exit[i] = row[2];
3400        unloaded_mass_entry[i] = row[3];
3401        unloaded_mass_exit[i] = row[4];
3402        unloaded_hazard_exit[i] = row[5];
3403    }
3404
3405    Ok(LatentSurvivalBaselineOffsets {
3406        loaded_eta_entry,
3407        loaded_eta_exit,
3408        loaded_derivative_exit,
3409        unloaded_mass_entry,
3410        unloaded_mass_exit,
3411        unloaded_hazard_exit,
3412    })
3413}
3414
3415// ---------------------------------------------------------------------------
3416// Time wiggle construction
3417// ---------------------------------------------------------------------------
3418
3419pub fn build_survival_timewiggle_derivative_design(
3420    eta_exit: &Array1<f64>,
3421    derivative_exit: &Array1<f64>,
3422    knots: &Array1<f64>,
3423    degree: usize,
3424) -> Result<Array2<f64>, String> {
3425    let mut design_derivative_exit =
3426        monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3427    for i in 0..design_derivative_exit.nrows() {
3428        let chain = derivative_exit[i];
3429        for j in 0..design_derivative_exit.ncols() {
3430            design_derivative_exit[[i, j]] *= chain;
3431        }
3432    }
3433    Ok(design_derivative_exit)
3434}
3435
3436/// Build the dynamic "baseline as prior" timewiggle runtime.
3437///
3438/// The baseline offsets are used only to initialize the wiggle knot placement
3439/// on a stable scalar scale.  The exact survival family evaluates the resulting
3440/// monotone wiggle dynamically on the current time predictor h0(t):
3441///
3442///   h(t) = g(h0(t)),   g(z) = z + w(z).
3443///
3444/// No fixed `B(eta_baseline)` design is constructed here.
3445pub fn build_survival_timewiggle_from_baseline(
3446    eta_entry: &Array1<f64>,
3447    eta_exit: &Array1<f64>,
3448    derivative_exit: &Array1<f64>,
3449    cfg: &LinkWiggleFormulaSpec,
3450) -> Result<SurvivalTimeWiggleBuild, String> {
3451    if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3452        return Err(
3453            "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3454        );
3455    }
3456    // Guard: if baseline offsets are all zero (linear baseline), the timewiggle
3457    // construction is degenerate — it adds only a constant, not time-varying structure.
3458    let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3459        && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3460        && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3461    if all_zero {
3462        return Err(
3463            "timewiggle requires a non-linear scalar survival baseline target; \
3464             the provided baseline offsets are all zero (linear baseline)"
3465                .to_string(),
3466        );
3467    }
3468    let n = eta_exit.len();
3469    let mut seed = Array1::<f64>::zeros(2 * n);
3470    for i in 0..n {
3471        seed[i] = eta_entry[i];
3472        seed[n + i] = eta_exit[i];
3473    }
3474    // Use the smallest requested positive penalty order as the primary
3475    // coefficient-space penalty so the fitted wiggle penalty system matches
3476    // the public formula exactly, including the slope (`order = 1`) case.
3477    let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3478    let wiggle_cfg = WiggleBlockConfig {
3479        degree: cfg.degree,
3480        num_internal_knots: cfg.num_internal_knots,
3481        penalty_order: primary_order,
3482        double_penalty: cfg.double_penalty,
3483    };
3484    let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3485    append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3486    let ncols = combined_block.design.ncols();
3487    Ok(SurvivalTimeWiggleBuild {
3488        nullspace_dims: combined_block.nullspace_dims.clone(),
3489        penalties: {
3490            combined_block
3491                .penalties
3492                .into_iter()
3493                .map(|ps| ps.to_global(ncols))
3494                .collect()
3495        },
3496        knots,
3497        degree: cfg.degree,
3498        ncols,
3499    })
3500}
3501
3502pub fn append_zero_tail_columns(
3503    x_entry: &mut DesignMatrix,
3504    x_exit: &mut DesignMatrix,
3505    x_derivative: &mut DesignMatrix,
3506    tail_cols: usize,
3507) {
3508    if tail_cols == 0 {
3509        return;
3510    }
3511    // Wiggle tail columns are dense, so materialize everything to dense.
3512    // This only runs once at construction time when time-wiggles are active.
3513    fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3514        let old = dm.to_dense();
3515        let n = old.nrows();
3516        let p_base = old.ncols();
3517        let mut out = Array2::<f64>::zeros((n, p_base + tail));
3518        out.slice_mut(s![.., 0..p_base]).assign(&old);
3519        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3520    }
3521    append_dense(x_entry, tail_cols);
3522    append_dense(x_exit, tail_cols);
3523    append_dense(x_derivative, tail_cols);
3524}
3525
3526// ---------------------------------------------------------------------------
3527// Resolved config (from build output back to config for serialization)
3528// ---------------------------------------------------------------------------
3529
3530// ---------------------------------------------------------------------------
3531// Time-varying covariate template
3532// ---------------------------------------------------------------------------
3533
3534/// Build a time-varying covariate block by tensoring the covariate design
3535/// with a 1D B-spline basis on log(time).
3536pub fn build_time_varying_survival_covariate_template(
3537    age_entry: &Array1<f64>,
3538    age_exit: &Array1<f64>,
3539    time_k: usize,
3540    time_degree: usize,
3541    block_name: &str,
3542) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3543    if time_k < time_degree + 1 {
3544        return Err(format!(
3545            "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3546            time_degree + 1
3547        ));
3548    }
3549    let num_internal_knots = time_k - (time_degree + 1);
3550
3551    let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3552    let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3553
3554    let time_spec = BSplineBasisSpec {
3555        degree: time_degree,
3556        penalty_order: 2,
3557        knotspec: BSplineKnotSpec::Automatic {
3558            num_internal_knots: Some(num_internal_knots),
3559            placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3560        },
3561        double_penalty: false,
3562        identifiability: BSplineIdentifiability::None,
3563        boundary: OneDimensionalBoundary::Open,
3564        boundary_conditions: BSplineBoundaryConditions::default(),
3565    };
3566
3567    let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3568        .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3569    let time_design_exit = time_build.design.to_dense();
3570
3571    let knots = match &time_build.metadata {
3572        BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3573        _ => {
3574            return Err(format!(
3575                "{block_name} time-margin basis returned unexpected metadata type"
3576            ));
3577        }
3578    };
3579
3580    let time_build_entry = build_bspline_basis_1d(
3581        log_entry.view(),
3582        &BSplineBasisSpec {
3583            degree: time_degree,
3584            penalty_order: 2,
3585            knotspec: BSplineKnotSpec::Provided(knots.clone()),
3586            double_penalty: false,
3587            identifiability: BSplineIdentifiability::None,
3588            boundary: OneDimensionalBoundary::Open,
3589            boundary_conditions: BSplineBoundaryConditions::default(),
3590        },
3591    )
3592    .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3593    let time_design_entry = time_build_entry.design.to_dense();
3594    let p_time = time_design_exit.ncols();
3595    let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3596    // Per-row derivative-basis evaluation is independent; each row owns its
3597    // own small `deriv_buf`. par_chunks_mut over the (n × p_time) output rows
3598    // hands disjoint mutable row-slices to rayon workers.
3599    time_design_derivative_exit
3600        .as_slice_mut()
3601        .expect("zeros are contiguous")
3602        .par_chunks_mut(p_time)
3603        .enumerate()
3604        .try_for_each(|(i, row_out)| -> Result<(), String> {
3605            let mut deriv_buf = vec![0.0_f64; p_time];
3606            evaluate_bspline_derivative_scalar(
3607                log_exit[i],
3608                knots.view(),
3609                time_degree,
3610                &mut deriv_buf,
3611            )
3612            .map_err(|e| {
3613                format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3614            })?;
3615            let chain = 1.0 / age_exit[i].max(1e-12);
3616            for j in 0..p_time {
3617                row_out[j] = deriv_buf[j] * chain;
3618            }
3619            Ok(())
3620        })?;
3621
3622    Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3623        time_basis_entry: time_design_entry,
3624        time_basis_exit: time_design_exit,
3625        time_basis_derivative_exit: time_design_derivative_exit,
3626        time_penalties: time_build.penalties,
3627    })
3628}
3629
3630#[cfg(test)]
3631mod tests {
3632    use super::{
3633        SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3634        baseline_chain_rule_gradient, baseline_offset_theta_partials,
3635        build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3636        build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3637        evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3638        gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3639        marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3640        marginal_slope_baseline_offset_theta_partials,
3641        optimize_survival_baseline_config_with_gradient,
3642        optimize_survival_baseline_config_with_gradient_only,
3643        resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3644        survival_baseline_theta_from_config,
3645    };
3646    use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3647    use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3648    use crate::probability::normal_cdf;
3649    use ndarray::{Array1, Array2, array};
3650
3651    #[test]
3652    fn survival_timewiggle_keeps_requested_order_one_penalty() {
3653        let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3654        let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3655        let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3656        let cfg = LinkWiggleFormulaSpec {
3657            degree: 3,
3658            num_internal_knots: 4,
3659            penalty_orders: vec![1, 2, 3],
3660            double_penalty: false,
3661        };
3662
3663        let build =
3664            build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3665                .expect("build survival timewiggle");
3666
3667        assert_eq!(build.penalties.len(), 3);
3668        assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3669        assert!(build.ncols > 0);
3670    }
3671
3672    #[test]
3673    fn marginal_slope_time_anchor_defaults_to_median_exit() {
3674        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3675        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3676        let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3677            .expect("resolve marginal-slope default time anchor");
3678
3679        assert!(
3680            (anchor - 19.0).abs() <= 1e-12,
3681            "marginal-slope default anchor should be median exit, got {anchor}"
3682        );
3683    }
3684
3685    #[test]
3686    fn marginal_slope_time_anchor_honors_explicit_value() {
3687        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3688        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3689        let anchor =
3690            resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3691                .expect("resolve explicit marginal-slope time anchor");
3692
3693        assert!(
3694            (anchor - 7.5).abs() <= 1e-12,
3695            "explicit marginal-slope anchor should round-trip, got {anchor}"
3696        );
3697    }
3698
3699    /// Derivative-contract parity for the two public baseline optimizers.
3700    ///
3701    /// After the unification onto `run_baseline_theta_optimizer`, the
3702    /// gradient-only and gradient+Hessian entry points differ *only* in how
3703    /// much derivative information they hand the outer solver — not in the
3704    /// surface they minimize. We exercise that invariant on a known
3705    /// strictly-convex quadratic in θ-space (Weibull baseline: θ = (ln scale,
3706    /// ln shape)) whose unique minimizer is `theta_star`, supplying the same
3707    /// objective as `(f, ∇f)` and as `(f, ∇f, ∇²f)`. Both contracts must
3708    /// recover the same minimizer config, not weakened to pass.
3709    #[test]
3710    fn baseline_optimizer_contracts_agree_on_shared_surface() {
3711        // SPD curvature and interior minimizer in θ-space. A is well away from
3712        // singular so both the analytic-Hessian and BFGS paths see the same
3713        // unambiguous bowl; θ* sits comfortably inside the ±6 box around the
3714        // θ=(0,0) seed below.
3715        let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3716        let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3717
3718        // Seed config at θ=(0,0) (scale=shape=1). The Linear early-return path
3719        // is not exercised here; Weibull has a genuine 2-dim θ to optimize.
3720        let initial = SurvivalBaselineConfig {
3721            target: SurvivalBaselineTarget::Weibull,
3722            scale: Some(1.0),
3723            shape: Some(1.0),
3724            rate: None,
3725            makeham: None,
3726        };
3727
3728        // θ recovered from a returned Weibull config, via the exact inverse of
3729        // the config→θ map the optimizers use internally.
3730        let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3731            survival_baseline_theta_from_config(cfg)
3732                .expect("config→θ")
3733                .expect("Weibull config has a θ")
3734        };
3735
3736        // Shared quadratic surface, evaluated by mapping config→θ so every
3737        // contract sees the identical objective.
3738        let curvature_cost = curvature.clone();
3739        let star_cost = theta_star.clone();
3740        let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3741            let theta = survival_baseline_theta_from_config(cfg)?
3742                .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3743            let d = &theta - &star_cost;
3744            let ad = curvature_cost.dot(&d);
3745            Ok(0.5 * d.dot(&ad))
3746        };
3747
3748        let curvature_grad = curvature.clone();
3749        let star_grad = theta_star.clone();
3750        let cost_for_grad = cost_at.clone();
3751        let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3752            &initial,
3753            "baseline parity (gradient-only)",
3754            move |cfg| {
3755                let cost = cost_for_grad(cfg)?;
3756                let theta = survival_baseline_theta_from_config(cfg)?
3757                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3758                let gradient = curvature_grad.dot(&(&theta - &star_grad));
3759                Ok((cost, gradient))
3760            },
3761        )
3762        .expect("gradient-only baseline optimization converges");
3763
3764        let curvature_hess = curvature.clone();
3765        let star_hess = theta_star.clone();
3766        let cost_for_hess = cost_at.clone();
3767        let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3768            &initial,
3769            "baseline parity (gradient+Hessian)",
3770            move |cfg| {
3771                let cost = cost_for_hess(cfg)?;
3772                let theta = survival_baseline_theta_from_config(cfg)?
3773                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3774                let gradient = curvature_hess.dot(&(&theta - &star_hess));
3775                Ok((cost, gradient, curvature_hess.clone()))
3776            },
3777        )
3778        .expect("gradient+Hessian baseline optimization converges");
3779
3780        let theta_grad_only = recovered_theta(&result_grad_only);
3781        let theta_grad_hess = recovered_theta(&result_grad_hess);
3782
3783        // Each contract recovers the true minimizer. 2e-3 is a safe,
3784        // un-weakened bound; both gradient paths land far tighter.
3785        for (label, theta) in [
3786            ("gradient-only", &theta_grad_only),
3787            ("gradient+Hessian", &theta_grad_hess),
3788        ] {
3789            let err = (theta - &theta_star)
3790                .mapv(f64::abs)
3791                .fold(0.0_f64, |a, &v| a.max(v));
3792            assert!(
3793                err <= 2e-3,
3794                "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3795            );
3796        }
3797
3798        // Cross-contract agreement: the three results must coincide, since the
3799        // only difference between the entry points is the derivative contract,
3800        // never the surface they minimize.
3801        let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3802            (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3803        };
3804        assert!(
3805            pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3806            "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3807        );
3808    }
3809
3810    #[test]
3811    fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3812        let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3813        let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3814        let requested_degree = 3;
3815        let num_internal_knots = 1;
3816
3817        let built = build_survival_time_basis(
3818            &age_entry,
3819            &age_exit,
3820            SurvivalTimeBasisConfig::ISpline {
3821                degree: requested_degree,
3822                knots: Array1::zeros(0),
3823                keep_cols: Vec::new(),
3824                smooth_lambda: 1e-2,
3825            },
3826            Some((num_internal_knots, 1e-2)),
3827        )
3828        .expect("automatic cubic ispline with one interior knot builds");
3829
3830        let working_degree = requested_degree + 1;
3831        let knots = built.knots.expect("resolved ispline knots");
3832        assert_eq!(
3833            knots.len(),
3834            num_internal_knots + 2 * (working_degree + 1),
3835            "I-spline automatic knots must be clamped for the working B-spline degree"
3836        );
3837        assert_eq!(built.degree, Some(requested_degree));
3838        assert!(built.x_exit_time.ncols() > 0);
3839        assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3840        assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3841    }
3842
3843    #[test]
3844    fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3845        let age_entry = array![1.0_f64, 1.0, 1.0];
3846        let age_exit = array![4.0_f64, 4.0, 4.0];
3847        let left = 1.0_f64.ln();
3848        let right = 4.0_f64.ln();
3849        let mid = left + 0.5 * (right - left);
3850        let knots = array![left, left, left, left, mid, right, right, right, right];
3851
3852        let built = build_survival_time_basis(
3853            &age_entry,
3854            &age_exit,
3855            SurvivalTimeBasisConfig::ISpline {
3856                degree: 2,
3857                knots,
3858                keep_cols: Vec::new(),
3859                smooth_lambda: 1e-2,
3860            },
3861            None,
3862        )
3863        .expect("build right-boundary ispline time basis");
3864
3865        let derivative = built.x_derivative_time.as_dense_cow();
3866        let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3867        assert!(
3868            max_abs > 1e-8,
3869            "right-boundary I-spline derivative must use the left-hand endpoint slope"
3870        );
3871        for row in derivative.rows() {
3872            assert!(
3873                row.iter().any(|v| *v > 1e-8),
3874                "each row at the right boundary needs a positive hazard derivative"
3875            );
3876        }
3877    }
3878
3879    #[test]
3880    fn ispline_time_penalty_is_psd_under_nontrivial_keep_cols() {
3881        // PSD-invariant forward guard for the gam#979 survival hang. The I-spline
3882        // value-space curvature penalty on the increment coefficients is the
3883        // congruence `S_I = Lᵀ S_B[1:,1:] L`. When identifiability drops columns,
3884        // the retained block MUST be taken as a PRINCIPAL SUBMATRIX of the FULL
3885        // congruence (congruence first, column selection second). The historical
3886        // regression assembled the reduced penalty in the wrong order, producing
3887        // a strongly INDEFINITE matrix (measured `s0_min_eval = −9.8e7`); an
3888        // indefinite time penalty makes `½γᵀ S_I γ` unbounded below, the inner
3889        // joint-Newton follows the divergence, and the outer REML never
3890        // terminates — the survival marginal-slope hang.
3891        //
3892        // This test exercises the reduction with a NON-TRIVIAL `keep_cols`
3893        // (a proper subset, an interior column dropped) and asserts the assembled
3894        // penalty satisfies the PSD contract the fix guarantees. It locks the
3895        // invariant on the shipped code path so a future reassembly that
3896        // reintroduces an indefinite reduction is caught at construction rather
3897        // than silently as an outer-loop hang. (It is a forward invariant lock,
3898        // not a bit-exact replay of the removed buggy assembly.)
3899        let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3900        let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3901        let left = 1.0_f64.ln();
3902        let right = 21.0_f64.ln();
3903        let q1 = left + 0.25 * (right - left);
3904        let mid = left + 0.5 * (right - left);
3905        let q3 = left + 0.75 * (right - left);
3906        // Degree-2 I-spline with three interior knots -> a value-space basis wide
3907        // enough to drop an interior column and still leave the reduction
3908        // non-trivial (p_time < p_time_full).
3909        let knots = array![
3910            left, left, left, left, q1, mid, q3, right, right, right, right
3911        ];
3912
3913        // Discover the full basis width by building with all columns retained.
3914        let full = build_survival_time_basis(
3915            &age_entry,
3916            &age_exit,
3917            SurvivalTimeBasisConfig::ISpline {
3918                degree: 2,
3919                knots: knots.clone(),
3920                keep_cols: Vec::new(),
3921                smooth_lambda: 1e-2,
3922            },
3923            None,
3924        )
3925        .expect("build full-width ispline time basis");
3926        let p_time_full = full
3927            .keep_cols
3928            .as_ref()
3929            .map(|k| k.len())
3930            .unwrap_or_else(|| full.x_exit_time.ncols());
3931        assert!(
3932            p_time_full >= 3,
3933            "test needs at least 3 shape-varying columns to drop an interior one; got {p_time_full}"
3934        );
3935
3936        // Retain everything except one interior column, forcing the
3937        // principal-submatrix-of-the-full-congruence path.
3938        let keep_cols: Vec<usize> = (0..p_time_full).filter(|&j| j != 1).collect();
3939
3940        let built = build_survival_time_basis(
3941            &age_entry,
3942            &age_exit,
3943            SurvivalTimeBasisConfig::ISpline {
3944                degree: 2,
3945                knots,
3946                keep_cols: keep_cols.clone(),
3947                smooth_lambda: 1e-2,
3948            },
3949            None,
3950        )
3951        .expect(
3952            "reduced ispline penalty must build (PSD contract must accept the \
3953             congruence-first / select-second ordering)",
3954        );
3955
3956        assert_eq!(
3957            built.penalties.len(),
3958            1,
3959            "the ispline time basis should carry exactly one curvature penalty"
3960        );
3961        let s = &built.penalties[0];
3962        assert_eq!(s.nrows(), keep_cols.len());
3963        assert_eq!(s.ncols(), keep_cols.len());
3964
3965        let (evals, _) =
3966            gam_linalg::faer_ndarray::FaerEigh::eigh(s, faer::Side::Lower).expect("eigh of penalty");
3967        let evals_slice = evals.as_slice().expect("contiguous eigenvalues");
3968        let max_abs = evals_slice
3969            .iter()
3970            .copied()
3971            .fold(0.0_f64, |a, b| a.max(b.abs()))
3972            .max(1.0);
3973        let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
3974        let tol = -100.0 * (s.nrows() as f64) * f64::EPSILON * max_abs;
3975        assert!(
3976            min_ev >= tol,
3977            "reduced I-spline time penalty must be PSD (gam#979): min eigenvalue \
3978             {min_ev:.3e} < tol {tol:.3e}, max|eig| {max_abs:.3e}"
3979        );
3980    }
3981
3982    #[test]
3983    fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3984        let cfg = SurvivalBaselineConfig {
3985            target: SurvivalBaselineTarget::GompertzMakeham,
3986            scale: None,
3987            shape: Some(0.07),
3988            rate: Some(0.012),
3989            makeham: Some(0.003),
3990        };
3991        let age = 11.5;
3992        let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3993            .expect("evaluate marginal-slope gompertz-makeham baseline");
3994        let shape = cfg.shape.expect("shape");
3995        let rate = cfg.rate.expect("rate");
3996        let makeham = cfg.makeham.expect("makeham");
3997        let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3998        let instant_hazard = makeham + rate * (shape * age).exp();
3999        let expected_survival = (-cumulative_hazard).exp();
4000        let actual_survival = normal_cdf(-q);
4001        assert!((actual_survival - expected_survival).abs() <= 1e-12);
4002
4003        let h = 1e-5;
4004        let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
4005            .expect("q plus")
4006            .0;
4007        let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
4008            .expect("q minus")
4009            .0;
4010        let fd = (q_plus - q_minus) / (2.0 * h);
4011        assert!((q_derivative - fd).abs() <= 1e-7);
4012        assert!(instant_hazard > 0.0);
4013    }
4014
4015    #[test]
4016    fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
4017        // Regression for #1024: the probit/marginal-slope baseline evaluator must
4018        // be defined at the survival-curve origin t = 0 (where S0(0) = 1, so the
4019        // probit index q(0) = -Phi^{-1}(1) = -inf and there is no finite offset),
4020        // exactly like its log-cumulative-hazard sibling `evaluate_survival_baseline`.
4021        // Before the fix the shared `age <= 0` hazard guard aborted, so a survival
4022        // prediction grid whose first node is the origin (the `Surv(time, event)`
4023        // right-censored shorthand) could not be evaluated for the location-scale /
4024        // marginal-slope likelihoods.
4025        let configs = [
4026            SurvivalBaselineConfig {
4027                target: SurvivalBaselineTarget::Linear,
4028                scale: None,
4029                shape: None,
4030                rate: None,
4031                makeham: None,
4032            },
4033            SurvivalBaselineConfig {
4034                target: SurvivalBaselineTarget::Weibull,
4035                scale: Some(2.5),
4036                shape: Some(1.3),
4037                rate: None,
4038                makeham: None,
4039            },
4040            SurvivalBaselineConfig {
4041                target: SurvivalBaselineTarget::Gompertz,
4042                scale: None,
4043                shape: Some(0.05),
4044                rate: Some(0.01),
4045                makeham: None,
4046            },
4047            SurvivalBaselineConfig {
4048                target: SurvivalBaselineTarget::GompertzMakeham,
4049                scale: None,
4050                shape: Some(0.07),
4051                rate: Some(0.012),
4052                makeham: Some(0.003),
4053            },
4054        ];
4055        for cfg in &configs {
4056            // The probit baseline returns a finite zero offset at the origin for
4057            // every target (the survival surface anchors S(0) = 1 directly).
4058            let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
4059                .expect("marginal-slope baseline must be evaluable at the origin");
4060            assert_eq!(q0, 0.0);
4061            assert_eq!(q0_derivative, 0.0);
4062
4063            // The log-cumulative-hazard sibling is likewise finite at the origin —
4064            // this parity is the whole point (the transformation likelihood already
4065            // worked because it rides this evaluator).
4066            let (eta0, eta0_derivative) =
4067                evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
4068            assert!(eta0_derivative.is_finite());
4069            assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
4070
4071            // The batched offset builder must not abort when a query exit age is the
4072            // origin (this is the exact call the location-scale predict path makes on
4073            // the default surface grid). Entry stays at the origin, exit spans 0 -> t.
4074            let age_entry = array![0.0, 0.0];
4075            let age_exit = array![0.0, 1.5];
4076            let (entry, exit, derivative) =
4077                build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
4078                    .expect("probit baseline offsets must build through the origin");
4079            assert!(entry.iter().all(|v| v.is_finite()));
4080            assert!(exit.iter().all(|v| v.is_finite()));
4081            assert!(derivative.iter().all(|v| v.is_finite()));
4082            // The origin exit column carries no probit offset.
4083            assert_eq!(exit[0], 0.0);
4084        }
4085    }
4086
4087    #[test]
4088    fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
4089        let cfg = SurvivalBaselineConfig {
4090            target: SurvivalBaselineTarget::GompertzMakeham,
4091            scale: None,
4092            shape: Some(0.03),
4093            rate: Some(0.01),
4094            makeham: Some(0.002),
4095        };
4096        let age_entry = array![2.0, 4.0];
4097        let age_exit = array![5.0, 9.0];
4098        let (entry, exit, derivative) =
4099            build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
4100                .expect("marginal-slope baseline offsets");
4101        for i in 0..age_entry.len() {
4102            let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
4103                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4104                    * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
4105            let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
4106                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4107                    * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
4108            assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
4109            assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
4110            assert!(derivative[i].is_finite() && derivative[i] > 0.0);
4111        }
4112    }
4113
4114    fn fd_marginal_slope_baseline_offset(
4115        age: f64,
4116        cfg: &SurvivalBaselineConfig,
4117        steps: &[f64],
4118    ) -> Vec<(f64, f64)> {
4119        let theta = survival_baseline_theta_from_config(cfg)
4120            .expect("theta")
4121            .expect("non-linear baseline");
4122        assert_eq!(
4123            steps.len(),
4124            theta.len(),
4125            "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
4126        );
4127        (0..theta.len())
4128            .map(|k| {
4129                let h = steps[k];
4130                let mut theta_plus = theta.clone();
4131                theta_plus[k] += h;
4132                let mut theta_minus = theta.clone();
4133                theta_minus[k] -= h;
4134                let cfg_plus =
4135                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4136                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4137                    .expect("minus cfg");
4138                let (q_p, qt_p) =
4139                    evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
4140                let (q_m, qt_m) =
4141                    evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
4142                ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
4143            })
4144            .collect()
4145    }
4146
4147    #[test]
4148    fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
4149        let cfg = SurvivalBaselineConfig {
4150            target: SurvivalBaselineTarget::GompertzMakeham,
4151            scale: None,
4152            shape: Some(0.04),
4153            rate: Some(0.013),
4154            makeham: Some(0.002),
4155        };
4156        let age = 17.0;
4157        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4158            .expect("partials")
4159            .expect("nonlinear");
4160        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
4161        assert_eq!(analytic.len(), fd.len());
4162        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4163            assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
4164            assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
4165        }
4166    }
4167
4168    #[test]
4169    fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4170        let cfg = SurvivalBaselineConfig {
4171            target: SurvivalBaselineTarget::GompertzMakeham,
4172            scale: None,
4173            shape: Some(1e-14),
4174            rate: Some(0.013),
4175            makeham: Some(0.002),
4176        };
4177        let age = 17.0;
4178        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4179            .expect("partials")
4180            .expect("nonlinear");
4181        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4182        assert_eq!(analytic.len(), fd.len());
4183        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4184            assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4185            assert_close(
4186                *aqt,
4187                *fqt,
4188                1e-5,
4189                &format!("near-zero gm-probit q' theta[{k}]"),
4190            );
4191        }
4192    }
4193
4194    fn shifted_quadratic_offset_residuals(
4195        age_entry: ndarray::ArrayView1<'_, f64>,
4196        age_exit: ndarray::ArrayView1<'_, f64>,
4197        base_cfg: &SurvivalBaselineConfig,
4198        candidate_cfg: &SurvivalBaselineConfig,
4199        base: &OffsetChannelResiduals,
4200        curvatures: &OffsetChannelCurvatures,
4201    ) -> OffsetChannelResiduals {
4202        let n = age_exit.len();
4203        let mut entry = base.entry.clone();
4204        let mut exit = base.exit.clone();
4205        let mut derivative = base.derivative.clone();
4206        for row in 0..n {
4207            let (_, base_exit, base_deriv) =
4208                baseline_marginal_slope_channels(age_exit[row], base_cfg);
4209            let (_, cand_exit, cand_deriv) =
4210                baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4211            let base_entry = if base.entry[row] == 0.0 {
4212                0.0
4213            } else {
4214                baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4215            };
4216            let cand_entry = if base.entry[row] == 0.0 {
4217                0.0
4218            } else {
4219                baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4220            };
4221            let delta = [
4222                cand_entry - base_entry,
4223                cand_exit - base_exit,
4224                cand_deriv - base_deriv,
4225            ];
4226            let mut shift = [0.0; 3];
4227            for i in 0..3 {
4228                for j in 0..3 {
4229                    shift[i] += curvatures.rows[row][i][j] * delta[j];
4230                }
4231            }
4232            if base.entry[row] != 0.0 {
4233                entry[row] += shift[0];
4234            }
4235            exit[row] += shift[1];
4236            derivative[row] += shift[2];
4237        }
4238        OffsetChannelResiduals {
4239            entry,
4240            exit,
4241            derivative,
4242            right: base.right.clone(),
4243        }
4244    }
4245
4246    fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4247        let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4248        (q, q, q_t)
4249    }
4250
4251    #[test]
4252    fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4253        let cfg = SurvivalBaselineConfig {
4254            target: SurvivalBaselineTarget::GompertzMakeham,
4255            scale: None,
4256            shape: Some(0.025),
4257            rate: Some(0.012),
4258            makeham: Some(0.003),
4259        };
4260        let theta = survival_baseline_theta_from_config(&cfg)
4261            .expect("theta")
4262            .expect("nonlinear");
4263        let age_entry = array![2.5, 0.0, 5.0];
4264        let age_exit = array![7.5, 11.0, 15.0];
4265        let base_residuals = OffsetChannelResiduals {
4266            entry: array![0.2, 0.0, -0.1],
4267            exit: array![0.6, -0.3, 0.4],
4268            derivative: array![-0.5, 0.25, 0.15],
4269            right: Array1::<f64>::zeros(3),
4270        };
4271        let curvatures = OffsetChannelCurvatures {
4272            rows: vec![
4273                [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4274                [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4275                [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4276            ],
4277        };
4278        let analytic = marginal_slope_baseline_chain_rule_hessian(
4279            age_entry.view(),
4280            age_exit.view(),
4281            &cfg,
4282            &base_residuals,
4283            &curvatures,
4284        )
4285        .expect("hessian")
4286        .expect("nonlinear");
4287
4288        let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4289            let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4290                .expect("candidate cfg");
4291            let residuals = shifted_quadratic_offset_residuals(
4292                age_entry.view(),
4293                age_exit.view(),
4294                &cfg,
4295                &candidate,
4296                &base_residuals,
4297                &curvatures,
4298            );
4299            marginal_slope_baseline_chain_rule_gradient(
4300                age_entry.view(),
4301                age_exit.view(),
4302                &candidate,
4303                &residuals,
4304            )
4305            .expect("gradient")
4306            .expect("nonlinear")
4307        };
4308
4309        for j in 0..theta.len() {
4310            let step = if j == 1 { 2e-5 } else { 1e-5 };
4311            let mut plus = theta.clone();
4312            plus[j] += step;
4313            let mut minus = theta.clone();
4314            minus[j] -= step;
4315            let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4316            for i in 0..theta.len() {
4317                assert_close(
4318                    analytic[[i, j]],
4319                    fd_col[i],
4320                    2e-5,
4321                    &format!("baseline Hessian ({i},{j})"),
4322                );
4323            }
4324        }
4325    }
4326
4327    #[test]
4328    fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4329        let cfg = SurvivalBaselineConfig {
4330            target: SurvivalBaselineTarget::GompertzMakeham,
4331            scale: None,
4332            shape: Some(0.03),
4333            rate: Some(0.01),
4334            makeham: Some(0.002),
4335        };
4336        let age_entry = array![3.0, 6.0];
4337        let age_exit = array![8.0, 12.0];
4338        let residuals = OffsetChannelResiduals {
4339            exit: array![0.7, -0.2],
4340            entry: array![0.1, 0.4],
4341            derivative: array![1.3, -0.6],
4342            right: Array1::<f64>::zeros(2),
4343        };
4344        let grad = marginal_slope_baseline_chain_rule_gradient(
4345            age_entry.view(),
4346            age_exit.view(),
4347            &cfg,
4348            &residuals,
4349        )
4350        .expect("gradient")
4351        .expect("nonlinear");
4352
4353        let mut expected = Array1::<f64>::zeros(3);
4354        for i in 0..age_exit.len() {
4355            let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4356                .expect("exit partials")
4357                .expect("nonlinear");
4358            let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4359                .expect("entry partials")
4360                .expect("nonlinear");
4361            for k in 0..3 {
4362                expected[k] += residuals.exit[i] * exit_partials[k].0
4363                    + residuals.derivative[i] * exit_partials[k].1
4364                    + residuals.entry[i] * entry_partials[k].0;
4365            }
4366        }
4367        for k in 0..3 {
4368            assert_close(
4369                grad[k],
4370                expected[k],
4371                1e-12,
4372                &format!("gm-probit chain gradient theta[{k}]"),
4373            );
4374        }
4375    }
4376
4377    /// Parity guard for the shared `baseline_chain_rule_gradient_with_partials`
4378    /// engine (issue #429): both public gradient functions delegate to it with a
4379    /// different partials provider. This test reimplements the pre-unification
4380    /// inline contraction (the serial reference) and asserts bit-for-bit equality
4381    /// against the unified engine's output for BOTH providers on the same data —
4382    /// the RP-eta provider (`baseline_offset_theta_partials`) and the probit-q
4383    /// provider (`marginal_slope_baseline_offset_theta_partials`). Any drift in
4384    /// the extracted contraction (length checks, theta-dim probe, exit/derivative
4385    /// combination, or entry gating) breaks this with an exact (0.0) tolerance.
4386    #[test]
4387    fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4388        let cfg = SurvivalBaselineConfig {
4389            target: SurvivalBaselineTarget::GompertzMakeham,
4390            scale: None,
4391            shape: Some(0.028),
4392            rate: Some(0.011),
4393            makeham: Some(0.0025),
4394        };
4395        // Mixed entry interval: row 1 is origin-entry (age_entry==0, r_entry==0)
4396        // to exercise the entry-gating branch in the shared engine.
4397        let age_entry = array![3.0, 0.0, 5.5];
4398        let age_exit = array![8.0, 12.0, 16.0];
4399        let residuals = OffsetChannelResiduals {
4400            exit: array![0.7, -0.2, 0.45],
4401            entry: array![0.1, 0.0, -0.3],
4402            derivative: array![1.3, -0.6, 0.2],
4403            right: Array1::<f64>::zeros(3),
4404        };
4405
4406        // Serial reference contraction matching the original inline body. Mirrors
4407        // the engine's exit+derivative/entry split and origin-entry gating.
4408        let reference_gradient = |partials: &dyn Fn(
4409            f64,
4410            &SurvivalBaselineConfig,
4411        )
4412            -> Result<Option<Vec<(f64, f64)>>, String>|
4413         -> Array1<f64> {
4414            let theta_dim = partials(age_exit[0], &cfg)
4415                .expect("probe partials")
4416                .expect("nonlinear")
4417                .len();
4418            let mut acc = Array1::<f64>::zeros(theta_dim);
4419            for i in 0..age_exit.len() {
4420                let p_exit = partials(age_exit[i], &cfg)
4421                    .expect("exit partials")
4422                    .expect("nonlinear");
4423                let r_x = residuals.exit[i];
4424                let r_d = residuals.derivative[i];
4425                for k in 0..theta_dim {
4426                    acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4427                }
4428                let r_e = residuals.entry[i];
4429                if r_e != 0.0 {
4430                    let p_entry = partials(age_entry[i], &cfg)
4431                        .expect("entry partials")
4432                        .expect("nonlinear");
4433                    for k in 0..theta_dim {
4434                        acc[k] += r_e * p_entry[k].0;
4435                    }
4436                }
4437            }
4438            acc
4439        };
4440
4441        // RP-eta provider parity.
4442        let rp_engine = baseline_chain_rule_gradient(
4443            age_entry.view(),
4444            age_exit.view(),
4445            age_exit.view(),
4446            &cfg,
4447            &residuals,
4448        )
4449        .expect("rp gradient")
4450        .expect("rp nonlinear");
4451        let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4452        assert_eq!(rp_engine.len(), rp_reference.len());
4453        for k in 0..rp_engine.len() {
4454            assert_close(
4455                rp_engine[k],
4456                rp_reference[k],
4457                0.0,
4458                &format!("rp engine vs inline reference theta[{k}]"),
4459            );
4460        }
4461
4462        // Probit-q provider parity.
4463        let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4464            age_entry.view(),
4465            age_exit.view(),
4466            &cfg,
4467            &residuals,
4468        )
4469        .expect("probit gradient")
4470        .expect("probit nonlinear");
4471        let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4472        assert_eq!(probit_engine.len(), probit_reference.len());
4473        for k in 0..probit_engine.len() {
4474            assert_close(
4475                probit_engine[k],
4476                probit_reference[k],
4477                0.0,
4478                &format!("probit engine vs inline reference theta[{k}]"),
4479            );
4480        }
4481    }
4482
4483    /// Finite-difference verification of the analytic θ-gradient used by the
4484    /// survival location-scale workflow path.
4485    ///
4486    /// At a converged β, the envelope theorem reduces the profile-NLL gradient
4487    /// w.r.t. the baseline-config θ to a per-row residual contraction against
4488    /// the per-row offset-channel partials ∂o/∂θ:
4489    ///
4490    ///   d(NLL)/dθ_k = Σ_i [ r_X[i]·∂η_exit/∂θ_k + r_E[i]·∂η_entry/∂θ_k
4491    ///                       + r_D[i]·∂o_D_exit/∂θ_k ]
4492    ///
4493    /// (`baseline_chain_rule_gradient`). Because β is fixed, an explicit loss
4494    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i; θ) + r_E[i]·η(t_entry_i; θ)
4495    ///              + r_D[i]·o_D(t_exit_i; θ) ]`
4496    /// has gradient identically equal to the chain-rule output. Comparing the
4497    /// analytic gradient to a central-difference of L over `evaluate_survival_baseline`
4498    /// therefore exercises every piece of the chain rule (incl. the Gompertz
4499    /// rate / shape / Makeham partials at both entry and exit ages) without
4500    /// needing the full location-scale fit pipeline inside this unit-test
4501    /// module. If the chain rule disagrees with FD here, the workflow's
4502    /// gradient is wrong by exactly the same amount.
4503    #[test]
4504    fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4505        let cfg = SurvivalBaselineConfig {
4506            target: SurvivalBaselineTarget::GompertzMakeham,
4507            scale: None,
4508            shape: Some(0.05),
4509            rate: Some(0.012),
4510            makeham: Some(0.003),
4511        };
4512        // n = 8 small synthetic dataset spanning a realistic age range.
4513        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4514        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4515        // Synthetic per-row NLL residuals on the three offset channels. Mix of
4516        // signs / magnitudes / one zero-entry row (origin entry → r_E=0).
4517        let residuals = OffsetChannelResiduals {
4518            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4519            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4520            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4521            right: Array1::<f64>::zeros(8),
4522        };
4523
4524        let analytic = baseline_chain_rule_gradient(
4525            age_entry.view(),
4526            age_exit.view(),
4527            age_exit.view(),
4528            &cfg,
4529            &residuals,
4530        )
4531        .expect("analytic gradient ok")
4532        .expect("GM baseline has a θ-gradient");
4533        assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4534
4535        // Evaluate the offset-projected loss at a perturbed θ. Mirrors the
4536        // chain rule's algebra: the entry channel is only added for rows whose
4537        // r_E is nonzero (matching baseline_chain_rule_gradient's gating that
4538        // avoids calling evaluate_survival_baseline at age 0 for origin-entry
4539        // rows).
4540        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4541            let mut acc = 0.0;
4542            for i in 0..age_exit.len() {
4543                let (eta_exit_i, od_exit_i) =
4544                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4545                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4546                if residuals.entry[i] != 0.0 {
4547                    let (eta_entry_i, _) =
4548                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4549                    acc += residuals.entry[i] * eta_entry_i;
4550                }
4551            }
4552            acc
4553        };
4554
4555        let theta0 = survival_baseline_theta_from_config(&cfg)
4556            .expect("theta seed")
4557            .expect("GM has θ");
4558        // Spec requested δ = 1e-4 per axis. Use central differences over θ.
4559        let delta = 1e-4;
4560        let mut fd = Array1::<f64>::zeros(analytic.len());
4561        for k in 0..analytic.len() {
4562            let mut theta_plus = theta0.clone();
4563            theta_plus[k] += delta;
4564            let mut theta_minus = theta0.clone();
4565            theta_minus[k] -= delta;
4566            let cfg_plus =
4567                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4568            let cfg_minus =
4569                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4570            let lp = loss_at_cfg(&cfg_plus);
4571            let lm = loss_at_cfg(&cfg_minus);
4572            fd[k] = (lp - lm) / (2.0 * delta);
4573        }
4574
4575        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4576        let max_err = analytic
4577            .iter()
4578            .zip(fd.iter())
4579            .map(|(a, b)| (a - b).abs())
4580            .fold(0.0_f64, f64::max);
4581        let rel = max_err / (analytic_norm + 1e-12);
4582        // Print so the deliverable can quote the exact max-error number.
4583        eprintln!(
4584            "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4585             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4586             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4587        );
4588        assert!(
4589            rel < 1e-2,
4590            "analytic θ-gradient disagrees with central FD beyond 1%: \
4591             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4592             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4593        );
4594    }
4595
4596    /// Weibull (dim=2) companion to
4597    /// `gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference`.
4598    ///
4599    /// This is the FD gate for the analytic outer θ-gradient that the
4600    /// transformation/Weibull survival baseline optimizers now feed to BFGS
4601    /// (`optimize_survival_baseline_config_with_gradient_only`). At a *fixed* β
4602    /// the profile-NLL surface is
4603    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i;θ) + r_E[i]·η(t_entry_i;θ)
4604    ///              + r_D[i]·o_D(t_exit_i;θ) ]`,
4605    /// whose exact gradient is `baseline_chain_rule_gradient`. Comparing it to a
4606    /// central difference of `L` over `evaluate_survival_baseline` exercises the
4607    /// Weibull scale/shape partials at both entry and exit ages. If this
4608    /// disagrees with FD, the workflow's outer gradient is wrong by the same
4609    /// amount.
4610    #[test]
4611    fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4612        let cfg = SurvivalBaselineConfig {
4613            target: SurvivalBaselineTarget::Weibull,
4614            scale: Some(11.0),
4615            shape: Some(1.4),
4616            rate: None,
4617            makeham: None,
4618        };
4619        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4620        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4621        let residuals = OffsetChannelResiduals {
4622            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4623            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4624            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4625            right: Array1::<f64>::zeros(8),
4626        };
4627
4628        let analytic = baseline_chain_rule_gradient(
4629            age_entry.view(),
4630            age_exit.view(),
4631            age_exit.view(),
4632            &cfg,
4633            &residuals,
4634        )
4635        .expect("analytic gradient ok")
4636        .expect("Weibull baseline has a θ-gradient");
4637        assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4638
4639        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4640            let mut acc = 0.0;
4641            for i in 0..age_exit.len() {
4642                let (eta_exit_i, od_exit_i) =
4643                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4644                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4645                if residuals.entry[i] != 0.0 {
4646                    let (eta_entry_i, _) =
4647                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4648                    acc += residuals.entry[i] * eta_entry_i;
4649                }
4650            }
4651            acc
4652        };
4653
4654        let theta0 = survival_baseline_theta_from_config(&cfg)
4655            .expect("theta seed")
4656            .expect("Weibull has θ");
4657        let delta = 1e-4;
4658        let mut fd = Array1::<f64>::zeros(analytic.len());
4659        for k in 0..analytic.len() {
4660            let mut theta_plus = theta0.clone();
4661            theta_plus[k] += delta;
4662            let mut theta_minus = theta0.clone();
4663            theta_minus[k] -= delta;
4664            let cfg_plus =
4665                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4666            let cfg_minus =
4667                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4668            let lp = loss_at_cfg(&cfg_plus);
4669            let lm = loss_at_cfg(&cfg_minus);
4670            fd[k] = (lp - lm) / (2.0 * delta);
4671        }
4672
4673        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4674        let max_err = analytic
4675            .iter()
4676            .zip(fd.iter())
4677            .map(|(a, b)| (a - b).abs())
4678            .fold(0.0_f64, f64::max);
4679        let rel = max_err / (analytic_norm + 1e-12);
4680        eprintln!(
4681            "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4682             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4683             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4684        );
4685        assert!(
4686            rel < 1e-2,
4687            "analytic θ-gradient disagrees with central FD beyond 1%: \
4688             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4689             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4690        );
4691    }
4692
4693    // ─── baseline_offset_theta_partials — analytic vs central-difference ─
4694
4695    /// Central-difference of (eta, o_D) at fixed age wrt each θ component in
4696    /// the theta layout defined by `survival_baseline_theta_from_config`.
4697    ///
4698    /// `steps` is per-θ-component: the caller picks the step size appropriate
4699    /// for each channel. Gompertz / Gompertz–Makeham need a tiny step on the
4700    /// shape channel near the Taylor pivot |shape| < 1e-10 (so θ±h stays on
4701    /// the same branch), but a normal-scale step on log_rate / log_makeham;
4702    /// using the tiny shape-step on every channel corrupts the log_rate
4703    /// channel with `eps/(2h)` cancellation noise and has nothing to do with
4704    /// correctness of the analytic derivative.
4705    fn fd_baseline_offset(
4706        age: f64,
4707        cfg: &SurvivalBaselineConfig,
4708        steps: &[f64],
4709    ) -> Vec<(f64, f64)> {
4710        let theta = survival_baseline_theta_from_config(cfg)
4711            .expect("theta")
4712            .expect("non-linear baseline");
4713        assert_eq!(
4714            steps.len(),
4715            theta.len(),
4716            "fd_baseline_offset: step vector length must match θ dimension"
4717        );
4718        (0..theta.len())
4719            .map(|k| {
4720                let h = steps[k];
4721                let mut theta_plus = theta.clone();
4722                theta_plus[k] += h;
4723                let mut theta_minus = theta.clone();
4724                theta_minus[k] -= h;
4725                let cfg_plus =
4726                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4727                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4728                    .expect("minus cfg");
4729                let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4730                let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4731                ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4732            })
4733            .collect()
4734    }
4735
4736    fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4737        // `<=` so that bit-equal values satisfy tol = 0. With `<`, |a−e| < 0
4738        // is unsatisfiable and a zero-tolerance "must match exactly" call
4739        // would reject identical numbers.
4740        let ok = if expected.abs() < 1.0 {
4741            (actual - expected).abs() <= tol
4742        } else {
4743            (actual - expected).abs() <= tol * expected.abs().max(1.0)
4744        };
4745        assert!(
4746            ok,
4747            "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4748        );
4749    }
4750
4751    #[test]
4752    fn gompertz_offset_partials_match_central_diff() {
4753        // Several (rate, shape, age) combinations spanning the small-shape
4754        // Taylor branch (|shape| < 1e-10) and the normal branch
4755        // (shape >> 1e-10), plus sign-reversed shape.
4756        let cases = [
4757            (0.5_f64, 0.01_f64, 30.0_f64),
4758            (0.2, 0.05, 60.0),
4759            (1.0, 0.001, 10.0),
4760            (0.4, 5e-11, 25.0),
4761            (0.4, -5e-11, 25.0),
4762            (0.3, -0.02, 40.0),
4763            (0.8, 0.2, 5.0),
4764        ];
4765        for &(rate, shape, age) in &cases {
4766            let cfg = SurvivalBaselineConfig {
4767                target: SurvivalBaselineTarget::Gompertz,
4768                scale: None,
4769                shape: Some(shape),
4770                rate: Some(rate),
4771                makeham: None,
4772            };
4773            let analytic = baseline_offset_theta_partials(age, &cfg)
4774                .expect("ok")
4775                .expect("non-linear");
4776            // Keep the FD probe inside the Taylor branch for tiny |shape| so
4777            // the numeric derivative matches the same small-shape map as the
4778            // analytic helper. log_rate always uses the normal step — rate
4779            // is a moderate-scale parameter and a 1e-11 step would swamp the
4780            // FD with cancellation noise.
4781            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4782            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4783            assert_eq!(analytic.len(), 2);
4784            // Gompertz θ=(log_rate, shape). Rate channel: ∂eta/∂log_rate=1, ∂o_D/∂log_rate=0.
4785            assert_close(
4786                analytic[0].0,
4787                fd[0].0,
4788                1e-7,
4789                &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4790            );
4791            assert_close(
4792                analytic[0].1,
4793                fd[0].1,
4794                1e-7,
4795                &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4796            );
4797            // shape channel — larger tol because finite-differencing near
4798            // shape=0 amplifies rounding; 1e-5 is fine.
4799            assert_close(
4800                analytic[1].0,
4801                fd[1].0,
4802                1e-5,
4803                &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4804            );
4805            assert_close(
4806                analytic[1].1,
4807                fd[1].1,
4808                1e-5,
4809                &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4810            );
4811        }
4812    }
4813
4814    #[test]
4815    fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4816        // Pure Gompertz: rate cancels in o_D, so ∂o_D/∂log_rate must be
4817        // exactly 0 and ∂eta/∂log_rate must be exactly 1. Verify the
4818        // analytic implementation returns the exact values, not FD-close.
4819        let cfg = SurvivalBaselineConfig {
4820            target: SurvivalBaselineTarget::Gompertz,
4821            scale: None,
4822            shape: Some(0.05),
4823            rate: Some(0.3),
4824            makeham: None,
4825        };
4826        let partials = baseline_offset_theta_partials(42.0, &cfg)
4827            .expect("ok")
4828            .expect("non-linear");
4829        assert_eq!(partials[0].0, 1.0);
4830        assert_eq!(partials[0].1, 0.0);
4831    }
4832
4833    #[test]
4834    fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4835        // Both branches of gompertz_shape_derivatives should agree to high
4836        // precision at shape = 1e-10 + epsilon on the direct side vs
4837        // shape = 1e-10 - epsilon on the Taylor side. Here we spot-check
4838        // the continuity at the branch cutoff: shape slightly above and
4839        // slightly below 1e-10 must give values within O(shape²·t²)
4840        // (the Taylor truncation error).
4841        let age = 25.0;
4842        let rate = 0.4;
4843        let cfg_taylor = SurvivalBaselineConfig {
4844            target: SurvivalBaselineTarget::Gompertz,
4845            scale: None,
4846            shape: Some(0.5e-10),
4847            rate: Some(rate),
4848            makeham: None,
4849        };
4850        let cfg_direct = SurvivalBaselineConfig {
4851            target: SurvivalBaselineTarget::Gompertz,
4852            scale: None,
4853            shape: Some(2.0e-10),
4854            rate: Some(rate),
4855            makeham: None,
4856        };
4857        let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4858            .expect("ok")
4859            .expect("nl");
4860        let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4861            .expect("ok")
4862            .expect("nl");
4863        // ∂eta/∂shape at shape≈0 should be t/2 = 12.5 on both sides.
4864        assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4865        assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4866        // ∂o_D/∂shape at shape≈0 should be 1/2.
4867        assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4868        assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4869    }
4870
4871    // ----------------------------------------------------------------------
4872    // Gompertz hazard-channel shape derivatives: FD oracle + Taylor-branch
4873    // continuity. These feed `survival_hazard_theta_partials` /
4874    // `survival_hazard_theta_first_second` (the marginal-slope probit
4875    // baseline). Before this test, the only coverage of
4876    // `gompertz_cumulative_shape_{,second_}derivative` was the indirect
4877    // marginal-slope Hessian FD at shape=0.025, which never touches the
4878    // small-shape (`|shape| < 1e-10`) Taylor branch nor directly FD-checks
4879    // these analytic shape derivatives.
4880    // ----------------------------------------------------------------------
4881
4882    #[test]
4883    fn gompertz_hazard_shape_derivatives_match_central_diff() {
4884        // shape stays well above the 1e-10 Taylor cutoff so the exact
4885        // closed-form branch is exercised and the expm1/exp arithmetic is
4886        // numerically clean. FD on the analytic value/first-derivative
4887        // confirms the first and second shape derivatives.
4888        let cases = [
4889            (10.0_f64, 0.012_f64, 0.05_f64),
4890            (2.5, 0.5, 0.2),
4891            (15.0, 0.003, 0.01),
4892            (40.0, 0.3, 0.001),
4893        ];
4894        let h = 1e-6;
4895        for &(age, rate, shape) in &cases {
4896            // First shape derivative of (H_G, h_G) vs central diff of value.
4897            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4898            let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4899            let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4900            assert_close(
4901                d_cum,
4902                (cum_p - cum_m) / (2.0 * h),
4903                1e-6,
4904                &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4905            );
4906            assert_close(
4907                d_inst,
4908                (inst_p - inst_m) / (2.0 * h),
4909                1e-6,
4910                &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4911            );
4912
4913            // Second shape derivative vs central diff of the first derivative.
4914            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4915            let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4916            let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4917            assert_close(
4918                d2_cum,
4919                (dcum_p - dcum_m) / (2.0 * h),
4920                1e-5,
4921                &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4922            );
4923            assert_close(
4924                d2_inst,
4925                (dinst_p - dinst_m) / (2.0 * h),
4926                1e-5,
4927                &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4928            );
4929        }
4930    }
4931
4932    #[test]
4933    fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4934        // At small x = shape·age the shape derivatives collapse to closed-form
4935        // limits. These MUST hold even for large ages with tiny shapes, which
4936        // is precisely the regime where the (cancelling) exact branch loses all
4937        // precision and the x-based pivot routes to the Taylor branch.
4938        //   ∂H_G/∂shape   -> rate·t²/2
4939        //   ∂h_G/∂shape   -> rate·t
4940        //   ∂²H_G/∂shape² -> rate·t³/3
4941        //   ∂²h_G/∂shape² -> rate·t²
4942        // The bug this guards: the second derivative's old `shape < 1e-10`
4943        // pivot ignored `age`, so e.g. (age=100, shape=1e-5 -> x=1e-3) took the
4944        // cancelling exact branch and returned a wildly wrong curvature.
4945        let cases = [
4946            (25.0_f64, 0.4_f64, 1e-9_f64),
4947            (100.0, 0.4, 1e-6),   // x = 1e-4
4948            (100.0, 0.012, 1e-6), // x = 1e-4, the old-pivot band (large age, tiny shape)
4949            (50.0, 1.2, 1e-8),
4950        ];
4951        // NOTE: every quantity below is compared against its shape->0 *limit*.
4952        // For the cancelling cumulative branches (∂H/∂shape, ∂²H/∂shape²,
4953        // ∂²h/∂shape²) the limit is the correct shape->0 target and the
4954        // implementation routes through Taylor in this band. But the
4955        // instantaneous first derivative ∂h_G/∂shape = rate·age·e^x carries NO
4956        // cancellation: it is exact, and its departure from the limit rate·t is
4957        // a genuine O(x) effect. At x=1e-3 that departure is ~1.2e-3 (> tol),
4958        // so the cases here keep x <= 1e-4 where the limit is a valid 1e-3
4959        // oracle for *all four* quantities. The cancelling-branch regression at
4960        // larger x is covered by gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap.
4961        for &(age, rate, shape) in &cases {
4962            let t = age;
4963            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4964            assert_close(
4965                d_cum,
4966                rate * t * t / 2.0,
4967                1e-3,
4968                &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4969            );
4970            assert_close(
4971                d_inst,
4972                rate * t,
4973                1e-3,
4974                &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4975            );
4976
4977            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4978            assert_close(
4979                d2_cum,
4980                rate * t * t * t / 3.0,
4981                1e-3,
4982                &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4983            );
4984            assert_close(
4985                d2_inst,
4986                rate * t * t,
4987                1e-3,
4988                &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4989            );
4990        }
4991    }
4992
4993    #[test]
4994    fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4995        // Regression: in the band shape ∈ [1e-10, ~1e-4] with a realistic age,
4996        // the OLD `shape < 1e-10` pivot sent ∂²H_G/∂shape² through the
4997        // catastrophically-cancelling exact branch. With age=100, shape=1e-9
4998        // (x=1e-7) the exact branch returned ~+5e1 vs the true ~rate·t³/3.
4999        // Assert the implementation now matches the closed-form limit to high
5000        // precision throughout that band, across several decades of shape.
5001        let age = 100.0;
5002        let rate = 0.4;
5003        let t = age;
5004        let truth = rate * t * t * t / 3.0; // 1.333e5
5005        // Start at shape=1e-5 (x=1e-3): below this the second derivative is,
5006        // to better than 1e-3 relative, equal to its shape->0 limit, so the
5007        // limit is a valid oracle. (At x=1e-2 the true value legitimately
5008        // departs from the limit by ~7e-3, which is a real O(x) correction,
5009        // not an error — so we do not extend the band up to shape=1e-4.)
5010        for k in 5..=12 {
5011            let shape = 10f64.powi(-(k as i32)); // 1e-5 .. 1e-12
5012            let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
5013            assert_close(
5014                d2_cum,
5015                truth,
5016                1e-3,
5017                &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
5018            );
5019        }
5020    }
5021
5022    #[test]
5023    fn weibull_offset_partials_match_central_diff() {
5024        let cases = [
5025            (0.5_f64, 1.2_f64, 25.0_f64),
5026            (2.0, 0.8, 60.0),
5027            (0.1, 3.0, 10.0),
5028        ];
5029        for &(scale, shape, age) in &cases {
5030            let cfg = SurvivalBaselineConfig {
5031                target: SurvivalBaselineTarget::Weibull,
5032                scale: Some(scale),
5033                shape: Some(shape),
5034                rate: None,
5035                makeham: None,
5036            };
5037            let analytic = baseline_offset_theta_partials(age, &cfg)
5038                .expect("ok")
5039                .expect("nl");
5040            let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
5041            assert_eq!(analytic.len(), 2);
5042            for k in 0..2 {
5043                assert_close(
5044                    analytic[k].0,
5045                    fd[k].0,
5046                    1e-7,
5047                    &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5048                );
5049                assert_close(
5050                    analytic[k].1,
5051                    fd[k].1,
5052                    1e-7,
5053                    &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5054                );
5055            }
5056            // Weibull o_D = shape/t is independent of scale; verify exactly.
5057            assert_eq!(analytic[0].1, 0.0);
5058        }
5059    }
5060
5061    #[test]
5062    fn gompertz_makeham_offset_partials_match_central_diff() {
5063        let cases = [
5064            (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
5065            (0.5, 0.01, 0.01, 25.0),
5066            (0.2, 0.001, 0.005, 60.0),
5067            (0.4, 5e-11, 0.01, 25.0),
5068            (0.4, -5e-11, 0.01, 25.0),
5069            (0.8, 0.2, 0.05, 5.0),
5070        ];
5071        for &(rate, shape, makeham, age) in &cases {
5072            let cfg = SurvivalBaselineConfig {
5073                target: SurvivalBaselineTarget::GompertzMakeham,
5074                scale: None,
5075                shape: Some(shape),
5076                rate: Some(rate),
5077                makeham: Some(makeham),
5078            };
5079            let analytic = baseline_offset_theta_partials(age, &cfg)
5080                .expect("ok")
5081                .expect("nl");
5082            // See gompertz_offset_partials_match_central_diff: tiny shape-step
5083            // is only needed for the shape component; log_rate and
5084            // log_makeham take the normal-scale step.
5085            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
5086            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
5087            assert_eq!(analytic.len(), 3);
5088            for k in 0..3 {
5089                assert_close(
5090                    analytic[k].0,
5091                    fd[k].0,
5092                    1e-5,
5093                    &format!(
5094                        "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5095                    ),
5096                );
5097                assert_close(
5098                    analytic[k].1,
5099                    fd[k].1,
5100                    1e-5,
5101                    &format!(
5102                        "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5103                    ),
5104                );
5105            }
5106        }
5107    }
5108
5109    #[test]
5110    fn linear_baseline_has_no_theta_partials() {
5111        let cfg = SurvivalBaselineConfig {
5112            target: SurvivalBaselineTarget::Linear,
5113            scale: None,
5114            shape: None,
5115            rate: None,
5116            makeham: None,
5117        };
5118        assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
5119    }
5120
5121    #[test]
5122    fn baseline_offset_partials_reject_non_positive_ages() {
5123        let cfg = SurvivalBaselineConfig {
5124            target: SurvivalBaselineTarget::Gompertz,
5125            scale: None,
5126            shape: Some(0.01),
5127            rate: Some(0.5),
5128            makeham: None,
5129        };
5130        assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
5131        assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
5132        assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
5133    }
5134
5135    // ─── baseline_chain_rule_gradient — mechanical and FD-vs-θ tests ─────
5136
5137    /// Mechanical sanity check: with only one event observation at known
5138    /// (r_X, r_E, r_D, age_exit, age_entry), the Gompertz chain-rule gradient
5139    /// reduces to the analytic linear combination of `baseline_offset_theta_partials`.
5140    #[test]
5141    fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
5142        let cfg = SurvivalBaselineConfig {
5143            target: SurvivalBaselineTarget::Gompertz,
5144            scale: None,
5145            shape: Some(0.05),
5146            rate: Some(0.3),
5147            makeham: None,
5148        };
5149        let age_entry = array![10.0_f64];
5150        let age_exit = array![25.0_f64];
5151        let residuals = OffsetChannelResiduals {
5152            exit: array![0.7_f64],
5153            entry: array![-0.2_f64],
5154            derivative: array![-0.4_f64],
5155            right: Array1::<f64>::zeros(1),
5156        };
5157        let grad = baseline_chain_rule_gradient(
5158            age_entry.view(),
5159            age_exit.view(),
5160            age_exit.view(),
5161            &cfg,
5162            &residuals,
5163        )
5164        .expect("ok")
5165        .expect("non-linear");
5166        // Hand-compute: grad[k] = r_X·∂eta_exit/∂θ_k + r_D·∂o_D_exit/∂θ_k + r_E·∂eta_entry/∂θ_k.
5167        let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5168            .unwrap()
5169            .unwrap();
5170        let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5171            .unwrap()
5172            .unwrap();
5173        for k in 0..p_exit.len() {
5174            let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5175            assert!(
5176                (grad[k] - expected).abs() < 1e-12,
5177                "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5178                grad[k],
5179                expected
5180            );
5181        }
5182    }
5183
5184    /// Origin-entry rows (r_entry == 0) must skip the baseline partials call at
5185    /// `age_entry = 0`, which would otherwise fail the positive-age precondition.
5186    #[test]
5187    fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5188        let cfg = SurvivalBaselineConfig {
5189            target: SurvivalBaselineTarget::Gompertz,
5190            scale: None,
5191            shape: Some(0.05),
5192            rate: Some(0.3),
5193            makeham: None,
5194        };
5195        let age_entry = array![0.0_f64, 5.0_f64];
5196        let age_exit = array![10.0_f64, 20.0_f64];
5197        let residuals = OffsetChannelResiduals {
5198            exit: array![0.5_f64, 0.3_f64],
5199            entry: array![0.0_f64, -0.1_f64], // row 0 is origin-entry (r_E = 0)
5200            derivative: array![-0.2_f64, 0.0_f64],
5201            right: Array1::<f64>::zeros(2),
5202        };
5203        // Must not error despite age_entry[0] == 0.
5204        let grad = baseline_chain_rule_gradient(
5205            age_entry.view(),
5206            age_exit.view(),
5207            age_exit.view(),
5208            &cfg,
5209            &residuals,
5210        )
5211        .expect("must not fail on origin-entry row with r_entry=0")
5212        .expect("non-linear");
5213        assert_eq!(grad.len(), 2);
5214        // Row 1's entry channel contributes, row 0's does not.
5215        let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5216        let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5217        let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5218        for k in 0..2 {
5219            let expected = 0.5 * p_exit_0[k].0
5220                + (-0.2) * p_exit_0[k].1
5221                + 0.3 * p_exit_1[k].0
5222                + (-0.1) * p_entry_1[k].0;
5223            assert!(
5224                (grad[k] - expected).abs() < 1e-12,
5225                "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5226                grad[k],
5227                expected
5228            );
5229        }
5230    }
5231
5232    /// Linear target has no θ-parameters; contractor returns None.
5233    #[test]
5234    fn chain_rule_gradient_linear_target_returns_none() {
5235        let cfg = SurvivalBaselineConfig {
5236            target: SurvivalBaselineTarget::Linear,
5237            scale: None,
5238            shape: None,
5239            rate: None,
5240            makeham: None,
5241        };
5242        let age_entry = array![1.0_f64];
5243        let age_exit = array![2.0_f64];
5244        let residuals = OffsetChannelResiduals {
5245            exit: array![0.1_f64],
5246            entry: array![0.0_f64],
5247            derivative: array![0.0_f64],
5248            right: Array1::<f64>::zeros(1),
5249        };
5250        let grad = baseline_chain_rule_gradient(
5251            age_entry.view(),
5252            age_exit.view(),
5253            age_exit.view(),
5254            &cfg,
5255            &residuals,
5256        )
5257        .expect("ok");
5258        assert!(grad.is_none());
5259    }
5260
5261    /// End-to-end envelope-theorem check: the chain-rule gradient at
5262    /// residuals-evaluated-at-β-fixed matches the central FD of the
5263    /// unpenalized NLL with respect to θ when the OFFSETS are recomputed
5264    /// from the perturbed cfg and β is held at its base value.
5265    ///
5266    /// This is the mathematical content of the envelope theorem applied to
5267    /// the penalized-deviance cost at fixed β: if β solves ∂C/∂β = 0 at
5268    /// (θ, β*), then the total derivative of C at (θ±h) when β is held at
5269    /// β* equals the partial derivative of C wrt θ at the base — up to
5270    /// O(h²) in the truncation error of central differences. For THIS test
5271    /// we're directly differencing NLL (the unpenalized piece that carries
5272    /// all the θ dependence), so the envelope identity is exact up to FD
5273    /// truncation.
5274    ///
5275    /// The test synthesizes a plausible residual set by hand rather than
5276    /// running PIRLS — what we're validating is the chain-rule contractor,
5277    /// not the fit. A PIRLS-based end-to-end check belongs in an
5278    /// integration test, not this unit-test module.
5279    #[test]
5280    fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5281        // Toy 3-observation case with two events (one origin-entry, one not)
5282        // and one censored row at large age.
5283        let cfg = SurvivalBaselineConfig {
5284            target: SurvivalBaselineTarget::Gompertz,
5285            scale: None,
5286            shape: Some(0.03),
5287            rate: Some(0.25),
5288            makeham: None,
5289        };
5290        let age_entry = array![0.0_f64, 5.0, 8.0];
5291        let age_exit = array![4.0_f64, 12.0, 20.0];
5292        // Weighted residuals at a notional β*. Values chosen in a plausible
5293        // range (~same order as w·exp(η)).
5294        let weights = array![1.0_f64, 2.0, 0.5];
5295        let events = [1.0_f64, 1.0, 0.0];
5296        // Fake a β* that yields finite eta_entry ± eta_exit ± s values by
5297        // directly specifying eta quantities. Contractor only consumes the
5298        // residuals, so the fake is sufficient.
5299        let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; // row 0 doesn't matter (origin entry)
5300        let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5301        let s_vals = [0.7_f64, 1.1, 1.5];
5302        let (r_x, r_e, r_d) = {
5303            let mut rx = Array1::<f64>::zeros(3);
5304            let mut re = Array1::<f64>::zeros(3);
5305            let mut rd = Array1::<f64>::zeros(3);
5306            for i in 0..3 {
5307                let w = weights[i];
5308                let d = events[i];
5309                rx[i] = w * (eta_exit_vals[i].exp() - d);
5310                re[i] = if i == 0 {
5311                    0.0 // origin entry
5312                } else {
5313                    -w * eta_entry_vals[i].exp()
5314                };
5315                rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5316            }
5317            (rx, re, rd)
5318        };
5319        let residuals = OffsetChannelResiduals {
5320            exit: r_x.clone(),
5321            entry: r_e.clone(),
5322            derivative: r_d.clone(),
5323            right: Array1::<f64>::zeros(3),
5324        };
5325        let grad = baseline_chain_rule_gradient(
5326            age_entry.view(),
5327            age_exit.view(),
5328            age_exit.view(),
5329            &cfg,
5330            &residuals,
5331        )
5332        .expect("ok")
5333        .expect("non-linear");
5334
5335        // Construct NLL(θ) with β* held to the same eta/s values by treating
5336        // eta_i, s_i as fixed "linear predictor" samples and shifting by
5337        // (offset(θ) - offset(θ_base)). That's exactly the RP NLL with β*
5338        // held constant and offsets varied through θ.
5339        let nll = |theta_plus: &Array1<f64>| -> f64 {
5340            let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5341            let mut sum = 0.0_f64;
5342            for i in 0..3 {
5343                let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5344                let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5345                let d_eta_x = eta_x_p - base.0;
5346                let d_d_x = d_x_p - base.1;
5347                let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5348                let s_new = s_vals[i] + d_d_x;
5349                let interval_entry = if i == 0 {
5350                    0.0_f64
5351                } else {
5352                    let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5353                    let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5354                    let d_eta_e = eta_e_p - base_e.0;
5355                    let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5356                    eta_entry_new.exp()
5357                };
5358                let w = weights[i];
5359                let d = events[i];
5360                let nll_i =
5361                    w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5362                sum += nll_i;
5363            }
5364            sum
5365        };
5366
5367        let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5368        let h = 1e-6;
5369        for k in 0..theta_base.len() {
5370            let mut tp = theta_base.clone();
5371            let mut tm = theta_base.clone();
5372            tp[k] += h;
5373            tm[k] -= h;
5374            let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5375            assert!(
5376                (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5377                "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5378                grad[k],
5379                fd
5380            );
5381        }
5382    }
5383
5384    /// Length-mismatch surfaces as an error, not a silent contraction.
5385    #[test]
5386    fn chain_rule_gradient_rejects_length_mismatch() {
5387        let cfg = SurvivalBaselineConfig {
5388            target: SurvivalBaselineTarget::Gompertz,
5389            scale: None,
5390            shape: Some(0.05),
5391            rate: Some(0.3),
5392            makeham: None,
5393        };
5394        let age_entry = array![1.0_f64, 2.0]; // length 2
5395        let age_exit = array![5.0_f64, 6.0, 7.0]; // length 3
5396        let residuals = OffsetChannelResiduals {
5397            exit: array![0.1_f64, 0.2, 0.3],
5398            entry: array![0.0_f64, 0.0, 0.0],
5399            derivative: array![0.0_f64, 0.0, 0.0],
5400            right: Array1::<f64>::zeros(3),
5401        };
5402        let err = baseline_chain_rule_gradient(
5403            age_entry.view(),
5404            age_exit.view(),
5405            age_exit.view(),
5406            &cfg,
5407            &residuals,
5408        )
5409        .expect_err("length mismatch must error");
5410        assert!(err.contains("length mismatch"), "err={err}");
5411    }
5412}