Skip to main content

gam_models/survival/
construction.rs

1//! Survival model construction helpers.
2//!
3//! Types and functions for building survival model components:
4//! - Baseline hazard targets (Weibull, Gompertz, Gompertz-Makeham)
5//! - Time basis construction (I-spline on log-time)
6//! - Baseline offset computation
7//! - Time wiggle construction
8//!
9//! These are the building blocks a library consumer needs to construct
10//! a `FitRequest::SurvivalLocationScale` without going through the CLI.
11
12use gam_terms::basis::{
13    BSplineBasisSpec, BSplineBoundaryConditions, BSplineIdentifiability, BSplineKnotSpec,
14    BasisMetadata, BasisOptions, Dense, KnotSource, OneDimensionalBoundary, build_bspline_basis_1d,
15    create_basis, evaluate_bspline_derivative_scalar,
16};
17use crate::survival::location_scale::{
18    DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD, ResidualDistribution,
19    SurvivalCovariateTermBlockTemplate,
20};
21use crate::survival::lognormal_kernel::HazardLoading;
22use crate::survival::marginal_slope::DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD;
23use crate::wiggle::{
24    WiggleBlockConfig, append_selected_wiggle_penalty_orders, buildwiggle_block_input_from_seed,
25    monotone_wiggle_basis_with_derivative_order, split_wiggle_penalty_orders,
26};
27use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
28use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SparseDesignMatrix, symmetrize_in_place};
29use crate::probability::{normal_pdf, standard_normal_quantile};
30use gam_problem::{InverseLink, StandardLink};
31use ndarray::{Array1, Array2, array, s};
32use rayon::prelude::*;
33
34// ---------------------------------------------------------------------------
35// Typed error
36// ---------------------------------------------------------------------------
37
38/// Structured failure surface for survival-model construction helpers
39/// (`parse_*`, baseline-config builders, time-basis construction). Every
40/// variant carries a free-form `reason: String` payload; `Display` emits
41/// that payload verbatim, so converting to `String` via the `From` impl
42/// produces text byte-equivalent to the pre-refactor `Err(format!(...))`
43/// call sites that were the only producers in this module.
44///
45/// The public CLI-input parsers (`parse_survival_distribution`,
46/// `parse_survival_likelihood_mode`, `parse_survival_baseline_config`)
47/// keep their `Result<_, String>` signatures — string is the natural
48/// failure type for free-form user input — and route through this enum
49/// internally via `From<SurvivalConstructionError> for String`.
50#[derive(Clone, Debug)]
51pub enum SurvivalConstructionError {
52    /// User-supplied configuration is malformed or out of range (knot
53    /// counts, anchor offsets, derivative guards, ranks).
54    InvalidConfig { reason: String },
55    /// A required column or block of metadata is absent (e.g. saved
56    /// survival ispline keep_cols, baseline target on a saved fit).
57    MissingColumn { reason: String },
58    /// Per-row / per-column shape disagreement (entry/exit lengths,
59    /// penalty rank vs basis width, basis vs coefficient counts).
60    IncompatibleDimensions { reason: String },
61    /// Numeric / domain rejection: non-finite ratios, non-positive
62    /// survival times, monotonicity violations, ispline-derivative
63    /// underflow.
64    DataValidationFailed { reason: String },
65    /// Underlying basis / penalty builder rejected the construction
66    /// request (invalid spline order, ispline keep_cols out of range,
67    /// internal empty ispline time basis).
68    BasisConstructionFailed { reason: String },
69    /// User-named distribution / likelihood-mode / baseline target /
70    /// time-basis kind is not one we recognise.
71    UnsupportedDistribution { reason: String },
72}
73
74impl_reason_error_boilerplate! {
75    SurvivalConstructionError {
76        InvalidConfig,
77        MissingColumn,
78        IncompatibleDimensions,
79        DataValidationFailed,
80        BasisConstructionFailed,
81        UnsupportedDistribution,
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Types
87// ---------------------------------------------------------------------------
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SurvivalBaselineTarget {
91    /// No additional parametric target:
92    /// eta_target(t) = 0, so regularized model defaults to linear log-cumulative
93    /// hazard from the existing time basis.
94    Linear,
95    /// Parametric target: Weibull baseline.
96    ///
97    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
98    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
99    Weibull,
100    /// Parametric target: Gompertz baseline.
101    ///
102    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
103    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
104    Gompertz,
105    /// Parametric target: Gompertz-Makeham baseline.
106    ///
107    /// Transformation/cloglog survival uses `eta_target(t) = log(H0(t))`;
108    /// marginal-slope probit survival uses `q(t) = -Phi^-1(exp(-H0(t)))`.
109    GompertzMakeham,
110}
111
112#[derive(Clone, Debug)]
113pub struct SurvivalBaselineConfig {
114    pub target: SurvivalBaselineTarget,
115    pub scale: Option<f64>,
116    pub shape: Option<f64>,
117    pub rate: Option<f64>,
118    pub makeham: Option<f64>,
119}
120
121#[derive(Clone, Debug)]
122pub enum SurvivalTimeBasisConfig {
123    None,
124    Linear,
125    BSpline {
126        degree: usize,
127        knots: Array1<f64>,
128        smooth_lambda: f64,
129    },
130    /// I-spline value rows on the `log(t)` axis with non-negative
131    /// coefficients (`γ ≥ 0`) enforcing structural monotonicity of
132    /// `q(t) = I_basis(log t) · γ`. This replaces the row-wise
133    /// `D β + o ≥ guard` derivative-guard constraints the marginal-slope
134    /// family previously relied on.
135    ///
136    /// The design builder lives below at `_build_time_block`'s
137    /// `SurvivalTimeBasisConfig::ISpline` arm and exposes:
138    ///
139    /// * `x_entry_time` / `x_exit_time` — I-spline value rows on the
140    ///   `log(t)` axis. Non-negative entries plus `γ ≥ 0` give a
141    ///   monotone-non-decreasing `q(t)`, the structural property the
142    ///   marginal-slope family needs.
143    /// * `x_derivative_time` — right-cumulative B-spline-derivative on
144    ///   `log(t)` scaled by `1/t`, again non-negative with `γ ≥ 0`, so
145    ///   `q'(t) ≥ 0` pointwise. The `derivative_guard` constant is added
146    ///   externally by [`add_survival_time_derivative_guard_offset`],
147    ///   leaving the derivative guarantee `q'(t) ≥ guard` exact.
148    /// * 2nd-difference penalty on the underlying degree-`(k+1)` B-spline
149    ///   coefficients, filtered through `keep_cols` for identifiability.
150    ///
151    /// `TimeBlockInput::time_monotonicity` declares to the consuming
152    /// family how monotonicity is enforced. The marginal-slope
153    /// construction site sets it to
154    /// [`crate::survival::location_scale::TimeBlockMonotonicity::StructuralISpline`]
155    /// so the family skips row-wise `D β + o ≥ guard` constraint
156    /// generation and treats `γ ≥ 0` as the sole derivative-guard
157    /// mechanism. The universal `validate_time_qd1_feasible` safety net
158    /// runs regardless.
159    ///
160    /// An earlier iteration proposed a separate C-spline antiderivative
161    /// parameterization that put `q'(t)` in the I-spline space and `q(t)`
162    /// in the integral-of-I-spline space. That was mathematically
163    /// equivalent but a strictly worse fit for the codebase (extra basis
164    /// degree, an extra antiderivative builder, an extra identifiability
165    /// path, an extra penalty); it was removed in favor of the canonical
166    /// I-spline-value path here.
167    ISpline {
168        degree: usize,
169        knots: Array1<f64>,
170        keep_cols: Vec<usize>,
171        smooth_lambda: f64,
172    },
173}
174
175/// Persistable snapshot of the time-basis state used by a survival fit.
176///
177/// Every survival family routes through [`SurvivalTimeBuildOutput`] during
178/// the fit, but the FFI save path needs only the metadata — not the full
179/// design matrices. This struct is the single source of truth that flows
180/// from the workflow-level basis construction, through the family-specific
181/// fit result, into the saved-model payload via
182/// [`crate::inference::model::FittedModelPayload::apply_survival_time_basis`].
183///
184/// Threading this snapshot end-to-end eliminates the prior bug pattern
185/// where each FFI builder had to reconstruct the metadata from
186/// `fit_config` + the formula (silent drift risk; one builder forgetting
187/// to do so caused the marginal-slope save→load break).
188#[derive(Clone, Debug, PartialEq)]
189pub struct SavedSurvivalTimeBasis {
190    pub basisname: String,
191    pub degree: Option<usize>,
192    pub knots: Option<Vec<f64>>,
193    pub keep_cols: Option<Vec<usize>>,
194    pub smooth_lambda: Option<f64>,
195    pub anchor: f64,
196}
197
198impl SavedSurvivalTimeBasis {
199    /// Build a snapshot from the realised time-basis state and the entry
200    /// anchor that was used during the fit.
201    pub fn from_build(build: &SurvivalTimeBuildOutput, anchor: f64) -> Self {
202        Self {
203            basisname: build.basisname.clone(),
204            degree: build.degree,
205            knots: build.knots.clone(),
206            keep_cols: build.keep_cols.clone(),
207            smooth_lambda: build.smooth_lambda,
208            anchor,
209        }
210    }
211}
212
213#[derive(Clone)]
214pub struct SurvivalTimeBuildOutput {
215    pub x_entry_time: DesignMatrix,
216    pub x_exit_time: DesignMatrix,
217    pub x_derivative_time: DesignMatrix,
218    pub penalties: Vec<Array2<f64>>,
219    /// Structural nullspace dimension of each penalty matrix.
220    pub nullspace_dims: Vec<usize>,
221    pub basisname: String,
222    pub degree: Option<usize>,
223    pub knots: Option<Vec<f64>>,
224    pub keep_cols: Option<Vec<usize>>,
225    pub smooth_lambda: Option<f64>,
226}
227
228pub const SURVIVAL_TIME_FLOOR: f64 = 1e-9;
229
230/// Seed smoothing penalty `λ` used when a survival time basis is reconstructed
231/// from a build (or saved model) that did not carry an explicit `smooth_lambda`.
232/// This is only an initial value for the REML smoothing search, not a fixed
233/// policy: a small positive seed keeps the baseline spline lightly regularized
234/// at the start so the outer optimizer begins from a well-conditioned point and
235/// then adapts `λ` to the data. Kept in one place so the b-spline and i-spline
236/// reconstruction paths cannot drift apart.
237const SURVIVAL_TIME_SMOOTH_LAMBDA_SEED: f64 = 1e-2;
238
239/// Default initial Gompertz / Gompertz-Makeham shape parameter when the user
240/// does not supply `--baseline-shape`. The Gompertz hazard is
241/// `h(t) = rate · exp(shape · t)`; a near-zero shape seeds the baseline at an
242/// almost-flat (exponential-like) hazard, letting the fit grow the
243/// age-acceleration term from the data rather than committing to a strong
244/// curvature up front. Shared by the parse and fit-seed paths so both start
245/// from the same neutral shape.
246const GOMPERTZ_DEFAULT_SHAPE_SEED: f64 = 0.01;
247
248#[derive(Clone, Copy, Debug, PartialEq, Eq)]
249pub enum SurvivalLikelihoodMode {
250    Transformation,
251    Weibull,
252    LocationScale,
253    MarginalSlope,
254    Latent,
255    LatentBinary,
256}
257
258pub struct SurvivalTimeWiggleBuild {
259    pub penalties: Vec<Array2<f64>>,
260    pub nullspace_dims: Vec<usize>,
261    pub knots: Array1<f64>,
262    pub degree: usize,
263    pub ncols: usize,
264}
265
266// ---------------------------------------------------------------------------
267// Time normalization
268// ---------------------------------------------------------------------------
269
270pub fn normalize_survival_time_pair(
271    entry_raw: f64,
272    exit_raw: f64,
273    row_index: usize,
274) -> Result<(f64, f64), String> {
275    if !entry_raw.is_finite() || !exit_raw.is_finite() {
276        return Err(SurvivalConstructionError::DataValidationFailed {
277            reason: format!("non-finite survival times at row {}", row_index + 1),
278        }
279        .into());
280    }
281    if entry_raw < 0.0 || exit_raw < 0.0 {
282        return Err(SurvivalConstructionError::DataValidationFailed {
283            reason: format!("negative survival times at row {}", row_index + 1),
284        }
285        .into());
286    }
287
288    let entry = entry_raw.max(SURVIVAL_TIME_FLOOR);
289    let exit = exit_raw.max(entry + SURVIVAL_TIME_FLOOR);
290    Ok((entry, exit))
291}
292
293// ---------------------------------------------------------------------------
294// Basis monotonicity helpers
295// ---------------------------------------------------------------------------
296
297pub fn survival_basis_supports_structural_monotonicity(basisname: &str) -> bool {
298    basisname.eq_ignore_ascii_case("ispline")
299}
300
301pub fn require_structural_survival_time_basis(
302    basisname: &str,
303    context: &str,
304) -> Result<(), String> {
305    if survival_basis_supports_structural_monotonicity(basisname) {
306        return Ok(());
307    }
308    Err(SurvivalConstructionError::UnsupportedDistribution {
309        reason: format!(
310            "{context} requires a structural monotone survival time basis, but got '{basisname}'. \
311Only `ispline` is accepted here because its basis functions enforce a monotone cumulative time effect by construction. \
312`{basisname}` can fit non-monotone shapes, which can break survival semantics. \
313Re-run with `--time-basis ispline`."
314        ),
315    }
316    .into())
317}
318
319// ---------------------------------------------------------------------------
320// Baseline config parsing
321// ---------------------------------------------------------------------------
322
323pub fn parse_survival_baseline_config(
324    target_raw: &str,
325    scale: Option<f64>,
326    shape: Option<f64>,
327    rate: Option<f64>,
328    makeham: Option<f64>,
329) -> Result<SurvivalBaselineConfig, String> {
330    let target = match target_raw.to_ascii_lowercase().as_str() {
331        "linear" => SurvivalBaselineTarget::Linear,
332        "weibull" => SurvivalBaselineTarget::Weibull,
333        "gompertz" => SurvivalBaselineTarget::Gompertz,
334        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
335        other => {
336            return Err(SurvivalConstructionError::UnsupportedDistribution {
337                reason: format!(
338                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
339                ),
340            }
341            .into());
342        }
343    };
344
345    match target {
346        SurvivalBaselineTarget::Linear => Ok(SurvivalBaselineConfig {
347            target,
348            scale: None,
349            shape: None,
350            rate: None,
351            makeham: None,
352        }),
353        SurvivalBaselineTarget::Weibull => {
354            let scale = scale.ok_or_else(|| {
355                "--baseline-target weibull requires --baseline-scale > 0".to_string()
356            })?;
357            let shape = shape.ok_or_else(|| {
358                "--baseline-target weibull requires --baseline-shape > 0".to_string()
359            })?;
360            if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
361                return Err(
362                    "weibull baseline requires finite positive --baseline-scale and --baseline-shape"
363                        .to_string(),
364                );
365            }
366            Ok(SurvivalBaselineConfig {
367                target,
368                scale: Some(scale),
369                shape: Some(shape),
370                rate: None,
371                makeham: None,
372            })
373        }
374        SurvivalBaselineTarget::Gompertz => {
375            let rate = rate.unwrap_or(1.0);
376            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
377            if !rate.is_finite() || rate <= 0.0 || !shape.is_finite() {
378                return Err(
379                    "gompertz baseline requires finite --baseline-shape and positive --baseline-rate"
380                        .to_string(),
381                );
382            }
383            Ok(SurvivalBaselineConfig {
384                target,
385                scale: None,
386                shape: Some(shape),
387                rate: Some(rate),
388                makeham: None,
389            })
390        }
391        SurvivalBaselineTarget::GompertzMakeham => {
392            let rate = rate.unwrap_or(0.5);
393            let shape = shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED);
394            let makeham = makeham.unwrap_or(0.5);
395            if !rate.is_finite()
396                || rate <= 0.0
397                || !shape.is_finite()
398                || !makeham.is_finite()
399                || makeham <= 0.0
400            {
401                return Err(
402                    "gompertz-makeham baseline requires finite --baseline-shape, positive --baseline-rate, and positive --baseline-makeham"
403                        .to_string(),
404                );
405            }
406            Ok(SurvivalBaselineConfig {
407                target,
408                scale: None,
409                shape: Some(shape),
410                rate: Some(rate),
411                makeham: Some(makeham),
412            })
413        }
414    }
415}
416
417// ---------------------------------------------------------------------------
418// Likelihood mode / distribution parsing
419// ---------------------------------------------------------------------------
420
421pub fn parse_survival_likelihood_mode(raw: &str) -> Result<SurvivalLikelihoodMode, String> {
422    match raw.to_ascii_lowercase().as_str() {
423        "transformation" => Ok(SurvivalLikelihoodMode::Transformation),
424        "weibull" => Ok(SurvivalLikelihoodMode::Weibull),
425        "location-scale" => Ok(SurvivalLikelihoodMode::LocationScale),
426        "marginal-slope" => Ok(SurvivalLikelihoodMode::MarginalSlope),
427        "latent" => Ok(SurvivalLikelihoodMode::Latent),
428        "latent-binary" => Ok(SurvivalLikelihoodMode::LatentBinary),
429        other => Err(SurvivalConstructionError::UnsupportedDistribution {
430            reason: format!(
431                "unsupported --survival-likelihood '{other}'; use transformation|weibull|location-scale|marginal-slope|latent|latent-binary"
432            ),
433        }
434        .into()),
435    }
436}
437
438pub const fn survival_likelihood_modename(mode: SurvivalLikelihoodMode) -> &'static str {
439    match mode {
440        SurvivalLikelihoodMode::Transformation => "transformation",
441        SurvivalLikelihoodMode::Weibull => "weibull",
442        SurvivalLikelihoodMode::LocationScale => "location-scale",
443        SurvivalLikelihoodMode::MarginalSlope => "marginal-slope",
444        SurvivalLikelihoodMode::Latent => "latent",
445        SurvivalLikelihoodMode::LatentBinary => "latent-binary",
446    }
447}
448
449pub fn parse_survival_distribution(raw: &str) -> Result<ResidualDistribution, String> {
450    match raw.to_ascii_lowercase().as_str() {
451        "gaussian" | "probit" => Ok(ResidualDistribution::Gaussian),
452        "gumbel" | "cloglog" => Ok(ResidualDistribution::Gumbel),
453        "logistic" | "logit" => Ok(ResidualDistribution::Logistic),
454        other => Err(SurvivalConstructionError::UnsupportedDistribution {
455            reason: format!(
456                "unsupported survmodel(distribution='{other}'); accepted: gaussian / probit, gumbel / cloglog, logistic / logit"
457            ),
458        }
459        .into()),
460    }
461}
462
463pub const fn survival_baseline_targetname(target: SurvivalBaselineTarget) -> &'static str {
464    match target {
465        SurvivalBaselineTarget::Linear => "linear",
466        SurvivalBaselineTarget::Weibull => "weibull",
467        SurvivalBaselineTarget::Gompertz => "gompertz",
468        SurvivalBaselineTarget::GompertzMakeham => "gompertz-makeham",
469    }
470}
471
472pub fn positive_survival_time_seed(age_exit: &Array1<f64>) -> f64 {
473    let sum = age_exit
474        .iter()
475        .copied()
476        .filter(|value| value.is_finite() && *value > 0.0)
477        .sum::<f64>();
478    let count = age_exit
479        .iter()
480        .filter(|value| value.is_finite() && **value > 0.0)
481        .count()
482        .max(1);
483    (sum / count as f64).max(SURVIVAL_TIME_FLOOR)
484}
485
486pub fn initial_survival_baseline_config_for_fit(
487    target_raw: &str,
488    scale: Option<f64>,
489    shape: Option<f64>,
490    rate: Option<f64>,
491    makeham: Option<f64>,
492    age_exit: &Array1<f64>,
493) -> Result<SurvivalBaselineConfig, String> {
494    let target = match target_raw.trim().to_ascii_lowercase().as_str() {
495        "linear" => SurvivalBaselineTarget::Linear,
496        "weibull" => SurvivalBaselineTarget::Weibull,
497        "gompertz" => SurvivalBaselineTarget::Gompertz,
498        "gompertz-makeham" => SurvivalBaselineTarget::GompertzMakeham,
499        other => {
500            return Err(SurvivalConstructionError::UnsupportedDistribution {
501                reason: format!(
502                    "unsupported --baseline-target '{other}'; use linear|weibull|gompertz|gompertz-makeham"
503                ),
504            }
505            .into());
506        }
507    };
508    let time_scale_seed = positive_survival_time_seed(age_exit);
509    let cfg = match target {
510        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
511            target,
512            scale: None,
513            shape: None,
514            rate: None,
515            makeham: None,
516        },
517        SurvivalBaselineTarget::Weibull => SurvivalBaselineConfig {
518            target,
519            scale: Some(scale.unwrap_or(time_scale_seed)),
520            shape: Some(shape.unwrap_or(1.0)),
521            rate: None,
522            makeham: None,
523        },
524        SurvivalBaselineTarget::Gompertz => SurvivalBaselineConfig {
525            target,
526            scale: None,
527            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
528            rate: Some(rate.unwrap_or(1.0 / time_scale_seed)),
529            makeham: None,
530        },
531        SurvivalBaselineTarget::GompertzMakeham => SurvivalBaselineConfig {
532            target,
533            scale: None,
534            shape: Some(shape.unwrap_or(GOMPERTZ_DEFAULT_SHAPE_SEED)),
535            rate: Some(rate.unwrap_or(0.5 / time_scale_seed)),
536            makeham: Some(makeham.unwrap_or(0.5 / time_scale_seed)),
537        },
538    };
539    parse_survival_baseline_config(
540        survival_baseline_targetname(cfg.target),
541        cfg.scale,
542        cfg.shape,
543        cfg.rate,
544        cfg.makeham,
545    )
546}
547
548fn survival_baseline_theta_from_config(
549    cfg: &SurvivalBaselineConfig,
550) -> Result<Option<Array1<f64>>, String> {
551    Ok(match cfg.target {
552        SurvivalBaselineTarget::Linear => None,
553        SurvivalBaselineTarget::Weibull => Some(array![
554            cfg.scale
555                .ok_or_else(|| "missing weibull baseline scale".to_string())?
556                .ln(),
557            cfg.shape
558                .ok_or_else(|| "missing weibull baseline shape".to_string())?
559                .ln(),
560        ]),
561        SurvivalBaselineTarget::Gompertz => Some(array![
562            cfg.rate
563                .ok_or_else(|| "missing gompertz baseline rate".to_string())?
564                .ln(),
565            cfg.shape
566                .ok_or_else(|| "missing gompertz baseline shape".to_string())?,
567        ]),
568        SurvivalBaselineTarget::GompertzMakeham => Some(array![
569            cfg.rate
570                .ok_or_else(|| "missing gompertz-makeham baseline rate".to_string())?
571                .ln(),
572            cfg.shape
573                .ok_or_else(|| "missing gompertz-makeham baseline shape".to_string())?,
574            cfg.makeham
575                .ok_or_else(|| "missing gompertz-makeham baseline makeham".to_string())?
576                .ln(),
577        ]),
578    })
579}
580
581fn survival_baseline_config_from_theta(
582    target: SurvivalBaselineTarget,
583    theta: &Array1<f64>,
584) -> Result<SurvivalBaselineConfig, String> {
585    let cfg = match target {
586        SurvivalBaselineTarget::Linear => SurvivalBaselineConfig {
587            target,
588            scale: None,
589            shape: None,
590            rate: None,
591            makeham: None,
592        },
593        SurvivalBaselineTarget::Weibull => {
594            if theta.len() != 2 {
595                return Err(SurvivalConstructionError::IncompatibleDimensions {
596                    reason: format!(
597                        "weibull baseline parameter dimension mismatch: expected 2, got {}",
598                        theta.len()
599                    ),
600                }
601                .into());
602            }
603            SurvivalBaselineConfig {
604                target,
605                scale: Some(theta[0].exp()),
606                shape: Some(theta[1].exp()),
607                rate: None,
608                makeham: None,
609            }
610        }
611        SurvivalBaselineTarget::Gompertz => {
612            if theta.len() != 2 {
613                return Err(SurvivalConstructionError::IncompatibleDimensions {
614                    reason: format!(
615                        "gompertz baseline parameter dimension mismatch: expected 2, got {}",
616                        theta.len()
617                    ),
618                }
619                .into());
620            }
621            SurvivalBaselineConfig {
622                target,
623                scale: None,
624                shape: Some(theta[1]),
625                rate: Some(theta[0].exp()),
626                makeham: None,
627            }
628        }
629        SurvivalBaselineTarget::GompertzMakeham => {
630            if theta.len() != 3 {
631                return Err(SurvivalConstructionError::IncompatibleDimensions {
632                    reason: format!(
633                        "gompertz-makeham baseline parameter dimension mismatch: expected 3, got {}",
634                        theta.len()
635                    ),
636                }
637                .into());
638            }
639            SurvivalBaselineConfig {
640                target,
641                scale: None,
642                shape: Some(theta[1]),
643                rate: Some(theta[0].exp()),
644                makeham: Some(theta[2].exp()),
645            }
646        }
647    };
648    parse_survival_baseline_config(
649        survival_baseline_targetname(cfg.target),
650        cfg.scale,
651        cfg.shape,
652        cfg.rate,
653        cfg.makeham,
654    )
655}
656
657/// Derivative contract for the shared baseline-θ outer optimizer.
658///
659/// The two public baseline optimizers (`…_with_gradient_only`,
660/// `…_with_gradient`) differ in exactly one axis: how much derivative
661/// information the objective closure supplies, and therefore which curvature
662/// declaration the `OuterProblem` must advertise. Every baseline-θ path now
663/// supplies an exact analytic gradient (profile-NLL envelope gradient), so both
664/// contracts route to a gradient-based solver. Everything else — θ↔config
665/// conversion, the ±6 log-space box,
666/// the single-seed config, the `run`/convergence/error-formatting boilerplate
667/// — is identical, so it lives once in [`run_baseline_theta_optimizer`] and
668/// this enum selects the per-contract `OuterProblem` configuration.
669#[derive(Clone, Copy, Debug, PartialEq, Eq)]
670enum BaselineDerivativeContract {
671    /// Cost + analytic gradient, no analytic Hessian. Routes to BFGS, which
672    /// builds its own quasi-Newton curvature from successive gradients.
673    GradientOnly,
674    /// Cost + analytic gradient + analytic Hessian. Routes to the primary
675    /// second-order outer solver, which may use either the analytic Hessian or
676    /// a BFGS approximation depending on the planner.
677    GradientHessian,
678}
679
680impl BaselineDerivativeContract {
681    /// Apply this contract's derivative declaration, solver class, tolerance,
682    /// and iteration budget to a freshly-constructed `OuterProblem`. The
683    /// bounds, initial ρ, and seed config are contract-independent and applied
684    /// by [`run_baseline_theta_optimizer`].
685    fn configure(
686        self,
687        problem: gam_solve::rho_optimizer::OuterProblem,
688    ) -> gam_solve::rho_optimizer::OuterProblem {
689        use gam_problem::{DeclaredHessianForm, Derivative};
690        match self {
691            // BFGS on a 2–3 dim problem with an exact gradient typically
692            // converges in 5–10 outer evaluations.
693            BaselineDerivativeContract::GradientOnly => problem
694                .with_gradient(Derivative::Analytic)
695                .with_hessian(DeclaredHessianForm::Unavailable)
696                .with_tolerance(1e-4)
697                .with_max_iter(240),
698            BaselineDerivativeContract::GradientHessian => problem
699                .with_gradient(Derivative::Analytic)
700                .with_hessian(DeclaredHessianForm::Either)
701                .with_tolerance(1e-4)
702                .with_max_iter(240),
703        }
704    }
705}
706
707/// Shared engine behind the three public baseline-config optimizers.
708///
709/// Owns every step that is identical across the cost-only, gradient-only, and
710/// gradient+Hessian contracts: config→θ seeding (with the linear/no-parameter
711/// early return), the ±6 log-space box, the single-seed `OuterProblem`
712/// skeleton, derivative-contract configuration, `build_objective` wiring,
713/// `run`, the convergence check + error formatting, and θ→config. The only
714/// contract-specific inputs are the already-wired `cost_fn`/`eval_fn` closures
715/// (which embed the derivative shape and dimension validation) and the
716/// `contract` selecting the `OuterProblem` derivative declaration.
717fn run_baseline_theta_optimizer<Fc, Fe>(
718    initial: &SurvivalBaselineConfig,
719    context: &str,
720    contract: BaselineDerivativeContract,
721    cost_fn: Fc,
722    eval_fn: Fe,
723) -> Result<SurvivalBaselineConfig, String>
724where
725    Fc: FnMut(&mut (), &Array1<f64>) -> Result<f64, crate::model_types::EstimationError>,
726    Fe: FnMut(
727        &mut (),
728        &Array1<f64>,
729    ) -> Result<gam_problem::OuterEval, crate::model_types::EstimationError>,
730{
731    use gam_solve::rho_optimizer::OuterProblem;
732    let Some(seed) = survival_baseline_theta_from_config(initial)? else {
733        return Ok(initial.clone());
734    };
735    let dim = seed.len();
736    let target = initial.target;
737    let lower = seed.mapv(|v| v - 6.0);
738    let upper = seed.mapv(|v| v + 6.0);
739    let problem = contract
740        .configure(OuterProblem::new(dim))
741        .with_bounds(lower, upper)
742        .with_initial_rho(seed.clone())
743        .with_seed_config(crate::seeding::SeedConfig {
744            max_seeds: 1,
745            seed_budget: 1,
746            num_auxiliary_trailing: dim,
747            ..Default::default()
748        });
749    let mut obj = problem.build_objective(
750        (),
751        cost_fn,
752        eval_fn,
753        None::<fn(&mut ())>,
754        None::<
755            fn(
756                &mut (),
757                &Array1<f64>,
758            ) -> Result<gam_problem::EfsEval, crate::model_types::EstimationError>,
759        >,
760    );
761    let result = problem
762        .run(&mut obj, context)
763        .map_err(|e| format!("{context} failed: {e}"))?;
764    if !result.converged {
765        return Err(SurvivalConstructionError::InvalidConfig {
766            reason: format!(
767                "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
768                result.iterations,
769                result.final_value,
770                result.final_grad_norm_report(),
771            ),
772        }
773        .into());
774    }
775    survival_baseline_config_from_theta(target, &result.rho)
776}
777
778/// Shared engine for the two derivative-carrying baseline-config optimizers.
779///
780/// Both `…_with_gradient_only` and `…_with_gradient` route an objective that
781/// returns a fully-populated [`OuterEval`](gam_problem::OuterEval)
782/// (cost + analytic gradient, optionally + analytic Hessian) for a given
783/// config. Everything downstream of that — the `Rc<RefCell>` sharing that lets
784/// the same user closure back both the `cost_fn` and `eval_fn`, the θ→config
785/// conversion, and deriving the scalar `cost_fn` from the eval result — is
786/// identical, so it lives here once. The contract-specific axis is only which
787/// `HessianResult` the objective embeds, which the wrapper has already encoded
788/// in the returned `OuterEval`, so this helper is contract-agnostic beyond the
789/// `contract` it forwards to [`run_baseline_theta_optimizer`].
790fn run_baseline_theta_optimizer_with_eval<F>(
791    initial: &SurvivalBaselineConfig,
792    context: &str,
793    contract: BaselineDerivativeContract,
794    objective: F,
795) -> Result<SurvivalBaselineConfig, String>
796where
797    F: FnMut(&SurvivalBaselineConfig) -> Result<gam_problem::OuterEval, String>,
798{
799    let target = initial.target;
800    let engine_context = context.to_string();
801    let objective = std::rc::Rc::new(std::cell::RefCell::new(objective));
802    let eval_at = move |obj: &std::rc::Rc<std::cell::RefCell<F>>,
803                        theta: &Array1<f64>|
804          -> Result<gam_problem::OuterEval, crate::model_types::EstimationError> {
805        let cfg = survival_baseline_config_from_theta(target, theta)
806            .map_err(crate::model_types::EstimationError::InvalidInput)?;
807        let eval =
808            obj.borrow_mut()(&cfg).map_err(crate::model_types::EstimationError::InvalidInput)?;
809        if eval.gradient.len() != theta.len() {
810            return Err(crate::model_types::EstimationError::InvalidInput(format!(
811                "{engine_context}: baseline gradient dimension mismatch: got {}, expected {}",
812                eval.gradient.len(),
813                theta.len()
814            )));
815        }
816        if let gam_problem::HessianResult::Analytic(ref h) = eval.hessian {
817            if h.nrows() != theta.len() || h.ncols() != theta.len() {
818                return Err(crate::model_types::EstimationError::InvalidInput(format!(
819                    "{engine_context}: baseline Hessian dimension mismatch: got {}x{}, expected {}x{}",
820                    h.nrows(),
821                    h.ncols(),
822                    theta.len(),
823                    theta.len()
824                )));
825            }
826        }
827        Ok(eval)
828    };
829    let cost_objective = std::rc::Rc::clone(&objective);
830    let cost_eval = eval_at.clone();
831    let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
832        cost_eval(&cost_objective, theta).map(|eval| eval.cost)
833    };
834    let eval_fn = move |_: &mut (), theta: &Array1<f64>| eval_at(&objective, theta);
835    run_baseline_theta_optimizer(initial, context, contract, cost_fn, eval_fn)
836}
837
838/// Gradient-only outer baseline-config optimizer. Thin adapter over
839/// [`run_baseline_theta_optimizer`] under the
840/// [`BaselineDerivativeContract::GradientOnly`] contract, which advertises
841/// `DeclaredHessianForm::Unavailable`, so the planner routes to BFGS and
842/// builds its own quasi-Newton curvature from successive gradient
843/// evaluations. Used by the survival location-scale path which has a
844/// closed-form θ-gradient (`baseline_chain_rule_gradient` /
845/// `marginal_slope_baseline_chain_rule_gradient`) but no native analytic
846/// θ-Hessian; BFGS on a 2–3 dim problem with an exact gradient typically
847/// converges in 5–10 outer evaluations.
848pub fn optimize_survival_baseline_config_with_gradient_only<F>(
849    initial: &SurvivalBaselineConfig,
850    context: &str,
851    mut objective: F,
852) -> Result<SurvivalBaselineConfig, String>
853where
854    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>), String>,
855{
856    use gam_problem::{HessianResult, OuterEval};
857    run_baseline_theta_optimizer_with_eval(
858        initial,
859        context,
860        BaselineDerivativeContract::GradientOnly,
861        move |cfg| {
862            let (cost, gradient) = objective(cfg)?;
863            Ok(OuterEval {
864                cost,
865                gradient,
866                hessian: HessianResult::Unavailable,
867                inner_beta_hint: None,
868            })
869        },
870    )
871}
872
873/// Gradient + Hessian outer baseline-config optimizer. Thin adapter over
874/// [`run_baseline_theta_optimizer`] under the
875/// [`BaselineDerivativeContract::GradientHessian`] contract, which advertises
876/// an analytic θ-Hessian so the primary second-order outer solver can use it.
877pub fn optimize_survival_baseline_config_with_gradient<F>(
878    initial: &SurvivalBaselineConfig,
879    context: &str,
880    mut objective: F,
881) -> Result<SurvivalBaselineConfig, String>
882where
883    F: FnMut(&SurvivalBaselineConfig) -> Result<(f64, Array1<f64>, Array2<f64>), String>,
884{
885    use gam_problem::{HessianResult, OuterEval};
886    run_baseline_theta_optimizer_with_eval(
887        initial,
888        context,
889        BaselineDerivativeContract::GradientHessian,
890        move |cfg| {
891            let (cost, gradient, hessian) = objective(cfg)?;
892            Ok(OuterEval {
893                cost,
894                gradient,
895                hessian: HessianResult::Analytic(hessian),
896                inner_beta_hint: None,
897            })
898        },
899    )
900}
901
902// ---------------------------------------------------------------------------
903// Time basis config (library-friendly: takes primitives, not CLI args)
904// ---------------------------------------------------------------------------
905
906pub fn parse_survival_time_basis_config(
907    time_basis: &str,
908    time_degree: usize,
909    time_num_internal_knots: usize,
910    time_smooth_lambda: f64,
911) -> Result<SurvivalTimeBasisConfig, String> {
912    match time_basis.to_ascii_lowercase().as_str() {
913        "none" => Ok(SurvivalTimeBasisConfig::None),
914        "ispline" => {
915            if time_degree < 1 {
916                return Err(
917                    "time-basis degree must be >= 1 for ispline time basis (CLI: --time-degree; Python: time_degree=)"
918                        .to_string(),
919                );
920            }
921            if time_num_internal_knots == 0 {
922                return Err(
923                    "time-basis must have > 0 internal knots for ispline time basis (CLI: --time-num-internal-knots; Python: time_num_internal_knots=)"
924                        .to_string(),
925                );
926            }
927            if !time_smooth_lambda.is_finite() || time_smooth_lambda < 0.0 {
928                return Err(
929                    "time-basis smoothing lambda must be finite and >= 0 (CLI: --time-smooth-lambda; Python: time_smooth_lambda=)"
930                        .to_string(),
931                );
932            }
933            Ok(SurvivalTimeBasisConfig::ISpline {
934                degree: time_degree,
935                knots: Array1::zeros(0),
936                keep_cols: Vec::new(),
937                smooth_lambda: time_smooth_lambda,
938            })
939        }
940        "linear" | "bspline" => {
941            // Forward to the shared structural-basis check so error text
942            // stays consistent with every other call site. `linear` /
943            // `bspline` are not structural, so this always returns Err;
944            // we map a (currently impossible) `Ok` to an explicit error
945            // string instead of `unreachable!`, keeping the match total
946            // without relying on a never-executes claim.
947            match require_structural_survival_time_basis(time_basis, "survival model configuration")
948            {
949                Err(e) => Err(e),
950                Ok(()) => Err(format!(
951                    "internal: structural-basis check accepted non-structural \
952                     survival time basis '{time_basis}'"
953                )),
954            }
955        }
956        other => Err(format!(
957            "unsupported --time-basis '{other}'; accepted values: ispline, none"
958        )),
959    }
960}
961
962// ---------------------------------------------------------------------------
963// Time basis construction
964// ---------------------------------------------------------------------------
965
966pub fn build_survival_time_basis(
967    age_entry: &Array1<f64>,
968    age_exit: &Array1<f64>,
969    cfg: SurvivalTimeBasisConfig,
970    infer_knots_if_needed: Option<(usize, f64)>,
971) -> Result<SurvivalTimeBuildOutput, String> {
972    fn checked_log_survival_times(times: &Array1<f64>, label: &str) -> Result<Array1<f64>, String> {
973        if let Some(row) = times.iter().position(|t| !t.is_finite()) {
974            return Err(SurvivalConstructionError::DataValidationFailed {
975                reason: format!(
976                    "survival time basis requires finite {label} times (row {})",
977                    row + 1
978                ),
979            }
980            .into());
981        }
982        if let Some(row) = times.iter().position(|t| *t < 0.0) {
983            return Err(SurvivalConstructionError::DataValidationFailed {
984                reason: format!(
985                    "survival time basis requires non-negative {label} times (row {})",
986                    row + 1
987                ),
988            }
989            .into());
990        }
991        Ok(times.mapv(|t| t.max(SURVIVAL_TIME_FLOOR).ln()))
992    }
993
994    let n = age_entry.len();
995    if n != age_exit.len() {
996        return Err(SurvivalConstructionError::IncompatibleDimensions {
997            reason: "survival time basis requires matching entry/exit lengths".to_string(),
998        }
999        .into());
1000    }
1001    for i in 0..n {
1002        if age_exit[i] < age_entry[i] {
1003            return Err(format!(
1004                "survival time basis requires exit times >= entry times (row {})",
1005                i + 1
1006            ));
1007        }
1008    }
1009    let log_entry = checked_log_survival_times(age_entry, "entry")?;
1010    let log_exit = checked_log_survival_times(age_exit, "exit")?;
1011
1012    fn survival_time_knot_input(log_entry: &Array1<f64>, log_exit: &Array1<f64>) -> Array1<f64> {
1013        let n = log_entry.len();
1014        let entry_range = log_entry
1015            .iter()
1016            .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
1017                (lo.min(v), hi.max(v))
1018            });
1019        let entry_degenerate = (entry_range.1 - entry_range.0).abs() < 1e-8;
1020        if entry_degenerate {
1021            log_exit.clone()
1022        } else {
1023            let mut combined = Array1::<f64>::zeros(2 * n);
1024            for i in 0..n {
1025                combined[i] = log_entry[i];
1026                combined[n + i] = log_exit[i];
1027            }
1028            combined
1029        }
1030    }
1031
1032    /// Cap the requested monotone-baseline internal-knot count to what the
1033    /// observed time resolution can actually support.
1034    ///
1035    /// The survival location-scale baseline is a degree-`d` I-spline with
1036    /// `num_internal_knots + d` shape-varying columns. Its smoothing parameter
1037    /// is informed *only* by the distinct interior log-time points: with fewer
1038    /// distinct interior times than requested knots the baseline is
1039    /// rank-deficient, and the REML/LAML profile in the time smoothing
1040    /// parameter becomes a flat ridge — the exact-joint outer search then
1041    /// probes that ridge indefinitely (each inner constrained Newton burns its
1042    /// whole cycle budget without certifying convergence) and the fit never
1043    /// terminates. This is the survival analogue of the standard
1044    /// "df must not exceed the data resolution" guard (`mgcv` caps `k` at the
1045    /// number of unique covariate values; `flexsurv`/`rstpm2` use a handful of
1046    /// baseline knots): we never place more interior knots than there are
1047    /// distinct interior points, and we keep the total baseline dimension a
1048    /// bounded fraction of the sample so the smoothing profile stays curved.
1049    ///
1050    /// This clamp lives in the shared knot-inference routine so the fit and any
1051    /// independent rebuild of the time basis (e.g. a predictor reconstructing
1052    /// `design · β` at fresh covariates) resolve to the *same* knot vector from
1053    /// the same data — there is no raw/active dimension drift.
1054    fn data_capped_internal_knots(
1055        combined: &Array1<f64>,
1056        degree: usize,
1057        requested_internal_knots: usize,
1058    ) -> usize {
1059        if requested_internal_knots == 0 {
1060            return 0;
1061        }
1062        let mut sorted: Vec<f64> = combined.iter().copied().collect();
1063        sorted.sort_by(f64::total_cmp);
1064        let minval = sorted.first().copied().unwrap_or(0.0);
1065        let maxval = sorted.last().copied().unwrap_or(minval);
1066        if minval == maxval {
1067            // Degenerate (single distinct time): no interior structure to fit.
1068            return 1.min(requested_internal_knots);
1069        }
1070        let scale = (maxval - minval).abs().max(1.0);
1071        let tol = 1e-12 * scale;
1072        // Count distinct strictly-interior points (knots can only live strictly
1073        // between the data extremes).
1074        let mut distinct_interior = 0usize;
1075        let mut last: Option<f64> = None;
1076        for &x in &sorted {
1077            if x <= minval + tol || x >= maxval - tol {
1078                continue;
1079            }
1080            if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1081                continue;
1082            }
1083            distinct_interior += 1;
1084            last = Some(x);
1085        }
1086        // Distinct-point ceiling: cannot place more interior knots than there
1087        // are distinct interior values.
1088        let mut cap = requested_internal_knots.min(distinct_interior.max(1));
1089        // Dimension-vs-resolution ceiling: keep the total baseline column count
1090        // `cap + degree` below ~1/4 of the distinct sample points so the
1091        // smoothing-parameter profile retains curvature (the data must be able
1092        // to identify the baseline shape, not just interpolate it). `n_distinct`
1093        // counts all distinct points (interior + the two extremes).
1094        let n_distinct = {
1095            let mut count = 0usize;
1096            let mut last: Option<f64> = None;
1097            for &x in &sorted {
1098                if last.is_some_and(|prev| (x - prev).abs() <= tol) {
1099                    continue;
1100                }
1101                count += 1;
1102                last = Some(x);
1103            }
1104            count
1105        };
1106        let dim_budget = n_distinct / 4;
1107        let dim_cap = dim_budget.saturating_sub(degree);
1108        cap = cap.min(dim_cap.max(1));
1109        cap.max(1)
1110    }
1111
1112    fn infer_survival_time_knots(
1113        combined: &Array1<f64>,
1114        knot_degree: usize,
1115        validation_degree: usize,
1116        num_internal_knots: usize,
1117        basis_options: BasisOptions,
1118    ) -> Result<Array1<f64>, String> {
1119        // Identifiability/termination guard: never request more baseline
1120        // internal knots than the observed time resolution supports. See
1121        // `data_capped_internal_knots` for the full rationale (a flat smoothing
1122        // ridge on an over-parameterized baseline is what makes the survival
1123        // location-scale exact-joint outer search fail to terminate).
1124        let num_internal_knots =
1125            data_capped_internal_knots(combined, validation_degree, num_internal_knots);
1126
1127        fn quantile_knot_inference_needs_uniform_fallback(
1128            combined: &Array1<f64>,
1129            num_internal_knots: usize,
1130        ) -> bool {
1131            if num_internal_knots == 0 || combined.is_empty() {
1132                return false;
1133            }
1134
1135            let mut sorted: Vec<f64> = combined.iter().copied().collect();
1136            sorted.sort_by(f64::total_cmp);
1137            let minval = sorted[0];
1138            let maxval = *sorted.last().unwrap_or(&minval);
1139            if minval == maxval {
1140                return false;
1141            }
1142
1143            let scale = (maxval - minval).abs().max(1.0);
1144            let tol = 1e-12 * scale;
1145            let mut support = Vec::with_capacity(sorted.len());
1146            let mut last: Option<f64> = None;
1147            for &x in &sorted {
1148                if x <= minval + tol || x >= maxval - tol {
1149                    continue;
1150                }
1151                if last.map(|prev| (x - prev).abs() <= tol).unwrap_or(false) {
1152                    continue;
1153                }
1154                support.push(x);
1155                last = Some(x);
1156            }
1157            if support.is_empty() {
1158                return true;
1159            }
1160
1161            let n = support.len();
1162            let mut prev_q = minval;
1163            for j in 1..=num_internal_knots {
1164                let p = j as f64 / (num_internal_knots + 1) as f64;
1165                let pos = p * (n.saturating_sub(1) as f64);
1166                let lo = pos.floor() as usize;
1167                let hi = pos.ceil() as usize;
1168                let frac = pos - lo as f64;
1169                let q = if lo == hi {
1170                    support[lo]
1171                } else {
1172                    support[lo] * (1.0 - frac) + support[hi] * frac
1173                }
1174                .clamp(minval, maxval);
1175                if q <= prev_q + tol || q >= maxval - tol {
1176                    return true;
1177                }
1178                prev_q = q;
1179            }
1180
1181            false
1182        }
1183
1184        let inferwith =
1185            |placement: gam_terms::basis::BSplineKnotPlacement| -> Result<Array1<f64>, String> {
1186                let built = build_bspline_basis_1d(
1187                    combined.view(),
1188                    &BSplineBasisSpec {
1189                        degree: knot_degree,
1190                        penalty_order: 2,
1191                        knotspec: BSplineKnotSpec::Automatic {
1192                            num_internal_knots: Some(num_internal_knots),
1193                            placement,
1194                        },
1195                        double_penalty: false,
1196                        identifiability: BSplineIdentifiability::None,
1197                        boundary: OneDimensionalBoundary::Open,
1198                        boundary_conditions: BSplineBoundaryConditions::default(),
1199                    },
1200                )
1201                .map_err(|e| format!("failed to infer survival time knots: {e}"))?;
1202                let knots = match built.metadata {
1203                    BasisMetadata::BSpline1D { knots, .. } => knots,
1204                    _ => {
1205                        return Err(
1206                            "internal error: expected BSpline1D metadata for survival time basis"
1207                                .to_string(),
1208                        );
1209                    }
1210                };
1211                // `knot_degree` is the clamped B-spline degree used to size
1212                // the knot vector. `validation_degree` is the public basis
1213                // degree passed to the final evaluator. They differ for
1214                // I-splines because `create_basis(..., BasisOptions::i_spline())`
1215                // internally raises the public degree by one to its working
1216                // B-spline antiderivative degree. Validating with
1217                // `knot_degree` here would raise a second time and reject the
1218                // coherent knot vector we just inferred.
1219                create_basis::<Dense>(
1220                    combined.view(),
1221                    KnotSource::Provided(knots.view()),
1222                    validation_degree,
1223                    basis_options,
1224                )
1225                .map_err(|e| e.to_string())?;
1226                Ok(knots)
1227            };
1228
1229        if quantile_knot_inference_needs_uniform_fallback(combined, num_internal_knots) {
1230            inferwith(gam_terms::basis::BSplineKnotPlacement::Uniform)
1231        } else {
1232            inferwith(gam_terms::basis::BSplineKnotPlacement::Quantile)
1233        }
1234    }
1235
1236    match cfg {
1237        SurvivalTimeBasisConfig::None => Ok(SurvivalTimeBuildOutput {
1238            x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1239            x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1240            x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, 0)))),
1241            penalties: Vec::new(),
1242            nullspace_dims: Vec::new(),
1243            basisname: "none".to_string(),
1244            degree: None,
1245            knots: None,
1246            keep_cols: None,
1247            smooth_lambda: None,
1248        }),
1249        SurvivalTimeBasisConfig::Linear => {
1250            let mut x_entry_time = Array2::<f64>::zeros((n, 2));
1251            let mut x_exit_time = Array2::<f64>::zeros((n, 2));
1252            let mut x_derivative_time = Array2::<f64>::zeros((n, 2));
1253            for i in 0..n {
1254                x_entry_time[[i, 0]] = 1.0;
1255                x_exit_time[[i, 0]] = 1.0;
1256                x_entry_time[[i, 1]] = log_entry[i];
1257                x_exit_time[[i, 1]] = log_exit[i];
1258                x_derivative_time[[i, 1]] = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1259            }
1260            Ok(SurvivalTimeBuildOutput {
1261                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1262                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1263                x_derivative_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_derivative_time)),
1264                penalties: Vec::new(),
1265                nullspace_dims: Vec::new(),
1266                basisname: "linear".to_string(),
1267                degree: None,
1268                knots: None,
1269                keep_cols: None,
1270                smooth_lambda: None,
1271            })
1272        }
1273        SurvivalTimeBasisConfig::BSpline {
1274            degree,
1275            knots,
1276            smooth_lambda,
1277        } => {
1278            let knotvec = if knots.is_empty() {
1279                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1280                    "internal error: bspline time basis requested without knot source".to_string()
1281                })?;
1282                let combined = survival_time_knot_input(&log_entry, &log_exit);
1283                infer_survival_time_knots(
1284                    &combined,
1285                    degree,
1286                    degree,
1287                    num_internal_knots,
1288                    BasisOptions::value(),
1289                )?
1290            } else {
1291                knots
1292            };
1293
1294            let entry_basis = build_bspline_basis_1d(
1295                log_entry.view(),
1296                &BSplineBasisSpec {
1297                    degree,
1298                    penalty_order: 2,
1299                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1300                    double_penalty: false,
1301                    identifiability: BSplineIdentifiability::None,
1302                    boundary: OneDimensionalBoundary::Open,
1303                    boundary_conditions: BSplineBoundaryConditions::default(),
1304                },
1305            )
1306            .map_err(|e| format!("failed to build bspline entry basis: {e}"))?;
1307            let exit_basis = build_bspline_basis_1d(
1308                log_exit.view(),
1309                &BSplineBasisSpec {
1310                    degree,
1311                    penalty_order: 2,
1312                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1313                    double_penalty: false,
1314                    identifiability: BSplineIdentifiability::None,
1315                    boundary: OneDimensionalBoundary::Open,
1316                    boundary_conditions: BSplineBoundaryConditions::default(),
1317                },
1318            )
1319            .map_err(|e| format!("failed to build bspline exit basis: {e}"))?;
1320
1321            let p_time = exit_basis.design.ncols();
1322            // Build derivative basis as sparse triplets — B-spline derivatives
1323            // have the same local support as the basis itself (at most degree+1
1324            // nonzeros per row), so building dense first wastes memory.
1325            let mut deriv_triplets = Vec::with_capacity(n * (degree + 1));
1326            let mut deriv_buf = vec![0.0_f64; p_time];
1327            for i in 0..n {
1328                deriv_buf.fill(0.0);
1329                evaluate_bspline_derivative_scalar(
1330                    log_exit[i],
1331                    knotvec.view(),
1332                    degree,
1333                    &mut deriv_buf,
1334                )
1335                .map_err(|e| format!("failed to evaluate bspline derivative: {e}"))?;
1336                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1337                for j in 0..p_time {
1338                    let v = deriv_buf[j] * chain;
1339                    if v.abs() > 1e-15 {
1340                        deriv_triplets.push(faer::sparse::Triplet::new(i, j, v));
1341                    }
1342                }
1343            }
1344            let x_derivative_time =
1345                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1346                {
1347                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1348                    Err(_) => {
1349                        // Fallback: build dense
1350                        let mut dense = Array2::<f64>::zeros((n, p_time));
1351                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1352                            dense[[row, col]] = val;
1353                        }
1354                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1355                    }
1356                };
1357
1358            Ok(SurvivalTimeBuildOutput {
1359                x_entry_time: entry_basis.design,
1360                x_exit_time: exit_basis.design,
1361                x_derivative_time,
1362                nullspace_dims: entry_basis.nullspace_dims,
1363                penalties: entry_basis.penalties,
1364                basisname: "bspline".to_string(),
1365                degree: Some(degree),
1366                knots: Some(knotvec.to_vec()),
1367                keep_cols: None,
1368                smooth_lambda: Some(smooth_lambda),
1369            })
1370        }
1371        SurvivalTimeBasisConfig::ISpline {
1372            degree,
1373            knots,
1374            keep_cols,
1375            smooth_lambda,
1376        } => {
1377            let bspline_degree = degree
1378                .checked_add(1)
1379                .ok_or_else(|| "ispline degree overflow while building knot basis".to_string())?;
1380            let knotvec = if knots.is_empty() {
1381                let (num_internal_knots, _) = infer_knots_if_needed.ok_or_else(|| {
1382                    "internal error: ispline time basis requested without knot source".to_string()
1383                })?;
1384                let combined = survival_time_knot_input(&log_entry, &log_exit);
1385                infer_survival_time_knots(
1386                    &combined,
1387                    bspline_degree,
1388                    degree,
1389                    num_internal_knots,
1390                    BasisOptions::i_spline(),
1391                )?
1392            } else {
1393                knots
1394            };
1395
1396            let (db_exit_arc, _) = create_basis::<Dense>(
1397                log_exit.view(),
1398                KnotSource::Provided(knotvec.view()),
1399                bspline_degree,
1400                BasisOptions::first_derivative(),
1401            )
1402            .map_err(|e| format!("failed to build ispline derivative basis: {e}"))?;
1403
1404            // Build full-width I-spline bases inside a block scope so the
1405            // large Arc allocations are freed when the block ends.
1406            let (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full) = {
1407                let (entry_arc, _) = create_basis::<Dense>(
1408                    log_entry.view(),
1409                    KnotSource::Provided(knotvec.view()),
1410                    degree,
1411                    BasisOptions::i_spline(),
1412                )
1413                .map_err(|e| format!("failed to build ispline entry basis: {e}"))?;
1414                let (exit_arc, _) = create_basis::<Dense>(
1415                    log_exit.view(),
1416                    KnotSource::Provided(knotvec.view()),
1417                    degree,
1418                    BasisOptions::i_spline(),
1419                )
1420                .map_err(|e| format!("failed to build ispline exit basis: {e}"))?;
1421
1422                let x_entry_full = entry_arc.as_ref();
1423                let x_exit_full = exit_arc.as_ref();
1424                let p_time_full = x_exit_full.ncols();
1425                if p_time_full == 0 {
1426                    return Err(SurvivalConstructionError::BasisConstructionFailed {
1427                        reason: "internal error: empty ispline time basis".to_string(),
1428                    }
1429                    .into());
1430                }
1431                let db_exit = db_exit_arc.as_ref();
1432                if db_exit.ncols() != p_time_full + 1 {
1433                    return Err(
1434                        "internal error: ispline derivative basis width must exceed basis width by one"
1435                            .to_string(),
1436                    );
1437                }
1438
1439                let keep_cols = if keep_cols.is_empty() {
1440                    let constant_tol = 1e-12_f64;
1441                    let mut inferred_keep_cols: Vec<usize> = Vec::new();
1442                    for j in 0..p_time_full {
1443                        let mut minv = f64::INFINITY;
1444                        let mut maxv = f64::NEG_INFINITY;
1445                        for i in 0..n {
1446                            let ve = x_exit_full[[i, j]];
1447                            let vs = x_entry_full[[i, j]];
1448                            minv = minv.min(ve.min(vs));
1449                            maxv = maxv.max(ve.max(vs));
1450                        }
1451                        if (maxv - minv) > constant_tol {
1452                            inferred_keep_cols.push(j);
1453                        }
1454                    }
1455                    inferred_keep_cols
1456                } else {
1457                    keep_cols
1458                };
1459                if keep_cols.is_empty() {
1460                    return Err(
1461                        "internal error: ispline basis has no shape-varying time columns"
1462                            .to_string(),
1463                    );
1464                }
1465                if keep_cols.iter().any(|&j| j >= p_time_full) {
1466                    return Err(SurvivalConstructionError::MissingColumn {
1467                        reason: "saved survival ispline keep_cols exceed basis width".to_string(),
1468                    }
1469                    .into());
1470                }
1471
1472                let p_time = keep_cols.len();
1473                let x_entry_time = x_entry_full.select(ndarray::Axis(1), &keep_cols);
1474                let x_exit_time = x_exit_full.select(ndarray::Axis(1), &keep_cols);
1475                // entry_arc and exit_arc go out of scope here, freeing the
1476                // full-width bases before derivative computation below.
1477                (x_entry_time, x_exit_time, keep_cols, p_time, p_time_full)
1478            };
1479            let db_exit = db_exit_arc.as_ref();
1480
1481            // Build I-spline derivative as sparse triplets.  The derivative
1482            // is a cumulative sum of B-spline derivatives and typically has
1483            // more nonzeros per row than a plain B-spline, but still much
1484            // fewer than p_time for modest bases.
1485            let mut deriv_triplets = Vec::with_capacity(n * p_time.min(16));
1486            let mut found_nonfinite: Option<(usize, usize)> = None;
1487            for i in 0..n {
1488                let mut running = 0.0_f64;
1489                let mut d_i_log_full = vec![0.0_f64; p_time_full];
1490                for j in (1..db_exit.ncols()).rev() {
1491                    let term = db_exit[[i, j]];
1492                    if term.is_finite() {
1493                        running += term;
1494                    }
1495                    d_i_log_full[j - 1] = running;
1496                }
1497                let chain = 1.0 / age_exit[i].max(SURVIVAL_TIME_FLOOR);
1498                for (j_new, &j_old) in keep_cols.iter().enumerate() {
1499                    let raw_v = d_i_log_full[j_old] * chain;
1500                    let v = if (-1e-12..0.0).contains(&raw_v) {
1501                        0.0
1502                    } else {
1503                        raw_v
1504                    };
1505                    if !v.is_finite() {
1506                        found_nonfinite = Some((i, j_new));
1507                    }
1508                    if v < -1e-12 {
1509                        return Err(format!(
1510                            "survival ispline derivative basis must stay non-negative at row {}, column {}; found {:.3e}",
1511                            i + 1,
1512                            j_new + 1,
1513                            v
1514                        ));
1515                    }
1516                    if v.abs() > 1e-15 {
1517                        deriv_triplets.push(faer::sparse::Triplet::new(i, j_new, v));
1518                    }
1519                }
1520            }
1521            if let Some((row, col)) = found_nonfinite {
1522                return Err(format!(
1523                    "survival ispline derivative basis produced non-finite value at row {}, column {}",
1524                    row + 1,
1525                    col + 1
1526                ));
1527            }
1528            let x_derivative_time =
1529                match faer::sparse::SparseColMat::try_new_from_triplets(n, p_time, &deriv_triplets)
1530                {
1531                    Ok(sparse) => DesignMatrix::Sparse(SparseDesignMatrix::new(sparse)),
1532                    Err(_) => {
1533                        let mut dense = Array2::<f64>::zeros((n, p_time));
1534                        for &faer::sparse::Triplet { row, col, val } in &deriv_triplets {
1535                            dense[[row, col]] = val;
1536                        }
1537                        DesignMatrix::Dense(DenseDesignMatrix::from(dense))
1538                    }
1539                };
1540
1541            let penalty_basis = build_bspline_basis_1d(
1542                log_exit.view(),
1543                &BSplineBasisSpec {
1544                    degree: bspline_degree,
1545                    penalty_order: 2,
1546                    knotspec: BSplineKnotSpec::Provided(knotvec.clone()),
1547                    double_penalty: false,
1548                    identifiability: BSplineIdentifiability::None,
1549                    boundary: OneDimensionalBoundary::Open,
1550                    boundary_conditions: BSplineBoundaryConditions::default(),
1551                },
1552            )
1553            .map_err(|e| format!("failed to build ispline smoothing penalty: {e}"))?;
1554            if penalty_basis.design.ncols() != p_time_full + 1 {
1555                return Err("internal error: ispline penalty dimension mismatch".to_string());
1556            }
1557            // I-spline curvature penalty in the *value* space of the baseline
1558            // log-cumulative-hazard, restricted to the retained (non-dropped)
1559            // coefficient block.
1560            //
1561            // The I-spline coefficient γ is the consecutive increment of the B-spline
1562            // value coefficients `c`: `c_0 = 0`, `c_k = Σ_{j<k} γ_j = (L γ)_k`, where
1563            // `L` is the `p_time × p_time` lower-triangular cumsum matrix. The
1564            // second-difference penalty on the B-spline values is `S_B = D₂ᵀD₂`
1565            // (the `penalty_basis.penalties` block). The correct curvature penalty
1566            // on γ is the **value-space congruence transform**
1567            //
1568            //   `S_I = Lᵀ S_B[1:,1:] L`,
1569            //
1570            // which satisfies `γᵀ S_I γ = (Lγ)ᵀ S_B[1:,1:] (Lγ)`.
1571            //
1572            // A constant γ (γ_k = γ₀ ∀k) maps to the linear value sequence
1573            // `c_k = k·γ₀`, which is annihilated by D₂: `D₂c = 0`. Therefore
1574            // `γᵀ S_I γ = 0` for constant γ, i.e. the **affine trend lies in the
1575            // penalty null space**. REML does not penalize the baseline slope
1576            // `d(log Λ)/d(log t)` or the overall level, so it correctly lets the
1577            // data determine these quantities without bias. The previous increment-
1578            // space form `S_B[1:,1:]` (applied directly to γ instead of Lγ) did NOT
1579            // have constant γ in its null space and therefore over-penalized affine
1580            // baselines, causing the fitted log-cumulative-hazard to lose its tail
1581            // slope to the penalty and fail quality tests (#1076).
1582            //
1583            // The value-space form has a 1-dimensional null space (span{(1,…,1)}),
1584            // declared via `nullspace_dims` so the REML generalized-logdet picks it
1585            // up. The penalized inner PIRLS is well-conditioned because the
1586            // likelihood Hessian H_lik has O(n_events) curvature along the affine
1587            // direction (the overall baseline level is identified by the data), and
1588            // the global stabilization ridge (ridge_lambda) provides an absolute
1589            // positive-definite floor.
1590            let mut penalties = Vec::<Array2<f64>>::new();
1591            for s_mat in &penalty_basis.penalties {
1592                if s_mat.nrows() != p_time_full + 1 || s_mat.ncols() != p_time_full + 1 {
1593                    continue;
1594                }
1595                // I-spline value-space penalty, computed in the CORRECT order
1596                // (gam#979). The B-spline value coefficients are the cumulative
1597                // sum of the I-spline increment coefficients, `c = L γ_full`, where
1598                // `L` is the FULL `p_time_full × p_time_full` LOWER-triangular
1599                // all-ones cumsum matrix (`L[i,j] = 1 iff j ≤ i`, so
1600                // `c_i = Σ_{j≤i} γ_j`). The value-space curvature penalty on the
1601                // full increment vector is the symmetric congruence
1602                //
1603                //   `S_I_full = Lᵀ · S_B[1:,1:] · L`,
1604                //
1605                // which is PSD because `S_B[1:,1:]` is a principal submatrix of the
1606                // PSD `S_B = D₂ᵀD₂` and congruence by any matrix preserves PSD.
1607                //
1608                // CRITICAL ORDERING (the gam#979 indefiniteness bug): the retained
1609                // columns `keep_cols` must be selected as a PRINCIPAL SUBMATRIX of
1610                // the FULL congruence `S_I_full` — i.e. congruence FIRST, selection
1611                // SECOND. The previous code selected `keep_cols` from `S_B[1:,1:]`
1612                // first and then applied a `p_time × p_time` cumsum to that
1613                // already-reduced block. Because the cumsum `L` couples every
1614                // increment, restricting the increment index set BEFORE the cumsum
1615                // does NOT commute with it: the reduced operator is a different,
1616                // generally INDEFINITE matrix (measured `s0_min_eval = −9.8e7`),
1617                // which makes `½γᵀS_Iγ` unbounded below and the penalized survival
1618                // NLL diverge (β drifts up the negative-eigenvalue mode, the inner
1619                // joint-Newton follows the unbounded objective, the outer REML never
1620                // terminates — the #979 hang). Doing the congruence on the full γ
1621                // and then taking the `keep_cols` principal submatrix restores the
1622                // PSD guarantee (a principal submatrix of a PSD matrix is PSD).
1623                let s_increment = s_mat.slice(s![1.., 1..]);
1624                if s_increment.nrows() != p_time_full || s_increment.ncols() != p_time_full {
1625                    return Err(format!(
1626                        "internal error: ispline penalty increment block must be {p_time_full}x{p_time_full}, got {}x{}",
1627                        s_increment.nrows(),
1628                        s_increment.ncols(),
1629                    ));
1630                }
1631                // Symmetrize the (already-symmetric) source with the shared
1632                // matrix utility. The survival builder's value-space
1633                // congruence is domain-specific; only the low-level symmetric
1634                // cleanup is common with the generic and SAE construction code.
1635                let mut s_full = s_increment.to_owned();
1636                symmetrize_in_place(&mut s_full);
1637                // S_mid = S_B[1:,1:] · L  (right-multiply by lower-triangular
1638                // cumsum): (S·L)[i,j] = Σ_k S[i,k]·L[k,j] = Σ_{k≥j} S[i,k]
1639                // because L[k,j] = 1 iff j ≤ k.
1640                let mut s_mid_full = Array2::<f64>::zeros((p_time_full, p_time_full));
1641                for i in 0..p_time_full {
1642                    for j in 0..p_time_full {
1643                        let mut v = 0.0;
1644                        for k in j..p_time_full {
1645                            v += s_full[[i, k]];
1646                        }
1647                        s_mid_full[[i, j]] = v;
1648                    }
1649                }
1650                // S_I_full = Lᵀ · S_mid = Lᵀ · S · L:
1651                // (Lᵀ·S_mid)[i,j] = Σ_k Lᵀ[i,k]·S_mid[k,j] = Σ_{k≥i} S_mid[k,j]
1652                // because Lᵀ[i,k] = L[k,i] = 1 iff i ≤ k.
1653                let mut s_full_congruent = Array2::<f64>::zeros((p_time_full, p_time_full));
1654                for i in 0..p_time_full {
1655                    for j in 0..p_time_full {
1656                        let mut v = 0.0;
1657                        for k in i..p_time_full {
1658                            v += s_mid_full[[k, j]];
1659                        }
1660                        s_full_congruent[[i, j]] = v;
1661                    }
1662                }
1663                // Principal submatrix on the retained (shape-varying) columns.
1664                let mut local = Array2::<f64>::zeros((p_time, p_time));
1665                for (i_new, &i_old) in keep_cols.iter().enumerate() {
1666                    for (j_new, &j_old) in keep_cols.iter().enumerate() {
1667                        // Symmetrize on the way out to absorb residual
1668                        // floating-point asymmetry.
1669                        local[[i_new, j_new]] = 0.5
1670                            * (s_full_congruent[[i_old, j_old]] + s_full_congruent[[j_old, i_old]]);
1671                    }
1672                }
1673                penalties.push(local);
1674            }
1675
1676            // PSD contract (gam#979). The value-space congruence Lᵀ S_B[1:,1:] L,
1677            // restricted to a principal submatrix, is positive semidefinite by
1678            // construction. A negative eigenvalue here means the construction has
1679            // regressed to the increment-space / wrong-ordering form that made the
1680            // penalized survival NLL unbounded below (the #979 divergence). Verify
1681            // it here, at construction, so the defect can never silently reach the
1682            // inner solver again. The tolerance is the same relative scale the
1683            // nullspace detection below uses; a numerically tiny negative (round-off
1684            // on the genuine 1-D null direction) is allowed, a structural one is not.
1685            for (idx, s_mat) in penalties.iter().enumerate() {
1686                let p = s_mat.nrows();
1687                if p == 0 {
1688                    continue;
1689                }
1690                if let Ok((evals, _)) =
1691                    gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower)
1692                {
1693                    let evals_slice: &[f64] = evals.as_slice().ok_or_else(|| {
1694                        "internal error: ispline penalty eigenvalues not contiguous".to_string()
1695                    })?;
1696                    let max_ev = evals_slice
1697                        .iter()
1698                        .copied()
1699                        .fold(0.0_f64, |a, b| a.max(b.abs()))
1700                        .max(1.0);
1701                    let min_ev = evals_slice.iter().copied().fold(f64::INFINITY, f64::min);
1702                    let neg_tol = -100.0 * (p as f64) * f64::EPSILON * max_ev;
1703                    if min_ev < neg_tol {
1704                        return Err(format!(
1705                            "internal error (gam#979): assembled ispline time-block penalty {idx} is \
1706                             indefinite (min eigenvalue {min_ev:.3e} < tol {neg_tol:.3e}, max |eig| \
1707                             {max_ev:.3e}); the value-space congruence Lᵀ S_B[1:,1:] L must be PSD"
1708                        ));
1709                    }
1710                }
1711            }
1712
1713            // The value-space penalty S_I = L^T S_B[1:,1:] L has a 1-dimensional
1714            // null space (constant γ ↦ affine c ↦ D₂c = 0). Detect it spectrally
1715            // so the REML uses the generalized logdet over the penalized subspace.
1716            let nullspace_dims: Vec<usize> = penalties
1717                .iter()
1718                .map(|s_mat| {
1719                    let p = s_mat.nrows();
1720                    if p == 0 {
1721                        return 0;
1722                    }
1723                    match gam_linalg::faer_ndarray::FaerEigh::eigh(s_mat, faer::Side::Lower) {
1724                        Ok((evals, _)) => {
1725                            let evals_slice: &[f64] = evals.as_slice().unwrap();
1726                            let max_ev = evals_slice
1727                                .iter()
1728                                .copied()
1729                                .fold(0.0_f64, |a, b| a.max(b.abs()))
1730                                .max(1.0);
1731                            let threshold = 100.0 * (p as f64) * f64::EPSILON * max_ev;
1732                            evals_slice.iter().filter(|&&e| e <= threshold).count()
1733                        }
1734                        Err(_) => 0,
1735                    }
1736                })
1737                .collect();
1738            Ok(SurvivalTimeBuildOutput {
1739                x_entry_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_entry_time)),
1740                x_exit_time: DesignMatrix::Dense(DenseDesignMatrix::from(x_exit_time)),
1741                x_derivative_time,
1742                penalties,
1743                nullspace_dims,
1744                basisname: "ispline".to_string(),
1745                degree: Some(degree),
1746                knots: Some(knotvec.to_vec()),
1747                keep_cols: Some(keep_cols),
1748                smooth_lambda: Some(smooth_lambda),
1749            })
1750        }
1751    }
1752}
1753
1754pub fn resolved_survival_time_basis_config_from_build(
1755    basisname: &str,
1756    degree: Option<usize>,
1757    knots: Option<&Vec<f64>>,
1758    keep_cols: Option<&Vec<usize>>,
1759    smooth_lambda: Option<f64>,
1760) -> Result<SurvivalTimeBasisConfig, String> {
1761    match basisname {
1762        "none" => Ok(SurvivalTimeBasisConfig::None),
1763        "linear" => Ok(SurvivalTimeBasisConfig::Linear),
1764        "bspline" => Ok(SurvivalTimeBasisConfig::BSpline {
1765            degree: degree.ok_or_else(|| "survival bspline basis is missing degree".to_string())?,
1766            knots: Array1::from_vec(
1767                knots
1768                    .cloned()
1769                    .ok_or_else(|| "survival bspline basis is missing knots".to_string())?,
1770            ),
1771            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1772        }),
1773        "ispline" => Ok(SurvivalTimeBasisConfig::ISpline {
1774            degree: degree.ok_or_else(|| "survival ispline basis is missing degree".to_string())?,
1775            knots: Array1::from_vec(
1776                knots
1777                    .cloned()
1778                    .ok_or_else(|| "survival ispline basis is missing knots".to_string())?,
1779            ),
1780            keep_cols: keep_cols
1781                .cloned()
1782                .ok_or_else(|| "survival ispline basis is missing keep_cols".to_string())?,
1783            smooth_lambda: smooth_lambda.unwrap_or(SURVIVAL_TIME_SMOOTH_LAMBDA_SEED),
1784        }),
1785        other => Err(format!("unsupported survival time basis '{other}'")),
1786    }
1787}
1788
1789pub fn resolve_survival_time_anchor_value(
1790    age_entry: &Array1<f64>,
1791    time_anchor: Option<f64>,
1792) -> Result<f64, String> {
1793    if age_entry.is_empty() {
1794        return Err("survival time anchor requires non-empty entry times".to_string());
1795    }
1796    let anchor = match time_anchor {
1797        Some(t_anchor) => {
1798            if !t_anchor.is_finite() || t_anchor < 0.0 {
1799                return Err(format!(
1800                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1801                ));
1802            }
1803            t_anchor
1804        }
1805        None => age_entry
1806            .iter()
1807            .copied()
1808            .min_by(f64::total_cmp)
1809            .ok_or_else(|| "failed to select survival time anchor".to_string())?,
1810    };
1811    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1812}
1813
1814/// Marginal-slope centering anchor: a robust *interior* time on the **exit**
1815/// scale rather than the earliest entry age.
1816///
1817/// `center_survival_time_designs_at_anchor` subtracts the time-basis row at the
1818/// anchor from every entry/exit design row, so the anchor sets the origin of
1819/// the baseline-hazard I-spline's affine reparameterization. The
1820/// location-scale path anchors at the minimum entry age
1821/// ([`resolve_survival_time_anchor_value`]); for right-censored-only data that
1822/// minimum is ≈ the time origin, so centering is nearly a no-op.
1823///
1824/// Under **left truncation** the minimum entry age is a genuine positive
1825/// *left-tail* point, and centering there leaves the centered linear-trend
1826/// column `X(exit) − X(anchor)` large and one-signed across all rows (exit
1827/// times sit far to the right of the earliest entry). That column is the
1828/// unpenalized polynomial null space of the 2nd-difference time penalty, so the
1829/// inflated, one-signed column multiplies the marginal-slope time-block score
1830/// at the `γ = 0` monotone-cone seed up by hundreds — the constrained joint
1831/// Newton cannot certify KKT on it and REML rejects every seed (issue #751).
1832///
1833/// Centering instead at a robust interior location on the *exit* scale — the
1834/// **median exit age**, where the at-risk mass concentrates — keeps the
1835/// centered column small and two-signed (some exits below the median, some
1836/// above), so the exit-event likelihood pins the linear trend and the seed
1837/// score stays bounded. Re-centering is an exact affine reparameterization of
1838/// the baseline offset: the fitted `q(t)` and the REML objective are unchanged,
1839/// only the seed conditioning improves. The median is chosen (over the mean)
1840/// for robustness to the heavy right tail of survival times.
1841///
1842/// An explicit `--survival-time-anchor` is honored verbatim (same validation as
1843/// the location-scale path) so the user retains full control; the saved
1844/// `survival_time_anchor` scalar round-trips to predict unchanged.
1845pub fn resolve_survival_marginal_slope_time_anchor_value(
1846    age_entry: &Array1<f64>,
1847    age_exit: &Array1<f64>,
1848    time_anchor: Option<f64>,
1849) -> Result<f64, String> {
1850    if age_entry.is_empty() || age_exit.is_empty() {
1851        return Err(
1852            "survival marginal-slope time anchor requires non-empty entry/exit times".to_string(),
1853        );
1854    }
1855    let anchor = match time_anchor {
1856        Some(t_anchor) => {
1857            if !t_anchor.is_finite() || t_anchor < 0.0 {
1858                return Err(format!(
1859                    "survival time anchor must be finite and non-negative, got {t_anchor}"
1860                ));
1861            }
1862            t_anchor
1863        }
1864        None => {
1865            let mut sorted: Vec<f64> = age_exit.iter().copied().collect();
1866            sorted.sort_by(f64::total_cmp);
1867            let m = sorted.len();
1868            if m % 2 == 1 {
1869                sorted[m / 2]
1870            } else {
1871                0.5 * (sorted[m / 2 - 1] + sorted[m / 2])
1872            }
1873        }
1874    };
1875    Ok(anchor.max(SURVIVAL_TIME_FLOOR))
1876}
1877
1878pub fn evaluate_survival_time_basis_row(
1879    age: f64,
1880    cfg: &SurvivalTimeBasisConfig,
1881) -> Result<Array1<f64>, String> {
1882    if !age.is_finite() || age < 0.0 {
1883        return Err(format!(
1884            "survival time basis row requires finite non-negative age, got {age}"
1885        ));
1886    }
1887    let age = age.max(SURVIVAL_TIME_FLOOR);
1888    let log_age = array![age.ln()];
1889    match cfg {
1890        SurvivalTimeBasisConfig::None => Ok(Array1::zeros(0)),
1891        SurvivalTimeBasisConfig::Linear => Ok(array![1.0, age.ln()]),
1892        SurvivalTimeBasisConfig::BSpline { degree, knots, .. } => {
1893            if knots.is_empty() {
1894                return Err(
1895                    "survival BSpline anchor evaluation requires resolved knot metadata"
1896                        .to_string(),
1897                );
1898            }
1899            let built = build_bspline_basis_1d(
1900                log_age.view(),
1901                &BSplineBasisSpec {
1902                    degree: *degree,
1903                    penalty_order: 2,
1904                    knotspec: BSplineKnotSpec::Provided(knots.clone()),
1905                    double_penalty: false,
1906                    identifiability: BSplineIdentifiability::None,
1907                    boundary: OneDimensionalBoundary::Open,
1908                    boundary_conditions: BSplineBoundaryConditions::default(),
1909                },
1910            )
1911            .map_err(|e| format!("failed to evaluate survival bspline anchor row: {e}"))?;
1912            Ok(built.design.to_dense().row(0).to_owned())
1913        }
1914        SurvivalTimeBasisConfig::ISpline {
1915            degree,
1916            knots,
1917            keep_cols,
1918            ..
1919        } => {
1920            if knots.is_empty() {
1921                return Err(
1922                    "survival ISpline anchor evaluation requires resolved knot metadata"
1923                        .to_string(),
1924                );
1925            }
1926            let (basis_arc, _) = create_basis::<Dense>(
1927                log_age.view(),
1928                KnotSource::Provided(knots.view()),
1929                *degree,
1930                BasisOptions::i_spline(),
1931            )
1932            .map_err(|e| format!("failed to evaluate survival ispline anchor row: {e}"))?;
1933            let basis = basis_arc.as_ref();
1934            let row = basis.row(0);
1935            if keep_cols.is_empty() {
1936                return Ok(row.to_owned());
1937            }
1938            if keep_cols.iter().any(|&j| j >= row.len()) {
1939                return Err(SurvivalConstructionError::MissingColumn {
1940                    reason: "survival ISpline anchor keep_cols exceed basis width".to_string(),
1941                }
1942                .into());
1943            }
1944            Ok(Array1::from_iter(keep_cols.iter().map(|&j| row[j])))
1945        }
1946    }
1947}
1948
1949pub fn center_survival_time_designs_at_anchor(
1950    design_entry: &mut DesignMatrix,
1951    design_exit: &mut DesignMatrix,
1952    anchor_row: &Array1<f64>,
1953) -> Result<(), String> {
1954    if design_entry.ncols() != anchor_row.len() || design_exit.ncols() != anchor_row.len() {
1955        return Err(format!(
1956            "survival time anchoring column mismatch: entry={}, exit={}, anchor={}",
1957            design_entry.ncols(),
1958            design_exit.ncols(),
1959            anchor_row.len()
1960        ));
1961    }
1962    // Centering destroys sparsity (every row gets a dense offset), so
1963    // materialize to dense.  This only runs once at construction time.
1964    fn center_dense(dm: &mut DesignMatrix, anchor: &Array1<f64>) {
1965        let mut dense = dm.to_dense();
1966        for mut row in dense.rows_mut() {
1967            row -= &anchor.view();
1968        }
1969        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(dense));
1970    }
1971    center_dense(design_entry, anchor_row);
1972    center_dense(design_exit, anchor_row);
1973    Ok(())
1974}
1975
1976// ---------------------------------------------------------------------------
1977// Baseline evaluation (Gompertz, Weibull, Gompertz-Makeham)
1978// ---------------------------------------------------------------------------
1979
1980/// Partial derivatives of the baseline offsets `(eta_target, d_eta_target/dt)`
1981/// with respect to the θ-parameters in the same parameterization that
1982/// [`survival_baseline_theta_from_config`] / [`survival_baseline_config_from_theta`]
1983/// use:
1984///
1985/// - **Weibull**: θ = (log_scale, log_shape).  `eta = shape·(log t − log scale)`,
1986///   `o_D = shape/t`.
1987/// - **Gompertz**: θ = (log_rate, shape).  `eta = log H_G(t)` with
1988///   `H_G(t) = (rate/shape)·(exp(shape·t) − 1)`, `o_D = h_G(t)/H_G(t) =
1989///   shape·E/(E−1)` where `E = exp(shape·t)`.
1990/// - **Gompertz–Makeham**: θ = (log_rate, shape, log_makeham).
1991///   `eta = log H(t)` with `H(t) = makeham·t + H_G(t)`,
1992///   `o_D = (makeham + h_G(t)) / H(t)`.
1993///
1994/// Returns a flat `(d_eta/dθ_k, d_oD/dθ_k)` pair for each component of θ,
1995/// in the same order as `survival_baseline_theta_from_config`.  Linear has
1996/// no θ-parameters so returns `Ok(None)`.
1997///
1998/// The `eta`-channel derivatives are closed-form for every branch.  The
1999/// `o_D`-channel derivatives use the log-derivative identity
2000/// `∂o_D/∂θ = o_D · ∂log(o_D)/∂θ` which is more numerically stable near
2001/// the small-shape limit (shape·t → 0).  Near shape = 0 we fall back to
2002/// a third-order Taylor expansion with the same 1e-10 pivot that
2003/// `evaluate_survival_baseline` uses, keeping the value/derivative pair
2004/// continuous and agreement with the linear-hazard limit exact at shape=0.
2005pub fn baseline_offset_theta_partials(
2006    age: f64,
2007    cfg: &SurvivalBaselineConfig,
2008) -> Result<Option<Vec<(f64, f64)>>, String> {
2009    let Some(params) = validated_baseline_params(age, cfg, "baseline derivative evaluation")?
2010    else {
2011        return Ok(None);
2012    };
2013
2014    match params {
2015        ValidatedBaselineTarget::Weibull { scale, shape } => {
2016            // eta = shape·(log t − log scale)
2017            //     = shape·log t − shape·log scale
2018            // o_D = shape / t
2019            //
2020            // θ = (log_scale, log_shape):
2021            //   ∂eta/∂log_scale  = −shape          ∂o_D/∂log_scale = 0
2022            //   ∂eta/∂log_shape  = shape·(log t − log scale) = eta
2023            //   ∂o_D/∂log_shape  = shape / t = o_D
2024            let eta = shape * (age.ln() - scale.ln());
2025            let o_d = shape / age;
2026            let d_eta_d_log_scale = -shape;
2027            let d_od_d_log_scale = 0.0;
2028            let d_eta_d_log_shape = eta;
2029            let d_od_d_log_shape = o_d;
2030            Ok(Some(vec![
2031                (d_eta_d_log_scale, d_od_d_log_scale),
2032                (d_eta_d_log_shape, d_od_d_log_shape),
2033            ]))
2034        }
2035        ValidatedBaselineTarget::Gompertz { shape, .. } => {
2036            // θ = (log_rate, shape):
2037            //   Rate cancels in o_D = h/H for Gompertz, so ∂o_D/∂log_rate = 0
2038            //   and ∂eta/∂log_rate = 1. The shape channel uses
2039            //     ∂eta/∂shape   = −1/shape + t·E/(E−1)
2040            //     ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2041            //     ∂o_D/∂shape  = o_D · ∂log(o_D)/∂shape
2042            //   Near shape=0 both numerators are 1/shape cancellations. Use
2043            //   Taylor expansions with the same 1e-10 pivot that
2044            //   gompertz_components uses in evaluate_survival_baseline.
2045            let (d_eta_d_shape, d_od_d_shape) = gompertz_shape_derivatives(age, shape);
2046            Ok(Some(vec![(1.0, 0.0), (d_eta_d_shape, d_od_d_shape)]))
2047        }
2048        ValidatedBaselineTarget::GompertzMakeham {
2049            rate,
2050            shape,
2051            makeham,
2052        } => {
2053            // H(t) = M·t + H_G(t),   H_G(t) = (rate/shape)·(E−1),  E = exp(shape·t)
2054            // h(t) = M + h_G(t),     h_G(t) = rate·E
2055            // o_D  = h/H
2056            //
2057            // θ = (log_rate, shape, log_makeham):
2058            //   ∂H/∂log_rate    = rate · ∂H/∂rate = H_G               (scales with rate)
2059            //   ∂H/∂shape       = H_G_shape                            (closed form below)
2060            //   ∂H/∂log_makeham = makeham · t                          (linear in makeham)
2061            //   ∂h/∂log_rate    = rate · ∂h/∂rate = h_G
2062            //   ∂h/∂shape       = h_G_shape = rate·t·E + 0              (= rate·t·E)
2063            //   ∂h/∂log_makeham = makeham
2064            //   ∂eta/∂θ = (∂H/∂θ) / H
2065            //   ∂o_D/∂θ = (∂h/∂θ − o_D · ∂H/∂θ) / H
2066            //           = (∂h/∂θ)/H − o_D · (∂H/∂θ)/H
2067            let (cum_g, inst_g) = gompertz_hazard_components(age, rate, shape);
2068            let cum_total = makeham * age + cum_g;
2069            if cum_total <= 0.0 || !cum_total.is_finite() {
2070                return Err(SurvivalConstructionError::DataValidationFailed {
2071                    reason: "gm baseline produced non-positive cumulative hazard".to_string(),
2072                }
2073                .into());
2074            }
2075            let inst_total = makeham + inst_g;
2076            let o_d = inst_total / cum_total;
2077            let inv_cum = 1.0 / cum_total;
2078            // Each channel: ∂cum/∂θ and ∂inst/∂θ → ∂eta/∂θ = ∂cum/∂θ / cum
2079            //                                       ∂o_D/∂θ = (∂inst/∂θ − o_D·∂cum/∂θ) / cum
2080            // log_rate channel: cum is linear in rate through H_G; ∂cum/∂rate = H_G/rate,
2081            //   so ∂cum/∂log_rate = H_G (= cum_g here). Similarly ∂inst/∂log_rate = h_G (= inst_g).
2082            let d_cum_dlr = cum_g;
2083            let d_inst_dlr = inst_g;
2084            let d_eta_dlr = d_cum_dlr * inv_cum;
2085            let d_od_dlr = (d_inst_dlr - o_d * d_cum_dlr) * inv_cum;
2086            // shape channel: only H_G and h_G have shape dependence.
2087            let (d_cum_dshape, d_inst_dshape) =
2088                gompertz_cumulative_shape_derivative(age, rate, shape);
2089            let d_eta_dshape = d_cum_dshape * inv_cum;
2090            let d_od_dshape = (d_inst_dshape - o_d * d_cum_dshape) * inv_cum;
2091            // log_makeham channel: cum contributes M·t, inst contributes M.
2092            //   ∂cum/∂log_makeham = makeham·t,  ∂inst/∂log_makeham = makeham.
2093            let d_cum_dlm = makeham * age;
2094            let d_inst_dlm = makeham;
2095            let d_eta_dlm = d_cum_dlm * inv_cum;
2096            let d_od_dlm = (d_inst_dlm - o_d * d_cum_dlm) * inv_cum;
2097            Ok(Some(vec![
2098                (d_eta_dlr, d_od_dlr),
2099                (d_eta_dshape, d_od_dshape),
2100                (d_eta_dlm, d_od_dlm),
2101            ]))
2102        }
2103    }
2104}
2105
2106/// Shared chain-rule θ-gradient contraction for baseline offsets.
2107///
2108/// Both [`baseline_chain_rule_gradient`] (RP eta offsets) and
2109/// [`marginal_slope_baseline_chain_rule_gradient`] (probit q-offsets) reduce to
2110/// the same contraction of [`OffsetChannelResiduals`] against per-age baseline
2111/// θ-partials; only the `partials` provider differs. This engine owns the length
2112/// checks, the θ-dim probe, the parallel per-row reduction, the entry gating, and
2113/// the error handling. Each provider returns, per age, a length-`theta_dim` vector
2114/// of `(∂eta/∂θ_k, ∂(d eta/dt)/∂θ_k)` pairs (or `(∂q/∂θ_k, ∂(dq/dt)/∂θ_k)` for the
2115/// probit channel), and `None` when `cfg` has no θ-parameters (`Linear` target).
2116///
2117/// Contract (envelope theorem at converged β; the penalty has no θ dependence):
2118///
2119///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2120///     = Σᵢ r_X[i]·(∂o_X_i/∂θ_k) + r_D[i]·(∂o_D_i/∂θ_k) + r_E[i]·(∂o_E_i/∂θ_k)
2121///       + r_R[i]·(∂o_R_i/∂θ_k)
2122///
2123/// where `r_X = residuals.exit`, `r_D = residuals.derivative`, `r_E =
2124/// residuals.entry`, `r_R = residuals.right` (all sampleweight-scaled already).
2125/// Exit and derivative partials both come from the `age_exit[i]` evaluation;
2126/// the entry partial from `age_entry[i]`; the interval upper-bound (`R`)
2127/// η-partial from `age_right[i]`. Origin-entry rows have `r_E[i] == 0` exactly
2128/// and non-interval rows have `r_R[i] == 0` exactly, so those partials are
2129/// skipped for those rows (avoiding the `age > 0` precondition failure when an
2130/// inactive boundary age is 0 / a placeholder).
2131///
2132/// Returns `Ok(None)` when the provider reports no θ-parameters.
2133fn baseline_chain_rule_gradient_with_partials<F>(
2134    label: &'static str,
2135    age_entry: ndarray::ArrayView1<'_, f64>,
2136    age_exit: ndarray::ArrayView1<'_, f64>,
2137    age_right: ndarray::ArrayView1<'_, f64>,
2138    cfg: &SurvivalBaselineConfig,
2139    residuals: &crate::survival::OffsetChannelResiduals,
2140    partials: F,
2141) -> Result<Option<Array1<f64>>, String>
2142where
2143    F: Fn(f64, &SurvivalBaselineConfig) -> Result<Option<Vec<(f64, f64)>>, String> + Sync,
2144{
2145    let n = age_exit.len();
2146    if age_entry.len() != n
2147        || age_right.len() != n
2148        || residuals.exit.len() != n
2149        || residuals.entry.len() != n
2150        || residuals.derivative.len() != n
2151        || residuals.right.len() != n
2152    {
2153        return Err(format!(
2154            "{label}: length mismatch (age_entry={}, age_exit={}, age_right={}, r_exit={}, r_entry={}, r_deriv={}, r_right={})",
2155            age_entry.len(),
2156            n,
2157            age_right.len(),
2158            residuals.exit.len(),
2159            residuals.entry.len(),
2160            residuals.derivative.len(),
2161            residuals.right.len(),
2162        ));
2163    }
2164    // Probe θ-dim via any valid positive age. If the provider returns None the
2165    // config carries no θ-parameters (Linear target) and there is no θ-gradient.
2166    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2167    let theta_dim = match probe_age {
2168        Some(t) => match partials(t, cfg)? {
2169            None => return Ok(None),
2170            Some(v) => v.len(),
2171        },
2172        None => {
2173            return Err(format!("{label}: no valid positive age for dim probe"));
2174        }
2175    };
2176    // Per-row partial contractions are independent, but each row's
2177    // contribution is a `theta_dim`-vector of `O(theta_dim · partial_cost)`
2178    // flops — small enough that the rayon parallel reduction's split
2179    // overhead dominates for any plausible `theta_dim`, *and* the
2180    // non-associative IEEE-754 sum order across thread chunks made the
2181    // engine drift in the low-order bits from row to row. The serial
2182    // accumulator below mirrors the inline reference exactly (and remains
2183    // ~memory-bandwidth-bound at large-scale `n`), so the engine is now a
2184    // bit-for-bit replacement for the legacy path, not just a
2185    // floating-point-noise-equivalent one.
2186    let mut grad = Array1::<f64>::zeros(theta_dim);
2187    for i in 0..n {
2188        // Exit + derivative partials both come from the age_exit evaluation.
2189        let partials_exit = partials(age_exit[i], cfg)?
2190            .ok_or_else(|| format!("{label}: unexpected None from partials at exit"))?;
2191        if partials_exit.len() != theta_dim {
2192            return Err(format!(
2193                "{label}: theta_dim drifted ({} != {})",
2194                partials_exit.len(),
2195                theta_dim
2196            ));
2197        }
2198        let r_x = residuals.exit[i];
2199        let r_d = residuals.derivative[i];
2200        for k in 0..theta_dim {
2201            let (d_eta_dk, d_od_dk) = partials_exit[k];
2202            grad[k] += r_x * d_eta_dk + r_d * d_od_dk;
2203        }
2204        // Entry channel is nonzero only for rows with a positive entry
2205        // interval; for origin-entry rows age_entry may be 0 and calling
2206        // the provider would error. Gate on residual==0.
2207        let r_e = residuals.entry[i];
2208        if r_e != 0.0 {
2209            let partials_entry = partials(age_entry[i], cfg)?
2210                .ok_or_else(|| format!("{label}: unexpected None from partials at entry"))?;
2211            for k in 0..theta_dim {
2212                grad[k] += r_e * partials_entry[k].0;
2213            }
2214        }
2215        // Interval upper-bound (`R`) channel: `q_right = X_time(R)·β + o_R(θ)`
2216        // carries its own baseline-θ η-offset evaluated at `age_right[i]`. It is
2217        // an η-level offset with NO time-derivative channel (the interval
2218        // likelihood `log[S(L) − S(R)]` has no hazard-derivative term), so it
2219        // contracts against the η-partial `.0` only. Nonzero only for
2220        // interval-censored latent rows; for every other channel/model
2221        // `r_right[i] == 0` exactly, so the (possibly placeholder) `age_right[i]`
2222        // partial is never consulted.
2223        let r_r = residuals.right[i];
2224        if r_r != 0.0 {
2225            let partials_right = partials(age_right[i], cfg)?.ok_or_else(|| {
2226                format!("{label}: unexpected None from partials at right boundary")
2227            })?;
2228            if partials_right.len() != theta_dim {
2229                return Err(format!(
2230                    "{label}: theta_dim drifted at right boundary ({} != {})",
2231                    partials_right.len(),
2232                    theta_dim
2233                ));
2234            }
2235            for k in 0..theta_dim {
2236                grad[k] += r_r * partials_right[k].0;
2237            }
2238        }
2239    }
2240    Ok(Some(grad))
2241}
2242
2243/// Contract `OffsetChannelResiduals` against `baseline_offset_theta_partials`
2244/// to produce the closed-form θ-gradient of the unpenalized NLL at converged β.
2245///
2246/// Derivation (envelope theorem on the penalized objective, β* minimizes the
2247/// same cost wrt β and the penalty has no θ dependence):
2248///
2249///   d[0.5·deviance + 0.5·βᵀS_λβ] / dθ_k
2250///     = d[NLL(β*; o(θ))] / dθ_k
2251///     = Σᵢ (∂NLL_i/∂o_X[i])·(∂o_X_i/∂θ_k)
2252///       + (∂NLL_i/∂o_E[i])·(∂o_E_i/∂θ_k)
2253///       + (∂NLL_i/∂o_D[i])·(∂o_D_i/∂θ_k)
2254///       + (∂NLL_i/∂o_R[i])·(∂o_R_i/∂θ_k)
2255///
2256/// The four `∂NLL_i/∂o_channel` terms are the `exit`, `entry`, `derivative`,
2257/// `right` fields of [`OffsetChannelResiduals`] (sampleweight-scaled already).
2258/// The `∂o/∂θ_k` terms come from [`baseline_offset_theta_partials`] per obs at
2259/// the appropriate age.
2260///
2261/// Per the RP offset convention:
2262///   o_E[i] = eta_target(age_entry[i])
2263///   o_X[i] = eta_target(age_exit[i])
2264///   o_D[i] = d/dt eta_target(t) |_{t=age_exit[i]}
2265///   o_R[i] = eta_target(age_right[i])   (interval upper bound `R`; η-level only)
2266///
2267/// so the exit and derivative partials are both evaluated at `age_exit[i]`,
2268/// the entry partial at `age_entry[i]`, and the interval-right η-partial at
2269/// `age_right[i]`. The origin-entry case (`entry_at_origin[i]`) has
2270/// `r_entry[i] = 0` exactly and every non-interval row has `r_right[i] = 0`
2271/// exactly, so we skip the `baseline_offset_theta_partials(age, ..)` call for
2272/// those rows (avoiding the `age > 0` precondition failure when an inactive
2273/// boundary age is 0 / a placeholder).
2274///
2275/// Returns `Ok(None)` when `cfg.target == Linear` (no θ-parameters).
2276pub fn baseline_chain_rule_gradient(
2277    age_entry: ndarray::ArrayView1<'_, f64>,
2278    age_exit: ndarray::ArrayView1<'_, f64>,
2279    age_right: ndarray::ArrayView1<'_, f64>,
2280    cfg: &SurvivalBaselineConfig,
2281    residuals: &crate::survival::OffsetChannelResiduals,
2282) -> Result<Option<Array1<f64>>, String> {
2283    baseline_chain_rule_gradient_with_partials(
2284        "baseline_chain_rule_gradient",
2285        age_entry,
2286        age_exit,
2287        age_right,
2288        cfg,
2289        residuals,
2290        baseline_offset_theta_partials,
2291    )
2292}
2293
2294/// Chain-rule θ-gradient for marginal-slope probit baseline offsets.
2295///
2296/// This is the probit-survival counterpart of [`baseline_chain_rule_gradient`].
2297/// It contracts residuals against
2298/// [`marginal_slope_baseline_offset_theta_partials`], so the offset channels
2299/// are `(q_entry, q_exit, dq_exit/dt)` with `Phi(-q(t)) = exp(-H0(t))`.
2300pub fn marginal_slope_baseline_chain_rule_gradient(
2301    age_entry: ndarray::ArrayView1<'_, f64>,
2302    age_exit: ndarray::ArrayView1<'_, f64>,
2303    cfg: &SurvivalBaselineConfig,
2304    residuals: &crate::survival::OffsetChannelResiduals,
2305) -> Result<Option<Array1<f64>>, String> {
2306    // Marginal-slope has no interval upper-bound channel; `residuals.right` is
2307    // all-zero, so the right channel never contracts and `age_exit` serves as an
2308    // unconsulted placeholder for the (unused) `age_right` argument.
2309    baseline_chain_rule_gradient_with_partials(
2310        "marginal_slope_baseline_chain_rule_gradient",
2311        age_entry,
2312        age_exit,
2313        age_exit,
2314        cfg,
2315        residuals,
2316        marginal_slope_baseline_offset_theta_partials,
2317    )
2318}
2319
2320/// Shared Gompertz hazard components `(H_G(t), h_G(t))`.
2321/// Mirrors the private helper in `evaluate_survival_baseline` with the
2322/// same 1e-10 small-shape pivot.
2323#[inline]
2324fn gompertz_hazard_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2325    if shape.abs() < 1e-10 {
2326        // Taylor at shape=0: H_G(t) = rate·t·(1 + shape·t/2 + (shape·t)²/6),
2327        // h_G(t) = rate·(1 + shape·t + (shape·t)²/2).
2328        let x = shape * age;
2329        (
2330            rate * age * (1.0 + 0.5 * x + x * x / 6.0),
2331            rate * (1.0 + x + 0.5 * x * x),
2332        )
2333    } else {
2334        let shape_age = shape * age;
2335        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
2336        let instant_hazard = rate * shape_age.exp();
2337        (cumulative_hazard, instant_hazard)
2338    }
2339}
2340
2341/// Partials of `(H_G(t), h_G(t))` with respect to the shape parameter.
2342///
2343/// H_G(t) = (rate/shape)·(E−1),  h_G(t) = rate·E,  E = exp(shape·t)
2344///
2345/// ∂H_G/∂shape  = −(rate/shape²)·(E−1) + (rate/shape)·t·E
2346///              = rate·[t·E/shape − (E−1)/shape²]
2347///              = rate·[t·E·shape − (E−1)] / shape²
2348/// ∂h_G/∂shape  = rate·t·E
2349///
2350/// Near shape=0 the first expression has a 1/shape² singularity that
2351/// cancels analytically. Using the series E−1 = Σₖ≥₁ (shape·t)ᵏ/k!:
2352///   t·E·shape − (E−1) = Σₖ≥₁ (shape·t)ᵏ·(k−1)/k!·shape⁰  [after simplification]
2353///                     = (shape·t)²/2 + 2(shape·t)³/6 + 3(shape·t)⁴/24 + ...
2354/// so ∂H_G/∂shape at shape→0 = rate·[t²/2 + shape·t³/3 + shape²·t⁴/8 + ...].
2355/// We use that Taylor expansion in the small-shape branch.
2356#[inline]
2357fn gompertz_cumulative_shape_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2358    let x = shape * age;
2359    let dinstg_dshape = rate * age * x.exp();
2360    // The exact form rate·[t·E·shape − (E−1)]/shape² is a difference of two
2361    // O(1/shape) quantities whose leading terms cancel, so its accuracy is
2362    // governed by the dimensionless product x = shape·age, NOT by `shape`
2363    // alone. Pivoting on `shape < 1e-10` ignored `age`: for large ages a small
2364    // shape still yields a small x where the catastrophic cancellation has
2365    // already corrupted the difference. Pivot on x instead; the 3-term Taylor
2366    // (through O(x²)) is accurate to <1e-9 for |x| < 1e-4, and the exact branch
2367    // is clean above it.
2368    let dhg_dshape = if x.abs() < 1e-4 {
2369        let t = age;
2370        // Truncated to O(x³): t²/2 + x·t²/3 + x²·t²/8
2371        rate * t * t * (0.5 + x / 3.0 + x * x / 8.0)
2372    } else {
2373        // t·E·shape − (E−1) = t·e^x·shape − expm1(x)
2374        let e = x.exp();
2375        let em1 = x.exp_m1();
2376        let numerator = age * e * shape - em1;
2377        rate * numerator / (shape * shape)
2378    };
2379    (dhg_dshape, dinstg_dshape)
2380}
2381
2382/// Partials `(∂eta/∂shape, ∂o_D/∂shape)` for the pure Gompertz baseline.
2383/// Pure Gompertz has rate cancelling in o_D, so there is no log_rate
2384/// contribution in o_D. The rate channel for eta is trivially 1; this
2385/// helper only covers the shape channel.
2386#[inline]
2387fn gompertz_shape_derivatives(age: f64, shape: f64) -> (f64, f64) {
2388    if shape.abs() < 1e-10 {
2389        // Closed-form limits from the series t·E/(E−1) = 1/x + 1/2 + x/12 + ...
2390        // with E = e^x, x = shape·t:
2391        //   ∂eta/∂shape  = −1/shape + t·E/(E−1)
2392        //                = t/2 + shape·t²/12 + O(shape²)
2393        //   o_D         = shape·E/(E−1)
2394        //                = 1/t + shape/2 + shape²·t/12 + O(shape³)
2395        //   ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2396        //                = t/2 − shape·t²/12 + O(shape²)
2397        //   ∂o_D/∂shape = o_D · ∂log(o_D)/∂shape
2398        let t = age;
2399        let d_eta = 0.5 * t + shape * t * t / 12.0;
2400        let dlog_od = 0.5 * t - shape * t * t / 12.0;
2401        let o_d = 1.0 / t + 0.5 * shape + shape * shape * t / 12.0;
2402        (d_eta, o_d * dlog_od)
2403    } else {
2404        let x = shape * age;
2405        let e = x.exp();
2406        let em1 = x.exp_m1(); // E − 1 via expm1 for accuracy at small x
2407        let d_eta = -1.0 / shape + age * e / em1;
2408        // o_D = shape · E/(E−1); ∂log(o_D)/∂shape = 1/shape − t/(E−1)
2409        let o_d = shape * e / em1;
2410        let dlog_od = 1.0 / shape - age / em1;
2411        (d_eta, o_d * dlog_od)
2412    }
2413}
2414
2415/// Per-target baseline parameters after the shared age guard and the per-target
2416/// required-field extraction + finiteness/positivity validation have passed.
2417///
2418/// This is the single source of truth for *which* config fields each baseline
2419/// target requires and *what* domain each must satisfy. Both the hazard-value
2420/// evaluator (`survival_cumulative_and_instant_hazard`) and the θ-partials
2421/// evaluator (`survival_hazard_theta_partials`) consume it and only differ in how
2422/// they assemble their (value vs derivative) outputs from these checked scalars.
2423#[derive(Clone, Copy, Debug)]
2424enum ValidatedBaselineTarget {
2425    Weibull { scale: f64, shape: f64 },
2426    Gompertz { rate: f64, shape: f64 },
2427    GompertzMakeham { rate: f64, shape: f64, makeham: f64 },
2428}
2429
2430/// Shared prologue for the survival baseline hazard evaluators: validate the age,
2431/// then extract and domain-check the per-target parameters from `cfg`.
2432///
2433/// `Ok(None)` is the `Linear` target (no parametric baseline). `context` is woven
2434/// into the age-guard error so each caller keeps its specific phrasing.
2435fn validated_baseline_params(
2436    age: f64,
2437    cfg: &SurvivalBaselineConfig,
2438    context: &str,
2439) -> Result<Option<ValidatedBaselineTarget>, String> {
2440    if !age.is_finite() || age <= 0.0 {
2441        return Err(format!(
2442            "survival ages must be finite and positive for {context}"
2443        ));
2444    }
2445
2446    match cfg.target {
2447        SurvivalBaselineTarget::Linear => Ok(None),
2448        SurvivalBaselineTarget::Weibull => {
2449            let scale = cfg
2450                .scale
2451                .ok_or_else(|| "weibull missing scale".to_string())?;
2452            let shape = cfg
2453                .shape
2454                .ok_or_else(|| "weibull missing shape".to_string())?;
2455            if !(scale.is_finite() && shape.is_finite() && scale > 0.0 && shape > 0.0) {
2456                return Err(SurvivalConstructionError::InvalidConfig {
2457                    reason: "weibull baseline requires finite positive scale and shape".to_string(),
2458                }
2459                .into());
2460            }
2461            Ok(Some(ValidatedBaselineTarget::Weibull { scale, shape }))
2462        }
2463        SurvivalBaselineTarget::Gompertz => {
2464            let rate = cfg
2465                .rate
2466                .ok_or_else(|| "gompertz missing rate".to_string())?;
2467            let shape = cfg
2468                .shape
2469                .ok_or_else(|| "gompertz missing shape".to_string())?;
2470            if !(rate.is_finite() && shape.is_finite() && rate > 0.0) {
2471                return Err(
2472                    "gompertz baseline requires finite positive rate and finite shape".to_string(),
2473                );
2474            }
2475            Ok(Some(ValidatedBaselineTarget::Gompertz { rate, shape }))
2476        }
2477        SurvivalBaselineTarget::GompertzMakeham => {
2478            let rate = cfg
2479                .rate
2480                .ok_or_else(|| "gompertz-makeham missing rate".to_string())?;
2481            let shape = cfg
2482                .shape
2483                .ok_or_else(|| "gompertz-makeham missing shape".to_string())?;
2484            let makeham = cfg
2485                .makeham
2486                .ok_or_else(|| "gompertz-makeham missing makeham".to_string())?;
2487            if !(rate.is_finite()
2488                && shape.is_finite()
2489                && makeham.is_finite()
2490                && rate > 0.0
2491                && makeham > 0.0)
2492            {
2493                return Err(
2494                    "gompertz-makeham baseline requires finite positive rate, makeham, and finite shape"
2495                        .to_string(),
2496                );
2497            }
2498            Ok(Some(ValidatedBaselineTarget::GompertzMakeham {
2499                rate,
2500                shape,
2501                makeham,
2502            }))
2503        }
2504    }
2505}
2506
2507fn survival_hazard_theta_partials(
2508    age: f64,
2509    cfg: &SurvivalBaselineConfig,
2510) -> Result<Option<Vec<(f64, f64)>>, String> {
2511    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard partials")? else {
2512        return Ok(None);
2513    };
2514
2515    match params {
2516        ValidatedBaselineTarget::Weibull { scale, shape } => {
2517            let log_time_ratio = age.ln() - scale.ln();
2518            let cumulative_hazard = (age / scale).powf(shape);
2519            let instant_hazard = shape * cumulative_hazard / age;
2520            let eta = shape * log_time_ratio;
2521            Ok(Some(vec![
2522                (-shape * cumulative_hazard, -shape * instant_hazard),
2523                (eta * cumulative_hazard, (1.0 + eta) * instant_hazard),
2524            ]))
2525        }
2526        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2527            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2528            let (d_cum_dshape, d_inst_dshape) =
2529                gompertz_cumulative_shape_derivative(age, rate, shape);
2530            Ok(Some(vec![
2531                (cumulative_hazard, instant_hazard),
2532                (d_cum_dshape, d_inst_dshape),
2533            ]))
2534        }
2535        ValidatedBaselineTarget::GompertzMakeham {
2536            rate,
2537            shape,
2538            makeham,
2539        } => {
2540            let (cum_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2541            let (d_cum_dshape, d_inst_dshape) =
2542                gompertz_cumulative_shape_derivative(age, rate, shape);
2543            Ok(Some(vec![
2544                (cum_gompertz, inst_gompertz),
2545                (d_cum_dshape, d_inst_dshape),
2546                (makeham * age, makeham),
2547            ]))
2548        }
2549    }
2550}
2551
2552fn survival_cumulative_and_instant_hazard(
2553    age: f64,
2554    cfg: &SurvivalBaselineConfig,
2555) -> Result<Option<(f64, f64)>, String> {
2556    let Some(params) = validated_baseline_params(age, cfg, "baseline hazard evaluation")? else {
2557        return Ok(None);
2558    };
2559
2560    match params {
2561        ValidatedBaselineTarget::Weibull { scale, shape } => {
2562            let cumulative_hazard = (age / scale).powf(shape);
2563            let instant_hazard = shape * cumulative_hazard / age;
2564            Ok(Some((cumulative_hazard, instant_hazard)))
2565        }
2566        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2567            let (cumulative_hazard, instant_hazard) = gompertz_hazard_components(age, rate, shape);
2568            Ok(Some((cumulative_hazard, instant_hazard)))
2569        }
2570        ValidatedBaselineTarget::GompertzMakeham {
2571            rate,
2572            shape,
2573            makeham,
2574        } => {
2575            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2576            Ok(Some((makeham * age + h_gompertz, makeham + inst_gompertz)))
2577        }
2578    }
2579}
2580
2581#[derive(Clone, Copy, Debug)]
2582struct MarginalSlopeBaselinePoint {
2583    instant_hazard: f64,
2584    q: f64,
2585    q_t: f64,
2586}
2587
2588fn evaluate_marginal_slope_baseline_point(
2589    age: f64,
2590    cfg: &SurvivalBaselineConfig,
2591) -> Result<Option<MarginalSlopeBaselinePoint>, String> {
2592    let Some((cumulative_hazard, instant_hazard)) =
2593        survival_cumulative_and_instant_hazard(age, cfg)?
2594    else {
2595        return Ok(None);
2596    };
2597    if !(cumulative_hazard.is_finite() && cumulative_hazard > 0.0) {
2598        return Err(format!(
2599            "{} marginal-slope baseline produced non-positive cumulative hazard",
2600            survival_baseline_targetname(cfg.target)
2601        ));
2602    }
2603    if !(instant_hazard.is_finite() && instant_hazard > 0.0) {
2604        return Err(format!(
2605            "{} marginal-slope baseline produced non-positive instant hazard",
2606            survival_baseline_targetname(cfg.target)
2607        ));
2608    }
2609    let survival = (-cumulative_hazard).exp();
2610    if !(survival.is_finite() && survival > 0.0 && survival < 1.0) {
2611        return Err(format!(
2612            "{} marginal-slope baseline survival must be strictly inside (0,1), got {survival}",
2613            survival_baseline_targetname(cfg.target)
2614        ));
2615    }
2616    let q = -standard_normal_quantile(survival).map_err(|e| {
2617        format!(
2618            "{} marginal-slope baseline failed to invert survival probability {survival}: {e}",
2619            survival_baseline_targetname(cfg.target)
2620        )
2621    })?;
2622    let phi_q = normal_pdf(q);
2623    if !(phi_q.is_finite() && phi_q > 0.0) {
2624        return Err(format!(
2625            "{} marginal-slope baseline produced non-positive probit density phi(q)={phi_q} at q={q}",
2626            survival_baseline_targetname(cfg.target)
2627        ));
2628    }
2629    Ok(Some(MarginalSlopeBaselinePoint {
2630        instant_hazard,
2631        q,
2632        q_t: instant_hazard * survival / phi_q,
2633    }))
2634}
2635
2636/// Evaluate the parametric baseline target at a given age.
2637/// Returns `(eta_target(age), d eta_target / d age)` on the log-cumulative-hazard scale.
2638pub fn evaluate_survival_baseline(
2639    age: f64,
2640    cfg: &SurvivalBaselineConfig,
2641) -> Result<(f64, f64), String> {
2642    if !age.is_finite() || age < 0.0 {
2643        return Err(
2644            "survival ages must be finite and non-negative for baseline target evaluation"
2645                .to_string(),
2646        );
2647    }
2648
2649    // At t = 0 every parametric cumulative-hazard target satisfies H(0) = 0
2650    // exactly (this is the defining property of a cumulative hazard:
2651    // S(0) = 1 ⇒ H(0) = -log S(0) = 0). The log-cumulative-hazard offset is
2652    // therefore eta(0) = log H(0) = -inf, and we report a zero log-derivative
2653    // since `exp(eta(0)) = H(0) = 0` is the only physically valid value.
2654    // Returning `Ok((-inf, 0.0))` keeps the baseline cumulative hazard exactly
2655    // zero at the origin; downstream callers that need to multiply this offset
2656    // into a linear predictor are responsible for handling the origin row via
2657    // the `entry_at_origin` / `exit_at_origin` gating already wired through the
2658    // engine.
2659    if age == 0.0 {
2660        return match cfg.target {
2661            SurvivalBaselineTarget::Linear => Ok((0.0, 0.0)),
2662            SurvivalBaselineTarget::Weibull
2663            | SurvivalBaselineTarget::Gompertz
2664            | SurvivalBaselineTarget::GompertzMakeham => Ok((f64::NEG_INFINITY, 0.0)),
2665        };
2666    }
2667
2668    let Some(params) = validated_baseline_params(age, cfg, "baseline target evaluation")? else {
2669        return Ok((0.0, 0.0));
2670    };
2671
2672    match params {
2673        ValidatedBaselineTarget::Weibull { scale, shape } => {
2674            let eta = shape * (age.ln() - scale.ln());
2675            let derivative = shape / age;
2676            Ok((eta, derivative))
2677        }
2678        ValidatedBaselineTarget::Gompertz { rate, shape } => {
2679            let (h, inst) = gompertz_hazard_components(age, rate, shape);
2680            if h <= 0.0 || !h.is_finite() {
2681                return Err(if shape.abs() < 1e-10 {
2682                    "invalid gompertz baseline at near-zero shape".to_string()
2683                } else {
2684                    "gompertz baseline produced non-positive cumulative hazard".to_string()
2685                });
2686            }
2687            let derivative = inst / h;
2688            Ok((h.ln(), derivative))
2689        }
2690        ValidatedBaselineTarget::GompertzMakeham {
2691            rate,
2692            shape,
2693            makeham,
2694        } => {
2695            let (h_gompertz, inst_gompertz) = gompertz_hazard_components(age, rate, shape);
2696            let h = makeham * age + h_gompertz;
2697            if h <= 0.0 || !h.is_finite() {
2698                return Err(
2699                    "gompertz-makeham baseline produced non-positive cumulative hazard".to_string(),
2700                );
2701            }
2702            let inst = makeham + inst_gompertz;
2703            let derivative = inst / h;
2704            Ok((h.ln(), derivative))
2705        }
2706    }
2707}
2708
2709/// Evaluate the parametric baseline as the probit index whose marginal
2710/// survival is the true hazard survival `exp(-H0(t))`.
2711///
2712/// Returns `(q(age), dq / d age)` such that `Phi(-q(age)) = exp(-H0(age))`.
2713/// The derivative is `h0(t) * exp(-H0(t)) / phi(q(t))`.
2714pub fn evaluate_survival_marginal_slope_baseline(
2715    age: f64,
2716    cfg: &SurvivalBaselineConfig,
2717) -> Result<(f64, f64), String> {
2718    // Survival-curve origin. Every cumulative-hazard baseline satisfies
2719    // `H0(0) = 0` (`S0(0) = exp(-H0(0)) = 1`), so the probit index
2720    // `q(0) = -Phi^{-1}(S0(0)) = -Phi^{-1}(1) = -inf`: there is no *finite*
2721    // probit-survival offset at the origin. The survival surface anchors
2722    // `S(0) = 1` directly (see the `t <= 0` origin handling in the survival
2723    // predict paths), so the baseline contributes nothing here — report the
2724    // zero offset rather than aborting in the `age <= 0` hazard guard. This
2725    // mirrors `evaluate_survival_baseline`'s explicit `age == 0` branch on the
2726    // log-cumulative-hazard channel; without it the probit/marginal-slope
2727    // baseline path (location-scale + marginal-slope likelihoods) could not be
2728    // evaluated on a prediction grid whose first node is the origin (#1024).
2729    if age == 0.0 {
2730        return Ok((0.0, 0.0));
2731    }
2732    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2733        return Ok((0.0, 0.0));
2734    };
2735    Ok((point.q, point.q_t))
2736}
2737
2738/// Partial derivatives of the true survival marginal-slope probit offsets
2739/// `(q(t), dq(t)/dt)` with respect to the baseline θ-parameters.
2740///
2741/// The returned channels match `survival_baseline_theta_from_config`.  For
2742/// Gompertz-Makeham, θ is `(log_rate, shape, log_makeham)`.  If
2743/// `S(t)=exp(-H(t))`, `q(t)=-Phi^-1(S(t))`, `A(t)=S(t)/phi(q(t))`, and
2744/// `h(t)=dH/dt`, then
2745///
2746///   dq/dθ      = A * dH/dθ
2747///   d(q')/dθ   = A * (dh/dθ + h * (q*A - 1) * dH/dθ)
2748///
2749/// which keeps the probit transform and the hazard baseline analytically tied.
2750pub fn marginal_slope_baseline_offset_theta_partials(
2751    age: f64,
2752    cfg: &SurvivalBaselineConfig,
2753) -> Result<Option<Vec<(f64, f64)>>, String> {
2754    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2755        return Ok(None);
2756    };
2757    let hazard_partials = survival_hazard_theta_partials(age, cfg)?
2758        .ok_or_else(|| "unexpected missing hazard partials for nonlinear baseline".to_string())?;
2759    let a = point.q_t / point.instant_hazard;
2760    let a_log_derivative_factor = point.q * a - 1.0;
2761    Ok(Some(
2762        hazard_partials
2763            .into_iter()
2764            .map(|(d_h_cum, d_h_inst)| {
2765                (
2766                    a * d_h_cum,
2767                    a * (d_h_inst + point.instant_hazard * a_log_derivative_factor * d_h_cum),
2768                )
2769            })
2770            .collect(),
2771    ))
2772}
2773
2774/// Contract marginal-slope offset residuals and channel curvatures into the
2775/// exact Hessian with respect to baseline θ-parameters.
2776pub fn marginal_slope_baseline_chain_rule_hessian(
2777    age_entry: ndarray::ArrayView1<'_, f64>,
2778    age_exit: ndarray::ArrayView1<'_, f64>,
2779    cfg: &SurvivalBaselineConfig,
2780    residuals: &crate::survival::OffsetChannelResiduals,
2781    curvatures: &crate::survival::OffsetChannelCurvatures,
2782) -> Result<Option<Array2<f64>>, String> {
2783    let n = age_exit.len();
2784    if age_entry.len() != n
2785        || residuals.exit.len() != n
2786        || residuals.entry.len() != n
2787        || residuals.derivative.len() != n
2788        || curvatures.rows.len() != n
2789    {
2790        return Err(format!(
2791            "marginal_slope_baseline_chain_rule_hessian: length mismatch (age_entry={}, age_exit={}, r_exit={}, r_entry={}, r_deriv={}, h_rows={})",
2792            age_entry.len(),
2793            n,
2794            residuals.exit.len(),
2795            residuals.entry.len(),
2796            residuals.derivative.len(),
2797            curvatures.rows.len(),
2798        ));
2799    }
2800    let probe_age = age_exit.iter().copied().find(|v| v.is_finite() && *v > 0.0);
2801    let dim = match probe_age {
2802        Some(t) => match marginal_slope_baseline_offset_theta_second_partials(t, cfg)? {
2803            None => return Ok(None),
2804            Some(parts) => parts.first.len(),
2805        },
2806        None => {
2807            return Err(
2808                "marginal_slope_baseline_chain_rule_hessian: no valid positive age for dim probe"
2809                    .to_string(),
2810            );
2811        }
2812    };
2813    // Per-row Hessian contractions are independent. Each row contributes a
2814    // dim×dim increment combining second partials (exit/entry channels) with
2815    // the curvature-weighted outer product of the (entry, exit, derivative)
2816    // first-partial Jacobians. Parallel try_fold/try_reduce accumulates them.
2817    let hessian = (0..n)
2818        .into_par_iter()
2819        .try_fold(
2820            || Array2::<f64>::zeros((dim, dim)),
2821            |mut acc, i| -> Result<Array2<f64>, String> {
2822                let exit_parts =
2823                    marginal_slope_baseline_offset_theta_second_partials(age_exit[i], cfg)?
2824                        .ok_or_else(|| {
2825                            "unexpected None from marginal-slope second partials at exit"
2826                                .to_string()
2827                        })?;
2828                if exit_parts.first.len() != dim {
2829                    return Err(
2830                        "marginal_slope_baseline_chain_rule_hessian: theta_dim drifted".to_string(),
2831                    );
2832                }
2833                let mut entry_parts = None;
2834                if residuals.entry[i] != 0.0 {
2835                    entry_parts = Some(
2836                        marginal_slope_baseline_offset_theta_second_partials(age_entry[i], cfg)?
2837                            .ok_or_else(|| {
2838                                "unexpected None from marginal-slope second partials at entry"
2839                                    .to_string()
2840                            })?,
2841                    );
2842                }
2843                for a in 0..dim {
2844                    for b in 0..dim {
2845                        let j_exit_a = exit_parts.first[a].0;
2846                        let j_exit_b = exit_parts.first[b].0;
2847                        let j_deriv_a = exit_parts.first[a].1;
2848                        let j_deriv_b = exit_parts.first[b].1;
2849                        let mut value = residuals.exit[i] * exit_parts.second[a][b].0
2850                            + residuals.derivative[i] * exit_parts.second[a][b].1;
2851                        if let Some(parts) = entry_parts.as_ref() {
2852                            value += residuals.entry[i] * parts.second[a][b].0;
2853                        }
2854                        let curv = curvatures.rows[i];
2855                        let j_entry_a = entry_parts.as_ref().map_or(0.0, |parts| parts.first[a].0);
2856                        let j_entry_b = entry_parts.as_ref().map_or(0.0, |parts| parts.first[b].0);
2857                        let ja = [j_entry_a, j_exit_a, j_deriv_a];
2858                        let jb = [j_entry_b, j_exit_b, j_deriv_b];
2859                        for u in 0..3 {
2860                            for v in 0..3 {
2861                                value += ja[u] * curv[u][v] * jb[v];
2862                            }
2863                        }
2864                        acc[[a, b]] += value;
2865                    }
2866                }
2867                Ok(acc)
2868            },
2869        )
2870        .try_reduce(|| Array2::<f64>::zeros((dim, dim)), |a, b| Ok(a + b))?;
2871    Ok(Some(hessian))
2872}
2873
2874struct MarginalSlopeThetaSecondPartials {
2875    first: Vec<(f64, f64)>,
2876    second: Vec<Vec<(f64, f64)>>,
2877}
2878
2879fn marginal_slope_baseline_offset_theta_second_partials(
2880    age: f64,
2881    cfg: &SurvivalBaselineConfig,
2882) -> Result<Option<MarginalSlopeThetaSecondPartials>, String> {
2883    let Some(point) = evaluate_marginal_slope_baseline_point(age, cfg)? else {
2884        return Ok(None);
2885    };
2886    let Some((hazard, first, second)) = survival_hazard_theta_first_second(age, cfg)? else {
2887        return Ok(None);
2888    };
2889    let (cum_hazard, instant_hazard) = hazard;
2890    let survival = (-cum_hazard).exp();
2891    let a = survival / normal_pdf(point.q);
2892    let b = point.q * a - 1.0;
2893    let b_factor = a + point.q * b;
2894    let dim = first.len();
2895    let mut first_out = Vec::with_capacity(dim);
2896    let mut second_out = vec![vec![(0.0, 0.0); dim]; dim];
2897    for i in 0..dim {
2898        let (h_i, inst_i) = first[i];
2899        first_out.push((a * h_i, a * (inst_i + instant_hazard * b * h_i)));
2900    }
2901    for i in 0..dim {
2902        for j in 0..dim {
2903            let (h_i, inst_i) = first[i];
2904            let (h_j, inst_j) = first[j];
2905            let (h_ij, inst_ij) = second[i][j];
2906            let a_j = a * b * h_j;
2907            let b_j = a * h_j * b_factor;
2908            let q_ij = a * h_ij + a * b * h_i * h_j;
2909            let qt_inner_i = inst_i + instant_hazard * b * h_i;
2910            let qt_ij = a_j * qt_inner_i
2911                + a * (inst_ij + inst_j * b * h_i + instant_hazard * (b_j * h_i + b * h_ij));
2912            second_out[i][j] = (q_ij, qt_ij);
2913        }
2914    }
2915    Ok(Some(MarginalSlopeThetaSecondPartials {
2916        first: first_out,
2917        second: second_out,
2918    }))
2919}
2920
2921type HazardFirstSecond = ((f64, f64), Vec<(f64, f64)>, Vec<Vec<(f64, f64)>>);
2922
2923fn survival_hazard_theta_first_second(
2924    age: f64,
2925    cfg: &SurvivalBaselineConfig,
2926) -> Result<Option<HazardFirstSecond>, String> {
2927    let Some(hazard) = survival_cumulative_and_instant_hazard(age, cfg)? else {
2928        return Ok(None);
2929    };
2930    let first = survival_hazard_theta_partials(age, cfg)?
2931        .ok_or_else(|| "unexpected missing hazard partials".to_string())?;
2932    let dim = first.len();
2933    let mut second = vec![vec![(0.0, 0.0); dim]; dim];
2934    match cfg.target {
2935        SurvivalBaselineTarget::Linear => return Ok(None),
2936        SurvivalBaselineTarget::Weibull => {
2937            let scale = cfg
2938                .scale
2939                .ok_or_else(|| "weibull missing scale".to_string())?;
2940            let shape = cfg
2941                .shape
2942                .ok_or_else(|| "weibull missing shape".to_string())?;
2943            let log_time_ratio = age.ln() - scale.ln();
2944            let cumulative_hazard = hazard.0;
2945            let instant_hazard = hazard.1;
2946            let eta = shape * log_time_ratio;
2947            second[0][0] = (
2948                shape * shape * cumulative_hazard,
2949                shape * shape * instant_hazard,
2950            );
2951            second[0][1] = (
2952                -shape * cumulative_hazard * (1.0 + eta),
2953                -shape * instant_hazard * (2.0 + eta),
2954            );
2955            second[1][0] = second[0][1];
2956            second[1][1] = (
2957                eta * cumulative_hazard * (1.0 + eta),
2958                (eta + (1.0 + eta) * (1.0 + eta)) * instant_hazard,
2959            );
2960        }
2961        SurvivalBaselineTarget::Gompertz => {
2962            let rate = cfg
2963                .rate
2964                .ok_or_else(|| "gompertz missing rate".to_string())?;
2965            let shape = cfg
2966                .shape
2967                .ok_or_else(|| "gompertz missing shape".to_string())?;
2968            second[0][0] = first[0];
2969            second[0][1] = first[1];
2970            second[1][0] = first[1];
2971            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2972        }
2973        SurvivalBaselineTarget::GompertzMakeham => {
2974            let rate = cfg.rate.ok_or_else(|| "gm missing rate".to_string())?;
2975            let shape = cfg.shape.ok_or_else(|| "gm missing shape".to_string())?;
2976            second[0][0] = first[0];
2977            second[0][1] = first[1];
2978            second[1][0] = first[1];
2979            second[1][1] = gompertz_cumulative_shape_second_derivative(age, rate, shape);
2980            second[2][2] = first[2];
2981        }
2982    }
2983    Ok(Some((hazard, first, second)))
2984}
2985
2986#[inline]
2987fn gompertz_cumulative_shape_second_derivative(age: f64, rate: f64, shape: f64) -> (f64, f64) {
2988    let x = shape * age;
2989    // ∂²H_G/∂shape² = rate·[t²·E/shape − 2·(shape·t·E − (E−1))/shape³]. This is
2990    // a difference of O(1/shape³) terms whose leading parts cancel, so its
2991    // floating-point accuracy is governed by x = shape·age — and the
2992    // cancellation is FAR worse than the first derivative's 1/shape² form.
2993    // Empirically the exact branch is already garbage for |x| < ~1e-4 (e.g.
2994    // x=1e-9 gives a ~98% relative error; x=1e-10 a ~9700% error). The old
2995    // `shape < 1e-10` pivot ignored `age` and so routed those small-x cases
2996    // through the cancelling exact form, corrupting the marginal-slope baseline
2997    // Hessian near small shape. Pivot on x with a wider threshold than the
2998    // first derivative: the 3-term Taylor (through O(x²)) holds to <1e-8 for
2999    // |x| < 1e-3, and the exact branch is clean above it.
3000    if x.abs() < 1e-3 {
3001        let t = age;
3002        (
3003            rate * t * t * t * (1.0 / 3.0 + x / 4.0 + x * x / 10.0),
3004            rate * t * t * (1.0 + x + 0.5 * x * x),
3005        )
3006    } else {
3007        let e = x.exp();
3008        let em1 = x.exp_m1();
3009        let n = shape * age * e - em1;
3010        (
3011            rate * (age * age * e / shape - 2.0 * n / (shape * shape * shape)),
3012            rate * age * age * e,
3013        )
3014    }
3015}
3016
3017// ---------------------------------------------------------------------------
3018// Baseline offsets
3019// ---------------------------------------------------------------------------
3020
3021#[derive(Clone, Copy)]
3022enum BaselineOffsetEvaluator {
3023    LogCumulativeHazard,
3024    ProbitSurvival,
3025}
3026
3027impl BaselineOffsetEvaluator {
3028    fn length_error(self) -> String {
3029        match self {
3030            Self::LogCumulativeHazard => SurvivalConstructionError::IncompatibleDimensions {
3031                reason: "survival baseline offsets require matching entry/exit lengths".to_string(),
3032            }
3033            .into(),
3034            Self::ProbitSurvival => {
3035                "survival probit baseline offsets require matching entry/exit lengths".to_string()
3036            }
3037        }
3038    }
3039
3040    fn finite_error(self) -> &'static str {
3041        match self {
3042            Self::LogCumulativeHazard => "non-finite survival baseline offsets computed",
3043            Self::ProbitSurvival => "non-finite survival probit baseline offsets computed",
3044        }
3045    }
3046
3047    fn evaluate(self, age: f64, cfg: &SurvivalBaselineConfig) -> Result<(f64, f64), String> {
3048        match self {
3049            Self::LogCumulativeHazard => evaluate_survival_baseline(age, cfg),
3050            Self::ProbitSurvival => evaluate_survival_marginal_slope_baseline(age, cfg),
3051        }
3052    }
3053
3054    fn exit_is_finite(self, value: f64, age: f64) -> bool {
3055        match self {
3056            Self::LogCumulativeHazard => {
3057                value.is_finite() || (age == 0.0 && value == f64::NEG_INFINITY)
3058            }
3059            Self::ProbitSurvival => value.is_finite(),
3060        }
3061    }
3062}
3063
3064fn build_survival_offsets_with_evaluator(
3065    age_entry: &Array1<f64>,
3066    age_exit: &Array1<f64>,
3067    cfg: &SurvivalBaselineConfig,
3068    evaluator: BaselineOffsetEvaluator,
3069) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3070    if age_entry.len() != age_exit.len() {
3071        return Err(evaluator.length_error());
3072    }
3073    let n = age_entry.len();
3074    // Each row's three offsets are independent across i. Compute the triplets
3075    // in parallel, then unpack into three Array1 outputs preserving order.
3076    let triples: Vec<(f64, f64, f64)> = (0..n)
3077        .into_par_iter()
3078        .map(|i| -> Result<(f64, f64, f64), String> {
3079            // Origin-entry rows are multiplied out by the survival engines, so
3080            // keep their entry channel finite even when the evaluator's natural
3081            // value at t=0 is undefined or -inf.
3082            let entry_age = age_entry[i];
3083            let e0 = if !entry_age.is_finite() {
3084                return Err(SurvivalConstructionError::DataValidationFailed {
3085                    reason: format!("non-finite entry age at row {i}"),
3086                }
3087                .into());
3088            } else if entry_age <= 0.0 {
3089                0.0
3090            } else {
3091                evaluator.evaluate(entry_age, cfg)?.0
3092            };
3093            let exit_age = age_exit[i];
3094            let (e1, d1) = evaluator.evaluate(exit_age, cfg)?;
3095            if !e0.is_finite() || !evaluator.exit_is_finite(e1, exit_age) || !d1.is_finite() {
3096                return Err(SurvivalConstructionError::DataValidationFailed {
3097                    reason: evaluator.finite_error().to_string(),
3098                }
3099                .into());
3100            }
3101            Ok((e0, e1, d1))
3102        })
3103        .collect::<Result<Vec<_>, String>>()?;
3104    let mut eta_entry = Array1::<f64>::zeros(n);
3105    let mut eta_exit = Array1::<f64>::zeros(n);
3106    let mut derivative_exit = Array1::<f64>::zeros(n);
3107    for (i, (e0, e1, d1)) in triples.into_iter().enumerate() {
3108        eta_entry[i] = e0;
3109        eta_exit[i] = e1;
3110        derivative_exit[i] = d1;
3111    }
3112    Ok((eta_entry, eta_exit, derivative_exit))
3113}
3114
3115/// Compute baseline target offsets for all observations.
3116/// Returns `(eta_entry, eta_exit, derivative_exit)`.
3117pub fn build_survival_baseline_offsets(
3118    age_entry: &Array1<f64>,
3119    age_exit: &Array1<f64>,
3120    cfg: &SurvivalBaselineConfig,
3121) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3122    build_survival_offsets_with_evaluator(
3123        age_entry,
3124        age_exit,
3125        cfg,
3126        BaselineOffsetEvaluator::LogCumulativeHazard,
3127    )
3128}
3129
3130/// Compute probit-survival baseline target offsets for all observations.
3131/// Returns `(q_entry, q_exit, q_derivative_exit)` where `Phi(-q(t)) = exp(-H0(t))`.
3132pub fn build_survival_marginal_slope_baseline_offsets(
3133    age_entry: &Array1<f64>,
3134    age_exit: &Array1<f64>,
3135    cfg: &SurvivalBaselineConfig,
3136) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3137    build_survival_offsets_with_evaluator(
3138        age_entry,
3139        age_exit,
3140        cfg,
3141        BaselineOffsetEvaluator::ProbitSurvival,
3142    )
3143}
3144
3145pub fn location_scale_uses_probit_survival_baseline(inverse_link: Option<&InverseLink>) -> bool {
3146    matches!(
3147        inverse_link,
3148        Some(
3149            InverseLink::Standard(StandardLink::Probit)
3150                | InverseLink::LatentCLogLog(_)
3151                | InverseLink::Sas(_)
3152                | InverseLink::BetaLogistic(_)
3153                | InverseLink::Mixture(_)
3154        )
3155    )
3156}
3157
3158pub fn survival_derivative_guard_for_likelihood(likelihood_mode: SurvivalLikelihoodMode) -> f64 {
3159    match likelihood_mode {
3160        SurvivalLikelihoodMode::LocationScale
3161        | SurvivalLikelihoodMode::Latent
3162        | SurvivalLikelihoodMode::LatentBinary => DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD,
3163        SurvivalLikelihoodMode::MarginalSlope => DEFAULT_SURVIVAL_MARGINAL_SLOPE_DERIVATIVE_GUARD,
3164        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => 0.0,
3165    }
3166}
3167
3168pub fn build_survival_time_offsets_for_likelihood(
3169    age_entry: &Array1<f64>,
3170    age_exit: &Array1<f64>,
3171    baseline_cfg: &SurvivalBaselineConfig,
3172    likelihood_mode: SurvivalLikelihoodMode,
3173    inverse_link: Option<&InverseLink>,
3174) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), String> {
3175    if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
3176        || (likelihood_mode == SurvivalLikelihoodMode::LocationScale
3177            && location_scale_uses_probit_survival_baseline(inverse_link))
3178    {
3179        build_survival_marginal_slope_baseline_offsets(age_entry, age_exit, baseline_cfg)
3180    } else {
3181        build_survival_baseline_offsets(age_entry, age_exit, baseline_cfg)
3182    }
3183}
3184
3185pub fn add_survival_time_derivative_guard_offset(
3186    age_entry: &Array1<f64>,
3187    age_exit: &Array1<f64>,
3188    anchor_time: f64,
3189    derivative_guard: f64,
3190    eta_offset_entry: &mut Array1<f64>,
3191    eta_offset_exit: &mut Array1<f64>,
3192    derivative_offset_exit: &mut Array1<f64>,
3193) -> Result<(), String> {
3194    if derivative_guard <= 0.0 {
3195        return Ok(());
3196    }
3197    let n = age_entry.len();
3198    if age_exit.len() != n
3199        || eta_offset_entry.len() != n
3200        || eta_offset_exit.len() != n
3201        || derivative_offset_exit.len() != n
3202    {
3203        return Err(SurvivalConstructionError::IncompatibleDimensions {
3204            reason: "survival derivative-guard offset lengths must match".to_string(),
3205        }
3206        .into());
3207    }
3208    for i in 0..n {
3209        eta_offset_entry[i] += derivative_guard * (age_entry[i] - anchor_time);
3210        eta_offset_exit[i] += derivative_guard * (age_exit[i] - anchor_time);
3211        derivative_offset_exit[i] += derivative_guard;
3212    }
3213    Ok(())
3214}
3215
3216#[derive(Clone, Debug)]
3217pub struct LatentSurvivalBaselineOffsets {
3218    pub loaded_eta_entry: Array1<f64>,
3219    pub loaded_eta_exit: Array1<f64>,
3220    pub loaded_derivative_exit: Array1<f64>,
3221    pub unloaded_mass_entry: Array1<f64>,
3222    pub unloaded_mass_exit: Array1<f64>,
3223    pub unloaded_hazard_exit: Array1<f64>,
3224}
3225
3226pub fn build_latent_survival_baseline_offsets(
3227    age_entry: &Array1<f64>,
3228    age_exit: &Array1<f64>,
3229    cfg: &SurvivalBaselineConfig,
3230    loading: HazardLoading,
3231) -> Result<LatentSurvivalBaselineOffsets, String> {
3232    if age_entry.len() != age_exit.len() {
3233        return Err(
3234            "latent survival baseline offsets require matching entry/exit lengths".to_string(),
3235        );
3236    }
3237
3238    fn gompertz_components(age: f64, rate: f64, shape: f64) -> (f64, f64) {
3239        if shape.abs() < 1e-10 {
3240            // Taylor at shape=0 matching `gompertz_hazard_components`:
3241            //   H_G(t) = rate·t·(1 + (shape·t)/2 + (shape·t)²/6)
3242            //   h_G(t) = rate·(1 + shape·t + (shape·t)²/2)
3243            // Dropping the higher-order `shape*t` corrections silently
3244            // diverges this helper from its sibling for non-zero shape near
3245            // the cutoff and gives inconsistent loaded-vs-unloaded offsets.
3246            let x = shape * age;
3247            return (
3248                rate * age * (1.0 + 0.5 * x + x * x / 6.0),
3249                rate * (1.0 + x + 0.5 * x * x),
3250            );
3251        }
3252        let shape_age = shape * age;
3253        let cumulative_hazard = (rate / shape) * shape_age.exp_m1();
3254        let instant_hazard = rate * shape_age.exp();
3255        (cumulative_hazard, instant_hazard)
3256    }
3257
3258    let n = age_entry.len();
3259
3260    // Per-row 6-tuple is independent. Evaluate in parallel into a Vec and then
3261    // unpack into the six Array1 outputs in original order.
3262    let rows: Vec<[f64; 6]> = (0..n)
3263        .into_par_iter()
3264        .map(|i| -> Result<[f64; 6], String> {
3265            let entry = age_entry[i];
3266            let exit = age_exit[i];
3267            if !entry.is_finite()
3268                || !exit.is_finite()
3269                || entry <= 0.0
3270                || exit <= 0.0
3271                || exit < entry
3272            {
3273                return Err(format!(
3274                    "latent survival baseline offsets require finite positive entry/exit ages with exit >= entry (row {})",
3275                    i + 1
3276                ));
3277            }
3278            match loading {
3279                HazardLoading::Full => {
3280                    let (eta_entry, _) = evaluate_survival_baseline(entry, cfg)?;
3281                    let (eta_exit, derivative_exit) = evaluate_survival_baseline(exit, cfg)?;
3282                    Ok([eta_entry, eta_exit, derivative_exit, 0.0, 0.0, 0.0])
3283                }
3284                HazardLoading::LoadedVsUnloaded => {
3285                    if cfg.target != SurvivalBaselineTarget::GompertzMakeham {
3286                        return Err(format!(
3287                            "HazardLoading::LoadedVsUnloaded requires --baseline-target gompertz-makeham, got {}",
3288                            survival_baseline_targetname(cfg.target)
3289                        ));
3290                    }
3291                    let rate = cfg.rate.ok_or_else(|| {
3292                        "gompertz-makeham latent survival is missing baseline rate".to_string()
3293                    })?;
3294                    let shape = cfg.shape.ok_or_else(|| {
3295                        "gompertz-makeham latent survival is missing baseline shape".to_string()
3296                    })?;
3297                    let makeham = cfg.makeham.ok_or_else(|| {
3298                        "gompertz-makeham latent survival is missing baseline makeham".to_string()
3299                    })?;
3300                    let (loaded_entry, _) = gompertz_components(entry, rate, shape);
3301                    let (loaded_exit, loaded_hazard) = gompertz_components(exit, rate, shape);
3302                    if !(loaded_entry.is_finite()
3303                        && loaded_entry > 0.0
3304                        && loaded_exit.is_finite()
3305                        && loaded_exit > 0.0
3306                        && loaded_hazard.is_finite()
3307                        && loaded_hazard > 0.0)
3308                    {
3309                        return Err(format!(
3310                            "gompertz-makeham latent loaded component produced a non-positive or non-finite hazard decomposition at row {}",
3311                            i + 1
3312                        ));
3313                    }
3314                    Ok([
3315                        loaded_entry.ln(),
3316                        loaded_exit.ln(),
3317                        loaded_hazard / loaded_exit,
3318                        makeham * entry,
3319                        makeham * exit,
3320                        makeham,
3321                    ])
3322                }
3323            }
3324        })
3325        .collect::<Result<Vec<_>, String>>()?;
3326
3327    let mut loaded_eta_entry = Array1::<f64>::zeros(n);
3328    let mut loaded_eta_exit = Array1::<f64>::zeros(n);
3329    let mut loaded_derivative_exit = Array1::<f64>::zeros(n);
3330    let mut unloaded_mass_entry = Array1::<f64>::zeros(n);
3331    let mut unloaded_mass_exit = Array1::<f64>::zeros(n);
3332    let mut unloaded_hazard_exit = Array1::<f64>::zeros(n);
3333    for (i, row) in rows.into_iter().enumerate() {
3334        loaded_eta_entry[i] = row[0];
3335        loaded_eta_exit[i] = row[1];
3336        loaded_derivative_exit[i] = row[2];
3337        unloaded_mass_entry[i] = row[3];
3338        unloaded_mass_exit[i] = row[4];
3339        unloaded_hazard_exit[i] = row[5];
3340    }
3341
3342    Ok(LatentSurvivalBaselineOffsets {
3343        loaded_eta_entry,
3344        loaded_eta_exit,
3345        loaded_derivative_exit,
3346        unloaded_mass_entry,
3347        unloaded_mass_exit,
3348        unloaded_hazard_exit,
3349    })
3350}
3351
3352// ---------------------------------------------------------------------------
3353// Time wiggle construction
3354// ---------------------------------------------------------------------------
3355
3356pub fn build_survival_timewiggle_derivative_design(
3357    eta_exit: &Array1<f64>,
3358    derivative_exit: &Array1<f64>,
3359    knots: &Array1<f64>,
3360    degree: usize,
3361) -> Result<Array2<f64>, String> {
3362    let mut design_derivative_exit =
3363        monotone_wiggle_basis_with_derivative_order(eta_exit.view(), knots, degree, 1)?;
3364    for i in 0..design_derivative_exit.nrows() {
3365        let chain = derivative_exit[i];
3366        for j in 0..design_derivative_exit.ncols() {
3367            design_derivative_exit[[i, j]] *= chain;
3368        }
3369    }
3370    Ok(design_derivative_exit)
3371}
3372
3373/// Build the dynamic "baseline as prior" timewiggle runtime.
3374///
3375/// The baseline offsets are used only to initialize the wiggle knot placement
3376/// on a stable scalar scale.  The exact survival family evaluates the resulting
3377/// monotone wiggle dynamically on the current time predictor h0(t):
3378///
3379///   h(t) = g(h0(t)),   g(z) = z + w(z).
3380///
3381/// No fixed `B(eta_baseline)` design is constructed here.
3382pub fn build_survival_timewiggle_from_baseline(
3383    eta_entry: &Array1<f64>,
3384    eta_exit: &Array1<f64>,
3385    derivative_exit: &Array1<f64>,
3386    cfg: &LinkWiggleFormulaSpec,
3387) -> Result<SurvivalTimeWiggleBuild, String> {
3388    if eta_entry.len() != eta_exit.len() || eta_exit.len() != derivative_exit.len() {
3389        return Err(
3390            "baseline-timewiggle requires matching entry/exit/derivative lengths".to_string(),
3391        );
3392    }
3393    // Guard: if baseline offsets are all zero (linear baseline), the timewiggle
3394    // construction is degenerate — it adds only a constant, not time-varying structure.
3395    let all_zero = eta_entry.iter().all(|&v| v.abs() < 1e-15)
3396        && eta_exit.iter().all(|&v| v.abs() < 1e-15)
3397        && derivative_exit.iter().all(|&v| v.abs() < 1e-15);
3398    if all_zero {
3399        return Err(
3400            "timewiggle requires a non-linear scalar survival baseline target; \
3401             the provided baseline offsets are all zero (linear baseline)"
3402                .to_string(),
3403        );
3404    }
3405    let n = eta_exit.len();
3406    let mut seed = Array1::<f64>::zeros(2 * n);
3407    for i in 0..n {
3408        seed[i] = eta_entry[i];
3409        seed[n + i] = eta_exit[i];
3410    }
3411    // Use the smallest requested positive penalty order as the primary
3412    // coefficient-space penalty so the fitted wiggle penalty system matches
3413    // the public formula exactly, including the slope (`order = 1`) case.
3414    let (primary_order, extra_orders) = split_wiggle_penalty_orders(2, &cfg.penalty_orders);
3415    let wiggle_cfg = WiggleBlockConfig {
3416        degree: cfg.degree,
3417        num_internal_knots: cfg.num_internal_knots,
3418        penalty_order: primary_order,
3419        double_penalty: cfg.double_penalty,
3420    };
3421    let (mut combined_block, knots) = buildwiggle_block_input_from_seed(seed.view(), &wiggle_cfg)?;
3422    append_selected_wiggle_penalty_orders(&mut combined_block, &extra_orders)?;
3423    let ncols = combined_block.design.ncols();
3424    Ok(SurvivalTimeWiggleBuild {
3425        nullspace_dims: combined_block.nullspace_dims.clone(),
3426        penalties: {
3427            combined_block
3428                .penalties
3429                .into_iter()
3430                .map(|ps| ps.to_global(ncols))
3431                .collect()
3432        },
3433        knots,
3434        degree: cfg.degree,
3435        ncols,
3436    })
3437}
3438
3439pub fn append_zero_tail_columns(
3440    x_entry: &mut DesignMatrix,
3441    x_exit: &mut DesignMatrix,
3442    x_derivative: &mut DesignMatrix,
3443    tail_cols: usize,
3444) {
3445    if tail_cols == 0 {
3446        return;
3447    }
3448    // Wiggle tail columns are dense, so materialize everything to dense.
3449    // This only runs once at construction time when time-wiggles are active.
3450    fn append_dense(dm: &mut DesignMatrix, tail: usize) {
3451        let old = dm.to_dense();
3452        let n = old.nrows();
3453        let p_base = old.ncols();
3454        let mut out = Array2::<f64>::zeros((n, p_base + tail));
3455        out.slice_mut(s![.., 0..p_base]).assign(&old);
3456        *dm = DesignMatrix::Dense(DenseDesignMatrix::from(out));
3457    }
3458    append_dense(x_entry, tail_cols);
3459    append_dense(x_exit, tail_cols);
3460    append_dense(x_derivative, tail_cols);
3461}
3462
3463// ---------------------------------------------------------------------------
3464// Resolved config (from build output back to config for serialization)
3465// ---------------------------------------------------------------------------
3466
3467// ---------------------------------------------------------------------------
3468// Time-varying covariate template
3469// ---------------------------------------------------------------------------
3470
3471/// Build a time-varying covariate block by tensoring the covariate design
3472/// with a 1D B-spline basis on log(time).
3473pub fn build_time_varying_survival_covariate_template(
3474    age_entry: &Array1<f64>,
3475    age_exit: &Array1<f64>,
3476    time_k: usize,
3477    time_degree: usize,
3478    block_name: &str,
3479) -> Result<SurvivalCovariateTermBlockTemplate, String> {
3480    if time_k < time_degree + 1 {
3481        return Err(format!(
3482            "--{block_name}-time-k must be >= degree + 1 = {}, got {time_k}",
3483            time_degree + 1
3484        ));
3485    }
3486    let num_internal_knots = time_k - (time_degree + 1);
3487
3488    let log_entry = age_entry.mapv(|t| t.max(1e-12).ln());
3489    let log_exit = age_exit.mapv(|t| t.max(1e-12).ln());
3490
3491    let time_spec = BSplineBasisSpec {
3492        degree: time_degree,
3493        penalty_order: 2,
3494        knotspec: BSplineKnotSpec::Automatic {
3495            num_internal_knots: Some(num_internal_knots),
3496            placement: gam_terms::basis::BSplineKnotPlacement::Quantile,
3497        },
3498        double_penalty: false,
3499        identifiability: BSplineIdentifiability::None,
3500        boundary: OneDimensionalBoundary::Open,
3501        boundary_conditions: BSplineBoundaryConditions::default(),
3502    };
3503
3504    let time_build = build_bspline_basis_1d(log_exit.view(), &time_spec)
3505        .map_err(|e| format!("failed to build {block_name} time-margin B-spline basis: {e}"))?;
3506    let time_design_exit = time_build.design.to_dense();
3507
3508    let knots = match &time_build.metadata {
3509        BasisMetadata::BSpline1D { knots, .. } => knots.clone(),
3510        _ => {
3511            return Err(format!(
3512                "{block_name} time-margin basis returned unexpected metadata type"
3513            ));
3514        }
3515    };
3516
3517    let time_build_entry = build_bspline_basis_1d(
3518        log_entry.view(),
3519        &BSplineBasisSpec {
3520            degree: time_degree,
3521            penalty_order: 2,
3522            knotspec: BSplineKnotSpec::Provided(knots.clone()),
3523            double_penalty: false,
3524            identifiability: BSplineIdentifiability::None,
3525            boundary: OneDimensionalBoundary::Open,
3526            boundary_conditions: BSplineBoundaryConditions::default(),
3527        },
3528    )
3529    .map_err(|e| format!("failed to evaluate {block_name} time-margin basis at entry: {e}"))?;
3530    let time_design_entry = time_build_entry.design.to_dense();
3531    let p_time = time_design_exit.ncols();
3532    let mut time_design_derivative_exit = Array2::<f64>::zeros((age_exit.len(), p_time));
3533    // Per-row derivative-basis evaluation is independent; each row owns its
3534    // own small `deriv_buf`. par_chunks_mut over the (n × p_time) output rows
3535    // hands disjoint mutable row-slices to rayon workers.
3536    time_design_derivative_exit
3537        .as_slice_mut()
3538        .expect("zeros are contiguous")
3539        .par_chunks_mut(p_time)
3540        .enumerate()
3541        .try_for_each(|(i, row_out)| -> Result<(), String> {
3542            let mut deriv_buf = vec![0.0_f64; p_time];
3543            evaluate_bspline_derivative_scalar(
3544                log_exit[i],
3545                knots.view(),
3546                time_degree,
3547                &mut deriv_buf,
3548            )
3549            .map_err(|e| {
3550                format!("failed to evaluate {block_name} time-margin derivative basis: {e}")
3551            })?;
3552            let chain = 1.0 / age_exit[i].max(1e-12);
3553            for j in 0..p_time {
3554                row_out[j] = deriv_buf[j] * chain;
3555            }
3556            Ok(())
3557        })?;
3558
3559    Ok(SurvivalCovariateTermBlockTemplate::TimeVarying {
3560        time_basis_entry: time_design_entry,
3561        time_basis_exit: time_design_exit,
3562        time_basis_derivative_exit: time_design_derivative_exit,
3563        time_penalties: time_build.penalties,
3564    })
3565}
3566
3567#[cfg(test)]
3568mod tests {
3569    use super::{
3570        SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalTimeBasisConfig,
3571        baseline_chain_rule_gradient, baseline_offset_theta_partials,
3572        build_survival_marginal_slope_baseline_offsets, build_survival_time_basis,
3573        build_survival_timewiggle_from_baseline, evaluate_survival_baseline,
3574        evaluate_survival_marginal_slope_baseline, gompertz_cumulative_shape_derivative,
3575        gompertz_cumulative_shape_second_derivative, gompertz_hazard_components,
3576        marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
3577        marginal_slope_baseline_offset_theta_partials,
3578        optimize_survival_baseline_config_with_gradient,
3579        optimize_survival_baseline_config_with_gradient_only,
3580        resolve_survival_marginal_slope_time_anchor_value, survival_baseline_config_from_theta,
3581        survival_baseline_theta_from_config,
3582    };
3583    use crate::survival::{OffsetChannelCurvatures, OffsetChannelResiduals};
3584    use gam_terms::inference::formula_dsl::LinkWiggleFormulaSpec;
3585    use crate::probability::normal_cdf;
3586    use ndarray::{Array1, Array2, array};
3587
3588    #[test]
3589    fn survival_timewiggle_keeps_requested_order_one_penalty() {
3590        let eta_entry = array![0.1, 0.3, 0.5, 0.8];
3591        let eta_exit = array![0.4, 0.7, 1.0, 1.4];
3592        let derivative_exit = array![0.9, 1.1, 1.2, 1.3];
3593        let cfg = LinkWiggleFormulaSpec {
3594            degree: 3,
3595            num_internal_knots: 4,
3596            penalty_orders: vec![1, 2, 3],
3597            double_penalty: false,
3598        };
3599
3600        let build =
3601            build_survival_timewiggle_from_baseline(&eta_entry, &eta_exit, &derivative_exit, &cfg)
3602                .expect("build survival timewiggle");
3603
3604        assert_eq!(build.penalties.len(), 3);
3605        assert_eq!(build.nullspace_dims, vec![1, 2, 3]);
3606        assert!(build.ncols > 0);
3607    }
3608
3609    #[test]
3610    fn marginal_slope_time_anchor_defaults_to_median_exit() {
3611        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3612        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3613        let anchor = resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)
3614            .expect("resolve marginal-slope default time anchor");
3615
3616        assert!(
3617            (anchor - 19.0).abs() <= 1e-12,
3618            "marginal-slope default anchor should be median exit, got {anchor}"
3619        );
3620    }
3621
3622    #[test]
3623    fn marginal_slope_time_anchor_honors_explicit_value() {
3624        let age_entry = array![9.0, 1.0, 4.0, 6.0];
3625        let age_exit = array![20.0, 12.0, 18.0, 30.0];
3626        let anchor =
3627            resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, Some(7.5))
3628                .expect("resolve explicit marginal-slope time anchor");
3629
3630        assert!(
3631            (anchor - 7.5).abs() <= 1e-12,
3632            "explicit marginal-slope anchor should round-trip, got {anchor}"
3633        );
3634    }
3635
3636    /// Derivative-contract parity for the two public baseline optimizers.
3637    ///
3638    /// After the unification onto `run_baseline_theta_optimizer`, the
3639    /// gradient-only and gradient+Hessian entry points differ *only* in how
3640    /// much derivative information they hand the outer solver — not in the
3641    /// surface they minimize. We exercise that invariant on a known
3642    /// strictly-convex quadratic in θ-space (Weibull baseline: θ = (ln scale,
3643    /// ln shape)) whose unique minimizer is `theta_star`, supplying the same
3644    /// objective as `(f, ∇f)` and as `(f, ∇f, ∇²f)`. Both contracts must
3645    /// recover the same minimizer config, not weakened to pass.
3646    #[test]
3647    fn baseline_optimizer_contracts_agree_on_shared_surface() {
3648        // SPD curvature and interior minimizer in θ-space. A is well away from
3649        // singular so both the analytic-Hessian and BFGS paths see the same
3650        // unambiguous bowl; θ* sits comfortably inside the ±6 box around the
3651        // θ=(0,0) seed below.
3652        let curvature: Array2<f64> = array![[3.0, 0.5], [0.5, 2.0]];
3653        let theta_star: Array1<f64> = array![2.5_f64.ln(), 1.3_f64.ln()];
3654
3655        // Seed config at θ=(0,0) (scale=shape=1). The Linear early-return path
3656        // is not exercised here; Weibull has a genuine 2-dim θ to optimize.
3657        let initial = SurvivalBaselineConfig {
3658            target: SurvivalBaselineTarget::Weibull,
3659            scale: Some(1.0),
3660            shape: Some(1.0),
3661            rate: None,
3662            makeham: None,
3663        };
3664
3665        // θ recovered from a returned Weibull config, via the exact inverse of
3666        // the config→θ map the optimizers use internally.
3667        let recovered_theta = |cfg: &SurvivalBaselineConfig| -> Array1<f64> {
3668            survival_baseline_theta_from_config(cfg)
3669                .expect("config→θ")
3670                .expect("Weibull config has a θ")
3671        };
3672
3673        // Shared quadratic surface, evaluated by mapping config→θ so every
3674        // contract sees the identical objective.
3675        let curvature_cost = curvature.clone();
3676        let star_cost = theta_star.clone();
3677        let cost_at = move |cfg: &SurvivalBaselineConfig| -> Result<f64, String> {
3678            let theta = survival_baseline_theta_from_config(cfg)?
3679                .ok_or_else(|| "expected a θ for the cost surface".to_string())?;
3680            let d = &theta - &star_cost;
3681            let ad = curvature_cost.dot(&d);
3682            Ok(0.5 * d.dot(&ad))
3683        };
3684
3685        let curvature_grad = curvature.clone();
3686        let star_grad = theta_star.clone();
3687        let cost_for_grad = cost_at.clone();
3688        let result_grad_only = optimize_survival_baseline_config_with_gradient_only(
3689            &initial,
3690            "baseline parity (gradient-only)",
3691            move |cfg| {
3692                let cost = cost_for_grad(cfg)?;
3693                let theta = survival_baseline_theta_from_config(cfg)?
3694                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3695                let gradient = curvature_grad.dot(&(&theta - &star_grad));
3696                Ok((cost, gradient))
3697            },
3698        )
3699        .expect("gradient-only baseline optimization converges");
3700
3701        let curvature_hess = curvature.clone();
3702        let star_hess = theta_star.clone();
3703        let cost_for_hess = cost_at.clone();
3704        let result_grad_hess = optimize_survival_baseline_config_with_gradient(
3705            &initial,
3706            "baseline parity (gradient+Hessian)",
3707            move |cfg| {
3708                let cost = cost_for_hess(cfg)?;
3709                let theta = survival_baseline_theta_from_config(cfg)?
3710                    .ok_or_else(|| "expected a θ for the gradient".to_string())?;
3711                let gradient = curvature_hess.dot(&(&theta - &star_hess));
3712                Ok((cost, gradient, curvature_hess.clone()))
3713            },
3714        )
3715        .expect("gradient+Hessian baseline optimization converges");
3716
3717        let theta_grad_only = recovered_theta(&result_grad_only);
3718        let theta_grad_hess = recovered_theta(&result_grad_hess);
3719
3720        // Each contract recovers the true minimizer. 2e-3 is a safe,
3721        // un-weakened bound; both gradient paths land far tighter.
3722        for (label, theta) in [
3723            ("gradient-only", &theta_grad_only),
3724            ("gradient+Hessian", &theta_grad_hess),
3725        ] {
3726            let err = (theta - &theta_star)
3727                .mapv(f64::abs)
3728                .fold(0.0_f64, |a, &v| a.max(v));
3729            assert!(
3730                err <= 2e-3,
3731                "{label} contract recovered θ {theta:?} off true minimizer {theta_star:?} by {err:e}"
3732            );
3733        }
3734
3735        // Cross-contract agreement: the three results must coincide, since the
3736        // only difference between the entry points is the derivative contract,
3737        // never the surface they minimize.
3738        let pairwise_max = |a: &Array1<f64>, b: &Array1<f64>| -> f64 {
3739            (a - b).mapv(f64::abs).fold(0.0_f64, |acc, &v| acc.max(v))
3740        };
3741        assert!(
3742            pairwise_max(&theta_grad_only, &theta_grad_hess) <= 2e-3,
3743            "gradient-only vs gradient+Hessian disagree: {theta_grad_only:?} vs {theta_grad_hess:?}"
3744        );
3745    }
3746
3747    #[test]
3748    fn automatic_ispline_time_knots_are_sized_for_antiderivative_degree() {
3749        let age_entry = array![1.0_f64, 1.0, 1.0, 1.0, 1.0, 1.0];
3750        let age_exit = array![2.0_f64, 3.0, 5.0, 8.0, 13.0, 21.0];
3751        let requested_degree = 3;
3752        let num_internal_knots = 1;
3753
3754        let built = build_survival_time_basis(
3755            &age_entry,
3756            &age_exit,
3757            SurvivalTimeBasisConfig::ISpline {
3758                degree: requested_degree,
3759                knots: Array1::zeros(0),
3760                keep_cols: Vec::new(),
3761                smooth_lambda: 1e-2,
3762            },
3763            Some((num_internal_knots, 1e-2)),
3764        )
3765        .expect("automatic cubic ispline with one interior knot builds");
3766
3767        let working_degree = requested_degree + 1;
3768        let knots = built.knots.expect("resolved ispline knots");
3769        assert_eq!(
3770            knots.len(),
3771            num_internal_knots + 2 * (working_degree + 1),
3772            "I-spline automatic knots must be clamped for the working B-spline degree"
3773        );
3774        assert_eq!(built.degree, Some(requested_degree));
3775        assert!(built.x_exit_time.ncols() > 0);
3776        assert_eq!(built.x_entry_time.ncols(), built.x_exit_time.ncols());
3777        assert_eq!(built.x_derivative_time.ncols(), built.x_exit_time.ncols());
3778    }
3779
3780    #[test]
3781    fn ispline_time_derivative_is_nonzero_at_right_boundary() {
3782        let age_entry = array![1.0_f64, 1.0, 1.0];
3783        let age_exit = array![4.0_f64, 4.0, 4.0];
3784        let left = 1.0_f64.ln();
3785        let right = 4.0_f64.ln();
3786        let mid = left + 0.5 * (right - left);
3787        let knots = array![left, left, left, left, mid, right, right, right, right];
3788
3789        let built = build_survival_time_basis(
3790            &age_entry,
3791            &age_exit,
3792            SurvivalTimeBasisConfig::ISpline {
3793                degree: 2,
3794                knots,
3795                keep_cols: Vec::new(),
3796                smooth_lambda: 1e-2,
3797            },
3798            None,
3799        )
3800        .expect("build right-boundary ispline time basis");
3801
3802        let derivative = built.x_derivative_time.as_dense_cow();
3803        let max_abs = derivative.iter().fold(0.0_f64, |acc, v| acc.max(v.abs()));
3804        assert!(
3805            max_abs > 1e-8,
3806            "right-boundary I-spline derivative must use the left-hand endpoint slope"
3807        );
3808        for row in derivative.rows() {
3809            assert!(
3810                row.iter().any(|v| *v > 1e-8),
3811                "each row at the right boundary needs a positive hazard derivative"
3812            );
3813        }
3814    }
3815
3816    #[test]
3817    fn marginal_slope_baseline_maps_gompertz_makeham_survival_to_probit_index() {
3818        let cfg = SurvivalBaselineConfig {
3819            target: SurvivalBaselineTarget::GompertzMakeham,
3820            scale: None,
3821            shape: Some(0.07),
3822            rate: Some(0.012),
3823            makeham: Some(0.003),
3824        };
3825        let age = 11.5;
3826        let (q, q_derivative) = evaluate_survival_marginal_slope_baseline(age, &cfg)
3827            .expect("evaluate marginal-slope gompertz-makeham baseline");
3828        let shape = cfg.shape.expect("shape");
3829        let rate = cfg.rate.expect("rate");
3830        let makeham = cfg.makeham.expect("makeham");
3831        let cumulative_hazard = makeham * age + (rate / shape) * ((shape * age).exp() - 1.0);
3832        let instant_hazard = makeham + rate * (shape * age).exp();
3833        let expected_survival = (-cumulative_hazard).exp();
3834        let actual_survival = normal_cdf(-q);
3835        assert!((actual_survival - expected_survival).abs() <= 1e-12);
3836
3837        let h = 1e-5;
3838        let q_plus = evaluate_survival_marginal_slope_baseline(age + h, &cfg)
3839            .expect("q plus")
3840            .0;
3841        let q_minus = evaluate_survival_marginal_slope_baseline(age - h, &cfg)
3842            .expect("q minus")
3843            .0;
3844        let fd = (q_plus - q_minus) / (2.0 * h);
3845        assert!((q_derivative - fd).abs() <= 1e-7);
3846        assert!(instant_hazard > 0.0);
3847    }
3848
3849    #[test]
3850    fn marginal_slope_baseline_is_evaluable_at_the_survival_curve_origin() {
3851        // Regression for #1024: the probit/marginal-slope baseline evaluator must
3852        // be defined at the survival-curve origin t = 0 (where S0(0) = 1, so the
3853        // probit index q(0) = -Phi^{-1}(1) = -inf and there is no finite offset),
3854        // exactly like its log-cumulative-hazard sibling `evaluate_survival_baseline`.
3855        // Before the fix the shared `age <= 0` hazard guard aborted, so a survival
3856        // prediction grid whose first node is the origin (the `Surv(time, event)`
3857        // right-censored shorthand) could not be evaluated for the location-scale /
3858        // marginal-slope likelihoods.
3859        let configs = [
3860            SurvivalBaselineConfig {
3861                target: SurvivalBaselineTarget::Linear,
3862                scale: None,
3863                shape: None,
3864                rate: None,
3865                makeham: None,
3866            },
3867            SurvivalBaselineConfig {
3868                target: SurvivalBaselineTarget::Weibull,
3869                scale: Some(2.5),
3870                shape: Some(1.3),
3871                rate: None,
3872                makeham: None,
3873            },
3874            SurvivalBaselineConfig {
3875                target: SurvivalBaselineTarget::Gompertz,
3876                scale: None,
3877                shape: Some(0.05),
3878                rate: Some(0.01),
3879                makeham: None,
3880            },
3881            SurvivalBaselineConfig {
3882                target: SurvivalBaselineTarget::GompertzMakeham,
3883                scale: None,
3884                shape: Some(0.07),
3885                rate: Some(0.012),
3886                makeham: Some(0.003),
3887            },
3888        ];
3889        for cfg in &configs {
3890            // The probit baseline returns a finite zero offset at the origin for
3891            // every target (the survival surface anchors S(0) = 1 directly).
3892            let (q0, q0_derivative) = evaluate_survival_marginal_slope_baseline(0.0, cfg)
3893                .expect("marginal-slope baseline must be evaluable at the origin");
3894            assert_eq!(q0, 0.0);
3895            assert_eq!(q0_derivative, 0.0);
3896
3897            // The log-cumulative-hazard sibling is likewise finite at the origin —
3898            // this parity is the whole point (the transformation likelihood already
3899            // worked because it rides this evaluator).
3900            let (eta0, eta0_derivative) =
3901                evaluate_survival_baseline(0.0, cfg).expect("log-cum-hazard baseline at origin");
3902            assert!(eta0_derivative.is_finite());
3903            assert!(eta0.is_finite() || eta0 == f64::NEG_INFINITY);
3904
3905            // The batched offset builder must not abort when a query exit age is the
3906            // origin (this is the exact call the location-scale predict path makes on
3907            // the default surface grid). Entry stays at the origin, exit spans 0 -> t.
3908            let age_entry = array![0.0, 0.0];
3909            let age_exit = array![0.0, 1.5];
3910            let (entry, exit, derivative) =
3911                build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, cfg)
3912                    .expect("probit baseline offsets must build through the origin");
3913            assert!(entry.iter().all(|v| v.is_finite()));
3914            assert!(exit.iter().all(|v| v.is_finite()));
3915            assert!(derivative.iter().all(|v| v.is_finite()));
3916            // The origin exit column carries no probit offset.
3917            assert_eq!(exit[0], 0.0);
3918        }
3919    }
3920
3921    #[test]
3922    fn marginal_slope_baseline_offsets_use_true_gompertz_makeham_survival() {
3923        let cfg = SurvivalBaselineConfig {
3924            target: SurvivalBaselineTarget::GompertzMakeham,
3925            scale: None,
3926            shape: Some(0.03),
3927            rate: Some(0.01),
3928            makeham: Some(0.002),
3929        };
3930        let age_entry = array![2.0, 4.0];
3931        let age_exit = array![5.0, 9.0];
3932        let (entry, exit, derivative) =
3933            build_survival_marginal_slope_baseline_offsets(&age_entry, &age_exit, &cfg)
3934                .expect("marginal-slope baseline offsets");
3935        for i in 0..age_entry.len() {
3936            let entry_h = cfg.makeham.expect("makeham") * age_entry[i]
3937                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
3938                    * ((cfg.shape.expect("shape") * age_entry[i]).exp() - 1.0);
3939            let exit_h = cfg.makeham.expect("makeham") * age_exit[i]
3940                + (cfg.rate.expect("rate") / cfg.shape.expect("shape"))
3941                    * ((cfg.shape.expect("shape") * age_exit[i]).exp() - 1.0);
3942            assert!((normal_cdf(-entry[i]) - (-entry_h).exp()).abs() <= 1e-12);
3943            assert!((normal_cdf(-exit[i]) - (-exit_h).exp()).abs() <= 1e-12);
3944            assert!(derivative[i].is_finite() && derivative[i] > 0.0);
3945        }
3946    }
3947
3948    fn fd_marginal_slope_baseline_offset(
3949        age: f64,
3950        cfg: &SurvivalBaselineConfig,
3951        steps: &[f64],
3952    ) -> Vec<(f64, f64)> {
3953        let theta = survival_baseline_theta_from_config(cfg)
3954            .expect("theta")
3955            .expect("non-linear baseline");
3956        assert_eq!(
3957            steps.len(),
3958            theta.len(),
3959            "fd_marginal_slope_baseline_offset: step vector length must match θ dimension"
3960        );
3961        (0..theta.len())
3962            .map(|k| {
3963                let h = steps[k];
3964                let mut theta_plus = theta.clone();
3965                theta_plus[k] += h;
3966                let mut theta_minus = theta.clone();
3967                theta_minus[k] -= h;
3968                let cfg_plus =
3969                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
3970                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
3971                    .expect("minus cfg");
3972                let (q_p, qt_p) =
3973                    evaluate_survival_marginal_slope_baseline(age, &cfg_plus).expect("q+");
3974                let (q_m, qt_m) =
3975                    evaluate_survival_marginal_slope_baseline(age, &cfg_minus).expect("q-");
3976                ((q_p - q_m) / (2.0 * h), (qt_p - qt_m) / (2.0 * h))
3977            })
3978            .collect()
3979    }
3980
3981    #[test]
3982    fn marginal_slope_baseline_theta_partials_match_fd_for_gompertz_makeham() {
3983        let cfg = SurvivalBaselineConfig {
3984            target: SurvivalBaselineTarget::GompertzMakeham,
3985            scale: None,
3986            shape: Some(0.04),
3987            rate: Some(0.013),
3988            makeham: Some(0.002),
3989        };
3990        let age = 17.0;
3991        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
3992            .expect("partials")
3993            .expect("nonlinear");
3994        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-5, 1e-5]);
3995        assert_eq!(analytic.len(), fd.len());
3996        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
3997            assert_close(*aq, *fq, 1e-6, &format!("gm-probit q theta[{k}]"));
3998            assert_close(*aqt, *fqt, 1e-6, &format!("gm-probit q' theta[{k}]"));
3999        }
4000    }
4001
4002    #[test]
4003    fn marginal_slope_baseline_theta_partials_match_fd_near_zero_gompertz_shape() {
4004        let cfg = SurvivalBaselineConfig {
4005            target: SurvivalBaselineTarget::GompertzMakeham,
4006            scale: None,
4007            shape: Some(1e-14),
4008            rate: Some(0.013),
4009            makeham: Some(0.002),
4010        };
4011        let age = 17.0;
4012        let analytic = marginal_slope_baseline_offset_theta_partials(age, &cfg)
4013            .expect("partials")
4014            .expect("nonlinear");
4015        let fd = fd_marginal_slope_baseline_offset(age, &cfg, &[1e-5, 1e-11, 1e-5]);
4016        assert_eq!(analytic.len(), fd.len());
4017        for (k, ((aq, aqt), (fq, fqt))) in analytic.iter().zip(fd.iter()).enumerate() {
4018            assert_close(*aq, *fq, 1e-5, &format!("near-zero gm-probit q theta[{k}]"));
4019            assert_close(
4020                *aqt,
4021                *fqt,
4022                1e-5,
4023                &format!("near-zero gm-probit q' theta[{k}]"),
4024            );
4025        }
4026    }
4027
4028    fn shifted_quadratic_offset_residuals(
4029        age_entry: ndarray::ArrayView1<'_, f64>,
4030        age_exit: ndarray::ArrayView1<'_, f64>,
4031        base_cfg: &SurvivalBaselineConfig,
4032        candidate_cfg: &SurvivalBaselineConfig,
4033        base: &OffsetChannelResiduals,
4034        curvatures: &OffsetChannelCurvatures,
4035    ) -> OffsetChannelResiduals {
4036        let n = age_exit.len();
4037        let mut entry = base.entry.clone();
4038        let mut exit = base.exit.clone();
4039        let mut derivative = base.derivative.clone();
4040        for row in 0..n {
4041            let (_, base_exit, base_deriv) =
4042                baseline_marginal_slope_channels(age_exit[row], base_cfg);
4043            let (_, cand_exit, cand_deriv) =
4044                baseline_marginal_slope_channels(age_exit[row], candidate_cfg);
4045            let base_entry = if base.entry[row] == 0.0 {
4046                0.0
4047            } else {
4048                baseline_marginal_slope_channels(age_entry[row], base_cfg).1
4049            };
4050            let cand_entry = if base.entry[row] == 0.0 {
4051                0.0
4052            } else {
4053                baseline_marginal_slope_channels(age_entry[row], candidate_cfg).1
4054            };
4055            let delta = [
4056                cand_entry - base_entry,
4057                cand_exit - base_exit,
4058                cand_deriv - base_deriv,
4059            ];
4060            let mut shift = [0.0; 3];
4061            for i in 0..3 {
4062                for j in 0..3 {
4063                    shift[i] += curvatures.rows[row][i][j] * delta[j];
4064                }
4065            }
4066            if base.entry[row] != 0.0 {
4067                entry[row] += shift[0];
4068            }
4069            exit[row] += shift[1];
4070            derivative[row] += shift[2];
4071        }
4072        OffsetChannelResiduals {
4073            entry,
4074            exit,
4075            derivative,
4076            right: base.right.clone(),
4077        }
4078    }
4079
4080    fn baseline_marginal_slope_channels(age: f64, cfg: &SurvivalBaselineConfig) -> (f64, f64, f64) {
4081        let (q, q_t) = evaluate_survival_marginal_slope_baseline(age, cfg).expect("baseline");
4082        (q, q, q_t)
4083    }
4084
4085    #[test]
4086    fn marginal_slope_baseline_chain_rule_hessian_matches_fd_gradient() {
4087        let cfg = SurvivalBaselineConfig {
4088            target: SurvivalBaselineTarget::GompertzMakeham,
4089            scale: None,
4090            shape: Some(0.025),
4091            rate: Some(0.012),
4092            makeham: Some(0.003),
4093        };
4094        let theta = survival_baseline_theta_from_config(&cfg)
4095            .expect("theta")
4096            .expect("nonlinear");
4097        let age_entry = array![2.5, 0.0, 5.0];
4098        let age_exit = array![7.5, 11.0, 15.0];
4099        let base_residuals = OffsetChannelResiduals {
4100            entry: array![0.2, 0.0, -0.1],
4101            exit: array![0.6, -0.3, 0.4],
4102            derivative: array![-0.5, 0.25, 0.15],
4103            right: Array1::<f64>::zeros(3),
4104        };
4105        let curvatures = OffsetChannelCurvatures {
4106            rows: vec![
4107                [[1.4, 0.2, -0.1], [0.2, 1.1, 0.05], [-0.1, 0.05, 0.7]],
4108                [[0.9, -0.15, 0.0], [-0.15, 1.3, 0.12], [0.0, 0.12, 0.8]],
4109                [[1.2, 0.05, 0.09], [0.05, 0.95, -0.04], [0.09, -0.04, 0.6]],
4110            ],
4111        };
4112        let analytic = marginal_slope_baseline_chain_rule_hessian(
4113            age_entry.view(),
4114            age_exit.view(),
4115            &cfg,
4116            &base_residuals,
4117            &curvatures,
4118        )
4119        .expect("hessian")
4120        .expect("nonlinear");
4121
4122        let gradient_at = |theta_candidate: &Array1<f64>| -> Array1<f64> {
4123            let candidate = survival_baseline_config_from_theta(cfg.target, theta_candidate)
4124                .expect("candidate cfg");
4125            let residuals = shifted_quadratic_offset_residuals(
4126                age_entry.view(),
4127                age_exit.view(),
4128                &cfg,
4129                &candidate,
4130                &base_residuals,
4131                &curvatures,
4132            );
4133            marginal_slope_baseline_chain_rule_gradient(
4134                age_entry.view(),
4135                age_exit.view(),
4136                &candidate,
4137                &residuals,
4138            )
4139            .expect("gradient")
4140            .expect("nonlinear")
4141        };
4142
4143        for j in 0..theta.len() {
4144            let step = if j == 1 { 2e-5 } else { 1e-5 };
4145            let mut plus = theta.clone();
4146            plus[j] += step;
4147            let mut minus = theta.clone();
4148            minus[j] -= step;
4149            let fd_col = (&gradient_at(&plus) - &gradient_at(&minus)) / (2.0 * step);
4150            for i in 0..theta.len() {
4151                assert_close(
4152                    analytic[[i, j]],
4153                    fd_col[i],
4154                    2e-5,
4155                    &format!("baseline Hessian ({i},{j})"),
4156                );
4157            }
4158        }
4159    }
4160
4161    #[test]
4162    fn marginal_slope_baseline_chain_rule_gradient_contracts_probit_partials() {
4163        let cfg = SurvivalBaselineConfig {
4164            target: SurvivalBaselineTarget::GompertzMakeham,
4165            scale: None,
4166            shape: Some(0.03),
4167            rate: Some(0.01),
4168            makeham: Some(0.002),
4169        };
4170        let age_entry = array![3.0, 6.0];
4171        let age_exit = array![8.0, 12.0];
4172        let residuals = OffsetChannelResiduals {
4173            exit: array![0.7, -0.2],
4174            entry: array![0.1, 0.4],
4175            derivative: array![1.3, -0.6],
4176            right: Array1::<f64>::zeros(2),
4177        };
4178        let grad = marginal_slope_baseline_chain_rule_gradient(
4179            age_entry.view(),
4180            age_exit.view(),
4181            &cfg,
4182            &residuals,
4183        )
4184        .expect("gradient")
4185        .expect("nonlinear");
4186
4187        let mut expected = Array1::<f64>::zeros(3);
4188        for i in 0..age_exit.len() {
4189            let exit_partials = marginal_slope_baseline_offset_theta_partials(age_exit[i], &cfg)
4190                .expect("exit partials")
4191                .expect("nonlinear");
4192            let entry_partials = marginal_slope_baseline_offset_theta_partials(age_entry[i], &cfg)
4193                .expect("entry partials")
4194                .expect("nonlinear");
4195            for k in 0..3 {
4196                expected[k] += residuals.exit[i] * exit_partials[k].0
4197                    + residuals.derivative[i] * exit_partials[k].1
4198                    + residuals.entry[i] * entry_partials[k].0;
4199            }
4200        }
4201        for k in 0..3 {
4202            assert_close(
4203                grad[k],
4204                expected[k],
4205                1e-12,
4206                &format!("gm-probit chain gradient theta[{k}]"),
4207            );
4208        }
4209    }
4210
4211    /// Parity guard for the shared `baseline_chain_rule_gradient_with_partials`
4212    /// engine (issue #429): both public gradient functions delegate to it with a
4213    /// different partials provider. This test reimplements the pre-unification
4214    /// inline contraction (the serial reference) and asserts bit-for-bit equality
4215    /// against the unified engine's output for BOTH providers on the same data —
4216    /// the RP-eta provider (`baseline_offset_theta_partials`) and the probit-q
4217    /// provider (`marginal_slope_baseline_offset_theta_partials`). Any drift in
4218    /// the extracted contraction (length checks, theta-dim probe, exit/derivative
4219    /// combination, or entry gating) breaks this with an exact (0.0) tolerance.
4220    #[test]
4221    fn baseline_chain_rule_gradient_engine_matches_inline_reference() {
4222        let cfg = SurvivalBaselineConfig {
4223            target: SurvivalBaselineTarget::GompertzMakeham,
4224            scale: None,
4225            shape: Some(0.028),
4226            rate: Some(0.011),
4227            makeham: Some(0.0025),
4228        };
4229        // Mixed entry interval: row 1 is origin-entry (age_entry==0, r_entry==0)
4230        // to exercise the entry-gating branch in the shared engine.
4231        let age_entry = array![3.0, 0.0, 5.5];
4232        let age_exit = array![8.0, 12.0, 16.0];
4233        let residuals = OffsetChannelResiduals {
4234            exit: array![0.7, -0.2, 0.45],
4235            entry: array![0.1, 0.0, -0.3],
4236            derivative: array![1.3, -0.6, 0.2],
4237            right: Array1::<f64>::zeros(3),
4238        };
4239
4240        // Serial reference contraction matching the original inline body. Mirrors
4241        // the engine's exit+derivative/entry split and origin-entry gating.
4242        let reference_gradient = |partials: &dyn Fn(
4243            f64,
4244            &SurvivalBaselineConfig,
4245        )
4246            -> Result<Option<Vec<(f64, f64)>>, String>|
4247         -> Array1<f64> {
4248            let theta_dim = partials(age_exit[0], &cfg)
4249                .expect("probe partials")
4250                .expect("nonlinear")
4251                .len();
4252            let mut acc = Array1::<f64>::zeros(theta_dim);
4253            for i in 0..age_exit.len() {
4254                let p_exit = partials(age_exit[i], &cfg)
4255                    .expect("exit partials")
4256                    .expect("nonlinear");
4257                let r_x = residuals.exit[i];
4258                let r_d = residuals.derivative[i];
4259                for k in 0..theta_dim {
4260                    acc[k] += r_x * p_exit[k].0 + r_d * p_exit[k].1;
4261                }
4262                let r_e = residuals.entry[i];
4263                if r_e != 0.0 {
4264                    let p_entry = partials(age_entry[i], &cfg)
4265                        .expect("entry partials")
4266                        .expect("nonlinear");
4267                    for k in 0..theta_dim {
4268                        acc[k] += r_e * p_entry[k].0;
4269                    }
4270                }
4271            }
4272            acc
4273        };
4274
4275        // RP-eta provider parity.
4276        let rp_engine = baseline_chain_rule_gradient(
4277            age_entry.view(),
4278            age_exit.view(),
4279            age_exit.view(),
4280            &cfg,
4281            &residuals,
4282        )
4283        .expect("rp gradient")
4284        .expect("rp nonlinear");
4285        let rp_reference = reference_gradient(&baseline_offset_theta_partials);
4286        assert_eq!(rp_engine.len(), rp_reference.len());
4287        for k in 0..rp_engine.len() {
4288            assert_close(
4289                rp_engine[k],
4290                rp_reference[k],
4291                0.0,
4292                &format!("rp engine vs inline reference theta[{k}]"),
4293            );
4294        }
4295
4296        // Probit-q provider parity.
4297        let probit_engine = marginal_slope_baseline_chain_rule_gradient(
4298            age_entry.view(),
4299            age_exit.view(),
4300            &cfg,
4301            &residuals,
4302        )
4303        .expect("probit gradient")
4304        .expect("probit nonlinear");
4305        let probit_reference = reference_gradient(&marginal_slope_baseline_offset_theta_partials);
4306        assert_eq!(probit_engine.len(), probit_reference.len());
4307        for k in 0..probit_engine.len() {
4308            assert_close(
4309                probit_engine[k],
4310                probit_reference[k],
4311                0.0,
4312                &format!("probit engine vs inline reference theta[{k}]"),
4313            );
4314        }
4315    }
4316
4317    /// Finite-difference verification of the analytic θ-gradient used by the
4318    /// survival location-scale workflow path.
4319    ///
4320    /// At a converged β, the envelope theorem reduces the profile-NLL gradient
4321    /// w.r.t. the baseline-config θ to a per-row residual contraction against
4322    /// the per-row offset-channel partials ∂o/∂θ:
4323    ///
4324    ///   d(NLL)/dθ_k = Σ_i [ r_X[i]·∂η_exit/∂θ_k + r_E[i]·∂η_entry/∂θ_k
4325    ///                       + r_D[i]·∂o_D_exit/∂θ_k ]
4326    ///
4327    /// (`baseline_chain_rule_gradient`). Because β is fixed, an explicit loss
4328    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i; θ) + r_E[i]·η(t_entry_i; θ)
4329    ///              + r_D[i]·o_D(t_exit_i; θ) ]`
4330    /// has gradient identically equal to the chain-rule output. Comparing the
4331    /// analytic gradient to a central-difference of L over `evaluate_survival_baseline`
4332    /// therefore exercises every piece of the chain rule (incl. the Gompertz
4333    /// rate / shape / Makeham partials at both entry and exit ages) without
4334    /// needing the full location-scale fit pipeline inside this unit-test
4335    /// module. If the chain rule disagrees with FD here, the workflow's
4336    /// gradient is wrong by exactly the same amount.
4337    #[test]
4338    fn gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference() {
4339        let cfg = SurvivalBaselineConfig {
4340            target: SurvivalBaselineTarget::GompertzMakeham,
4341            scale: None,
4342            shape: Some(0.05),
4343            rate: Some(0.012),
4344            makeham: Some(0.003),
4345        };
4346        // n = 8 small synthetic dataset spanning a realistic age range.
4347        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4348        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4349        // Synthetic per-row NLL residuals on the three offset channels. Mix of
4350        // signs / magnitudes / one zero-entry row (origin entry → r_E=0).
4351        let residuals = OffsetChannelResiduals {
4352            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4353            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4354            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4355            right: Array1::<f64>::zeros(8),
4356        };
4357
4358        let analytic = baseline_chain_rule_gradient(
4359            age_entry.view(),
4360            age_exit.view(),
4361            age_exit.view(),
4362            &cfg,
4363            &residuals,
4364        )
4365        .expect("analytic gradient ok")
4366        .expect("GM baseline has a θ-gradient");
4367        assert_eq!(analytic.len(), 3, "GM θ has 3 components");
4368
4369        // Evaluate the offset-projected loss at a perturbed θ. Mirrors the
4370        // chain rule's algebra: the entry channel is only added for rows whose
4371        // r_E is nonzero (matching baseline_chain_rule_gradient's gating that
4372        // avoids calling evaluate_survival_baseline at age 0 for origin-entry
4373        // rows).
4374        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4375            let mut acc = 0.0;
4376            for i in 0..age_exit.len() {
4377                let (eta_exit_i, od_exit_i) =
4378                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4379                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4380                if residuals.entry[i] != 0.0 {
4381                    let (eta_entry_i, _) =
4382                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4383                    acc += residuals.entry[i] * eta_entry_i;
4384                }
4385            }
4386            acc
4387        };
4388
4389        let theta0 = survival_baseline_theta_from_config(&cfg)
4390            .expect("theta seed")
4391            .expect("GM has θ");
4392        // Spec requested δ = 1e-4 per axis. Use central differences over θ.
4393        let delta = 1e-4;
4394        let mut fd = Array1::<f64>::zeros(analytic.len());
4395        for k in 0..analytic.len() {
4396            let mut theta_plus = theta0.clone();
4397            theta_plus[k] += delta;
4398            let mut theta_minus = theta0.clone();
4399            theta_minus[k] -= delta;
4400            let cfg_plus =
4401                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4402            let cfg_minus =
4403                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4404            let lp = loss_at_cfg(&cfg_plus);
4405            let lm = loss_at_cfg(&cfg_minus);
4406            fd[k] = (lp - lm) / (2.0 * delta);
4407        }
4408
4409        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4410        let max_err = analytic
4411            .iter()
4412            .zip(fd.iter())
4413            .map(|(a, b)| (a - b).abs())
4414            .fold(0.0_f64, f64::max);
4415        let rel = max_err / (analytic_norm + 1e-12);
4416        // Print so the deliverable can quote the exact max-error number.
4417        eprintln!(
4418            "gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference: \
4419             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4420             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4421        );
4422        assert!(
4423            rel < 1e-2,
4424            "analytic θ-gradient disagrees with central FD beyond 1%: \
4425             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4426             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4427        );
4428    }
4429
4430    /// Weibull (dim=2) companion to
4431    /// `gompertz_makeham_baseline_chain_rule_gradient_matches_finite_difference`.
4432    ///
4433    /// This is the FD gate for the analytic outer θ-gradient that the
4434    /// transformation/Weibull survival baseline optimizers now feed to BFGS
4435    /// (`optimize_survival_baseline_config_with_gradient_only`). At a *fixed* β
4436    /// the profile-NLL surface is
4437    /// `L(θ) = Σ_i [ r_X[i]·η(t_exit_i;θ) + r_E[i]·η(t_entry_i;θ)
4438    ///              + r_D[i]·o_D(t_exit_i;θ) ]`,
4439    /// whose exact gradient is `baseline_chain_rule_gradient`. Comparing it to a
4440    /// central difference of `L` over `evaluate_survival_baseline` exercises the
4441    /// Weibull scale/shape partials at both entry and exit ages. If this
4442    /// disagrees with FD, the workflow's outer gradient is wrong by the same
4443    /// amount.
4444    #[test]
4445    fn weibull_baseline_chain_rule_gradient_matches_finite_difference() {
4446        let cfg = SurvivalBaselineConfig {
4447            target: SurvivalBaselineTarget::Weibull,
4448            scale: Some(11.0),
4449            shape: Some(1.4),
4450            rate: None,
4451            makeham: None,
4452        };
4453        let age_entry = array![5.0, 8.0, 12.0, 0.5, 20.0, 30.0, 45.0, 60.0];
4454        let age_exit = array![10.0, 15.0, 25.0, 4.0, 35.0, 50.0, 65.0, 80.0];
4455        let residuals = OffsetChannelResiduals {
4456            exit: array![0.42, -0.18, 0.73, -0.91, 0.05, -0.27, 0.61, -0.34],
4457            entry: array![-0.12, 0.31, -0.44, 0.0, 0.16, -0.22, 0.07, -0.51],
4458            derivative: array![1.04, -0.65, 0.18, -1.21, 0.42, -0.13, 0.88, -0.27],
4459            right: Array1::<f64>::zeros(8),
4460        };
4461
4462        let analytic = baseline_chain_rule_gradient(
4463            age_entry.view(),
4464            age_exit.view(),
4465            age_exit.view(),
4466            &cfg,
4467            &residuals,
4468        )
4469        .expect("analytic gradient ok")
4470        .expect("Weibull baseline has a θ-gradient");
4471        assert_eq!(analytic.len(), 2, "Weibull θ has 2 components");
4472
4473        let loss_at_cfg = |cfg_eval: &SurvivalBaselineConfig| -> f64 {
4474            let mut acc = 0.0;
4475            for i in 0..age_exit.len() {
4476                let (eta_exit_i, od_exit_i) =
4477                    evaluate_survival_baseline(age_exit[i], cfg_eval).expect("eval exit");
4478                acc += residuals.exit[i] * eta_exit_i + residuals.derivative[i] * od_exit_i;
4479                if residuals.entry[i] != 0.0 {
4480                    let (eta_entry_i, _) =
4481                        evaluate_survival_baseline(age_entry[i], cfg_eval).expect("eval entry");
4482                    acc += residuals.entry[i] * eta_entry_i;
4483                }
4484            }
4485            acc
4486        };
4487
4488        let theta0 = survival_baseline_theta_from_config(&cfg)
4489            .expect("theta seed")
4490            .expect("Weibull has θ");
4491        let delta = 1e-4;
4492        let mut fd = Array1::<f64>::zeros(analytic.len());
4493        for k in 0..analytic.len() {
4494            let mut theta_plus = theta0.clone();
4495            theta_plus[k] += delta;
4496            let mut theta_minus = theta0.clone();
4497            theta_minus[k] -= delta;
4498            let cfg_plus =
4499                survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("cfg(θ+δ)");
4500            let cfg_minus =
4501                survival_baseline_config_from_theta(cfg.target, &theta_minus).expect("cfg(θ-δ)");
4502            let lp = loss_at_cfg(&cfg_plus);
4503            let lm = loss_at_cfg(&cfg_minus);
4504            fd[k] = (lp - lm) / (2.0 * delta);
4505        }
4506
4507        let analytic_norm = analytic.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4508        let max_err = analytic
4509            .iter()
4510            .zip(fd.iter())
4511            .map(|(a, b)| (a - b).abs())
4512            .fold(0.0_f64, f64::max);
4513        let rel = max_err / (analytic_norm + 1e-12);
4514        eprintln!(
4515            "weibull_baseline_chain_rule_gradient_matches_finite_difference: \
4516             analytic={analytic:?} fd={fd:?} max_err={max_err:.3e} \
4517             analytic_inf_norm={analytic_norm:.3e} rel={rel:.3e}"
4518        );
4519        assert!(
4520            rel < 1e-2,
4521            "analytic θ-gradient disagrees with central FD beyond 1%: \
4522             analytic={analytic:?}, fd={fd:?}, max_err={max_err:.3e}, \
4523             rel={rel:.3e} (analytic_inf_norm={analytic_norm:.3e})"
4524        );
4525    }
4526
4527    // ─── baseline_offset_theta_partials — analytic vs central-difference ─
4528
4529    /// Central-difference of (eta, o_D) at fixed age wrt each θ component in
4530    /// the theta layout defined by `survival_baseline_theta_from_config`.
4531    ///
4532    /// `steps` is per-θ-component: the caller picks the step size appropriate
4533    /// for each channel. Gompertz / Gompertz–Makeham need a tiny step on the
4534    /// shape channel near the Taylor pivot |shape| < 1e-10 (so θ±h stays on
4535    /// the same branch), but a normal-scale step on log_rate / log_makeham;
4536    /// using the tiny shape-step on every channel corrupts the log_rate
4537    /// channel with `eps/(2h)` cancellation noise and has nothing to do with
4538    /// correctness of the analytic derivative.
4539    fn fd_baseline_offset(
4540        age: f64,
4541        cfg: &SurvivalBaselineConfig,
4542        steps: &[f64],
4543    ) -> Vec<(f64, f64)> {
4544        let theta = survival_baseline_theta_from_config(cfg)
4545            .expect("theta")
4546            .expect("non-linear baseline");
4547        assert_eq!(
4548            steps.len(),
4549            theta.len(),
4550            "fd_baseline_offset: step vector length must match θ dimension"
4551        );
4552        (0..theta.len())
4553            .map(|k| {
4554                let h = steps[k];
4555                let mut theta_plus = theta.clone();
4556                theta_plus[k] += h;
4557                let mut theta_minus = theta.clone();
4558                theta_minus[k] -= h;
4559                let cfg_plus =
4560                    survival_baseline_config_from_theta(cfg.target, &theta_plus).expect("plus cfg");
4561                let cfg_minus = survival_baseline_config_from_theta(cfg.target, &theta_minus)
4562                    .expect("minus cfg");
4563                let (eta_p, od_p) = evaluate_survival_baseline(age, &cfg_plus).expect("eta+");
4564                let (eta_m, od_m) = evaluate_survival_baseline(age, &cfg_minus).expect("eta-");
4565                ((eta_p - eta_m) / (2.0 * h), (od_p - od_m) / (2.0 * h))
4566            })
4567            .collect()
4568    }
4569
4570    fn assert_close(actual: f64, expected: f64, tol: f64, what: &str) {
4571        // `<=` so that bit-equal values satisfy tol = 0. With `<`, |a−e| < 0
4572        // is unsatisfiable and a zero-tolerance "must match exactly" call
4573        // would reject identical numbers.
4574        let ok = if expected.abs() < 1.0 {
4575            (actual - expected).abs() <= tol
4576        } else {
4577            (actual - expected).abs() <= tol * expected.abs().max(1.0)
4578        };
4579        assert!(
4580            ok,
4581            "{what}: analytic={actual:.6e} fd={expected:.6e} (tol={tol:.1e})"
4582        );
4583    }
4584
4585    #[test]
4586    fn gompertz_offset_partials_match_central_diff() {
4587        // Several (rate, shape, age) combinations spanning the small-shape
4588        // Taylor branch (|shape| < 1e-10) and the normal branch
4589        // (shape >> 1e-10), plus sign-reversed shape.
4590        let cases = [
4591            (0.5_f64, 0.01_f64, 30.0_f64),
4592            (0.2, 0.05, 60.0),
4593            (1.0, 0.001, 10.0),
4594            (0.4, 5e-11, 25.0),
4595            (0.4, -5e-11, 25.0),
4596            (0.3, -0.02, 40.0),
4597            (0.8, 0.2, 5.0),
4598        ];
4599        for &(rate, shape, age) in &cases {
4600            let cfg = SurvivalBaselineConfig {
4601                target: SurvivalBaselineTarget::Gompertz,
4602                scale: None,
4603                shape: Some(shape),
4604                rate: Some(rate),
4605                makeham: None,
4606            };
4607            let analytic = baseline_offset_theta_partials(age, &cfg)
4608                .expect("ok")
4609                .expect("non-linear");
4610            // Keep the FD probe inside the Taylor branch for tiny |shape| so
4611            // the numeric derivative matches the same small-shape map as the
4612            // analytic helper. log_rate always uses the normal step — rate
4613            // is a moderate-scale parameter and a 1e-11 step would swamp the
4614            // FD with cancellation noise.
4615            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4616            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape]);
4617            assert_eq!(analytic.len(), 2);
4618            // Gompertz θ=(log_rate, shape). Rate channel: ∂eta/∂log_rate=1, ∂o_D/∂log_rate=0.
4619            assert_close(
4620                analytic[0].0,
4621                fd[0].0,
4622                1e-7,
4623                &format!("gompertz ∂eta/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4624            );
4625            assert_close(
4626                analytic[0].1,
4627                fd[0].1,
4628                1e-7,
4629                &format!("gompertz ∂o_D/∂log_rate (rate={rate}, shape={shape}, age={age})"),
4630            );
4631            // shape channel — larger tol because finite-differencing near
4632            // shape=0 amplifies rounding; 1e-5 is fine.
4633            assert_close(
4634                analytic[1].0,
4635                fd[1].0,
4636                1e-5,
4637                &format!("gompertz ∂eta/∂shape (rate={rate}, shape={shape}, age={age})"),
4638            );
4639            assert_close(
4640                analytic[1].1,
4641                fd[1].1,
4642                1e-5,
4643                &format!("gompertz ∂o_D/∂shape (rate={rate}, shape={shape}, age={age})"),
4644            );
4645        }
4646    }
4647
4648    #[test]
4649    fn gompertz_offset_partials_log_rate_channel_is_trivial() {
4650        // Pure Gompertz: rate cancels in o_D, so ∂o_D/∂log_rate must be
4651        // exactly 0 and ∂eta/∂log_rate must be exactly 1. Verify the
4652        // analytic implementation returns the exact values, not FD-close.
4653        let cfg = SurvivalBaselineConfig {
4654            target: SurvivalBaselineTarget::Gompertz,
4655            scale: None,
4656            shape: Some(0.05),
4657            rate: Some(0.3),
4658            makeham: None,
4659        };
4660        let partials = baseline_offset_theta_partials(42.0, &cfg)
4661            .expect("ok")
4662            .expect("non-linear");
4663        assert_eq!(partials[0].0, 1.0);
4664        assert_eq!(partials[0].1, 0.0);
4665    }
4666
4667    #[test]
4668    fn gompertz_offset_partials_small_shape_taylor_agrees_with_direct_branch() {
4669        // Both branches of gompertz_shape_derivatives should agree to high
4670        // precision at shape = 1e-10 + epsilon on the direct side vs
4671        // shape = 1e-10 - epsilon on the Taylor side. Here we spot-check
4672        // the continuity at the branch cutoff: shape slightly above and
4673        // slightly below 1e-10 must give values within O(shape²·t²)
4674        // (the Taylor truncation error).
4675        let age = 25.0;
4676        let rate = 0.4;
4677        let cfg_taylor = SurvivalBaselineConfig {
4678            target: SurvivalBaselineTarget::Gompertz,
4679            scale: None,
4680            shape: Some(0.5e-10),
4681            rate: Some(rate),
4682            makeham: None,
4683        };
4684        let cfg_direct = SurvivalBaselineConfig {
4685            target: SurvivalBaselineTarget::Gompertz,
4686            scale: None,
4687            shape: Some(2.0e-10),
4688            rate: Some(rate),
4689            makeham: None,
4690        };
4691        let p_t = baseline_offset_theta_partials(age, &cfg_taylor)
4692            .expect("ok")
4693            .expect("nl");
4694        let p_d = baseline_offset_theta_partials(age, &cfg_direct)
4695            .expect("ok")
4696            .expect("nl");
4697        // ∂eta/∂shape at shape≈0 should be t/2 = 12.5 on both sides.
4698        assert_close(p_t[1].0, 12.5, 1e-8, "taylor ∂eta/∂shape near 0");
4699        assert_close(p_d[1].0, 12.5, 1e-8, "direct ∂eta/∂shape near 0");
4700        // ∂o_D/∂shape at shape≈0 should be 1/2.
4701        assert_close(p_t[1].1, 0.5, 1e-8, "taylor ∂o_D/∂shape near 0");
4702        assert_close(p_d[1].1, 0.5, 1e-8, "direct ∂o_D/∂shape near 0");
4703    }
4704
4705    // ----------------------------------------------------------------------
4706    // Gompertz hazard-channel shape derivatives: FD oracle + Taylor-branch
4707    // continuity. These feed `survival_hazard_theta_partials` /
4708    // `survival_hazard_theta_first_second` (the marginal-slope probit
4709    // baseline). Before this test, the only coverage of
4710    // `gompertz_cumulative_shape_{,second_}derivative` was the indirect
4711    // marginal-slope Hessian FD at shape=0.025, which never touches the
4712    // small-shape (`|shape| < 1e-10`) Taylor branch nor directly FD-checks
4713    // these analytic shape derivatives.
4714    // ----------------------------------------------------------------------
4715
4716    #[test]
4717    fn gompertz_hazard_shape_derivatives_match_central_diff() {
4718        // shape stays well above the 1e-10 Taylor cutoff so the exact
4719        // closed-form branch is exercised and the expm1/exp arithmetic is
4720        // numerically clean. FD on the analytic value/first-derivative
4721        // confirms the first and second shape derivatives.
4722        let cases = [
4723            (10.0_f64, 0.012_f64, 0.05_f64),
4724            (2.5, 0.5, 0.2),
4725            (15.0, 0.003, 0.01),
4726            (40.0, 0.3, 0.001),
4727        ];
4728        let h = 1e-6;
4729        for &(age, rate, shape) in &cases {
4730            // First shape derivative of (H_G, h_G) vs central diff of value.
4731            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4732            let (cum_p, inst_p) = gompertz_hazard_components(age, rate, shape + h);
4733            let (cum_m, inst_m) = gompertz_hazard_components(age, rate, shape - h);
4734            assert_close(
4735                d_cum,
4736                (cum_p - cum_m) / (2.0 * h),
4737                1e-6,
4738                &format!("∂H_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4739            );
4740            assert_close(
4741                d_inst,
4742                (inst_p - inst_m) / (2.0 * h),
4743                1e-6,
4744                &format!("∂h_G/∂shape (age={age}, rate={rate}, shape={shape})"),
4745            );
4746
4747            // Second shape derivative vs central diff of the first derivative.
4748            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4749            let (dcum_p, dinst_p) = gompertz_cumulative_shape_derivative(age, rate, shape + h);
4750            let (dcum_m, dinst_m) = gompertz_cumulative_shape_derivative(age, rate, shape - h);
4751            assert_close(
4752                d2_cum,
4753                (dcum_p - dcum_m) / (2.0 * h),
4754                1e-5,
4755                &format!("∂²H_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4756            );
4757            assert_close(
4758                d2_inst,
4759                (dinst_p - dinst_m) / (2.0 * h),
4760                1e-5,
4761                &format!("∂²h_G/∂shape² (age={age}, rate={rate}, shape={shape})"),
4762            );
4763        }
4764    }
4765
4766    #[test]
4767    fn gompertz_hazard_shape_derivatives_small_shape_match_analytic_limit() {
4768        // At small x = shape·age the shape derivatives collapse to closed-form
4769        // limits. These MUST hold even for large ages with tiny shapes, which
4770        // is precisely the regime where the (cancelling) exact branch loses all
4771        // precision and the x-based pivot routes to the Taylor branch.
4772        //   ∂H_G/∂shape   -> rate·t²/2
4773        //   ∂h_G/∂shape   -> rate·t
4774        //   ∂²H_G/∂shape² -> rate·t³/3
4775        //   ∂²h_G/∂shape² -> rate·t²
4776        // The bug this guards: the second derivative's old `shape < 1e-10`
4777        // pivot ignored `age`, so e.g. (age=100, shape=1e-5 -> x=1e-3) took the
4778        // cancelling exact branch and returned a wildly wrong curvature.
4779        let cases = [
4780            (25.0_f64, 0.4_f64, 1e-9_f64),
4781            (100.0, 0.4, 1e-6),   // x = 1e-4
4782            (100.0, 0.012, 1e-6), // x = 1e-4, the old-pivot band (large age, tiny shape)
4783            (50.0, 1.2, 1e-8),
4784        ];
4785        // NOTE: every quantity below is compared against its shape->0 *limit*.
4786        // For the cancelling cumulative branches (∂H/∂shape, ∂²H/∂shape²,
4787        // ∂²h/∂shape²) the limit is the correct shape->0 target and the
4788        // implementation routes through Taylor in this band. But the
4789        // instantaneous first derivative ∂h_G/∂shape = rate·age·e^x carries NO
4790        // cancellation: it is exact, and its departure from the limit rate·t is
4791        // a genuine O(x) effect. At x=1e-3 that departure is ~1.2e-3 (> tol),
4792        // so the cases here keep x <= 1e-4 where the limit is a valid 1e-3
4793        // oracle for *all four* quantities. The cancelling-branch regression at
4794        // larger x is covered by gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap.
4795        for &(age, rate, shape) in &cases {
4796            let t = age;
4797            let (d_cum, d_inst) = gompertz_cumulative_shape_derivative(age, rate, shape);
4798            assert_close(
4799                d_cum,
4800                rate * t * t / 2.0,
4801                1e-3,
4802                &format!("∂H_G/∂shape limit (age={age}, shape={shape})"),
4803            );
4804            assert_close(
4805                d_inst,
4806                rate * t,
4807                1e-3,
4808                &format!("∂h_G/∂shape limit (age={age}, shape={shape})"),
4809            );
4810
4811            let (d2_cum, d2_inst) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4812            assert_close(
4813                d2_cum,
4814                rate * t * t * t / 3.0,
4815                1e-3,
4816                &format!("∂²H_G/∂shape² limit (age={age}, shape={shape})"),
4817            );
4818            assert_close(
4819                d2_inst,
4820                rate * t * t,
4821                1e-3,
4822                &format!("∂²h_G/∂shape² limit (age={age}, shape={shape})"),
4823            );
4824        }
4825    }
4826
4827    #[test]
4828    fn gompertz_second_shape_derivative_is_accurate_in_old_pivot_gap() {
4829        // Regression: in the band shape ∈ [1e-10, ~1e-4] with a realistic age,
4830        // the OLD `shape < 1e-10` pivot sent ∂²H_G/∂shape² through the
4831        // catastrophically-cancelling exact branch. With age=100, shape=1e-9
4832        // (x=1e-7) the exact branch returned ~+5e1 vs the true ~rate·t³/3.
4833        // Assert the implementation now matches the closed-form limit to high
4834        // precision throughout that band, across several decades of shape.
4835        let age = 100.0;
4836        let rate = 0.4;
4837        let t = age;
4838        let truth = rate * t * t * t / 3.0; // 1.333e5
4839        // Start at shape=1e-5 (x=1e-3): below this the second derivative is,
4840        // to better than 1e-3 relative, equal to its shape->0 limit, so the
4841        // limit is a valid oracle. (At x=1e-2 the true value legitimately
4842        // departs from the limit by ~7e-3, which is a real O(x) correction,
4843        // not an error — so we do not extend the band up to shape=1e-4.)
4844        for k in 5..=12 {
4845            let shape = 10f64.powi(-(k as i32)); // 1e-5 .. 1e-12
4846            let (d2_cum, _) = gompertz_cumulative_shape_second_derivative(age, rate, shape);
4847            assert_close(
4848                d2_cum,
4849                truth,
4850                1e-3,
4851                &format!("∂²H_G/∂shape² in old-pivot gap (age={age}, shape=1e-{k})"),
4852            );
4853        }
4854    }
4855
4856    #[test]
4857    fn weibull_offset_partials_match_central_diff() {
4858        let cases = [
4859            (0.5_f64, 1.2_f64, 25.0_f64),
4860            (2.0, 0.8, 60.0),
4861            (0.1, 3.0, 10.0),
4862        ];
4863        for &(scale, shape, age) in &cases {
4864            let cfg = SurvivalBaselineConfig {
4865                target: SurvivalBaselineTarget::Weibull,
4866                scale: Some(scale),
4867                shape: Some(shape),
4868                rate: None,
4869                makeham: None,
4870            };
4871            let analytic = baseline_offset_theta_partials(age, &cfg)
4872                .expect("ok")
4873                .expect("nl");
4874            let fd = fd_baseline_offset(age, &cfg, &[1e-5, 1e-5]);
4875            assert_eq!(analytic.len(), 2);
4876            for k in 0..2 {
4877                assert_close(
4878                    analytic[k].0,
4879                    fd[k].0,
4880                    1e-7,
4881                    &format!("weibull ∂eta/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4882                );
4883                assert_close(
4884                    analytic[k].1,
4885                    fd[k].1,
4886                    1e-7,
4887                    &format!("weibull ∂o_D/∂θ[{k}] (scale={scale}, shape={shape}, age={age})"),
4888                );
4889            }
4890            // Weibull o_D = shape/t is independent of scale; verify exactly.
4891            assert_eq!(analytic[0].1, 0.0);
4892        }
4893    }
4894
4895    #[test]
4896    fn gompertz_makeham_offset_partials_match_central_diff() {
4897        let cases = [
4898            (0.3_f64, 0.05_f64, 0.002_f64, 40.0_f64),
4899            (0.5, 0.01, 0.01, 25.0),
4900            (0.2, 0.001, 0.005, 60.0),
4901            (0.4, 5e-11, 0.01, 25.0),
4902            (0.4, -5e-11, 0.01, 25.0),
4903            (0.8, 0.2, 0.05, 5.0),
4904        ];
4905        for &(rate, shape, makeham, age) in &cases {
4906            let cfg = SurvivalBaselineConfig {
4907                target: SurvivalBaselineTarget::GompertzMakeham,
4908                scale: None,
4909                shape: Some(shape),
4910                rate: Some(rate),
4911                makeham: Some(makeham),
4912            };
4913            let analytic = baseline_offset_theta_partials(age, &cfg)
4914                .expect("ok")
4915                .expect("nl");
4916            // See gompertz_offset_partials_match_central_diff: tiny shape-step
4917            // is only needed for the shape component; log_rate and
4918            // log_makeham take the normal-scale step.
4919            let h_shape = if shape.abs() < 1e-9 { 1e-11 } else { 1e-5 };
4920            let fd = fd_baseline_offset(age, &cfg, &[1e-5, h_shape, 1e-5]);
4921            assert_eq!(analytic.len(), 3);
4922            for k in 0..3 {
4923                assert_close(
4924                    analytic[k].0,
4925                    fd[k].0,
4926                    1e-5,
4927                    &format!(
4928                        "gm ∂eta/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
4929                    ),
4930                );
4931                assert_close(
4932                    analytic[k].1,
4933                    fd[k].1,
4934                    1e-5,
4935                    &format!(
4936                        "gm ∂o_D/∂θ[{k}] (rate={rate}, shape={shape}, mk={makeham}, age={age})"
4937                    ),
4938                );
4939            }
4940        }
4941    }
4942
4943    #[test]
4944    fn linear_baseline_has_no_theta_partials() {
4945        let cfg = SurvivalBaselineConfig {
4946            target: SurvivalBaselineTarget::Linear,
4947            scale: None,
4948            shape: None,
4949            rate: None,
4950            makeham: None,
4951        };
4952        assert!(baseline_offset_theta_partials(5.0, &cfg).unwrap().is_none());
4953    }
4954
4955    #[test]
4956    fn baseline_offset_partials_reject_non_positive_ages() {
4957        let cfg = SurvivalBaselineConfig {
4958            target: SurvivalBaselineTarget::Gompertz,
4959            scale: None,
4960            shape: Some(0.01),
4961            rate: Some(0.5),
4962            makeham: None,
4963        };
4964        assert!(baseline_offset_theta_partials(0.0, &cfg).is_err());
4965        assert!(baseline_offset_theta_partials(-1.0, &cfg).is_err());
4966        assert!(baseline_offset_theta_partials(f64::NAN, &cfg).is_err());
4967    }
4968
4969    // ─── baseline_chain_rule_gradient — mechanical and FD-vs-θ tests ─────
4970
4971    /// Mechanical sanity check: with only one event observation at known
4972    /// (r_X, r_E, r_D, age_exit, age_entry), the Gompertz chain-rule gradient
4973    /// reduces to the analytic linear combination of `baseline_offset_theta_partials`.
4974    #[test]
4975    fn chain_rule_gradient_single_obs_reduces_to_pointwise_contract() {
4976        let cfg = SurvivalBaselineConfig {
4977            target: SurvivalBaselineTarget::Gompertz,
4978            scale: None,
4979            shape: Some(0.05),
4980            rate: Some(0.3),
4981            makeham: None,
4982        };
4983        let age_entry = array![10.0_f64];
4984        let age_exit = array![25.0_f64];
4985        let residuals = OffsetChannelResiduals {
4986            exit: array![0.7_f64],
4987            entry: array![-0.2_f64],
4988            derivative: array![-0.4_f64],
4989            right: Array1::<f64>::zeros(1),
4990        };
4991        let grad = baseline_chain_rule_gradient(
4992            age_entry.view(),
4993            age_exit.view(),
4994            age_exit.view(),
4995            &cfg,
4996            &residuals,
4997        )
4998        .expect("ok")
4999        .expect("non-linear");
5000        // Hand-compute: grad[k] = r_X·∂eta_exit/∂θ_k + r_D·∂o_D_exit/∂θ_k + r_E·∂eta_entry/∂θ_k.
5001        let p_exit = baseline_offset_theta_partials(age_exit[0], &cfg)
5002            .unwrap()
5003            .unwrap();
5004        let p_entry = baseline_offset_theta_partials(age_entry[0], &cfg)
5005            .unwrap()
5006            .unwrap();
5007        for k in 0..p_exit.len() {
5008            let expected = 0.7 * p_exit[k].0 + (-0.4) * p_exit[k].1 + (-0.2) * p_entry[k].0;
5009            assert!(
5010                (grad[k] - expected).abs() < 1e-12,
5011                "chain-rule contract mismatch at k={k}: got={:.6e} expected={:.6e}",
5012                grad[k],
5013                expected
5014            );
5015        }
5016    }
5017
5018    /// Origin-entry rows (r_entry == 0) must skip the baseline partials call at
5019    /// `age_entry = 0`, which would otherwise fail the positive-age precondition.
5020    #[test]
5021    fn chain_rule_gradient_skips_entry_call_for_origin_entry_rows() {
5022        let cfg = SurvivalBaselineConfig {
5023            target: SurvivalBaselineTarget::Gompertz,
5024            scale: None,
5025            shape: Some(0.05),
5026            rate: Some(0.3),
5027            makeham: None,
5028        };
5029        let age_entry = array![0.0_f64, 5.0_f64];
5030        let age_exit = array![10.0_f64, 20.0_f64];
5031        let residuals = OffsetChannelResiduals {
5032            exit: array![0.5_f64, 0.3_f64],
5033            entry: array![0.0_f64, -0.1_f64], // row 0 is origin-entry (r_E = 0)
5034            derivative: array![-0.2_f64, 0.0_f64],
5035            right: Array1::<f64>::zeros(2),
5036        };
5037        // Must not error despite age_entry[0] == 0.
5038        let grad = baseline_chain_rule_gradient(
5039            age_entry.view(),
5040            age_exit.view(),
5041            age_exit.view(),
5042            &cfg,
5043            &residuals,
5044        )
5045        .expect("must not fail on origin-entry row with r_entry=0")
5046        .expect("non-linear");
5047        assert_eq!(grad.len(), 2);
5048        // Row 1's entry channel contributes, row 0's does not.
5049        let p_exit_0 = baseline_offset_theta_partials(10.0, &cfg).unwrap().unwrap();
5050        let p_exit_1 = baseline_offset_theta_partials(20.0, &cfg).unwrap().unwrap();
5051        let p_entry_1 = baseline_offset_theta_partials(5.0, &cfg).unwrap().unwrap();
5052        for k in 0..2 {
5053            let expected = 0.5 * p_exit_0[k].0
5054                + (-0.2) * p_exit_0[k].1
5055                + 0.3 * p_exit_1[k].0
5056                + (-0.1) * p_entry_1[k].0;
5057            assert!(
5058                (grad[k] - expected).abs() < 1e-12,
5059                "origin-entry contract at k={k}: got={:.6e} expected={:.6e}",
5060                grad[k],
5061                expected
5062            );
5063        }
5064    }
5065
5066    /// Linear target has no θ-parameters; contractor returns None.
5067    #[test]
5068    fn chain_rule_gradient_linear_target_returns_none() {
5069        let cfg = SurvivalBaselineConfig {
5070            target: SurvivalBaselineTarget::Linear,
5071            scale: None,
5072            shape: None,
5073            rate: None,
5074            makeham: None,
5075        };
5076        let age_entry = array![1.0_f64];
5077        let age_exit = array![2.0_f64];
5078        let residuals = OffsetChannelResiduals {
5079            exit: array![0.1_f64],
5080            entry: array![0.0_f64],
5081            derivative: array![0.0_f64],
5082            right: Array1::<f64>::zeros(1),
5083        };
5084        let grad = baseline_chain_rule_gradient(
5085            age_entry.view(),
5086            age_exit.view(),
5087            age_exit.view(),
5088            &cfg,
5089            &residuals,
5090        )
5091        .expect("ok");
5092        assert!(grad.is_none());
5093    }
5094
5095    /// End-to-end envelope-theorem check: the chain-rule gradient at
5096    /// residuals-evaluated-at-β-fixed matches the central FD of the
5097    /// unpenalized NLL with respect to θ when the OFFSETS are recomputed
5098    /// from the perturbed cfg and β is held at its base value.
5099    ///
5100    /// This is the mathematical content of the envelope theorem applied to
5101    /// the penalized-deviance cost at fixed β: if β solves ∂C/∂β = 0 at
5102    /// (θ, β*), then the total derivative of C at (θ±h) when β is held at
5103    /// β* equals the partial derivative of C wrt θ at the base — up to
5104    /// O(h²) in the truncation error of central differences. For THIS test
5105    /// we're directly differencing NLL (the unpenalized piece that carries
5106    /// all the θ dependence), so the envelope identity is exact up to FD
5107    /// truncation.
5108    ///
5109    /// The test synthesizes a plausible residual set by hand rather than
5110    /// running PIRLS — what we're validating is the chain-rule contractor,
5111    /// not the fit. A PIRLS-based end-to-end check belongs in an
5112    /// integration test, not this unit-test module.
5113    #[test]
5114    fn chain_rule_gradient_matches_fd_of_nll_through_offset_perturbation() {
5115        // Toy 3-observation case with two events (one origin-entry, one not)
5116        // and one censored row at large age.
5117        let cfg = SurvivalBaselineConfig {
5118            target: SurvivalBaselineTarget::Gompertz,
5119            scale: None,
5120            shape: Some(0.03),
5121            rate: Some(0.25),
5122            makeham: None,
5123        };
5124        let age_entry = array![0.0_f64, 5.0, 8.0];
5125        let age_exit = array![4.0_f64, 12.0, 20.0];
5126        // Weighted residuals at a notional β*. Values chosen in a plausible
5127        // range (~same order as w·exp(η)).
5128        let weights = array![1.0_f64, 2.0, 0.5];
5129        let events = [1.0_f64, 1.0, 0.0];
5130        // Fake a β* that yields finite eta_entry ± eta_exit ± s values by
5131        // directly specifying eta quantities. Contractor only consumes the
5132        // residuals, so the fake is sufficient.
5133        let eta_entry_vals = [-100.0_f64, 0.5, 0.8]; // row 0 doesn't matter (origin entry)
5134        let eta_exit_vals = [0.4_f64, 0.9, 1.3];
5135        let s_vals = [0.7_f64, 1.1, 1.5];
5136        let (r_x, r_e, r_d) = {
5137            let mut rx = Array1::<f64>::zeros(3);
5138            let mut re = Array1::<f64>::zeros(3);
5139            let mut rd = Array1::<f64>::zeros(3);
5140            for i in 0..3 {
5141                let w = weights[i];
5142                let d = events[i];
5143                rx[i] = w * (eta_exit_vals[i].exp() - d);
5144                re[i] = if i == 0 {
5145                    0.0 // origin entry
5146                } else {
5147                    -w * eta_entry_vals[i].exp()
5148                };
5149                rd[i] = if d > 0.0 { -w * d / s_vals[i] } else { 0.0 };
5150            }
5151            (rx, re, rd)
5152        };
5153        let residuals = OffsetChannelResiduals {
5154            exit: r_x.clone(),
5155            entry: r_e.clone(),
5156            derivative: r_d.clone(),
5157            right: Array1::<f64>::zeros(3),
5158        };
5159        let grad = baseline_chain_rule_gradient(
5160            age_entry.view(),
5161            age_exit.view(),
5162            age_exit.view(),
5163            &cfg,
5164            &residuals,
5165        )
5166        .expect("ok")
5167        .expect("non-linear");
5168
5169        // Construct NLL(θ) with β* held to the same eta/s values by treating
5170        // eta_i, s_i as fixed "linear predictor" samples and shifting by
5171        // (offset(θ) - offset(θ_base)). That's exactly the RP NLL with β*
5172        // held constant and offsets varied through θ.
5173        let nll = |theta_plus: &Array1<f64>| -> f64 {
5174            let cfg_p = survival_baseline_config_from_theta(cfg.target, theta_plus).expect("cfg_p");
5175            let mut sum = 0.0_f64;
5176            for i in 0..3 {
5177                let (eta_x_p, d_x_p) = evaluate_survival_baseline(age_exit[i], &cfg_p).unwrap();
5178                let base = evaluate_survival_baseline(age_exit[i], &cfg).unwrap();
5179                let d_eta_x = eta_x_p - base.0;
5180                let d_d_x = d_x_p - base.1;
5181                let eta_exit_new = eta_exit_vals[i] + d_eta_x;
5182                let s_new = s_vals[i] + d_d_x;
5183                let interval_entry = if i == 0 {
5184                    0.0_f64
5185                } else {
5186                    let (eta_e_p, _) = evaluate_survival_baseline(age_entry[i], &cfg_p).unwrap();
5187                    let base_e = evaluate_survival_baseline(age_entry[i], &cfg).unwrap();
5188                    let d_eta_e = eta_e_p - base_e.0;
5189                    let eta_entry_new = eta_entry_vals[i] + d_eta_e;
5190                    eta_entry_new.exp()
5191                };
5192                let w = weights[i];
5193                let d = events[i];
5194                let nll_i =
5195                    w * (eta_exit_new.exp() - interval_entry - d * (eta_exit_new + s_new.ln()));
5196                sum += nll_i;
5197            }
5198            sum
5199        };
5200
5201        let theta_base = survival_baseline_theta_from_config(&cfg).unwrap().unwrap();
5202        let h = 1e-6;
5203        for k in 0..theta_base.len() {
5204            let mut tp = theta_base.clone();
5205            let mut tm = theta_base.clone();
5206            tp[k] += h;
5207            tm[k] -= h;
5208            let fd = (nll(&tp) - nll(&tm)) / (2.0 * h);
5209            assert!(
5210                (grad[k] - fd).abs() < 1e-5 * grad[k].abs().max(1.0),
5211                "chain-rule θ[{k}]: analytic={:.6e} fd={:.6e}",
5212                grad[k],
5213                fd
5214            );
5215        }
5216    }
5217
5218    /// Length-mismatch surfaces as an error, not a silent contraction.
5219    #[test]
5220    fn chain_rule_gradient_rejects_length_mismatch() {
5221        let cfg = SurvivalBaselineConfig {
5222            target: SurvivalBaselineTarget::Gompertz,
5223            scale: None,
5224            shape: Some(0.05),
5225            rate: Some(0.3),
5226            makeham: None,
5227        };
5228        let age_entry = array![1.0_f64, 2.0]; // length 2
5229        let age_exit = array![5.0_f64, 6.0, 7.0]; // length 3
5230        let residuals = OffsetChannelResiduals {
5231            exit: array![0.1_f64, 0.2, 0.3],
5232            entry: array![0.0_f64, 0.0, 0.0],
5233            derivative: array![0.0_f64, 0.0, 0.0],
5234            right: Array1::<f64>::zeros(3),
5235        };
5236        let err = baseline_chain_rule_gradient(
5237            age_entry.view(),
5238            age_exit.view(),
5239            age_exit.view(),
5240            &cfg,
5241            &residuals,
5242        )
5243        .expect_err("length mismatch must error");
5244        assert!(err.contains("length mismatch"), "err={err}");
5245    }
5246}