Skip to main content

gam_models/survival/
construction.rs

1//! Survival model construction helpers.
2//!
3//! Types and functions for building survival model components:
4//! - Baseline hazard targets (Weibull, Gompertz, Gompertz-Makeham)
5//! - Time basis construction (I-spline on log-time)
6//! - Baseline offset computation
7//! - Time wiggle construction
8//!
9//! These are the building blocks a library consumer needs to construct
10//! a `FitRequest::SurvivalLocationScale` without going through the CLI.
11
12use gam_terms::basis::{
13    BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14    BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15    create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18    DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19    SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24    WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25    monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::{InverseLink, StandardLink};
31use ndarray::{Array1, Array2, array, s};
32use rayon::prelude::*;
33
34// ---------------------------------------------------------------------------
35// Typed error
36// ---------------------------------------------------------------------------
37
38/// Structured failure surface for survival-model construction helpers
39/// (`parse_*`, baseline-config builders, time-basis construction). Every
40/// variant carries a free-form `reason: String` payload; `Display` emits
41/// that payload verbatim, so converting to `String` via the `From` impl
42/// produces text byte-equivalent to the pre-refactor `Err(format!(...))`
43/// call sites that were the only producers in this module.
44///
45/// The public CLI-input parsers (`parse_survival_distribution`,
46/// `parse_survival_likelihood_mode`, `parse_survival_baseline_config`)
47/// keep their `Result<_, String>` signatures — string is the natural
48/// failure type for free-form user input — and route through this enum
49/// internally via `From<SurvivalConstructionError> for String`.
50#[derive(Clone, Debug)]
51pub enum SurvivalConstructionError {
52    /// User-supplied configuration is malformed or out of range (knot
53    /// counts, anchor offsets, derivative guards, ranks).
54    InvalidConfig { reason: String },
55    /// A required column or block of metadata is absent (e.g. saved
56    /// survival ispline keep_cols, baseline target on a saved fit).
57    MissingColumn { reason: String },
58    /// Per-row / per-column shape disagreement (entry/exit lengths,
59    /// penalty rank vs basis width, basis vs coefficient counts).
60    IncompatibleDimensions { reason: String },
61    /// Numeric / domain rejection: non-finite ratios, non-positive
62    /// survival times, monotonicity violations, ispline-derivative
63    /// underflow.
64    DataValidationFailed { reason: String },
65    /// Underlying basis / penalty builder rejected the construction
66    /// request (invalid spline order, ispline keep_cols out of range,
67    /// internal empty ispline time basis).
68    BasisConstructionFailed { reason: String },
69    /// User-named distribution / likelihood-mode / baseline target /
70    /// time-basis kind is not one we recognise.
71    UnsupportedDistribution { reason: String },
72}
73
74impl_reason_error_boilerplate! {
75    SurvivalConstructionError {
76        InvalidConfig,
77        MissingColumn,
78        IncompatibleDimensions,
79        DataValidationFailed,
80        BasisConstructionFailed,
81        UnsupportedDistribution,
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Types
87// ---------------------------------------------------------------------------
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SurvivalBaselineTarget {
91    /// No additional parametric target:
92    /// eta_target(t) = 0, so regularized model defaults to linear log-cumulative
93    /// hazard from the existing time basis.
94    Linear,
95    /// Parametric target: Weibull baseline.
96    ///
97    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
98    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
99    Weibull,
100    /// Parametric target: Gompertz baseline.
101    ///
102    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
103    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
104    Gompertz,
105    /// Parametric target: Gompertz-Makeham baseline.
106    ///
107    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
108    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
109    GompertzMakeham,
110}
111
112#[derive(Clone, Debug)]
113pub struct SurvivalBaselineConfig {
114    pub target: SurvivalBaselineTarget,
115    pub scale: Option<f64>,
116    pub shape: Option<f64>,
117    pub rate: Option<f64>,
118    pub makeham: Option<f64>,
119}
120
121#[derive(Clone, Debug)]
122pub enum SurvivalTimeBasisConfig {
123    None,
124    Linear,
125    BSpline {
126        degree: usize,
127        knots: Array1<f64>,
128        smooth_lambda: f64,
129    },
130    /// I-spline value rows on the `log(t)` axis with non-negative
131    /// coefficients (`γ ≥ 0`) enforcing structural monotonicity of
132    /// `q(t) = I_basis(log t) · γ`. This replaces the row-wise
133    /// `D β + o ≥ guard` derivative-guard constraints the marginal-slope
134    /// family previously relied on.
135    ///
136    /// The design builder lives below at `_build_time_block`'s
137    /// `SurvivalTimeBasisConfig::ISpline` arm and exposes:
138    ///
139    /// * `x_entry_time` / `x_exit_time` — I-spline value rows on the
140    ///   `log(t)` axis. Non-negative entries plus `γ ≥ 0` give a
141    ///   monotone-non-decreasing `q(t)`, the structural property the
142    ///   marginal-slope family needs.
143    /// * `x_derivative_time` — right-cumulative B-spline-derivative on
144    ///   `log(t)` scaled by `1/t`, again non-negative with `γ ≥ 0`, so
145    ///   `q'(t) ≥ 0` pointwise. The `derivative_guard` constant is added
146    ///   externally by [`add_survival_time_derivative_guard_offset`],
147    ///   leaving the derivative guarantee `q'(t) ≥ guard` exact.
148    /// * 2nd-difference penalty on the underlying degree-`(k+1)` B-spline
149    ///   coefficients, filtered through `keep_cols` for identifiability.
150    ///
151    /// `TimeBlockInput::time_monotonicity` declares to the consuming
152    /// family how monotonicity is enforced. The marginal-slope
153    /// construction site sets it to
154    /// [`crate::survival::location_scale::TimeBlockMonotonicity::StructuralISpline`]
155    /// so the family skips row-wise `D β + o ≥ guard` constraint
156    /// generation and treats `γ ≥ 0` as the sole derivative-guard
157    /// mechanism. The universal `validate_time_qd1_feasible` safety net
158    /// runs regardless.
159    ///
160    /// An earlier iteration proposed a separate C-spline antiderivative
161    /// parameterization that put `q'(t)` in the I-spline space and `q(t)`
162    /// in the integral-of-I-spline space. That was mathematically
163    /// equivalent but a strictly worse fit for the codebase (extra basis
164    /// degree, an extra antiderivative builder, an extra identifiability
165    /// path, an extra penalty); it was removed in favor of the canonical
166    /// I-spline-value path here.
167    ISpline {
168        degree: usize,
169        knots: Array1<f64>,
170        keep_cols: Vec<usize>,
171        smooth_lambda: f64,
172    },
173}
174
175/// Persistable snapshot of the time-basis state used by a survival fit.
176///
177/// Every survival family routes through [`SurvivalTimeBuildOutput`] during
178/// the fit, but the FFI save path needs only the metadata — not the full
179/// design matrices. This struct is the single source of truth that flows
180/// from the workflow-level basis construction, through the family-specific
181/// fit result, into the saved-model payload via
182/// [`crate::inference::model::FittedModelPayload::apply_survival_time_basis`].
183///
184/// Threading this snapshot end-to-end eliminates the prior bug pattern
185/// where each FFI builder had to reconstruct the metadata from
186/// `fit_config` + the formula (silent drift risk; one builder forgetting
187/// to do so caused the marginal-slope save→load break).
188#[derive(Clone, Debug, PartialEq)]
189pub struct SavedSurvivalTimeBasis {
190    pub basisname: String,
191    pub degree: Option<usize>,
192    pub knots: Option<Vec<f64>>,
193    pub keep_cols: Option<Vec<usize>>,
194    pub smooth_lambda: Option<f64>,
195    pub anchor: f64,
196}
197
198impl SavedSurvivalTimeBasis {
199    /// Build a snapshot from the realised time-basis state and the entry
200    /// anchor that was used during the fit.
201    pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
202        Self {
203            basisname: build.basisname.clone(),
204            degree: build.degree,
205            knots: build.knots.clone(),
206            keep_cols: build.keep_cols.clone(),
207            smooth_lambda: build.smooth_lambda,
208            anchor,
209        }
210    }
211}
212
213#[derive(Clone)]
214pub struct SurvivalTimeBuildOutput {
215    pub x_entry_time: DesignMatrix,
216    pub x_exit_time: DesignMatrix,
217    pub x_derivative_time: DesignMatrix,
218    pub penalties: Vec<Array2<f64>>,
219    /// Structural nullspace dimension of each penalty matrix.
220    pub nullspace_dims: Vec<usize>,
221    pub basisname: String,
222    pub degree: Option<usize>,
223    pub knots: Option<Vec<f64>>,
224    pub keep_cols: Option<Vec<usize>>,
225    pub smooth_lambda: Option<f64>,
226}
227
228pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
229
230/// Entry ages above this value mark genuine left truncation (delayed entry): the
231/// row's cumulative-hazard interval starts at a positive left-tail time rather
232/// than the time origin. Kept in lockstep with the working-model's
233/// `ENTRY_AT_ORIGIN_THRESHOLD` so the "this row has an entry interval" and "the
234/// data is left-truncated" decisions agree.
235pub const SURVIVAL_DELAYED_ENTRY_THRESHOLD: f64 = 1e-8;
236
237/// Seed smoothing penalty `λ` used when a survival time basis is reconstructed
238/// from a build (or saved model) that did not carry an explicit `smooth_lambda`.
239/// This is only an initial value for the REML smoothing search, not a fixed
240/// policy: a small positive seed keeps the baseline spline lightly regularized
241/// at the start so the outer optimizer begins from a well-conditioned point and
242/// then adapts `λ` to the data. Kept in one place so the b-spline and i-spline
243/// reconstruction paths cannot drift apart.
244const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
245
246/// Default initial Gompertz / Gompertz-Makeham shape parameter when the user
247/// does not supply `--baseline-shape`. The Gompertz hazard is
248/// `h(t) = rate · exp(shape · t)`; a near-zero shape seeds the baseline at an
249/// almost-flat (exponential-like) hazard, letting the fit grow the
250/// age-acceleration term from the data rather than committing to a strong
251/// curvature up front. Shared by the parse and fit-seed paths so both start
252/// from the same neutral shape.
253const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
254
255#[derive(Clone, Copy, Debug, PartialEq, Eq)]
256pub enum SurvivalLikelihoodMode {
257    Transformation,
258    Weibull,
259    LocationScale,
260    MarginalSlope,
261    Latent,
262    LatentBinary,
263}
264
265pub struct SurvivalTimeWiggleBuild {
266    pub penalties: Vec<Array2<f64>>,
267    pub nullspace_dims: Vec<usize>,
268    pub knots: Array1<f64>,
269    pub degree: usize,
270    pub ncols: usize,
271}
272
273// ---------------------------------------------------------------------------
274// Time normalization
275// ---------------------------------------------------------------------------
276
277pub fn normalize_survival_time_pair(
278    entry_raw: f64,
279    exit_raw: f64,
280    row_index: usize,
281) -> Result<(f64, f64), String> {
282    if !entry_raw.is_finite() || !exit_raw.is_finite() {
283        return Err(SurvivalConstructionError::DataValidationFailed {
284            reason: format!("non-finite survival times at row {}", row_index + 1),
285        }
286        .into());
287    }
288    if entry_raw < 0.0 || exit_raw < 0.0 {
289        return Err(SurvivalConstructionError::DataValidationFailed {
290            reason: format!("negative survival times at row {}", row_index + 1),
291        }
292        .into());
293    }
294
295    let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
296    let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
297    Ok((entry, exit))
298}
299
300// ---------------------------------------------------------------------------
301// Basis monotonicity helpers
302// ---------------------------------------------------------------------------
303
304pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
305    basisname.eq_ignore_ascii_case("ispline")
306}
307
308pub fn require_structural_survival_time_basis(
309    basisname: &str,
310    context: &str,
311) -> Result<(), String> {
312    if survival_basis_supports_structural_monotonicity(basisname) {
313        return Ok(());
314    }
315    Err(SurvivalConstructionError::UnsupportedDistribution {
316        reason: format!(
317            "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
318Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
319`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
320Re-run with `--time-basis ispline`."
321        ),
322    }
323    .into())
324}
325
326// ---------------------------------------------------------------------------
327// Baseline config parsing
328// ---------------------------------------------------------------------------
329
330pub fn parse_survival_baseline_config(
331    target_raw: &str,
332    scale: Option<f64>,
333    shape: Option<f64>,
334    rate: Option<f64>,
335    makeham: Option<f64>,
336) -> Result<SurvivalBaselineConfig, String> {
337    let target = match target_raw.to_ascii_lowercase().as_str() {
338        "linear" => SurvivalBaselineTarget::Linear,
339        "weibull" => SurvivalBaselineTarget::Weibull,
340        "gompertz" => SurvivalBaselineTarget::Gompertz,
341        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
342        other => {
343            return Err(SurvivalConstructionError::UnsupportedDistribution {
344                reason: format!(
345                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
346                ),
347            }
348            .into());
349        }
350    };
351
352    match target {
353        SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
354            target,
355            scale: None,
356            shape: None,
357            rate: None,
358            makeham: None,
359        }),
360        SurvivalBaselineTarget::Weibull => {
361            let scale = scale.ok_or_else(|| {
362                "--baseline-target weibull requires --baseline-scale > 0".to_string()
363            })?;
364            let shape = shape.ok_or_else(|| {
365                "--baseline-target weibull requires --baseline-shape > 0".to_string()
366            })?;
367            if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
368                return Err(
369                    "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
370                        .to_string(),
371                );
372            }
373            Ok(SurvivalBaselineConfig {
374                target,
375                scale: Some(scale),
376                shape: Some(shape),
377                rate: None,
378                makeham: None,
379            })
380        }
381        SurvivalBaselineTarget::Gompertz => {
382            let rate = rate.unwrap_or(1.0);
383            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
384            if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
385                return Err(
386                    "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
387                        .to_string(),
388                );
389            }
390            Ok(SurvivalBaselineConfig {
391                target,
392                scale: None,
393                shape: Some(shape),
394                rate: Some(rate),
395                makeham: None,
396            })
397        }
398        SurvivalBaselineTarget::GompertzMakeham => {
399            let rate = rate.unwrap_or(0.5);
400            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
401            let makeham = makeham.unwrap_or(0.5);
402            if !rate.is_finite()
403                || rate <= 0.0
404                || !shape.is_finite()
405                || !makeham.is_finite()
406                || makeham <= 0.0
407            {
408                return Err(
409                    "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
410                        .to_string(),
411                );
412            }
413            Ok(SurvivalBaselineConfig {
414                target,
415                scale: None,
416                shape: Some(shape),
417                rate: Some(rate),
418                makeham: Some(makeham),
419            })
420        }
421    }
422}
423
424// ---------------------------------------------------------------------------
425// Likelihood mode / distribution parsing
426// ---------------------------------------------------------------------------
427
428pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
429    match raw.to_ascii_lowercase().as_str() {
430        "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
431        "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
432        "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
433        "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
434        "latent" => Ok(SurvivalLikelihoodMode::Latent),
435        "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
436        other => Err(SurvivalConstructionError::UnsupportedDistribution {
437            reason: format!(
438                "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
439            ),
440        }
441        .into()),
442    }
443}
444
445pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
446    match mode {
447        SurvivalLikelihoodMode::Transformation => "transformation",
448        SurvivalLikelihoodMode::Weibull => "weibull",
449        SurvivalLikelihoodMode::LocationScale => "location-scale",
450        SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
451        SurvivalLikelihoodMode::Latent => "latent",
452        SurvivalLikelihoodMode::LatentBinary => "latent-binary",
453    }
454}
455
456pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
457    match raw.to_ascii_lowercase().as_str() {
458        "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
459        "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
460        "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
461        other => Err(SurvivalConstructionError::UnsupportedDistribution {
462            reason: format!(
463                "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
464            ),
465        }
466        .into()),
467    }
468}
469
470pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
471    match target {
472        SurvivalBaselineTarget::Linear => "linear",
473        SurvivalBaselineTarget::Weibull => "weibull",
474        SurvivalBaselineTarget::Gompertz => "gompertz",
475        SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
476    }
477}
478
479pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
480    let sum = age_exit
481        .iter()
482        .copied()
483        .filter(|value| value.is_finite() && *value > 0.0)
484        .sum::<f64>();
485    let count = age_exit
486        .iter()
487        .filter(|value| value.is_finite() && **value > 0.0)
488        .count()
489        .max(1);
490    (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
491}
492
493pub fn initial_survival_baseline_config_for_fit(
494    target_raw: &str,
495    scale: Option<f64>,
496    shape: Option<f64>,
497    rate: Option<f64>,
498    makeham: Option<f64>,
499    age_exit: &Array1<f64>,
500) -> Result<SurvivalBaselineConfig, String> {
501    let target = match target_raw.trim().to_ascii_lowercase().as_str() {
502        "linear" => SurvivalBaselineTarget::Linear,
503        "weibull" => SurvivalBaselineTarget::Weibull,
504        "gompertz" => SurvivalBaselineTarget::Gompertz,
505        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
506        other => {
507            return Err(SurvivalConstructionError::UnsupportedDistribution {
508                reason: format!(
509                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
510                ),
511            }
512            .into());
513        }
514    };
515    let time_scale_seed = positive_survival_time_seed(age_exit);
516    let cfg = match target {
517        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
518            target,
519            scale: None,
520            shape: None,
521            rate: None,
522            makeham: None,
523        },
524        SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
525            target,
526            scale: Some(scale.unwrap_or(time_scale_seed)),
527            shape: Some(shape.unwrap_or(1.0)),
528            rate: None,
529            makeham: None,
530        },
531        SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
532            target,
533            scale: None,
534            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
535            rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
536            makeham: None,
537        },
538        SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
539            target,
540            scale: None,
541            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
542            rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
543            makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
544        },
545    };
546    parse_survival_baseline_config(
547        survival_baseline_targetname(cfg.target),
548        cfg.scale,
549        cfg.shape,
550        cfg.rate,
551        cfg.makeham,
552    )
553}
554
555fn survival_baseline_theta_from_config(
556    cfg: &SurvivalBaselineConfig,
557) -> Result<Option<Array1<f64>>, String> {
558    Ok(match cfg.target {
559        SurvivalBaselineTarget::Linear => None,
560        SurvivalBaselineTarget::Weibull => Some(array![
561            cfg.scale
562                .ok_or_else(|| "missing weibull baseline scale".to_string())?
563                .ln(),
564            cfg.shape
565                .ok_or_else(|| "missing weibull baseline shape".to_string())?
566                .ln(),
567        ]),
568        SurvivalBaselineTarget::Gompertz => Some(array![
569            cfg.rate
570                .ok_or_else(|| "missing gompertz baseline rate".to_string())?
571                .ln(),
572            cfg.shape
573                .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
574        ]),
575        SurvivalBaselineTarget::GompertzMakeham => Some(array![
576            cfg.rate
577                .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
578                .ln(),
579            cfg.shape
580                .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
581            cfg.makeham
582                .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
583                .ln(),
584        ]),
585    })
586}
587
588fn survival_baseline_config_from_theta(
589    target: SurvivalBaselineTarget,
590    theta: &Array1<f64>,
591) -> Result<SurvivalBaselineConfig, String> {
592    let cfg = match target {
593        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
594            target,
595            scale: None,
596            shape: None,
597            rate: None,
598            makeham: None,
599        },
600        SurvivalBaselineTarget::Weibull => {
601            if theta.len() != 2 {
602                return Err(SurvivalConstructionError::IncompatibleDimensions {
603                    reason: format!(
604                        "weibull baseline parameter dimension mismatch: expected 2, got {}",
605                        theta.len()
606                    ),
607                }
608                .into());
609            }
610            SurvivalBaselineConfig {
611                target,
612                scale: Some(theta[0].exp()),
613                shape: Some(theta[1].exp()),
614                rate: None,
615                makeham: None,
616            }
617        }
618        SurvivalBaselineTarget::Gompertz => {
619            if theta.len() != 2 {
620                return Err(SurvivalConstructionError::IncompatibleDimensions {
621                    reason: format!(
622                        "gompertz baseline parameter dimension mismatch: expected 2, got {}",
623                        theta.len()
624                    ),
625                }
626                .into());
627            }
628            SurvivalBaselineConfig {
629                target,
630                scale: None,
631                shape: Some(theta[1]),
632                rate: Some(theta[0].exp()),
633                makeham: None,
634            }
635        }
636        SurvivalBaselineTarget::GompertzMakeham => {
637            if theta.len() != 3 {
638                return Err(SurvivalConstructionError::IncompatibleDimensions {
639                    reason: format!(
640                        "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
641                        theta.len()
642                    ),
643                }
644                .into());
645            }
646            SurvivalBaselineConfig {
647                target,
648                scale: None,
649                shape: Some(theta[1]),
650                rate: Some(theta[0].exp()),
651                makeham: Some(theta[2].exp()),
652            }
653        }
654    };
655    parse_survival_baseline_config(
656        survival_baseline_targetname(cfg.target),
657        cfg.scale,
658        cfg.shape,
659        cfg.rate,
660        cfg.makeham,
661    )
662}
663
664/// Derivative contract for the shared baseline-θ outer optimizer.
665///
666/// The two public baseline optimizers (`…_with_gradient_only`,
667/// `…_with_gradient`) differ in exactly one axis: how much derivative
668/// information the objective closure supplies, and therefore which curvature
669/// declaration the `OuterProblem` must advertise. Every baseline-θ path now
670/// supplies an exact analytic gradient (profile-NLL envelope gradient), so both
671/// contracts route to a gradient-based solver. Everything else — θ↔config
672/// conversion, the ±6 log-space box,
673/// the single-seed config, the `run`/convergence/error-formatting boilerplate
674/// — is identical, so it lives once in [`run_baseline_theta_optimizer`] and
675/// this enum selects the per-contract `OuterProblem` configuration.
676#[derive(Clone, Copy, Debug, PartialEq, Eq)]
677enum BaselineDerivativeContract {
678    /// Cost + analytic gradient, no analytic Hessian. Routes to BFGS, which
679    /// builds its own quasi-Newton curvature from successive gradients.
680    GradientOnly,
681    /// Cost + analytic gradient + analytic Hessian. Routes to the primary
682    /// second-order outer solver, which may use either the analytic Hessian or
683    /// a BFGS approximation depending on the planner.
684    GradientHessian,
685}
686
687impl BaselineDerivativeContract {
688    /// Apply this contract's derivative declaration, solver class, tolerance,
689    /// and iteration budget to a freshly-constructed `OuterProblem`. The
690    /// bounds, initial ρ, and seed config are contract-independent and applied
691    /// by [`run_baseline_theta_optimizer`].
692    fn configure(
693        self,
694        problem: gam_solve::rho_optimizer::OuterProblem,
695    ) -> gam_solve::rho_optimizer::OuterProblem {
696        use gam_problem::{DeclaredHessianForm, Derivative};
697        match self {
698            // BFGS on a 2–3 dim problem with an exact gradient typically
699            // converges in 5–10 outer evaluations.
700            BaselineDerivativeContract::GradientOnly => problem
701                .with_gradient(Derivative::Analytic)
702                .with_hessian(DeclaredHessianForm::Unavailable)
703                .with_tolerance(1e-4)
704                .with_max_iter(240),
705            BaselineDerivativeContract::GradientHessian => problem
706                .with_gradient(Derivative::Analytic)
707                .with_hessian(DeclaredHessianForm::Either)
708                .with_tolerance(1e-4)
709                .with_max_iter(240),
710        }
711    }
712}
713
714/// Shared engine behind the three public baseline-config optimizers.
715///
716/// Owns every step that is identical across the cost-only, gradient-only, and
717/// gradient+Hessian contracts: config→θ seeding (with the linear/no-parameter
718/// early return), the ±6 log-space box, the single-seed `OuterProblem`
719/// skeleton, derivative-contract configuration, `build_objective` wiring,
720/// `run`, the convergence check + error formatting, and θ→config. The only
721/// contract-specific inputs are the already-wired `cost_fn`/`eval_fn` closures
722/// (which embed the derivative shape and dimension validation) and the
723/// `contract` selecting the `OuterProblem` derivative declaration.
724fn run_baseline_theta_optimizer<Fc, Fe>(
725    initial: &SurvivalBaselineConfig,
726    context: &str,
727    contract: BaselineDerivativeContract,
728    cost_fn: Fc,
729    eval_fn: Fe,
730) -> Result<SurvivalBaselineConfig, String>
731where
732    Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
733    Fe: FnMut(
734        &mut (),
735        &Array1<f64>,
736    ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
737{
738    use gam_solve::rho_optimizer::OuterProblem;
739    let Some(seed) = survival_baseline_theta_from_config(initial)? else {
740        return Ok(initial.clone());
741    };
742    let dim = seed.len();
743    let target = initial.target;
744    let lower = seed.mapv(|v| v - 6.0);
745    let upper = seed.mapv(|v| v + 6.0);
746    let problem = contract
747        .configure(OuterProblem::new(dim))
748        .with_bounds(lower, upper)
749        .with_initial_rho(seed.clone())
750        .with_seed_config(crate::seeding::SeedConfig {
751            max_seeds: 1,
752            seed_budget: 1,
753            num_auxiliary_trailing: dim,
754            ..Default::default()
755        });
756    let mut obj = problem.build_objective(
757        (),
758        cost_fn,
759        eval_fn,
760        None::<fn(&mut ())>,
761        None::<
762            fn(
763                &mut (),
764                &Array1<f64>,
765            ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
766        >,
767    );
768    let result = problem
769        .run(&mut obj, context)
770        .map_err(|e| format!("{context} failed: {e}"))?;
771    if !result.converged {
772        return Err(SurvivalConstructionError::InvalidConfig {
773            reason: format!(
774                "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
775                result.iterations,
776                result.final_value,
777                result.final_grad_norm_report(),
778            ),
779        }
780        .into());
781    }
782    survival_baseline_config_from_theta(target, &result.rho)
783}
784
785/// Shared engine for the two derivative-carrying baseline-config optimizers.
786///
787/// Both `…_with_gradient_only` and `…_with_gradient` route an objective that
788/// returns a fully-populated [`OuterEval`](gam_problem::OuterEval)
789/// (cost + analytic gradient, optionally + analytic Hessian) for a given
790/// config. Everything downstream of that — the `Rc<RefCell>` sharing that lets
791/// the same user closure back both the `cost_fn` and `eval_fn`, the θ→config
792/// conversion, and deriving the scalar `cost_fn` from the eval result — is
793/// identical, so it lives here once. The contract-specific axis is only which
794/// `HessianResult` the objective embeds, which the wrapper has already encoded
795/// in the returned `OuterEval`, so this helper is contract-agnostic beyond the
796/// `contract` it forwards to [`run_baseline_theta_optimizer`].
797fn run_baseline_theta_optimizer_with_eval<F>(
798    initial: &SurvivalBaselineConfig,
799    context: &str,
800    contract: BaselineDerivativeContract,
801    objective: F,
802) -> Result<SurvivalBaselineConfig, String>
803where
804    F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
805{
806    let target = initial.target;
807    let engine_context = context.to_string();
808    let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
809    let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
810                        theta: &Array1<f64>|
811          -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
812        let cfg = survival_baseline_config_from_theta(target, theta)
813            .map_err(crate::model_types::EstimationError::InvalidInput)?;
814        let eval =
815            obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
816        if eval.gradient.len() != theta.len() {
817            return Err(crate::model_types::EstimationError::InvalidInput(format!(
818                "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
819                eval.gradient.len(),
820                theta.len()
821            )));
822        }
823        if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
824            if h.nrows() != theta.len() || h.ncols() != theta.len() {
825                return Err(crate::model_types::EstimationError::InvalidInput(format!(
826                    "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
827                    h.nrows(),
828                    h.ncols(),
829                    theta.len(),
830                    theta.len()
831                )));
832            }
833        }
834        Ok(eval)
835    };
836    let cost_objective = std::rc::Rc::clone(&objective);
837    let cost_eval = eval_at.clone();
838    let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
839        cost_eval(&cost_objective, theta).map(|eval| eval.cost)
840    };
841    let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
842    run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
843}
844
845/// Gradient-only outer baseline-config optimizer. Thin adapter over
846/// [`run_baseline_theta_optimizer`] under the
847/// [`BaselineDerivativeContract::GradientOnly`] contract, which advertises
848/// `DeclaredHessianForm::Unavailable`, so the planner routes to BFGS and
849/// builds its own quasi-Newton curvature from successive gradient
850/// evaluations. Used by the survival location-scale path which has a
851/// closed-form θ-gradient (`baseline_chain_rule_gradient` /
852/// `marginal_slope_baseline_chain_rule_gradient`) but no native analytic
853/// θ-Hessian; BFGS on a 2–3 dim problem with an exact gradient typically
854/// converges in 5–10 outer evaluations.
855pub fn optimize_survival_baseline_config_with_gradient_only<F>(
856    initial: &SurvivalBaselineConfig,
857    context: &str,
858    mut objective: F,
859) -> Result<SurvivalBaselineConfig, String>
860where
861    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
862{
863    use gam_problem::{HessianResult, OuterEval};
864    run_baseline_theta_optimizer_with_eval(
865        initial,
866        context,
867        BaselineDerivativeContract::GradientOnly,
868        move |cfg| {
869            let (cost, gradient) = objective(cfg)?;
870            Ok(OuterEval {
871                cost,
872                gradient,
873                hessian: HessianResult::Unavailable,
874                inner_beta_hint: None,
875            })
876        },
877    )
878}
879
880/// Gradient + Hessian outer baseline-config optimizer. Thin adapter over
881/// [`run_baseline_theta_optimizer`] under the
882/// [`BaselineDerivativeContract::GradientHessian`] contract, which advertises
883/// an analytic θ-Hessian so the primary second-order outer solver can use it.
884pub fn optimize_survival_baseline_config_with_gradient<F>(
885    initial: &SurvivalBaselineConfig,
886    context: &str,
887    mut objective: F,
888) -> Result<SurvivalBaselineConfig, String>
889where
890    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
891{
892    use gam_problem::{HessianResult, OuterEval};
893    run_baseline_theta_optimizer_with_eval(
894        initial,
895        context,
896        BaselineDerivativeContract::GradientHessian,
897        move |cfg| {
898            let (cost, gradient, hessian) = objective(cfg)?;
899            Ok(OuterEval {
900                cost,
901                gradient,
902                hessian: HessianResult::Analytic(hessian),
903                inner_beta_hint: None,
904            })
905        },
906    )
907}
908
909// ---------------------------------------------------------------------------
910// Time basis config (library-friendly: takes primitives, not CLI args)
911// ---------------------------------------------------------------------------
912
913pub fn parse_survival_time_basis_config(
914    time_basis: &str,
915    time_degree: usize,
916    time_num_internal_knots: usize,
917    time_smooth_lambda: f64,
918) -> Result<SurvivalTimeBasisConfig, String> {
919    match time_basis.to_ascii_lowercase().as_str() {
920        "none" => Ok(SurvivalTimeBasisConfig::None),
921        "ispline" => {
922            if time_degree < 1 {
923                return Err(
924                    "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
925                        .to_string(),
926                );
927            }
928            if time_num_internal_knots == 0 {
929                return Err(
930                    "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
931                        .to_string(),
932                );
933            }
934            if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
935                return Err(
936                    "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
937                        .to_string(),
938                );
939            }
940            Ok(SurvivalTimeBasisConfig::ISpline {
941                degree: time_degree,
942                knots: Array1::zeros(0),
943                keep_cols: Vec::new(),
944                smooth_lambda: time_smooth_lambda,
945            })
946        }
947        "linear" | "bspline" => {
948            // Forward to the shared structural-basis check so error text
949            // stays consistent with every other call site. `linear` /
950            // `bspline` are not structural, so this always returns Err;
951            // we map a (currently impossible) `Ok` to an explicit error
952            // string instead of `unreachable!`, keeping the match total
953            // without relying on a never-executes claim.
954            match require_structural_survival_time_basis(time_basis, "survival model configuration")
955            {
956                Err(e) => Err(e),
957                Ok(()) => Err(format!(
958                    "internal: structural-basis check accepted non-structural \
959                     survival time basis '{time_basis}'"
960                )),
961            }
962        }
963        other => Err(format!(
964            "unsupported --time-basis '{other}'; accepted values: ispline, none"
965        )),
966    }
967}
968
969// ---------------------------------------------------------------------------
970// Time basis construction
971// ---------------------------------------------------------------------------
972
973pub fn build_survival_time_basis(
974    age_entry: &Array1<f64>,
975    age_exit: &Array1<f64>,
976    cfg: SurvivalTimeBasisConfig,
977    infer_knots_if_needed: Option<(usize, f64)>,
978) -> Result<SurvivalTimeBuildOutput, String> {
979    fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
980        if let Some(row) = times.iter().position(|t| !t.is_finite()) {
981            return Err(SurvivalConstructionError::DataValidationFailed {
982                reason: format!(
983                    "survival time basis requires finite {label} times (row {})",
984                    row + 1
985                ),
986            }
987            .into());
988        }
989        if let Some(row) = times.iter().position(|t| *t < 0.0) {
990            return Err(SurvivalConstructionError::DataValidationFailed {
991                reason: format!(
992                    "survival time basis requires non-negative {label} times (row {})",
993                    row + 1
994                ),
995            }
996            .into());
997        }
998        Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
999    }
1000
1001    let n = age_entry.len();
1002    if n != age_exit.len() {
1003        return Err(SurvivalConstructionError::IncompatibleDimensions {
1004            reason: "survival time basis requires matching entry/exit lengths".to_string(),
1005        }
1006        .into());
1007    }
1008    for i in 0..n {
1009        if age_exit[i] < age_entry[i] {
1010            return Err(format!(
1011                "survival time basis requires exit times >= entry times (row {})",
1012                i + 1
1013            ));
1014        }
1015    }
1016    let log_entry = checked_log_survival_times(age_entry, "entry")?;
1017    let log_exit = checked_log_survival_times(age_exit, "exit")?;
1018
1019    fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1020        let n = log_entry.len();
1021        let entry_range = log_entry
1022            .iter()
1023            .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1024                (lo.min(v), hi.max(v))
1025            });
1026        let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1027        if entry_degenerate {
1028            log_exit.clone()
1029        } else {
1030            let mut combined = Array1::<f64>::zeros(2 * n);
1031            for i in 0..n {
1032                combined[i] = log_entry[i];
1033                combined[n + i] = log_exit[i];
1034            }
1035            combined
1036        }
1037    }
1038
1039    /// Cap the requested monotone-baseline internal-knot count to what the
1040    /// observed time resolution can actually support.
1041    ///
1042    /// The survival location-scale baseline is a degree-`d` I-spline with
1043    /// `num_internal_knots + d` shape-varying columns. Its smoothing parameter
1044    /// is informed *only* by the distinct interior log-time points: with fewer
1045    /// distinct interior times than requested knots the baseline is
1046    /// rank-deficient, and the REML/LAML profile in the time smoothing
1047    /// parameter becomes a flat ridge — the exact-joint outer search then
1048    /// probes that ridge indefinitely (each inner constrained Newton burns its
1049    /// whole cycle budget without certifying convergence) and the fit never
1050    /// terminates. This is the survival analogue of the standard
1051    /// "df must not exceed the data resolution" guard (`mgcv` caps `k` at the
1052    /// number of unique covariate values; `flexsurv`/`rstpm2` use a handful of
1053    /// baseline knots): we never place more interior knots than there are
1054    /// distinct interior points, and we keep the total baseline dimension a
1055    /// bounded fraction of the sample so the smoothing profile stays curved.
1056    ///
1057    /// This clamp lives in the shared knot-inference routine so the fit and any
1058    /// independent rebuild of the time basis (e.g. a predictor reconstructing
1059    /// `design · β` at fresh covariates) resolve to the *same* knot vector from
1060    /// the same data — there is no raw/active dimension drift.
1061    fn data_capped_internal_knots(
1062        combined: &Array1<f64>,
1063        degree: usize,
1064        requested_internal_knots: usize,
1065    ) -> usize {
1066        if requested_internal_knots == 0 {
1067            return 0;
1068        }
1069        let mut sorted: Vec<f64> = combined.iter().copied().collect();
1070        sorted.sort_by(f64::total_cmp);
1071        let minval = sorted.first().copied().unwrap_or(0.0);
1072        let maxval = sorted.last().copied().unwrap_or(minval);
1073        if minval == maxval {
1074            // Degenerate (single distinct time): no interior structure to fit.
1075            return 1.min(requested_internal_knots);
1076        }
1077        let scale = (maxval - minval).abs().max(1.0);
1078        let tol = 1e-12 * scale;
1079        // Count distinct strictly-interior points (knots can only live strictly
1080        // between the data extremes).
1081        let mut distinct_interior = 0usize;
1082        let mut last: Option<f64> = None;
1083        for &x in &sorted {
1084            if x <= minval + tol || x >= maxval - tol {
1085                continue;
1086            }
1087            if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1088                continue;
1089            }
1090            distinct_interior += 1;
1091            last = Some(x);
1092        }
1093        // Distinct-point ceiling: cannot place more interior knots than there
1094        // are distinct interior values.
1095        let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1096        // Dimension-vs-resolution ceiling: keep the total baseline column count
1097        // `cap + degree` below ~1/4 of the distinct sample points so the
1098        // smoothing-parameter profile retains curvature (the data must be able
1099        // to identify the baseline shape, not just interpolate it). `n_distinct`
1100        // counts all distinct points (interior + the two extremes).
1101        let n_distinct = {
1102            let mut count = 0usize;
1103            let mut last: Option<f64> = None;
1104            for &x in &sorted {
1105                if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1106                    continue;
1107                }
1108                count += 1;
1109                last = Some(x);
1110            }
1111            count
1112        };
1113        let dim_budget = n_distinct / 4;
1114        let dim_cap = dim_budget.saturating_sub(degree);
1115        cap = cap.min(dim_cap.max(1));
1116        cap.max(1)
1117    }
1118
1119    fn infer_survival_time_knots(
1120        combined: &Array1<f64>,
1121        knot_degree: usize,
1122        validation_degree: usize,
1123        num_internal_knots: usize,
1124        basis_options: BasisOptions,
1125    ) -> Result<Array1<f64>, String> {
1126        // Identifiability/termination guard: never request more baseline
1127        // internal knots than the observed time resolution supports. See
1128        // `data_capped_internal_knots` for the full rationale (a flat smoothing
1129        // ridge on an over-parameterized baseline is what makes the survival
1130        // location-scale exact-joint outer search fail to terminate).
1131        let num_internal_knots =
1132            data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1133
1134        fn quantile_knot_inference_needs_uniform_fallback(
1135            combined: &Array1<f64>,
1136            num_internal_knots: usize,
1137        ) -> bool {
1138            if num_internal_knots == 0 || combined.is_empty() {
1139                return false;
1140            }
1141
1142            let mut sorted: Vec<f64> = combined.iter().copied().collect();
1143            sorted.sort_by(f64::total_cmp);
1144            let minval = sorted[0];
1145            let maxval = *sorted.last().unwrap_or(&minval);
1146            if minval == maxval {
1147                return false;
1148            }
1149
1150            let scale = (maxval - minval).abs().max(1.0);
1151            let tol = 1e-12 * scale;
1152            let mut support = Vec::with_capacity(sorted.len());
1153            let mut last: Option<f64> = None;
1154            for &x in &sorted {
1155                if x <= minval + tol || x >= maxval - tol {
1156                    continue;
1157                }
1158                if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1159                    continue;
1160                }
1161                support.push(x);
1162                last = Some(x);
1163            }
1164            if support.is_empty() {
1165                return true;
1166            }
1167
1168            let n = support.len();
1169            let mut prev_q = minval;
1170            for j in 1..=num_internal_knots {
1171                let p = j as f64 / (num_internal_knots + 1) as f64;
1172                let pos = p * (n.saturating_sub(1) as f64);
1173                let lo = pos.floor() as usize;
1174                let hi = pos.ceil() as usize;
1175                let frac = pos - lo as f64;
1176                let q = if lo == hi {
1177                    support[lo]
1178                } else {
1179                    support[lo] * (1.0 - frac) + support[hi] * frac
1180                }
1181                .clamp(minval, maxval);
1182                if q <= prev_q + tol || q >= maxval - tol {
1183                    return true;
1184                }
1185                prev_q = q;
1186            }
1187
1188            false
1189        }
1190
1191        let inferwith =
1192            |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1193                let built = build_bspline_basis_1d(
1194                    combined.view(),
1195                    &BSplineBasisSpec {
1196                        degree: knot_degree,
1197                        penalty_order: 2,
1198                        knotspec: BSplineKnotSpec::Automatic {
1199                            num_internal_knots: Some(num_internal_knots),
1200                            placement,
1201                        },
1202                        double_penalty: false,
1203                        identifiability: BSplineIdentifiability::None,
1204                        boundary: OneDimensionalBoundary::Open,
1205                        boundary_conditions: BSplineBoundaryConditions::default(),
1206                    },
1207                )
1208                .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1209                let knots = match built.metadata {
1210                    BasisMetadata::BSpline1D { knots, .. } => knots,
1211                    _ => {
1212                        return Err(
1213                            "internal error: expected BSpline1D metadata for survival time basis"
1214                                .to_string(),
1215                        );
1216                    }
1217                };
1218                // `knot_degree` is the clamped B-spline degree used to size
1219                // the knot vector. `validation_degree` is the public basis
1220                // degree passed to the final evaluator. They differ for
1221                // I-splines because `create_basis(..., BasisOptions::i_spline())`
1222                // internally raises the public degree by one to its working
1223                // B-spline antiderivative degree. Validating with
1224                // `knot_degree` here would raise a second time and reject the
1225                // coherent knot vector we just inferred.
1226                create_basis::<Dense>(
1227                    combined.view(),
1228                    KnotSource::Provided(knots.view()),
1229                    validation_degree,
1230                    basis_options,
1231                )
1232                .map_err(|e| e.to_string())?;
1233                Ok(knots)
1234            };
1235
1236        if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1237            inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1238        } else {
1239            inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1240        }
1241    }
1242
1243    match cfg {
1244        SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1245            x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1246            x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1247            x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1248            penalties: Vec::new(),
1249            nullspace_dims: Vec::new(),
1250            basisname: "none".to_string(),
1251            degree: None,
1252            knots: None,
1253            keep_cols: None,
1254            smooth_lambda: None,
1255        }),
1256        SurvivalTimeBasisConfig::Linear => {
1257            let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1258            let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1259            let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1260            for i in 0..n {
1261                x_entry_time[[i, 0]] = 1.0;
1262                x_exit_time[[i, 0]] = 1.0;
1263                x_entry_time[[i, 1]] = log_entry[i];
1264                x_exit_time[[i, 1]] = log_exit[i];
1265                x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1266            }
1267            Ok(SurvivalTimeBuildOutput {
1268                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1269                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1270                x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1271                penalties: Vec::new(),
1272                nullspace_dims: Vec::new(),
1273                basisname: "linear".to_string(),
1274                degree: None,
1275                knots: None,
1276                keep_cols: None,
1277                smooth_lambda: None,
1278            })
1279        }
1280        SurvivalTimeBasisConfig::BSpline {
1281            degree,
1282            knots,
1283            smooth_lambda,
1284        } => {
1285            let knotvec = if knots.is_empty() {
1286                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1287                    "internal error: bspline time basis requested without knot source".to_string()
1288                })?;
1289                let combined = survival_time_knot_input(&log_entry, &log_exit);
1290                infer_survival_time_knots(
1291                    &combined,
1292                    degree,
1293                    degree,
1294                    num_internal_knots,
1295                    BasisOptions::value(),
1296                )?
1297            } else {
1298                knots
1299            };
1300
1301            let entry_basis = build_bspline_basis_1d(
1302                log_entry.view(),
1303                &BSplineBasisSpec {
1304                    degree,
1305                    penalty_order: 2,
1306                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1307                    double_penalty: false,
1308                    identifiability: BSplineIdentifiability::None,
1309                    boundary: OneDimensionalBoundary::Open,
1310                    boundary_conditions: BSplineBoundaryConditions::default(),
1311                },
1312            )
1313            .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1314            let exit_basis = build_bspline_basis_1d(
1315                log_exit.view(),
1316                &BSplineBasisSpec {
1317                    degree,
1318                    penalty_order: 2,
1319                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1320                    double_penalty: false,
1321                    identifiability: BSplineIdentifiability::None,
1322                    boundary: OneDimensionalBoundary::Open,
1323                    boundary_conditions: BSplineBoundaryConditions::default(),
1324                },
1325            )
1326            .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1327
1328            let p_time = exit_basis.design.ncols();
1329            // Build derivative basis as sparse triplets — B-spline derivatives
1330            // have the same local support as the basis itself (at most degree+1
1331            // nonzeros per row), so building dense first wastes memory.
1332            let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1333            let mut deriv_buf = vec![0.0_f64; p_time];
1334            for i in 0..n {
1335                deriv_buf.fill(0.0);
1336                evaluate_bspline_derivative_scalar(
1337                    log_exit[i],
1338                    knotvec.view(),
1339                    degree,
1340                    &mut deriv_buf,
1341                )
1342                .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1343                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1344                for j in 0..p_time {
1345                    let v = deriv_buf[j] * chain;
1346                    if v.abs() > 1e-15 {
1347                        deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1348                    }
1349                }
1350            }
1351            let x_derivative_time =
1352                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1353                {
1354                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1355                    Err(_) => {
1356                        // Fallback: build dense
1357                        let mut dense = Array2::<f64>::zeros((n, p_time));
1358                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1359                            dense[[row, col]] = val;
1360                        }
1361                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1362                    }
1363                };
1364
1365            Ok(SurvivalTimeBuildOutput {
1366                x_entry_time: entry_basis.design,
1367                x_exit_time: exit_basis.design,
1368                x_derivative_time,
1369                nullspace_dims: entry_basis.nullspace_dims,
1370                penalties: entry_basis.penalties,
1371                basisname: "bspline".to_string(),
1372                degree: Some(degree),
1373                knots: Some(knotvec.to_vec()),
1374                keep_cols: None,
1375                smooth_lambda: Some(smooth_lambda),
1376            })
1377        }
1378        SurvivalTimeBasisConfig::ISpline {
1379            degree,
1380            knots,
1381            keep_cols,
1382            smooth_lambda,
1383        } => {
1384            let bspline_degree = degree
1385                .checked_add(1)
1386                .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1387            let knotvec = if knots.is_empty() {
1388                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1389                    "internal error: ispline time basis requested without knot source".to_string()
1390                })?;
1391                let combined = survival_time_knot_input(&log_entry, &log_exit);
1392                infer_survival_time_knots(
1393                    &combined,
1394                    bspline_degree,
1395                    degree,
1396                    num_internal_knots,
1397                    BasisOptions::i_spline(),
1398                )?
1399            } else {
1400                knots
1401            };
1402
1403            let (db_exit_arc, _) = create_basis::<Dense>(
1404                log_exit.view(),
1405                KnotSource::Provided(knotvec.view()),
1406                bspline_degree,
1407                BasisOptions::first_derivative(),
1408            )
1409            .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1410
1411            // Build full-width I-spline bases inside a block scope so the
1412            // large Arc allocations are freed when the block ends.
1413            let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1414                let (entry_arc, _) = create_basis::<Dense>(
1415                    log_entry.view(),
1416                    KnotSource::Provided(knotvec.view()),
1417                    degree,
1418                    BasisOptions::i_spline(),
1419                )
1420                .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1421                let (exit_arc, _) = create_basis::<Dense>(
1422                    log_exit.view(),
1423                    KnotSource::Provided(knotvec.view()),
1424                    degree,
1425                    BasisOptions::i_spline(),
1426                )
1427                .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1428
1429                let x_entry_full = entry_arc.as_ref();
1430                let x_exit_full = exit_arc.as_ref();
1431                let p_time_full = x_exit_full.ncols();
1432                if p_time_full == 0 {
1433                    return Err(SurvivalConstructionError::BasisConstructionFailed {
1434                        reason: "internal error: empty ispline time basis".to_string(),
1435                    }
1436                    .into());
1437                }
1438                let db_exit = db_exit_arc.as_ref();
1439                if db_exit.ncols() != p_time_full + 1 {
1440                    return Err(
1441                        "internal error: ispline derivative basis width must exceed basis width by one"
1442                            .to_string(),
1443                    );
1444                }
1445
1446                let keep_cols = if keep_cols.is_empty() {
1447                    let constant_tol = 1e-12_f64;
1448                    let mut inferred_keep_cols: Vec<usize> = Vec::new();
1449                    for j in 0..p_time_full {
1450                        let mut minv = f64::INFINITY;
1451                        let mut maxv = f64::NEG_INFINITY;
1452                        for i in 0..n {
1453                            let ve = x_exit_full[[i, j]];
1454                            let vs = x_entry_full[[i, j]];
1455                            minv = minv.min(ve.min(vs));
1456                            maxv = maxv.max(ve.max(vs));
1457                        }
1458                        if (maxv - minv) > constant_tol {
1459                            inferred_keep_cols.push(j);
1460                        }
1461                    }
1462                    inferred_keep_cols
1463                } else {
1464                    keep_cols
1465                };
1466                if keep_cols.is_empty() {
1467                    return Err(
1468                        "internal error: ispline basis has no shape-varying time columns"
1469                            .to_string(),
1470                    );
1471                }
1472                if keep_cols.iter().any(|&j| j >= p_time_full) {
1473                    return Err(SurvivalConstructionError::MissingColumn {
1474                        reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1475                    }
1476                    .into());
1477                }
1478
1479                let p_time = keep_cols.len();
1480                let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1481                let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1482                // entry_arc and exit_arc go out of scope here, freeing the
1483                // full-width bases before derivative computation below.
1484                (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1485            };
1486            let db_exit = db_exit_arc.as_ref();
1487
1488            // Build I-spline derivative as sparse triplets.  The derivative
1489            // is a cumulative sum of B-spline derivatives and typically has
1490            // more nonzeros per row than a plain B-spline, but still much
1491            // fewer than p_time for modest bases.
1492            let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1493            let mut found_nonfinite: Option<(usize, usize)> = None;
1494            for i in 0..n {
1495                let mut running = 0.0_f64;
1496                let mut d_i_log_full = vec![0.0_f64; p_time_full];
1497                for j in (1..db_exit.ncols()).rev() {
1498                    let term = db_exit[[i, j]];
1499                    if term.is_finite() {
1500                        running += term;
1501                    }
1502                    d_i_log_full[j - 1] = running;
1503                }
1504                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1505                for (j_new, &j_old) in keep_cols.iter().enumerate() {
1506                    let raw_v = d_i_log_full[j_old] * chain;
1507                    let v = if (-1e-12..0.0).contains(&raw_v) {
1508                        0.0
1509                    } else {
1510                        raw_v
1511                    };
1512                    if !v.is_finite() {
1513                        found_nonfinite = Some((i, j_new));
1514                    }
1515                    if v < -1e-12 {
1516                        return Err(format!(
1517                            "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1518                            i + 1,
1519                            j_new + 1,
1520                            v
1521                        ));
1522                    }
1523                    if v.abs() > 1e-15 {
1524                        deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1525                    }
1526                }
1527            }
1528            if let Some((row, col)) = found_nonfinite {
1529                return Err(format!(
1530                    "survival ispline derivative basis produced non-finite value at row {}, column {}",
1531                    row + 1,
1532                    col + 1
1533                ));
1534            }
1535            let x_derivative_time =
1536                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1537                {
1538                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1539                    Err(_) => {
1540                        let mut dense = Array2::<f64>::zeros((n, p_time));
1541                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1542                            dense[[row, col]] = val;
1543                        }
1544                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1545                    }
1546                };
1547
1548            let penalty_basis = build_bspline_basis_1d(
1549                log_exit.view(),
1550                &BSplineBasisSpec {
1551                    degree: bspline_degree,
1552                    penalty_order: 2,
1553                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1554                    double_penalty: false,
1555                    identifiability: BSplineIdentifiability::None,
1556                    boundary: OneDimensionalBoundary::Open,
1557                    boundary_conditions: BSplineBoundaryConditions::default(),
1558                },
1559            )
1560            .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1561            if penalty_basis.design.ncols() != p_time_full + 1 {
1562                return Err("internal error: ispline penalty dimension mismatch".to_string());
1563            }
1564            // I-spline curvature penalty in the *value* space of the baseline
1565            // log-cumulative-hazard, restricted to the retained (non-dropped)
1566            // coefficient block.
1567            //
1568            // The I-spline coefficient γ is the consecutive increment of the B-spline
1569            // value coefficients `c`: `c_0 = 0`, `c_k = Σ_{j<k} γ_j = (L γ)_k`, where
1570            // `L` is the `p_time × p_time` lower-triangular cumsum matrix. The
1571            // second-difference penalty on the B-spline values is `S_B = D₂ᵀD₂`
1572            // (the `penalty_basis.penalties` block). The correct curvature penalty
1573            // on γ is the **value-space congruence transform**
1574            //
1575            //   `S_I = Lᵀ S_B[1:,1:] L`,
1576            //
1577            // which satisfies `γᵀ S_I γ = (Lγ)ᵀ S_B[1:,1:] (Lγ)`.
1578            //
1579            // A constant γ (γ_k = γ₀ ∀k) maps to the linear value sequence
1580            // `c_k = k·γ₀`, which is annihilated by D₂: `D₂c = 0`. Therefore
1581            // `γᵀ S_I γ = 0` for constant γ, i.e. the **affine trend lies in the
1582            // penalty null space**. REML does not penalize the baseline slope
1583            // `d(log Λ)/d(log t)` or the overall level, so it correctly lets the
1584            // data determine these quantities without bias. The previous increment-
1585            // space form `S_B[1:,1:]` (applied directly to γ instead of Lγ) did NOT
1586            // have constant γ in its null space and therefore over-penalized affine
1587            // baselines, causing the fitted log-cumulative-hazard to lose its tail
1588            // slope to the penalty and fail quality tests (#1076).
1589            //
1590            // The value-space form has a 1-dimensional null space (span{(1,…,1)}),
1591            // declared via `nullspace_dims` so the REML generalized-logdet picks it
1592            // up. The penalized inner PIRLS is well-conditioned because the
1593            // likelihood Hessian H_lik has O(n_events) curvature along the affine
1594            // direction (the overall baseline level is identified by the data), and
1595            // the global stabilization ridge (ridge_lambda) provides an absolute
1596            // positive-definite floor.
1597            let mut penalties = Vec::<Array2<f64>>::new();
1598            for s_mat in &penalty_basis.penalties {
1599                if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1600                    continue;
1601                }
1602                // I-spline value-space penalty, computed in the CORRECT order
1603                // (gam#979). The B-spline value coefficients are the cumulative
1604                // sum of the I-spline increment coefficients, `c = L γ_full`, where
1605                // `L` is the FULL `p_time_full × p_time_full` LOWER-triangular
1606                // all-ones cumsum matrix (`L[i,j] = 1 iff j ≤ i`, so
1607                // `c_i = Σ_{j≤i} γ_j`). The value-space curvature penalty on the
1608                // full increment vector is the symmetric congruence
1609                //
1610                //   `S_I_full = Lᵀ · S_B[1:,1:] · L`,
1611                //
1612                // which is PSD because `S_B[1:,1:]` is a principal submatrix of the
1613                // PSD `S_B = D₂ᵀD₂` and congruence by any matrix preserves PSD.
1614                //
1615                // CRITICAL ORDERING (the gam#979 indefiniteness bug): the retained
1616                // columns `keep_cols` must be selected as a PRINCIPAL SUBMATRIX of
1617                // the FULL congruence `S_I_full` — i.e. congruence FIRST, selection
1618                // SECOND. The previous code selected `keep_cols` from `S_B[1:,1:]`
1619                // first and then applied a `p_time × p_time` cumsum to that
1620                // already-reduced block. Because the cumsum `L` couples every
1621                // increment, restricting the increment index set BEFORE the cumsum
1622                // does NOT commute with it: the reduced operator is a different,
1623                // generally INDEFINITE matrix (measured `s0_min_eval = −9.8e7`),
1624                // which makes `½γᵀS_Iγ` unbounded below and the penalized survival
1625                // NLL diverge (β drifts up the negative-eigenvalue mode, the inner
1626                // joint-Newton follows the unbounded objective, the outer REML never
1627                // terminates — the #979 hang). Doing the congruence on the full γ
1628                // and then taking the `keep_cols` principal submatrix restores the
1629                // PSD guarantee (a principal submatrix of a PSD matrix is PSD).
1630                let s_increment = s_mat.slice(s![1.., 1..]);
1631                if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1632                    return Err(format!(
1633                        "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1634                        s_increment.nrows(),
1635                        s_increment.ncols(),
1636                    ));
1637                }
1638                // Symmetrize the (already-symmetric) source with the shared
1639                // matrix utility. The survival builder's value-space
1640                // congruence is domain-specific; only the low-level symmetric
1641                // cleanup is common with the generic and SAE construction code.
1642                let mut s_full = s_increment.to_owned();
1643                symmetrize_in_place(&mut s_full);
1644                // S_mid = S_B[1:,1:] · L  (right-multiply by lower-triangular
1645                // cumsum): (S·L)[i,j] = Σ_k S[i,k]·L[k,j] = Σ_{k≥j} S[i,k]
1646                // because L[k,j] = 1 iff j ≤ k.
1647                let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1648                for i in 0..p_time_full {
1649                    for j in 0..p_time_full {
1650                        let mut v = 0.0;
1651                        for k in j..p_time_full {
1652                            v += s_full[[i, k]];
1653                        }
1654                        s_mid_full[[i, j]] = v;
1655                    }
1656                }
1657                // S_I_full = Lᵀ · S_mid = Lᵀ · S · L:
1658                // (Lᵀ·S_mid)[i,j] = Σ_k Lᵀ[i,k]·S_mid[k,j] = Σ_{k≥i} S_mid[k,j]
1659                // because Lᵀ[i,k] = L[k,i] = 1 iff i ≤ k.
1660                let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1661                for i in 0..p_time_full {
1662                    for j in 0..p_time_full {
1663                        let mut v = 0.0;
1664                        for k in i..p_time_full {
1665                            v += s_mid_full[[k, j]];
1666                        }
1667                        s_full_congruent[[i, j]] = v;
1668                    }
1669                }
1670                // Principal submatrix on the retained (shape-varying) columns.
1671                let mut local = Array2::<f64>::zeros((p_time, p_time));
1672                for (i_new, &i_old) in keep_cols.iter().enumerate() {
1673                    for (j_new, &j_old) in keep_cols.iter().enumerate() {
1674                        // Symmetrize on the way out to absorb residual
1675                        // floating-point asymmetry.
1676                        local[[i_new, j_new]] = 0.5
1677                            * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1678                    }
1679                }
1680                penalties.push(local);
1681            }
1682
1683            // PSD contract (gam#979). The value-space congruence Lᵀ S_B[1:,1:] L,
1684            // restricted to a principal submatrix, is positive semidefinite by
1685            // construction. A negative eigenvalue here means the construction has
1686            // regressed to the increment-space / wrong-ordering form that made the
1687            // penalized survival NLL unbounded below (the #979 divergence). Verify
1688            // it here, at construction, so the defect can never silently reach the
1689            // inner solver again. The tolerance is the same relative scale the
1690            // nullspace detection below uses; a numerically tiny negative (round-off
1691            // on the genuine 1-D null direction) is allowed, a structural one is not.
1692            for (idx, s_mat) in penalties.iter().enumerate() {
1693                let p = s_mat.nrows();
1694                if p == 0 {
1695                    continue;
1696                }
1697                if let Ok((evals, _)) =
1698                    gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1699                {
1700                    let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1701                        "internal error: ispline penalty eigenvalues not contiguous".to_string()
1702                    })?;
1703                    let max_ev = evals_slice
1704                        .iter()
1705                        .copied()
1706                        .fold(0.0_f64, |a, b| a.max(b.abs()))
1707                        .max(1.0);
1708                    let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1709                    let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1710                    if min_ev < neg_tol {
1711                        return Err(format!(
1712                            "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1713                             indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1714                             {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1715                        ));
1716                    }
1717                }
1718            }
1719
1720            // The value-space penalty S_I = L^T S_B[1:,1:] L has a 1-dimensional
1721            // null space (constant γ ↦ affine c ↦ D₂c = 0). Detect it spectrally
1722            // so the REML uses the generalized logdet over the penalized subspace.
1723            let nullspace_dims: Vec<usize> = penalties
1724                .iter()
1725                .map(|s_mat| {
1726                    let p = s_mat.nrows();
1727                    if p == 0 {
1728                        return 0;
1729                    }
1730                    match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1731                        Ok((evals, _)) => {
1732                            let evals_slice: &[f64] = evals.as_slice().unwrap();
1733                            let max_ev = evals_slice
1734                                .iter()
1735                                .copied()
1736                                .fold(0.0_f64, |a, b| a.max(b.abs()))
1737                                .max(1.0);
1738                            let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1739                            evals_slice.iter().filter(|&&e| e <= threshold).count()
1740                        }
1741                        Err(_) => 0,
1742                    }
1743                })
1744                .collect();
1745            Ok(SurvivalTimeBuildOutput {
1746                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1747                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1748                x_derivative_time,
1749                penalties,
1750                nullspace_dims,
1751                basisname: "ispline".to_string(),
1752                degree: Some(degree),
1753                knots: Some(knotvec.to_vec()),
1754                keep_cols: Some(keep_cols),
1755                smooth_lambda: Some(smooth_lambda),
1756            })
1757        }
1758    }
1759}
1760
1761pub fn resolved_survival_time_basis_config_from_build(
1762    basisname: &str,
1763    degree: Option<usize>,
1764    knots: Option<&Vec<f64>>,
1765    keep_cols: Option<&Vec<usize>>,
1766    smooth_lambda: Option<f64>,
1767) -> Result<SurvivalTimeBasisConfig, String> {
1768    match basisname {
1769        "none" => Ok(SurvivalTimeBasisConfig::None),
1770        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1771        "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1772            degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1773            knots: Array1::from_vec(
1774                knots
1775                    .cloned()
1776                    .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1777            ),
1778            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1779        }),
1780        "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1781            degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1782            knots: Array1::from_vec(
1783                knots
1784                    .cloned()
1785                    .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1786            ),
1787            keep_cols: keep_cols
1788                .cloned()
1789                .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1790            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1791        }),
1792        other => Err(format!("unsupported survival time basis '{other}'")),
1793    }
1794}
1795
1796pub fn resolve_survival_time_anchor_value(
1797    age_entry: &Array1<f64>,
1798    time_anchor: Option<f64>,
1799) -> Result<f64, String> {
1800    if age_entry.is_empty() {
1801        return Err("survival time anchor requires non-empty entry times".to_string());
1802    }
1803    let anchor = match time_anchor {
1804        Some(t_anchor) => {
1805            if !t_anchor.is_finite() || t_anchor < 0.0 {
1806                return Err(format!(
1807                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1808                ));
1809            }
1810            t_anchor
1811        }
1812        None => age_entry
1813            .iter()
1814            .copied()
1815            .min_by(f64::total_cmp)
1816            .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1817    };
1818    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1819}
1820
1821/// Marginal-slope centering anchor: a robust *interior* time on the **exit**
1822/// scale rather than the earliest entry age.
1823///
1824/// `center_survival_time_designs_at_anchor` subtracts the time-basis row at the
1825/// anchor from every entry/exit design row, so the anchor sets the origin of
1826/// the baseline-hazard I-spline's affine reparameterization. The
1827/// location-scale path anchors at the minimum entry age
1828/// ([`resolve_survival_time_anchor_value`]); for right-censored-only data that
1829/// minimum is ≈ the time origin, so centering is nearly a no-op.
1830///
1831/// Under **left truncation** the minimum entry age is a genuine positive
1832/// *left-tail* point, and centering there leaves the centered linear-trend
1833/// column `X(exit) − X(anchor)` large and one-signed across all rows (exit
1834/// times sit far to the right of the earliest entry). That column is the
1835/// unpenalized polynomial null space of the 2nd-difference time penalty, so the
1836/// inflated, one-signed column multiplies the marginal-slope time-block score
1837/// at the `γ = 0` monotone-cone seed up by hundreds — the constrained joint
1838/// Newton cannot certify KKT on it and REML rejects every seed (issue #751).
1839///
1840/// Centering instead at a robust interior location on the *exit* scale — the
1841/// **median exit age**, where the at-risk mass concentrates — keeps the
1842/// centered column small and two-signed (some exits below the median, some
1843/// above), so the exit-event likelihood pins the linear trend and the seed
1844/// score stays bounded. Re-centering is an exact affine reparameterization of
1845/// the baseline offset: the fitted `q(t)` and the REML objective are unchanged,
1846/// only the seed conditioning improves. The median is chosen (over the mean)
1847/// for robustness to the heavy right tail of survival times.
1848///
1849/// An explicit `--survival-time-anchor` is honored verbatim (same validation as
1850/// the location-scale path) so the user retains full control; the saved
1851/// `survival_time_anchor` scalar round-trips to predict unchanged.
1852pub fn resolve_survival_marginal_slope_time_anchor_value(
1853    age_entry: &Array1<f64>,
1854    age_exit: &Array1<f64>,
1855    time_anchor: Option<f64>,
1856) -> Result<f64, String> {
1857    if age_entry.is_empty() || age_exit.is_empty() {
1858        return Err(
1859            "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1860        );
1861    }
1862    let anchor = match time_anchor {
1863        Some(t_anchor) => {
1864            if !t_anchor.is_finite() || t_anchor < 0.0 {
1865                return Err(format!(
1866                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1867                ));
1868            }
1869            t_anchor
1870        }
1871        None => robust_interior_exit_anchor(age_exit),
1872    };
1873    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1874}
1875
1876/// Median exit age — a robust interior time on the exit scale, where the
1877/// at-risk mass concentrates. Used as the survival time-basis centering anchor
1878/// whenever the earliest entry is a positive left-tail point (delayed entry):
1879/// centering there keeps the reparameterized linear-trend column small and
1880/// two-signed instead of large and one-signed. The median is chosen over the
1881/// mean for robustness to the heavy right tail of survival times.
1882fn robust_interior_exit_anchor(age_exit: &Array1<f64>) -> f64 {
1883    let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1884    sorted.sort_by(f64::total_cmp);
1885    let m = sorted.len();
1886    if m == 0 {
1887        return SURVIVAL_TIME_FLOOR;
1888    }
1889    if m % 2 == 1 {
1890        sorted[m / 2]
1891    } else {
1892        0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1893    }
1894}
1895
1896/// Centering anchor for the default transformation (Royston-Parmar) survival
1897/// baseline.
1898///
1899/// For right-censored-only data the earliest entry age is ≈ the time origin, so
1900/// [`resolve_survival_time_anchor_value`] (min entry) is nearly a no-op and is
1901/// used unchanged. Under **left truncation** (every row enters at a positive
1902/// delayed-entry time) that minimum is a genuine left-tail point far below the
1903/// exit mass, and centering the I-spline time basis there leaves the
1904/// unpenalized linear-trend column `X(exit) − X(anchor)` large and one-signed
1905/// across all rows. That column is the null space of the 2nd-difference time
1906/// penalty, so the inflated one-signed column blows up the transformation
1907/// smoothing-parameter selection: it rails a penalty direction and collapses the
1908/// baseline to a covariate-independent, cumulative-hazard-inflated degenerate
1909/// fit (issue #1790 — the transformation-model analogue of the marginal-slope
1910/// #751 defect). Anchoring instead at the robust interior **median exit age**
1911/// keeps the centered column small and two-signed so the exit-event likelihood
1912/// pins the linear trend. Re-centering is an exact affine reparameterization of
1913/// the baseline offset — the fitted `q(t)` and REML objective are unchanged,
1914/// only the seed conditioning improves. An explicit `time_anchor` is honored
1915/// verbatim.
1916pub fn resolve_survival_transformation_time_anchor_value(
1917    age_entry: &Array1<f64>,
1918    age_exit: &Array1<f64>,
1919    time_anchor: Option<f64>,
1920) -> Result<f64, String> {
1921    if time_anchor.is_some() {
1922        return resolve_survival_time_anchor_value(age_entry, time_anchor);
1923    }
1924    if age_exit.is_empty() {
1925        return Err(
1926            "survival transformation time anchor requires non-empty exit times".to_string(),
1927        );
1928    }
1929    let min_entry = age_entry
1930        .iter()
1931        .copied()
1932        .fold(f64::INFINITY, f64::min);
1933    if min_entry > SURVIVAL_DELAYED_ENTRY_THRESHOLD {
1934        Ok(robust_interior_exit_anchor(age_exit).max(SURVIVAL_TIME_FLOOR))
1935    } else {
1936        resolve_survival_time_anchor_value(age_entry, None)
1937    }
1938}
1939
1940pub fn evaluate_survival_time_basis_row(
1941    age: f64,
1942    cfg: &SurvivalTimeBasisConfig,
1943) -> Result<Array1<f64>, String> {
1944    if !age.is_finite() || age < 0.0 {
1945        return Err(format!(
1946            "survival time basis row requires finite non-negative age, got {age}"
1947        ));
1948    }
1949    let age = age.max(SURVIVAL_TIME_FLOOR);
1950    let log_age = array![age.ln()];
1951    match cfg {
1952        SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1953        SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1954        SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1955            if knots.is_empty() {
1956                return Err(
1957                    "survival BSpline anchor evaluation requires resolved knot metadata"
1958                        .to_string(),
1959                );
1960            }
1961            let built = build_bspline_basis_1d(
1962                log_age.view(),
1963                &BSplineBasisSpec {
1964                    degree: *degree,
1965                    penalty_order: 2,
1966                    knotspec: BSplineKnotSpec::Provided(knots.clone()),
1967                    double_penalty: false,
1968                    identifiability: BSplineIdentifiability::None,
1969                    boundary: OneDimensionalBoundary::Open,
1970                    boundary_conditions: BSplineBoundaryConditions::default(),
1971                },
1972            )
1973            .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1974            Ok(built.design.to_dense().row(0).to_owned())
1975        }
1976        SurvivalTimeBasisConfig::ISpline {
1977            degree,
1978            knots,
1979            keep_cols,
1980            ..
1981        } => {
1982            if knots.is_empty() {
1983                return Err(
1984                    "survival ISpline anchor evaluation requires resolved knot metadata"
1985                        .to_string(),
1986                );
1987            }
1988            let (basis_arc, _) = create_basis::<Dense>(
1989                log_age.view(),
1990                KnotSource::Provided(knots.view()),
1991                *degree,
1992                BasisOptions::i_spline(),
1993            )
1994            .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1995            let basis = basis_arc.as_ref();
1996            let row = basis.row(0);
1997            if keep_cols.is_empty() {
1998                return Ok(row.to_owned());
1999            }
2000            if keep_cols.iter().any(|&j| j >= row.len()) {
2001                return Err(SurvivalConstructionError::MissingColumn {
2002                    reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
2003                }
2004                .into());
2005            }
2006            Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
2007        }
2008    }
2009}
2010
2011pub fn center_survival_time_designs_at_anchor(
2012    design_entry: &mut DesignMatrix,
2013    design_exit: &mut DesignMatrix,
2014    anchor_row: &Array1<f64>,
2015) -> Result<(), String> {
2016    if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
2017        return Err(format!(
2018            "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
2019            design_entry.ncols(),
2020            design_exit.ncols(),
2021            anchor_row.len()
2022        ));
2023    }
2024    // Centering destroys sparsity (every row gets a dense offset), so
2025    // materialize to dense.  This only runs once at construction time.
2026    fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
2027        let mut dense = dm.to_dense();
2028        for mut row in dense.rows_mut() {
2029            row -= &anchor.view();
2030        }
2031        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
2032    }
2033    center_dense(design_entry, anchor_row);
2034    center_dense(design_exit, anchor_row);
2035    Ok(())
2036}
2037
2038// ---------------------------------------------------------------------------
2039// Baseline evaluation (Gompertz, Weibull, Gompertz-Makeham)
2040// ---------------------------------------------------------------------------
2041
2042/// Partial derivatives of the baseline offsets `(eta_target, d_eta_target/dt)`
2043/// with respect to the θ-parameters in the same parameterization that
2044/// [`survival_baseline_theta_from_config`] / [`survival_baseline_config_from_theta`]
2045/// use:
2046///
2047/// - **Weibull**: θ = (log_scale, log_shape).  `eta = shape·(log t − log scale)`,
2048///   `o_D = shape/t`.
2049/// - **Gompertz**: θ = (log_rate, shape).  `eta = log H_G(t)` with
2050///   `H_G(t) = (rate/shape)·(exp(shape·t) − 1)`, `o_D = h_G(t)/H_G(t) =
2051///   shape·E/(E−1)` where `E = exp(shape·t)`.
2052/// - **Gompertz–Makeham**: θ = (log_rate, shape, log_makeham).
2053///   `eta = log H(t)` with `H(t) = makeham·t + H_G(t)`,
2054///   `o_D = (makeham + h_G(t)) / H(t)`.
2055///
2056/// Returns a flat `(d_eta/dθ_k, d_oD/dθ_k)` pair for each component of θ,
2057/// in the same order as `survival_baseline_theta_from_config`.  Linear has
2058/// no θ-parameters so returns `Ok(None)`.
2059///
2060/// The `eta`-channel derivatives are closed-form for every branch.  The
2061/// `o_D`-channel derivatives use the log-derivative identity
2062/// `∂o_D/∂θ = o_D · ∂log(o_D)/∂θ` which is more numerically stable near
2063/// the small-shape limit (shape·t → 0).  Near shape = 0 we fall back to
2064/// a third-order Taylor expansion with the same 1e-10 pivot that
2065/// `evaluate_survival_baseline` uses, keeping the value/derivative pair
2066/// continuous and agreement with the linear-hazard limit exact at shape=0.
2067pub fn baseline_offset_theta_partials(
2068    age: f64,
2069    cfg: &SurvivalBaselineConfig,
2070) -> Result<Option<Vec<(f64, f64)>>, String> {
2071    let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2072    else {
2073        return Ok(None);
2074    };
2075
2076    match params {
2077        ValidatedBaselineTarget::Weibull { scale, shape } => {
2078            // eta = shape·(log t − log scale)
2079            //     = shape·log t − shape·log scale
2080            // o_D = shape / t
2081            //
2082            // θ = (log_scale, log_shape):
2083            //   ∂eta/∂log_scale  = −shape          ∂o_D/∂log_scale = 0
2084            //   ∂eta/∂log_shape  = shape·(log t − log scale) = eta
2085            //   ∂o_D/∂log_shape  = shape / t = o_D
2086            let eta = shape * (age.ln() - scale.ln());
2087            let o_d = shape / age;
2088            let d_eta_d_log_scale = -shape;
2089            let d_od_d_log_scale = 0.0;
2090            let d_eta_d_log_shape = eta;
2091            let d_od_d_log_shape = o_d;
2092            Ok(Some(vec![
2093                (d_eta_d_log_scale, d_od_d_log_scale),
2094                (d_eta_d_log_shape, d_od_d_log_shape),
2095            ]))
2096        }
2097        ValidatedBaselineTarget::Gompertz { shape, .. } => {
2098            // θ = (log_rate, shape):
2099            //   Rate cancels in o_D = h/H for Gompertz, so ∂o_D/∂log_rate = 0
2100            //   and ∂eta/∂log_rate = 1. The shape channel uses
2101            //     ∂eta/∂shape   = −1/shape + t·E/(E−1)
2102            //     ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2103            //     ∂o_D/∂shape  = o_D · ∂log(o_D)/∂shape
2104            //   Near shape=0 both numerators are 1/shape cancellations. Use
2105            //   Taylor expansions with the same 1e-10 pivot that
2106            //   gompertz_components uses in evaluate_survival_baseline.
2107            let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2108            Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2109        }
2110        ValidatedBaselineTarget::GompertzMakeham {
2111            rate,
2112            shape,
2113            makeham,
2114        } => {
2115            // H(t) = M·t + H_G(t),   H_G(t) = (rate/shape)·(E−1),  E = exp(shape·t)
2116            // h(t) = M + h_G(t),     h_G(t) = rate·E
2117            // o_D  = h/H
2118            //
2119            // θ = (log_rate, shape, log_makeham):
2120            //   ∂H/∂log_rate    = rate · ∂H/∂rate = H_G               (scales with rate)
2121            //   ∂H/∂shape       = H_G_shape                            (closed form below)
2122            //   ∂H/∂log_makeham = makeham · t                          (linear in makeham)
2123            //   ∂h/∂log_rate    = rate · ∂h/∂rate = h_G
2124            //   ∂h/∂shape       = h_G_shape = rate·t·E + 0              (= rate·t·E)
2125            //   ∂h/∂log_makeham = makeham
2126            //   ∂eta/∂θ = (∂H/∂θ) / H
2127            //   ∂o_D/∂θ = (∂h/∂θ − o_D · ∂H/∂θ) / H
2128            //           = (∂h/∂θ)/H − o_D · (∂H/∂θ)/H
2129            let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2130            let cum_total = makeham * age + cum_g;
2131            if cum_total <= 0.0 || !cum_total.is_finite() {
2132                return Err(SurvivalConstructionError::DataValidationFailed {
2133                    reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2134                }
2135                .into());
2136            }
2137            let inst_total = makeham + inst_g;
2138            let o_d = inst_total / cum_total;
2139            let inv_cum = 1.0 / cum_total;
2140            // Each channel: ∂cum/∂θ and ∂inst/∂θ → ∂eta/∂θ = ∂cum/∂θ / cum
2141            //                                       ∂o_D/∂θ = (∂inst/∂θ − o_D·∂cum/∂θ) / cum
2142            // log_rate channel: cum is linear in rate through H_G; ∂cum/∂rate = H_G/rate,
2143            //   so ∂cum/∂log_rate = H_G (= cum_g here). Similarly ∂inst/∂log_rate = h_G (= inst_g).
2144            let d_cum_dlr = cum_g;
2145            let d_inst_dlr = inst_g;
2146            let d_eta_dlr = d_cum_dlr * inv_cum;
2147            let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2148            // shape channel: only H_G and h_G have shape dependence.
2149            let (d_cum_dshape, d_inst_dshape) =
2150                gompertz_cumulative_shape_derivative(age, rate, shape);
2151            let d_eta_dshape = d_cum_dshape * inv_cum;
2152            let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2153            // log_makeham channel: cum contributes M·t, inst contributes M.
2154            //   ∂cum/∂log_makeham = makeham·t,  ∂inst/∂log_makeham = makeham.
2155            let d_cum_dlm = makeham * age;
2156            let d_inst_dlm = makeham;
2157            let d_eta_dlm = d_cum_dlm * inv_cum;
2158            let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2159            Ok(Some(vec![
2160                (d_eta_dlr, d_od_dlr),
2161                (d_eta_dshape, d_od_dshape),
2162                (d_eta_dlm, d_od_dlm),
2163            ]))
2164        }
2165    }
2166}
2167
2168/// Shared chain-rule θ-gradient contraction for baseline offsets.
2169///
2170/// Both [`baseline_chain_rule_gradient`] (RP eta offsets) and
2171/// [`marginal_slope_baseline_chain_rule_gradient`] (probit q-offsets) reduce to
2172/// the same contraction of [`OffsetChannelResiduals`] against per-age baseline
2173/// θ-partials; only the `partials` provider differs. This engine owns the length
2174/// checks, the θ-dim probe, the parallel per-row reduction, the entry gating, and
2175/// the error handling. Each provider returns, per age, a length-`theta_dim` vector
2176/// of `(∂eta/∂θ_k, ∂(d eta/dt)/∂θ_k)` pairs (or `(∂q/∂θ_k, ∂(dq/dt)/∂θ_k)` for the
2177/// probit channel), and `None` when `cfg` has no θ-parameters (`Linear` target).
2178///
2179/// Contract (envelope theorem at converged β; the penalty has no θ dependence):
2180///
2181///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2182///     = Σᵢ r_X[i]·(∂o_X_i/∂θ_k) + r_D[i]·(∂o_D_i/∂θ_k) + r_E[i]·(∂o_E_i/∂θ_k)
2183///       + r_R[i]·(∂o_R_i/∂θ_k)
2184///
2185/// where `r_X = residuals.exit`, `r_D = residuals.derivative`, `r_E =
2186/// residuals.entry`, `r_R = residuals.right` (all sampleweight-scaled already).
2187/// Exit and derivative partials both come from the `age_exit[i]` evaluation;
2188/// the entry partial from `age_entry[i]`; the interval upper-bound (`R`)
2189/// η-partial from `age_right[i]`. Origin-entry rows have `r_E[i] == 0` exactly
2190/// and non-interval rows have `r_R[i] == 0` exactly, so those partials are
2191/// skipped for those rows (avoiding the `age > 0` precondition failure when an
2192/// inactive boundary age is 0 / a placeholder).
2193///
2194/// Returns `Ok(None)` when the provider reports no θ-parameters.
2195fn baseline_chain_rule_gradient_with_partials<F>(
2196    label: &'static str,
2197    age_entry: ndarray::ArrayView1<'_, f64>,
2198    age_exit: ndarray::ArrayView1<'_, f64>,
2199    age_right: ndarray::ArrayView1<'_, f64>,
2200    cfg: &SurvivalBaselineConfig,
2201    residuals: &crate::survival::OffsetChannelResiduals,
2202    partials: F,
2203) -> Result<Option<Array1<f64>>, String>
2204where
2205    F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2206{
2207    let n = age_exit.len();
2208    if age_entry.len() != n
2209        || age_right.len() != n
2210        || residuals.exit.len() != n
2211        || residuals.entry.len() != n
2212        || residuals.derivative.len() != n
2213        || residuals.right.len() != n
2214    {
2215        return Err(format!(
2216            "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2217            age_entry.len(),
2218            n,
2219            age_right.len(),
2220            residuals.exit.len(),
2221            residuals.entry.len(),
2222            residuals.derivative.len(),
2223            residuals.right.len(),
2224        ));
2225    }
2226    // Probe θ-dim via any valid positive age. If the provider returns None the
2227    // config carries no θ-parameters (Linear target) and there is no θ-gradient.
2228    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2229    let theta_dim = match probe_age {
2230        Some(t) => match partials(t, cfg)? {
2231            None => return Ok(None),
2232            Some(v) => v.len(),
2233        },
2234        None => {
2235            return Err(format!("{label}: no valid positive age for dim probe"));
2236        }
2237    };
2238    // Per-row partial contractions are independent, but each row's
2239    // contribution is a `theta_dim`-vector of `O(theta_dim · partial_cost)`
2240    // flops — small enough that the rayon parallel reduction's split
2241    // overhead dominates for any plausible `theta_dim`, *and* the
2242    // non-associative IEEE-754 sum order across thread chunks made the
2243    // engine drift in the low-order bits from row to row. The serial
2244    // accumulator below mirrors the inline reference exactly (and remains
2245    // ~memory-bandwidth-bound at large-scale `n`), so the engine is now a
2246    // bit-for-bit replacement for the legacy path, not just a
2247    // floating-point-noise-equivalent one.
2248    let mut grad = Array1::<f64>::zeros(theta_dim);
2249    for i in 0..n {
2250        // Exit + derivative partials both come from the age_exit evaluation.
2251        let partials_exit = partials(age_exit[i], cfg)?
2252            .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2253        if partials_exit.len() != theta_dim {
2254            return Err(format!(
2255                "{label}: theta_dim drifted ({} != {})",
2256                partials_exit.len(),
2257                theta_dim
2258            ));
2259        }
2260        let r_x = residuals.exit[i];
2261        let r_d = residuals.derivative[i];
2262        for k in 0..theta_dim {
2263            let (d_eta_dk, d_od_dk) = partials_exit[k];
2264            grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2265        }
2266        // Entry channel is nonzero only for rows with a positive entry
2267        // interval; for origin-entry rows age_entry may be 0 and calling
2268        // the provider would error. Gate on residual==0.
2269        let r_e = residuals.entry[i];
2270        if r_e != 0.0 {
2271            let partials_entry = partials(age_entry[i], cfg)?
2272                .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2273            for k in 0..theta_dim {
2274                grad[k] += r_e * partials_entry[k].0;
2275            }
2276        }
2277        // Interval upper-bound (`R`) channel: `q_right = X_time(R)·β + o_R(θ)`
2278        // carries its own baseline-θ η-offset evaluated at `age_right[i]`. It is
2279        // an η-level offset with NO time-derivative channel (the interval
2280        // likelihood `log[S(L) − S(R)]` has no hazard-derivative term), so it
2281        // contracts against the η-partial `.0` only. Nonzero only for
2282        // interval-censored latent rows; for every other channel/model
2283        // `r_right[i] == 0` exactly, so the (possibly placeholder) `age_right[i]`
2284        // partial is never consulted.
2285        let r_r = residuals.right[i];
2286        if r_r != 0.0 {
2287            let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2288                format!("{label}: unexpected None from partials at right boundary")
2289            })?;
2290            if partials_right.len() != theta_dim {
2291                return Err(format!(
2292                    "{label}: theta_dim drifted at right boundary ({} != {})",
2293                    partials_right.len(),
2294                    theta_dim
2295                ));
2296            }
2297            for k in 0..theta_dim {
2298                grad[k] += r_r * partials_right[k].0;
2299            }
2300        }
2301    }
2302    Ok(Some(grad))
2303}
2304
2305/// Contract `OffsetChannelResiduals` against `baseline_offset_theta_partials`
2306/// to produce the closed-form θ-gradient of the unpenalized NLL at converged β.
2307///
2308/// Derivation (envelope theorem on the penalized objective, β* minimizes the
2309/// same cost wrt β and the penalty has no θ dependence):
2310///
2311///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2312///     = d[NLL(β*; o(θ))] / dθ_k
2313///     = Σᵢ (∂NLL_i/∂o_X[i])·(∂o_X_i/∂θ_k)
2314///       + (∂NLL_i/∂o_E[i])·(∂o_E_i/∂θ_k)
2315///       + (∂NLL_i/∂o_D[i])·(∂o_D_i/∂θ_k)
2316///       + (∂NLL_i/∂o_R[i])·(∂o_R_i/∂θ_k)
2317///
2318/// The four `∂NLL_i/∂o_channel` terms are the `exit`, `entry`, `derivative`,
2319/// `right` fields of [`OffsetChannelResiduals`] (sampleweight-scaled already).
2320/// The `∂o/∂θ_k` terms come from [`baseline_offset_theta_partials`] per obs at
2321/// the appropriate age.
2322///
2323/// Per the RP offset convention:
2324///   o_E[i] = eta_target(age_entry[i])
2325///   o_X[i] = eta_target(age_exit[i])
2326///   o_D[i] = d/dt eta_target(t) |_{t=age_exit[i]}
2327///   o_R[i] = eta_target(age_right[i])   (interval upper bound `R`; η-level only)
2328///
2329/// so the exit and derivative partials are both evaluated at `age_exit[i]`,
2330/// the entry partial at `age_entry[i]`, and the interval-right η-partial at
2331/// `age_right[i]`. The origin-entry case (`entry_at_origin[i]`) has
2332/// `r_entry[i] = 0` exactly and every non-interval row has `r_right[i] = 0`
2333/// exactly, so we skip the `baseline_offset_theta_partials(age, ..)` call for
2334/// those rows (avoiding the `age > 0` precondition failure when an inactive
2335/// boundary age is 0 / a placeholder).
2336///
2337/// Returns `Ok(None)` when `cfg.target == Linear` (no θ-parameters).
2338pub fn baseline_chain_rule_gradient(
2339    age_entry: ndarray::ArrayView1<'_, f64>,
2340    age_exit: ndarray::ArrayView1<'_, f64>,
2341    age_right: ndarray::ArrayView1<'_, f64>,
2342    cfg: &SurvivalBaselineConfig,
2343    residuals: &crate::survival::OffsetChannelResiduals,
2344) -> Result<Option<Array1<f64>>, String> {
2345    baseline_chain_rule_gradient_with_partials(
2346        "baseline_chain_rule_gradient",
2347        age_entry,
2348        age_exit,
2349        age_right,
2350        cfg,
2351        residuals,
2352        baseline_offset_theta_partials,
2353    )
2354}
2355
2356/// Chain-rule θ-gradient for marginal-slope probit baseline offsets.
2357///
2358/// This is the probit-survival counterpart of [`baseline_chain_rule_gradient`].
2359/// It contracts residuals against
2360/// [`marginal_slope_baseline_offset_theta_partials`], so the offset channels
2361/// are `(q_entry, q_exit, dq_exit/dt)` with `Phi(-q(t)) = exp(-H0(t))`.
2362pub fn marginal_slope_baseline_chain_rule_gradient(
2363    age_entry: ndarray::ArrayView1<'_, f64>,
2364    age_exit: ndarray::ArrayView1<'_, f64>,
2365    cfg: &SurvivalBaselineConfig,
2366    residuals: &crate::survival::OffsetChannelResiduals,
2367) -> Result<Option<Array1<f64>>, String> {
2368    // Marginal-slope has no interval upper-bound channel; `residuals.right` is
2369    // all-zero, so the right channel never contracts and `age_exit` serves as an
2370    // unconsulted placeholder for the (unused) `age_right` argument.
2371    baseline_chain_rule_gradient_with_partials(
2372        "marginal_slope_baseline_chain_rule_gradient",
2373        age_entry,
2374        age_exit,
2375        age_exit,
2376        cfg,
2377        residuals,
2378        marginal_slope_baseline_offset_theta_partials,
2379    )
2380}
2381
2382/// Shared Gompertz hazard components `(H_G(t), h_G(t))`.
2383/// Mirrors the private helper in `evaluate_survival_baseline` with the
2384/// same 1e-10 small-shape pivot.
2385#[inline]
2386fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2387    if shape.abs() < 1e-10 {
2388        // Taylor at shape=0: H_G(t) = rate·t·(1 + shape·t/2 + (shape·t)²/6),
2389        // h_G(t) = rate·(1 + shape·t + (shape·t)²/2).
2390        let x = shape * age;
2391        (
2392            rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2393            rate * (1.0 + x + 0.5 * x * x),
2394        )
2395    } else {
2396        let shape_age = shape * age;
2397        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2398        let instant_hazard = rate * shape_age.exp();
2399        (cumulative_hazard, instant_hazard)
2400    }
2401}
2402
2403/// Partials of `(H_G(t), h_G(t))` with respect to the shape parameter.
2404///
2405/// H_G(t) = (rate/shape)·(E−1),  h_G(t) = rate·E,  E = exp(shape·t)
2406///
2407/// ∂H_G/∂shape  = −(rate/shape²)·(E−1) + (rate/shape)·t·E
2408///              = rate·[t·E/shape − (E−1)/shape²]
2409///              = rate·[t·E·shape − (E−1)] / shape²
2410/// ∂h_G/∂shape  = rate·t·E
2411///
2412/// Near shape=0 the first expression has a 1/shape² singularity that
2413/// cancels analytically. Using the series E−1 = Σₖ≥₁ (shape·t)ᵏ/k!:
2414///   t·E·shape − (E−1) = Σₖ≥₁ (shape·t)ᵏ·(k−1)/k!·shape⁰  [after simplification]
2415///                     = (shape·t)²/2 + 2(shape·t)³/6 + 3(shape·t)⁴/24 + ...
2416/// so ∂H_G/∂shape at shape→0 = rate·[t²/2 + shape·t³/3 + shape²·t⁴/8 + ...].
2417/// We use that Taylor expansion in the small-shape branch.
2418#[inline]
2419fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2420    let x = shape * age;
2421    let dinstg_dshape = rate * age * x.exp();
2422    // The exact form rate·[t·E·shape − (E−1)]/shape² is a difference of two
2423    // O(1/shape) quantities whose leading terms cancel, so its accuracy is
2424    // governed by the dimensionless product x = shape·age, NOT by `shape`
2425    // alone. Pivoting on `shape < 1e-10` ignored `age`: for large ages a small
2426    // shape still yields a small x where the catastrophic cancellation has
2427    // already corrupted the difference. Pivot on x instead; the 3-term Taylor
2428    // (through O(x²)) is accurate to <1e-9 for |x| < 1e-4, and the exact branch
2429    // is clean above it.
2430    let dhg_dshape = if x.abs() < 1e-4 {
2431        let t = age;
2432        // Truncated to O(x³): t²/2 + x·t²/3 + x²·t²/8
2433        rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2434    } else {
2435        // t·E·shape − (E−1) = t·e^x·shape − expm1(x)
2436        let e = x.exp();
2437        let em1 = x.exp_m1();
2438        let numerator = age * e * shape - em1;
2439        rate * numerator / (shape * shape)
2440    };
2441    (dhg_dshape, dinstg_dshape)
2442}
2443
2444/// Partials `(∂eta/∂shape, ∂o_D/∂shape)` for the pure Gompertz baseline.
2445/// Pure Gompertz has rate cancelling in o_D, so there is no log_rate
2446/// contribution in o_D. The rate channel for eta is trivially 1; this
2447/// helper only covers the shape channel.
2448#[inline]
2449fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2450    if shape.abs() < 1e-10 {
2451        // Closed-form limits from the series t·E/(E−1) = 1/x + 1/2 + x/12 + ...
2452        // with E = e^x, x = shape·t:
2453        //   ∂eta/∂shape  = −1/shape + t·E/(E−1)
2454        //                = t/2 + shape·t²/12 + O(shape²)
2455        //   o_D         = shape·E/(E−1)
2456        //                = 1/t + shape/2 + shape²·t/12 + O(shape³)
2457        //   ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2458        //                = t/2 − shape·t²/12 + O(shape²)
2459        //   ∂o_D/∂shape = o_D · ∂log(o_D)/∂shape
2460        let t = age;
2461        let d_eta = 0.5 * t + shape * t * t / 12.0;
2462        let dlog_od = 0.5 * t - shape * t * t / 12.0;
2463        let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2464        (d_eta, o_d * dlog_od)
2465    } else {
2466        let x = shape * age;
2467        let e = x.exp();
2468        let em1 = x.exp_m1(); // E − 1 via expm1 for accuracy at small x
2469        let d_eta = -1.0 / shape + age * e / em1;
2470        // o_D = shape · E/(E−1); ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2471        let o_d = shape * e / em1;
2472        let dlog_od = 1.0 / shape - age / em1;
2473        (d_eta, o_d * dlog_od)
2474    }
2475}
2476
2477/// Per-target baseline parameters after the shared age guard and the per-target
2478/// required-field extraction + finiteness/positivity validation have passed.
2479///
2480/// This is the single source of truth for *which* config fields each baseline
2481/// target requires and *what* domain each must satisfy. Both the hazard-value
2482/// evaluator (`survival_cumulative_and_instant_hazard`) and the θ-partials
2483/// evaluator (`survival_hazard_theta_partials`) consume it and only differ in how
2484/// they assemble their (value vs derivative) outputs from these checked scalars.
2485#[derive(Clone, Copy, Debug)]
2486enum ValidatedBaselineTarget {
2487    Weibull { scale: f64, shape: f64 },
2488    Gompertz { rate: f64, shape: f64 },
2489    GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2490}
2491
2492/// Shared prologue for the survival baseline hazard evaluators: validate the age,
2493/// then extract and domain-check the per-target parameters from `cfg`.
2494///
2495/// `Ok(None)` is the `Linear` target (no parametric baseline). `context` is woven
2496/// into the age-guard error so each caller keeps its specific phrasing.
2497fn validated_baseline_params(
2498    age: f64,
2499    cfg: &SurvivalBaselineConfig,
2500    context: &str,
2501) -> Result<Option<ValidatedBaselineTarget>, String> {
2502    if !age.is_finite() || age <= 0.0 {
2503        return Err(format!(
2504            "survival ages must be finite and positive for {context}"
2505        ));
2506    }
2507
2508    match cfg.target {
2509        SurvivalBaselineTarget::Linear => Ok(None),
2510        SurvivalBaselineTarget::Weibull => {
2511            let scale = cfg
2512                .scale
2513                .ok_or_else(|| "weibull missing scale".to_string())?;
2514            let shape = cfg
2515                .shape
2516                .ok_or_else(|| "weibull missing shape".to_string())?;
2517            if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2518                return Err(SurvivalConstructionError::InvalidConfig {
2519                    reason: "weibull baseline requires finite positive scale and shape".to_string(),
2520                }
2521                .into());
2522            }
2523            Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2524        }
2525        SurvivalBaselineTarget::Gompertz => {
2526            let rate = cfg
2527                .rate
2528                .ok_or_else(|| "gompertz missing rate".to_string())?;
2529            let shape = cfg
2530                .shape
2531                .ok_or_else(|| "gompertz missing shape".to_string())?;
2532            if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2533                return Err(
2534                    "gompertz baseline requires finite positive rate and finite shape".to_string(),
2535                );
2536            }
2537            Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2538        }
2539        SurvivalBaselineTarget::GompertzMakeham => {
2540            let rate = cfg
2541                .rate
2542                .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2543            let shape = cfg
2544                .shape
2545                .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2546            let makeham = cfg
2547                .makeham
2548                .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2549            if !(rate.is_finite()
2550                && shape.is_finite()
2551                && makeham.is_finite()
2552                && rate > 0.0
2553                && makeham > 0.0)
2554            {
2555                return Err(
2556                    "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2557                        .to_string(),
2558                );
2559            }
2560            Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2561                rate,
2562                shape,
2563                makeham,
2564            }))
2565        }
2566    }
2567}
2568
2569fn survival_hazard_theta_partials(
2570    age: f64,
2571    cfg: &SurvivalBaselineConfig,
2572) -> Result<Option<Vec<(f64, f64)>>, String> {
2573    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2574        return Ok(None);
2575    };
2576
2577    match params {
2578        ValidatedBaselineTarget::Weibull { scale, shape } => {
2579            let log_time_ratio = age.ln() - scale.ln();
2580            let cumulative_hazard = (age / scale).powf(shape);
2581            let instant_hazard = shape * cumulative_hazard / age;
2582            let eta = shape * log_time_ratio;
2583            Ok(Some(vec![
2584                (-shape * cumulative_hazard, -shape * instant_hazard),
2585                (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2586            ]))
2587        }
2588        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2589            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2590            let (d_cum_dshape, d_inst_dshape) =
2591                gompertz_cumulative_shape_derivative(age, rate, shape);
2592            Ok(Some(vec![
2593                (cumulative_hazard, instant_hazard),
2594                (d_cum_dshape, d_inst_dshape),
2595            ]))
2596        }
2597        ValidatedBaselineTarget::GompertzMakeham {
2598            rate,
2599            shape,
2600            makeham,
2601        } => {
2602            let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2603            let (d_cum_dshape, d_inst_dshape) =
2604                gompertz_cumulative_shape_derivative(age, rate, shape);
2605            Ok(Some(vec![
2606                (cum_gompertz, inst_gompertz),
2607                (d_cum_dshape, d_inst_dshape),
2608                (makeham * age, makeham),
2609            ]))
2610        }
2611    }
2612}
2613
2614fn survival_cumulative_and_instant_hazard(
2615    age: f64,
2616    cfg: &SurvivalBaselineConfig,
2617) -> Result<Option<(f64, f64)>, String> {
2618    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2619        return Ok(None);
2620    };
2621
2622    match params {
2623        ValidatedBaselineTarget::Weibull { scale, shape } => {
2624            let cumulative_hazard = (age / scale).powf(shape);
2625            let instant_hazard = shape * cumulative_hazard / age;
2626            Ok(Some((cumulative_hazard, instant_hazard)))
2627        }
2628        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2629            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2630            Ok(Some((cumulative_hazard, instant_hazard)))
2631        }
2632        ValidatedBaselineTarget::GompertzMakeham {
2633            rate,
2634            shape,
2635            makeham,
2636        } => {
2637            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2638            Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2639        }
2640    }
2641}
2642
2643#[derive(Clone, Copy, Debug)]
2644struct MarginalSlopeBaselinePoint {
2645    instant_hazard: f64,
2646    q: f64,
2647    q_t: f64,
2648}
2649
2650fn evaluate_marginal_slope_baseline_point(
2651    age: f64,
2652    cfg: &SurvivalBaselineConfig,
2653) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2654    let Some((cumulative_hazard, instant_hazard)) =
2655        survival_cumulative_and_instant_hazard(age, cfg)?
2656    else {
2657        return Ok(None);
2658    };
2659    if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2660        return Err(format!(
2661            "{} marginal-slope baseline produced non-positive cumulative hazard",
2662            survival_baseline_targetname(cfg.target)
2663        ));
2664    }
2665    if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2666        return Err(format!(
2667            "{} marginal-slope baseline produced non-positive instant hazard",
2668            survival_baseline_targetname(cfg.target)
2669        ));
2670    }
2671    let survival = (-cumulative_hazard).exp();
2672    if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2673        return Err(format!(
2674            "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2675            survival_baseline_targetname(cfg.target)
2676        ));
2677    }
2678    let q = -standard_normal_quantile(survival).map_err(|e| {
2679        format!(
2680            "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2681            survival_baseline_targetname(cfg.target)
2682        )
2683    })?;
2684    let phi_q = normal_pdf(q);
2685    if !(phi_q.is_finite() && phi_q > 0.0) {
2686        return Err(format!(
2687            "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2688            survival_baseline_targetname(cfg.target)
2689        ));
2690    }
2691    Ok(Some(MarginalSlopeBaselinePoint {
2692        instant_hazard,
2693        q,
2694        q_t: instant_hazard * survival / phi_q,
2695    }))
2696}
2697
2698/// Evaluate the parametric baseline target at a given age.
2699/// Returns `(eta_target(age), d eta_target / d age)` on the log-cumulative-hazard scale.
2700pub fn evaluate_survival_baseline(
2701    age: f64,
2702    cfg: &SurvivalBaselineConfig,
2703) -> Result<(f64, f64), String> {
2704    if !age.is_finite() || age < 0.0 {
2705        return Err(
2706            "survival ages must be finite and non-negative for baseline target evaluation"
2707                .to_string(),
2708        );
2709    }
2710
2711    // At t = 0 every parametric cumulative-hazard target satisfies H(0) = 0
2712    // exactly (this is the defining property of a cumulative hazard:
2713    // S(0) = 1 ⇒ H(0) = -log S(0) = 0). The log-cumulative-hazard offset is
2714    // therefore eta(0) = log H(0) = -inf, and we report a zero log-derivative
2715    // since `exp(eta(0)) = H(0) = 0` is the only physically valid value.
2716    // Returning `Ok((-inf, 0.0))` keeps the baseline cumulative hazard exactly
2717    // zero at the origin; downstream callers that need to multiply this offset
2718    // into a linear predictor are responsible for handling the origin row via
2719    // the `entry_at_origin` / `exit_at_origin` gating already wired through the
2720    // engine.
2721    if age == 0.0 {
2722        return match cfg.target {
2723            SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2724            SurvivalBaselineTarget::Weibull
2725            | SurvivalBaselineTarget::Gompertz
2726            | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2727        };
2728    }
2729
2730    let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2731        return Ok((0.0, 0.0));
2732    };
2733
2734    match params {
2735        ValidatedBaselineTarget::Weibull { scale, shape } => {
2736            let eta = shape * (age.ln() - scale.ln());
2737            let derivative = shape / age;
2738            Ok((eta, derivative))
2739        }
2740        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2741            let (h, inst) = gompertz_hazard_components(age, rate, shape);
2742            if h <= 0.0 || !h.is_finite() {
2743                return Err(if shape.abs() < 1e-10 {
2744                    "invalid gompertz baseline at near-zero shape".to_string()
2745                } else {
2746                    "gompertz baseline produced non-positive cumulative hazard".to_string()
2747                });
2748            }
2749            let derivative = inst / h;
2750            Ok((h.ln(), derivative))
2751        }
2752        ValidatedBaselineTarget::GompertzMakeham {
2753            rate,
2754            shape,
2755            makeham,
2756        } => {
2757            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2758            let h = makeham * age + h_gompertz;
2759            if h <= 0.0 || !h.is_finite() {
2760                return Err(
2761                    "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2762                );
2763            }
2764            let inst = makeham + inst_gompertz;
2765            let derivative = inst / h;
2766            Ok((h.ln(), derivative))
2767        }
2768    }
2769}
2770
2771/// Evaluate the parametric baseline as the probit index whose marginal
2772/// survival is the true hazard survival `exp(-H0(t))`.
2773///
2774/// Returns `(q(age), dq / d age)` such that `Phi(-q(age)) = exp(-H0(age))`.
2775/// The derivative is `h0(t) * exp(-H0(t)) / phi(q(t))`.
2776pub fn evaluate_survival_marginal_slope_baseline(
2777    age: f64,
2778    cfg: &SurvivalBaselineConfig,
2779) -> Result<(f64, f64), String> {
2780    // Survival-curve origin. Every cumulative-hazard baseline satisfies
2781    // `H0(0) = 0` (`S0(0) = exp(-H0(0)) = 1`), so the probit index
2782    // `q(0) = -Phi^{-1}(S0(0)) = -Phi^{-1}(1) = -inf`: there is no *finite*
2783    // probit-survival offset at the origin. The survival surface anchors
2784    // `S(0) = 1` directly (see the `t <= 0` origin handling in the survival
2785    // predict paths), so the baseline contributes nothing here — report the
2786    // zero offset rather than aborting in the `age <= 0` hazard guard. This
2787    // mirrors `evaluate_survival_baseline`'s explicit `age == 0` branch on the
2788    // log-cumulative-hazard channel; without it the probit/marginal-slope
2789    // baseline path (location-scale + marginal-slope likelihoods) could not be
2790    // evaluated on a prediction grid whose first node is the origin (#1024).
2791    if age == 0.0 {
2792        return Ok((0.0, 0.0));
2793    }
2794    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2795        return Ok((0.0, 0.0));
2796    };
2797    Ok((point.q, point.q_t))
2798}
2799
2800/// Partial derivatives of the true survival marginal-slope probit offsets
2801/// `(q(t), dq(t)/dt)` with respect to the baseline θ-parameters.
2802///
2803/// The returned channels match `survival_baseline_theta_from_config`.  For
2804/// Gompertz-Makeham, θ is `(log_rate, shape, log_makeham)`.  If
2805/// `S(t)=exp(-H(t))`, `q(t)=-Phi^-1(S(t))`, `A(t)=S(t)/phi(q(t))`, and
2806/// `h(t)=dH/dt`, then
2807///
2808///   dq/dθ      = A * dH/dθ
2809///   d(q')/dθ   = A * (dh/dθ + h * (q*A - 1) * dH/dθ)
2810///
2811/// which keeps the probit transform and the hazard baseline analytically tied.
2812pub fn marginal_slope_baseline_offset_theta_partials(
2813    age: f64,
2814    cfg: &SurvivalBaselineConfig,
2815) -> Result<Option<Vec<(f64, f64)>>, String> {
2816    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2817        return Ok(None);
2818    };
2819    let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2820        .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2821    let a = point.q_t / point.instant_hazard;
2822    let a_log_derivative_factor = point.q * a - 1.0;
2823    Ok(Some(
2824        hazard_partials
2825            .into_iter()
2826            .map(|(d_h_cum, d_h_inst)| {
2827                (
2828                    a * d_h_cum,
2829                    a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2830                )
2831            })
2832            .collect(),
2833    ))
2834}
2835
2836/// Contract marginal-slope offset residuals and channel curvatures into the
2837/// exact Hessian with respect to baseline θ-parameters.
2838pub fn marginal_slope_baseline_chain_rule_hessian(
2839    age_entry: ndarray::ArrayView1<'_, f64>,
2840    age_exit: ndarray::ArrayView1<'_, f64>,
2841    cfg: &SurvivalBaselineConfig,
2842    residuals: &crate::survival::OffsetChannelResiduals,
2843    curvatures: &crate::survival::OffsetChannelCurvatures,
2844) -> Result<Option<Array2<f64>>, String> {
2845    let n = age_exit.len();
2846    if age_entry.len() != n
2847        || residuals.exit.len() != n
2848        || residuals.entry.len() != n
2849        || residuals.derivative.len() != n
2850        || curvatures.rows.len() != n
2851    {
2852        return Err(format!(
2853            "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2854            age_entry.len(),
2855            n,
2856            residuals.exit.len(),
2857            residuals.entry.len(),
2858            residuals.derivative.len(),
2859            curvatures.rows.len(),
2860        ));
2861    }
2862    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2863    let dim = match probe_age {
2864        Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2865            None => return Ok(None),
2866            Some(parts) => parts.first.len(),
2867        },
2868        None => {
2869            return Err(
2870                "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2871                    .to_string(),
2872            );
2873        }
2874    };
2875    // Per-row Hessian contractions are independent. Each row contributes a
2876    // dim×dim increment combining second partials (exit/entry channels) with
2877    // the curvature-weighted outer product of the (entry, exit, derivative)
2878    // first-partial Jacobians. Parallel try_fold/try_reduce accumulates them.
2879    let hessian = (0..n)
2880        .into_par_iter()
2881        .try_fold(
2882            || Array2::<f64>::zeros((dim, dim)),
2883            |mut acc, i| -> Result<Array2<f64>, String> {
2884                let exit_parts =
2885                    marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2886                        .ok_or_else(|| {
2887                            "unexpected None from marginal-slope second partials at exit"
2888                                .to_string()
2889                        })?;
2890                if exit_parts.first.len() != dim {
2891                    return Err(
2892                        "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2893                    );
2894                }
2895                let mut entry_parts = None;
2896                if residuals.entry[i] != 0.0 {
2897                    entry_parts = Some(
2898                        marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2899                            .ok_or_else(|| {
2900                                "unexpected None from marginal-slope second partials at entry"
2901                                    .to_string()
2902                            })?,
2903                    );
2904                }
2905                for a in 0..dim {
2906                    for b in 0..dim {
2907                        let j_exit_a = exit_parts.first[a].0;
2908                        let j_exit_b = exit_parts.first[b].0;
2909                        let j_deriv_a = exit_parts.first[a].1;
2910                        let j_deriv_b = exit_parts.first[b].1;
2911                        let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2912                            + residuals.derivative[i] * exit_parts.second[a][b].1;
2913                        if let Some(parts) = entry_parts.as_ref() {
2914                            value += residuals.entry[i] * parts.second[a][b].0;
2915                        }
2916                        let curv = curvatures.rows[i];
2917                        let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2918                        let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2919                        let ja = [j_entry_a, j_exit_a, j_deriv_a];
2920                        let jb = [j_entry_b, j_exit_b, j_deriv_b];
2921                        for u in 0..3 {
2922                            for v in 0..3 {
2923                                value += ja[u] * curv[u][v] * jb[v];
2924                            }
2925                        }
2926                        acc[[a, b]] += value;
2927                    }
2928                }
2929                Ok(acc)
2930            },
2931        )
2932        .try_reduce(|| Array2::<f64>::zeros((dim, dim)), |a, b| Ok(a + b))?;
2933    Ok(Some(hessian))
2934}
2935
2936struct MarginalSlopeThetaSecondPartials {
2937    first: Vec<(f64, f64)>,
2938    second: Vec<Vec<(f64, f64)>>,
2939}
2940
2941fn marginal_slope_baseline_offset_theta_second_partials(
2942    age: f64,
2943    cfg: &SurvivalBaselineConfig,
2944) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2945    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2946        return Ok(None);
2947    };
2948    let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2949        return Ok(None);
2950    };
2951    let (cum_hazard, instant_hazard) = hazard;
2952    let survival = (-cum_hazard).exp();
2953    let a = survival / normal_pdf(point.q);
2954    let b = point.q * a - 1.0;
2955    let b_factor = a + point.q * b;
2956    let dim = first.len();
2957    let mut first_out = Vec::with_capacity(dim);
2958    let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2959    for i in 0..dim {
2960        let (h_i, inst_i) = first[i];
2961        first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2962    }
2963    for i in 0..dim {
2964        for j in 0..dim {
2965            let (h_i, inst_i) = first[i];
2966            let (h_j, inst_j) = first[j];
2967            let (h_ij, inst_ij) = second[i][j];
2968            let a_j = a * b * h_j;
2969            let b_j = a * h_j * b_factor;
2970            let q_ij = a * h_ij + a * b * h_i * h_j;
2971            let qt_inner_i = inst_i + instant_hazard * b * h_i;
2972            let qt_ij = a_j * qt_inner_i
2973                + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2974            second_out[i][j] = (q_ij, qt_ij);
2975        }
2976    }
2977    Ok(Some(MarginalSlopeThetaSecondPartials {
2978        first: first_out,
2979        second: second_out,
2980    }))
2981}
2982
2983type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2984
2985fn survival_hazard_theta_first_second(
2986    age: f64,
2987    cfg: &SurvivalBaselineConfig,
2988) -> Result<Option<HazardFirstSecond>, String> {
2989    let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2990        return Ok(None);
2991    };
2992    let first = survival_hazard_theta_partials(age, cfg)?
2993        .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2994    let dim = first.len();
2995    let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2996    match cfg.target {
2997        SurvivalBaselineTarget::Linear => return Ok(None),
2998        SurvivalBaselineTarget::Weibull => {
2999            let scale = cfg
3000                .scale
3001                .ok_or_else(|| "weibull missing scale".to_string())?;
3002            let shape = cfg
3003                .shape
3004                .ok_or_else(|| "weibull missing shape".to_string())?;
3005            let log_time_ratio = age.ln() - scale.ln();
3006            let cumulative_hazard = hazard.0;
3007            let instant_hazard = hazard.1;
3008            let eta = shape * log_time_ratio;
3009            second[0][0] = (
3010                shape * shape * cumulative_hazard,
3011                shape * shape * instant_hazard,
3012            );
3013            second[0][1] = (
3014                -shape * cumulative_hazard * (1.0 + eta),
3015                -shape * instant_hazard * (2.0 + eta),
3016            );
3017            second[1][0] = second[0][1];
3018            second[1][1] = (
3019                eta * cumulative_hazard * (1.0 + eta),
3020                (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
3021            );
3022        }
3023        SurvivalBaselineTarget::Gompertz => {
3024            let rate = cfg
3025                .rate
3026                .ok_or_else(|| "gompertz missing rate".to_string())?;
3027            let shape = cfg
3028                .shape
3029                .ok_or_else(|| "gompertz missing shape".to_string())?;
3030            second[0][0] = first[0];
3031            second[0][1] = first[1];
3032            second[1][0] = first[1];
3033            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3034        }
3035        SurvivalBaselineTarget::GompertzMakeham => {
3036            let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
3037            let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
3038            second[0][0] = first[0];
3039            second[0][1] = first[1];
3040            second[1][0] = first[1];
3041            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
3042            second[2][2] = first[2];
3043        }
3044    }
3045    Ok(Some((hazard, first, second)))
3046}
3047
3048#[inline]
3049fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3050    let x = shape * age;
3051    // ∂²H_G/∂shape² = rate·[t²·E/shape − 2·(shape·t·E − (E−1))/shape³]. This is
3052    // a difference of O(1/shape³) terms whose leading parts cancel, so its
3053    // floating-point accuracy is governed by x = shape·age — and the
3054    // cancellation is FAR worse than the first derivative's 1/shape² form.
3055    // Empirically the exact branch is already garbage for |x| < ~1e-4 (e.g.
3056    // x=1e-9 gives a ~98% relative error; x=1e-10 a ~9700% error). The old
3057    // `shape < 1e-10` pivot ignored `age` and so routed those small-x cases
3058    // through the cancelling exact form, corrupting the marginal-slope baseline
3059    // Hessian near small shape. Pivot on x with a wider threshold than the
3060    // first derivative: the 3-term Taylor (through O(x²)) holds to <1e-8 for
3061    // |x| < 1e-3, and the exact branch is clean above it.
3062    if x.abs() < 1e-3 {
3063        let t = age;
3064        (
3065            rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3066            rate * t * t * (1.0 + x + 0.5 * x * x),
3067        )
3068    } else {
3069        let e = x.exp();
3070        let em1 = x.exp_m1();
3071        let n = shape * age * e - em1;
3072        (
3073            rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3074            rate * age * age * e,
3075        )
3076    }
3077}
3078
3079// ---------------------------------------------------------------------------
3080// Baseline offsets
3081// ---------------------------------------------------------------------------
3082
3083#[derive(Clone, Copy)]
3084enum BaselineOffsetEvaluator {
3085    LogCumulativeHazard,
3086    ProbitSurvival,
3087}
3088
3089impl BaselineOffsetEvaluator {
3090    fn length_error(self) -> String {
3091        match self {
3092            Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3093                reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3094            }
3095            .into(),
3096            Self::ProbitSurvival => {
3097                "survival probit baseline offsets require matching entry/exit lengths".to_string()
3098            }
3099        }
3100    }
3101
3102    fn finite_error(self) -> &'static str {
3103        match self {
3104            Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3105            Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3106        }
3107    }
3108
3109    fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3110        match self {
3111            Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3112            Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3113        }
3114    }
3115
3116    fn exit_is_finite(self, value: f64, age: f64) -> bool {
3117        match self {
3118            Self::LogCumulativeHazard => {
3119                value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3120            }
3121            Self::ProbitSurvival => value.is_finite(),
3122        }
3123    }
3124}
3125
3126fn build_survival_offsets_with_evaluator(
3127    age_entry: &Array1<f64>,
3128    age_exit: &Array1<f64>,
3129    cfg: &SurvivalBaselineConfig,
3130    evaluator: BaselineOffsetEvaluator,
3131) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3132    if age_entry.len() != age_exit.len() {
3133        return Err(evaluator.length_error());
3134    }
3135    let n = age_entry.len();
3136    // Each row's three offsets are independent across i. Compute the triplets
3137    // in parallel, then unpack into three Array1 outputs preserving order.
3138    let triples: Vec<(f64, f64, f64)> = (0..n)
3139        .into_par_iter()
3140        .map(|i| -> Result<(f64, f64, f64), String> {
3141            // Origin-entry rows are multiplied out by the survival engines, so
3142            // keep their entry channel finite even when the evaluator's natural
3143            // value at t=0 is undefined or -inf.
3144            let entry_age = age_entry[i];
3145            let e0 = if !entry_age.is_finite() {
3146                return Err(SurvivalConstructionError::DataValidationFailed {
3147                    reason: format!("non-finite entry age at row {i}"),
3148                }
3149                .into());
3150            } else if entry_age <= 0.0 {
3151                0.0
3152            } else {
3153                evaluator.evaluate(entry_age, cfg)?.0
3154            };
3155            let exit_age = age_exit[i];
3156            let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3157            if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3158                return Err(SurvivalConstructionError::DataValidationFailed {
3159                    reason: evaluator.finite_error().to_string(),
3160                }
3161                .into());
3162            }
3163            Ok((e0, e1, d1))
3164        })
3165        .collect::<Result<Vec<_>, String>>()?;
3166    let mut eta_entry = Array1::<f64>::zeros(n);
3167    let mut eta_exit = Array1::<f64>::zeros(n);
3168    let mut derivative_exit = Array1::<f64>::zeros(n);
3169    for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3170        eta_entry[i] = e0;
3171        eta_exit[i] = e1;
3172        derivative_exit[i] = d1;
3173    }
3174    Ok((eta_entry, eta_exit, derivative_exit))
3175}
3176
3177/// Compute baseline target offsets for all observations.
3178/// Returns `(eta_entry, eta_exit, derivative_exit)`.
3179pub fn build_survival_baseline_offsets(
3180    age_entry: &Array1<f64>,
3181    age_exit: &Array1<f64>,
3182    cfg: &SurvivalBaselineConfig,
3183) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3184    build_survival_offsets_with_evaluator(
3185        age_entry,
3186        age_exit,
3187        cfg,
3188        BaselineOffsetEvaluator::LogCumulativeHazard,
3189    )
3190}
3191
3192/// Compute probit-survival baseline target offsets for all observations.
3193/// Returns `(q_entry, q_exit, q_derivative_exit)` where `Phi(-q(t)) = exp(-H0(t))`.
3194pub fn build_survival_marginal_slope_baseline_offsets(
3195    age_entry: &Array1<f64>,
3196    age_exit: &Array1<f64>,
3197    cfg: &SurvivalBaselineConfig,
3198) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3199    build_survival_offsets_with_evaluator(
3200        age_entry,
3201        age_exit,
3202        cfg,
3203        BaselineOffsetEvaluator::ProbitSurvival,
3204    )
3205}
3206
3207pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3208    matches!(
3209        inverse_link,
3210        Some(
3211            InverseLink::Standard(StandardLink::Probit)
3212                | InverseLink::LatentCLogLog(_)
3213                | InverseLink::Sas(_)
3214                | InverseLink::BetaLogistic(_)
3215                | InverseLink::Mixture(_)
3216        )
3217    )
3218}
3219
3220pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3221    match likelihood_mode {
3222        SurvivalLikelihoodMode::LocationScale
3223        | SurvivalLikelihoodMode::Latent
3224        | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3225        SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3226        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3227    }
3228}
3229
3230pub fn build_survival_time_offsets_for_likelihood(
3231    age_entry: &Array1<f64>,
3232    age_exit: &Array1<f64>,
3233    baseline_cfg: &SurvivalBaselineConfig,
3234    likelihood_mode: SurvivalLikelihoodMode,
3235    inverse_link: Option<&InverseLink>,
3236) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3237    if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3238        || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3239            && location_scale_uses_probit_survival_baseline(inverse_link))
3240    {
3241        build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3242    } else {
3243        build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3244    }
3245}
3246
3247pub fn add_survival_time_derivative_guard_offset(
3248    age_entry: &Array1<f64>,
3249    age_exit: &Array1<f64>,
3250    anchor_time: f64,
3251    derivative_guard: f64,
3252    eta_offset_entry: &mut Array1<f64>,
3253    eta_offset_exit: &mut Array1<f64>,
3254    derivative_offset_exit: &mut Array1<f64>,
3255) -> Result<(), String> {
3256    if derivative_guard <= 0.0 {
3257        return Ok(());
3258    }
3259    let n = age_entry.len();
3260    if age_exit.len() != n
3261        || eta_offset_entry.len() != n
3262        || eta_offset_exit.len() != n
3263        || derivative_offset_exit.len() != n
3264    {
3265        return Err(SurvivalConstructionError::IncompatibleDimensions {
3266            reason: "survival derivative-guard offset lengths must match".to_string(),
3267        }
3268        .into());
3269    }
3270    for i in 0..n {
3271        eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3272        eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3273        derivative_offset_exit[i] += derivative_guard;
3274    }
3275    Ok(())
3276}
3277
3278#[derive(Clone, Debug)]
3279pub struct LatentSurvivalBaselineOffsets {
3280    pub loaded_eta_entry: Array1<f64>,
3281    pub loaded_eta_exit: Array1<f64>,
3282    pub loaded_derivative_exit: Array1<f64>,
3283    pub unloaded_mass_entry: Array1<f64>,
3284    pub unloaded_mass_exit: Array1<f64>,
3285    pub unloaded_hazard_exit: Array1<f64>,
3286}
3287
3288pub fn build_latent_survival_baseline_offsets(
3289    age_entry: &Array1<f64>,
3290    age_exit: &Array1<f64>,
3291    cfg: &SurvivalBaselineConfig,
3292    loading: HazardLoading,
3293) -> Result<LatentSurvivalBaselineOffsets, String> {
3294    if age_entry.len() != age_exit.len() {
3295        return Err(
3296            "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3297        );
3298    }
3299
3300    fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3301        if shape.abs() < 1e-10 {
3302            // Taylor at shape=0 matching `gompertz_hazard_components`:
3303            //   H_G(t) = rate·t·(1 + (shape·t)/2 + (shape·t)²/6)
3304            //   h_G(t) = rate·(1 + shape·t + (shape·t)²/2)
3305            // Dropping the higher-order `shape*t` corrections silently
3306            // diverges this helper from its sibling for non-zero shape near
3307            // the cutoff and gives inconsistent loaded-vs-unloaded offsets.
3308            let x = shape * age;
3309            return (
3310                rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3311                rate * (1.0 + x + 0.5 * x * x),
3312            );
3313        }
3314        let shape_age = shape * age;
3315        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3316        let instant_hazard = rate * shape_age.exp();
3317        (cumulative_hazard, instant_hazard)
3318    }
3319
3320    let n = age_entry.len();
3321
3322    // Per-row 6-tuple is independent. Evaluate in parallel into a Vec and then
3323    // unpack into the six Array1 outputs in original order.
3324    let rows: Vec<[f64; 6]> = (0..n)
3325        .into_par_iter()
3326        .map(|i| -> Result<[f64; 6], String> {
3327            let entry = age_entry[i];
3328            let exit = age_exit[i];
3329            if !entry.is_finite()
3330                || !exit.is_finite()
3331                || entry <= 0.0
3332                || exit <= 0.0
3333                || exit < entry
3334            {
3335                return Err(format!(
3336                    "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3337                    i + 1
3338                ));
3339            }
3340            match loading {
3341                HazardLoading::Full => {
3342                    let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3343                    let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3344                    Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3345                }
3346                HazardLoading::LoadedVsUnloaded => {
3347                    if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3348                        return Err(format!(
3349                            "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3350                            survival_baseline_targetname(cfg.target)
3351                        ));
3352                    }
3353                    let rate = cfg.rate.ok_or_else(|| {
3354                        "gompertz-makeham latent survival is missing baseline rate".to_string()
3355                    })?;
3356                    let shape = cfg.shape.ok_or_else(|| {
3357                        "gompertz-makeham latent survival is missing baseline shape".to_string()
3358                    })?;
3359                    let makeham = cfg.makeham.ok_or_else(|| {
3360                        "gompertz-makeham latent survival is missing baseline makeham".to_string()
3361                    })?;
3362                    let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3363                    let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3364                    if !(loaded_entry.is_finite()
3365                        && loaded_entry > 0.0
3366                        && loaded_exit.is_finite()
3367                        && loaded_exit > 0.0
3368                        && loaded_hazard.is_finite()
3369                        && loaded_hazard > 0.0)
3370                    {
3371                        return Err(format!(
3372                            "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3373                            i + 1
3374                        ));
3375                    }
3376                    Ok([
3377                        loaded_entry.ln(),
3378                        loaded_exit.ln(),
3379                        loaded_hazard / loaded_exit,
3380                        makeham * entry,
3381                        makeham * exit,
3382                        makeham,
3383                    ])
3384                }
3385            }
3386        })
3387        .collect::<Result<Vec<_>, String>>()?;
3388
3389    let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3390    let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3391    let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3392    let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3393    let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3394    let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3395    for (i, row) in rows.into_iter().enumerate() {
3396        loaded_eta_entry[i] = row[0];
3397        loaded_eta_exit[i] = row[1];
3398        loaded_derivative_exit[i] = row[2];
3399        unloaded_mass_entry[i] = row[3];
3400        unloaded_mass_exit[i] = row[4];
3401        unloaded_hazard_exit[i] = row[5];
3402    }
3403
3404    Ok(LatentSurvivalBaselineOffsets {
3405        loaded_eta_entry,
3406        loaded_eta_exit,
3407        loaded_derivative_exit,
3408        unloaded_mass_entry,
3409        unloaded_mass_exit,
3410        unloaded_hazard_exit,
3411    })
3412}
3413
3414// ---------------------------------------------------------------------------
3415// Time wiggle construction
3416// ---------------------------------------------------------------------------
3417
3418pub fn build_survival_timewiggle_derivative_design(
3419    eta_exit: &Array1<f64>,
3420    derivative_exit: &Array1<f64>,
3421    knots: &Array1<f64>,
3422    degree: usize,
3423) -> Result<Array2<f64>, String> {
3424    let mut design_derivative_exit =
3425        monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3426    for i in 0..design_derivative_exit.nrows() {
3427        let chain = derivative_exit[i];
3428        for j in 0..design_derivative_exit.ncols() {
3429            design_derivative_exit[[i, j]] *= chain;
3430        }
3431    }
3432    Ok(design_derivative_exit)
3433}
3434
3435/// Build the dynamic "baseline as prior" timewiggle runtime.
3436///
3437/// The baseline offsets are used only to initialize the wiggle knot placement
3438/// on a stable scalar scale.  The exact survival family evaluates the resulting
3439/// monotone wiggle dynamically on the current time predictor h0(t):
3440///
3441///   h(t) = g(h0(t)),   g(z) = z + w(z).
3442///
3443/// No fixed `B(eta_baseline)` design is constructed here.
3444pub fn build_survival_timewiggle_from_baseline(
3445    eta_entry: &Array1<f64>,
3446    eta_exit: &Array1<f64>,
3447    derivative_exit: &Array1<f64>,
3448    cfg: &LinkWiggleFormulaSpec,
3449) -> Result<SurvivalTimeWiggleBuild, String> {
3450    if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3451        return Err(
3452            "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3453        );
3454    }
3455    // Guard: if baseline offsets are all zero (linear baseline), the timewiggle
3456    // construction is degenerate — it adds only a constant, not time-varying structure.
3457    let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3458        && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3459        && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3460    if all_zero {
3461        return Err(
3462            "timewiggle requires a non-linear scalar survival baseline target; \
3463             the provided baseline offsets are all zero (linear baseline)"
3464                .to_string(),
3465        );
3466    }
3467    let n = eta_exit.len();
3468    let mut seed = Array1::<f64>::zeros(2 * n);
3469    for i in 0..n {
3470        seed[i] = eta_entry[i];
3471        seed[n + i] = eta_exit[i];
3472    }
3473    // Use the smallest requested positive penalty order as the primary
3474    // coefficient-space penalty so the fitted wiggle penalty system matches
3475    // the public formula exactly, including the slope (`order = 1`) case.
3476    let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3477    let wiggle_cfg = WiggleBlockConfig {
3478        degree: cfg.degree,
3479        num_internal_knots: cfg.num_internal_knots,
3480        penalty_order: primary_order,
3481        double_penalty: cfg.double_penalty,
3482    };
3483    let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3484    append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3485    let ncols = combined_block.design.ncols();
3486    Ok(SurvivalTimeWiggleBuild {
3487        nullspace_dims: combined_block.nullspace_dims.clone(),
3488        penalties: {
3489            combined_block
3490                .penalties
3491                .into_iter()
3492                .map(|ps| ps.to_global(ncols))
3493                .collect()
3494        },
3495        knots,
3496        degree: cfg.degree,
3497        ncols,
3498    })
3499}
3500
3501pub fn append_zero_tail_columns(
3502    x_entry: &mut DesignMatrix,
3503    x_exit: &mut DesignMatrix,
3504    x_derivative: &mut DesignMatrix,
3505    tail_cols: usize,
3506) {
3507    if tail_cols == 0 {
3508        return;
3509    }
3510    // Wiggle tail columns are dense, so materialize everything to dense.
3511    // This only runs once at construction time when time-wiggles are active.
3512    fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3513        let old = dm.to_dense();
3514        let n = old.nrows();
3515        let p_base = old.ncols();
3516        let mut out = Array2::<f64>::zeros((n, p_base + tail));
3517        out.slice_mut(s![.., 0..p_base]).assign(&old);
3518        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3519    }
3520    append_dense(x_entry, tail_cols);
3521    append_dense(x_exit, tail_cols);
3522    append_dense(x_derivative, tail_cols);
3523}
3524
3525// ---------------------------------------------------------------------------
3526// Resolved config (from build output back to config for serialization)
3527// ---------------------------------------------------------------------------
3528
3529// ---------------------------------------------------------------------------
3530// Time-varying covariate template
3531// ---------------------------------------------------------------------------
3532
3533/// Build a time-varying covariate block by tensoring the covariate design
3534/// with a 1D B-spline basis on log(time).
3535pub fn build_time_varying_survival_covariate_template(
3536    age_entry: &Array1<f64>,
3537    age_exit: &Array1<f64>,
3538    time_k: usize,
3539    time_degree: usize,
3540    block_name: &str,
3541) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3542    if time_k < time_degree + 1 {
3543        return Err(format!(
3544            "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3545            time_degree + 1
3546        ));
3547    }
3548    let num_internal_knots = time_k - (time_degree + 1);
3549
3550    let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3551    let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3552
3553    let time_spec = BSplineBasisSpec {
3554        degree: time_degree,
3555        penalty_order: 2,
3556        knotspec: BSplineKnotSpec::Automatic {
3557            num_internal_knots: Some(num_internal_knots),
3558            placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3559        },
3560        double_penalty: false,
3561        identifiability: BSplineIdentifiability::None,
3562        boundary: OneDimensionalBoundary::Open,
3563        boundary_conditions: BSplineBoundaryConditions::default(),
3564    };
3565
3566    let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3567        .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3568    let time_design_exit = time_build.design.to_dense();
3569
3570    let knots = match &time_build.metadata {
3571        BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3572        _ => {
3573            return Err(format!(
3574                "{block_name} time-margin basis returned unexpected metadata type"
3575            ));
3576        }
3577    };
3578
3579    let time_build_entry = build_bspline_basis_1d(
3580        log_entry.view(),
3581        &BSplineBasisSpec {
3582            degree: time_degree,
3583            penalty_order: 2,
3584            knotspec: BSplineKnotSpec::Provided(knots.clone()),
3585            double_penalty: false,
3586            identifiability: BSplineIdentifiability::None,
3587            boundary: OneDimensionalBoundary::Open,
3588            boundary_conditions: BSplineBoundaryConditions::default(),
3589        },
3590    )
3591    .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3592    let time_design_entry = time_build_entry.design.to_dense();
3593    let p_time = time_design_exit.ncols();
3594    let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3595    // Per-row derivative-basis evaluation is independent; each row owns its
3596    // own small `deriv_buf`. par_chunks_mut over the (n × p_time) output rows
3597    // hands disjoint mutable row-slices to rayon workers.
3598    time_design_derivative_exit
3599        .as_slice_mut()
3600        .expect("zeros are contiguous")
3601        .par_chunks_mut(p_time)
3602        .enumerate()
3603        .try_for_each(|(i, row_out)| -> Result<(), String> {
3604            let mut deriv_buf = vec![0.0_f64; p_time];
3605            evaluate_bspline_derivative_scalar(
3606                log_exit[i],
3607                knots.view(),
3608                time_degree,
3609                &mut deriv_buf,
3610            )
3611            .map_err(|e| {
3612                format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3613            })?;
3614            let chain = 1.0 / age_exit[i].max(1e-12);
3615            for j in 0..p_time {
3616                row_out[j] = deriv_buf[j] * chain;
3617            }
3618            Ok(())
3619        })?;
3620
3621    Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3622        time_basis_entry: time_design_entry,
3623        time_basis_exit: time_design_exit,
3624        time_basis_derivative_exit: time_design_derivative_exit,
3625        time_penalties: time_build.penalties,
3626    })
3627}
3628
3629#[cfg(test)]
3630mod tests {
3631    use super::{
3632        SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3633        baseline_chain_rule_gradient, baseline_offset_theta_partials,
3634        build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3635        build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3636        evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3637        gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3638        marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3639        marginal_slope_baseline_offset_theta_partials,
3640        optimize_survival_baseline_config_with_gradient,
3641        optimize_survival_baseline_config_with_gradient_only,
3642        resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3643        survival_baseline_theta_from_config,
3644    };
3645    use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3646    use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3647    use crate::probability::normal_cdf;
3648    use ndarray::{Array1, Array2, array};
3649
3650    #[test]
3651    fn survival_timewiggle_keeps_requested_order_one_penalty() {
3652        let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3653        let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3654        let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3655        let cfg = LinkWiggleFormulaSpec {
3656            degree: 3,
3657            num_internal_knots: 4,
3658            penalty_orders: vec![1, 2, 3],
3659            double_penalty: false,
3660        };
3661
3662        let build =
3663            build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3664                .expect("build survival timewiggle");
3665
3666        assert_eq!(build.penalties.len(), 3);
3667        assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3668        assert!(build.ncols > 0);
3669    }
3670
3671    #[test]
3672    fn marginal_slope_time_anchor_defaults_to_median_exit() {
3673        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3674        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3675        let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3676            .expect("resolve marginal-slope default time anchor");
3677
3678        assert!(
3679            (anchor - 19.0).abs() <= 1e-12,
3680            "marginal-slope default anchor should be median exit, got {anchor}"
3681        );
3682    }
3683
3684    #[test]
3685    fn marginal_slope_time_anchor_honors_explicit_value() {
3686        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3687        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3688        let anchor =
3689            resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3690                .expect("resolve explicit marginal-slope time anchor");
3691
3692        assert!(
3693            (anchor - 7.5).abs() <= 1e-12,
3694            "explicit marginal-slope anchor should round-trip, got {anchor}"
3695        );
3696    }
3697
3698    /// Derivative-contract parity for the two public baseline optimizers.
3699    ///
3700    /// After the unification onto `run_baseline_theta_optimizer`, the
3701    /// gradient-only and gradient+Hessian entry points differ *only* in how
3702    /// much derivative information they hand the outer solver — not in the
3703    /// surface they minimize. We exercise that invariant on a known
3704    /// strictly-convex quadratic in θ-space (Weibull baseline: θ = (ln scale,
3705    /// ln shape)) whose unique minimizer is `theta_star`, supplying the same
3706    /// objective as `(f, ∇f)` and as `(f, ∇f, ∇²f)`. Both contracts must
3707    /// recover the same minimizer config, not weakened to pass.
3708    #[test]
3709    fn baseline_optimizer_contracts_agree_on_shared_surface() {
3710        // SPD curvature and interior minimizer in θ-space. A is well away from
3711        // singular so both the analytic-Hessian and BFGS paths see the same
3712        // unambiguous bowl; θ* sits comfortably inside the ±6 box around the
3713        // θ=(0,0) seed below.
3714        let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3715        let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3716
3717        // Seed config at θ=(0,0) (scale=shape=1). The Linear early-return path
3718        // is not exercised here; Weibull has a genuine 2-dim θ to optimize.
3719        let initial = SurvivalBaselineConfig {
3720            target: SurvivalBaselineTarget::Weibull,
3721            scale: Some(1.0),
3722            shape: Some(1.0),
3723            rate: None,
3724            makeham: None,
3725        };
3726
3727        // θ recovered from a returned Weibull config, via the exact inverse of
3728        // the config→θ map the optimizers use internally.
3729        let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3730            survival_baseline_theta_from_config(cfg)
3731                .expect("config→θ")
3732                .expect("Weibull config has a θ")
3733        };
3734
3735        // Shared quadratic surface, evaluated by mapping config→θ so every
3736        // contract sees the identical objective.
3737        let curvature_cost = curvature.clone();
3738        let star_cost = theta_star.clone();
3739        let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3740            let theta = survival_baseline_theta_from_config(cfg)?
3741                .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3742            let d = &theta - &star_cost;
3743            let ad = curvature_cost.dot(&d);
3744            Ok(0.5 * d.dot(&ad))
3745        };
3746
3747        let curvature_grad = curvature.clone();
3748        let star_grad = theta_star.clone();
3749        let cost_for_grad = cost_at.clone();
3750        let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3751            &initial,
3752            "baseline parity (gradient-only)",
3753            move |cfg| {
3754                let cost = cost_for_grad(cfg)?;
3755                let theta = survival_baseline_theta_from_config(cfg)?
3756                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3757                let gradient = curvature_grad.dot(&(&theta - &star_grad));
3758                Ok((cost, gradient))
3759            },
3760        )
3761        .expect("gradient-only baseline optimization converges");
3762
3763        let curvature_hess = curvature.clone();
3764        let star_hess = theta_star.clone();
3765        let cost_for_hess = cost_at.clone();
3766        let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3767            &initial,
3768            "baseline parity (gradient+Hessian)",
3769            move |cfg| {
3770                let cost = cost_for_hess(cfg)?;
3771                let theta = survival_baseline_theta_from_config(cfg)?
3772                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3773                let gradient = curvature_hess.dot(&(&theta - &star_hess));
3774                Ok((cost, gradient, curvature_hess.clone()))
3775            },
3776        )
3777        .expect("gradient+Hessian baseline optimization converges");
3778
3779        let theta_grad_only = recovered_theta(&result_grad_only);
3780        let theta_grad_hess = recovered_theta(&result_grad_hess);
3781
3782        // Each contract recovers the true minimizer. 2e-3 is a safe,
3783        // un-weakened bound; both gradient paths land far tighter.
3784        for (label, theta) in [
3785            ("gradient-only", &theta_grad_only),
3786            ("gradient+Hessian", &theta_grad_hess),
3787        ] {
3788            let err = (theta - &theta_star)
3789                .mapv(f64::abs)
3790                .fold(0.0_f64, |a, &v| a.max(v));
3791            assert!(
3792                err <= 2e-3,
3793                "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3794            );
3795        }
3796
3797        // Cross-contract agreement: the three results must coincide, since the
3798        // only difference between the entry points is the derivative contract,
3799        // never the surface they minimize.
3800        let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3801            (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3802        };
3803        assert!(
3804            pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3805            "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3806        );
3807    }
3808
3809    #[test]
3810    fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3811        let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3812        let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3813        let requested_degree = 3;
3814        let num_internal_knots = 1;
3815
3816        let built = build_survival_time_basis(
3817            &age_entry,
3818            &age_exit,
3819            SurvivalTimeBasisConfig::ISpline {
3820                degree: requested_degree,
3821                knots: Array1::zeros(0),
3822                keep_cols: Vec::new(),
3823                smooth_lambda: 1e-2,
3824            },
3825            Some((num_internal_knots, 1e-2)),
3826        )
3827        .expect("automatic cubic ispline with one interior knot builds");
3828
3829        let working_degree = requested_degree + 1;
3830        let knots = built.knots.expect("resolved ispline knots");
3831        assert_eq!(
3832            knots.len(),
3833            num_internal_knots + 2 * (working_degree + 1),
3834            "I-spline automatic knots must be clamped for the working B-spline degree"
3835        );
3836        assert_eq!(built.degree, Some(requested_degree));
3837        assert!(built.x_exit_time.ncols() > 0);
3838        assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3839        assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3840    }
3841
3842    #[test]
3843    fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3844        let age_entry = array![1.0_f64, 1.0, 1.0];
3845        let age_exit = array![4.0_f64, 4.0, 4.0];
3846        let left = 1.0_f64.ln();
3847        let right = 4.0_f64.ln();
3848        let mid = left + 0.5 * (right - left);
3849        let knots = array![left, left, left, left, mid, right, right, right, right];
3850
3851        let built = build_survival_time_basis(
3852            &age_entry,
3853            &age_exit,
3854            SurvivalTimeBasisConfig::ISpline {
3855                degree: 2,
3856                knots,
3857                keep_cols: Vec::new(),
3858                smooth_lambda: 1e-2,
3859            },
3860            None,
3861        )
3862        .expect("build right-boundary ispline time basis");
3863
3864        let derivative = built.x_derivative_time.as_dense_cow();
3865        let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3866        assert!(
3867            max_abs > 1e-8,
3868            "right-boundary I-spline derivative must use the left-hand endpoint slope"
3869        );
3870        for row in derivative.rows() {
3871            assert!(
3872                row.iter().any(|v| *v > 1e-8),
3873                "each row at the right boundary needs a positive hazard derivative"
3874            );
3875        }
3876    }
3877
3878    #[test]
3879    fn ispline_time_penalty_is_psd_under_nontrivial_keep_cols() {
3880        // PSD-invariant forward guard for the gam#979 survival hang. The I-spline
3881        // value-space curvature penalty on the increment coefficients is the
3882        // congruence `S_I = Lᵀ S_B[1:,1:] L`. When identifiability drops columns,
3883        // the retained block MUST be taken as a PRINCIPAL SUBMATRIX of the FULL
3884        // congruence (congruence first, column selection second). The historical
3885        // regression assembled the reduced penalty in the wrong order, producing
3886        // a strongly INDEFINITE matrix (measured `s0_min_eval = −9.8e7`); an
3887        // indefinite time penalty makes `½γᵀ S_I γ` unbounded below, the inner
3888        // joint-Newton follows the divergence, and the outer REML never
3889        // terminates — the survival marginal-slope hang.
3890        //
3891        // This test exercises the reduction with a NON-TRIVIAL `keep_cols`
3892        // (a proper subset, an interior column dropped) and asserts the assembled
3893        // penalty satisfies the PSD contract the fix guarantees. It locks the
3894        // invariant on the shipped code path so a future reassembly that
3895        // reintroduces an indefinite reduction is caught at construction rather
3896        // than silently as an outer-loop hang. (It is a forward invariant lock,
3897        // not a bit-exact replay of the removed buggy assembly.)
3898        let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3899        let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3900        let left = 1.0_f64.ln();
3901        let right = 21.0_f64.ln();
3902        let q1 = left + 0.25 * (right - left);
3903        let mid = left + 0.5 * (right - left);
3904        let q3 = left + 0.75 * (right - left);
3905        // Degree-2 I-spline with three interior knots -> a value-space basis wide
3906        // enough to drop an interior column and still leave the reduction
3907        // non-trivial (p_time < p_time_full).
3908        let knots = array![
3909            left, left, left, left, q1, mid, q3, right, right, right, right
3910        ];
3911
3912        // Discover the full basis width by building with all columns retained.
3913        let full = build_survival_time_basis(
3914            &age_entry,
3915            &age_exit,
3916            SurvivalTimeBasisConfig::ISpline {
3917                degree: 2,
3918                knots: knots.clone(),
3919                keep_cols: Vec::new(),
3920                smooth_lambda: 1e-2,
3921            },
3922            None,
3923        )
3924        .expect("build full-width ispline time basis");
3925        let p_time_full = full
3926            .keep_cols
3927            .as_ref()
3928            .map(|k| k.len())
3929            .unwrap_or_else(|| full.x_exit_time.ncols());
3930        assert!(
3931            p_time_full >= 3,
3932            "test needs at least 3 shape-varying columns to drop an interior one; got {p_time_full}"
3933        );
3934
3935        // Retain everything except one interior column, forcing the
3936        // principal-submatrix-of-the-full-congruence path.
3937        let keep_cols: Vec<usize> = (0..p_time_full).filter(|&j| j != 1).collect();
3938
3939        let built = build_survival_time_basis(
3940            &age_entry,
3941            &age_exit,
3942            SurvivalTimeBasisConfig::ISpline {
3943                degree: 2,
3944                knots,
3945                keep_cols: keep_cols.clone(),
3946                smooth_lambda: 1e-2,
3947            },
3948            None,
3949        )
3950        .expect(
3951            "reduced ispline penalty must build (PSD contract must accept the \
3952             congruence-first / select-second ordering)",
3953        );
3954
3955        assert_eq!(
3956            built.penalties.len(),
3957            1,
3958            "the ispline time basis should carry exactly one curvature penalty"
3959        );
3960        let s = &built.penalties[0];
3961        assert_eq!(s.nrows(), keep_cols.len());
3962        assert_eq!(s.ncols(), keep_cols.len());
3963
3964        let (evals, _) =
3965            gam_linalg::faer_ndarray::FaerEigh::eigh(s, faer::Side::Lower).expect("eigh of penalty");
3966        let evals_slice = evals.as_slice().expect("contiguous eigenvalues");
3967        let max_abs = evals_slice
3968            .iter()
3969            .copied()
3970            .fold(0.0_f64, |a, b| a.max(b.abs()))
3971            .max(1.0);
3972        let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
3973        let tol = -100.0 * (s.nrows() as f64) * f64::EPSILON * max_abs;
3974        assert!(
3975            min_ev >= tol,
3976            "reduced I-spline time penalty must be PSD (gam#979): min eigenvalue \
3977             {min_ev:.3e} < tol {tol:.3e}, max|eig| {max_abs:.3e}"
3978        );
3979    }
3980
3981    #[test]
3982    fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3983        let cfg = SurvivalBaselineConfig {
3984            target: SurvivalBaselineTarget::GompertzMakeham,
3985            scale: None,
3986            shape: Some(0.07),
3987            rate: Some(0.012),
3988            makeham: Some(0.003),
3989        };
3990        let age = 11.5;
3991        let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3992            .expect("evaluate marginal-slope gompertz-makeham baseline");
3993        let shape = cfg.shape.expect("shape");
3994        let rate = cfg.rate.expect("rate");
3995        let makeham = cfg.makeham.expect("makeham");
3996        let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3997        let instant_hazard = makeham + rate * (shape * age).exp();
3998        let expected_survival = (-cumulative_hazard).exp();
3999        let actual_survival = normal_cdf(-q);
4000        assert!((actual_survival - expected_survival).abs() <= 1e-12);
4001
4002        let h = 1e-5;
4003        let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
4004            .expect("q plus")
4005            .0;
4006        let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
4007            .expect("q minus")
4008            .0;
4009        let fd = (q_plus - q_minus) / (2.0 * h);
4010        assert!((q_derivative - fd).abs() <= 1e-7);
4011        assert!(instant_hazard > 0.0);
4012    }
4013
4014    #[test]
4015    fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
4016        // Regression for #1024: the probit/marginal-slope baseline evaluator must
4017        // be defined at the survival-curve origin t = 0 (where S0(0) = 1, so the
4018        // probit index q(0) = -Phi^{-1}(1) = -inf and there is no finite offset),
4019        // exactly like its log-cumulative-hazard sibling `evaluate_survival_baseline`.
4020        // Before the fix the shared `age <= 0` hazard guard aborted, so a survival
4021        // prediction grid whose first node is the origin (the `Surv(time, event)`
4022        // right-censored shorthand) could not be evaluated for the location-scale /
4023        // marginal-slope likelihoods.
4024        let configs = [
4025            SurvivalBaselineConfig {
4026                target: SurvivalBaselineTarget::Linear,
4027                scale: None,
4028                shape: None,
4029                rate: None,
4030                makeham: None,
4031            },
4032            SurvivalBaselineConfig {
4033                target: SurvivalBaselineTarget::Weibull,
4034                scale: Some(2.5),
4035                shape: Some(1.3),
4036                rate: None,
4037                makeham: None,
4038            },
4039            SurvivalBaselineConfig {
4040                target: SurvivalBaselineTarget::Gompertz,
4041                scale: None,
4042                shape: Some(0.05),
4043                rate: Some(0.01),
4044                makeham: None,
4045            },
4046            SurvivalBaselineConfig {
4047                target: SurvivalBaselineTarget::GompertzMakeham,
4048                scale: None,
4049                shape: Some(0.07),
4050                rate: Some(0.012),
4051                makeham: Some(0.003),
4052            },
4053        ];
4054        for cfg in &configs {
4055            // The probit baseline returns a finite zero offset at the origin for
4056            // every target (the survival surface anchors S(0) = 1 directly).
4057            let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
4058                .expect("marginal-slope baseline must be evaluable at the origin");
4059            assert_eq!(q0, 0.0);
4060            assert_eq!(q0_derivative, 0.0);
4061
4062            // The log-cumulative-hazard sibling is likewise finite at the origin —
4063            // this parity is the whole point (the transformation likelihood already
4064            // worked because it rides this evaluator).
4065            let (eta0, eta0_derivative) =
4066                evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
4067            assert!(eta0_derivative.is_finite());
4068            assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
4069
4070            // The batched offset builder must not abort when a query exit age is the
4071            // origin (this is the exact call the location-scale predict path makes on
4072            // the default surface grid). Entry stays at the origin, exit spans 0 -> t.
4073            let age_entry = array![0.0, 0.0];
4074            let age_exit = array![0.0, 1.5];
4075            let (entry, exit, derivative) =
4076                build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
4077                    .expect("probit baseline offsets must build through the origin");
4078            assert!(entry.iter().all(|v| v.is_finite()));
4079            assert!(exit.iter().all(|v| v.is_finite()));
4080            assert!(derivative.iter().all(|v| v.is_finite()));
4081            // The origin exit column carries no probit offset.
4082            assert_eq!(exit[0], 0.0);
4083        }
4084    }
4085
4086    #[test]
4087    fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
4088        let cfg = SurvivalBaselineConfig {
4089            target: SurvivalBaselineTarget::GompertzMakeham,
4090            scale: None,
4091            shape: Some(0.03),
4092            rate: Some(0.01),
4093            makeham: Some(0.002),
4094        };
4095        let age_entry = array![2.0, 4.0];
4096        let age_exit = array![5.0, 9.0];
4097        let (entry, exit, derivative) =
4098            build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
4099                .expect("marginal-slope baseline offsets");
4100        for i in 0..age_entry.len() {
4101            let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
4102                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4103                    * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
4104            let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
4105                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
4106                    * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
4107            assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
4108            assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
4109            assert!(derivative[i].is_finite() && derivative[i] > 0.0);
4110        }
4111    }
4112
4113    fn fd_marginal_slope_baseline_offset(
4114        age: f64,
4115        cfg: &SurvivalBaselineConfig,
4116        steps: &[f64],
4117    ) -> Vec<(f64, f64)> {
4118        let theta = survival_baseline_theta_from_config(cfg)
4119            .expect("theta")
4120            .expect("non-linear baseline");
4121        assert_eq!(
4122            steps.len(),
4123            theta.len(),
4124            "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
4125        );
4126        (0..theta.len())
4127            .map(|k| {
4128                let h = steps[k];
4129                let mut theta_plus = theta.clone();
4130                theta_plus[k] += h;
4131                let mut theta_minus = theta.clone();
4132                theta_minus[k] -= h;
4133                let cfg_plus =
4134                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4135                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4136                    .expect("minus cfg");
4137                let (q_p, qt_p) =
4138                    evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
4139                let (q_m, qt_m) =
4140                    evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
4141                ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
4142            })
4143            .collect()
4144    }
4145
4146    #[test]
4147    fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
4148        let cfg = SurvivalBaselineConfig {
4149            target: SurvivalBaselineTarget::GompertzMakeham,
4150            scale: None,
4151            shape: Some(0.04),
4152            rate: Some(0.013),
4153            makeham: Some(0.002),
4154        };
4155        let age = 17.0;
4156        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4157            .expect("partials")
4158            .expect("nonlinear");
4159        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
4160        assert_eq!(analytic.len(), fd.len());
4161        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4162            assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
4163            assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
4164        }
4165    }
4166
4167    #[test]
4168    fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4169        let cfg = SurvivalBaselineConfig {
4170            target: SurvivalBaselineTarget::GompertzMakeham,
4171            scale: None,
4172            shape: Some(1e-14),
4173            rate: Some(0.013),
4174            makeham: Some(0.002),
4175        };
4176        let age = 17.0;
4177        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4178            .expect("partials")
4179            .expect("nonlinear");
4180        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4181        assert_eq!(analytic.len(), fd.len());
4182        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4183            assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4184            assert_close(
4185                *aqt,
4186                *fqt,
4187                1e-5,
4188                &format!("near-zero gm-probit q' theta[{k}]"),
4189            );
4190        }
4191    }
4192
4193    fn shifted_quadratic_offset_residuals(
4194        age_entry: ndarray::ArrayView1<'_, f64>,
4195        age_exit: ndarray::ArrayView1<'_, f64>,
4196        base_cfg: &SurvivalBaselineConfig,
4197        candidate_cfg: &SurvivalBaselineConfig,
4198        base: &OffsetChannelResiduals,
4199        curvatures: &OffsetChannelCurvatures,
4200    ) -> OffsetChannelResiduals {
4201        let n = age_exit.len();
4202        let mut entry = base.entry.clone();
4203        let mut exit = base.exit.clone();
4204        let mut derivative = base.derivative.clone();
4205        for row in 0..n {
4206            let (_, base_exit, base_deriv) =
4207                baseline_marginal_slope_channels(age_exit[row], base_cfg);
4208            let (_, cand_exit, cand_deriv) =
4209                baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4210            let base_entry = if base.entry[row] == 0.0 {
4211                0.0
4212            } else {
4213                baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4214            };
4215            let cand_entry = if base.entry[row] == 0.0 {
4216                0.0
4217            } else {
4218                baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4219            };
4220            let delta = [
4221                cand_entry - base_entry,
4222                cand_exit - base_exit,
4223                cand_deriv - base_deriv,
4224            ];
4225            let mut shift = [0.0; 3];
4226            for i in 0..3 {
4227                for j in 0..3 {
4228                    shift[i] += curvatures.rows[row][i][j] * delta[j];
4229                }
4230            }
4231            if base.entry[row] != 0.0 {
4232                entry[row] += shift[0];
4233            }
4234            exit[row] += shift[1];
4235            derivative[row] += shift[2];
4236        }
4237        OffsetChannelResiduals {
4238            entry,
4239            exit,
4240            derivative,
4241            right: base.right.clone(),
4242        }
4243    }
4244
4245    fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4246        let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4247        (q, q, q_t)
4248    }
4249
4250    #[test]
4251    fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4252        let cfg = SurvivalBaselineConfig {
4253            target: SurvivalBaselineTarget::GompertzMakeham,
4254            scale: None,
4255            shape: Some(0.025),
4256            rate: Some(0.012),
4257            makeham: Some(0.003),
4258        };
4259        let theta = survival_baseline_theta_from_config(&cfg)
4260            .expect("theta")
4261            .expect("nonlinear");
4262        let age_entry = array![2.5, 0.0, 5.0];
4263        let age_exit = array![7.5, 11.0, 15.0];
4264        let base_residuals = OffsetChannelResiduals {
4265            entry: array![0.2, 0.0, -0.1],
4266            exit: array![0.6, -0.3, 0.4],
4267            derivative: array![-0.5, 0.25, 0.15],
4268            right: Array1::<f64>::zeros(3),
4269        };
4270        let curvatures = OffsetChannelCurvatures {
4271            rows: vec![
4272                [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4273                [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4274                [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4275            ],
4276        };
4277        let analytic = marginal_slope_baseline_chain_rule_hessian(
4278            age_entry.view(),
4279            age_exit.view(),
4280            &cfg,
4281            &base_residuals,
4282            &curvatures,
4283        )
4284        .expect("hessian")
4285        .expect("nonlinear");
4286
4287        let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4288            let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4289                .expect("candidate cfg");
4290            let residuals = shifted_quadratic_offset_residuals(
4291                age_entry.view(),
4292                age_exit.view(),
4293                &cfg,
4294                &candidate,
4295                &base_residuals,
4296                &curvatures,
4297            );
4298            marginal_slope_baseline_chain_rule_gradient(
4299                age_entry.view(),
4300                age_exit.view(),
4301                &candidate,
4302                &residuals,
4303            )
4304            .expect("gradient")
4305            .expect("nonlinear")
4306        };
4307
4308        for j in 0..theta.len() {
4309            let step = if j == 1 { 2e-5 } else { 1e-5 };
4310            let mut plus = theta.clone();
4311            plus[j] += step;
4312            let mut minus = theta.clone();
4313            minus[j] -= step;
4314            let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4315            for i in 0..theta.len() {
4316                assert_close(
4317                    analytic[[i, j]],
4318                    fd_col[i],
4319                    2e-5,
4320                    &format!("baseline Hessian ({i},{j})"),
4321                );
4322            }
4323        }
4324    }
4325
4326    #[test]
4327    fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4328        let cfg = SurvivalBaselineConfig {
4329            target: SurvivalBaselineTarget::GompertzMakeham,
4330            scale: None,
4331            shape: Some(0.03),
4332            rate: Some(0.01),
4333            makeham: Some(0.002),
4334        };
4335        let age_entry = array![3.0, 6.0];
4336        let age_exit = array![8.0, 12.0];
4337        let residuals = OffsetChannelResiduals {
4338            exit: array![0.7, -0.2],
4339            entry: array![0.1, 0.4],
4340            derivative: array![1.3, -0.6],
4341            right: Array1::<f64>::zeros(2),
4342        };
4343        let grad = marginal_slope_baseline_chain_rule_gradient(
4344            age_entry.view(),
4345            age_exit.view(),
4346            &cfg,
4347            &residuals,
4348        )
4349        .expect("gradient")
4350        .expect("nonlinear");
4351
4352        let mut expected = Array1::<f64>::zeros(3);
4353        for i in 0..age_exit.len() {
4354            let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4355                .expect("exit partials")
4356                .expect("nonlinear");
4357            let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4358                .expect("entry partials")
4359                .expect("nonlinear");
4360            for k in 0..3 {
4361                expected[k] += residuals.exit[i] * exit_partials[k].0
4362                    + residuals.derivative[i] * exit_partials[k].1
4363                    + residuals.entry[i] * entry_partials[k].0;
4364            }
4365        }
4366        for k in 0..3 {
4367            assert_close(
4368                grad[k],
4369                expected[k],
4370                1e-12,
4371                &format!("gm-probit chain gradient theta[{k}]"),
4372            );
4373        }
4374    }
4375
4376    /// Parity guard for the shared `baseline_chain_rule_gradient_with_partials`
4377    /// engine (issue #429): both public gradient functions delegate to it with a
4378    /// different partials provider. This test reimplements the pre-unification
4379    /// inline contraction (the serial reference) and asserts bit-for-bit equality
4380    /// against the unified engine's output for BOTH providers on the same data —
4381    /// the RP-eta provider (`baseline_offset_theta_partials`) and the probit-q
4382    /// provider (`marginal_slope_baseline_offset_theta_partials`). Any drift in
4383    /// the extracted contraction (length checks, theta-dim probe, exit/derivative
4384    /// combination, or entry gating) breaks this with an exact (0.0) tolerance.
4385    #[test]
4386    fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4387        let cfg = SurvivalBaselineConfig {
4388            target: SurvivalBaselineTarget::GompertzMakeham,
4389            scale: None,
4390            shape: Some(0.028),
4391            rate: Some(0.011),
4392            makeham: Some(0.0025),
4393        };
4394        // Mixed entry interval: row 1 is origin-entry (age_entry==0, r_entry==0)
4395        // to exercise the entry-gating branch in the shared engine.
4396        let age_entry = array![3.0, 0.0, 5.5];
4397        let age_exit = array![8.0, 12.0, 16.0];
4398        let residuals = OffsetChannelResiduals {
4399            exit: array![0.7, -0.2, 0.45],
4400            entry: array![0.1, 0.0, -0.3],
4401            derivative: array![1.3, -0.6, 0.2],
4402            right: Array1::<f64>::zeros(3),
4403        };
4404
4405        // Serial reference contraction matching the original inline body. Mirrors
4406        // the engine's exit+derivative/entry split and origin-entry gating.
4407        let reference_gradient = |partials: &dyn Fn(
4408            f64,
4409            &SurvivalBaselineConfig,
4410        )
4411            -> Result<Option<Vec<(f64, f64)>>, String>|
4412         -> Array1<f64> {
4413            let theta_dim = partials(age_exit[0], &cfg)
4414                .expect("probe partials")
4415                .expect("nonlinear")
4416                .len();
4417            let mut acc = Array1::<f64>::zeros(theta_dim);
4418            for i in 0..age_exit.len() {
4419                let p_exit = partials(age_exit[i], &cfg)
4420                    .expect("exit partials")
4421                    .expect("nonlinear");
4422                let r_x = residuals.exit[i];
4423                let r_d = residuals.derivative[i];
4424                for k in 0..theta_dim {
4425                    acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4426                }
4427                let r_e = residuals.entry[i];
4428                if r_e != 0.0 {
4429                    let p_entry = partials(age_entry[i], &cfg)
4430                        .expect("entry partials")
4431                        .expect("nonlinear");
4432                    for k in 0..theta_dim {
4433                        acc[k] += r_e * p_entry[k].0;
4434                    }
4435                }
4436            }
4437            acc
4438        };
4439
4440        // RP-eta provider parity.
4441        let rp_engine = baseline_chain_rule_gradient(
4442            age_entry.view(),
4443            age_exit.view(),
4444            age_exit.view(),
4445            &cfg,
4446            &residuals,
4447        )
4448        .expect("rp gradient")
4449        .expect("rp nonlinear");
4450        let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4451        assert_eq!(rp_engine.len(), rp_reference.len());
4452        for k in 0..rp_engine.len() {
4453            assert_close(
4454                rp_engine[k],
4455                rp_reference[k],
4456                0.0,
4457                &format!("rp engine vs inline reference theta[{k}]"),
4458            );
4459        }
4460
4461        // Probit-q provider parity.
4462        let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4463            age_entry.view(),
4464            age_exit.view(),
4465            &cfg,
4466            &residuals,
4467        )
4468        .expect("probit gradient")
4469        .expect("probit nonlinear");
4470        let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4471        assert_eq!(probit_engine.len(), probit_reference.len());
4472        for k in 0..probit_engine.len() {
4473            assert_close(
4474                probit_engine[k],
4475                probit_reference[k],
4476                0.0,
4477                &format!("probit engine vs inline reference theta[{k}]"),
4478            );
4479        }
4480    }
4481
4482    /// Finite-difference verification of the analytic θ-gradient used by the
4483    /// survival location-scale workflow path.
4484    ///
4485    /// At a converged β, the envelope theorem reduces the profile-NLL gradient
4486    /// w.r.t. the baseline-config θ to a per-row residual contraction against
4487    /// the per-row offset-channel partials ∂o/∂θ:
4488    ///
4489    ///   d(NLL)/dθ_k = Σ_i [ r_X[i]·∂η_exit/∂θ_k + r_E[i]·∂η_entry/∂θ_k
4490    ///                       + r_D[i]·∂o_D_exit/∂θ_k ]
4491    ///
4492    /// (`baseline_chain_rule_gradient`). Because β is fixed, an explicit loss
4493    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i; θ) + r_E[i]·η(t_entry_i; θ)
4494    ///              + r_D[i]·o_D(t_exit_i; θ) ]`
4495    /// has gradient identically equal to the chain-rule output. Comparing the
4496    /// analytic gradient to a central-difference of L over `evaluate_survival_baseline`
4497    /// therefore exercises every piece of the chain rule (incl. the Gompertz
4498    /// rate / shape / Makeham partials at both entry and exit ages) without
4499    /// needing the full location-scale fit pipeline inside this unit-test
4500    /// module. If the chain rule disagrees with FD here, the workflow's
4501    /// gradient is wrong by exactly the same amount.
4502    #[test]
4503    fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4504        let cfg = SurvivalBaselineConfig {
4505            target: SurvivalBaselineTarget::GompertzMakeham,
4506            scale: None,
4507            shape: Some(0.05),
4508            rate: Some(0.012),
4509            makeham: Some(0.003),
4510        };
4511        // n = 8 small synthetic dataset spanning a realistic age range.
4512        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4513        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4514        // Synthetic per-row NLL residuals on the three offset channels. Mix of
4515        // signs / magnitudes / one zero-entry row (origin entry → r_E=0).
4516        let residuals = OffsetChannelResiduals {
4517            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4518            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4519            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4520            right: Array1::<f64>::zeros(8),
4521        };
4522
4523        let analytic = baseline_chain_rule_gradient(
4524            age_entry.view(),
4525            age_exit.view(),
4526            age_exit.view(),
4527            &cfg,
4528            &residuals,
4529        )
4530        .expect("analytic gradient ok")
4531        .expect("GM baseline has a θ-gradient");
4532        assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4533
4534        // Evaluate the offset-projected loss at a perturbed θ. Mirrors the
4535        // chain rule's algebra: the entry channel is only added for rows whose
4536        // r_E is nonzero (matching baseline_chain_rule_gradient's gating that
4537        // avoids calling evaluate_survival_baseline at age 0 for origin-entry
4538        // rows).
4539        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4540            let mut acc = 0.0;
4541            for i in 0..age_exit.len() {
4542                let (eta_exit_i, od_exit_i) =
4543                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4544                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4545                if residuals.entry[i] != 0.0 {
4546                    let (eta_entry_i, _) =
4547                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4548                    acc += residuals.entry[i] * eta_entry_i;
4549                }
4550            }
4551            acc
4552        };
4553
4554        let theta0 = survival_baseline_theta_from_config(&cfg)
4555            .expect("theta seed")
4556            .expect("GM has θ");
4557        // Spec requested δ = 1e-4 per axis. Use central differences over θ.
4558        let delta = 1e-4;
4559        let mut fd = Array1::<f64>::zeros(analytic.len());
4560        for k in 0..analytic.len() {
4561            let mut theta_plus = theta0.clone();
4562            theta_plus[k] += delta;
4563            let mut theta_minus = theta0.clone();
4564            theta_minus[k] -= delta;
4565            let cfg_plus =
4566                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4567            let cfg_minus =
4568                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4569            let lp = loss_at_cfg(&cfg_plus);
4570            let lm = loss_at_cfg(&cfg_minus);
4571            fd[k] = (lp - lm) / (2.0 * delta);
4572        }
4573
4574        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4575        let max_err = analytic
4576            .iter()
4577            .zip(fd.iter())
4578            .map(|(a, b)| (a - b).abs())
4579            .fold(0.0_f64, f64::max);
4580        let rel = max_err / (analytic_norm + 1e-12);
4581        // Print so the deliverable can quote the exact max-error number.
4582        eprintln!(
4583            "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4584             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4585             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4586        );
4587        assert!(
4588            rel < 1e-2,
4589            "analytic θ-gradient disagrees with central FD beyond 1%: \
4590             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4591             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4592        );
4593    }
4594
4595    /// Weibull (dim=2) companion to
4596    /// `gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference`.
4597    ///
4598    /// This is the FD gate for the analytic outer θ-gradient that the
4599    /// transformation/Weibull survival baseline optimizers now feed to BFGS
4600    /// (`optimize_survival_baseline_config_with_gradient_only`). At a *fixed* β
4601    /// the profile-NLL surface is
4602    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i;θ) + r_E[i]·η(t_entry_i;θ)
4603    ///              + r_D[i]·o_D(t_exit_i;θ) ]`,
4604    /// whose exact gradient is `baseline_chain_rule_gradient`. Comparing it to a
4605    /// central difference of `L` over `evaluate_survival_baseline` exercises the
4606    /// Weibull scale/shape partials at both entry and exit ages. If this
4607    /// disagrees with FD, the workflow's outer gradient is wrong by the same
4608    /// amount.
4609    #[test]
4610    fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4611        let cfg = SurvivalBaselineConfig {
4612            target: SurvivalBaselineTarget::Weibull,
4613            scale: Some(11.0),
4614            shape: Some(1.4),
4615            rate: None,
4616            makeham: None,
4617        };
4618        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4619        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4620        let residuals = OffsetChannelResiduals {
4621            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4622            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4623            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4624            right: Array1::<f64>::zeros(8),
4625        };
4626
4627        let analytic = baseline_chain_rule_gradient(
4628            age_entry.view(),
4629            age_exit.view(),
4630            age_exit.view(),
4631            &cfg,
4632            &residuals,
4633        )
4634        .expect("analytic gradient ok")
4635        .expect("Weibull baseline has a θ-gradient");
4636        assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4637
4638        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4639            let mut acc = 0.0;
4640            for i in 0..age_exit.len() {
4641                let (eta_exit_i, od_exit_i) =
4642                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4643                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4644                if residuals.entry[i] != 0.0 {
4645                    let (eta_entry_i, _) =
4646                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4647                    acc += residuals.entry[i] * eta_entry_i;
4648                }
4649            }
4650            acc
4651        };
4652
4653        let theta0 = survival_baseline_theta_from_config(&cfg)
4654            .expect("theta seed")
4655            .expect("Weibull has θ");
4656        let delta = 1e-4;
4657        let mut fd = Array1::<f64>::zeros(analytic.len());
4658        for k in 0..analytic.len() {
4659            let mut theta_plus = theta0.clone();
4660            theta_plus[k] += delta;
4661            let mut theta_minus = theta0.clone();
4662            theta_minus[k] -= delta;
4663            let cfg_plus =
4664                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4665            let cfg_minus =
4666                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4667            let lp = loss_at_cfg(&cfg_plus);
4668            let lm = loss_at_cfg(&cfg_minus);
4669            fd[k] = (lp - lm) / (2.0 * delta);
4670        }
4671
4672        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4673        let max_err = analytic
4674            .iter()
4675            .zip(fd.iter())
4676            .map(|(a, b)| (a - b).abs())
4677            .fold(0.0_f64, f64::max);
4678        let rel = max_err / (analytic_norm + 1e-12);
4679        eprintln!(
4680            "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4681             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4682             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4683        );
4684        assert!(
4685            rel < 1e-2,
4686            "analytic θ-gradient disagrees with central FD beyond 1%: \
4687             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4688             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4689        );
4690    }
4691
4692    // ─── baseline_offset_theta_partials — analytic vs central-difference ─
4693
4694    /// Central-difference of (eta, o_D) at fixed age wrt each θ component in
4695    /// the theta layout defined by `survival_baseline_theta_from_config`.
4696    ///
4697    /// `steps` is per-θ-component: the caller picks the step size appropriate
4698    /// for each channel. Gompertz / Gompertz–Makeham need a tiny step on the
4699    /// shape channel near the Taylor pivot |shape| < 1e-10 (so θ±h stays on
4700    /// the same branch), but a normal-scale step on log_rate / log_makeham;
4701    /// using the tiny shape-step on every channel corrupts the log_rate
4702    /// channel with `eps/(2h)` cancellation noise and has nothing to do with
4703    /// correctness of the analytic derivative.
4704    fn fd_baseline_offset(
4705        age: f64,
4706        cfg: &SurvivalBaselineConfig,
4707        steps: &[f64],
4708    ) -> Vec<(f64, f64)> {
4709        let theta = survival_baseline_theta_from_config(cfg)
4710            .expect("theta")
4711            .expect("non-linear baseline");
4712        assert_eq!(
4713            steps.len(),
4714            theta.len(),
4715            "fd_baseline_offset: step vector length must match θ dimension"
4716        );
4717        (0..theta.len())
4718            .map(|k| {
4719                let h = steps[k];
4720                let mut theta_plus = theta.clone();
4721                theta_plus[k] += h;
4722                let mut theta_minus = theta.clone();
4723                theta_minus[k] -= h;
4724                let cfg_plus =
4725                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4726                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4727                    .expect("minus cfg");
4728                let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4729                let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4730                ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4731            })
4732            .collect()
4733    }
4734
4735    fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4736        // `<=` so that bit-equal values satisfy tol = 0. With `<`, |a−e| < 0
4737        // is unsatisfiable and a zero-tolerance "must match exactly" call
4738        // would reject identical numbers.
4739        let ok = if expected.abs() < 1.0 {
4740            (actual - expected).abs() <= tol
4741        } else {
4742            (actual - expected).abs() <= tol * expected.abs().max(1.0)
4743        };
4744        assert!(
4745            ok,
4746            "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4747        );
4748    }
4749
4750    #[test]
4751    fn gompertz_offset_partials_match_central_diff() {
4752        // Several (rate, shape, age) combinations spanning the small-shape
4753        // Taylor branch (|shape| < 1e-10) and the normal branch
4754        // (shape >> 1e-10), plus sign-reversed shape.
4755        let cases = [
4756            (0.5_f64, 0.01_f64, 30.0_f64),
4757            (0.2, 0.05, 60.0),
4758            (1.0, 0.001, 10.0),
4759            (0.4, 5e-11, 25.0),
4760            (0.4, -5e-11, 25.0),
4761            (0.3, -0.02, 40.0),
4762            (0.8, 0.2, 5.0),
4763        ];
4764        for &(rate, shape, age) in &cases {
4765            let cfg = SurvivalBaselineConfig {
4766                target: SurvivalBaselineTarget::Gompertz,
4767                scale: None,
4768                shape: Some(shape),
4769                rate: Some(rate),
4770                makeham: None,
4771            };
4772            let analytic = baseline_offset_theta_partials(age, &cfg)
4773                .expect("ok")
4774                .expect("non-linear");
4775            // Keep the FD probe inside the Taylor branch for tiny |shape| so
4776            // the numeric derivative matches the same small-shape map as the
4777            // analytic helper. log_rate always uses the normal step — rate
4778            // is a moderate-scale parameter and a 1e-11 step would swamp the
4779            // FD with cancellation noise.
4780            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4781            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4782            assert_eq!(analytic.len(), 2);
4783            // Gompertz θ=(log_rate, shape). Rate channel: ∂eta/∂log_rate=1, ∂o_D/∂log_rate=0.
4784            assert_close(
4785                analytic[0].0,
4786                fd[0].0,
4787                1e-7,
4788                &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4789            );
4790            assert_close(
4791                analytic[0].1,
4792                fd[0].1,
4793                1e-7,
4794                &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4795            );
4796            // shape channel — larger tol because finite-differencing near
4797            // shape=0 amplifies rounding; 1e-5 is fine.
4798            assert_close(
4799                analytic[1].0,
4800                fd[1].0,
4801                1e-5,
4802                &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4803            );
4804            assert_close(
4805                analytic[1].1,
4806                fd[1].1,
4807                1e-5,
4808                &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4809            );
4810        }
4811    }
4812
4813    #[test]
4814    fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4815        // Pure Gompertz: rate cancels in o_D, so ∂o_D/∂log_rate must be
4816        // exactly 0 and ∂eta/∂log_rate must be exactly 1. Verify the
4817        // analytic implementation returns the exact values, not FD-close.
4818        let cfg = SurvivalBaselineConfig {
4819            target: SurvivalBaselineTarget::Gompertz,
4820            scale: None,
4821            shape: Some(0.05),
4822            rate: Some(0.3),
4823            makeham: None,
4824        };
4825        let partials = baseline_offset_theta_partials(42.0, &cfg)
4826            .expect("ok")
4827            .expect("non-linear");
4828        assert_eq!(partials[0].0, 1.0);
4829        assert_eq!(partials[0].1, 0.0);
4830    }
4831
4832    #[test]
4833    fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4834        // Both branches of gompertz_shape_derivatives should agree to high
4835        // precision at shape = 1e-10 + epsilon on the direct side vs
4836        // shape = 1e-10 - epsilon on the Taylor side. Here we spot-check
4837        // the continuity at the branch cutoff: shape slightly above and
4838        // slightly below 1e-10 must give values within O(shape²·t²)
4839        // (the Taylor truncation error).
4840        let age = 25.0;
4841        let rate = 0.4;
4842        let cfg_taylor = SurvivalBaselineConfig {
4843            target: SurvivalBaselineTarget::Gompertz,
4844            scale: None,
4845            shape: Some(0.5e-10),
4846            rate: Some(rate),
4847            makeham: None,
4848        };
4849        let cfg_direct = SurvivalBaselineConfig {
4850            target: SurvivalBaselineTarget::Gompertz,
4851            scale: None,
4852            shape: Some(2.0e-10),
4853            rate: Some(rate),
4854            makeham: None,
4855        };
4856        let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4857            .expect("ok")
4858            .expect("nl");
4859        let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4860            .expect("ok")
4861            .expect("nl");
4862        // ∂eta/∂shape at shape≈0 should be t/2 = 12.5 on both sides.
4863        assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4864        assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4865        // ∂o_D/∂shape at shape≈0 should be 1/2.
4866        assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4867        assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4868    }
4869
4870    // ----------------------------------------------------------------------
4871    // Gompertz hazard-channel shape derivatives: FD oracle + Taylor-branch
4872    // continuity. These feed `survival_hazard_theta_partials` /
4873    // `survival_hazard_theta_first_second` (the marginal-slope probit
4874    // baseline). Before this test, the only coverage of
4875    // `gompertz_cumulative_shape_{,second_}derivative` was the indirect
4876    // marginal-slope Hessian FD at shape=0.025, which never touches the
4877    // small-shape (`|shape| < 1e-10`) Taylor branch nor directly FD-checks
4878    // these analytic shape derivatives.
4879    // ----------------------------------------------------------------------
4880
4881    #[test]
4882    fn gompertz_hazard_shape_derivatives_match_central_diff() {
4883        // shape stays well above the 1e-10 Taylor cutoff so the exact
4884        // closed-form branch is exercised and the expm1/exp arithmetic is
4885        // numerically clean. FD on the analytic value/first-derivative
4886        // confirms the first and second shape derivatives.
4887        let cases = [
4888            (10.0_f64, 0.012_f64, 0.05_f64),
4889            (2.5, 0.5, 0.2),
4890            (15.0, 0.003, 0.01),
4891            (40.0, 0.3, 0.001),
4892        ];
4893        let h = 1e-6;
4894        for &(age, rate, shape) in &cases {
4895            // First shape derivative of (H_G, h_G) vs central diff of value.
4896            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4897            let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4898            let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4899            assert_close(
4900                d_cum,
4901                (cum_p - cum_m) / (2.0 * h),
4902                1e-6,
4903                &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4904            );
4905            assert_close(
4906                d_inst,
4907                (inst_p - inst_m) / (2.0 * h),
4908                1e-6,
4909                &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4910            );
4911
4912            // Second shape derivative vs central diff of the first derivative.
4913            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4914            let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4915            let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4916            assert_close(
4917                d2_cum,
4918                (dcum_p - dcum_m) / (2.0 * h),
4919                1e-5,
4920                &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4921            );
4922            assert_close(
4923                d2_inst,
4924                (dinst_p - dinst_m) / (2.0 * h),
4925                1e-5,
4926                &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4927            );
4928        }
4929    }
4930
4931    #[test]
4932    fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4933        // At small x = shape·age the shape derivatives collapse to closed-form
4934        // limits. These MUST hold even for large ages with tiny shapes, which
4935        // is precisely the regime where the (cancelling) exact branch loses all
4936        // precision and the x-based pivot routes to the Taylor branch.
4937        //   ∂H_G/∂shape   -> rate·t²/2
4938        //   ∂h_G/∂shape   -> rate·t
4939        //   ∂²H_G/∂shape² -> rate·t³/3
4940        //   ∂²h_G/∂shape² -> rate·t²
4941        // The bug this guards: the second derivative's old `shape < 1e-10`
4942        // pivot ignored `age`, so e.g. (age=100, shape=1e-5 -> x=1e-3) took the
4943        // cancelling exact branch and returned a wildly wrong curvature.
4944        let cases = [
4945            (25.0_f64, 0.4_f64, 1e-9_f64),
4946            (100.0, 0.4, 1e-6),   // x = 1e-4
4947            (100.0, 0.012, 1e-6), // x = 1e-4, the old-pivot band (large age, tiny shape)
4948            (50.0, 1.2, 1e-8),
4949        ];
4950        // NOTE: every quantity below is compared against its shape->0 *limit*.
4951        // For the cancelling cumulative branches (∂H/∂shape, ∂²H/∂shape²,
4952        // ∂²h/∂shape²) the limit is the correct shape->0 target and the
4953        // implementation routes through Taylor in this band. But the
4954        // instantaneous first derivative ∂h_G/∂shape = rate·age·e^x carries NO
4955        // cancellation: it is exact, and its departure from the limit rate·t is
4956        // a genuine O(x) effect. At x=1e-3 that departure is ~1.2e-3 (> tol),
4957        // so the cases here keep x <= 1e-4 where the limit is a valid 1e-3
4958        // oracle for *all four* quantities. The cancelling-branch regression at
4959        // larger x is covered by gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap.
4960        for &(age, rate, shape) in &cases {
4961            let t = age;
4962            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4963            assert_close(
4964                d_cum,
4965                rate * t * t / 2.0,
4966                1e-3,
4967                &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4968            );
4969            assert_close(
4970                d_inst,
4971                rate * t,
4972                1e-3,
4973                &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4974            );
4975
4976            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4977            assert_close(
4978                d2_cum,
4979                rate * t * t * t / 3.0,
4980                1e-3,
4981                &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4982            );
4983            assert_close(
4984                d2_inst,
4985                rate * t * t,
4986                1e-3,
4987                &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4988            );
4989        }
4990    }
4991
4992    #[test]
4993    fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4994        // Regression: in the band shape ∈ [1e-10, ~1e-4] with a realistic age,
4995        // the OLD `shape < 1e-10` pivot sent ∂²H_G/∂shape² through the
4996        // catastrophically-cancelling exact branch. With age=100, shape=1e-9
4997        // (x=1e-7) the exact branch returned ~+5e1 vs the true ~rate·t³/3.
4998        // Assert the implementation now matches the closed-form limit to high
4999        // precision throughout that band, across several decades of shape.
5000        let age = 100.0;
5001        let rate = 0.4;
5002        let t = age;
5003        let truth = rate * t * t * t / 3.0; // 1.333e5
5004        // Start at shape=1e-5 (x=1e-3): below this the second derivative is,
5005        // to better than 1e-3 relative, equal to its shape->0 limit, so the
5006        // limit is a valid oracle. (At x=1e-2 the true value legitimately
5007        // departs from the limit by ~7e-3, which is a real O(x) correction,
5008        // not an error — so we do not extend the band up to shape=1e-4.)
5009        for k in 5..=12 {
5010            let shape = 10f64.powi(-(k as i32)); // 1e-5 .. 1e-12
5011            let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
5012            assert_close(
5013                d2_cum,
5014                truth,
5015                1e-3,
5016                &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
5017            );
5018        }
5019    }
5020
5021    #[test]
5022    fn weibull_offset_partials_match_central_diff() {
5023        let cases = [
5024            (0.5_f64, 1.2_f64, 25.0_f64),
5025            (2.0, 0.8, 60.0),
5026            (0.1, 3.0, 10.0),
5027        ];
5028        for &(scale, shape, age) in &cases {
5029            let cfg = SurvivalBaselineConfig {
5030                target: SurvivalBaselineTarget::Weibull,
5031                scale: Some(scale),
5032                shape: Some(shape),
5033                rate: None,
5034                makeham: None,
5035            };
5036            let analytic = baseline_offset_theta_partials(age, &cfg)
5037                .expect("ok")
5038                .expect("nl");
5039            let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
5040            assert_eq!(analytic.len(), 2);
5041            for k in 0..2 {
5042                assert_close(
5043                    analytic[k].0,
5044                    fd[k].0,
5045                    1e-7,
5046                    &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5047                );
5048                assert_close(
5049                    analytic[k].1,
5050                    fd[k].1,
5051                    1e-7,
5052                    &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
5053                );
5054            }
5055            // Weibull o_D = shape/t is independent of scale; verify exactly.
5056            assert_eq!(analytic[0].1, 0.0);
5057        }
5058    }
5059
5060    #[test]
5061    fn gompertz_makeham_offset_partials_match_central_diff() {
5062        let cases = [
5063            (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
5064            (0.5, 0.01, 0.01, 25.0),
5065            (0.2, 0.001, 0.005, 60.0),
5066            (0.4, 5e-11, 0.01, 25.0),
5067            (0.4, -5e-11, 0.01, 25.0),
5068            (0.8, 0.2, 0.05, 5.0),
5069        ];
5070        for &(rate, shape, makeham, age) in &cases {
5071            let cfg = SurvivalBaselineConfig {
5072                target: SurvivalBaselineTarget::GompertzMakeham,
5073                scale: None,
5074                shape: Some(shape),
5075                rate: Some(rate),
5076                makeham: Some(makeham),
5077            };
5078            let analytic = baseline_offset_theta_partials(age, &cfg)
5079                .expect("ok")
5080                .expect("nl");
5081            // See gompertz_offset_partials_match_central_diff: tiny shape-step
5082            // is only needed for the shape component; log_rate and
5083            // log_makeham take the normal-scale step.
5084            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
5085            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
5086            assert_eq!(analytic.len(), 3);
5087            for k in 0..3 {
5088                assert_close(
5089                    analytic[k].0,
5090                    fd[k].0,
5091                    1e-5,
5092                    &format!(
5093                        "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5094                    ),
5095                );
5096                assert_close(
5097                    analytic[k].1,
5098                    fd[k].1,
5099                    1e-5,
5100                    &format!(
5101                        "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
5102                    ),
5103                );
5104            }
5105        }
5106    }
5107
5108    #[test]
5109    fn linear_baseline_has_no_theta_partials() {
5110        let cfg = SurvivalBaselineConfig {
5111            target: SurvivalBaselineTarget::Linear,
5112            scale: None,
5113            shape: None,
5114            rate: None,
5115            makeham: None,
5116        };
5117        assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
5118    }
5119
5120    #[test]
5121    fn baseline_offset_partials_reject_non_positive_ages() {
5122        let cfg = SurvivalBaselineConfig {
5123            target: SurvivalBaselineTarget::Gompertz,
5124            scale: None,
5125            shape: Some(0.01),
5126            rate: Some(0.5),
5127            makeham: None,
5128        };
5129        assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
5130        assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
5131        assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
5132    }
5133
5134    // ─── baseline_chain_rule_gradient — mechanical and FD-vs-θ tests ─────
5135
5136    /// Mechanical sanity check: with only one event observation at known
5137    /// (r_X, r_E, r_D, age_exit, age_entry), the Gompertz chain-rule gradient
5138    /// reduces to the analytic linear combination of `baseline_offset_theta_partials`.
5139    #[test]
5140    fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
5141        let cfg = SurvivalBaselineConfig {
5142            target: SurvivalBaselineTarget::Gompertz,
5143            scale: None,
5144            shape: Some(0.05),
5145            rate: Some(0.3),
5146            makeham: None,
5147        };
5148        let age_entry = array![10.0_f64];
5149        let age_exit = array![25.0_f64];
5150        let residuals = OffsetChannelResiduals {
5151            exit: array![0.7_f64],
5152            entry: array![-0.2_f64],
5153            derivative: array![-0.4_f64],
5154            right: Array1::<f64>::zeros(1),
5155        };
5156        let grad = baseline_chain_rule_gradient(
5157            age_entry.view(),
5158            age_exit.view(),
5159            age_exit.view(),
5160            &cfg,
5161            &residuals,
5162        )
5163        .expect("ok")
5164        .expect("non-linear");
5165        // Hand-compute: grad[k] = r_X·∂eta_exit/∂θ_k + r_D·∂o_D_exit/∂θ_k + r_E·∂eta_entry/∂θ_k.
5166        let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5167            .unwrap()
5168            .unwrap();
5169        let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5170            .unwrap()
5171            .unwrap();
5172        for k in 0..p_exit.len() {
5173            let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5174            assert!(
5175                (grad[k] - expected).abs() < 1e-12,
5176                "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5177                grad[k],
5178                expected
5179            );
5180        }
5181    }
5182
5183    /// Origin-entry rows (r_entry == 0) must skip the baseline partials call at
5184    /// `age_entry = 0`, which would otherwise fail the positive-age precondition.
5185    #[test]
5186    fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5187        let cfg = SurvivalBaselineConfig {
5188            target: SurvivalBaselineTarget::Gompertz,
5189            scale: None,
5190            shape: Some(0.05),
5191            rate: Some(0.3),
5192            makeham: None,
5193        };
5194        let age_entry = array![0.0_f64, 5.0_f64];
5195        let age_exit = array![10.0_f64, 20.0_f64];
5196        let residuals = OffsetChannelResiduals {
5197            exit: array![0.5_f64, 0.3_f64],
5198            entry: array![0.0_f64, -0.1_f64], // row 0 is origin-entry (r_E = 0)
5199            derivative: array![-0.2_f64, 0.0_f64],
5200            right: Array1::<f64>::zeros(2),
5201        };
5202        // Must not error despite age_entry[0] == 0.
5203        let grad = baseline_chain_rule_gradient(
5204            age_entry.view(),
5205            age_exit.view(),
5206            age_exit.view(),
5207            &cfg,
5208            &residuals,
5209        )
5210        .expect("must not fail on origin-entry row with r_entry=0")
5211        .expect("non-linear");
5212        assert_eq!(grad.len(), 2);
5213        // Row 1's entry channel contributes, row 0's does not.
5214        let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5215        let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5216        let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5217        for k in 0..2 {
5218            let expected = 0.5 * p_exit_0[k].0
5219                + (-0.2) * p_exit_0[k].1
5220                + 0.3 * p_exit_1[k].0
5221                + (-0.1) * p_entry_1[k].0;
5222            assert!(
5223                (grad[k] - expected).abs() < 1e-12,
5224                "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5225                grad[k],
5226                expected
5227            );
5228        }
5229    }
5230
5231    /// Linear target has no θ-parameters; contractor returns None.
5232    #[test]
5233    fn chain_rule_gradient_linear_target_returns_none() {
5234        let cfg = SurvivalBaselineConfig {
5235            target: SurvivalBaselineTarget::Linear,
5236            scale: None,
5237            shape: None,
5238            rate: None,
5239            makeham: None,
5240        };
5241        let age_entry = array![1.0_f64];
5242        let age_exit = array![2.0_f64];
5243        let residuals = OffsetChannelResiduals {
5244            exit: array![0.1_f64],
5245            entry: array![0.0_f64],
5246            derivative: array![0.0_f64],
5247            right: Array1::<f64>::zeros(1),
5248        };
5249        let grad = baseline_chain_rule_gradient(
5250            age_entry.view(),
5251            age_exit.view(),
5252            age_exit.view(),
5253            &cfg,
5254            &residuals,
5255        )
5256        .expect("ok");
5257        assert!(grad.is_none());
5258    }
5259
5260    /// End-to-end envelope-theorem check: the chain-rule gradient at
5261    /// residuals-evaluated-at-β-fixed matches the central FD of the
5262    /// unpenalized NLL with respect to θ when the OFFSETS are recomputed
5263    /// from the perturbed cfg and β is held at its base value.
5264    ///
5265    /// This is the mathematical content of the envelope theorem applied to
5266    /// the penalized-deviance cost at fixed β: if β solves ∂C/∂β = 0 at
5267    /// (θ, β*), then the total derivative of C at (θ±h) when β is held at
5268    /// β* equals the partial derivative of C wrt θ at the base — up to
5269    /// O(h²) in the truncation error of central differences. For THIS test
5270    /// we're directly differencing NLL (the unpenalized piece that carries
5271    /// all the θ dependence), so the envelope identity is exact up to FD
5272    /// truncation.
5273    ///
5274    /// The test synthesizes a plausible residual set by hand rather than
5275    /// running PIRLS — what we're validating is the chain-rule contractor,
5276    /// not the fit. A PIRLS-based end-to-end check belongs in an
5277    /// integration test, not this unit-test module.
5278    #[test]
5279    fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5280        // Toy 3-observation case with two events (one origin-entry, one not)
5281        // and one censored row at large age.
5282        let cfg = SurvivalBaselineConfig {
5283            target: SurvivalBaselineTarget::Gompertz,
5284            scale: None,
5285            shape: Some(0.03),
5286            rate: Some(0.25),
5287            makeham: None,
5288        };
5289        let age_entry = array![0.0_f64, 5.0, 8.0];
5290        let age_exit = array![4.0_f64, 12.0, 20.0];
5291        // Weighted residuals at a notional β*. Values chosen in a plausible
5292        // range (~same order as w·exp(η)).
5293        let weights = array![1.0_f64, 2.0, 0.5];
5294        let events = [1.0_f64, 1.0, 0.0];
5295        // Fake a β* that yields finite eta_entry ± eta_exit ± s values by
5296        // directly specifying eta quantities. Contractor only consumes the
5297        // residuals, so the fake is sufficient.
5298        let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; // row 0 doesn't matter (origin entry)
5299        let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5300        let s_vals = [0.7_f64, 1.1, 1.5];
5301        let (r_x, r_e, r_d) = {
5302            let mut rx = Array1::<f64>::zeros(3);
5303            let mut re = Array1::<f64>::zeros(3);
5304            let mut rd = Array1::<f64>::zeros(3);
5305            for i in 0..3 {
5306                let w = weights[i];
5307                let d = events[i];
5308                rx[i] = w * (eta_exit_vals[i].exp() - d);
5309                re[i] = if i == 0 {
5310                    0.0 // origin entry
5311                } else {
5312                    -w * eta_entry_vals[i].exp()
5313                };
5314                rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5315            }
5316            (rx, re, rd)
5317        };
5318        let residuals = OffsetChannelResiduals {
5319            exit: r_x.clone(),
5320            entry: r_e.clone(),
5321            derivative: r_d.clone(),
5322            right: Array1::<f64>::zeros(3),
5323        };
5324        let grad = baseline_chain_rule_gradient(
5325            age_entry.view(),
5326            age_exit.view(),
5327            age_exit.view(),
5328            &cfg,
5329            &residuals,
5330        )
5331        .expect("ok")
5332        .expect("non-linear");
5333
5334        // Construct NLL(θ) with β* held to the same eta/s values by treating
5335        // eta_i, s_i as fixed "linear predictor" samples and shifting by
5336        // (offset(θ) - offset(θ_base)). That's exactly the RP NLL with β*
5337        // held constant and offsets varied through θ.
5338        let nll = |theta_plus: &Array1<f64>| -> f64 {
5339            let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5340            let mut sum = 0.0_f64;
5341            for i in 0..3 {
5342                let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5343                let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5344                let d_eta_x = eta_x_p - base.0;
5345                let d_d_x = d_x_p - base.1;
5346                let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5347                let s_new = s_vals[i] + d_d_x;
5348                let interval_entry = if i == 0 {
5349                    0.0_f64
5350                } else {
5351                    let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5352                    let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5353                    let d_eta_e = eta_e_p - base_e.0;
5354                    let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5355                    eta_entry_new.exp()
5356                };
5357                let w = weights[i];
5358                let d = events[i];
5359                let nll_i =
5360                    w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5361                sum += nll_i;
5362            }
5363            sum
5364        };
5365
5366        let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5367        let h = 1e-6;
5368        for k in 0..theta_base.len() {
5369            let mut tp = theta_base.clone();
5370            let mut tm = theta_base.clone();
5371            tp[k] += h;
5372            tm[k] -= h;
5373            let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5374            assert!(
5375                (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5376                "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5377                grad[k],
5378                fd
5379            );
5380        }
5381    }
5382
5383    /// Length-mismatch surfaces as an error, not a silent contraction.
5384    #[test]
5385    fn chain_rule_gradient_rejects_length_mismatch() {
5386        let cfg = SurvivalBaselineConfig {
5387            target: SurvivalBaselineTarget::Gompertz,
5388            scale: None,
5389            shape: Some(0.05),
5390            rate: Some(0.3),
5391            makeham: None,
5392        };
5393        let age_entry = array![1.0_f64, 2.0]; // length 2
5394        let age_exit = array![5.0_f64, 6.0, 7.0]; // length 3
5395        let residuals = OffsetChannelResiduals {
5396            exit: array![0.1_f64, 0.2, 0.3],
5397            entry: array![0.0_f64, 0.0, 0.0],
5398            derivative: array![0.0_f64, 0.0, 0.0],
5399            right: Array1::<f64>::zeros(3),
5400        };
5401        let err = baseline_chain_rule_gradient(
5402            age_entry.view(),
5403            age_exit.view(),
5404            age_exit.view(),
5405            &cfg,
5406            &residuals,
5407        )
5408        .expect_err("length mismatch must error");
5409        assert!(err.contains("length mismatch"), "err={err}");
5410    }
5411}