Skip to main content

gam_models/survival/
predict.rs

1//! Library-side survival prediction pipeline.
2//!
3//! Extracts the hazard/survival/cumulative-hazard math from the CLI's
4//! `run_predict_survival` so that both the CLI and the Python FFI can
5//! share a single entry point. The CLI retains ownership of progress
6//! bars, CSV writing, and uncertainty bounds; everything else (design
7//! build, baseline + time basis evaluation, link/time wiggles, and
8//! hazard/survival conversion) flows through [`predict_survival`].
9
10use std::collections::HashMap;
11
12use ndarray::{Array1, Array2, ArrayView2, s};
13
14use crate::scale_design::scale_transform_from_payload;
15use crate::survival::construction::{
16    SurvivalBaselineConfig, SurvivalBaselineTarget, SurvivalLikelihoodMode,
17    SurvivalTimeBuildOutput, add_survival_time_derivative_guard_offset, build_survival_time_basis,
18    build_survival_time_offsets_for_likelihood, build_survival_timewiggle_derivative_design,
19    center_survival_time_designs_at_anchor, evaluate_survival_time_basis_row,
20    normalize_survival_time_pair, parse_survival_likelihood_mode,
21    require_structural_survival_time_basis, resolved_survival_time_basis_config_from_build,
22    survival_derivative_guard_for_likelihood, survival_likelihood_modename,
23};
24use crate::survival::lognormal_kernel::FrailtySpec;
25use crate::survival::{
26    CompetingRisksCifResult, assemble_competing_risks_cif_from_endpoints,
27};
28use crate::wiggle::buildwiggle_block_input_from_knots;
29use crate::inference::model::{
30    FittedFamily, FittedModel as SavedModel, SavedBaselineTimeWiggleRuntime,
31    load_survival_time_basis_config_from_model, survival_baseline_config_from_model,
32};
33use crate::inference::predict_io::{BernoulliMarginalSlopePredictor, PredictInput};
34use gam_linalg::matrix::DesignMatrix;
35use crate::model_types::{BlockRole, FittedBlock, FittedLinkState, UnifiedFitResult};
36use crate::probability::signed_probit_logcdf_and_mills_ratio;
37use gam_solve::mixture_link::inverse_link_jet_for_inverse_link;
38use gam_terms::term_builder::resolve_role_col;
39use gam_terms::smooth::build_term_collection_design;
40use gam_terms::smooth::TermCollectionSpec;
41use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
42
43/// Resolved survival entry/exit column indices for a saved survival model.
44///
45/// `entry_col` is `None` when the model was trained with the right-censored
46/// shorthand `Surv(time, event)`; callers synthesize a zero entry time per
47/// row in that case via [`SurvivalTimeColumns::row_entry_time`]. Mirrors
48/// the CLI predict path so every site that consumes saved survival
49/// metadata applies the same fallback contract.
50pub struct SurvivalTimeColumns {
51    pub entry_col: Option<usize>,
52    pub exit_col: usize,
53}
54
55impl SurvivalTimeColumns {
56    /// Entry time for row `i`, defaulting to `0.0` when the saved model has
57    /// no `survival_entry` column (right-censored shorthand).
58    #[inline]
59    pub fn row_entry_time(&self, data: ArrayView2<'_, f64>, i: usize) -> f64 {
60        self.entry_col.map_or(0.0, |idx| data[[i, idx]])
61    }
62}
63
64/// Resolve saved survival entry/exit column names against the runtime
65/// `col_map`, treating an absent `survival_entry` as the right-censored
66/// shorthand (entry times synthesized as zero downstream).
67pub fn resolve_saved_survival_time_columns(
68    model: &SavedModel,
69    col_map: &HashMap<String, usize>,
70) -> Result<SurvivalTimeColumns, String> {
71    let entry_col: Option<usize> = model
72        .survival_entry
73        .as_deref()
74        .map(|name| resolve_role_col(col_map, name, "entry"))
75        .transpose()?;
76    let exitname = model
77        .survival_exit
78        .as_ref()
79        .ok_or_else(|| "survival model missing exit column metadata".to_string())?;
80    let exit_col = resolve_role_col(col_map, exitname, "exit")?;
81    Ok(SurvivalTimeColumns {
82        entry_col,
83        exit_col,
84    })
85}
86
87/// Smallest positive survival probability we admit before taking
88/// `-ln(S)` for the cumulative hazard. Using `f64::MIN_POSITIVE` (≈ 2.2e-308)
89/// would let `-ln(S)` reach ~709 and risk downstream `exp(-cum)` underflow
90/// patterns that don't round-trip through `clamp(0,1)`. `1e-300` keeps
91/// `-ln(S) ≤ ~691` and matches the location-scale predict contract upstream.
92const SURVIVAL_PROB_MIN_FOR_LOG: f64 = 1e-300;
93
94/// Typed errors emitted by the survival prediction pipeline.
95///
96/// Each variant carries a pre-formatted `reason` string so `Display` is
97/// byte-equivalent to the original `format!(...)` outputs the module used
98/// before the typed-error migration. The category split lets callers
99/// pattern-match on the failure kind without dragging the string apart.
100#[derive(Debug, Clone)]
101pub enum SurvivalPredictError {
102    /// Request-level input did not satisfy the predict contract: bad offset
103    /// lengths, malformed time grids, empty grids, non-finite times.
104    InvalidInput { reason: String },
105    /// The saved model is missing metadata required to drive the prediction
106    /// (anchor, link/distribution tags, likelihood-mode marker, etc.) or
107    /// carries legacy metadata that the current runtime refuses to consume.
108    MissingFitMetadata { reason: String },
109    /// Saved coefficient blocks, design columns, or baseline-timewiggle
110    /// runtime dimensions disagree with the rebuilt prediction designs.
111    IncompatibleSchema { reason: String },
112    /// The requested combination of saved-model mode and predict-time
113    /// options is not implemented in this library entry point yet (e.g.
114    /// uncertainty for non-location-scale, latent window prediction,
115    /// competing-risks with `with_uncertainty`).
116    UnsupportedConfiguration { reason: String },
117    /// A numerical step (hazard / derivative / survival reconstruction)
118    /// produced a non-finite or out-of-domain value that downstream code
119    /// cannot consume.
120    NumericalFailure { reason: String },
121    /// Saved-model validation failed below this prediction layer; the model
122    /// source error keeps its own payload/schema category.
123    ModelPayload {
124        context: &'static str,
125        source: crate::inference::model::FittedModelError,
126    },
127}
128
129impl std::fmt::Display for SurvivalPredictError {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            SurvivalPredictError::InvalidInput { reason }
133            | SurvivalPredictError::MissingFitMetadata { reason }
134            | SurvivalPredictError::IncompatibleSchema { reason }
135            | SurvivalPredictError::UnsupportedConfiguration { reason }
136            | SurvivalPredictError::NumericalFailure { reason } => f.write_str(reason),
137            SurvivalPredictError::ModelPayload { context, source } => {
138                write!(f, "{context}: {source}")
139            }
140        }
141    }
142}
143
144impl std::error::Error for SurvivalPredictError {
145    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
146        match self {
147            SurvivalPredictError::ModelPayload { source, .. } => Some(source),
148            SurvivalPredictError::InvalidInput { .. }
149            | SurvivalPredictError::MissingFitMetadata { .. }
150            | SurvivalPredictError::IncompatibleSchema { .. }
151            | SurvivalPredictError::UnsupportedConfiguration { .. }
152            | SurvivalPredictError::NumericalFailure { .. } => None,
153        }
154    }
155}
156
157impl From<SurvivalPredictError> for String {
158    fn from(err: SurvivalPredictError) -> String {
159        err.to_string()
160    }
161}
162
163impl From<String> for SurvivalPredictError {
164    /// Inbound conversion from the many `Result<_, String>` helpers this
165    /// module still calls into (basis builders, fit deserializers,
166    /// term-collection assembly). The text is preserved verbatim; we only
167    /// pick a category so external messages flow through `?` without
168    /// per-callsite `.map_err`.
169    fn from(reason: String) -> SurvivalPredictError {
170        SurvivalPredictError::InvalidInput { reason }
171    }
172}
173
174impl From<gam_data::DataError> for SurvivalPredictError {
175    /// Column-resolution failures from `resolve_role_col` / `resolve_col`
176    /// land as `InvalidInput` since they reflect a mismatch between the
177    /// caller-supplied predict frame and the model's expected schema.
178    fn from(err: gam_data::DataError) -> SurvivalPredictError {
179        SurvivalPredictError::InvalidInput {
180            reason: err.to_string(),
181        }
182    }
183}
184
185/// Inputs to the unified survival predict pipeline.
186pub struct SurvivalPredictRequest<'a> {
187    pub model: &'a SavedModel,
188    pub data: ArrayView2<'a, f64>,
189    pub col_map: &'a HashMap<String, usize>,
190    pub training_headers: Option<&'a Vec<String>>,
191    pub primary_offset: &'a Array1<f64>,
192    pub noise_offset: &'a Array1<f64>,
193    /// If `None`, every row is evaluated at its own `age_exit`. If
194    /// `Some(grid)`, every row is evaluated at every time in the grid.
195    pub time_grid: Option<&'a [f64]>,
196    /// When true, the result also carries delta-method standard errors
197    /// for the survival surface (response scale) and the linear
198    /// predictor.  Currently honored for `LocationScale` only; other
199    /// likelihood modes return `Err` rather than silently dropping
200    /// the request.
201    pub with_uncertainty: bool,
202}
203
204/// Result of [`predict_survival`].
205pub struct SurvivalPredictResult {
206    pub times: Vec<f64>,
207    pub hazard: Array2<f64>,
208    pub survival: Array2<f64>,
209    pub cumulative_hazard: Array2<f64>,
210    pub linear_predictor: Array1<f64>,
211    pub likelihood_mode: SurvivalLikelihoodMode,
212    /// Per-cell delta-method SE on the survival surface.  Same shape as
213    /// `survival`.  Populated only when the request set
214    /// `with_uncertainty = true` and the model class supports it.
215    pub survival_se: Option<Array2<f64>>,
216    /// Per-row delta-method SE on the linear predictor at the row's own
217    /// exit time.  Length `n`.  Populated under the same conditions as
218    /// `survival_se`.
219    pub eta_se: Option<Array1<f64>>,
220}
221
222/// Trapezoidal integral of a per-row survival curve `s(t)` sampled at the shared
223/// increasing `times` grid, restricted to `[0, tau]` — the restricted mean
224/// survival time (RMST) at horizon `tau`.
225///
226/// `RMST_i(tau) = \int_0^{tau} S_i(t) dt`. This is the standard clinical-trial
227/// survival summary (`survRM2`, lifelines `restricted_mean_survival_time`,
228/// flexsurv `rmst_*`): the area under the survival curve up to `tau`, equal to
229/// the mean of `min(T_i, tau)`. The curve is integrated with the trapezoid rule
230/// over the prediction grid; the head segment `[0, times[0]]` uses `S(0) = 1`
231/// (every subject is alive at the time origin), and when `tau` falls strictly
232/// inside a grid cell the survival value at `tau` is linearly interpolated so the
233/// partial cell contributes exactly. Grid points beyond `tau` are dropped.
234///
235/// Returns `None` when the grid is empty or `tau <= 0` (no area to accumulate),
236/// or when any sampled survival value on the integrated span is non-finite.
237fn restricted_mean_survival_time_from_curve(
238    times: &[f64],
239    survival_row: ndarray::ArrayView1<'_, f64>,
240    tau: f64,
241) -> Option<f64> {
242    if times.is_empty() || !(tau > 0.0) || !tau.is_finite() {
243        return None;
244    }
245    if times.len() != survival_row.len() {
246        return None;
247    }
248
249    // Survival at the cell boundaries we sweep through, starting from S(0) = 1.
250    let mut prev_t = 0.0_f64;
251    let mut prev_s = 1.0_f64;
252    let mut area = 0.0_f64;
253
254    for (idx, &t) in times.iter().enumerate() {
255        if !t.is_finite() || t < prev_t {
256            return None;
257        }
258        let s = survival_row[idx];
259        if !s.is_finite() {
260            return None;
261        }
262        if t >= tau {
263            // tau lands in (prev_t, t]; interpolate S(tau) and add the partial cell.
264            let span = t - prev_t;
265            let s_tau = if span > 0.0 {
266                let w = (tau - prev_t) / span;
267                prev_s + w * (s - prev_s)
268            } else {
269                prev_s
270            };
271            area += 0.5 * (prev_s + s_tau) * (tau - prev_t);
272            return Some(area);
273        }
274        area += 0.5 * (prev_s + s) * (t - prev_t);
275        prev_t = t;
276        prev_s = s;
277    }
278
279    // tau is beyond the last grid point: extend the last survival value flat to
280    // tau (conservative, matches survRM2's tau-at-or-before-last-event contract;
281    // callers wanting a strict horizon pass a tau within the grid).
282    area += prev_s * (tau - prev_t);
283    Some(area)
284}
285
286impl SurvivalPredictResult {
287    /// Per-row restricted mean survival time `\int_0^{tau} S_i(t) dt` from the
288    /// predicted survival surface. `tau` is the restriction horizon (e.g. the
289    /// study follow-up bound). Length-`n` vector, one RMST per predicted row.
290    ///
291    /// Returns `None` if the prediction grid is empty, `tau <= 0`, or any row's
292    /// survival curve carries a non-finite value on `[0, tau]`.
293    pub fn restricted_mean_survival_time(&self, tau: f64) -> Option<Array1<f64>> {
294        let n = self.survival.nrows();
295        let mut out = Array1::<f64>::zeros(n);
296        for i in 0..n {
297            let rmst =
298                restricted_mean_survival_time_from_curve(&self.times, self.survival.row(i), tau)?;
299            out[i] = rmst;
300        }
301        Some(out)
302    }
303}
304
305impl CompetingRisksPredictResult {
306    /// Per-row restricted mean survival time of the OVERALL (all-cause) survival
307    /// curve, `\int_0^{tau} S_overall_i(t) dt`. For competing risks the relevant
308    /// restricted-mean summary is taken on the all-cause survival
309    /// `exp(-sum_k H_k(t))`; cause-specific restricted-mean-time-lost is
310    /// `tau - RMST` partitioned by CIF and is left to the CIF surface directly.
311    pub fn restricted_mean_overall_survival_time(&self, tau: f64) -> Option<Array1<f64>> {
312        let n = self.overall_survival.nrows();
313        let mut out = Array1::<f64>::zeros(n);
314        for i in 0..n {
315            let rmst = restricted_mean_survival_time_from_curve(
316                &self.times,
317                self.overall_survival.row(i),
318                tau,
319            )?;
320            out[i] = rmst;
321        }
322        Some(out)
323    }
324}
325
326/// Harrell's concordance index (C-index) of a survival risk score against
327/// held-out outcomes. A larger `risk[i]` must predict a SHORTER survival time
328/// (higher hazard). Over every orderable pair — pairs whose earlier observed
329/// time is a genuine event, so the failure ordering is observed — a pair is
330/// concordant when the earlier-failing subject carries the larger risk; equal
331/// risks score half credit. `C = (concordant + 0.5·tied) / comparable`.
332/// `C = 0.5` is random ranking, `C = 1.0` a perfect ordering.
333///
334/// This is the standard discrimination metric (`survival::concordance`,
335/// `lifelines.utils.concordance_index`, scikit-survival `concordance_index_censored`).
336/// `time`, `event` (1 = event, 0 = censored), and `risk` must share length `n`.
337/// Returns `None` if there are no comparable pairs (e.g. all rows censored).
338pub fn harrell_concordance(time: &[f64], event: &[f64], risk: &[f64]) -> Option<f64> {
339    let n = time.len();
340    if n != event.len() || n != risk.len() {
341        return None;
342    }
343    let mut comparable = 0.0_f64;
344    let mut concordant = 0.0_f64;
345    for i in 0..n {
346        for j in (i + 1)..n {
347            let (early, late) = if time[i] < time[j] {
348                (i, j)
349            } else if time[j] < time[i] {
350                (j, i)
351            } else {
352                // Tied times are comparable only if both failed; such a pair is a
353                // pure tie (no strict outcome ordering).
354                if event[i] > 0.5 && event[j] > 0.5 {
355                    comparable += 1.0;
356                    concordant += 0.5;
357                }
358                continue;
359            };
360            if event[early] < 0.5 {
361                // The earlier subject was censored: the true ordering is unknown.
362                continue;
363            }
364            comparable += 1.0;
365            if risk[early] > risk[late] {
366                concordant += 1.0;
367            } else if risk[early] == risk[late] {
368                concordant += 0.5;
369            }
370        }
371    }
372    if comparable == 0.0 {
373        return None;
374    }
375    Some(concordant / comparable)
376}
377
378/// IPCW (inverse-probability-of-censoring-weighted) Brier score of a predicted
379/// survival probability at a fixed horizon `tau` against held-out outcomes — the
380/// Graf et al. (1999) estimator used by scikit-survival `brier_score`, `pec`, and
381/// `survival::brier`.
382///
383/// `s_pred[i]` is the model's predicted survival probability `S(tau | x_i)`.
384/// `time`/`event` are the held-out observed time and event indicator. `g_cens`
385/// is the censoring survival distribution `G(t) = P(C > t)` evaluated at the two
386/// weighting times the estimator needs per subject — supplied as a callable so
387/// the caller can pass a Kaplan–Meier fit of the censoring process. Each
388/// subject's squared residual `(target − Ŝ_i(τ))²` is reweighted by the inverse
389/// censoring probability:
390///   * event at/before `τ` (`T_i ≤ τ, δ_i = 1`) → target `0` (dead), weight `1/G(T_i)`;
391///   * still alive past `τ` (`T_i > τ`)         → target `1` (alive), weight `1/G(τ)`;
392///   * censored at/before `τ`                    → target undefined, contributes `0`.
393///
394/// The score is the **sample mean over all valid subjects** (Graf normalization,
395/// dividing by `n`, not by the sum of weights):
396///   `BS(τ) = (1/n) Σ_i w_i·(target_i − Ŝ_i(τ))²`.
397/// This is the convention scikit-survival / pec / `survival::brier` report, so
398/// the value is directly comparable to those packages. Lower is better; `0` is
399/// perfect. Returns `None` on length mismatch or when no subject is valid.
400///
401/// Subjects with non-finite or non-positive `time`/`event` are dropped from both
402/// numerator and denominator. When `G` collapses to `0` at a weighting time the
403/// IPCW weight is undefined; such a subject contributes `0` (rather than `∞`),
404/// which keeps the estimator finite at the extreme tail where the censoring KM
405/// runs out of support.
406pub fn ipcw_brier_score(
407    s_pred: &[f64],
408    time: &[f64],
409    event: &[f64],
410    tau: f64,
411    g_cens: impl Fn(f64) -> f64,
412) -> Option<f64> {
413    let n = s_pred.len();
414    if n != time.len() || n != event.len() {
415        return None;
416    }
417    let mut n_valid = 0.0_f64;
418    let mut acc = 0.0_f64;
419    for i in 0..n {
420        if !time[i].is_finite() || !event[i].is_finite() || time[i] <= 0.0 {
421            continue;
422        }
423        // Every valid subject counts toward the Graf denominator, even when its
424        // IPCW contribution is zero (censored before τ, or G undefined).
425        n_valid += 1.0;
426        let (target, weight) = if time[i] <= tau && event[i] > 0.5 {
427            // Failed at or before the horizon: contributes via 1/G(T_i).
428            let g = g_cens(time[i]);
429            if !(g > 0.0) {
430                continue;
431            }
432            (0.0, 1.0 / g)
433        } else if time[i] > tau {
434            // Survived past the horizon: contributes via 1/G(τ).
435            let g = g_cens(tau);
436            if !(g > 0.0) {
437                continue;
438            }
439            (1.0, 1.0 / g)
440        } else {
441            // Censored at or before τ (and not an event past τ): no info.
442            continue;
443        };
444        let resid = target - s_pred[i];
445        acc += weight * resid * resid;
446    }
447    if n_valid == 0.0 {
448        return None;
449    }
450    Some(acc / n_valid)
451}
452
453/// Integrated IPCW Brier score (IBS) — the time-integrated [`ipcw_brier_score`],
454/// matching scikit-survival's `integrated_brier_score` and `pec`'s integrated
455/// prediction-error curve.
456///
457/// `s_pred` is the `n × m` matrix of predicted survival probabilities whose
458/// column `k` is `Ŝ_i(grid[k])`; `grid` is the strictly-increasing set of
459/// evaluation times. The per-time Graf Brier `BS(grid[k])` is integrated by the
460/// trapezoidal rule over the grid and normalized by the integration span:
461///   `IBS = (1 / (t_max − t_min)) ∫_{t_min}^{t_max} BS(t) dt`.
462///
463/// `g_cens` is the censoring survival `G(t) = P(C > t)` (see [`KaplanMeier`]).
464/// Integration is restricted to grid points within `[grid[0], horizon]`; pass
465/// `horizon = f64::INFINITY` to integrate the full grid. Restricting to the
466/// observed support is the standard guard against the extrapolation tail where
467/// no subject remains at risk and the IPCW weights become unstable.
468///
469/// Returns `None` if the grid is malformed (fewer than two usable points, wrong
470/// width, non-increasing) or every per-time Brier is undefined.
471pub fn integrated_ipcw_brier_score(
472    s_pred: ArrayView2<f64>,
473    time: &[f64],
474    event: &[f64],
475    grid: &[f64],
476    horizon: f64,
477    g_cens: impl Fn(f64) -> f64,
478) -> Option<f64> {
479    let m = grid.len();
480    if m < 2 || s_pred.ncols() != m || s_pred.nrows() != time.len() {
481        return None;
482    }
483    if grid.windows(2).any(|pair| !(pair[1] > pair[0])) {
484        return None;
485    }
486    // Collect (time, Brier) at every grid point inside the integration window.
487    let mut pts: Vec<(f64, f64)> = Vec::with_capacity(m);
488    for k in 0..m {
489        if grid[k] > horizon {
490            break;
491        }
492        let col = s_pred.column(k);
493        let col_slice: Vec<f64> = col.to_vec();
494        if let Some(bs) = ipcw_brier_score(&col_slice, time, event, grid[k], &g_cens) {
495            pts.push((grid[k], bs));
496        }
497    }
498    if pts.len() < 2 {
499        return None;
500    }
501    let span = pts[pts.len() - 1].0 - pts[0].0;
502    if !(span > 0.0) {
503        return None;
504    }
505    let mut integral = 0.0_f64;
506    for w in pts.windows(2) {
507        integral += 0.5 * (w[1].1 + w[0].1) * (w[1].0 - w[0].0);
508    }
509    Some(integral / span)
510}
511
512/// Right-continuous Kaplan–Meier survival estimator `Ŝ(t) = ∏_{t_j ≤ t}(1 − d_j/n_j)`.
513///
514/// Built from observed `(time, event)` pairs. To estimate the **censoring**
515/// survival `G(t) = P(C > t)` required by the IPCW Brier score, fit with the
516/// event indicator flipped (`1 − event`) so that censorings are the "events"
517/// of the reversed process — see [`KaplanMeier::fit_censoring`].
518#[derive(Clone, Debug, Default)]
519pub struct KaplanMeier {
520    /// `(event_time, survival_after_that_time)`, strictly increasing in time.
521    steps: Vec<(f64, f64)>,
522}
523
524impl KaplanMeier {
525    /// Fit the survival of the process whose event indicator is `event > 0.5`.
526    pub fn fit(time: &[f64], event: &[f64]) -> Self {
527        let mut rows: Vec<(f64, bool)> = time
528            .iter()
529            .zip(event.iter())
530            .filter_map(|(&t, &e)| {
531                (t.is_finite() && e.is_finite() && t > 0.0).then_some((t, e > 0.5))
532            })
533            .collect();
534        rows.sort_by(|a, b| a.0.total_cmp(&b.0));
535        let mut steps = Vec::new();
536        let mut at_risk = rows.len() as f64;
537        let mut survival = 1.0_f64;
538        let mut i = 0usize;
539        while i < rows.len() {
540            let t = rows[i].0;
541            let mut j = i;
542            let mut deaths = 0usize;
543            while j < rows.len() && rows[j].0 == t {
544                deaths += usize::from(rows[j].1);
545                j += 1;
546            }
547            if deaths > 0 && at_risk > 0.0 {
548                survival *= ((at_risk - deaths as f64) / at_risk).max(0.0);
549                steps.push((t, survival));
550            }
551            at_risk -= (j - i) as f64;
552            i = j;
553        }
554        Self { steps }
555    }
556
557    /// Fit the censoring survival `G(t) = P(C > t)` by reversing the event role:
558    /// a censored observation (`event ≤ 0.5`) is an "event" of the censoring
559    /// process and a death (`event > 0.5`) is a censoring of it.
560    pub fn fit_censoring(time: &[f64], event: &[f64]) -> Self {
561        let flipped: Vec<f64> = event
562            .iter()
563            .map(|&e| if e > 0.5 { 0.0 } else { 1.0 })
564            .collect();
565        Self::fit(time, &flipped)
566    }
567
568    /// Right-continuous step lookup: `Ŝ(t)` = survival at the last event time
569    /// `≤ t` (and `1.0` before the first event).
570    pub fn at(&self, t: f64) -> f64 {
571        let mut s = 1.0_f64;
572        for &(time, surv) in &self.steps {
573            if time <= t {
574                s = surv;
575            } else {
576                break;
577            }
578        }
579        s
580    }
581}
582
583/// Joint cause-specific competing-risks prediction result.
584pub struct CompetingRisksPredictResult {
585    pub times: Vec<f64>,
586    pub endpoint_names: Vec<String>,
587    /// Cause-specific instantaneous hazards, shaped endpoint x row x time.
588    pub hazard: Vec<Array2<f64>>,
589    /// Endpoint-specific survival surfaces exp(-H_k(t)), endpoint x row x time.
590    pub survival: Vec<Array2<f64>>,
591    /// Cause-specific cumulative hazards, endpoint x row x time.
592    pub cumulative_hazard: Vec<Array2<f64>>,
593    /// Aalen-Johansen cumulative incidence, endpoint x row x time.
594    pub cif: Vec<Array2<f64>>,
595    /// Overall survival exp(-sum_k H_k(t)), row x time.
596    pub overall_survival: Array2<f64>,
597    /// Per-endpoint linear predictor at each row's own exit time, endpoint x row.
598    pub linear_predictor: Vec<Array1<f64>>,
599    pub likelihood_mode: SurvivalLikelihoodMode,
600}
601
602/// Run the survival prediction pipeline.
603///
604/// Pure library function: no progress bars, no file I/O, no uncertainty
605/// bounds. The CLI wraps this with progress updates + CSV writes; the
606/// FFI wraps it with JSON serialization.
607pub fn predict_survival(
608    req: SurvivalPredictRequest<'_>,
609) -> Result<SurvivalPredictResult, SurvivalPredictError> {
610    let SurvivalPredictRequest {
611        model,
612        data,
613        col_map,
614        training_headers,
615        primary_offset,
616        noise_offset,
617        time_grid,
618        with_uncertainty,
619    } = req;
620
621    // `survival_entry == None` is the right-censored shorthand
622    // `Surv(time, event)` produced by `gam fit` / `gamfit.fit`: no entry
623    // column was supplied at training time, so entry ages default to
624    // zero at prediction time too. The CLI's `run_predict_survival`
625    // applies the same fallback; mirroring it here keeps `gam predict`,
626    // `gam sample`, and the Python `model.predict` FFI symmetric across
627    // every likelihood that lands in this code path (weibull,
628    // transformation, ...).
629    let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
630    let exit_col = time_cols.exit_col;
631
632    let termspec = resolve_termspec_for_prediction(
633        &model.resolved_termspec,
634        training_headers,
635        col_map,
636        "resolved_termspec",
637    )?;
638    // Clip continuous covariate columns to the training range before basis
639    // assembly so polyharmonic / spline terms cannot extrapolate outside the
640    // data envelope. Times (`entry_col` / `exit_col`) are read from the
641    // original `data` view further down so the hazard integration stays on
642    // the raw timestamps the user supplied.
643    let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
644    let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
645    let cov_design = build_term_collection_design(cov_input, &termspec)
646        .map_err(|e| format!("failed to build survival prediction design: {e}"))?;
647
648    let n = data.nrows();
649    if primary_offset.len() != n || noise_offset.len() != n {
650        return Err(SurvivalPredictError::InvalidInput {
651            reason: format!(
652                "survival prediction offset length mismatch: rows={n}, offset={}, noise_offset={}",
653                primary_offset.len(),
654                noise_offset.len()
655            ),
656        });
657    }
658
659    use rayon::iter::{IntoParallelIterator, ParallelIterator};
660    let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
661        .into_par_iter()
662        .map(|i| {
663            normalize_survival_time_pair(time_cols.row_entry_time(data, i), data[[i, exit_col]], i)
664        })
665        .collect();
666    let pairs = pairs?;
667    let mut age_entry = Array1::<f64>::zeros(n);
668    let mut age_exit = Array1::<f64>::zeros(n);
669    for (i, (t0, t1)) in pairs.into_iter().enumerate() {
670        age_entry[i] = t0;
671        age_exit[i] = t1;
672    }
673
674    let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
675
676    // Latent modes emit binary event-window probabilities, not survival
677    // curves. The CLI's `run_predict_saved_latent_*` helpers wrap them with
678    // window quadrature + uncertainty pipelines that aren't ported yet.
679    if matches!(
680        saved_likelihood_mode,
681        SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
682    ) {
683        return Err(SurvivalPredictError::UnsupportedConfiguration {
684            reason: format!(
685                "survival prediction via predict_survival does not support likelihood_mode={} yet; \
686             latent window prediction lives in the CLI's run_predict_saved_latent_window_impl \
687             pipeline and has not yet been ported to the library. Use the CLI predict command.",
688                survival_likelihood_modename(saved_likelihood_mode)
689            ),
690        });
691    }
692    // Location-scale: handled via a dedicated batch path that calls
693    // `predict_survival_location_scale` directly.
694    if saved_likelihood_mode == SurvivalLikelihoodMode::LocationScale {
695        return predict_survival_location_scale_batch(
696            model,
697            &age_entry,
698            &age_exit,
699            &cov_design,
700            primary_offset,
701            noise_offset,
702            training_headers,
703            col_map,
704            data,
705            time_grid,
706            with_uncertainty,
707        )
708        .map_err(SurvivalPredictError::from);
709    }
710    if with_uncertainty {
711        return Err(SurvivalPredictError::from(format!(
712            "predict_survival: with_uncertainty is currently supported only for the \
713             location-scale likelihood mode; got {}",
714            survival_likelihood_modename(saved_likelihood_mode)
715        )));
716    }
717
718    // Ambient time basis: built once with (age_entry, age_exit) so that
719    // the saved anchor / monotonicity checks fire at construction time.
720    let time_cfg = load_survival_time_basis_config_from_model(model)?;
721    let mut time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
722    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
723        &time_build.basisname,
724        time_build.degree,
725        time_build.knots.as_ref(),
726        time_build.keep_cols.as_ref(),
727        time_build.smooth_lambda,
728    )?;
729    // Single-cause Weibull without a learned baseline timewiggle carries its
730    // ENTIRE log-cumulative-hazard baseline in the fitted `[1, log t]` linear
731    // time-basis coefficients, not in a parametric offset. The fit centers that
732    // basis at the survival time anchor (`center_survival_time_designs_at_anchor`
733    // in the workflow), which zeroes the constant column so `beta[0]` is
734    // unidentified and the fitted baseline is exactly
735    // `beta[1] * (log t - log anchor)`. The model still SAVES a `Weibull`
736    // baseline target (recovered scale/shape) for CIF/reporting, but that
737    // metadata must NOT re-enter prediction as a parametric offset: doing so
738    // double-counts the baseline (offset + beta) and, combined with predicting
739    // against the UN-centered basis, collapses the survival surface to the
740    // degenerate `S(t) == 1` (issue #897). Mirror the fit here: center the basis
741    // at the anchor and carry a zero baseline offset, so predict reproduces the
742    // fitted `beta[1] * (log t - log anchor)`. Weibull-WITH-timewiggle is a
743    // different regime (the parametric offset is the baseline and beta carries
744    // only the wiggle deviation), so it is excluded.
745    let weibull_baseline_in_beta = saved_likelihood_mode == SurvivalLikelihoodMode::Weibull
746        && !model.has_baseline_time_wiggle();
747    let mut time_anchor: Option<f64> = None;
748    let mut time_anchor_row_cached: Option<Array1<f64>> = None;
749    if matches!(
750        saved_likelihood_mode,
751        SurvivalLikelihoodMode::LocationScale | SurvivalLikelihoodMode::MarginalSlope
752    ) || weibull_baseline_in_beta
753    {
754        let anchor = model
755            .survival_time_anchor
756            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
757        let time_anchor_row = evaluate_survival_time_basis_row(anchor, &resolved_time_cfg)?;
758        center_survival_time_designs_at_anchor(
759            &mut time_build.x_entry_time,
760            &mut time_build.x_exit_time,
761            &time_anchor_row,
762        )?;
763        time_anchor = Some(anchor);
764        time_anchor_row_cached = Some(time_anchor_row);
765    }
766    if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull && !model.has_baseline_time_wiggle()
767    {
768        require_structural_survival_time_basis(&time_build.basisname, "saved survival sampling")?;
769    }
770    let mut baseline_cfg = saved_survival_runtime_baseline_config(model)?;
771    if weibull_baseline_in_beta {
772        baseline_cfg = SurvivalBaselineConfig {
773            target: SurvivalBaselineTarget::Linear,
774            scale: None,
775            shape: None,
776            rate: None,
777            makeham: None,
778        };
779    }
780
781    // Resolve the time-grid: either the explicit grid (same for every
782    // row) or per-row exit times (one column per row).
783    let per_row_eval = time_grid.is_none();
784    let eval_times: Vec<f64> = match time_grid {
785        Some(grid) => {
786            if grid.is_empty() {
787                return Err(SurvivalPredictError::InvalidInput {
788                    reason: "survival time_grid must contain at least one time".to_string(),
789                });
790            }
791            for (idx, &t) in grid.iter().enumerate() {
792                if !t.is_finite() || t < 0.0 {
793                    return Err(SurvivalPredictError::InvalidInput {
794                        reason: format!(
795                            "survival time_grid requires finite non-negative times (index {idx})",
796                        ),
797                    });
798                }
799            }
800            grid.to_vec()
801        }
802        None => Vec::new(),
803    };
804
805    let t_cols = if per_row_eval { 1 } else { eval_times.len() };
806    let mut hazard = Array2::<f64>::zeros((n, t_cols));
807    let mut survival = Array2::<f64>::zeros((n, t_cols));
808    let mut cumulative_hazard = Array2::<f64>::zeros((n, t_cols));
809    let mut linear_predictor = Array1::<f64>::zeros(n);
810
811    // For marginal-slope, build the saved predictor (with link-deviation +
812    // score-warp blocks plumbed in) once. The per-(row, t) loop reuses this
813    // predictor and only assembles the per-cell q-design slice. Without this,
814    // the library skipped link-deviation and score-warp replay entirely and
815    // disagreed with the CLI's `gam predict` on every flex model.
816    let marginal_slope_ctx = if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
817        // Baseline offsets at the predict-data's age_entry / age_exit. Used to
818        // build the predictor's `pred_input` (which we discard) — the actual
819        // per-(row, t) offset is rebuilt inside `evaluate_marginal_slope_row`.
820        let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
821            build_survival_time_offsets_for_likelihood(
822                &age_entry,
823                &age_exit,
824                &baseline_cfg,
825                saved_likelihood_mode,
826                None,
827            )?;
828        add_survival_time_derivative_guard_offset(
829            &age_entry,
830            &age_exit,
831            time_anchor.ok_or_else(|| {
832                "saved survival marginal-slope model missing survival_time_anchor".to_string()
833            })?,
834            survival_derivative_guard_for_likelihood(saved_likelihood_mode),
835            &mut eta_offset_entry,
836            &mut eta_offset_exit,
837            &mut derivative_offset_exit,
838        )?;
839        Some(build_marginal_slope_predict_context(
840            model,
841            data,
842            col_map,
843            training_headers,
844            &cov_design.design,
845            primary_offset,
846            noise_offset,
847            &time_build,
848            &eta_offset_entry,
849            &eta_offset_exit,
850            &derivative_offset_exit,
851        )?)
852    } else {
853        None
854    };
855
856    // Evaluate each row independently.  For an explicit time grid, each worker
857    // reuses the row's covariate slice across all grid times and returns a
858    // complete row, avoiding synchronized writes into the output matrices.
859    struct SurvivalPredictionRow {
860        hazard: Vec<f64>,
861        survival: Vec<f64>,
862        cumulative_hazard: Vec<f64>,
863        linear_predictor: f64,
864    }
865
866    let row_results: Result<Vec<SurvivalPredictionRow>, SurvivalPredictError> = (0..n)
867        .into_par_iter()
868        .map(|i| {
869            let cov_row = if matches!(
870                saved_likelihood_mode,
871                SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
872            ) {
873                Some(design_row_owned(
874                    &cov_design.design,
875                    i,
876                    "survival predict covariate row",
877                )?)
878            } else {
879                None
880            };
881            let evaluate_at = |t_query: f64| -> Result<(f64, f64, f64), SurvivalPredictError> {
882                let t_entry = age_entry[i].min(t_query);
883                let single_entry = Array1::from_elem(1, t_entry);
884                let single_exit = Array1::from_elem(1, t_query);
885                let mut row_time =
886                    build_survival_time_basis(&single_entry, &single_exit, time_cfg.clone(), None)?;
887                if let Some(anchor_row) = time_anchor_row_cached.as_ref() {
888                    center_survival_time_designs_at_anchor(
889                        &mut row_time.x_entry_time,
890                        &mut row_time.x_exit_time,
891                        anchor_row,
892                    )?;
893                }
894                let (mut r_eta_entry, mut r_eta_exit, mut r_deriv_exit) =
895                    build_survival_time_offsets_for_likelihood(
896                        &single_entry,
897                        &single_exit,
898                        &baseline_cfg,
899                        saved_likelihood_mode,
900                        None,
901                    )?;
902                if saved_likelihood_mode == SurvivalLikelihoodMode::MarginalSlope {
903                    add_survival_time_derivative_guard_offset(
904                        &single_entry,
905                        &single_exit,
906                        time_anchor.ok_or_else(|| {
907                            "saved survival marginal-slope model missing survival_time_anchor"
908                                .to_string()
909                        })?,
910                        survival_derivative_guard_for_likelihood(saved_likelihood_mode),
911                        &mut r_eta_entry,
912                        &mut r_eta_exit,
913                        &mut r_deriv_exit,
914                    )?;
915                }
916
917                match saved_likelihood_mode {
918                    SurvivalLikelihoodMode::MarginalSlope => {
919                        let ctx = marginal_slope_ctx.as_ref().ok_or_else(|| {
920                            "internal error: marginal-slope context missing for marginal-slope mode"
921                                .to_string()
922                        })?;
923                        evaluate_marginal_slope_row(
924                            i,
925                            ctx,
926                            &row_time,
927                            &r_eta_exit,
928                            &r_deriv_exit,
929                            primary_offset[i],
930                        )
931                    }
932                    SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
933                        let cov_row = cov_row.as_ref().ok_or_else(|| {
934                            "internal error: covariate row missing for Royston-Parmar prediction"
935                                .to_string()
936                        })?;
937                        evaluate_rp_row(
938                            model,
939                            &row_time,
940                            cov_row,
941                            r_eta_exit[0],
942                            r_deriv_exit[0],
943                            primary_offset[i],
944                        )
945                    }
946                    SurvivalLikelihoodMode::Latent
947                    | SurvivalLikelihoodMode::LatentBinary
948                    | SurvivalLikelihoodMode::LocationScale => {
949                        Err(SurvivalPredictError::NumericalFailure {
950                            reason: "unreachable: unsupported likelihood_mode filtered earlier"
951                                .to_string(),
952                        })
953                    }
954                }
955            };
956
957            let mut row = SurvivalPredictionRow {
958                hazard: vec![0.0; t_cols],
959                survival: vec![0.0; t_cols],
960                cumulative_hazard: vec![0.0; t_cols],
961                linear_predictor: 0.0,
962            };
963            if per_row_eval {
964                let (eta_t, cum_t, haz_t) = evaluate_at(age_exit[i])?;
965                row.linear_predictor = eta_t;
966                row.hazard[0] = haz_t;
967                row.cumulative_hazard[0] = cum_t;
968                row.survival[0] = (-cum_t).exp().clamp(0.0, 1.0);
969            } else {
970                for (j, &t_query) in eval_times.iter().enumerate() {
971                    if t_query <= 0.0 {
972                        row.hazard[j] = 0.0;
973                        row.cumulative_hazard[j] = 0.0;
974                        row.survival[j] = 1.0;
975                    } else {
976                        let (_eta_t, cum_t, haz_t) = evaluate_at(t_query)?;
977                        row.hazard[j] = haz_t;
978                        row.cumulative_hazard[j] = cum_t;
979                        row.survival[j] = (-cum_t).exp().clamp(0.0, 1.0);
980                    }
981                }
982                let (eta_t, _, _) = evaluate_at(age_exit[i])?;
983                row.linear_predictor = eta_t;
984            }
985            Ok(row)
986        })
987        .collect();
988
989    for (i, row) in row_results?.into_iter().enumerate() {
990        linear_predictor[i] = row.linear_predictor;
991        for j in 0..t_cols {
992            hazard[[i, j]] = row.hazard[j];
993            cumulative_hazard[[i, j]] = row.cumulative_hazard[j];
994            survival[[i, j]] = row.survival[j];
995        }
996    }
997
998    let times_out: Vec<f64> = if per_row_eval {
999        age_exit.to_vec()
1000    } else {
1001        eval_times
1002    };
1003
1004    Ok(SurvivalPredictResult {
1005        times: times_out,
1006        hazard,
1007        survival,
1008        cumulative_hazard,
1009        linear_predictor,
1010        likelihood_mode: saved_likelihood_mode,
1011        survival_se: None,
1012        eta_se: None,
1013    })
1014}
1015
1016pub fn predict_competing_risks_survival(
1017    req: SurvivalPredictRequest<'_>,
1018) -> Result<CompetingRisksPredictResult, SurvivalPredictError> {
1019    let SurvivalPredictRequest {
1020        model,
1021        data,
1022        col_map,
1023        training_headers,
1024        primary_offset,
1025        noise_offset,
1026        time_grid,
1027        with_uncertainty,
1028    } = req;
1029
1030    if with_uncertainty {
1031        return Err(SurvivalPredictError::UnsupportedConfiguration {
1032            reason: "competing-risks survival prediction does not yet support with_uncertainty"
1033                .to_string(),
1034        });
1035    }
1036
1037    let saved_likelihood_mode = require_saved_survival_likelihood_mode(model)?;
1038    if !matches!(
1039        saved_likelihood_mode,
1040        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
1041    ) {
1042        return Err(SurvivalPredictError::UnsupportedConfiguration {
1043            reason: format!(
1044                "joint cause-specific prediction supports transformation/weibull survival only; got {}",
1045                survival_likelihood_modename(saved_likelihood_mode)
1046            ),
1047        });
1048    }
1049
1050    let fit = fit_result_from_saved_model_for_prediction(model)?;
1051    let cause_count = model
1052        .survival_cause_count
1053        .unwrap_or(fit.blocks.len())
1054        .max(1);
1055    if cause_count <= 1 {
1056        return Err(SurvivalPredictError::MissingFitMetadata {
1057            reason: "competing-risks survival prediction requires a saved model with at least two causes"
1058                .to_string(),
1059        });
1060    }
1061    if fit.blocks.len() != cause_count {
1062        return Err(SurvivalPredictError::IncompatibleSchema {
1063            reason: format!(
1064                "saved competing-risks survival fit has {} coefficient blocks but metadata says {cause_count} causes",
1065                fit.blocks.len()
1066            ),
1067        });
1068    }
1069    let endpoint_names = model.survival_endpoint_names.clone().unwrap_or_else(|| {
1070        (1..=cause_count)
1071            .map(|idx| format!("cause_{idx}"))
1072            .collect()
1073    });
1074    if endpoint_names.len() != cause_count {
1075        return Err(SurvivalPredictError::IncompatibleSchema {
1076            reason: format!(
1077                "saved competing-risks survival endpoint_names has length {}, expected {cause_count}",
1078                endpoint_names.len()
1079            ),
1080        });
1081    }
1082
1083    // Right-censored shorthand: same fallback as the single-cause path
1084    // above — entry ages default to zero when the model was fit without
1085    // an explicit entry column.
1086    let time_cols = resolve_saved_survival_time_columns(model, col_map)?;
1087    let exit_col = time_cols.exit_col;
1088
1089    let termspec = resolve_termspec_for_prediction(
1090        &model.resolved_termspec,
1091        training_headers,
1092        col_map,
1093        "resolved_termspec",
1094    )?;
1095    let cov_clipped = model.axis_clip_to_training_ranges(data, col_map);
1096    let cov_input = cov_clipped.as_ref().map_or(data, |arr| arr.view());
1097    let cov_design = build_term_collection_design(cov_input, &termspec)
1098        .map_err(|e| format!("failed to build competing-risks prediction design: {e}"))?;
1099
1100    let n = data.nrows();
1101    if primary_offset.len() != n || noise_offset.len() != n {
1102        return Err(SurvivalPredictError::InvalidInput {
1103            reason: format!(
1104                "competing-risks prediction offset length mismatch: rows={n}, offset={}, noise_offset={}",
1105                primary_offset.len(),
1106                noise_offset.len()
1107            ),
1108        });
1109    }
1110
1111    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1112    let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
1113        .into_par_iter()
1114        .map(|i| {
1115            normalize_survival_time_pair(time_cols.row_entry_time(data, i), data[[i, exit_col]], i)
1116        })
1117        .collect();
1118    let pairs = pairs?;
1119    let mut age_entry = Array1::<f64>::zeros(n);
1120    let mut age_exit = Array1::<f64>::zeros(n);
1121    for (i, (t0, t1)) in pairs.into_iter().enumerate() {
1122        age_entry[i] = t0;
1123        age_exit[i] = t1;
1124    }
1125
1126    let time_cfg = load_survival_time_basis_config_from_model(model)?;
1127    let time_build = build_survival_time_basis(&age_entry, &age_exit, time_cfg.clone(), None)?;
1128    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
1129        &time_build.basisname,
1130        time_build.degree,
1131        time_build.knots.as_ref(),
1132        time_build.keep_cols.as_ref(),
1133        time_build.smooth_lambda,
1134    )?;
1135    // See the single-cause `predict_survival` note: per-cause Weibull baselines
1136    // (no learned timewiggle) live in the anchor-centered linear time-basis
1137    // coefficients, so prediction must center the basis at the saved anchor and
1138    // carry a zero parametric baseline offset rather than re-adding the saved
1139    // (reporting-only) `Weibull` target as an offset (issues #897 / #689 / #690).
1140    // The ambient `time_build` is consumed only for the structural-basis check;
1141    // the per-(cause, row) loop rebuilds and centers its own `row_time`, so the
1142    // anchor row is all that needs threading through.
1143    let weibull_baseline_in_beta = saved_likelihood_mode == SurvivalLikelihoodMode::Weibull
1144        && !model.has_baseline_time_wiggle();
1145    let cr_time_anchor_row: Option<Array1<f64>> = if weibull_baseline_in_beta {
1146        let anchor = model
1147            .survival_time_anchor
1148            .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
1149        Some(evaluate_survival_time_basis_row(
1150            anchor,
1151            &resolved_time_cfg,
1152        )?)
1153    } else {
1154        None
1155    };
1156    if saved_likelihood_mode != SurvivalLikelihoodMode::Weibull && !model.has_baseline_time_wiggle()
1157    {
1158        require_structural_survival_time_basis(
1159            &time_build.basisname,
1160            "saved competing-risks survival prediction",
1161        )?;
1162    }
1163    let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
1164
1165    let per_row_eval = time_grid.is_none();
1166    let eval_times: Vec<f64> = match time_grid {
1167        Some(grid) => {
1168            if grid.is_empty() {
1169                return Err(SurvivalPredictError::InvalidInput {
1170                    reason: "survival time_grid must contain at least one time".to_string(),
1171                });
1172            }
1173            for (idx, &t) in grid.iter().enumerate() {
1174                if !t.is_finite() || t < 0.0 {
1175                    return Err(SurvivalPredictError::InvalidInput {
1176                        reason: format!(
1177                            "survival time_grid requires finite non-negative times (index {idx})",
1178                        ),
1179                    });
1180                }
1181            }
1182            grid.to_vec()
1183        }
1184        None => Vec::new(),
1185    };
1186    let t_cols = if per_row_eval { 1 } else { eval_times.len() };
1187
1188    // Refined internal grid for the Aalen-Johansen CIF assembly (gam#1385).
1189    //
1190    // The discrete AJ increment ΔF_k = S(t_{j-1})·(1−exp(−ΔH_total))·ΔH_k/ΔH_total
1191    // assumes the cause-specific hazard *ratio* h_k/h_total is constant within
1192    // each interval. On a coarse user grid with differently-shaped competing
1193    // hazards that assumption is violated, making the returned CIF a function of
1194    // the requested grid resolution (up to ~22% pointwise error) rather than a
1195    // pure function of the query time. We assemble AJ on a refined grid (extra
1196    // points inserted from 0 to the first user time and between consecutive user
1197    // times — cause-specific cumulative hazards are cheap closed-form
1198    // evaluate_at calls) and then read CIF/overall-survival back at the user's
1199    // requested times. The per-cause hazard/survival/cumulative_hazard returned
1200    // to the caller stay on the user grid (those are pointwise and already
1201    // grid-independent); only the AJ assembly uses the refinement.
1202    //
1203    // `refined_times` is strictly increasing and is a superset of `eval_times`;
1204    // `user_time_to_refined_index[j]` is the position of the j-th user time
1205    // inside `refined_times`. Per-row eval keeps its single-time anchor path.
1206    const CIF_REFINE_SUBINTERVALS: usize = 32;
1207    let (refined_times, user_time_to_refined_index): (Vec<f64>, Vec<usize>) = if per_row_eval {
1208        (Vec::new(), Vec::new())
1209    } else {
1210        let mut refined: Vec<f64> = Vec::new();
1211        let mut user_index: Vec<usize> = Vec::with_capacity(eval_times.len());
1212        let mut prev = 0.0_f64;
1213        for &t_user in &eval_times {
1214            // Insert CIF_REFINE_SUBINTERVALS-1 strictly-interior points in
1215            // (prev, t_user], landing exactly on t_user as the last point. Skip
1216            // the interior fill for a zero-length gap (duplicate / origin user
1217            // time) so `refined` stays strictly increasing.
1218            let gap = t_user - prev;
1219            if gap > 0.0 {
1220                for s in 1..CIF_REFINE_SUBINTERVALS {
1221                    let t_mid = prev + gap * (s as f64) / (CIF_REFINE_SUBINTERVALS as f64);
1222                    // Guard against ties from floating-point rounding.
1223                    if refined.last().is_none_or(|&last| t_mid > last) {
1224                        refined.push(t_mid);
1225                    }
1226                }
1227            }
1228            if refined.last().is_none_or(|&last| t_user > last) {
1229                refined.push(t_user);
1230            }
1231            user_index.push(refined.len() - 1);
1232            prev = t_user;
1233        }
1234        (refined, user_index)
1235    };
1236    let refined_cols = refined_times.len();
1237
1238    let saved_timewiggle_by_cause = saved_cause_specific_timewiggles(model, &fit, cause_count)?;
1239    let cov_rows = (0..n)
1240        .map(|i| design_row_owned(&cov_design.design, i, "competing-risks covariate row"))
1241        .collect::<Result<Vec<_>, _>>()?;
1242
1243    let mut hazard = (0..cause_count)
1244        .map(|_| Array2::<f64>::zeros((n, t_cols)))
1245        .collect::<Vec<_>>();
1246    let mut survival = (0..cause_count)
1247        .map(|_| Array2::<f64>::zeros((n, t_cols)))
1248        .collect::<Vec<_>>();
1249    let mut cumulative_hazard = (0..cause_count)
1250        .map(|_| Array2::<f64>::zeros((n, t_cols)))
1251        .collect::<Vec<_>>();
1252    // Cause-specific cumulative hazards on the refined AJ grid (gam#1385);
1253    // unused (zero-width) on the per-row-eval path.
1254    let mut cumulative_hazard_refined = (0..cause_count)
1255        .map(|_| Array2::<f64>::zeros((n, refined_cols)))
1256        .collect::<Vec<_>>();
1257    let mut linear_predictor = (0..cause_count)
1258        .map(|_| Array1::<f64>::zeros(n))
1259        .collect::<Vec<_>>();
1260
1261    struct CauseRow {
1262        cause: usize,
1263        row: usize,
1264        hazard: Vec<f64>,
1265        survival: Vec<f64>,
1266        cumulative: Vec<f64>,
1267        /// Cumulative hazard on the refined AJ grid (gam#1385); empty on the
1268        /// per-row-eval path.
1269        cumulative_refined: Vec<f64>,
1270        eta_exit: f64,
1271    }
1272
1273    let rows: Result<Vec<CauseRow>, SurvivalPredictError> = (0..cause_count * n)
1274        .into_par_iter()
1275        .map(|flat| {
1276            let cause = flat / n;
1277            let i = flat % n;
1278            let block = &fit.blocks[cause];
1279            let timewiggle = saved_timewiggle_by_cause[cause].as_ref();
1280            let evaluate_at = |t_query: f64| -> Result<(f64, f64, f64), SurvivalPredictError> {
1281                let t_entry = age_entry[i].min(t_query);
1282                let single_entry = Array1::from_elem(1, t_entry);
1283                let single_exit = Array1::from_elem(1, t_query);
1284                let mut row_time =
1285                    build_survival_time_basis(&single_entry, &single_exit, time_cfg.clone(), None)?;
1286                if let Some(anchor_row) = cr_time_anchor_row.as_ref() {
1287                    center_survival_time_designs_at_anchor(
1288                        &mut row_time.x_entry_time,
1289                        &mut row_time.x_exit_time,
1290                        anchor_row,
1291                    )?;
1292                }
1293                let (r_eta_exit, r_deriv_exit) = if weibull_baseline_in_beta {
1294                    (0.0, 0.0)
1295                } else {
1296                    let (_, eta_exit, deriv_exit) = build_survival_time_offsets_for_likelihood(
1297                        &single_entry,
1298                        &single_exit,
1299                        &baseline_cfg,
1300                        saved_likelihood_mode,
1301                        None,
1302                    )?;
1303                    (eta_exit[0], deriv_exit[0])
1304                };
1305                evaluate_rp_row_with_beta(
1306                    &block.beta,
1307                    timewiggle,
1308                    &row_time,
1309                    &cov_rows[i],
1310                    r_eta_exit,
1311                    r_deriv_exit,
1312                    primary_offset[i],
1313                )
1314            };
1315
1316            let mut out = CauseRow {
1317                cause,
1318                row: i,
1319                hazard: vec![0.0; t_cols],
1320                survival: vec![0.0; t_cols],
1321                cumulative: vec![0.0; t_cols],
1322                cumulative_refined: vec![0.0; refined_cols],
1323                eta_exit: 0.0,
1324            };
1325            if per_row_eval {
1326                let (eta_t, cum_t, haz_t) = evaluate_at(age_exit[i])?;
1327                out.eta_exit = eta_t;
1328                out.hazard[0] = haz_t;
1329                out.cumulative[0] = cum_t;
1330                out.survival[0] = (-cum_t).exp().clamp(0.0, 1.0);
1331            } else {
1332                for (j, &t_query) in eval_times.iter().enumerate() {
1333                    // Mirror the single-cause origin guard: every subject is
1334                    // alive at the time origin, so S(0)=1, H(0)=0, h(0)=0.
1335                    // Without this, the time basis floors t=0 to
1336                    // SURVIVAL_TIME_FLOOR and returns a nonzero hazard, which
1337                    // would anchor the Aalen-Johansen CIF assembly on a
1338                    // non-unit S(0) and bias every downstream value.
1339                    if t_query <= 0.0 {
1340                        out.hazard[j] = 0.0;
1341                        out.cumulative[j] = 0.0;
1342                        out.survival[j] = 1.0;
1343                    } else {
1344                        let (_eta_t, cum_t, haz_t) = evaluate_at(t_query)?;
1345                        out.hazard[j] = haz_t;
1346                        out.cumulative[j] = cum_t;
1347                        out.survival[j] = (-cum_t).exp().clamp(0.0, 1.0);
1348                    }
1349                }
1350                // Refined-grid cumulative hazards for the AJ CIF assembly
1351                // (gam#1385). Same closed-form evaluate_at; reuse the exact
1352                // user-grid values at the points that coincide so the returned
1353                // per-cause cumulative_hazard and the assembly agree at the user
1354                // times to the bit.
1355                for (jr, &t_query) in refined_times.iter().enumerate() {
1356                    out.cumulative_refined[jr] = if t_query <= 0.0 {
1357                        0.0
1358                    } else {
1359                        evaluate_at(t_query)?.1
1360                    };
1361                }
1362                let (eta_t, _, _) = evaluate_at(age_exit[i])?;
1363                out.eta_exit = eta_t;
1364            }
1365            Ok(out)
1366        })
1367        .collect();
1368
1369    for row in rows? {
1370        linear_predictor[row.cause][row.row] = row.eta_exit;
1371        for j in 0..t_cols {
1372            hazard[row.cause][[row.row, j]] = row.hazard[j];
1373            survival[row.cause][[row.row, j]] = row.survival[j];
1374            cumulative_hazard[row.cause][[row.row, j]] = row.cumulative[j];
1375        }
1376        for jr in 0..refined_cols {
1377            cumulative_hazard_refined[row.cause][[row.row, jr]] = row.cumulative_refined[jr];
1378        }
1379    }
1380
1381    // Assemble the Aalen-Johansen CIF on the refined grid (gam#1385), then read
1382    // the result back at the user-requested times so the CIF is grid-resolution
1383    // independent. Per-row eval keeps the single-anchor assembly path.
1384    let assembled = if per_row_eval {
1385        let assembly_times = Array1::from_elem(1, 0.0);
1386        assemble_competing_risks_cif_from_endpoints(assembly_times.view(), &cumulative_hazard)
1387            .map_err(|err| err.to_string())?
1388    } else {
1389        let assembly_times = Array1::from_vec(refined_times.clone());
1390        let refined_assembled = assemble_competing_risks_cif_from_endpoints(
1391            assembly_times.view(),
1392            &cumulative_hazard_refined,
1393        )
1394        .map_err(|err| err.to_string())?;
1395        // Project refined CIF / overall-survival columns onto the user grid.
1396        let mut cif_user = (0..cause_count)
1397            .map(|_| Array2::<f64>::zeros((n, t_cols)))
1398            .collect::<Vec<_>>();
1399        let mut overall_user = Array2::<f64>::zeros((n, t_cols));
1400        for (j_user, &jr) in user_time_to_refined_index.iter().enumerate() {
1401            for cause in 0..cause_count {
1402                for row in 0..n {
1403                    cif_user[cause][[row, j_user]] = refined_assembled.cif[cause][[row, jr]];
1404                }
1405            }
1406            for row in 0..n {
1407                overall_user[[row, j_user]] = refined_assembled.overall_survival[[row, jr]];
1408            }
1409        }
1410        CompetingRisksCifResult {
1411            cif: cif_user,
1412            overall_survival: overall_user,
1413        }
1414    };
1415    if assembled.cif.len() != cause_count {
1416        return Err(format!(
1417            "competing-risks CIF assembly produced {} endpoint matrices, expected {cause_count}",
1418            assembled.cif.len()
1419        )
1420        .into());
1421    }
1422    let cif = assembled.cif;
1423    let overall_survival = assembled.overall_survival;
1424    let times_out = if per_row_eval {
1425        age_exit.to_vec()
1426    } else {
1427        eval_times
1428    };
1429    Ok(CompetingRisksPredictResult {
1430        times: times_out,
1431        endpoint_names,
1432        hazard,
1433        survival,
1434        cumulative_hazard,
1435        cif,
1436        overall_survival,
1437        linear_predictor,
1438        likelihood_mode: saved_likelihood_mode,
1439    })
1440}
1441
1442fn saved_cause_specific_timewiggles(
1443    model: &SavedModel,
1444    fit: &UnifiedFitResult,
1445    cause_count: usize,
1446) -> Result<Vec<Option<SavedBaselineTimeWiggleRuntime>>, SurvivalPredictError> {
1447    let has_metadata = model.baseline_timewiggle_knots.is_some()
1448        || model.baseline_timewiggle_degree.is_some()
1449        || model.baseline_timewiggle_penalty_orders.is_some()
1450        || model.baseline_timewiggle_double_penalty.is_some()
1451        || model.beta_baseline_timewiggle_by_cause.is_some();
1452    if !has_metadata {
1453        return Ok(vec![None; cause_count]);
1454    }
1455    let knots = model.baseline_timewiggle_knots.clone().ok_or_else(|| {
1456        "joint cause-specific survival missing baseline_timewiggle_knots".to_string()
1457    })?;
1458    let degree = model.baseline_timewiggle_degree.ok_or_else(|| {
1459        "joint cause-specific survival missing baseline_timewiggle_degree".to_string()
1460    })?;
1461    let penalty_orders = model
1462        .baseline_timewiggle_penalty_orders
1463        .clone()
1464        .ok_or_else(|| {
1465            "joint cause-specific survival missing baseline_timewiggle_penalty_orders".to_string()
1466        })?;
1467    let double_penalty = model.baseline_timewiggle_double_penalty.ok_or_else(|| {
1468        "joint cause-specific survival missing baseline_timewiggle_double_penalty".to_string()
1469    })?;
1470    let by_cause = model
1471        .beta_baseline_timewiggle_by_cause
1472        .as_ref()
1473        .ok_or_else(|| {
1474            "joint cause-specific survival missing beta_baseline_timewiggle_by_cause".to_string()
1475        })?;
1476    if by_cause.len() != cause_count {
1477        return Err(SurvivalPredictError::IncompatibleSchema {
1478            reason: format!(
1479                "joint cause-specific survival has {} timewiggle coefficient blocks, expected {cause_count}",
1480                by_cause.len()
1481            ),
1482        });
1483    }
1484    for (cause, (block, beta_w)) in fit.blocks.iter().zip(by_cause).enumerate() {
1485        if beta_w.len() > block.beta.len() {
1486            return Err(SurvivalPredictError::IncompatibleSchema {
1487                reason: format!(
1488                    "joint cause-specific survival cause {} timewiggle beta has length {}, but endpoint beta has {} coefficients",
1489                    cause + 1,
1490                    beta_w.len(),
1491                    block.beta.len()
1492                ),
1493            });
1494        }
1495    }
1496    Ok(by_cause
1497        .iter()
1498        .map(|beta| {
1499            Some(SavedBaselineTimeWiggleRuntime {
1500                knots: knots.clone(),
1501                degree,
1502                penalty_orders: penalty_orders.clone(),
1503                double_penalty,
1504                beta: beta.clone(),
1505            })
1506        })
1507        .collect())
1508}
1509
1510// ---------------------------------------------------------------------------
1511// Per-mode single-row evaluators.
1512// ---------------------------------------------------------------------------
1513
1514/// Precomputed context for evaluating the saved survival marginal-slope
1515/// predictor row-by-row. Built once per call to `predict_survival` so the
1516/// per-(row, t) loop only assembles the per-time q-design slice.
1517struct MarginalSlopePredictContext {
1518    predictor: BernoulliMarginalSlopePredictor,
1519    /// Time-block coefficients (length `p_time_base + p_timewiggle`).
1520    beta_time: Array1<f64>,
1521    /// Covariate (marginal) coefficients.
1522    beta_marginal: Array1<f64>,
1523    saved_timewiggle: Option<SavedBaselineTimeWiggleRuntime>,
1524    /// Covariate design (n × p_marginal), kept operator-backed when possible.
1525    cov_design: DesignMatrix,
1526    /// Logslope design (n × p_logslope), kept operator-backed when possible.
1527    logslope_design: DesignMatrix,
1528    /// Per-row covariate eta = `cov_design[i] · beta_marginal`. Used to
1529    /// pre-compute `q_exit_base`.
1530    cov_eta: Array1<f64>,
1531    /// Per-row latent z (raw, un-normalized — the predictor's
1532    /// `latent_z_normalization` is applied internally).
1533    z_raw: Array1<f64>,
1534    /// Per-row noise offset, mirroring the `pred_input.offset_noise` slice
1535    /// used by the CLI.
1536    noise_offset: Array1<f64>,
1537}
1538
1539fn design_row_owned(
1540    design: &DesignMatrix,
1541    row: usize,
1542    context: &str,
1543) -> Result<Array1<f64>, SurvivalPredictError> {
1544    let chunk = design
1545        .try_row_chunk(row..row + 1)
1546        .map_err(|e| format!("{context}: {e}"))?;
1547    Ok(chunk.row(0).to_owned())
1548}
1549
1550fn build_marginal_slope_predict_context(
1551    model: &SavedModel,
1552    data: ArrayView2<'_, f64>,
1553    col_map: &HashMap<String, usize>,
1554    training_headers: Option<&Vec<String>>,
1555    cov_design: &DesignMatrix,
1556    primary_offset: &Array1<f64>,
1557    noise_offset: &Array1<f64>,
1558    time_build: &SurvivalTimeBuildOutput,
1559    eta_offset_entry: &Array1<f64>,
1560    eta_offset_exit: &Array1<f64>,
1561    derivative_offset_exit: &Array1<f64>,
1562) -> Result<MarginalSlopePredictContext, SurvivalPredictError> {
1563    let z_name = model
1564        .z_column
1565        .as_ref()
1566        .ok_or_else(|| "saved survival marginal-slope model missing z_column".to_string())?;
1567    let z_col = resolve_role_col(col_map, z_name, "z")?;
1568    let z_raw = data.column(z_col).to_owned();
1569
1570    let logslopespec = resolve_termspec_for_prediction(
1571        &model.resolved_termspec_logslope.as_ref().cloned(),
1572        training_headers,
1573        col_map,
1574        "resolved_termspec_logslope",
1575    )?;
1576    let logslope_clipped = model.axis_clip_to_training_ranges(data, col_map);
1577    let logslope_input = logslope_clipped.as_ref().map_or(data, |arr| arr.view());
1578    let logslope_design = build_term_collection_design(logslope_input, &logslopespec)
1579        .map_err(|e| format!("failed to build survival marginal-slope logslope design: {e}"))?;
1580
1581    let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1582    let (predictor, _pred_input, _predictor_fit) = build_saved_survival_marginal_slope_predictor(
1583        model,
1584        &fit_saved,
1585        z_name,
1586        &z_raw,
1587        cov_design,
1588        &logslope_design.design,
1589        time_build,
1590        eta_offset_entry,
1591        eta_offset_exit,
1592        derivative_offset_exit,
1593        primary_offset,
1594        noise_offset,
1595    )?;
1596
1597    let blocks = &fit_saved.blocks;
1598    if blocks.len() < 3 {
1599        return Err(SurvivalPredictError::IncompatibleSchema {
1600            reason: format!(
1601                "saved survival marginal-slope model requires at least 3 blocks [time, marginal, slope], got {}",
1602                blocks.len()
1603            ),
1604        });
1605    }
1606    let beta_time = blocks[0].beta.clone();
1607    let beta_marginal = blocks[1].beta.clone();
1608    let saved_runtime = model.saved_prediction_runtime()?;
1609    let saved_timewiggle = saved_runtime.baseline_time_wiggle.clone();
1610
1611    // cov_eta is time-independent so doing it here avoids `O(n × T)`
1612    // re-multiplications inside the per-cell loop.
1613    let cov_eta = cov_design.dot(&beta_marginal);
1614
1615    Ok(MarginalSlopePredictContext {
1616        predictor,
1617        beta_time,
1618        beta_marginal,
1619        saved_timewiggle,
1620        cov_design: cov_design.clone(),
1621        logslope_design: logslope_design.design.clone(),
1622        cov_eta,
1623        z_raw,
1624        noise_offset: noise_offset.clone(),
1625    })
1626}
1627
1628/// Evaluate one (row, t) cell for the saved survival marginal-slope kernel.
1629///
1630/// Calls the saved [`BernoulliMarginalSlopePredictor`]
1631/// (`predict_eta_and_q_chain`) to obtain both the linear predictor `eta` and
1632/// the exact IFT-pullback factor `∂eta/∂q`. The survival-index time derivative
1633/// is then `(∂eta/∂q) · qd_with_wiggle`. In rigid mode this collapses to
1634/// `c · qd` (the closed-form probit-frailty composition); under score-warp /
1635/// link-deviation it picks up the exact implicit-function pull-back through the
1636/// per-row calibration intercept, mirroring `compute_survival_timepoint_exact`
1637/// in `survival_marginal_slope.rs`.
1638fn evaluate_marginal_slope_row(
1639    row_index: usize,
1640    ctx: &MarginalSlopePredictContext,
1641    row_time: &SurvivalTimeBuildOutput,
1642    r_eta_exit: &Array1<f64>,
1643    r_deriv_exit: &Array1<f64>,
1644    primary_offset_row: f64,
1645) -> Result<(f64, f64, f64), SurvivalPredictError> {
1646    let beta_time = &ctx.beta_time;
1647    let p_time_base = row_time.x_exit_time.ncols();
1648    let p_timewiggle = ctx
1649        .saved_timewiggle
1650        .as_ref()
1651        .map_or(0, |runtime| runtime.beta.len());
1652    if beta_time.len() != p_time_base + p_timewiggle {
1653        return Err(SurvivalPredictError::IncompatibleSchema {
1654            reason: format!(
1655                "saved survival marginal-slope time coefficient mismatch: beta has {} entries but expected base={} plus timewiggle={}",
1656                beta_time.len(),
1657                p_time_base,
1658                p_timewiggle
1659            ),
1660        });
1661    }
1662    let beta_time_base = beta_time.slice(s![..p_time_base]).to_owned();
1663
1664    // Pre-wiggle q-eta for this (row, t) cell. Mirrors the CLI's `q_exit_base`
1665    // construction in `build_saved_survival_marginal_slope_predictor`:
1666    //   q = time_basis(t) · beta_time_base + cov[row] · beta_marginal
1667    //       + r_eta_exit + primary_offset_row.
1668    let q_exit_base = row_time.x_exit_time.dot(&beta_time_base)[0]
1669        + ctx.cov_eta[row_index]
1670        + r_eta_exit[0]
1671        + primary_offset_row;
1672    let qd_exit_base = row_time.x_derivative_time.dot(&beta_time_base)[0] + r_deriv_exit[0];
1673
1674    // For timewiggle the `exit_design` row enters the predictor's q-design;
1675    // the `derivative_design` row enters the time-derivative used to build the
1676    // hazard. Both are evaluated at the wiggle anchor `q_exit_base`.
1677    let (qd_with_wiggle, exit_wiggle_design) = if let Some(runtime) = ctx.saved_timewiggle.as_ref()
1678    {
1679        let knots = Array1::from_vec(runtime.knots.clone());
1680        let beta_w = beta_time.slice(s![p_time_base..]).to_owned();
1681        let eta_exit_row = Array1::from_elem(1, q_exit_base);
1682        let deriv_row = Array1::from_elem(1, qd_exit_base);
1683        let exit_design = match buildwiggle_block_input_from_knots(
1684            eta_exit_row.view(),
1685            &knots,
1686            runtime.degree,
1687            2,
1688            false,
1689        )?
1690        .design
1691        {
1692            DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
1693            _ => {
1694                return Err(SurvivalPredictError::IncompatibleSchema {
1695                    reason: "saved baseline-timewiggle exit design must be dense".to_string(),
1696                });
1697            }
1698        };
1699        let derivative_design = build_survival_timewiggle_derivative_design(
1700            &eta_exit_row,
1701            &deriv_row,
1702            &knots,
1703            runtime.degree,
1704        )?;
1705        (
1706            qd_exit_base + derivative_design.dot(&beta_w)[0],
1707            Some(exit_design),
1708        )
1709    } else {
1710        (qd_exit_base, None)
1711    };
1712
1713    // Build a 1-row PredictInput for this (row, t) cell and call the saved
1714    // predictor. The predictor's `marginal_eta` formula is
1715    //   marginal_eta = q_design · combined_q_beta + baseline_marginal + offset
1716    // with `combined_q_beta = [beta_time | beta_marginal]` and the survival
1717    // predictor sets `baseline_marginal = 0`. We supply the full per-row
1718    // q_design = [time_basis(t) | timewiggle | cov_design[row]] so the
1719    // predictor reproduces `q_with_wiggle` exactly with `offset = r_eta_exit[0]
1720    // + primary_offset_row`.
1721    let cov_dim = ctx.beta_marginal.len();
1722    let q_design_ncols = p_time_base + p_timewiggle + cov_dim;
1723    let mut q_design_full = Array2::<f64>::zeros((1, q_design_ncols));
1724    q_design_full
1725        .slice_mut(s![.., ..p_time_base])
1726        .assign(&row_time.x_exit_time.to_dense());
1727    if let Some(exit_w) = exit_wiggle_design.as_ref() {
1728        q_design_full
1729            .slice_mut(s![.., p_time_base..p_time_base + p_timewiggle])
1730            .assign(exit_w);
1731    }
1732    if cov_dim > 0 {
1733        let cov_row = design_row_owned(
1734            &ctx.cov_design,
1735            row_index,
1736            "survival marginal covariate row",
1737        )?;
1738        q_design_full
1739            .slice_mut(s![.., p_time_base + p_timewiggle..])
1740            .row_mut(0)
1741            .assign(&cov_row);
1742    }
1743
1744    // Logslope design row + offset chosen so that the predictor's logslope_eta
1745    // equals our precomputed `slope_eta[row]`.  The predictor computes:
1746    //   logslope_eta = design_noise · beta_logslope + baseline_logslope
1747    //                  + offset_noise.
1748    // We feed the actual saved logslope row + the row's noise offset, matching
1749    // exactly the CLI's `pred_input.design_noise` / `offset_noise` slice.
1750    let logslope_row = design_row_owned(
1751        &ctx.logslope_design,
1752        row_index,
1753        "survival marginal logslope row",
1754    )?;
1755    let mut logslope_design_2d = Array2::<f64>::zeros((1, logslope_row.len()));
1756    logslope_design_2d.row_mut(0).assign(&logslope_row);
1757
1758    let pred_input = PredictInput {
1759        design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(q_design_full)),
1760        offset: Array1::from_elem(1, r_eta_exit[0] + primary_offset_row),
1761        design_noise: Some(DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
1762            logslope_design_2d,
1763        ))),
1764        offset_noise: Some(Array1::from_elem(1, ctx.noise_offset[row_index])),
1765        auxiliary_scalar: Some(Array1::from_elem(1, ctx.z_raw[row_index])),
1766        auxiliary_matrix: None,
1767    };
1768
1769    // Exact IFT pull-back: the predictor returns both `eta` and the analytic
1770    // factor `∂eta/∂q` for this (row, t). This gives d eta(t) / dt; the hazard
1771    // conversion below divides the event density by S(t).
1772    let (eta_arr, deta_dq_arr) = ctx
1773        .predictor
1774        .predict_eta_and_q_chain(&pred_input)
1775        .map_err(|e| format!("saved survival marginal-slope predictor eta failed: {e}"))?;
1776    let eta = eta_arr[0];
1777    // `qd_with_wiggle` is the base survival-index time derivative q'(t), built
1778    // identically to fit-time `qd1 = dq_dq0·d_raw` (the wiggle chain and the
1779    // `+derivative_guard` offset are both already folded into `qd_exit_base`),
1780    // so there is no predict-vs-fit desync in the derivative reconstruction.
1781    //
1782    // Fit enforces the monotonicity floor `q'(t) >= derivative_guard` ONLY at
1783    // each training row's own exit time (one `t` per row), via the active-set
1784    // guard constraints. A prediction horizon is an arbitrary `t` — typically a
1785    // single CIF horizon evaluated for every row — which generally is NOT one of
1786    // the constrained training exit times. Where that horizon lands in a region
1787    // of sparse/no training exits, the penalized baseline spline can extrapolate
1788    // to a locally decreasing survival index, so `q'(t) < 0` is a legitimate
1789    // model statement ("no instantaneous hazard accrues here"), not a numerical
1790    // bug. The instantaneous hazard rate is physically non-negative, so the
1791    // truthful response is to clamp the index time-derivative at its floor 0
1792    // (flat hazard, survival locally constant) rather than reject the whole
1793    // prediction — clamping keeps the CIF well-posed and monotone. Only a
1794    // non-finite derivative (a real numerical failure) is surfaced to the strict
1795    // validator below.
1796    let eta_derivative = marginal_slope_index_derivative_at_horizon(deta_dq_arr[0], qd_with_wiggle);
1797    let (cum, haz) = probit_survival_hazard_components(eta, eta_derivative)?;
1798    Ok((eta, cum, haz))
1799}
1800
1801/// Reconstruct the marginal-slope survival index time-derivative `eta'(t)` at a
1802/// prediction horizon and clamp it to its physical floor.
1803///
1804/// `deta_dq = ∂eta/∂q ≥ 1` is the rigid probit-frailty chain factor and
1805/// `qd_with_wiggle = q'(t)` is the base survival-index time derivative built
1806/// identically to fit-time `qd1`. The instantaneous hazard rate `h(t) = mills ·
1807/// eta'(t)` is physically non-negative, so a finite negative `eta'(t)` — which a
1808/// penalized baseline spline can legitimately produce when the prediction
1809/// horizon lands outside the training exit times the monotonicity guard
1810/// constrains — is clamped to its floor 0 (flat hazard, locally constant
1811/// survival), keeping the CIF well-posed. Non-finite values pass through
1812/// unchanged so the strict validator rejects them as genuine numerical failures.
1813#[inline]
1814fn marginal_slope_index_derivative_at_horizon(deta_dq: f64, qd_with_wiggle: f64) -> f64 {
1815    let eta_derivative = deta_dq * qd_with_wiggle;
1816    if eta_derivative.is_finite() {
1817        eta_derivative.max(0.0)
1818    } else {
1819        eta_derivative
1820    }
1821}
1822
1823#[inline]
1824fn probit_survival_hazard_components(
1825    eta: f64,
1826    eta_derivative: f64,
1827) -> Result<(f64, f64), SurvivalPredictError> {
1828    if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative >= 0.0) {
1829        return Err(SurvivalPredictError::NumericalFailure {
1830            reason: format!(
1831                "saved survival marginal-slope prediction produced invalid survival index derivative: eta={eta}, eta_t={eta_derivative}"
1832            ),
1833        });
1834    }
1835
1836    // Survival marginal-slope defines S(t) = Phi(-eta(t)). The event density
1837    // is f(t) = phi(eta(t)) * eta'(t), while the hazard rate exposed by the
1838    // prediction API is h(t) = f(t) / S(t). The signed-probit helper returns
1839    // both log Phi(-eta) and the stable Mills ratio phi(eta) / Phi(-eta).
1840    let (log_survival, mills_ratio) = signed_probit_logcdf_and_mills_ratio(-eta);
1841    let cumulative_hazard = -log_survival;
1842    let hazard = if eta_derivative == 0.0 {
1843        0.0
1844    } else {
1845        mills_ratio * eta_derivative
1846    };
1847    // `>= 0.0` rejects NaN (a programming-bug signal) and accepts the full
1848    // mathematical range [0, +∞]. Saturated probit fits where the model
1849    // genuinely says S(t)→0 produce a +∞ cumulative hazard — that is the
1850    // truthful answer, and the consumer's `survival = exp(-cum).clamp(0,1)`
1851    // handles it cleanly. Rejecting +∞ would force the predictor to fail on
1852    // models that the inner solver has already certified as a valid fit.
1853    if !(cumulative_hazard >= 0.0 && hazard >= 0.0) {
1854        return Err(SurvivalPredictError::NumericalFailure {
1855            reason: format!(
1856                "saved survival marginal-slope prediction produced invalid survival components: eta={eta}, eta_t={eta_derivative}, log_survival={log_survival}, hazard={hazard}"
1857            ),
1858        });
1859    }
1860    Ok((cumulative_hazard, hazard))
1861}
1862
1863fn evaluate_rp_row(
1864    model: &SavedModel,
1865    row_time: &SurvivalTimeBuildOutput,
1866    cov_row: &Array1<f64>,
1867    eta_time_offset_row: f64,
1868    derivative_time_offset_row: f64,
1869    primary_offset_row: f64,
1870) -> Result<(f64, f64, f64), SurvivalPredictError> {
1871    let fit_saved = fit_result_from_saved_model_for_prediction(model)?;
1872    let saved_runtime = model.saved_prediction_runtime()?;
1873    evaluate_rp_row_with_beta(
1874        &fit_saved.beta,
1875        saved_runtime.baseline_time_wiggle.as_ref(),
1876        row_time,
1877        cov_row,
1878        eta_time_offset_row,
1879        derivative_time_offset_row,
1880        primary_offset_row,
1881    )
1882}
1883
1884fn evaluate_rp_row_with_beta(
1885    beta: &Array1<f64>,
1886    saved_timewiggle: Option<&SavedBaselineTimeWiggleRuntime>,
1887    row_time: &SurvivalTimeBuildOutput,
1888    cov_row: &Array1<f64>,
1889    eta_time_offset_row: f64,
1890    derivative_time_offset_row: f64,
1891    primary_offset_row: f64,
1892) -> Result<(f64, f64, f64), SurvivalPredictError> {
1893    let p_time = row_time.x_exit_time.ncols();
1894    let p_timewiggle = saved_timewiggle.map_or(0, |runtime| runtime.beta.len());
1895    let p_cov = cov_row.len();
1896    let p = p_time + p_timewiggle + p_cov;
1897    if beta.len() != p {
1898        return Err(SurvivalPredictError::IncompatibleSchema {
1899            reason: format!(
1900                "survival RP coefficient mismatch: beta has {} entries but design has {} columns",
1901                beta.len(),
1902                p
1903            ),
1904        });
1905    }
1906    let mut x_exit = Array2::<f64>::zeros((1, p));
1907    if p_time > 0 {
1908        x_exit
1909            .slice_mut(s![.., ..p_time])
1910            .assign(&row_time.x_exit_time.to_dense());
1911    }
1912    let mut eta_derivative = derivative_time_offset_row;
1913    if p_time > 0 {
1914        eta_derivative += row_time
1915            .x_derivative_time
1916            .dot(&beta.slice(s![..p_time]).to_owned())[0];
1917    }
1918    if let Some(runtime) = saved_timewiggle {
1919        let knots = Array1::from_vec(runtime.knots.clone());
1920        let beta_w = beta.slice(s![p_time..p_time + p_timewiggle]).to_owned();
1921        let eta_exit_row = Array1::from_elem(1, eta_time_offset_row);
1922        let derivative_exit_row = Array1::from_elem(1, derivative_time_offset_row);
1923        let exit_design = match buildwiggle_block_input_from_knots(
1924            eta_exit_row.view(),
1925            &knots,
1926            runtime.degree,
1927            2,
1928            false,
1929        )?
1930        .design
1931        {
1932            DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
1933            _ => {
1934                return Err(SurvivalPredictError::IncompatibleSchema {
1935                    reason: "saved baseline-timewiggle exit design must be dense".to_string(),
1936                });
1937            }
1938        };
1939        if exit_design.ncols() != p_timewiggle {
1940            return Err(SurvivalPredictError::IncompatibleSchema {
1941                reason: format!(
1942                    "survival RP timewiggle design mismatch: rebuilt {} columns but runtime expects {}",
1943                    exit_design.ncols(),
1944                    p_timewiggle
1945                ),
1946            });
1947        }
1948        x_exit
1949            .slice_mut(s![.., p_time..p_time + p_timewiggle])
1950            .assign(&exit_design);
1951        let derivative_design = build_survival_timewiggle_derivative_design(
1952            &eta_exit_row,
1953            &derivative_exit_row,
1954            &knots,
1955            runtime.degree,
1956        )?;
1957        eta_derivative += derivative_design.dot(&beta_w)[0];
1958    }
1959    if p_cov > 0 {
1960        x_exit
1961            .slice_mut(s![
1962                ..,
1963                (p_time + p_timewiggle)..(p_time + p_timewiggle + p_cov)
1964            ])
1965            .row_mut(0)
1966            .assign(cov_row);
1967    }
1968    let offset_view = Array1::from_elem(1, eta_time_offset_row + primary_offset_row);
1969    let likelihood = LikelihoodSpec::new(
1970        ResponseFamily::RoystonParmar,
1971        InverseLink::Standard(StandardLink::Identity),
1972    );
1973    let eta =
1974        predict_royston_parmar_eta(x_exit.view(), beta.view(), offset_view.view(), &likelihood)?[0];
1975    let (cum, haz) = royston_parmar_survival_hazard_components(eta, eta_derivative)?;
1976    Ok((eta, cum, haz))
1977}
1978
1979fn predict_royston_parmar_eta<X>(
1980    x: X,
1981    beta: ndarray::ArrayView1<'_, f64>,
1982    offset: ndarray::ArrayView1<'_, f64>,
1983    likelihood: &LikelihoodSpec,
1984) -> Result<Array1<f64>, SurvivalPredictError>
1985where
1986    X: Into<DesignMatrix>,
1987{
1988    if !matches!(likelihood.response, ResponseFamily::RoystonParmar)
1989        || !matches!(
1990            likelihood.link,
1991            InverseLink::Standard(StandardLink::Identity)
1992        )
1993    {
1994        return Err(SurvivalPredictError::UnsupportedConfiguration {
1995            reason: "survival prediction requires RoystonParmar with identity link".to_string(),
1996        });
1997    }
1998    let x = x.into();
1999    if x.nrows() != offset.len() || x.ncols() != beta.len() {
2000        return Err(SurvivalPredictError::IncompatibleSchema {
2001            reason: format!(
2002                "survival prediction design dimensions disagree: design is {}x{}, beta has length {}, offset has length {}",
2003                x.nrows(),
2004                x.ncols(),
2005                beta.len(),
2006                offset.len()
2007            ),
2008        });
2009    }
2010    let mut eta = x.matrixvectormultiply(&beta.to_owned());
2011    eta += &offset;
2012    Ok(eta)
2013}
2014
2015#[inline]
2016fn royston_parmar_survival_hazard_components(
2017    eta: f64,
2018    eta_derivative: f64,
2019) -> Result<(f64, f64), SurvivalPredictError> {
2020    // `eta = log Λ(t)` and `eta_derivative = d(log Λ)/dt`, so the instantaneous
2021    // hazard is `h(t) = Λ(t) · eta_derivative = dΛ/dt`. Reject only the true bug
2022    // signals: a non-finite `eta`, and a derivative that is NaN or genuinely
2023    // negative.
2024    //
2025    // `eta_derivative == 0` is a VALID boundary value, not a failure. The RP
2026    // baseline `log Λ(t)` is an I-spline (monotone non-decreasing cumulative
2027    // hazard): beyond its last interior knot every I-spline basis is flat, so
2028    // its time-derivative is exactly 0 and the instantaneous hazard there is 0
2029    // (`S(t)` locally constant). Any RP model predicted on a grid that extends
2030    // past its training support hits this regime on the tail nodes. The earlier
2031    // strict `> 0.0` gate spuriously failed those predictions (#1564). The
2032    // probit / marginal-slope sibling guard
2033    // (`probit_survival_hazard_components`) already accepts the full `[0, ∞)`
2034    // range and maps a zero derivative to a zero hazard; the RP guard must match.
2035    if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative >= 0.0) {
2036        return Err(SurvivalPredictError::NumericalFailure {
2037            reason: format!(
2038                "saved Royston-Parmar survival prediction produced invalid log-cumulative-hazard derivative: eta={eta}, eta_t={eta_derivative}"
2039            ),
2040        });
2041    }
2042    let cumulative_hazard = eta.exp();
2043    // `h(t) = Λ(t) · d(log Λ)/dt`. Compute the zero-derivative boundary FIRST so
2044    // the `Λ = +∞` (saturated tail, `eta >~ 709.78`) × `0` (flat I-spline)
2045    // indeterminate form resolves to the mathematically correct `0`, not the
2046    // `NaN` that `f64::INFINITY * 0.0` produces. A flat cumulative hazard has
2047    // zero instantaneous hazard regardless of its (possibly saturated) level.
2048    let hazard = if eta_derivative == 0.0 {
2049        0.0
2050    } else {
2051        cumulative_hazard * eta_derivative
2052    };
2053    // Royston-Parmar parameterizes `eta = log Lambda(t)`, so `Lambda = exp(eta)`
2054    // is unbounded above and `exp(eta)` saturates to `+∞` in f64 once
2055    // `eta >~ 709.78` — exactly the regime a saturated RP fit produces in the
2056    // right tail. The math is well-defined (`S(t) → 0`, `h(t) → ∞`); rejecting
2057    // `+∞` here would crash predict on a fit the inner solver already accepted.
2058    // `>= 0.0` rejects NaN (the only true bug signal) while allowing the full
2059    // [0, +∞] range. The consumer materializes survival via
2060    // `survival = exp(-cum).clamp(0, 1)`, which collapses cleanly at saturation.
2061    if !(cumulative_hazard >= 0.0 && hazard >= 0.0) {
2062        return Err(SurvivalPredictError::NumericalFailure {
2063            reason: format!(
2064                "saved Royston-Parmar survival prediction produced invalid survival components: eta={eta}, eta_t={eta_derivative}, cumulative_hazard={cumulative_hazard}, hazard={hazard}"
2065            ),
2066        });
2067    }
2068    Ok((cumulative_hazard, hazard))
2069}
2070
2071/// Batch evaluator for the location-scale survival likelihood mode.
2072///
2073/// Mirrors the CLI's LocationScale predict path (main.rs::run_predict_survival
2074/// LocationScale arm) but stays library-only: builds the threshold/log_sigma
2075/// designs from the saved frozen specs, replays the saved scale-deviation
2076/// transform on the noise design, applies the survival time-derivative guard,
2077/// and calls `predict_survival_location_scale`.
2078///
2079/// Plugin survival only — uncertainty paths still live in the CLI.
2080fn predict_survival_location_scale_batch(
2081    model: &SavedModel,
2082    age_entry: &Array1<f64>,
2083    age_exit: &Array1<f64>,
2084    cov_design: &gam_terms::smooth::TermCollectionDesign,
2085    primary_offset: &Array1<f64>,
2086    noise_offset: &Array1<f64>,
2087    training_headers: Option<&Vec<String>>,
2088    col_map: &HashMap<String, usize>,
2089    data: ArrayView2<'_, f64>,
2090    time_grid: Option<&[f64]>,
2091    with_uncertainty: bool,
2092) -> Result<SurvivalPredictResult, String> {
2093    use crate::scale_design::build_scale_deviation_operator;
2094    use crate::survival::construction::evaluate_survival_time_basis_row;
2095    use crate::survival::location_scale::{
2096        SurvivalLocationScalePredictInput, predict_survival_location_scale,
2097        predict_survival_location_scale_from_linear_components,
2098        predict_survival_location_scalewith_uncertainty,
2099    };
2100    use gam_linalg::matrix::DesignMatrix;
2101
2102    let n = age_entry.len();
2103    let per_row_eval = time_grid.is_none();
2104    let eval_times: Vec<f64> = match time_grid {
2105        Some(grid) => {
2106            if grid.is_empty() {
2107                return Err("survival time_grid must contain at least one time".to_string());
2108            }
2109            for (idx, &t) in grid.iter().enumerate() {
2110                if !t.is_finite() || t < 0.0 {
2111                    return Err(format!(
2112                        "survival time_grid requires finite non-negative times (index {idx})",
2113                    ));
2114                }
2115            }
2116            grid.to_vec()
2117        }
2118        None => Vec::new(),
2119    };
2120    let t_cols = if per_row_eval { 1 } else { eval_times.len() };
2121    let eval_width = if per_row_eval { 1 } else { t_cols + 1 };
2122    let saved_likelihood_mode = SurvivalLikelihoodMode::LocationScale;
2123    let baseline_cfg = saved_survival_runtime_baseline_config(model)?;
2124    let saved_fit = saved_survival_location_scale_fit_result(model)?;
2125    // Reduced parametric-AFT regime (issue #892): the fit removed the time warp
2126    // entirely (`h ≡ 0`, zero free time columns) and carried the σ-scaled `log t`
2127    // baseline as a per-row LOCATION shift `η_t → η_t − log t`, so the
2128    // standardized residual is `u = inv_sigma·(log t − η_t) = (log t − μ)/σ` and
2129    // σ is identified through the event Jacobian's `−log σ` term (the
2130    // survreg / lifelines / flexsurv AFT gauge). The saved model therefore carries
2131    // a time-warp β that is identically ZERO: the reduced time block has zero free
2132    // columns and the Gauge-owned affine shift is zero, so the finalized
2133    // `beta_time = T·β_reduced + a` is an all-zero length-`p` vector (exact zeros
2134    // — no arithmetic noise — or empty when p==0). A genuine
2135    // flexible location-scale fit always retains a non-zero unpenalized monotone
2136    // log-t trend in its warp (its affine null space is never shrunk away), so an
2137    // all-zero `beta_time` uniquely identifies the reduced regime. Predict must
2138    // MIRROR the `−log t` location shift instead of reconstructing a warp from the
2139    // zero `beta_time`; otherwise `S(t|x)` carries no `log t` dependence and is
2140    // wrong for every saved reduced-AFT model. Detected from the saved payload
2141    // alone (zero time-warp β + no learned baseline timewiggle), so no new
2142    // persisted flag is needed. (`iter().all` is `true` on an empty β too.)
2143    let reduced_parametric_aft =
2144        !model.has_baseline_time_wiggle() && saved_fit.beta_time().iter().all(|&b| b == 0.0);
2145    let time_cfg = load_survival_time_basis_config_from_model(model)?;
2146    let mut time_build = build_survival_time_basis(age_entry, age_exit, time_cfg.clone(), None)?;
2147    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
2148        &time_build.basisname,
2149        time_build.degree,
2150        time_build.knots.as_ref(),
2151        time_build.keep_cols.as_ref(),
2152        time_build.smooth_lambda,
2153    )?;
2154    let time_anchor = model
2155        .survival_time_anchor
2156        .ok_or_else(|| "saved survival model missing survival_time_anchor".to_string())?;
2157    let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
2158    center_survival_time_designs_at_anchor(
2159        &mut time_build.x_entry_time,
2160        &mut time_build.x_exit_time,
2161        &time_anchor_row,
2162    )?;
2163    // The reduced-AFT regime has no structural time warp (the monotone baseline
2164    // rides the location channel), so the structural-basis requirement does not
2165    // apply to it.
2166    if !model.has_baseline_time_wiggle() && !reduced_parametric_aft {
2167        require_structural_survival_time_basis(&time_build.basisname, "saved survival sampling")?;
2168    }
2169    let saved_inverse_link = resolve_survival_inverse_link_from_saved(model)?;
2170    let (eval_entry, eval_exit) = if per_row_eval {
2171        (age_entry.clone(), age_exit.clone())
2172    } else {
2173        let total = n * eval_width;
2174        let mut entry = Array1::<f64>::zeros(total);
2175        let mut exit = Array1::<f64>::zeros(total);
2176        {
2177            use rayon::iter::{IntoParallelIterator, ParallelIterator};
2178            let pairs: Vec<(f64, f64)> = (0..total)
2179                .into_par_iter()
2180                .map(|k| {
2181                    let i = k / eval_width;
2182                    let col = k % eval_width;
2183                    let t = if col < t_cols {
2184                        eval_times[col]
2185                    } else {
2186                        age_exit[i]
2187                    };
2188                    (age_entry[i].min(t), t)
2189                })
2190                .collect();
2191            for (k, (t0, t1)) in pairs.into_iter().enumerate() {
2192                entry[k] = t0;
2193                exit[k] = t1;
2194            }
2195        }
2196        (entry, exit)
2197    };
2198    let mut time_build =
2199        build_survival_time_basis(&eval_entry, &eval_exit, time_cfg.clone(), None)?;
2200    center_survival_time_designs_at_anchor(
2201        &mut time_build.x_entry_time,
2202        &mut time_build.x_exit_time,
2203        &time_anchor_row,
2204    )?;
2205    let (mut eta_offset_entry, mut eta_offset_exit, mut derivative_offset_exit) =
2206        build_survival_time_offsets_for_likelihood(
2207            &eval_entry,
2208            &eval_exit,
2209            &baseline_cfg,
2210            saved_likelihood_mode,
2211            Some(&saved_inverse_link),
2212        )?;
2213    add_survival_time_derivative_guard_offset(
2214        &eval_entry,
2215        &eval_exit,
2216        time_anchor,
2217        survival_derivative_guard_for_likelihood(saved_likelihood_mode),
2218        &mut eta_offset_entry,
2219        &mut eta_offset_exit,
2220        &mut derivative_offset_exit,
2221    )?;
2222    if reduced_parametric_aft {
2223        // The warp is removed in this regime (`h ≡ 0`); the σ-scaled log-t baseline
2224        // rides the location channel via the `−log t` threshold shift applied
2225        // below. The saved `beta_time` is an all-zero length-`p` vector (the
2226        // reduced time block has zero free columns and a zero affine lift), so the
2227        // time-warp contribution `x_exit_time · beta_time` is identically zero for
2228        // ANY design — we therefore KEEP the full-width centered basis (so the
2229        // hazard's `beta.len() == x_exit_time.ncols()` check holds and the
2230        // scale-deviation primary keeps its full column count to match the saved
2231        // transform) and only zero the value OFFSET so `h_base = 0`. The derivative
2232        // is handled separately from `inv_sigma / t` in the hazard computation, so
2233        // the entry/derivative designs and offsets are left as built.
2234        eta_offset_exit = Array1::<f64>::zeros(eval_exit.len());
2235    }
2236
2237    let saved_timewiggle_runtime = model.saved_baseline_time_wiggle()?;
2238
2239    // Build threshold + log-sigma designs from the frozen saved specs. Re-using
2240    // resolve_termspec_for_prediction guarantees we honor the predict-data's
2241    // column layout via the model's training_headers.
2242    // The threshold design uses the same frozen spec as the covariate design
2243    // already built for predict_survival; reuse it instead of rebuilding.
2244    let threshold_design = cov_design;
2245    let log_sigmaspec = resolve_termspec_for_prediction(
2246        &model.resolved_termspec_noise,
2247        training_headers,
2248        col_map,
2249        "resolved_termspec_noise",
2250    )?;
2251    let sigma_clipped = model.axis_clip_to_training_ranges(data, col_map);
2252    let sigma_input = sigma_clipped.as_ref().map_or(data, |arr| arr.view());
2253    let raw_sigma_design =
2254        gam_terms::smooth::build_term_collection_design(sigma_input, &log_sigmaspec)
2255            .map_err(|err| format!("failed to build survival log-sigma design: {err}"))?;
2256    let survival_noise_transform = scale_transform_from_payload(
2257        &model.survival_noise_projection,
2258        &model.survival_noise_center,
2259        &model.survival_noise_scale,
2260        model.survival_noise_non_intercept_start,
2261        model.survival_noise_projection_ridge_alpha,
2262    )?;
2263
2264    let x_time_exit_dense = time_build
2265        .x_exit_time
2266        .try_to_dense_by_chunks("survival location-scale prediction time-exit design")?;
2267    let total_rows = eval_exit.len();
2268    let x_time_exit = if let Some(runtime) = saved_timewiggle_runtime.as_ref() {
2269        let mut full =
2270            Array2::<f64>::zeros((total_rows, x_time_exit_dense.ncols() + runtime.beta.len()));
2271        full.slice_mut(s![.., 0..x_time_exit_dense.ncols()])
2272            .assign(&x_time_exit_dense);
2273        full
2274    } else {
2275        x_time_exit_dense
2276    };
2277
2278    let repeat_rows =
2279        |matrix: &DesignMatrix, label: &str| -> Result<DesignMatrix, SurvivalPredictError> {
2280            if per_row_eval {
2281                return Ok(matrix.clone());
2282            }
2283            let dense = matrix.try_to_dense_by_chunks(label)?;
2284            let mut repeated = Array2::<f64>::zeros((total_rows, dense.ncols()));
2285            use rayon::iter::{IntoParallelIterator, ParallelIterator};
2286            let rows: Vec<Vec<f64>> = (0..total_rows)
2287                .into_par_iter()
2288                .map(|k| dense.row(k / eval_width).to_vec())
2289                .collect();
2290            for (k, row) in rows.into_iter().enumerate() {
2291                for (j, value) in row.into_iter().enumerate() {
2292                    repeated[[k, j]] = value;
2293                }
2294            }
2295            Ok(DesignMatrix::from(repeated))
2296        };
2297    let threshold_matrix = repeat_rows(
2298        &threshold_design.design,
2299        "survival location-scale prediction threshold design",
2300    )?;
2301    let raw_sigma_matrix = repeat_rows(
2302        &raw_sigma_design.design,
2303        "survival location-scale prediction log-sigma design",
2304    )?;
2305
2306    // Scale-deviation primary mirrors the fit's full location design: `x_time_exit`
2307    // carries the full centered basis (plus any timewiggle columns) in every regime
2308    // — including the reduced parametric-AFT regime, where the basis is retained at
2309    // full width with an all-zero `beta_time` — matching the fit's `time_design_exit`.
2310    let time_design = DesignMatrix::from(x_time_exit.clone());
2311    let survival_primary_design =
2312        DesignMatrix::hstack(vec![time_design, threshold_matrix.clone()])?;
2313    let prepared_sigma_design = if let Some(transform) = survival_noise_transform.as_ref() {
2314        build_scale_deviation_operator(survival_primary_design, raw_sigma_matrix, transform)?
2315    } else {
2316        raw_sigma_matrix
2317    };
2318    let link_wiggle_knots = model
2319        .linkwiggle_knots
2320        .as_ref()
2321        .map(|k| Array1::from_vec(k.clone()));
2322    let link_wiggle_degree = model.linkwiggle_degree;
2323    let time_wiggle_knots = saved_timewiggle_runtime
2324        .as_ref()
2325        .map(|w| Array1::from_vec(w.knots.clone()));
2326    let time_wiggle_degree = saved_timewiggle_runtime.as_ref().map(|w| w.degree);
2327    let time_wiggle_ncols = saved_timewiggle_runtime
2328        .as_ref()
2329        .map_or(0, |w| w.beta.len());
2330
2331    let expand_vector = |values: &Array1<f64>| -> Array1<f64> {
2332        if per_row_eval {
2333            values.clone()
2334        } else {
2335            Array1::from_shape_fn(total_rows, |k| values[k / eval_width])
2336        }
2337    };
2338    // Threshold (location) offset. In the reduced parametric-AFT regime the
2339    // σ-scaled `log t` baseline rides the location channel: shift the effective
2340    // location `η_t → η_t − log t` per query time so the predicted standardized
2341    // residual reproduces `u = inv_sigma·(log t − η_t) = (log t − μ)/σ`, exactly
2342    // as the fit's `LocationLogTimeOffset` does. `eval_exit` already carries the
2343    // per-(row, time) query exit times in the same flattened layout as the
2344    // expanded offsets; `−log t` uses the same `SURVIVAL_TIME_FLOOR` floor as the
2345    // fit's `checked_log_survival_times` (issue #892).
2346    let eta_threshold_offset = {
2347        let mut offset = expand_vector(primary_offset);
2348        if reduced_parametric_aft {
2349            for (slot, &t) in offset.iter_mut().zip(eval_exit.iter()) {
2350                *slot -= t
2351                    .max(crate::survival::construction::SURVIVAL_TIME_FLOOR)
2352                    .ln();
2353            }
2354        }
2355        offset
2356    };
2357    // Build the SurvivalLocationScalePredictInput once, with replicated /
2358    // expanded designs and offsets, regardless of `per_row_eval`.  This
2359    // unifies the mean-only and uncertainty paths and lets the
2360    // uncertainty branch reuse the same input.
2361    let pred_input = SurvivalLocationScalePredictInput {
2362        x_time_exit,
2363        eta_time_offset_exit: eta_offset_exit.clone(),
2364        time_wiggle_knots: time_wiggle_knots.clone(),
2365        time_wiggle_degree,
2366        time_wiggle_ncols,
2367        x_threshold: threshold_matrix,
2368        eta_threshold_offset,
2369        x_log_sigma: prepared_sigma_design,
2370        eta_log_sigma_offset: expand_vector(noise_offset),
2371        x_link_wiggle: None,
2372        link_wiggle_knots: link_wiggle_knots.clone(),
2373        link_wiggle_degree,
2374        inverse_link: saved_inverse_link.clone(),
2375    };
2376
2377    // Mean / SE computation.  The uncertainty path also computes the
2378    // survival mean and eta, so we use whichever output we have.
2379    let (eta_full, survival_prob_full, response_se_full, eta_se_full): (
2380        Array1<f64>,
2381        Array1<f64>,
2382        Option<Array1<f64>>,
2383        Option<Array1<f64>>,
2384    ) = if with_uncertainty {
2385        let cov = saved_fit.beta_covariance().ok_or_else(|| {
2386            "survival location-scale uncertainty: saved fit is missing the \
2387             posterior covariance; refit with the current CLI / library to \
2388             populate beta_covariance"
2389                .to_string()
2390        })?;
2391        let unc = predict_survival_location_scalewith_uncertainty(
2392            &pred_input,
2393            &saved_fit,
2394            cov,
2395            false,
2396            true,
2397        )
2398        .map_err(|err| format!("survival location-scale uncertainty predict failed: {err}"))?;
2399        let response_se = unc.response_standard_error.ok_or_else(|| {
2400            "survival location-scale uncertainty: response_standard_error \
2401             missing despite include_response_sd=true"
2402                .to_string()
2403        })?;
2404        (
2405            unc.eta,
2406            unc.survival_prob,
2407            Some(response_se),
2408            Some(unc.eta_standard_error),
2409        )
2410    } else if per_row_eval {
2411        let pred = predict_survival_location_scale(&pred_input, &saved_fit)
2412            .map_err(|err| format!("survival location-scale predict failed: {err}"))?;
2413        (pred.eta, pred.survival_prob, None, None)
2414    } else {
2415        let beta_threshold = saved_fit.beta_threshold();
2416        let beta_log_sigma = saved_fit.beta_log_sigma();
2417        let eta_t_subject =
2418            cov_design.design.matrixvectormultiply(&beta_threshold) + primary_offset;
2419        // `expand_vector(noise_offset)` already lives on `pred_input` as
2420        // `eta_log_sigma_offset`; reuse it instead of re-expanding the noise
2421        // offset (a per-call allocation when the time grid is explicit).
2422        let eta_ls_subject = prepared_sigma_design_view(&pred_input)
2423            .matrixvectormultiply(&beta_log_sigma)
2424            + &pred_input.eta_log_sigma_offset;
2425        // This explicit-grid branch rebuilds the per-(row, time) location predictor
2426        // from the per-subject `eta_t_subject` directly (bypassing
2427        // `pred_input.eta_threshold_offset`), so it must apply the reduced-AFT
2428        // `−log t` location shift here too — otherwise the grid path would predict a
2429        // `log t`-flat surface even though the per-row path is shifted (issue #892).
2430        let mut eta_t = expand_vector(&eta_t_subject);
2431        if reduced_parametric_aft {
2432            for (slot, &t) in eta_t.iter_mut().zip(eval_exit.iter()) {
2433                *slot -= t
2434                    .max(crate::survival::construction::SURVIVAL_TIME_FLOOR)
2435                    .ln();
2436            }
2437        }
2438        let pred = predict_survival_location_scale_from_linear_components(
2439            &pred_input.x_time_exit,
2440            &eta_offset_exit,
2441            time_wiggle_knots.as_ref(),
2442            time_wiggle_degree,
2443            time_wiggle_ncols,
2444            &eta_t,
2445            &eta_ls_subject,
2446            link_wiggle_knots.as_ref(),
2447            link_wiggle_degree,
2448            &saved_inverse_link,
2449            &saved_fit,
2450        )
2451        .map_err(|err| format!("survival location-scale predict failed: {err}"))?;
2452        (pred.eta, pred.survival_prob, None, None)
2453    };
2454
2455    let eta_derivative_full = if reduced_parametric_aft {
2456        // Reduced-AFT regime: the warp is `h ≡ 0` and the location carries the
2457        // `−log t` shift, so the standardized-residual time derivative is
2458        // `du/dt = d/dt[inv_sigma·(log t − μ)] = inv_sigma / t` (the fit's
2459        // `qdot = inv_sigma/t`). The time-warp design contributes nothing, so
2460        // reconstruct `eta_derivative` directly from `inv_sigma` and the query
2461        // times rather than from the (empty) time-derivative design (issue #892).
2462        use crate::sigma_link::exp_sigma_inverse_from_eta_scalar;
2463        let beta_log_sigma = saved_fit.beta_log_sigma();
2464        let eta_ls = prepared_sigma_design_view(&pred_input).matrixvectormultiply(&beta_log_sigma)
2465            + &pred_input.eta_log_sigma_offset;
2466        let mut deriv = Array1::<f64>::zeros(eval_exit.len());
2467        for (k, slot) in deriv.iter_mut().enumerate() {
2468            let inv_sigma = exp_sigma_inverse_from_eta_scalar(eta_ls[k]);
2469            let t = eval_exit[k].max(crate::survival::construction::SURVIVAL_TIME_FLOOR);
2470            *slot = inv_sigma / t;
2471        }
2472        deriv
2473    } else {
2474        let x_time_derivative = time_build
2475            .x_derivative_time
2476            .try_to_dense_by_chunks("survival location-scale prediction time-derivative design")?;
2477        location_scale_eta_derivative_components(
2478            &x_time_derivative,
2479            &derivative_offset_exit,
2480            &pred_input.x_time_exit,
2481            &pred_input.eta_time_offset_exit,
2482            time_wiggle_knots.as_ref(),
2483            time_wiggle_degree,
2484            time_wiggle_ncols,
2485            &saved_fit,
2486        )?
2487    };
2488    let hazard_full = location_scale_hazard_from_eta_derivative(
2489        &eta_full,
2490        &eta_derivative_full,
2491        &saved_inverse_link,
2492    )?;
2493
2494    let mut survival = Array2::<f64>::zeros((n, t_cols));
2495    let mut cumulative_hazard = Array2::<f64>::zeros((n, t_cols));
2496    let mut hazard = Array2::<f64>::zeros((n, t_cols));
2497    ndarray::Zip::indexed(&mut survival)
2498        .and(&mut cumulative_hazard)
2499        .and(&mut hazard)
2500        .par_for_each(|(i, j), s, ch, h| {
2501            // Survival-curve origin: at t = 0 everyone is still at risk, so
2502            // S(0) = 1, H(0) = 0 and h(0) = 0 exactly, independent of the
2503            // fitted baseline. Anchor the origin column directly instead of
2504            // routing it through the (probit-survival) baseline, whose index is
2505            // -inf at S0(0) = 1. This matches the transformation / marginal-slope
2506            // predict path's `t <= 0` handling and keeps the default surface grid
2507            // — whose first node is the origin for the `Surv(time, event)`
2508            // right-censored shorthand — evaluable end to end (#1024).
2509            let query_time = if per_row_eval {
2510                age_exit[i]
2511            } else {
2512                eval_times[j]
2513            };
2514            if query_time <= 0.0 {
2515                *s = 1.0;
2516                *ch = 0.0;
2517                *h = 0.0;
2518                return;
2519            }
2520            let k = if per_row_eval { i } else { i * eval_width + j };
2521            let surv = survival_prob_full[k].clamp(SURVIVAL_PROB_MIN_FOR_LOG, 1.0);
2522            *s = surv;
2523            *ch = -surv.ln();
2524            *h = hazard_full[k];
2525        });
2526
2527    let linear_predictor = if per_row_eval {
2528        eta_full.clone()
2529    } else {
2530        Array1::from_shape_fn(n, |i| eta_full[i * eval_width + t_cols])
2531    };
2532    let times = if per_row_eval {
2533        age_exit.to_vec()
2534    } else {
2535        // Cloned (not moved) so the origin-column anchor below can still read the
2536        // per-column query times when assembling the survival standard errors.
2537        eval_times.clone()
2538    };
2539
2540    let survival_se = response_se_full.as_ref().map(|response_se| {
2541        let mut out = Array2::<f64>::zeros((n, t_cols));
2542        ndarray::Zip::indexed(&mut out).par_for_each(|(i, j), slot| {
2543            // S(0) = 1 is a deterministic identity, so its standard error is 0
2544            // at the origin column (consistent with the anchored survival above).
2545            let query_time = if per_row_eval {
2546                age_exit[i]
2547            } else {
2548                eval_times[j]
2549            };
2550            if query_time <= 0.0 {
2551                *slot = 0.0;
2552                return;
2553            }
2554            let k = if per_row_eval { i } else { i * eval_width + j };
2555            *slot = response_se[k].max(0.0);
2556        });
2557        out
2558    });
2559    let eta_se_per_row = eta_se_full.as_ref().map(|eta_se| {
2560        if per_row_eval {
2561            eta_se.clone()
2562        } else {
2563            Array1::from_shape_fn(n, |i| eta_se[i * eval_width + t_cols])
2564        }
2565    });
2566
2567    Ok(SurvivalPredictResult {
2568        times,
2569        hazard,
2570        survival,
2571        cumulative_hazard,
2572        linear_predictor,
2573        likelihood_mode: saved_likelihood_mode,
2574        survival_se,
2575        eta_se: eta_se_per_row,
2576    })
2577}
2578
2579/// Helper: borrow the prepared sigma design back from the pred_input
2580/// without consuming it.  Used so the mean-only fast path can reuse the
2581/// log-sigma design without an extra clone.
2582fn prepared_sigma_design_view(
2583    input: &crate::survival::location_scale::SurvivalLocationScalePredictInput,
2584) -> &gam_linalg::matrix::DesignMatrix {
2585    &input.x_log_sigma
2586}
2587
2588pub(crate) struct LocationScaleEtaComponents {
2589    pub h: Array1<f64>,
2590    pub time_jac: Array2<f64>,
2591    pub eta_t: Array1<f64>,
2592    pub eta_ls: Array1<f64>,
2593    pub inv_sigma: Array1<f64>,
2594}
2595
2596pub(crate) struct LocationScaleTimeWarpComponents {
2597    pub(crate) h: Array1<f64>,
2598    pub(crate) time_jac: Array2<f64>,
2599    pub(crate) time_wiggle_dq: Option<Array1<f64>>,
2600}
2601
2602pub(crate) fn location_scale_time_warp_components(
2603    x_time_exit: &Array2<f64>,
2604    eta_time_offset_exit: &Array1<f64>,
2605    time_wiggle_knots: Option<&Array1<f64>>,
2606    time_wiggle_degree: Option<usize>,
2607    time_wiggle_ncols: usize,
2608    fit: &UnifiedFitResult,
2609) -> Result<LocationScaleTimeWarpComponents, String> {
2610    let n = x_time_exit.nrows();
2611    if eta_time_offset_exit.len() != n {
2612        return Err("survival location-scale time-warp row mismatch across inputs".to_string());
2613    }
2614    let beta_time = fit.beta_time();
2615    if x_time_exit.ncols() != beta_time.len() {
2616        return Err(format!(
2617            "survival location-scale time-warp design mismatch: x_exit={} beta_time={}",
2618            x_time_exit.ncols(),
2619            beta_time.len()
2620        ));
2621    }
2622
2623    let p_time_total = beta_time.len();
2624    let p_wiggle = time_wiggle_ncols.min(p_time_total);
2625    let p_base = p_time_total - p_wiggle;
2626    let beta_base = beta_time.slice(s![..p_base]).to_owned();
2627    let h_base = if p_base > 0 {
2628        x_time_exit.slice(s![.., ..p_base]).dot(&beta_base) + eta_time_offset_exit
2629    } else {
2630        eta_time_offset_exit.clone()
2631    };
2632    let mut h = h_base.clone();
2633    let mut time_jac = x_time_exit.clone();
2634    let mut time_wiggle_dq = None;
2635    if p_wiggle > 0 {
2636        if x_time_exit
2637            .slice(s![.., p_base..p_time_total])
2638            .iter()
2639            .any(|&value| value != 0.0)
2640        {
2641            return Err(
2642                "survival location-scale timewiggle prediction requires zero placeholder tail columns"
2643                    .to_string(),
2644            );
2645        }
2646        let knots = time_wiggle_knots.ok_or_else(|| {
2647            "survival location-scale time-warp: timewiggle coefficients are missing knot metadata"
2648                .to_string()
2649        })?;
2650        let degree = time_wiggle_degree.ok_or_else(|| {
2651            "survival location-scale time-warp: timewiggle coefficients are missing degree metadata"
2652                .to_string()
2653        })?;
2654        let beta_w = beta_time.slice(s![p_base..p_time_total]).to_owned();
2655        let time_basis = crate::wiggle::monotone_wiggle_basis_with_derivative_order(
2656            h_base.view(),
2657            knots,
2658            degree,
2659            0,
2660        )?;
2661        let time_basis_d1 = crate::wiggle::monotone_wiggle_basis_with_derivative_order(
2662            h_base.view(),
2663            knots,
2664            degree,
2665            1,
2666        )?;
2667        if time_basis.ncols() != p_wiggle || time_basis_d1.ncols() != p_wiggle {
2668            return Err(format!(
2669                "survival location-scale time-warp timewiggle mismatch: value basis has {} columns, derivative basis has {}, beta has {}",
2670                time_basis.ncols(),
2671                time_basis_d1.ncols(),
2672                p_wiggle
2673            ));
2674        }
2675        let dq = time_basis_d1.dot(&beta_w) + 1.0;
2676        h = &h_base + &time_basis.dot(&beta_w);
2677        time_jac = Array2::<f64>::zeros((n, p_time_total));
2678        if p_base > 0 {
2679            let scaled_base = crate::survival::location_scale::scale_dense_rows(
2680                &x_time_exit.slice(s![.., ..p_base]).to_owned(),
2681                &dq,
2682            )?;
2683            time_jac.slice_mut(s![.., ..p_base]).assign(&scaled_base);
2684        }
2685        time_jac
2686            .slice_mut(s![.., p_base..p_time_total])
2687            .assign(&time_basis);
2688        time_wiggle_dq = Some(dq);
2689    }
2690
2691    Ok(LocationScaleTimeWarpComponents {
2692        h,
2693        time_jac,
2694        time_wiggle_dq,
2695    })
2696}
2697
2698pub(crate) fn location_scale_eta_components(
2699    x_time_exit: &Array2<f64>,
2700    eta_time_offset_exit: &Array1<f64>,
2701    time_wiggle_knots: Option<&Array1<f64>>,
2702    time_wiggle_degree: Option<usize>,
2703    time_wiggle_ncols: usize,
2704    x_threshold: &gam_linalg::matrix::DesignMatrix,
2705    eta_threshold_offset: &Array1<f64>,
2706    x_log_sigma: &gam_linalg::matrix::DesignMatrix,
2707    eta_log_sigma_offset: &Array1<f64>,
2708    fit: &UnifiedFitResult,
2709) -> Result<LocationScaleEtaComponents, String> {
2710    let n = x_time_exit.nrows();
2711    if x_threshold.nrows() != n
2712        || eta_threshold_offset.len() != n
2713        || x_log_sigma.nrows() != n
2714        || eta_log_sigma_offset.len() != n
2715    {
2716        return Err("survival location-scale eta component row mismatch across inputs".to_string());
2717    }
2718    let time_components = location_scale_time_warp_components(
2719        x_time_exit,
2720        eta_time_offset_exit,
2721        time_wiggle_knots,
2722        time_wiggle_degree,
2723        time_wiggle_ncols,
2724        fit,
2725    )?;
2726    let beta_threshold = fit.beta_threshold();
2727    let beta_log_sigma = fit.beta_log_sigma();
2728    let eta_t = x_threshold.matrixvectormultiply(&beta_threshold) + eta_threshold_offset;
2729    let eta_ls = x_log_sigma.matrixvectormultiply(&beta_log_sigma) + eta_log_sigma_offset;
2730    let inv_sigma = eta_ls.mapv(crate::sigma_link::exp_sigma_inverse_from_eta_scalar);
2731    Ok(LocationScaleEtaComponents {
2732        h: time_components.h,
2733        time_jac: time_components.time_jac,
2734        eta_t,
2735        eta_ls,
2736        inv_sigma,
2737    })
2738}
2739
2740fn location_scale_eta_derivative_components(
2741    x_time_derivative: &Array2<f64>,
2742    derivative_offset_exit: &Array1<f64>,
2743    x_time_exit: &Array2<f64>,
2744    eta_time_offset_exit: &Array1<f64>,
2745    time_wiggle_knots: Option<&Array1<f64>>,
2746    time_wiggle_degree: Option<usize>,
2747    time_wiggle_ncols: usize,
2748    fit: &UnifiedFitResult,
2749) -> Result<Array1<f64>, String> {
2750    let n = x_time_exit.nrows();
2751    if x_time_derivative.nrows() != n
2752        || derivative_offset_exit.len() != n
2753        || eta_time_offset_exit.len() != n
2754    {
2755        return Err(
2756            "survival location-scale hazard derivative row mismatch across inputs".to_string(),
2757        );
2758    }
2759    let beta_time = fit.beta_time();
2760    let p_time_total = beta_time.len();
2761    let p_wiggle = time_wiggle_ncols.min(p_time_total);
2762    let p_base = p_time_total - p_wiggle;
2763    if x_time_exit.ncols() != p_time_total || x_time_derivative.ncols() != p_base {
2764        return Err(format!(
2765            "survival location-scale hazard derivative design mismatch: x_exit={} beta_time={} x_derivative={} base={}",
2766            x_time_exit.ncols(),
2767            p_time_total,
2768            x_time_derivative.ncols(),
2769            p_base
2770        ));
2771    }
2772
2773    let time_components = location_scale_time_warp_components(
2774        x_time_exit,
2775        eta_time_offset_exit,
2776        time_wiggle_knots,
2777        time_wiggle_degree,
2778        time_wiggle_ncols,
2779        fit,
2780    )?;
2781    let beta_base = beta_time.slice(s![..p_base]).to_owned();
2782    let mut eta_derivative = if p_base > 0 {
2783        x_time_derivative.dot(&beta_base) + derivative_offset_exit
2784    } else {
2785        derivative_offset_exit.clone()
2786    };
2787    if let Some(dq) = time_components.time_wiggle_dq.as_ref() {
2788        eta_derivative *= dq;
2789    }
2790    if eta_derivative
2791        .iter()
2792        .any(|value| !(value.is_finite() && *value > 0.0))
2793    {
2794        return Err(
2795            "survival location-scale hazard derivative must be finite and positive".to_string(),
2796        );
2797    }
2798    Ok(eta_derivative)
2799}
2800
2801fn location_scale_hazard_from_eta_derivative(
2802    eta: &Array1<f64>,
2803    eta_derivative: &Array1<f64>,
2804    inverse_link: &InverseLink,
2805) -> Result<Array1<f64>, String> {
2806    if eta.len() != eta_derivative.len() {
2807        return Err(format!(
2808            "survival location-scale hazard row mismatch: eta={} eta_derivative={}",
2809            eta.len(),
2810            eta_derivative.len()
2811        ));
2812    }
2813    let values = eta
2814        .iter()
2815        .zip(eta_derivative.iter())
2816        .map(|(&q, &q_t)| location_scale_hazard_component(q, q_t, inverse_link))
2817        .collect::<Result<Vec<_>, _>>()?;
2818    Ok(Array1::from_vec(values))
2819}
2820
2821fn location_scale_hazard_component(
2822    eta: f64,
2823    eta_derivative: f64,
2824    inverse_link: &InverseLink,
2825) -> Result<f64, String> {
2826    if !(eta.is_finite() && eta_derivative.is_finite() && eta_derivative > 0.0) {
2827        return Err(format!(
2828            "survival location-scale hazard requires finite eta and positive eta_t, got eta={eta}, eta_t={eta_derivative}"
2829        ));
2830    }
2831    match inverse_link {
2832        InverseLink::Standard(StandardLink::Probit) => {
2833            let (_, hazard) = probit_survival_hazard_components(eta, eta_derivative)?;
2834            Ok(hazard)
2835        }
2836        InverseLink::Standard(StandardLink::CLogLog) => {
2837            let (_, hazard) = royston_parmar_survival_hazard_components(eta, eta_derivative)?;
2838            Ok(hazard)
2839        }
2840        InverseLink::Standard(StandardLink::Logit) => {
2841            let failure = if eta >= 0.0 {
2842                1.0 / (1.0 + (-eta).exp())
2843            } else {
2844                let exp_eta = eta.exp();
2845                exp_eta / (1.0 + exp_eta)
2846            };
2847            Ok(failure * eta_derivative)
2848        }
2849        InverseLink::Standard(StandardLink::Identity) => {
2850            let survival = 1.0 - eta;
2851            if !(survival.is_finite() && survival > 0.0) {
2852                return Err(format!(
2853                    "survival location-scale identity link produced invalid survival={survival} at eta={eta}"
2854                ));
2855            }
2856            Ok(eta_derivative / survival)
2857        }
2858        _ => {
2859            let jet = inverse_link_jet_for_inverse_link(inverse_link, eta)
2860                .map_err(|err| format!("survival location-scale inverse-link jet failed: {err}"))?;
2861            let survival = 1.0 - jet.mu;
2862            let hazard = jet.d1 * eta_derivative / survival;
2863            if !(survival.is_finite() && survival > 0.0 && hazard.is_finite() && hazard >= 0.0) {
2864                return Err(format!(
2865                    "survival location-scale inverse link produced invalid hazard components: eta={eta}, eta_t={eta_derivative}, failure={}, d_failure={}, survival={survival}, hazard={hazard}",
2866                    jet.mu, jet.d1
2867                ));
2868            }
2869            Ok(hazard)
2870        }
2871    }
2872}
2873
2874// ---------------------------------------------------------------------------
2875// Shared library helpers (used by the CLI wrapper too).
2876// ---------------------------------------------------------------------------
2877
2878/// Extract the saved survival likelihood mode from the model payload.
2879pub fn require_saved_survival_likelihood_mode(
2880    model: &SavedModel,
2881) -> Result<SurvivalLikelihoodMode, SurvivalPredictError> {
2882    if matches!(&model.family_state, FittedFamily::LatentSurvival { .. }) {
2883        return match model.survival_likelihood.as_deref() {
2884            Some("latent") => Ok(SurvivalLikelihoodMode::Latent),
2885            Some(other) => Err(SurvivalPredictError::MissingFitMetadata { reason: format!(
2886                "saved latent survival model has contradictory survival_likelihood metadata: expected 'latent', got '{other}'"
2887            ) }),
2888            None => Err(SurvivalPredictError::MissingFitMetadata {
2889                reason:
2890                    "saved latent survival model is missing survival_likelihood=latent metadata; refit"
2891                        .to_string(),
2892            }),
2893        };
2894    }
2895    if matches!(&model.family_state, FittedFamily::LatentBinary { .. }) {
2896        return match model.survival_likelihood.as_deref() {
2897            Some("latent-binary") => Ok(SurvivalLikelihoodMode::LatentBinary),
2898            Some(other) => Err(SurvivalPredictError::MissingFitMetadata { reason: format!(
2899                "saved latent binary model has contradictory survival_likelihood metadata: expected 'latent-binary', got '{other}'"
2900            ) }),
2901            None => Err(SurvivalPredictError::MissingFitMetadata {
2902                reason:
2903                    "saved latent binary model is missing survival_likelihood=latent-binary metadata; refit"
2904                        .to_string(),
2905            }),
2906        };
2907    }
2908    let raw = model.survival_likelihood.as_deref().ok_or_else(|| {
2909        "saved survival model is missing survival_likelihood metadata; refit".to_string()
2910    })?;
2911    parse_survival_likelihood_mode(raw).map_err(SurvivalPredictError::from)
2912}
2913
2914/// Baseline config persisted by the saved survival model.
2915pub fn saved_survival_runtime_baseline_config(
2916    model: &SavedModel,
2917) -> Result<SurvivalBaselineConfig, SurvivalPredictError> {
2918    survival_baseline_config_from_model(model).map_err(SurvivalPredictError::from)
2919}
2920
2921/// Resolve the covariate `TermCollectionSpec` for prediction, remapping
2922/// saved training-column indices onto the runtime dataset's layout.
2923pub fn resolve_termspec_for_prediction(
2924    modelspec: &Option<TermCollectionSpec>,
2925    training_headers: Option<&Vec<String>>,
2926    col_map: &HashMap<String, usize>,
2927    spec_label: &str,
2928) -> Result<TermCollectionSpec, SurvivalPredictError> {
2929    let saved = modelspec.as_ref().ok_or_else(|| {
2930        format!(
2931            "model is missing {spec_label}; refit to guarantee train/predict design consistency"
2932        )
2933    })?;
2934    saved.validate_frozen(spec_label)?;
2935    let headers = training_headers.ok_or_else(|| {
2936        "model is missing training_headers; refit to guarantee stable feature mapping at prediction time"
2937            .to_string()
2938    })?;
2939    let remapped = remap_term_collectionspec_columns(saved, headers, col_map)?;
2940    remapped.validate_frozen(spec_label)?;
2941    Ok(remapped)
2942}
2943
2944fn remap_term_collectionspec_columns(
2945    spec: &TermCollectionSpec,
2946    training_headers: &[String],
2947    prediction_column_map: &HashMap<String, usize>,
2948) -> Result<TermCollectionSpec, SurvivalPredictError> {
2949    // Delegate the (variant-exhaustive, easy-to-miss-a-field) walk to the
2950    // single shared authority on TermCollectionSpec; supply the survival
2951    // train→predict resolution as the per-index remap closure.
2952    spec.remap_feature_columns(|index| -> Result<usize, SurvivalPredictError> {
2953        let name = training_headers
2954            .get(index)
2955            .ok_or_else(|| format!("saved training column index {index} is out of bounds"))?;
2956        resolve_role_col(prediction_column_map, name, "prediction")
2957            .map_err(SurvivalPredictError::from)
2958    })
2959}
2960
2961/// Canonical saved fit result for prediction.
2962pub fn fit_result_from_saved_model_for_prediction(
2963    model: &SavedModel,
2964) -> Result<UnifiedFitResult, String> {
2965    model
2966        .fit_result
2967        .clone()
2968        .ok_or_else(|| "model is missing canonical fit_result payload; refit".to_string())
2969}
2970
2971/// Resolve the saved survival location-scale fit result.
2972///
2973/// Returns a `UnifiedFitResult` with the fitted inverse-link state
2974/// re-applied -- matching the CLI's behaviour in
2975/// `main.rs::saved_survival_location_scale_fit_result`.
2976pub fn saved_survival_location_scale_fit_result(
2977    model: &SavedModel,
2978) -> Result<UnifiedFitResult, SurvivalPredictError> {
2979    model.saved_prediction_runtime()?;
2980    let mut fit = model.fit_result.clone().ok_or_else(|| {
2981        "saved location-scale survival model missing canonical fit_result; refit".to_string()
2982    })?;
2983    let inverse_link = resolve_survival_inverse_link_from_saved(model)?;
2984    apply_inverse_link_state_to_fit_result(&mut fit, &inverse_link);
2985    Ok(fit)
2986}
2987
2988pub fn apply_inverse_link_state_to_fit_result(
2989    fit_result: &mut UnifiedFitResult,
2990    inverse_link: &InverseLink,
2991) {
2992    fit_result.fitted_link = match inverse_link {
2993        InverseLink::LatentCLogLog(state) => FittedLinkState::LatentCLogLog { state: *state },
2994        InverseLink::Sas(state) => FittedLinkState::Sas {
2995            state: *state,
2996            covariance: None,
2997        },
2998        InverseLink::BetaLogistic(state) => FittedLinkState::BetaLogistic {
2999            state: *state,
3000            covariance: None,
3001        },
3002        InverseLink::Mixture(state) => FittedLinkState::Mixture {
3003            state: state.clone(),
3004            covariance: None,
3005        },
3006        InverseLink::Standard(_) => FittedLinkState::Standard(None),
3007    };
3008}
3009
3010/// Resolve the saved survival inverse-link from saved link metadata and fitted
3011/// state.
3012pub fn resolve_survival_inverse_link_from_saved(
3013    model: &SavedModel,
3014) -> Result<InverseLink, SurvivalPredictError> {
3015    if let Some(link) = model.link.as_ref() {
3016        return Ok(link.clone());
3017    }
3018    Err(SurvivalPredictError::MissingFitMetadata {
3019        reason: "saved survival model is missing link metadata; refit".to_string(),
3020    })
3021}
3022
3023/// Concatenate referenced 1-D arrays into a single owned `Array1<f64>`.
3024pub fn concat_array1_refs(parts: &[&Array1<f64>]) -> Array1<f64> {
3025    let total: usize = parts.iter().map(|part| part.len()).sum();
3026    let mut out = Array1::<f64>::zeros(total);
3027    let mut offset = 0usize;
3028    for part in parts {
3029        let width = part.len();
3030        out.slice_mut(s![offset..offset + width]).assign(part);
3031        offset += width;
3032    }
3033    out
3034}
3035
3036/// Rebuild the saved baseline-timewiggle entry/exit/derivative design blocks
3037/// from the saved runtime metadata. Returns `None` when the saved model has no
3038/// baseline-timewiggle.
3039pub fn saved_baseline_timewiggle_components(
3040    eta_entry: &Array1<f64>,
3041    eta_exit: &Array1<f64>,
3042    derivative_exit: &Array1<f64>,
3043    model: &SavedModel,
3044) -> Result<Option<(Array2<f64>, Array2<f64>, Array2<f64>)>, SurvivalPredictError> {
3045    match model.saved_baseline_time_wiggle()? {
3046        None => Ok(None),
3047        Some(runtime) => {
3048            runtime.validate_global_monotonicity()?;
3049            let SavedBaselineTimeWiggleRuntime {
3050                knots,
3051                degree,
3052                beta,
3053                ..
3054            } = runtime;
3055            let knots = Array1::from_vec(knots);
3056            let entry = match buildwiggle_block_input_from_knots(
3057                eta_entry.view(),
3058                &knots,
3059                degree,
3060                2,
3061                false,
3062            )?
3063            .design
3064            {
3065                DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
3066                _ => {
3067                    return Err(SurvivalPredictError::IncompatibleSchema {
3068                        reason: "saved baseline-timewiggle entry design must be dense".to_string(),
3069                    });
3070                }
3071            };
3072            let exit = match buildwiggle_block_input_from_knots(
3073                eta_exit.view(),
3074                &knots,
3075                degree,
3076                2,
3077                false,
3078            )?
3079            .design
3080            {
3081                DesignMatrix::Dense(m) => m.to_dense_arc().as_ref().clone(),
3082                _ => {
3083                    return Err(SurvivalPredictError::IncompatibleSchema {
3084                        reason: "saved baseline-timewiggle exit design must be dense".to_string(),
3085                    });
3086                }
3087            };
3088            let betaw = beta;
3089            if entry.ncols() != betaw.len() || exit.ncols() != betaw.len() {
3090                return Err(SurvivalPredictError::IncompatibleSchema {
3091                    reason: format!(
3092                        "saved baseline-timewiggle dimension mismatch: coefficients have {} entries but basis has entry={} exit={}",
3093                        betaw.len(),
3094                        entry.ncols(),
3095                        exit.ncols()
3096                    ),
3097                });
3098            }
3099            let derivative = build_survival_timewiggle_derivative_design(
3100                eta_exit,
3101                derivative_exit,
3102                &knots,
3103                degree,
3104            )
3105            .map_err(|e| {
3106                e.replace(
3107                    "build baseline-timewiggle",
3108                    "evaluate saved baseline-timewiggle",
3109                )
3110            })?;
3111            if derivative.ncols() != betaw.len() {
3112                return Err(SurvivalPredictError::IncompatibleSchema {
3113                    reason: format!(
3114                        "saved baseline-timewiggle derivative dimension mismatch: coefficients have {} entries but derivative basis has {} columns",
3115                        betaw.len(),
3116                        derivative.ncols()
3117                    ),
3118                });
3119            }
3120            Ok(Some((entry, exit, derivative)))
3121        }
3122    }
3123}
3124
3125/// Build the saved survival marginal-slope predictor along with the matching
3126/// `PredictInput` and a `UnifiedFitResult` repackaged into the layout
3127/// `BernoulliMarginalSlopePredictor::from_unified` expects.
3128///
3129/// This is the single source of truth for assembling the marginal-slope
3130/// predictor at predict time. The CLI's `gam predict` flow and the
3131/// library-side `predict_survival` both call into this helper so they share
3132/// bit-identical eta math (link-deviation + score-warp replay included).
3133pub fn build_saved_survival_marginal_slope_predictor(
3134    model: &SavedModel,
3135    fit_saved: &UnifiedFitResult,
3136    z_name: &str,
3137    z: &Array1<f64>,
3138    cov_design: &DesignMatrix,
3139    logslope_design: &DesignMatrix,
3140    time_build: &SurvivalTimeBuildOutput,
3141    eta_offset_entry: &Array1<f64>,
3142    eta_offset_exit: &Array1<f64>,
3143    derivative_offset_exit: &Array1<f64>,
3144    primary_offset: &Array1<f64>,
3145    noise_offset: &Array1<f64>,
3146) -> Result<
3147    (
3148        BernoulliMarginalSlopePredictor,
3149        PredictInput,
3150        UnifiedFitResult,
3151    ),
3152    SurvivalPredictError,
3153> {
3154    let saved_runtime = model.saved_prediction_runtime()?;
3155    if saved_runtime.link_wiggle.is_some() {
3156        return Err(SurvivalPredictError::MissingFitMetadata {
3157            reason:
3158                "saved survival marginal-slope model contains legacy linkwiggle metadata; refit with the anchored link-deviation runtime"
3159                    .to_string(),
3160        });
3161    }
3162
3163    let saved_score_runtime = saved_runtime.score_warp;
3164    let saved_link_runtime = saved_runtime.link_deviation;
3165    // #461: the absorbed Stage-1 influence block (when present) is the trailing
3166    // block. Its `γ` is DROPPED at predict (the orthogonalized β̂ is a
3167    // training-fit property), so it is NOT read below — but it IS persisted, so
3168    // the saved block count includes it.
3169    let influence_absorber_width = saved_runtime.influence_absorber_width;
3170    let blocks = &fit_saved.blocks;
3171    let expected_blocks = 3
3172        + usize::from(saved_score_runtime.is_some())
3173        + usize::from(saved_link_runtime.is_some())
3174        + usize::from(influence_absorber_width.is_some());
3175    if blocks.len() != expected_blocks {
3176        return Err(SurvivalPredictError::IncompatibleSchema {
3177            reason: format!(
3178                "saved survival marginal-slope model requires {} blocks [time, marginal, slope{}{}{}], got {}",
3179                expected_blocks,
3180                if saved_score_runtime.is_some() {
3181                    ", score-warp"
3182                } else {
3183                    ""
3184                },
3185                if saved_link_runtime.is_some() {
3186                    ", link-deviation"
3187                } else {
3188                    ""
3189                },
3190                if influence_absorber_width.is_some() {
3191                    ", influence-absorber(dropped)"
3192                } else {
3193                    ""
3194                },
3195                blocks.len(),
3196            ),
3197        });
3198    }
3199
3200    let beta_time = &blocks[0].beta;
3201    let beta_marginal = &blocks[1].beta;
3202    let beta_logslope = &blocks[2].beta;
3203    if let Some(runtime) = saved_score_runtime.as_ref() {
3204        let beta = &blocks[3].beta;
3205        if beta.len() != runtime.basis_dim {
3206            return Err(SurvivalPredictError::IncompatibleSchema {
3207                reason: format!(
3208                    "saved survival marginal-slope score-warp coefficient mismatch: beta has {} entries but runtime expects {}",
3209                    beta.len(),
3210                    runtime.basis_dim
3211                ),
3212            });
3213        }
3214    }
3215    if let Some(runtime) = saved_link_runtime.as_ref() {
3216        let idx = 3 + usize::from(saved_score_runtime.is_some());
3217        let beta = &blocks[idx].beta;
3218        if beta.len() != runtime.basis_dim {
3219            return Err(SurvivalPredictError::IncompatibleSchema {
3220                reason: format!(
3221                    "saved survival marginal-slope link-deviation coefficient mismatch: beta has {} entries but runtime expects {}",
3222                    beta.len(),
3223                    runtime.basis_dim
3224                ),
3225            });
3226        }
3227    }
3228
3229    if beta_marginal.len() != cov_design.ncols() {
3230        return Err(SurvivalPredictError::IncompatibleSchema {
3231            reason: format!(
3232                "saved survival marginal-slope marginal coefficient mismatch: beta has {} entries but baseline design has {} columns",
3233                beta_marginal.len(),
3234                cov_design.ncols()
3235            ),
3236        });
3237    }
3238    if beta_logslope.len() != logslope_design.ncols() {
3239        return Err(SurvivalPredictError::IncompatibleSchema {
3240            reason: format!(
3241                "saved survival marginal-slope slope coefficient mismatch: beta has {} entries but slope design has {} columns",
3242                beta_logslope.len(),
3243                logslope_design.ncols()
3244            ),
3245        });
3246    }
3247
3248    let p_time_base = time_build.x_exit_time.ncols();
3249    let saved_timewiggle = saved_runtime.baseline_time_wiggle;
3250    let p_timewiggle = saved_timewiggle
3251        .as_ref()
3252        .map_or(0, |runtime| runtime.beta.len());
3253    if beta_time.len() != p_time_base + p_timewiggle {
3254        return Err(SurvivalPredictError::IncompatibleSchema {
3255            reason: format!(
3256                "saved survival marginal-slope time coefficient mismatch: beta has {} entries but expected base={} plus timewiggle={}",
3257                beta_time.len(),
3258                p_time_base,
3259                p_timewiggle
3260            ),
3261        });
3262    }
3263
3264    let beta_time_base = beta_time.slice(s![..p_time_base]).to_owned();
3265    // `cov_design · beta_marginal` is row-only (no time dependence); hoist it
3266    // once so both the entry- and exit-time baselines share the single
3267    // matrix-vector multiply instead of recomputing it.
3268    let cov_eta_marginal = cov_design.dot(beta_marginal);
3269    let q_entry_base = time_build.x_entry_time.dot(&beta_time_base)
3270        + &cov_eta_marginal
3271        + eta_offset_entry
3272        + primary_offset;
3273    let q_exit_base = time_build.x_exit_time.dot(&beta_time_base)
3274        + &cov_eta_marginal
3275        + eta_offset_exit
3276        + primary_offset;
3277    let qd_exit_base = time_build.x_derivative_time.dot(&beta_time_base) + derivative_offset_exit;
3278
3279    let mut q_design_parts = vec![time_build.x_exit_time.clone()];
3280    if saved_timewiggle.is_some() {
3281        let (_, exit_w, _) = saved_baseline_timewiggle_components(
3282            &q_entry_base,
3283            &q_exit_base,
3284            &qd_exit_base,
3285            model,
3286        )?
3287        .ok_or_else(|| {
3288            "saved survival marginal-slope model is missing baseline-timewiggle runtime metadata"
3289                .to_string()
3290        })?;
3291        if exit_w.ncols() != p_timewiggle {
3292            return Err(SurvivalPredictError::IncompatibleSchema {
3293                reason: format!(
3294                    "saved survival marginal-slope timewiggle design mismatch: rebuilt {} columns but runtime expects {}",
3295                    exit_w.ncols(),
3296                    p_timewiggle
3297                ),
3298            });
3299        }
3300        q_design_parts.push(DesignMatrix::from(exit_w));
3301    }
3302    q_design_parts.push(cov_design.clone());
3303    let q_design = DesignMatrix::hstack(q_design_parts)?;
3304
3305    let combined_q_beta = concat_array1_refs(&[beta_time, beta_marginal]);
3306    let combined_q_lambdas = concat_array1_refs(&[&blocks[0].lambdas, &blocks[1].lambdas]);
3307    let mut predictor_blocks = Vec::with_capacity(
3308        2 + usize::from(saved_score_runtime.is_some()) + usize::from(saved_link_runtime.is_some()),
3309    );
3310    predictor_blocks.push(FittedBlock {
3311        beta: combined_q_beta.clone(),
3312        role: BlockRole::Mean,
3313        edf: blocks[0].edf + blocks[1].edf,
3314        lambdas: combined_q_lambdas,
3315    });
3316    predictor_blocks.push(FittedBlock {
3317        beta: beta_logslope.clone(),
3318        role: BlockRole::Scale,
3319        edf: blocks[2].edf,
3320        lambdas: blocks[2].lambdas.clone(),
3321    });
3322    if saved_score_runtime.is_some() {
3323        let mut block = blocks[3].clone();
3324        block.role = BlockRole::Mean;
3325        predictor_blocks.push(block);
3326    }
3327    if saved_link_runtime.is_some() {
3328        let idx = 3 + usize::from(saved_score_runtime.is_some());
3329        let mut block = blocks[idx].clone();
3330        block.role = BlockRole::LinkWiggle;
3331        predictor_blocks.push(block);
3332    }
3333
3334    let mut predictor_fit = fit_saved.clone();
3335    predictor_fit.blocks = predictor_blocks;
3336    predictor_fit.beta = concat_array1_refs(
3337        &predictor_fit
3338            .blocks
3339            .iter()
3340            .map(|block| &block.beta)
3341            .collect::<Vec<_>>(),
3342    );
3343    predictor_fit.block_states.clear();
3344
3345    let predictor = BernoulliMarginalSlopePredictor::from_unified(
3346        &predictor_fit,
3347        z_name.to_string(),
3348        model.latent_z_normalization.ok_or_else(|| {
3349            "saved survival marginal-slope model missing latent_z_normalization".to_string()
3350        })?,
3351        model.latent_measure.clone().ok_or_else(|| {
3352            "saved survival marginal-slope model missing latent_measure".to_string()
3353        })?,
3354        0.0,
3355        model.logslope_baseline.ok_or_else(|| {
3356            "saved survival marginal-slope model missing logslope_baseline".to_string()
3357        })?,
3358        model
3359            .resolved_inverse_link()?
3360            .unwrap_or(InverseLink::Standard(StandardLink::Probit)),
3361        model
3362            .family_state
3363            .frailty()
3364            .cloned()
3365            .unwrap_or(FrailtySpec::None),
3366        saved_score_runtime,
3367        saved_link_runtime,
3368        model.latent_z_rank_int_calibration.clone(),
3369        // Survival marginal-slope never engages the BMS-only conditional Auto
3370        // gate (#905); the field is always `None` for survival fits.
3371        model.latent_z_conditional_calibration.clone(),
3372    )?;
3373
3374    let pred_input = PredictInput {
3375        design: q_design,
3376        offset: eta_offset_exit + primary_offset,
3377        design_noise: Some(logslope_design.clone()),
3378        offset_noise: Some(noise_offset.clone()),
3379        auxiliary_scalar: Some(z.clone()),
3380        auxiliary_matrix: None,
3381    };
3382
3383    Ok((predictor, pred_input, predictor_fit))
3384}
3385
3386#[cfg(test)]
3387mod tests {
3388    use super::*;
3389    use crate::probability::{normal_cdf, normal_pdf};
3390
3391    #[test]
3392    fn probit_survival_hazard_uses_density_over_survival() {
3393        let eta = 2.0;
3394        let eta_t = 0.3;
3395
3396        let (cum, hazard) =
3397            probit_survival_hazard_components(eta, eta_t).expect("valid components");
3398
3399        let survival = normal_cdf(-eta);
3400        let expected_cum = -survival.ln();
3401        let expected_hazard = normal_pdf(eta) * eta_t / survival;
3402        assert!((cum - expected_cum).abs() <= 1e-14);
3403        assert!((hazard - expected_hazard).abs() <= 1e-14);
3404    }
3405
3406    #[test]
3407    fn probit_survival_hazard_stays_finite_in_right_tail() {
3408        let eta = 40.0;
3409        let eta_t = 9.694_340_360_912_401e-5;
3410
3411        let event_density =
3412            (-0.5_f64 * eta * eta).exp() / (2.0 * std::f64::consts::PI).sqrt() * eta_t;
3413        assert_eq!(event_density, 0.0);
3414
3415        let (cum, hazard) =
3416            probit_survival_hazard_components(eta, eta_t).expect("valid tail components");
3417        assert!(cum > 800.0, "right-tail cumulative hazard was {cum}");
3418        assert!(
3419            (3.87e-3..3.89e-3).contains(&hazard),
3420            "right-tail hazard was {hazard}"
3421        );
3422    }
3423
3424    #[test]
3425    fn probit_survival_hazard_accepts_zero_time_derivative_as_flat_hazard() {
3426        let (cum, hazard) =
3427            probit_survival_hazard_components(1.0, 0.0).expect("zero derivative is flat hazard");
3428        assert!(cum > 0.0);
3429        assert_eq!(hazard, 0.0);
3430    }
3431
3432    #[test]
3433    fn marginal_slope_index_derivative_clamps_extrapolation_negative_to_flat_hazard() {
3434        // The #1040 end-to-end blocker: at a prediction horizon outside the
3435        // training exit times, the penalized baseline derivative q'(t) can dip
3436        // negative (e.g. the reported eta_t=-0.00135), producing a negative
3437        // index time-derivative the strict validator used to reject. The
3438        // physical hazard floor is 0, so the clamp must turn it into a flat
3439        // hazard the validator accepts — keeping predict/CIF runnable.
3440        let deta_dq = (1.0_f64 + 0.4 * 0.4).sqrt(); // rigid c = sqrt(1+sb^2) >= 1
3441        let qd_with_wiggle = -1.35e-3;
3442        let eta_t = marginal_slope_index_derivative_at_horizon(deta_dq, qd_with_wiggle);
3443        assert_eq!(
3444            eta_t, 0.0,
3445            "negative extrapolation derivative must clamp to 0"
3446        );
3447        // Downstream validator now accepts it as a flat-hazard point.
3448        let (cum, hazard) = probit_survival_hazard_components(-0.563, eta_t)
3449            .expect("clamped flat-hazard prediction must validate");
3450        assert!(
3451            cum >= 0.0,
3452            "cumulative hazard must be well-posed, got {cum}"
3453        );
3454        assert_eq!(
3455            hazard, 0.0,
3456            "clamped derivative gives zero instantaneous hazard"
3457        );
3458    }
3459
3460    #[test]
3461    fn marginal_slope_index_derivative_preserves_positive_and_nonfinite() {
3462        // A genuinely positive derivative passes through unchanged (scaled by
3463        // the chain factor), and a non-finite value is left for the strict
3464        // validator to reject as a real numerical failure rather than masked.
3465        let positive = marginal_slope_index_derivative_at_horizon(1.25, 0.8);
3466        assert!(
3467            (positive - 1.0).abs() <= 1e-15,
3468            "positive derivative scaled by chain factor"
3469        );
3470        let nonfinite = marginal_slope_index_derivative_at_horizon(1.25, f64::NAN);
3471        assert!(
3472            nonfinite.is_nan(),
3473            "non-finite derivative passes through unclamped"
3474        );
3475        assert!(
3476            probit_survival_hazard_components(0.5, nonfinite).is_err(),
3477            "non-finite derivative must still be rejected by the validator"
3478        );
3479    }
3480
3481    #[test]
3482    fn probit_survival_hazard_rejects_infinite_time_derivative() {
3483        let err = probit_survival_hazard_components(1.0, f64::INFINITY)
3484            .expect_err("infinite derivative should be invalid");
3485        assert!(
3486            err.to_string()
3487                .contains("invalid survival index derivative")
3488        );
3489    }
3490
3491    #[test]
3492    fn probit_survival_hazard_rejects_nan_inputs() {
3493        // The upstream input gate is the only line that rejects NaN — the
3494        // output gate (`>= 0.0`) is dead-code for finite input because
3495        // `signed_probit_logcdf_and_mills_ratio` is provably NaN-free on the
3496        // finite domain (every internal branch clamps `erfcx`/`cdf` away from
3497        // zero). Pin both NaN slots so the input gate cannot regress.
3498        let err_eta =
3499            probit_survival_hazard_components(f64::NAN, 0.5).expect_err("NaN eta must be rejected");
3500        assert!(
3501            err_eta
3502                .to_string()
3503                .contains("invalid survival index derivative")
3504        );
3505        let err_dt = probit_survival_hazard_components(1.0, f64::NAN)
3506            .expect_err("NaN eta_derivative must be rejected");
3507        assert!(
3508            err_dt
3509                .to_string()
3510                .contains("invalid survival index derivative")
3511        );
3512    }
3513
3514    #[test]
3515    fn probit_survival_hazard_rejects_negative_time_derivative() {
3516        // The CDF S(t) = Phi(-eta(t)) is monotone in t iff eta'(t) > 0. A
3517        // negative slope would give a non-monotone survival curve, which is
3518        // not a valid survival function.
3519        let err = probit_survival_hazard_components(1.0, -0.5)
3520            .expect_err("negative derivative should be invalid");
3521        assert!(
3522            err.to_string()
3523                .contains("invalid survival index derivative")
3524        );
3525    }
3526
3527    #[test]
3528    fn royston_parmar_hazard_is_cumulative_hazard_derivative() {
3529        let eta = 2.0_f64.ln();
3530        let eta_t = 0.25;
3531
3532        let (cum, hazard) =
3533            royston_parmar_survival_hazard_components(eta, eta_t).expect("valid components");
3534
3535        assert!((cum - 2.0).abs() <= 1e-14);
3536        assert!((hazard - 0.5).abs() <= 1e-14);
3537        assert_ne!(hazard, cum);
3538    }
3539
3540    #[test]
3541    fn royston_parmar_hazard_rejects_negative_log_hazard_derivative() {
3542        // A negative time-derivative of log Λ(t) means a *decreasing* cumulative
3543        // hazard — not a valid survival model. Only the genuinely-negative slope
3544        // is rejected; the zero boundary is valid (see the sibling test below).
3545        let err = royston_parmar_survival_hazard_components(0.0, -0.5)
3546            .expect_err("negative derivative should be invalid");
3547        assert!(
3548            err.to_string()
3549                .contains("invalid log-cumulative-hazard derivative")
3550        );
3551    }
3552
3553    #[test]
3554    fn royston_parmar_hazard_accepts_zero_derivative_as_flat_boundary() {
3555        // #1564: a monotone I-spline cumulative hazard is flat beyond its last
3556        // interior knot, so `d(log Λ)/dt == 0` exactly on any grid node past the
3557        // training support. That is a *valid* prediction (zero instantaneous
3558        // hazard, locally constant survival), not a numerical failure. The old
3559        // strict `> 0.0` gate rejected it and crashed saved-model RP predict.
3560        let eta = 1.9909019457445971_f64; // the exact η from the #1564 report
3561        let (cum, hazard) = royston_parmar_survival_hazard_components(eta, 0.0)
3562            .expect("zero derivative is a valid flat boundary, not an error");
3563        assert!((cum - eta.exp()).abs() <= 1e-12, "cum = Λ(t) = exp(η)");
3564        assert_eq!(
3565            hazard, 0.0,
3566            "flat cumulative hazard ⇒ zero instantaneous hazard"
3567        );
3568        // Survival is finite and well-defined at the boundary.
3569        let survival = (-cum).exp().clamp(0.0, 1.0);
3570        assert!(survival.is_finite() && (0.0..=1.0).contains(&survival));
3571    }
3572
3573    #[test]
3574    fn royston_parmar_hazard_zero_derivative_in_saturated_tail_is_zero_not_nan() {
3575        // The dangerous corner: a saturated tail (η large ⇒ Λ = exp(η) = +∞) that
3576        // also lands past the I-spline support (derivative == 0). The naive
3577        // product `+∞ * 0.0` is `NaN`, which would (a) trip the components guard
3578        // and (b) serialize to JSON `null` and break the Python parse (#1564,
3579        // bug 1). The hazard must resolve to the mathematically correct `0`.
3580        let eta = 1000.0_f64;
3581        assert!(
3582            eta.exp().is_infinite(),
3583            "test premise: exp(1000) overflows to +∞"
3584        );
3585        assert!(
3586            (f64::INFINITY * 0.0).is_nan(),
3587            "test premise: the naive product is NaN"
3588        );
3589        let (cum, hazard) = royston_parmar_survival_hazard_components(eta, 0.0)
3590            .expect("saturated + flat boundary must be valid");
3591        assert!(cum.is_infinite() && cum > 0.0, "cum saturates to +∞");
3592        assert_eq!(hazard, 0.0, "hazard at a flat boundary is 0, never NaN");
3593    }
3594
3595    #[test]
3596    fn royston_parmar_hazard_propagates_saturation_as_infinity() {
3597        // η = log Λ(t); a saturated RP fit can drive η well past the
3598        // exp(709.78)≈f64::MAX boundary in the right tail. The math is
3599        // S(t)→0, h(t)→∞; the helper must not reject this regime, because the
3600        // inner solver has already accepted the underlying fit.
3601        let eta = 1000.0_f64;
3602        let eta_t = 0.5_f64;
3603        assert!(eta.exp().is_infinite(), "test premise: exp(1000) overflows");
3604
3605        let (cum, hazard) = royston_parmar_survival_hazard_components(eta, eta_t)
3606            .expect("saturated RP fit must yield a result, not an error");
3607        assert!(cum.is_infinite() && cum > 0.0, "expected +∞ cum, got {cum}");
3608        assert!(
3609            hazard.is_infinite() && hazard > 0.0,
3610            "expected +∞ hazard, got {hazard}"
3611        );
3612
3613        // Consumer materializes survival via exp(-cum).clamp(0,1).
3614        let survival = (-cum).exp().clamp(0.0, 1.0);
3615        assert_eq!(survival, 0.0, "saturated cum_hazard must give survival 0");
3616    }
3617
3618    #[test]
3619    fn royston_parmar_hazard_rejects_nan_eta() {
3620        let err = royston_parmar_survival_hazard_components(f64::NAN, 0.5)
3621            .expect_err("NaN eta should be invalid");
3622        assert!(
3623            err.to_string()
3624                .contains("invalid log-cumulative-hazard derivative")
3625        );
3626    }
3627
3628    #[test]
3629    fn royston_parmar_hazard_left_tail_collapses_to_zero() {
3630        // η = log Λ(t); η → -∞ means Λ(t) → 0, so cum_hazard underflows to 0
3631        // and hazard rate underflows to 0. Survival → 1. No error.
3632        let eta = -1000.0_f64;
3633        let eta_t = 2.0_f64;
3634        assert_eq!(eta.exp(), 0.0, "test premise: exp(-1000) underflows to 0");
3635
3636        let (cum, hazard) = royston_parmar_survival_hazard_components(eta, eta_t)
3637            .expect("RP left tail must remain valid");
3638        assert_eq!(
3639            cum, 0.0,
3640            "left-tail cum_hazard should underflow to 0, got {cum}"
3641        );
3642        assert_eq!(
3643            hazard, 0.0,
3644            "left-tail hazard should underflow to 0, got {hazard}"
3645        );
3646
3647        // Consumer: survival = exp(-0) = 1.
3648        let survival = (-cum).exp().clamp(0.0, 1.0);
3649        assert_eq!(survival, 1.0);
3650    }
3651
3652    #[test]
3653    fn probit_survival_hazard_left_tail_collapses_to_zero() {
3654        // η→-∞ mirror of the right-tail test: survival → 1, hazard → 0.
3655        // Asymptote: Mills(η) = φ(η)/Φ(-η) → 0 as η → -∞ (φ underflows,
3656        // Φ(-η) → 1).  No error, no NaN, no spurious negativity.
3657        let eta = -40.0_f64;
3658        let eta_t = 1.5_f64;
3659
3660        let (cum, hazard) =
3661            probit_survival_hazard_components(eta, eta_t).expect("left tail must remain valid");
3662        assert!(
3663            (0.0..1e-300).contains(&cum),
3664            "left-tail cum should be ~0, got {cum}"
3665        );
3666        assert_eq!(
3667            hazard, 0.0,
3668            "left-tail hazard should underflow to 0, got {hazard}"
3669        );
3670    }
3671
3672    #[test]
3673    fn location_scale_logit_hazard_is_failure_slope_over_survival() {
3674        let eta = 0.7;
3675        let eta_t = 0.4;
3676
3677        let hazard = location_scale_hazard_component(
3678            eta,
3679            eta_t,
3680            &InverseLink::Standard(StandardLink::Logit),
3681        )
3682        .expect("valid logit hazard");
3683
3684        let failure = 1.0 / (1.0 + (-eta).exp());
3685        assert!((hazard - failure * eta_t).abs() <= 1e-14);
3686    }
3687
3688    #[test]
3689    fn location_scale_cloglog_hazard_matches_log_cumulative_hazard_derivative() {
3690        let eta = 1.5;
3691        let eta_t = 0.2;
3692
3693        let hazard = location_scale_hazard_component(
3694            eta,
3695            eta_t,
3696            &InverseLink::Standard(StandardLink::CLogLog),
3697        )
3698        .expect("valid cloglog hazard");
3699
3700        assert!((hazard - eta.exp() * eta_t).abs() <= 1e-14);
3701    }
3702
3703    // ---- IPCW Brier score (Graf et al. 1999) -------------------------------
3704
3705    #[test]
3706    fn kaplan_meier_censoring_is_right_continuous_step() {
3707        // Two censorings (events flipped) at t=4 and t=8; deaths at t=2,6.
3708        let time = [2.0, 4.0, 6.0, 8.0];
3709        let event = [1.0, 0.0, 1.0, 0.0];
3710        let g = KaplanMeier::fit_censoring(&time, &event);
3711        // Before the first censoring the censoring-survival is 1.
3712        assert!((g.at(0.0) - 1.0).abs() <= 1e-15);
3713        assert!((g.at(2.0) - 1.0).abs() <= 1e-15);
3714        assert!((g.at(3.999) - 1.0).abs() <= 1e-15);
3715        // At t=4 the at-risk set {4,6,8} loses one to censoring: G = 2/3.
3716        assert!((g.at(4.0) - 2.0 / 3.0).abs() <= 1e-12);
3717        assert!((g.at(5.0) - 2.0 / 3.0).abs() <= 1e-12);
3718        // A death at t=6 does not move the censoring KM.
3719        assert!((g.at(6.0) - 2.0 / 3.0).abs() <= 1e-12);
3720        // At t=8 the last (sole) at-risk subject is censored: G collapses to 0.
3721        assert!(g.at(8.0).abs() <= 1e-15);
3722    }
3723
3724    #[test]
3725    fn ipcw_brier_no_censoring_reduces_to_plain_brier() {
3726        // With no censoring G(t) ≡ 1, so the IPCW Brier is the ordinary Brier of
3727        // the predicted survival against the alive-indicator I(T_i > tau).
3728        let s_pred = [0.3, 0.7, 0.6, 0.2];
3729        let time = [2.0, 8.0, 10.0, 3.0];
3730        let event = [1.0, 1.0, 0.0, 1.0];
3731        let tau = 5.0;
3732        let g = KaplanMeier::fit_censoring(&time, &event);
3733        let bs = ipcw_brier_score(&s_pred, &time, &event, tau, |t| g.at(t)).unwrap();
3734        // targets: dead→0 (subj1,4), alive→1 (subj2,3).
3735        let expected =
3736            (0.3f64.powi(2) + (1.0 - 0.7f64).powi(2) + (1.0 - 0.6f64).powi(2) + 0.2f64.powi(2))
3737                / 4.0;
3738        assert!(
3739            (bs - expected).abs() <= 1e-12,
3740            "bs={bs} expected={expected}"
3741        );
3742    }
3743
3744    #[test]
3745    fn ipcw_brier_reweights_by_inverse_censoring_probability() {
3746        // Hand-computed Graf estimator with real censoring weights.
3747        // times/events: death@2, cens@4, death@6, cens@8; tau=5.
3748        // Censoring KM: G(5)=2/3 (one censoring at t=4 among {4,6,8}); G(2)=1.
3749        let s_pred = [0.4, 0.5, 0.7, 0.8];
3750        let time = [2.0, 4.0, 6.0, 8.0];
3751        let event = [1.0, 0.0, 1.0, 0.0];
3752        let tau = 5.0;
3753        let g = KaplanMeier::fit_censoring(&time, &event);
3754        let bs = ipcw_brier_score(&s_pred, &time, &event, tau, |t| g.at(t)).unwrap();
3755        // subj1 dead by 5: weight 1/G(2)=1, contrib 0.4²=0.16.
3756        // subj2 censored before 5: contributes 0.
3757        // subj3 alive: weight 1/G(5)=1.5, contrib 1.5·0.3²=0.135.
3758        // subj4 alive: weight 1/G(5)=1.5, contrib 1.5·0.2²=0.06.
3759        let expected = (0.16 + 0.0 + 0.135 + 0.06) / 4.0;
3760        assert!(
3761            (bs - expected).abs() <= 1e-12,
3762            "bs={bs} expected={expected}"
3763        );
3764    }
3765
3766    #[test]
3767    fn ipcw_brier_drops_invalid_rows_from_both_numerator_and_denominator() {
3768        // A NaN-time row and a non-positive-time row must not be counted at all.
3769        let s_pred = [0.3, 0.7, 0.5, 0.5];
3770        let time = [2.0, 8.0, f64::NAN, -1.0];
3771        let event = [1.0, 1.0, 1.0, 0.0];
3772        let g = KaplanMeier::fit_censoring(&time, &event);
3773        let bs = ipcw_brier_score(&s_pred, &time, &event, 5.0, |t| g.at(t)).unwrap();
3774        // Only subj1 (dead, contrib 0.3²) and subj2 (alive, contrib 0.3²) count;
3775        // censoring KM has no censorings so G≡1.
3776        let expected = (0.3f64.powi(2) + (1.0 - 0.7f64).powi(2)) / 2.0;
3777        assert!(
3778            (bs - expected).abs() <= 1e-12,
3779            "bs={bs} expected={expected}"
3780        );
3781    }
3782
3783    #[test]
3784    fn integrated_ipcw_brier_of_constant_brier_is_that_constant() {
3785        // A survival matrix whose every column equals a perfect classifier yields
3786        // BS(t)=0 at every grid point, so the integral is 0.
3787        let time = [2.0, 8.0, 10.0, 3.0];
3788        let event = [1.0, 1.0, 0.0, 1.0];
3789        let grid = [0.0, 1.0, 2.5, 4.0, 6.0];
3790        // Perfect prediction at every grid time given the (no-censoring) data is
3791        // not generally achievable, so instead test the integral of a literally
3792        // constant-in-time Brier: replicate one column across the grid.
3793        let col = [0.3, 0.7, 0.6, 0.2];
3794        let mut surv = Array2::<f64>::zeros((4, grid.len()));
3795        for k in 0..grid.len() {
3796            for i in 0..4 {
3797                surv[[i, k]] = col[i];
3798            }
3799        }
3800        let g = KaplanMeier::fit_censoring(&time, &event);
3801        let per_time = ipcw_brier_score(&col, &time, &event, grid[2], |t| g.at(t)).unwrap();
3802        // Because the predicted survival is identical at every grid time, BS(t)
3803        // is *not* constant (tau changes which subjects are "alive"), so use a
3804        // direct trapezoid as the oracle.
3805        let mut oracle_pts = Vec::new();
3806        for k in 0..grid.len() {
3807            oracle_pts.push((
3808                grid[k],
3809                ipcw_brier_score(&col, &time, &event, grid[k], |t| g.at(t)).unwrap(),
3810            ));
3811        }
3812        let mut integral = 0.0;
3813        for w in oracle_pts.windows(2) {
3814            integral += 0.5 * (w[0].1 + w[1].1) * (w[1].0 - w[0].0);
3815        }
3816        let oracle = integral / (grid[grid.len() - 1] - grid[0]);
3817        let ibs =
3818            integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, f64::INFINITY, |t| {
3819                g.at(t)
3820            })
3821            .unwrap();
3822        assert!((ibs - oracle).abs() <= 1e-12, "ibs={ibs} oracle={oracle}");
3823        // Sanity: per-time value is in a sensible [0,1]-ish range.
3824        assert!(per_time >= 0.0);
3825    }
3826
3827    #[test]
3828    fn integrated_ipcw_brier_respects_the_horizon_cutoff() {
3829        let time = [2.0, 8.0, 10.0, 3.0];
3830        let event = [1.0, 1.0, 0.0, 1.0];
3831        let grid = [0.0, 2.0, 4.0, 100.0];
3832        let col = [0.3, 0.7, 0.6, 0.2];
3833        let mut surv = Array2::<f64>::zeros((4, grid.len()));
3834        for k in 0..grid.len() {
3835            for i in 0..4 {
3836                surv[[i, k]] = col[i];
3837            }
3838        }
3839        let g = KaplanMeier::fit_censoring(&time, &event);
3840        // Horizon 5 drops the extrapolation point at t=100: integral runs [0,4].
3841        let restricted =
3842            integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, 5.0, |t| g.at(t))
3843                .unwrap();
3844        let full =
3845            integrated_ipcw_brier_score(surv.view(), &time, &event, &grid, f64::INFINITY, |t| {
3846                g.at(t)
3847            })
3848            .unwrap();
3849        // The huge [4,100] tail interval dominates the full integral, so the two
3850        // must differ substantially — the horizon guard is doing real work.
3851        assert!(
3852            (restricted - full).abs() > 1e-3,
3853            "horizon cutoff had no effect: restricted={restricted} full={full}"
3854        );
3855    }
3856
3857    #[test]
3858    fn integrated_ipcw_brier_rejects_malformed_grids() {
3859        let time = [2.0, 8.0];
3860        let event = [1.0, 0.0];
3861        let surv = Array2::<f64>::from_elem((2, 3), 0.5);
3862        let g = KaplanMeier::fit_censoring(&time, &event);
3863        // Non-increasing grid.
3864        let bad = [0.0, 2.0, 1.0];
3865        assert!(
3866            integrated_ipcw_brier_score(surv.view(), &time, &event, &bad, f64::INFINITY, |t| g
3867                .at(t))
3868            .is_none()
3869        );
3870        // Grid width mismatched to the survival matrix.
3871        let short = [0.0, 1.0];
3872        assert!(
3873            integrated_ipcw_brier_score(surv.view(), &time, &event, &short, f64::INFINITY, |t| g
3874                .at(t))
3875            .is_none()
3876        );
3877    }
3878}