Skip to main content

gam_models/survival/latent/
survival.rs

1//! Jointly learned latent-frailty survival and binary deployment families with
2//! a live time/baseline block.
3//!
4//! Model:
5//!   H_0(a) = exp(q(a)),
6//!   h_0(a) = dq(a)/da,
7//!   H(a | U) = H_0(a) * exp(U),
8//!   U ~ N(mu, sigma^2),
9//!   mu = X beta + offset.
10//!
11//! Unlike the old compiled-row path, the cumulative masses and baseline hazard
12//! are rebuilt inside the optimizer from the current time-basis coefficients.
13//! The family-level fit surface supports exact events, right censoring, and
14//! interval censoring `T ∈ (L, R]` (contribution `log[S(L) − S(R)]`). Interval
15//! rows carry the reserved [`LATENT_SURVIVAL_EVENT_INTERVAL`] event code and a
16//! dedicated upper-bound time channel (`time_design_right` / `q_right`); the
17//! 3-way event dispatch is [`latent_survival_event_type_for`]. Reached from the
18//! formula DSL via `SurvInterval(L, R, event) ~ ...`.
19
20use crate::custom_family::{
21    BlockWorkingSet, BlockwiseFitOptions, CustomFamily, ExactNewtonJointGradientEvaluation,
22    ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
23    PenaltyMatrix, fit_custom_family, fit_custom_family_fixed_log_lambdas,
24};
25use crate::gamlss::{FamilyMetadata, ParameterLink};
26use crate::sigma_link::{exp_sigma_eta_for_sigma_scalar, exp_sigma_from_eta_scalar};
27use crate::survival::latent::interval::{
28    LatentFrailtyResolution, LatentIntervalModel, LatentIntervalRowView,
29    validate_latent_interval_inputs,
30};
31use crate::survival::location_scale::{
32    TimeBlockInput, project_onto_linear_constraints, structural_time_coefficient_constraints,
33};
34use crate::survival::lognormal_kernel::{
35    FrailtySpec, HazardLoading, LatentSurvivalEventType, LatentSurvivalRow, LatentSurvivalRowJet,
36    log_kernel_bundle,
37};
38use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
39use crate::model_types::UnifiedFitResult;
40use gam_solve::pirls::LinearInequalityConstraints;
41use crate::probability::signed_log_sum_exp;
42use crate::quadrature::{IntegratedExpectationMode, QuadratureContext};
43use gam_terms::smooth::{
44    TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
45};
46use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
47use gam_problem::MIN_WEIGHT;
48use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
49use std::collections::BTreeMap;
50use std::sync::Arc;
51
52/// Typed error for the latent-survival / latent-binary family kernels and
53/// their fit-time and per-row validation helpers. Variants pick the semantic
54/// bucket while the inner `reason` carries the original byte-equivalent
55/// message so external callers that previously consumed `String` errors keep
56/// the same diagnostic text via `Display`.
57#[derive(Debug, Clone)]
58pub enum LatentSurvivalError {
59    /// The frailty spec supplied to a latent-survival or latent-binary
60    /// helper is incompatible (wrong variant, missing fixed sigma, non-finite
61    /// or negative fixed sigma).
62    InvalidFrailty { reason: String },
63    /// Per-row dataset validation failed: empty input, size mismatch across
64    /// the spec vectors, or invalid age / event / weight / unloaded-mass
65    /// values for an individual row.
66    InvalidDataset { reason: String },
67    /// A parameter-block state, eta vector, or directional-derivative
68    /// argument supplied to a family entry point has the wrong length.
69    BlockMismatch { reason: String },
70    /// A runtime numerical value (sigma, baseline hazard derivative, kernel
71    /// sum, event probability) became non-finite or out-of-domain.
72    NumericalFailure { reason: String },
73    /// The requested combination of time-block structure or event type is
74    /// not implemented (non-structural monotonicity, interval-censored rows
75    /// on the dynamic-derivative path).
76    UnsupportedConfiguration { reason: String },
77}
78
79impl_reason_error_boilerplate! {
80    LatentSurvivalError {
81        InvalidFrailty,
82        InvalidDataset,
83        BlockMismatch,
84        NumericalFailure,
85        UnsupportedConfiguration,
86    }
87}
88
89impl From<crate::block_layout::block_count::BlockCountMismatch> for LatentSurvivalError {
90    fn from(
91        err: crate::block_layout::block_count::BlockCountMismatch,
92    ) -> LatentSurvivalError {
93        LatentSurvivalError::BlockMismatch {
94            reason: err.message(),
95        }
96    }
97}
98
99impl From<String> for LatentSurvivalError {
100    /// Inbound conversion for the many `Result<_, String>` helpers this
101    /// module still calls into (term-collection design assembly, dense
102    /// chunk conversion, sparse linear constraints). The text is preserved
103    /// verbatim; we only pick a category so external messages flow through
104    /// `?` without per-callsite `.map_err`.
105    fn from(reason: String) -> LatentSurvivalError {
106        LatentSurvivalError::InvalidDataset { reason }
107    }
108}
109
110/// Reserved [`LatentSurvivalTermSpec::event_target`] code marking an
111/// interval-censored row `(L, R]`. Exact-event codes are `>= 1` and right
112/// censoring is `0`; the interval code is the sentinel `u8::MAX` so it never
113/// collides with an exact-event count and the dispatch is an explicit 3-way map
114/// `{0 → RightCensored, INTERVAL → IntervalCensored, k ≥ 1 → ExactEvent}`.
115pub const LATENT_SURVIVAL_EVENT_INTERVAL: u8 = u8::MAX;
116
117#[inline]
118fn latent_survival_event_type_for(code: u8) -> LatentSurvivalEventType {
119    match code {
120        0 => LatentSurvivalEventType::RightCensored,
121        LATENT_SURVIVAL_EVENT_INTERVAL => LatentSurvivalEventType::IntervalCensored,
122        _ => LatentSurvivalEventType::ExactEvent,
123    }
124}
125
126#[derive(Clone)]
127pub struct LatentSurvivalTermSpec {
128    pub age_entry: Array1<f64>,
129    pub age_exit: Array1<f64>,
130    pub event_target: Array1<u8>,
131    pub weights: Array1<f64>,
132    pub derivative_guard: f64,
133    pub time_block: TimeBlockInput,
134    /// Time-basis design evaluated at the interval upper bound `R` (so
135    /// `q_right = design_right · β_time + offset_right`). `None` when the data
136    /// carries no interval-censored rows; the family then reuses the exit design
137    /// for the unused `q_right` channel. When `Some`, rows whose
138    /// `event_target == LATENT_SURVIVAL_EVENT_INTERVAL` contribute the interval
139    /// likelihood `log[S(L) − S(R)]`.
140    pub time_design_right: Option<DesignMatrix>,
141    pub time_offset_right: Option<Array1<f64>>,
142    pub unloaded_mass_entry: Array1<f64>,
143    pub unloaded_mass_exit: Array1<f64>,
144    /// Unloaded (background) cumulative mass at the interval upper bound `R`.
145    /// Length-`n`; entries for non-interval rows are ignored. Empty/`None`
146    /// folds to zero (full-loading interval rows).
147    pub unloaded_mass_right: Array1<f64>,
148    pub unloaded_hazard_exit: Array1<f64>,
149    pub meanspec: TermCollectionSpec,
150    pub mean_offset: Array1<f64>,
151}
152
153pub struct LatentSurvivalTermFitResult {
154    pub fit: UnifiedFitResult,
155    pub design: TermCollectionDesign,
156    pub resolvedspec: TermCollectionSpec,
157    pub latent_sd: f64,
158    /// Per-row residuals of the unpenalized NLL w.r.t. the additive baseline
159    /// time-block offsets `(entry, exit, derivative)` at the converged β̂.
160    /// Contracted against `baseline_offset_theta_partials` by
161    /// `baseline_chain_rule_gradient` to give the exact θ-gradient of the
162    /// profile penalized NLL for the outer baseline-config optimizer.
163    pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
164}
165
166#[derive(Clone)]
167pub struct LatentBinaryTermSpec {
168    pub age_entry: Array1<f64>,
169    pub age_exit: Array1<f64>,
170    pub event_target: Array1<u8>,
171    pub weights: Array1<f64>,
172    pub derivative_guard: f64,
173    pub time_block: TimeBlockInput,
174    pub unloaded_mass_entry: Array1<f64>,
175    pub unloaded_mass_exit: Array1<f64>,
176    pub meanspec: TermCollectionSpec,
177    pub mean_offset: Array1<f64>,
178}
179
180pub struct LatentBinaryTermFitResult {
181    pub fit: UnifiedFitResult,
182    pub design: TermCollectionDesign,
183    pub resolvedspec: TermCollectionSpec,
184    /// Per-row residuals of the unpenalized NLL w.r.t. the additive baseline
185    /// time-block offsets `(entry, exit)` at the converged β̂ (the derivative
186    /// channel is identically zero for the binary deployment likelihood).
187    pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
188}
189
190#[derive(Clone)]
191struct PreparedLatentTimeBlock {
192    design_entry: Array2<f64>,
193    design_exit: Array2<f64>,
194    design_derivative_exit: Array2<f64>,
195    /// Dense time-basis design at the interval upper bound `R`. Falls back to a
196    /// clone of `design_exit` when the spec supplies no interval design, so the
197    /// `q_right` channel is always well-defined (and unused for non-interval
198    /// rows).
199    design_right: Array2<f64>,
200    linear_constraints: Option<LinearInequalityConstraints>,
201    penalties: Vec<Array2<f64>>,
202    initial_beta: Option<Array1<f64>>,
203}
204
205#[derive(Clone)]
206pub struct LatentSurvivalFamily {
207    pub event_target: Array1<u8>,
208    pub weights: Array1<f64>,
209    pub latent_sd_fixed: Option<f64>,
210    pub hazard_loading: HazardLoading,
211    pub unloaded_mass_entry: Array1<f64>,
212    pub unloaded_mass_exit: Array1<f64>,
213    pub unloaded_hazard_exit: Array1<f64>,
214    pub x_time_entry: Array2<f64>,
215    pub x_time_exit: Array2<f64>,
216    pub x_time_derivative_exit: Array2<f64>,
217    /// Time-basis design evaluated at the interval upper bound `R` (so
218    /// `q_right = x_time_right · β_time + time_offset_right`). For non-interval
219    /// rows this row equals `x_time_exit`'s row (`q_right` is then unused by the
220    /// likelihood), so the matrix always has `n` rows and the same column count
221    /// as the other time designs.
222    pub x_time_right: Array2<f64>,
223    /// Time-block offset at the interval upper bound `R` (length `n`).
224    pub time_offset_right: Array1<f64>,
225    /// Unloaded (background) cumulative mass at the interval upper bound `R`
226    /// (length `n`). Ignored for non-interval rows.
227    pub unloaded_mass_right: Array1<f64>,
228    pub x_mean: DesignMatrix,
229    pub time_linear_constraints: Option<LinearInequalityConstraints>,
230    pub quadctx: Arc<QuadratureContext>,
231}
232
233#[derive(Clone)]
234pub struct LatentBinaryFamily {
235    pub event_target: Array1<u8>,
236    pub weights: Array1<f64>,
237    pub latent_sd: f64,
238    pub hazard_loading: HazardLoading,
239    pub unloaded_mass_entry: Array1<f64>,
240    pub unloaded_mass_exit: Array1<f64>,
241    pub x_time_entry: Array2<f64>,
242    pub x_time_exit: Array2<f64>,
243    pub x_mean: DesignMatrix,
244    pub time_linear_constraints: Option<LinearInequalityConstraints>,
245    pub quadctx: Arc<QuadratureContext>,
246}
247
248impl LatentSurvivalFamily {
249    pub const BLOCK_TIME: usize = 0;
250    pub const BLOCK_MEAN: usize = 1;
251    pub const BLOCK_LOG_SIGMA: usize = 2;
252
253    pub fn parameter_names() -> &'static [&'static str] {
254        &["time_transform", "mean"]
255    }
256
257    pub fn parameter_links() -> &'static [ParameterLink] {
258        &[ParameterLink::Identity, ParameterLink::Identity]
259    }
260
261    pub fn metadata() -> FamilyMetadata {
262        FamilyMetadata {
263            name: "latent_survival",
264            parameternames: Self::parameter_names(),
265            parameter_links: Self::parameter_links(),
266        }
267    }
268
269    fn split_time_eta<'a>(
270        &self,
271        block_states: &'a [ParameterBlockState],
272    ) -> Result<
273        (
274            ArrayView1<'a, f64>,
275            ArrayView1<'a, f64>,
276            ArrayView1<'a, f64>,
277            &'a Array1<f64>,
278        ),
279        LatentSurvivalError,
280    > {
281        let expected_blocks = if self.latent_sd_fixed.is_some() { 2 } else { 3 };
282        crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
283            "LatentSurvivalFamily",
284            expected_blocks,
285            block_states.len(),
286        )?;
287        let n = self.event_target.len();
288        let eta_time = &block_states[Self::BLOCK_TIME].eta;
289        let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
290        if eta_time.len() != 3 * n {
291            return Err(LatentSurvivalError::BlockMismatch {
292                reason: format!(
293                    "latent survival time eta length mismatch: got {}, expected {}",
294                    eta_time.len(),
295                    3 * n
296                ),
297            });
298        }
299        if eta_mean.len() != n || self.weights.len() != n {
300            return Err(LatentSurvivalError::BlockMismatch {
301                reason: "latent survival mean eta dimension mismatch".to_string(),
302            });
303        }
304        Ok((
305            eta_time.slice(s![0..n]),
306            eta_time.slice(s![n..2 * n]),
307            eta_time.slice(s![2 * n..3 * n]),
308            eta_mean,
309        ))
310    }
311
312    /// Per-row interval upper-bound time transform `q_right = x_time_right · β_time
313    /// + time_offset_right`. Shares the time-block coefficients with `q_exit`
314    /// (same monotone basis, evaluated at `R`), so it is read off the time
315    /// block's `beta` rather than carried as an extra eta channel. For
316    /// non-interval rows `x_time_right` equals `x_time_exit`, so the (unused)
317    /// value is simply `q_exit`.
318    fn time_q_right(
319        &self,
320        block_states: &[ParameterBlockState],
321    ) -> Result<Array1<f64>, LatentSurvivalError> {
322        let n = self.event_target.len();
323        let beta_time = &block_states[Self::BLOCK_TIME].beta;
324        if self.x_time_right.ncols() != beta_time.len() {
325            return Err(LatentSurvivalError::BlockMismatch {
326                reason: format!(
327                    "latent survival interval right design has {} columns but time beta has {}",
328                    self.x_time_right.ncols(),
329                    beta_time.len()
330                ),
331            });
332        }
333        if self.x_time_right.nrows() != n || self.time_offset_right.len() != n {
334            return Err(LatentSurvivalError::BlockMismatch {
335                reason: "latent survival interval right design/offset row count mismatch"
336                    .to_string(),
337            });
338        }
339        let mut q_right = self.x_time_right.dot(beta_time);
340        q_right += &self.time_offset_right;
341        Ok(q_right)
342    }
343
344    fn latent_sd(&self, block_states: &[ParameterBlockState]) -> Result<f64, LatentSurvivalError> {
345        if let Some(sigma) = self.latent_sd_fixed {
346            return Ok(sigma);
347        }
348        let eta = *block_states
349            .get(Self::BLOCK_LOG_SIGMA)
350            .and_then(|state| state.eta.get(0))
351            .ok_or_else(|| LatentSurvivalError::BlockMismatch {
352                reason: "latent survival learnable log_sigma block is missing".to_string(),
353            })?;
354        let sigma = exp_sigma_from_eta_scalar(eta);
355        if !(sigma.is_finite() && sigma > 0.0) {
356            return Err(LatentSurvivalError::NumericalFailure {
357                reason: format!(
358                    "latent survival learnable sigma became invalid: log_sigma={eta}, sigma={sigma}"
359                ),
360            });
361        }
362        Ok(sigma)
363    }
364}
365
366impl LatentBinaryFamily {
367    pub const BLOCK_TIME: usize = 0;
368    pub const BLOCK_MEAN: usize = 1;
369
370    fn split_time_eta<'a>(
371        &self,
372        block_states: &'a [ParameterBlockState],
373    ) -> Result<(ArrayView1<'a, f64>, ArrayView1<'a, f64>, &'a Array1<f64>), LatentSurvivalError>
374    {
375        crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
376            "LatentBinaryFamily",
377            2,
378            block_states.len(),
379        )?;
380        let n = self.event_target.len();
381        let eta_time = &block_states[Self::BLOCK_TIME].eta;
382        let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
383        if eta_time.len() != 3 * n {
384            return Err(LatentSurvivalError::BlockMismatch {
385                reason: format!(
386                    "latent binary time eta length mismatch: got {}, expected {}",
387                    eta_time.len(),
388                    3 * n
389                ),
390            });
391        }
392        if eta_mean.len() != n || self.weights.len() != n {
393            return Err(LatentSurvivalError::BlockMismatch {
394                reason: "latent binary mean eta dimension mismatch".to_string(),
395            });
396        }
397        Ok((
398            eta_time.slice(s![0..n]),
399            eta_time.slice(s![n..2 * n]),
400            eta_mean,
401        ))
402    }
403}
404
405pub fn fixed_latent_hazard_frailty(
406    frailty: &FrailtySpec,
407    context: &str,
408) -> Result<(f64, HazardLoading), String> {
409    fixed_latent_hazard_frailty_typed(frailty, context).map_err(Into::into)
410}
411
412fn fixed_latent_hazard_frailty_typed(
413    frailty: &FrailtySpec,
414    context: &str,
415) -> Result<(f64, HazardLoading), LatentSurvivalError> {
416    match frailty {
417        FrailtySpec::HazardMultiplier {
418            sigma_fixed: Some(sigma),
419            loading,
420        } if sigma.is_finite() && *sigma >= 0.0 => Ok((*sigma, *loading)),
421        FrailtySpec::HazardMultiplier {
422            sigma_fixed: Some(sigma),
423            ..
424        } => Err(LatentSurvivalError::InvalidFrailty {
425            reason: format!(
426                "{context} requires a finite fixed hazard-multiplier sigma >= 0, got {sigma}"
427            ),
428        }),
429        FrailtySpec::HazardMultiplier {
430            sigma_fixed: None, ..
431        } => Err(LatentSurvivalError::InvalidFrailty {
432            reason: format!("{context} currently requires a fixed hazard-multiplier sigma"),
433        }),
434        FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
435            reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
436        }),
437        FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
438            reason: format!("{context} requires a fixed HazardMultiplier frailty specification"),
439        }),
440    }
441}
442
443pub fn latent_hazard_loading(
444    frailty: &FrailtySpec,
445    context: &str,
446) -> Result<HazardLoading, String> {
447    latent_hazard_loading_typed(frailty, context).map_err(Into::into)
448}
449
450fn latent_hazard_loading_typed(
451    frailty: &FrailtySpec,
452    context: &str,
453) -> Result<HazardLoading, LatentSurvivalError> {
454    match frailty {
455        FrailtySpec::HazardMultiplier { loading, .. } => Ok(*loading),
456        FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
457            reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
458        }),
459        FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
460            reason: format!("{context} requires a HazardMultiplier frailty specification"),
461        }),
462    }
463}
464
465#[derive(Clone, Copy)]
466struct LatentSurvivalTimeJet {
467    grad_entry: f64,
468    grad_exit: f64,
469    neg_hess_entry: f64,
470    neg_hess_exit: f64,
471}
472
473pub fn fit_latent_survival_terms(
474    data: ArrayView2<'_, f64>,
475    spec: LatentSurvivalTermSpec,
476    frailty: FrailtySpec,
477    options: &BlockwiseFitOptions,
478) -> Result<LatentSurvivalTermFitResult, String> {
479    let latent_sd = validate_latent_survival_inputs(data, &spec, &frailty)?;
480    let hazard_loading = latent_hazard_loading(&frailty, "latent-survival")?;
481    let mean_design =
482        build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
483    let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
484        .map_err(|e| e.to_string())?;
485    let time_prepared = prepare_latent_time_block(
486        &spec.time_block,
487        spec.time_design_right.as_ref(),
488        spec.derivative_guard,
489    )?;
490
491    let n = spec.event_target.len();
492    let time_offset_right = match spec.time_offset_right.as_ref() {
493        Some(offset) => {
494            if offset.len() != n {
495                return Err(format!(
496                    "latent survival interval right time offset must have length {n}, got {}",
497                    offset.len()
498                ));
499            }
500            offset.clone()
501        }
502        None => Array1::zeros(n),
503    };
504    let unloaded_mass_right = if spec.unloaded_mass_right.is_empty() {
505        Array1::zeros(n)
506    } else {
507        if spec.unloaded_mass_right.len() != n {
508            return Err(format!(
509                "latent survival interval right unloaded mass must have length {n}, got {}",
510                spec.unloaded_mass_right.len()
511            ));
512        }
513        spec.unloaded_mass_right.clone()
514    };
515
516    let family = LatentSurvivalFamily {
517        event_target: spec.event_target.clone(),
518        weights: spec.weights.clone(),
519        latent_sd_fixed: latent_sd,
520        hazard_loading,
521        unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
522        unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
523        unloaded_hazard_exit: spec.unloaded_hazard_exit.clone(),
524        x_time_entry: time_prepared.design_entry.clone(),
525        x_time_exit: time_prepared.design_exit.clone(),
526        x_time_derivative_exit: time_prepared.design_derivative_exit.clone(),
527        x_time_right: time_prepared.design_right.clone(),
528        time_offset_right,
529        unloaded_mass_right,
530        x_mean: mean_design.design.clone(),
531        time_linear_constraints: time_prepared.linear_constraints.clone(),
532        quadctx: Arc::new(QuadratureContext::new()),
533    };
534
535    let mut blocks = vec![
536        build_time_blockspec(&time_prepared, &spec.time_block),
537        build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
538    ];
539    if latent_sd.is_none() {
540        blocks.push(build_log_sigma_blockspec(
541            LEARNABLE_LATENT_SD_SEED,
542            mean_design.design.nrows(),
543        ));
544    }
545    // Interval warm start (issue #1108). Interval-censored rows contribute the
546    // NON-concave `ℓ = log[S(L) − S(R)]`; the coupled exact-joint inner Newton
547    // diverges from the cold seed (β_time = 1e-4, σ = 0.5) — the failure surfaces
548    // first as `fit_custom_family`'s outer ρ-seed startup validation rejecting
549    // every seed (`solver_started = 0`). We warm-start from a LOG-CONCAVE
550    // surrogate whose β/σ land in the interval basin, threaded via `initial_beta`
551    // (consumed by every inner solve, including each ρ-seed validation fit).
552    //
553    // Surrogate = right-censored at the bracket LOWER bound `L`. Its survival
554    // mass `S(L) = K_{0,B(L)}` is log-concave (PD Hessian) and — crucially —
555    // its time-block design is the SAME fixed-knot I-spline basis the interval
556    // fit uses, which is FULL RANK regardless of how heavily the inspection-grid
557    // `L` values are TIED (the basis columns are functions of the frozen knots,
558    // not of the observed time multiplicities). Unlike an exact-event surrogate
559    // it imposes NO per-row `q̇(L) > 0` hazard-derivative feasibility condition
560    // (which the tied/degenerate cold-start derivative design can violate), so it
561    // is robust where exact-event-at-L is not. The warm σ then refines from the
562    // bracket-width spread inside the (now in-basin) interval fit.
563    //
564    // Failure is NON-SILENT (#1108): a surrogate that errors or returns a
565    // non-finite / all-zero degenerate β is surfaced as a hard error rather than
566    // silently reverting to the diverging cold start (which masked the real
567    // failure across several attempts). Only `initial_beta` is seeded; the EXACT
568    // interval objective/gradient/Hessian are unchanged, so σ̂ is the true MLE.
569    let has_interval_rows = spec
570        .event_target
571        .iter()
572        .any(|&code| code == LATENT_SURVIVAL_EVENT_INTERVAL);
573    if has_interval_rows {
574        let censored_warm_event_target = spec.event_target.mapv(|code| {
575            if code == LATENT_SURVIVAL_EVENT_INTERVAL {
576                0u8
577            } else {
578                code
579            }
580        });
581        let mut warm_family = family.clone();
582        warm_family.event_target = censored_warm_event_target;
583        // Right-censored-at-L ignores the interval upper bound `R`, so the
584        // (unused) `q_right` channel cannot drift the fit; leaving the right
585        // design/mass in place is harmless (no interval row remains to read it).
586        let warm_fit_result = fit_custom_family_fixed_log_lambdas(
587            &warm_family,
588            &blocks,
589            options,
590            None,
591            0,
592            None,
593            false,
594        );
595        let warm_fit = match warm_fit_result {
596            Ok(fit) => fit,
597            Err(censored_error) => {
598                let has_finite_event_in_censored_surrogate =
599                    warm_family.event_target.iter().any(|&code| code != 0);
600                if has_finite_event_in_censored_surrogate {
601                    return Err(format!(
602                        "latent interval warm start: right-censored-at-L surrogate fit failed \
603                         (so the interval fit cannot be safely warm-started; this surrogate is \
604                         log-concave and should converge — investigate the surrogate, not the \
605                         interval kernel): {censored_error}"
606                    ));
607                }
608
609                // When every observed row is interval-censored, the
610                // right-censored-at-L surrogate contains no failures at all.
611                // Its likelihood is maximized only on the zero-hazard boundary
612                // (β_time -> -∞), so the fixed-λ Newton solve is correctly
613                // allowed to refuse it even though the objective is concave.
614                // Use the finite lower-endpoint event surrogate solely to obtain
615                // an interior β/σ seed for the exact interval likelihood below;
616                // no fitted surrogate likelihood or derivative is reused.
617                let lower_event_warm_target = spec.event_target.mapv(|code| {
618                    if code == LATENT_SURVIVAL_EVENT_INTERVAL {
619                        1u8
620                    } else {
621                        code
622                    }
623                });
624                let mut event_warm_family = family.clone();
625                event_warm_family.event_target = lower_event_warm_target;
626                fit_custom_family_fixed_log_lambdas(
627                    &event_warm_family,
628                    &blocks,
629                    options,
630                    None,
631                    0,
632                    None,
633                    false,
634                )
635                .map_err(|event_error| {
636                    format!(
637                        "latent interval warm start failed: the right-censored-at-L surrogate \
638                         has no finite failures and refused its boundary optimum ({censored_error}); \
639                         the finite lower-endpoint event surrogate also failed ({event_error})"
640                    )
641                })?
642            }
643        };
644        let warm_beta_usable = warm_fit
645            .block_states
646            .iter()
647            .any(|s| s.beta.iter().all(|v| v.is_finite()) && s.beta.iter().any(|&v| v != 0.0));
648        if !warm_beta_usable {
649            return Err(
650                "latent interval warm start: right-censored-at-L surrogate returned a \
651                 degenerate (non-finite or all-zero) β across every block; the warm start \
652                 cannot seed the interval fit. This indicates the surrogate's time-block \
653                 design is rank-deficient or the inner solve stalled at the seed — \
654                 investigate the surrogate before retrying the interval fit."
655                    .to_string(),
656            );
657        }
658        for (block, state) in blocks.iter_mut().zip(warm_fit.block_states.iter()) {
659            if state.beta.iter().all(|v| v.is_finite()) {
660                block.initial_beta = Some(state.beta.clone());
661            }
662        }
663    }
664    let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
665    let latent_sd = family.latent_sd(&fit.block_states)?;
666    let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
667    Ok(LatentSurvivalTermFitResult {
668        fit,
669        design: mean_design,
670        resolvedspec,
671        latent_sd,
672        baseline_offset_residuals,
673    })
674}
675
676pub fn fit_latent_binary_terms(
677    data: ArrayView2<'_, f64>,
678    spec: LatentBinaryTermSpec,
679    frailty: FrailtySpec,
680    options: &BlockwiseFitOptions,
681) -> Result<LatentBinaryTermFitResult, String> {
682    let latent_sd = validate_latent_binary_inputs(data, &spec, &frailty)?;
683    let (_, hazard_loading) = fixed_latent_hazard_frailty(&frailty, "latent-binary")?;
684    let mean_design =
685        build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
686    let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
687        .map_err(|e| e.to_string())?;
688    let time_prepared = prepare_latent_time_block(&spec.time_block, None, spec.derivative_guard)?;
689
690    let family = LatentBinaryFamily {
691        event_target: spec.event_target.clone(),
692        weights: spec.weights.clone(),
693        latent_sd,
694        hazard_loading,
695        unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
696        unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
697        x_time_entry: time_prepared.design_entry.clone(),
698        x_time_exit: time_prepared.design_exit.clone(),
699        x_mean: mean_design.design.clone(),
700        time_linear_constraints: time_prepared.linear_constraints.clone(),
701        quadctx: Arc::new(QuadratureContext::new()),
702    };
703
704    let blocks = vec![
705        build_time_blockspec(&time_prepared, &spec.time_block),
706        build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
707    ];
708    let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
709    let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
710    Ok(LatentBinaryTermFitResult {
711        fit,
712        design: mean_design,
713        resolvedspec,
714        baseline_offset_residuals,
715    })
716}
717
718/// Latent-survival adapter for the shared [`LatentIntervalModel`] driver.
719///
720/// Survival permits a learnable sigma (`sigma_fixed == None`) and carries the
721/// per-row unloaded baseline hazard at exit (which feeds the exact-event
722/// loaded/unloaded split); everything else is validated by the shared engine.
723struct LatentSurvivalModel;
724
725impl LatentIntervalModel for LatentSurvivalModel {
726    fn context() -> &'static str {
727        "latent-survival"
728    }
729
730    fn allows_interval() -> bool {
731        true
732    }
733
734    fn frailty_policy(
735        frailty: &FrailtySpec,
736    ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
737        match frailty {
738            FrailtySpec::HazardMultiplier {
739                sigma_fixed,
740                loading,
741            } => {
742                if let Some(sigma) = sigma_fixed
743                    && (!sigma.is_finite() || *sigma < 0.0)
744                {
745                    return Err(LatentSurvivalError::InvalidFrailty {
746                        reason: format!(
747                            "latent-survival requires a finite hazard-multiplier sigma >= 0, got {sigma}"
748                        ),
749                    });
750                }
751                Ok(LatentFrailtyResolution {
752                    sigma: *sigma_fixed,
753                    loading: *loading,
754                })
755            }
756            FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
757                reason: "latent-survival requires HazardMultiplier frailty, not GaussianShift"
758                    .to_string(),
759            }),
760            FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
761                reason: "latent-survival requires a HazardMultiplier frailty specification"
762                    .to_string(),
763            }),
764        }
765    }
766}
767
768fn validate_latent_survival_inputs(
769    data: ArrayView2<'_, f64>,
770    spec: &LatentSurvivalTermSpec,
771    frailty: &FrailtySpec,
772) -> Result<Option<f64>, LatentSurvivalError> {
773    let row = LatentIntervalRowView {
774        frailty,
775        age_entry: &spec.age_entry,
776        age_exit: &spec.age_exit,
777        event_target: &spec.event_target,
778        weights: &spec.weights,
779        unloaded_mass_entry: &spec.unloaded_mass_entry,
780        unloaded_mass_exit: &spec.unloaded_mass_exit,
781        unloaded_hazard_exit: Some(&spec.unloaded_hazard_exit),
782        mean_offset: &spec.mean_offset,
783        derivative_guard: spec.derivative_guard,
784        time_block: &spec.time_block,
785    };
786    validate_latent_interval_inputs::<LatentSurvivalModel>(data, &row)
787}
788
789pub(crate) fn validate_unloaded_components_for_loading(
790    context: &str,
791    row_index: usize,
792    loading: HazardLoading,
793    unloaded_entry: f64,
794    unloaded_exit: f64,
795    unloaded_hazard: Option<f64>,
796) -> Result<(), LatentSurvivalError> {
797    match loading {
798        HazardLoading::Full => {
799            if unloaded_entry != 0.0
800                || unloaded_exit != 0.0
801                || unloaded_hazard.is_some_and(|hazard| hazard != 0.0)
802            {
803                return Err(LatentSurvivalError::InvalidDataset {
804                    reason: format!(
805                        "{context} row {} uses full hazard loading, so unloaded components must be exactly zero; got entry_mass={}, exit_mass={}, exit_hazard={}",
806                        row_index + 1,
807                        unloaded_entry,
808                        unloaded_exit,
809                        unloaded_hazard.unwrap_or(0.0)
810                    ),
811                });
812            }
813        }
814        HazardLoading::LoadedVsUnloaded => {}
815    }
816    Ok(())
817}
818
819/// Latent-binary adapter for the shared [`LatentIntervalModel`] driver.
820///
821/// Binary never evaluates an exact event, so it requires a finite *fixed*
822/// latent sigma (via [`fixed_latent_hazard_frailty_typed`]) and carries no
823/// per-row unloaded hazard; every other invariant is validated by the shared
824/// engine.
825struct LatentBinaryModel;
826
827impl LatentIntervalModel for LatentBinaryModel {
828    fn context() -> &'static str {
829        "latent-binary"
830    }
831
832    fn frailty_policy(
833        frailty: &FrailtySpec,
834    ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
835        let (sigma, loading) = fixed_latent_hazard_frailty_typed(frailty, "latent-binary")?;
836        Ok(LatentFrailtyResolution {
837            sigma: Some(sigma),
838            loading,
839        })
840    }
841}
842
843fn validate_latent_binary_inputs(
844    data: ArrayView2<'_, f64>,
845    spec: &LatentBinaryTermSpec,
846    frailty: &FrailtySpec,
847) -> Result<f64, LatentSurvivalError> {
848    let row = LatentIntervalRowView {
849        frailty,
850        age_entry: &spec.age_entry,
851        age_exit: &spec.age_exit,
852        event_target: &spec.event_target,
853        weights: &spec.weights,
854        unloaded_mass_entry: &spec.unloaded_mass_entry,
855        unloaded_mass_exit: &spec.unloaded_mass_exit,
856        unloaded_hazard_exit: None,
857        mean_offset: &spec.mean_offset,
858        derivative_guard: spec.derivative_guard,
859        time_block: &spec.time_block,
860    };
861    // The binary `frailty_policy` always yields `Some(sigma)` (it rejects the
862    // learnable-scale case), so the shared driver's `Option<f64>` is `Some`
863    // here; surface a structured error rather than unwrapping if that ever
864    // changes.
865    validate_latent_interval_inputs::<LatentBinaryModel>(data, &row)?.ok_or_else(|| {
866        LatentSurvivalError::InvalidFrailty {
867            reason: "latent-binary requires a fixed latent sigma".to_string(),
868        }
869    })
870}
871
872fn prepare_latent_time_block(
873    input: &TimeBlockInput,
874    design_right: Option<&DesignMatrix>,
875    derivative_guard: f64,
876) -> Result<PreparedLatentTimeBlock, LatentSurvivalError> {
877    if !input.time_monotonicity.is_coordinate_cone() {
878        return Err(LatentSurvivalError::UnsupportedConfiguration {
879            reason: format!(
880                "latent survival requires a coordinate-cone monotonicity strategy; got {:?}",
881                input.time_monotonicity
882            ),
883        });
884    }
885    let design_entry = input
886        .design_entry
887        .try_to_dense_by_chunks("latent survival entry time design")?;
888    let design_exit = input
889        .design_exit
890        .try_to_dense_by_chunks("latent survival exit time design")?;
891    let design_derivative_exit = input
892        .design_derivative_exit
893        .try_to_dense_by_chunks("latent survival derivative time design")?;
894    // The interval upper-bound design shares the time-block coefficients with
895    // the exit design; when the data has no interval rows we reuse the exit
896    // design so `q_right` stays well-defined (its likelihood contribution is
897    // gated off for non-interval rows). When present it must match the exit
898    // design's shape (same basis, evaluated at R).
899    let design_right = match design_right {
900        Some(matrix) => {
901            let dense =
902                matrix.try_to_dense_by_chunks("latent survival interval right time design")?;
903            if dense.nrows() != design_exit.nrows() || dense.ncols() != design_exit.ncols() {
904                return Err(LatentSurvivalError::InvalidDataset {
905                    reason: format!(
906                        "latent survival interval right time design must match exit design shape \
907                         {:?}, got {:?}",
908                        design_exit.dim(),
909                        dense.dim()
910                    ),
911                });
912            }
913            dense
914        }
915        None => design_exit.clone(),
916    };
917    let linear_constraints = structural_time_coefficient_constraints(
918        &input.design_derivative_exit,
919        &input.derivative_offset_exit,
920        derivative_guard,
921    )?;
922    let initial_beta = match linear_constraints.as_ref() {
923        // `project_onto_linear_constraints` validates that any supplied
924        // `initial_beta` matches `design_exit.ncols()`; surface a mismatch as a
925        // structured error rather than letting an ndarray broadcast panic
926        // (issue #374).
927        Some(constraints) => Some(project_onto_linear_constraints(
928            design_exit.ncols(),
929            constraints,
930            input.initial_beta.as_ref(),
931        )?),
932        None => None,
933    };
934    Ok(PreparedLatentTimeBlock {
935        design_entry,
936        design_exit,
937        design_derivative_exit,
938        design_right,
939        linear_constraints,
940        penalties: input.penalties.clone(),
941        initial_beta,
942    })
943}
944
945fn stack_rows(blocks: &[&Array2<f64>]) -> Array2<f64> {
946    let ncols = blocks.first().map_or(0, |m| m.ncols());
947    let nrows = blocks.iter().map(|m| m.nrows()).sum();
948    let mut out = Array2::<f64>::zeros((nrows, ncols));
949    let mut row = 0usize;
950    for block in blocks {
951        let end = row + block.nrows();
952        out.slice_mut(s![row..end, ..]).assign(block);
953        row = end;
954    }
955    out
956}
957
958fn build_time_blockspec(
959    prepared: &PreparedLatentTimeBlock,
960    input: &TimeBlockInput,
961) -> ParameterBlockSpec {
962    // The solver produces a `3·n`-long time `eta` (the `[entry; exit; deriv]`
963    // channel stack that `split_time_eta` slices). That stacked operator is
964    // the eta-producing matrix and so belongs in `stacked_design`, paired with
965    // the matching `3·n`-row stacked offset. The audit / shape-policy invariant
966    // `design.nrows() == n_obs` is satisfied by exposing the single-channel
967    // n-row exit design as `design`; the audit never inspects `stacked_design`.
968    //
969    // This mirrors the survival location-scale fix for the same #326 class
970    // (`survival_location_scale.rs`): the previous code put the `3·n`-row
971    // stack in `design`, which tripped the flat identifiability audit's
972    // row-equality invariant (`block 1 (mean) has n rows, expected 3n`).
973    let stacked_design = stack_rows(&[
974        &prepared.design_entry,
975        &prepared.design_exit,
976        &prepared.design_derivative_exit,
977    ]);
978    let stacked_offset = gam_linalg::utils::stack_offsets(&[
979        &input.offset_entry,
980        &input.offset_exit,
981        &input.derivative_offset_exit,
982    ]);
983    ParameterBlockSpec {
984        name: "time_transform".to_string(),
985        design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
986            prepared.design_exit.clone(),
987        ))),
988        offset: input.offset_exit.clone(),
989        penalties: prepared
990            .penalties
991            .iter()
992            .cloned()
993            .map(PenaltyMatrix::Dense)
994            .collect(),
995        nullspace_dims: input.nullspace_dims.clone(),
996        initial_log_lambdas: input
997            .initial_log_lambdas
998            .clone()
999            .unwrap_or_else(|| Array1::zeros(prepared.penalties.len())),
1000        initial_beta: prepared.initial_beta.clone(),
1001        // Canonical-gauge ownership for the latent-survival joint design: the
1002        // time-transform block carries the structural monotone baseline that
1003        // anchors the parameterisation, so it owns any shared constant
1004        // direction (strictly higher than `mean`/`log_sigma` at 100). This
1005        // matches the survival location-scale gauge contract (time highest).
1006        gauge_priority: 200,
1007        jacobian_callback: None,
1008        stacked_design: Some(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
1009            stacked_design,
1010        )))),
1011        stacked_offset: Some(stacked_offset),
1012    }
1013}
1014
1015fn build_mean_blockspec(design: &TermCollectionDesign, offset: Array1<f64>) -> ParameterBlockSpec {
1016    ParameterBlockSpec {
1017        name: "mean".to_string(),
1018        design: design.design.clone(),
1019        offset,
1020        penalties: design.penalties_as_penalty_matrix(),
1021        nullspace_dims: design.nullspace_dims.clone(),
1022        initial_log_lambdas: Array1::zeros(design.penalties.len()),
1023        initial_beta: None,
1024        // Strictly below `time_transform` (200) so any constant direction
1025        // shared between the monotone time baseline and the mean intercept is
1026        // deterministically attributable to the lower-priority `mean` block by
1027        // the canonical-gauge RRQR (the descending-priority contract used by
1028        // survival location-scale; #366/#556 gauge story).
1029        gauge_priority: 150,
1030        jacobian_callback: None,
1031        stacked_design: None,
1032        stacked_offset: None,
1033    }
1034}
1035
1036/// Starting latent-frailty standard deviation when `sigma` is learnable
1037/// (`sigma_fixed == None`). The log-sigma block is seeded at `log(0.5)` so the
1038/// optimizer begins from a moderate, well-conditioned dispersion (σ = 0.5,
1039/// neither a near-degenerate σ → 0 that flattens the frailty integral nor a
1040/// large σ that makes the Gauss-Hermite quadrature heavy-tailed) and then
1041/// learns the data's actual scale. Only an initial value, not a constraint.
1042const LEARNABLE_LATENT_SD_SEED: f64 = 0.5;
1043
1044fn build_log_sigma_blockspec(initial_sigma: f64, n_obs: usize) -> ParameterBlockSpec {
1045    ParameterBlockSpec {
1046        name: "log_sigma".to_string(),
1047        // The frailty/dispersion scale is a single GLOBAL hyperparameter (one free
1048        // coefficient), but the identifiability audit — and the canonical-row
1049        // architecture generally — require every block's effective Jacobian to carry
1050        // `n_obs` rows. A global scalar is realised the same way the survival
1051        // location-scale `log_sigma` block is (see `BinomialLocationScaleFamily`): an
1052        // `n_obs × 1` constant column of ones, so `eta = design · β` is the same scalar
1053        // broadcast to every observation. This keeps it a single free parameter while
1054        // exposing the `n_obs`-row shape the audit checks, and `latent_sd` reads
1055        // `eta[0]` — identical across rows by construction.
1056        design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(Array2::from_elem(
1057            (n_obs, 1),
1058            1.0,
1059        )))),
1060        offset: Array1::zeros(n_obs),
1061        penalties: vec![],
1062        nullspace_dims: vec![],
1063        initial_log_lambdas: Array1::zeros(0),
1064        initial_beta: Some(Array1::from_elem(
1065            1,
1066            exp_sigma_eta_for_sigma_scalar(initial_sigma),
1067        )),
1068        // Lowest of the three (time=200, mean=150): the learnable-scale channel
1069        // yields any shared constant to the location blocks.
1070        gauge_priority: 120,
1071        jacobian_callback: None,
1072        stacked_design: None,
1073        stacked_offset: None,
1074    }
1075}
1076
1077const LATENT_SURVIVAL_PRIMARY_Q_ENTRY: usize = 0;
1078const LATENT_SURVIVAL_PRIMARY_Q_EXIT: usize = 1;
1079const LATENT_SURVIVAL_PRIMARY_QDOT_EXIT: usize = 2;
1080// Interval-censored right boundary R: q_right = log B(R) shares the time-block
1081// coefficients with q_exit (same monotone transform, different time point), so
1082// it is a fourth linear functional of the time block, NOT an independent eta
1083// channel. It sits before `mu`/`log_sigma` so the "trailing optional log_sigma"
1084// invariant used by `active_primary` (= `LATENT_SURVIVAL_PRIMARY_LOG_SIGMA`)
1085// keeps q_right always active.
1086const LATENT_SURVIVAL_PRIMARY_Q_RIGHT: usize = 3;
1087const LATENT_SURVIVAL_PRIMARY_MU: usize = 4;
1088const LATENT_SURVIVAL_PRIMARY_LOG_SIGMA: usize = 5;
1089const LATENT_SURVIVAL_PRIMARY_DIM: usize = 6;
1090
1091use gam_math::jet_partitions::MultiDirJet as LatentMultiDirJet;
1092
1093/// Derivatives of `log(x)` through 4th order.
1094///
1095/// # Contract
1096///
1097/// `x` must be strictly positive. This function does NOT clamp: a previous
1098/// version replaced `x` by `x.max(1e-300)`, which fabricated enormous finite
1099/// derivatives (`1/1e-300` etc.) that are the derivatives of neither `log(x)`
1100/// nor `log(max(x, floor))` and would silently mask an upstream domain
1101/// failure. Both callers guarantee `x > 0`: one composes at the literal `1.0`
1102/// (the normalised log-sum base); the other passes `base`, which is gated by
1103/// an explicit `base.is_finite() && base > 0.0` check immediately upstream. A
1104/// non-positive `x` therefore never reaches here on any supported path; were
1105/// one to, the function returns the honest IEEE result (`-inf`/`NaN`) —
1106/// identical in debug and release — rather than a finite fabrication. For all
1107/// valid `x > 0` the output is bit-identical to the previous clamped version.
1108#[inline]
1109fn latent_unary_derivatives_log(x: f64) -> [f64; 5] {
1110    let x2 = x * x;
1111    let x3 = x2 * x;
1112    let x4 = x3 * x;
1113    [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
1114}
1115
1116#[derive(Clone, Copy, Debug)]
1117struct LatentKernelPrimaryTerm {
1118    coeff: f64,
1119    q_exp: usize,
1120    qdot_power: usize,
1121    tau_exp: usize,
1122    k: usize,
1123}
1124
1125#[derive(Clone, Copy, Debug)]
1126struct LatentKernelPrimaryDirection {
1127    dq: f64,
1128    dqd: f64,
1129    dmu: f64,
1130    dtau: f64,
1131}
1132
1133#[derive(Clone, Copy, Debug)]
1134struct LatentSurvivalPrimaryDirection {
1135    dq_entry: f64,
1136    dq_exit: f64,
1137    dqdot_exit: f64,
1138    dq_right: f64,
1139    dmu: f64,
1140    dlog_sigma: f64,
1141}
1142
1143#[derive(Clone, Copy, Debug)]
1144struct LatentKernelPrimaryState {
1145    q: f64,
1146    qdot: f64,
1147    mu: f64,
1148    sigma: f64,
1149    log_sigma_factor: f64,
1150}
1151
1152fn latent_kernel_accumulate_term(
1153    terms: &mut BTreeMap<(usize, usize, usize, usize), f64>,
1154    term: LatentKernelPrimaryTerm,
1155    scale: f64,
1156) {
1157    if scale == 0.0 || term.coeff == 0.0 {
1158        return;
1159    }
1160    *terms
1161        .entry((term.q_exp, term.qdot_power, term.tau_exp, term.k))
1162        .or_insert(0.0) += scale * term.coeff;
1163}
1164
1165fn latent_kernel_differentiate_terms(
1166    terms: &[LatentKernelPrimaryTerm],
1167    dir: LatentKernelPrimaryDirection,
1168) -> Vec<LatentKernelPrimaryTerm> {
1169    let mut out = BTreeMap::<(usize, usize, usize, usize), f64>::new();
1170    for term in terms {
1171        if dir.dq != 0.0 {
1172            if term.q_exp > 0 {
1173                latent_kernel_accumulate_term(&mut out, *term, dir.dq * term.q_exp as f64);
1174            }
1175            latent_kernel_accumulate_term(
1176                &mut out,
1177                LatentKernelPrimaryTerm {
1178                    q_exp: term.q_exp + 1,
1179                    k: term.k + 1,
1180                    ..*term
1181                },
1182                -dir.dq,
1183            );
1184        }
1185        if dir.dmu != 0.0 {
1186            if term.k > 0 {
1187                latent_kernel_accumulate_term(&mut out, *term, dir.dmu * term.k as f64);
1188            }
1189            latent_kernel_accumulate_term(
1190                &mut out,
1191                LatentKernelPrimaryTerm {
1192                    q_exp: term.q_exp + 1,
1193                    k: term.k + 1,
1194                    ..*term
1195                },
1196                -dir.dmu,
1197            );
1198        }
1199        if dir.dtau != 0.0 {
1200            if term.tau_exp > 0 {
1201                latent_kernel_accumulate_term(&mut out, *term, dir.dtau * term.tau_exp as f64);
1202            }
1203            let kf = term.k as f64;
1204            latent_kernel_accumulate_term(
1205                &mut out,
1206                LatentKernelPrimaryTerm {
1207                    tau_exp: term.tau_exp + 2,
1208                    ..*term
1209                },
1210                dir.dtau * kf * kf,
1211            );
1212            latent_kernel_accumulate_term(
1213                &mut out,
1214                LatentKernelPrimaryTerm {
1215                    q_exp: term.q_exp + 1,
1216                    tau_exp: term.tau_exp + 2,
1217                    k: term.k + 1,
1218                    ..*term
1219                },
1220                -dir.dtau * (2.0 * kf + 1.0),
1221            );
1222            latent_kernel_accumulate_term(
1223                &mut out,
1224                LatentKernelPrimaryTerm {
1225                    q_exp: term.q_exp + 2,
1226                    tau_exp: term.tau_exp + 2,
1227                    k: term.k + 2,
1228                    ..*term
1229                },
1230                dir.dtau,
1231            );
1232        }
1233        if dir.dqd != 0.0 && term.qdot_power > 0 {
1234            latent_kernel_accumulate_term(
1235                &mut out,
1236                LatentKernelPrimaryTerm {
1237                    qdot_power: term.qdot_power - 1,
1238                    ..*term
1239                },
1240                dir.dqd * term.qdot_power as f64,
1241            );
1242        }
1243    }
1244    out.into_iter()
1245        .filter_map(|((q_exp, qdot_power, tau_exp, k), coeff)| {
1246            (coeff != 0.0).then_some(LatentKernelPrimaryTerm {
1247                coeff,
1248                q_exp,
1249                qdot_power,
1250                tau_exp,
1251                k,
1252            })
1253        })
1254        .collect()
1255}
1256
1257fn latent_kernel_term_lists_for_directions(
1258    base_terms: &[LatentKernelPrimaryTerm],
1259    directions: &[LatentKernelPrimaryDirection],
1260) -> Vec<Vec<LatentKernelPrimaryTerm>> {
1261    fn build_mask(
1262        mask: usize,
1263        base_terms: &[LatentKernelPrimaryTerm],
1264        directions: &[LatentKernelPrimaryDirection],
1265        cache: &mut [Option<Vec<LatentKernelPrimaryTerm>>],
1266    ) -> Vec<LatentKernelPrimaryTerm> {
1267        if let Some(existing) = &cache[mask] {
1268            return existing.clone();
1269        }
1270        let built = if mask == 0 {
1271            base_terms.to_vec()
1272        } else {
1273            let bit = 1usize << mask.trailing_zeros();
1274            let prev = build_mask(mask ^ bit, base_terms, directions, cache);
1275            latent_kernel_differentiate_terms(&prev, directions[bit.trailing_zeros() as usize])
1276        };
1277        cache[mask] = Some(built.clone());
1278        built
1279    }
1280
1281    let mut cache = vec![None; 1usize << directions.len()];
1282    (0..cache.len())
1283        .map(|mask| build_mask(mask, base_terms, directions, &mut cache))
1284        .collect()
1285}
1286
1287fn latent_kernel_sum_log_jet(
1288    quadctx: &QuadratureContext,
1289    base_terms: &[LatentKernelPrimaryTerm],
1290    state: LatentKernelPrimaryState,
1291    directions: &[LatentKernelPrimaryDirection],
1292    context: &str,
1293) -> Result<LatentMultiDirJet, LatentSurvivalError> {
1294    let term_lists = latent_kernel_term_lists_for_directions(base_terms, directions);
1295    let max_k = term_lists
1296        .iter()
1297        .flat_map(|terms| terms.iter().map(|term| term.k))
1298        .max()
1299        .unwrap_or(0);
1300    let bundle =
1301        log_kernel_bundle(quadctx, state.q.exp(), state.mu, state.sigma, max_k).map_err(|e| {
1302            LatentSurvivalError::NumericalFailure {
1303                reason: format!("{context} kernel evaluation failed: {e}"),
1304            }
1305        })?;
1306
1307    let evaluate_terms =
1308        |terms: &[LatentKernelPrimaryTerm]| -> Result<(f64, f64), LatentSurvivalError> {
1309            let mut log_mags = Vec::new();
1310            let mut signs = Vec::new();
1311            for term in terms {
1312                if term.coeff == 0.0 {
1313                    continue;
1314                }
1315                if term.qdot_power > 0 && !(state.qdot.is_finite() && state.qdot > 0.0) {
1316                    return Err(LatentSurvivalError::NumericalFailure {
1317                        reason: format!(
1318                            "{context} requires positive finite qdot for exact-event directional terms, got {}",
1319                            state.qdot
1320                        ),
1321                    });
1322                }
1323                let log_qdot = if term.qdot_power > 0 {
1324                    state.qdot.ln()
1325                } else {
1326                    0.0
1327                };
1328                let log_mag = term.coeff.abs().ln()
1329                    + term.q_exp as f64 * state.q
1330                    + term.tau_exp as f64 * state.log_sigma_factor
1331                    + term.qdot_power as f64 * log_qdot
1332                    + bundle.get(term.k);
1333                log_mags.push(log_mag);
1334                signs.push(term.coeff.signum());
1335            }
1336            if log_mags.is_empty() {
1337                return Ok((f64::NEG_INFINITY, 0.0));
1338            }
1339            Ok(signed_log_sum_exp(&log_mags, &signs))
1340        };
1341
1342    let (base_log_sum, base_sign) = evaluate_terms(&term_lists[0])?;
1343    if !(base_log_sum.is_finite() && base_sign > 0.0) {
1344        return Err(LatentSurvivalError::NumericalFailure {
1345            reason: format!("{context} produced a non-positive signed kernel sum"),
1346        });
1347    }
1348
1349    let mut normalized = LatentMultiDirJet::constant(directions.len(), 1.0);
1350    for mask in 1..term_lists.len() {
1351        let (log_abs, sign) = evaluate_terms(&term_lists[mask])?;
1352        normalized.coeffs[mask] = if !log_abs.is_finite() || sign == 0.0 {
1353            0.0
1354        } else {
1355            sign * (log_abs - base_log_sum).exp()
1356        };
1357    }
1358
1359    let mut out = normalized.compose_unary(latent_unary_derivatives_log(1.0));
1360    out.coeffs[0] += base_log_sum;
1361    Ok(out)
1362}
1363
1364fn latent_survival_basis_direction(primary_idx: usize) -> LatentSurvivalPrimaryDirection {
1365    match primary_idx {
1366        LATENT_SURVIVAL_PRIMARY_Q_ENTRY => LatentSurvivalPrimaryDirection {
1367            dq_entry: 1.0,
1368            dq_exit: 0.0,
1369            dqdot_exit: 0.0,
1370            dq_right: 0.0,
1371            dmu: 0.0,
1372            dlog_sigma: 0.0,
1373        },
1374        LATENT_SURVIVAL_PRIMARY_Q_EXIT => LatentSurvivalPrimaryDirection {
1375            dq_entry: 0.0,
1376            dq_exit: 1.0,
1377            dqdot_exit: 0.0,
1378            dq_right: 0.0,
1379            dmu: 0.0,
1380            dlog_sigma: 0.0,
1381        },
1382        LATENT_SURVIVAL_PRIMARY_QDOT_EXIT => LatentSurvivalPrimaryDirection {
1383            dq_entry: 0.0,
1384            dq_exit: 0.0,
1385            dqdot_exit: 1.0,
1386            dq_right: 0.0,
1387            dmu: 0.0,
1388            dlog_sigma: 0.0,
1389        },
1390        LATENT_SURVIVAL_PRIMARY_Q_RIGHT => LatentSurvivalPrimaryDirection {
1391            dq_entry: 0.0,
1392            dq_exit: 0.0,
1393            dqdot_exit: 0.0,
1394            dq_right: 1.0,
1395            dmu: 0.0,
1396            dlog_sigma: 0.0,
1397        },
1398        LATENT_SURVIVAL_PRIMARY_MU => LatentSurvivalPrimaryDirection {
1399            dq_entry: 0.0,
1400            dq_exit: 0.0,
1401            dqdot_exit: 0.0,
1402            dq_right: 0.0,
1403            dmu: 1.0,
1404            dlog_sigma: 0.0,
1405        },
1406        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA => LatentSurvivalPrimaryDirection {
1407            dq_entry: 0.0,
1408            dq_exit: 0.0,
1409            dqdot_exit: 0.0,
1410            dq_right: 0.0,
1411            dmu: 0.0,
1412            dlog_sigma: 1.0,
1413        },
1414        // SAFETY: latent survival has exactly `LATENT_SURVIVAL_PRIMARY_DIM`
1415        // (= 5) primary directions, indexed 0..=4 via the module-private
1416        // `LATENT_SURVIVAL_PRIMARY_*` constants. All five are matched
1417        // above, so this wildcard fires only on an out-of-range index,
1418        // which the internal iteration bounds (`0..LATENT_SURVIVAL_PRIMARY_DIM`)
1419        // make unreachable.
1420        // SAFETY: primary_idx is bounded by LATENT_SURVIVAL_PRIMARY_DIM at every internal call site.
1421        _ => std::panic::panic_any(format!(
1422            "latent survival primary index out of bounds: primary_idx={primary_idx}, primary_dim={LATENT_SURVIVAL_PRIMARY_DIM}"
1423        )),
1424    }
1425}
1426
1427fn latent_survival_map_entry_direction(
1428    direction: LatentSurvivalPrimaryDirection,
1429) -> LatentKernelPrimaryDirection {
1430    LatentKernelPrimaryDirection {
1431        dq: direction.dq_entry,
1432        dqd: 0.0,
1433        dmu: direction.dmu,
1434        dtau: direction.dlog_sigma,
1435    }
1436}
1437
1438fn latent_survival_map_exit_direction(
1439    direction: LatentSurvivalPrimaryDirection,
1440    event_type: LatentSurvivalEventType,
1441) -> LatentKernelPrimaryDirection {
1442    LatentKernelPrimaryDirection {
1443        dq: direction.dq_exit,
1444        dqd: if matches!(event_type, LatentSurvivalEventType::ExactEvent) {
1445            direction.dqdot_exit
1446        } else {
1447            0.0
1448        },
1449        dmu: direction.dmu,
1450        dtau: direction.dlog_sigma,
1451    }
1452}
1453
1454/// Direction map for the interval-censored LEFT boundary state (mass `M_L =
1455/// exp(q_exit)`). The left boundary tracks the same `q_exit` time functional as
1456/// right-censoring (no hazard-derivative channel), plus the shared `mu`/`sigma`.
1457fn latent_survival_map_left_direction(
1458    direction: LatentSurvivalPrimaryDirection,
1459) -> LatentKernelPrimaryDirection {
1460    LatentKernelPrimaryDirection {
1461        dq: direction.dq_exit,
1462        dqd: 0.0,
1463        dmu: direction.dmu,
1464        dtau: direction.dlog_sigma,
1465    }
1466}
1467
1468/// Direction map for the interval-censored RIGHT boundary state (mass `M_R =
1469/// exp(q_right)`). The right boundary tracks the dedicated `q_right` functional
1470/// (which shares the time-block coefficients with `q_exit` but is evaluated at
1471/// the interval upper bound `R`), plus the shared `mu`/`sigma`.
1472fn latent_survival_map_right_direction(
1473    direction: LatentSurvivalPrimaryDirection,
1474) -> LatentKernelPrimaryDirection {
1475    LatentKernelPrimaryDirection {
1476        dq: direction.dq_right,
1477        dqd: 0.0,
1478        dmu: direction.dmu,
1479        dtau: direction.dlog_sigma,
1480    }
1481}
1482
1483fn latent_survival_row_primary_log_jet(
1484    quadctx: &QuadratureContext,
1485    row: &LatentSurvivalRow,
1486    q_entry: f64,
1487    q_exit: f64,
1488    qdot_exit: f64,
1489    q_right: f64,
1490    mu: f64,
1491    sigma: f64,
1492    log_sigma_factor: f64,
1493    directions: &[LatentSurvivalPrimaryDirection],
1494) -> Result<LatentMultiDirJet, String> {
1495    let entry_state = LatentKernelPrimaryState {
1496        q: q_entry,
1497        qdot: 1.0,
1498        mu,
1499        sigma,
1500        log_sigma_factor,
1501    };
1502    let entry_directions = directions
1503        .iter()
1504        .copied()
1505        .map(latent_survival_map_entry_direction)
1506        .collect::<Vec<_>>();
1507
1508    let denominator = latent_kernel_sum_log_jet(
1509        quadctx,
1510        &[LatentKernelPrimaryTerm {
1511            coeff: 1.0,
1512            q_exp: 0,
1513            qdot_power: 0,
1514            tau_exp: 0,
1515            k: 0,
1516        }],
1517        entry_state,
1518        &entry_directions,
1519        "latent survival denominator",
1520    )?;
1521
1522    // The numerator for right-censoring / exact events is a single-state log-sum
1523    // kernel at the exit mass. Interval censoring is the difference of two
1524    // single-state kernels at DIFFERENT masses (L at `q_exit`, R at `q_right`),
1525    // so it is assembled by `latent_survival_interval_numerator_log_jet` below.
1526    let numerator = match row.event_type {
1527        LatentSurvivalEventType::RightCensored | LatentSurvivalEventType::ExactEvent => {
1528            let exit_state = LatentKernelPrimaryState {
1529                q: q_exit,
1530                qdot: qdot_exit,
1531                mu,
1532                sigma,
1533                log_sigma_factor,
1534            };
1535            let exit_directions = directions
1536                .iter()
1537                .copied()
1538                .map(|dir| latent_survival_map_exit_direction(dir, row.event_type))
1539                .collect::<Vec<_>>();
1540            let numerator_terms = match row.event_type {
1541                LatentSurvivalEventType::RightCensored => vec![LatentKernelPrimaryTerm {
1542                    coeff: 1.0,
1543                    q_exp: 0,
1544                    qdot_power: 0,
1545                    tau_exp: 0,
1546                    k: 0,
1547                }],
1548                LatentSurvivalEventType::ExactEvent => {
1549                    let mut terms = Vec::new();
1550                    if row.hazard_unloaded > 0.0 {
1551                        terms.push(LatentKernelPrimaryTerm {
1552                            coeff: row.hazard_unloaded,
1553                            q_exp: 0,
1554                            qdot_power: 0,
1555                            tau_exp: 0,
1556                            k: 0,
1557                        });
1558                    }
1559                    terms.push(LatentKernelPrimaryTerm {
1560                        coeff: 1.0,
1561                        q_exp: 1,
1562                        qdot_power: 1,
1563                        tau_exp: 0,
1564                        k: 1,
1565                    });
1566                    terms
1567                }
1568                LatentSurvivalEventType::IntervalCensored => {
1569                    // Interval-censored rows are routed to the dedicated two-state
1570                    // numerator branch (the outer match arm below), so this inner
1571                    // arm is not reached; a clean error rather than a panic guards
1572                    // against a future routing change.
1573                    return Err(
1574                        "interval-censored row reached the single-state numerator branch; \
1575                         it must take the dedicated two-state branch"
1576                            .to_string(),
1577                    );
1578                }
1579            };
1580            latent_kernel_sum_log_jet(
1581                quadctx,
1582                &numerator_terms,
1583                exit_state,
1584                &exit_directions,
1585                "latent survival numerator",
1586            )?
1587        }
1588        LatentSurvivalEventType::IntervalCensored => latent_survival_interval_numerator_log_jet(
1589            quadctx,
1590            row,
1591            q_exit,
1592            q_right,
1593            mu,
1594            sigma,
1595            log_sigma_factor,
1596            directions,
1597        )?,
1598    };
1599
1600    let mut total = numerator.add(&denominator.scale(-1.0));
1601    // For interval rows the unloaded exit mass is folded into the per-boundary
1602    // coefficients `exp(-mass_unloaded_{left,right})` inside the two-state
1603    // numerator, so only the (constant) unloaded-entry term remains here; for
1604    // right-censoring / exact events the exit/entry unloaded masses are an
1605    // additive constant on the log-likelihood.
1606    match row.event_type {
1607        LatentSurvivalEventType::IntervalCensored => {
1608            total.coeffs[0] += row.mass_unloaded_entry;
1609        }
1610        _ => {
1611            total.coeffs[0] += -row.mass_unloaded_exit + row.mass_unloaded_entry;
1612        }
1613    }
1614    Ok(total)
1615}
1616
1617/// Interval-censored numerator jet `log[ c_L·K_{0,M_L} − c_R·K_{0,M_R} ]` where
1618/// `M_L = exp(q_exit)`, `M_R = exp(q_right)`, `c_L = exp(-mass_unloaded_left)`
1619/// and `c_R = exp(-mass_unloaded_right)`.
1620///
1621/// This is the dynamic-time analogue of the static
1622/// [`LatentSurvivalRowJet::interval_censored`] kernel: the interval likelihood
1623/// is the difference of two BOUNDARY survival masses, each a single-state
1624/// order-0 kernel, but at two DISTINCT cumulative masses. Because the two
1625/// boundaries respond to different time functionals (`q_exit` vs `q_right`) we
1626/// cannot fold them into one `latent_kernel_sum_log_jet` state. Instead we:
1627///   1. build each boundary's `log K_{0,M}` jet at its own state, with its own
1628///      direction map (left → `dq_exit`, right → `dq_right`; both share
1629///      `mu`/`sigma`),
1630///   2. lift each to the LINEAR domain via `exp` (a unary composition whose five
1631///      derivatives at value `v` are all `exp(v)`), scaled by its coefficient
1632///      `c_L` (resp. `−c_R`),
1633///   3. add the two linear-domain jets, and
1634///   4. drop back to the log domain via the same `log` unary composition the
1635///      single-state path uses.
1636/// Every multi-direction coefficient (value, score, neg-Hessian, 3rd, 4th)
1637/// follows by the Faà-di-Bruno composition already implemented in
1638/// `MultiDirJet::compose_unary`, so the derivative reductions are consistent
1639/// with the exact-event/right-censored branches by construction.
1640fn latent_survival_interval_numerator_log_jet(
1641    quadctx: &QuadratureContext,
1642    row: &LatentSurvivalRow,
1643    q_exit: f64,
1644    q_right: f64,
1645    mu: f64,
1646    sigma: f64,
1647    log_sigma_factor: f64,
1648    directions: &[LatentSurvivalPrimaryDirection],
1649) -> Result<LatentMultiDirJet, String> {
1650    let single_k0 = [LatentKernelPrimaryTerm {
1651        coeff: 1.0,
1652        q_exp: 0,
1653        qdot_power: 0,
1654        tau_exp: 0,
1655        k: 0,
1656    }];
1657
1658    let left_state = LatentKernelPrimaryState {
1659        q: q_exit,
1660        qdot: 1.0,
1661        mu,
1662        sigma,
1663        log_sigma_factor,
1664    };
1665    let right_state = LatentKernelPrimaryState {
1666        q: q_right,
1667        qdot: 1.0,
1668        mu,
1669        sigma,
1670        log_sigma_factor,
1671    };
1672    let left_directions = directions
1673        .iter()
1674        .copied()
1675        .map(latent_survival_map_left_direction)
1676        .collect::<Vec<_>>();
1677    let right_directions = directions
1678        .iter()
1679        .copied()
1680        .map(latent_survival_map_right_direction)
1681        .collect::<Vec<_>>();
1682
1683    let log_left = latent_kernel_sum_log_jet(
1684        quadctx,
1685        &single_k0,
1686        left_state,
1687        &left_directions,
1688        "latent survival interval left boundary",
1689    )?;
1690    let log_right = latent_kernel_sum_log_jet(
1691        quadctx,
1692        &single_k0,
1693        right_state,
1694        &right_directions,
1695        "latent survival interval right boundary",
1696    )?;
1697
1698    // Lift each boundary's log-kernel jet to the linear domain and scale by the
1699    // unloaded-mass prefactor. exp''''(v) = exp(v) for all orders, so the unary
1700    // derivative tower is `[exp(v); exp(v); exp(v); exp(v); exp(v)]`.
1701    let c_left = (-row.mass_unloaded_left).exp();
1702    let c_right = (-row.mass_unloaded_right).exp();
1703    let exp_left_value = log_left.coeff(0).exp();
1704    let exp_right_value = log_right.coeff(0).exp();
1705    let linear_left = log_left.compose_unary([exp_left_value; 5]).scale(c_left);
1706    let linear_right = log_right.compose_unary([exp_right_value; 5]).scale(c_right);
1707
1708    let linear_numerator = linear_left.add(&linear_right.scale(-1.0));
1709    let base = linear_numerator.coeff(0);
1710    if !(base.is_finite() && base > 0.0) {
1711        return Err(LatentSurvivalError::NumericalFailure {
1712            reason: format!(
1713                "latent survival interval numerator must be a positive survival-mass difference, \
1714                 got c_L*K0(M_L) - c_R*K0(M_R) = {base}; require M_L < M_R (i.e. L < R)"
1715            ),
1716        }
1717        .into());
1718    }
1719    // Drop back to the log domain. `latent_unary_derivatives_log(base)` is the
1720    // unary derivative tower of `ln` at the positive base value, so the composed
1721    // value channel is `ln(base)` and the higher coefficients are the
1722    // log-of-a-difference score / curvature, consistent with the single-state
1723    // log-sum path (which composes `ln` at its normalised base of 1).
1724    Ok(linear_numerator.compose_unary(latent_unary_derivatives_log(base)))
1725}
1726
1727fn latent_survival_row_primary_gradient_hessian(
1728    quadctx: &QuadratureContext,
1729    row: &LatentSurvivalRow,
1730    q_entry: f64,
1731    q_exit: f64,
1732    qdot_exit: f64,
1733    q_right: f64,
1734    mu: f64,
1735    sigma: f64,
1736    include_log_sigma: bool,
1737) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
1738    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1739    let mut gradient = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1740    let mut neg_hessian =
1741        Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1742    let active_primary = if include_log_sigma {
1743        LATENT_SURVIVAL_PRIMARY_DIM
1744    } else {
1745        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1746    };
1747    let log_lik = latent_survival_row_primary_log_jet(
1748        quadctx,
1749        row,
1750        q_entry,
1751        q_exit,
1752        qdot_exit,
1753        q_right,
1754        mu,
1755        sigma,
1756        log_sigma_factor,
1757        &[],
1758    )?
1759    .coeff(0);
1760    for a in 0..active_primary {
1761        let dir_a = latent_survival_basis_direction(a);
1762        gradient[a] = latent_survival_row_primary_log_jet(
1763            quadctx,
1764            row,
1765            q_entry,
1766            q_exit,
1767            qdot_exit,
1768            q_right,
1769            mu,
1770            sigma,
1771            log_sigma_factor,
1772            &[dir_a],
1773        )?
1774        .coeff(1);
1775        for b in a..active_primary {
1776            let coeff = latent_survival_row_primary_log_jet(
1777                quadctx,
1778                row,
1779                q_entry,
1780                q_exit,
1781                qdot_exit,
1782                q_right,
1783                mu,
1784                sigma,
1785                log_sigma_factor,
1786                &[dir_a, latent_survival_basis_direction(b)],
1787            )?
1788            .coeff(3);
1789            neg_hessian[[a, b]] = -coeff;
1790            neg_hessian[[b, a]] = -coeff;
1791        }
1792    }
1793    Ok((log_lik, gradient, neg_hessian))
1794}
1795
1796fn latent_survival_row_primary_third_contracted(
1797    quadctx: &QuadratureContext,
1798    row: &LatentSurvivalRow,
1799    q_entry: f64,
1800    q_exit: f64,
1801    qdot_exit: f64,
1802    q_right: f64,
1803    mu: f64,
1804    sigma: f64,
1805    direction: &Array1<f64>,
1806    include_log_sigma: bool,
1807) -> Result<Array2<f64>, String> {
1808    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1809    let active_primary = if include_log_sigma {
1810        LATENT_SURVIVAL_PRIMARY_DIM
1811    } else {
1812        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1813    };
1814    let dir = LatentSurvivalPrimaryDirection {
1815        dq_entry: direction[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1816        dq_exit: direction[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1817        dqdot_exit: direction[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1818        dq_right: direction[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1819        dmu: direction[LATENT_SURVIVAL_PRIMARY_MU],
1820        dlog_sigma: direction[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1821    };
1822    let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1823    for a in 0..active_primary {
1824        let dir_a = latent_survival_basis_direction(a);
1825        for b in a..active_primary {
1826            let coeff = latent_survival_row_primary_log_jet(
1827                quadctx,
1828                row,
1829                q_entry,
1830                q_exit,
1831                qdot_exit,
1832                q_right,
1833                mu,
1834                sigma,
1835                log_sigma_factor,
1836                &[dir_a, latent_survival_basis_direction(b), dir],
1837            )?
1838            .coeff(7);
1839            out[[a, b]] = -coeff;
1840            out[[b, a]] = -coeff;
1841        }
1842    }
1843    Ok(out)
1844}
1845
1846fn latent_survival_row_primary_fourth_contracted(
1847    quadctx: &QuadratureContext,
1848    row: &LatentSurvivalRow,
1849    q_entry: f64,
1850    q_exit: f64,
1851    qdot_exit: f64,
1852    q_right: f64,
1853    mu: f64,
1854    sigma: f64,
1855    direction_u: &Array1<f64>,
1856    direction_v: &Array1<f64>,
1857    include_log_sigma: bool,
1858) -> Result<Array2<f64>, String> {
1859    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1860    let active_primary = if include_log_sigma {
1861        LATENT_SURVIVAL_PRIMARY_DIM
1862    } else {
1863        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1864    };
1865    let dir_u = LatentSurvivalPrimaryDirection {
1866        dq_entry: direction_u[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1867        dq_exit: direction_u[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1868        dqdot_exit: direction_u[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1869        dq_right: direction_u[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1870        dmu: direction_u[LATENT_SURVIVAL_PRIMARY_MU],
1871        dlog_sigma: direction_u[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1872    };
1873    let dir_v = LatentSurvivalPrimaryDirection {
1874        dq_entry: direction_v[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1875        dq_exit: direction_v[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1876        dqdot_exit: direction_v[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1877        dq_right: direction_v[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1878        dmu: direction_v[LATENT_SURVIVAL_PRIMARY_MU],
1879        dlog_sigma: direction_v[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1880    };
1881    let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1882    for a in 0..active_primary {
1883        let dir_a = latent_survival_basis_direction(a);
1884        for b in a..active_primary {
1885            let coeff = latent_survival_row_primary_log_jet(
1886                quadctx,
1887                row,
1888                q_entry,
1889                q_exit,
1890                qdot_exit,
1891                q_right,
1892                mu,
1893                sigma,
1894                log_sigma_factor,
1895                &[dir_a, latent_survival_basis_direction(b), dir_u, dir_v],
1896            )?
1897            .coeff(15);
1898            out[[a, b]] = -coeff;
1899            out[[b, a]] = -coeff;
1900        }
1901    }
1902    Ok(out)
1903}
1904
1905#[derive(Clone)]
1906struct LatentSurvivalJointSlices {
1907    time: std::ops::Range<usize>,
1908    mean: std::ops::Range<usize>,
1909    log_sigma: Option<std::ops::Range<usize>>,
1910    total: usize,
1911}
1912
1913#[derive(Clone)]
1914struct LatentSurvivalJointGradientAccum {
1915    ll: f64,
1916    gradient: Array1<f64>,
1917}
1918
1919#[derive(Clone)]
1920struct LatentSurvivalJointDenseAccum {
1921    ll: f64,
1922    gradient: Array1<f64>,
1923    hessian: Array2<f64>,
1924}
1925
1926#[derive(Clone)]
1927struct LatentSurvivalDenseHessianAccum {
1928    hessian: Array2<f64>,
1929}
1930
1931/// Process latent-survival rows in fixed contiguous chunks, using one
1932/// accumulator per rayon task and reducing those accumulators in chunk-index
1933/// order so gradient/Hessian assembly stays deterministic across runs.
1934fn deterministic_latent_survival_row_reduction<Acc, Init, Process, Combine>(
1935    n_rows: usize,
1936    init: Init,
1937    process_row: Process,
1938    mut combine: Combine,
1939) -> Result<Acc, String>
1940where
1941    Acc: Send,
1942    Init: Fn() -> Acc + Sync,
1943    Process: Fn(usize, &mut Acc) -> Result<(), String> + Sync,
1944    Combine: FnMut(&mut Acc, Acc),
1945{
1946    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1947
1948    const TARGET_CHUNK_COUNT: usize = 32;
1949    if n_rows == 0 {
1950        return Ok(init());
1951    }
1952    let chunk_size = n_rows.div_ceil(TARGET_CHUNK_COUNT).max(1);
1953    let n_chunks = n_rows.div_ceil(chunk_size);
1954    let chunk_accumulators: Vec<Acc> = (0..n_chunks)
1955        .into_par_iter()
1956        .map(|chunk_idx| -> Result<Acc, String> {
1957            let start = chunk_idx * chunk_size;
1958            let end = (start + chunk_size).min(n_rows);
1959            let mut acc = init();
1960            for row_idx in start..end {
1961                process_row(row_idx, &mut acc)?;
1962            }
1963            Ok(acc)
1964        })
1965        .collect::<Result<Vec<_>, String>>()?;
1966
1967    let mut total = init();
1968    for acc in chunk_accumulators {
1969        combine(&mut total, acc);
1970    }
1971    Ok(total)
1972}
1973
1974impl LatentSurvivalFamily {
1975    /// Assemble the per-row [`LatentSurvivalRow`] for `row_idx` from the family's
1976    /// unloaded-mass/hazard fields and the supplied per-row time quantiles.
1977    ///
1978    /// Shared by every per-row reduction (log-likelihood, gradient, Hessian,
1979    /// directional third derivatives): each previously inlined an identical
1980    /// `event_type` lookup followed by the same 12-argument
1981    /// `build_latent_survival_row` call. Behavior is unchanged.
1982    fn build_row_at(
1983        &self,
1984        row_idx: usize,
1985        q_entry: f64,
1986        q_exit: f64,
1987        qdot_exit: f64,
1988        q_right: f64,
1989    ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
1990        let event_type = latent_survival_event_type_for(self.event_target[row_idx]);
1991        build_latent_survival_row(
1992            row_idx,
1993            self.hazard_loading,
1994            event_type,
1995            q_entry,
1996            q_exit,
1997            qdot_exit,
1998            q_right,
1999            self.unloaded_mass_entry[row_idx],
2000            self.unloaded_mass_exit[row_idx],
2001            self.unloaded_mass_right[row_idx],
2002            self.unloaded_hazard_exit[row_idx],
2003        )
2004    }
2005
2006    fn joint_slices(&self) -> LatentSurvivalJointSlices {
2007        let p_time = self.x_time_exit.ncols();
2008        let p_mean = self.x_mean.ncols();
2009        let time = 0..p_time;
2010        let mean = p_time..p_time + p_mean;
2011        let log_sigma = self
2012            .latent_sd_fixed
2013            .is_none()
2014            .then_some((p_time + p_mean)..(p_time + p_mean + 1));
2015        LatentSurvivalJointSlices {
2016            total: log_sigma
2017                .as_ref()
2018                .map_or(p_time + p_mean, |range| range.end),
2019            time,
2020            mean,
2021            log_sigma,
2022        }
2023    }
2024
2025    fn row_primary_direction_from_flat(
2026        &self,
2027        row: usize,
2028        slices: &LatentSurvivalJointSlices,
2029        d_beta_flat: &Array1<f64>,
2030    ) -> Array1<f64> {
2031        let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
2032        let d_time = d_beta_flat.slice(s![slices.time.clone()]);
2033        out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
2034        out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
2035        out[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT] = self.x_time_derivative_exit.row(row).dot(&d_time);
2036        out[LATENT_SURVIVAL_PRIMARY_Q_RIGHT] = self.x_time_right.row(row).dot(&d_time);
2037        out[LATENT_SURVIVAL_PRIMARY_MU] = self
2038            .x_mean
2039            .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
2040        if let Some(range) = &slices.log_sigma {
2041            out[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA] = d_beta_flat[range.start];
2042        }
2043        out
2044    }
2045
2046    fn joint_block_ranges(&self) -> Vec<std::ops::Range<usize>> {
2047        let slices = self.joint_slices();
2048        let mut ranges = vec![slices.time.clone(), slices.mean.clone()];
2049        if let Some(log_sigma) = slices.log_sigma {
2050            ranges.push(log_sigma);
2051        }
2052        ranges
2053    }
2054
2055    fn add_pullback_primary_gradient(
2056        &self,
2057        target: &mut Array1<f64>,
2058        row: usize,
2059        slices: &LatentSurvivalJointSlices,
2060        primary_gradient: &Array1<f64>,
2061        weight: f64,
2062    ) -> Result<(), String> {
2063        for (primary_idx, time_vec) in [
2064            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2065            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2066            (
2067                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2068                self.x_time_derivative_exit.row(row),
2069            ),
2070            (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2071        ] {
2072            let scale = weight * primary_gradient[primary_idx];
2073            if scale == 0.0 {
2074                continue;
2075            }
2076            for i in 0..time_vec.len() {
2077                let xi = time_vec[i];
2078                if xi != 0.0 {
2079                    target[slices.time.start + i] += scale * xi;
2080                }
2081            }
2082        }
2083
2084        let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
2085        if mean_scale != 0.0 {
2086            self.x_mean
2087                .axpy_row_into(
2088                    row,
2089                    mean_scale,
2090                    &mut target.slice_mut(s![slices.mean.clone()]),
2091                )
2092                .map_err(|error| {
2093                    format!(
2094                        "latent survival mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
2095                        slices.mean,
2096                        target.len(),
2097                        self.x_mean.ncols()
2098                    )
2099                })?;
2100        }
2101
2102        if let Some(log_sigma) = &slices.log_sigma {
2103            target[log_sigma.start] += weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA];
2104        }
2105        Ok(())
2106    }
2107
2108    fn add_pullback_primary_hessian(
2109        &self,
2110        target: &mut Array2<f64>,
2111        row: usize,
2112        slices: &LatentSurvivalJointSlices,
2113        primary_hessian: &Array2<f64>,
2114    ) -> Result<(), String> {
2115        let time_weights = [
2116            primary_hessian[[
2117                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2118                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2119            ]],
2120            primary_hessian[[
2121                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2122                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2123            ]],
2124            primary_hessian[[
2125                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2126                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2127            ]],
2128            primary_hessian[[
2129                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2130                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2131            ]],
2132        ];
2133        let time_cross_weights = [
2134            (
2135                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2136                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2137                &self.x_time_entry,
2138                &self.x_time_exit,
2139            ),
2140            (
2141                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2142                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2143                &self.x_time_entry,
2144                &self.x_time_derivative_exit,
2145            ),
2146            (
2147                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2148                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2149                &self.x_time_exit,
2150                &self.x_time_derivative_exit,
2151            ),
2152            (
2153                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2154                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2155                &self.x_time_entry,
2156                &self.x_time_right,
2157            ),
2158            (
2159                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2160                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2161                &self.x_time_exit,
2162                &self.x_time_right,
2163            ),
2164            (
2165                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2166                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2167                &self.x_time_derivative_exit,
2168                &self.x_time_right,
2169            ),
2170        ];
2171        {
2172            let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
2173            dense_outer_accumulate(time_target, time_weights[0], self.x_time_entry.row(row));
2174            dense_outer_accumulate(time_target, time_weights[1], self.x_time_exit.row(row));
2175            dense_outer_accumulate(
2176                time_target,
2177                time_weights[2],
2178                self.x_time_derivative_exit.row(row),
2179            );
2180            dense_outer_accumulate(time_target, time_weights[3], self.x_time_right.row(row));
2181            for (a, b, lhs, rhs) in time_cross_weights {
2182                let weight = primary_hessian[[a, b]];
2183                if weight == 0.0 {
2184                    continue;
2185                }
2186                dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2187            }
2188        }
2189
2190        let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2191        self.x_mean
2192            .syr_row_into_view(
2193                row,
2194                mean_weight,
2195                target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
2196            )
2197            .map_err(|error| {
2198                format!(
2199                    "latent survival mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
2200                    slices.mean,
2201                    target.dim(),
2202                    self.x_mean.ncols()
2203                )
2204            })?;
2205
2206        let mean_row = self
2207            .x_mean
2208            .try_row_chunk(row..row + 1)
2209            .map_err(|error| {
2210                format!(
2211                    "latent survival mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
2212                    self.x_mean.nrows(),
2213                    self.x_mean.ncols()
2214                )
2215            })?;
2216        let mean_vec = mean_row.row(0);
2217        let time_mean_weights = [
2218            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2219            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2220            (
2221                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2222                self.x_time_derivative_exit.row(row),
2223            ),
2224            (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2225        ];
2226        for (primary_idx, time_vec) in time_mean_weights {
2227            let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
2228            if weight == 0.0 {
2229                continue;
2230            }
2231            for i in 0..time_vec.len() {
2232                let xi = time_vec[i];
2233                if xi == 0.0 {
2234                    continue;
2235                }
2236                for j in 0..mean_vec.len() {
2237                    let xj = mean_vec[j];
2238                    if xj == 0.0 {
2239                        continue;
2240                    }
2241                    target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
2242                    target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
2243                }
2244            }
2245        }
2246
2247        if let Some(log_sigma) = &slices.log_sigma {
2248            let sigma_idx = log_sigma.start;
2249            target[[sigma_idx, sigma_idx]] += primary_hessian[[
2250                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2251                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2252            ]];
2253
2254            for (primary_idx, time_vec) in [
2255                (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2256                (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2257                (
2258                    LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2259                    self.x_time_derivative_exit.row(row),
2260                ),
2261                (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2262            ] {
2263                let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_LOG_SIGMA]];
2264                if weight == 0.0 {
2265                    continue;
2266                }
2267                for i in 0..time_vec.len() {
2268                    let xi = time_vec[i];
2269                    if xi == 0.0 {
2270                        continue;
2271                    }
2272                    target[[slices.time.start + i, sigma_idx]] += weight * xi;
2273                    target[[sigma_idx, slices.time.start + i]] += weight * xi;
2274                }
2275            }
2276
2277            let mean_sigma_weight = primary_hessian[[
2278                LATENT_SURVIVAL_PRIMARY_MU,
2279                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2280            ]];
2281            if mean_sigma_weight != 0.0 {
2282                for j in 0..mean_vec.len() {
2283                    let xj = mean_vec[j];
2284                    if xj == 0.0 {
2285                        continue;
2286                    }
2287                    target[[slices.mean.start + j, sigma_idx]] += mean_sigma_weight * xj;
2288                    target[[sigma_idx, slices.mean.start + j]] += mean_sigma_weight * xj;
2289                }
2290            }
2291        }
2292        Ok(())
2293    }
2294
2295    fn evaluate_exact_newton_joint_gradient_dense(
2296        &self,
2297        block_states: &[ParameterBlockState],
2298    ) -> Result<(f64, Array1<f64>), String> {
2299        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2300        let q_right = self.time_q_right(block_states)?;
2301        let sigma = self.latent_sd(block_states)?;
2302        let slices = self.joint_slices();
2303        let include_log_sigma = slices.log_sigma.is_some();
2304        let total = slices.total;
2305        let acc = deterministic_latent_survival_row_reduction(
2306            self.event_target.len(),
2307            || LatentSurvivalJointGradientAccum {
2308                ll: 0.0,
2309                gradient: Array1::<f64>::zeros(total),
2310            },
2311            |row_idx, acc| {
2312                let wi = self.weights[row_idx];
2313                if wi <= MIN_WEIGHT {
2314                    return Ok(());
2315                }
2316                let row = self.build_row_at(
2317                    row_idx,
2318                    q_entry[row_idx],
2319                    q_exit[row_idx],
2320                    qdot_exit[row_idx],
2321                    q_right[row_idx],
2322                )?;
2323                let (row_ll, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2324                    &self.quadctx,
2325                    &row,
2326                    q_entry[row_idx],
2327                    q_exit[row_idx],
2328                    qdot_exit[row_idx],
2329                    q_right[row_idx],
2330                    mu[row_idx],
2331                    sigma,
2332                    include_log_sigma,
2333                )?;
2334                acc.ll += wi * row_ll;
2335                self.add_pullback_primary_gradient(
2336                    &mut acc.gradient,
2337                    row_idx,
2338                    &slices,
2339                    &primary_gradient,
2340                    wi,
2341                )?;
2342                Ok(())
2343            },
2344            |total_acc, chunk_acc| {
2345                total_acc.ll += chunk_acc.ll;
2346                total_acc.gradient += &chunk_acc.gradient;
2347            },
2348        )?;
2349        Ok((acc.ll, acc.gradient))
2350    }
2351
2352    /// Per-row residuals of the unpenalized NLL with respect to the three
2353    /// additive baseline time-block offsets `(entry, exit, derivative)`.
2354    ///
2355    /// The baseline configuration θ enters the latent-survival working model
2356    /// only through the additive offsets on the three time channels
2357    ///   q_entry = x_time_entry·β_time + o_E(θ),
2358    ///   q_exit  = x_time_exit·β_time  + o_X(θ),
2359    ///   q̇_exit = x_time_deriv·β_time + o_D(θ),
2360    /// exactly the offset channel the transformation path carries through
2361    /// [`WorkingModelSurvival::offset_channel_residuals`]. Because
2362    /// `∂q_ch/∂o_ch = 1`, the residual `∂NLL/∂o_ch_i` equals
2363    /// `−∂(log-likelihood)/∂q_ch_i`, and the per-row primary log-likelihood
2364    /// gradient over `(q_entry, q_exit, q̇_exit)` is precisely the
2365    /// `Q_ENTRY`/`Q_EXIT`/`QDOT_EXIT` components returned by
2366    /// [`latent_survival_row_primary_gradient_hessian`]. Sampleweight-scaled to
2367    /// match the [`OffsetChannelResiduals`] contract consumed by
2368    /// `baseline_chain_rule_gradient`.
2369    ///
2370    /// At the converged (constrained) β̂ the envelope theorem makes this the
2371    /// exact θ-gradient of the profile penalized NLL `0.5·deviance + 0.5·βᵀSβ`.
2372    /// The interval upper-bound `q_right = x_time_right·β_time + o_R(θ)` channel
2373    /// DOES carry its own baseline-θ offset `o_R(θ)` (the time basis evaluated at
2374    /// the bracket upper bound `R`), distinct from the exit offset at `L`, so its
2375    /// residual `−∂(log-likelihood)/∂q_right` is returned in the dedicated
2376    /// [`OffsetChannelResiduals::right`] channel; it is exactly 0 on every
2377    /// non-interval row (the `Q_RIGHT` primary channel is inert there) and the
2378    /// baseline-θ chain rule contracts it against the `age_right`-evaluated
2379    /// η-partial.
2380    pub fn offset_channel_residuals(
2381        &self,
2382        block_states: &[ParameterBlockState],
2383    ) -> Result<crate::survival::OffsetChannelResiduals, String> {
2384        let n = self.event_target.len();
2385        if block_states.is_empty() {
2386            // Degraded-fit fallback mirroring the location-scale family: an
2387            // empty block-state slate (ARC deterministic-replay stall) yields
2388            // zero residuals so the outer baseline-θ BFGS sees ‖g‖ = 0 and
2389            // terminates cleanly at the current θ̂ rather than panicking.
2390            log::warn!(
2391                "LatentSurvivalFamily::offset_channel_residuals: block_states is empty \
2392                 (degraded fit); returning zero offset residuals (n={n})"
2393            );
2394            return Ok(crate::survival::OffsetChannelResiduals {
2395                exit: Array1::<f64>::zeros(n),
2396                entry: Array1::<f64>::zeros(n),
2397                derivative: Array1::<f64>::zeros(n),
2398                right: Array1::<f64>::zeros(n),
2399            });
2400        }
2401        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2402        let q_right = self.time_q_right(block_states)?;
2403        let sigma = self.latent_sd(block_states)?;
2404        let include_log_sigma = self.joint_slices().log_sigma.is_some();
2405        let mut entry = Array1::<f64>::zeros(n);
2406        let mut exit = Array1::<f64>::zeros(n);
2407        let mut derivative = Array1::<f64>::zeros(n);
2408        let mut right = Array1::<f64>::zeros(n);
2409        for row_idx in 0..n {
2410            let wi = self.weights[row_idx];
2411            if wi <= MIN_WEIGHT {
2412                continue;
2413            }
2414            let row = self.build_row_at(
2415                row_idx,
2416                q_entry[row_idx],
2417                q_exit[row_idx],
2418                qdot_exit[row_idx],
2419                q_right[row_idx],
2420            )?;
2421            let (_, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2422                &self.quadctx,
2423                &row,
2424                q_entry[row_idx],
2425                q_exit[row_idx],
2426                qdot_exit[row_idx],
2427                q_right[row_idx],
2428                mu[row_idx],
2429                sigma,
2430                include_log_sigma,
2431            )?;
2432            // ∂NLL/∂o_ch = −w · ∂(log-likelihood)/∂q_ch.
2433            entry[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
2434            exit[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
2435            derivative[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
2436            // Interval upper-bound (`R`) channel. `q_right` shares the time-block
2437            // coefficients but carries its OWN baseline-θ η-offset evaluated at
2438            // `R` (`o_R(θ)`), so the profile-NLL θ-gradient must include it.
2439            // `∂(log-likelihood)/∂q_right` is exactly 0 for non-interval rows
2440            // (the `Q_RIGHT` channel is inert there), so this is 0 except on
2441            // interval-censored rows.
2442            right[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
2443        }
2444        Ok(crate::survival::OffsetChannelResiduals {
2445            exit,
2446            entry,
2447            derivative,
2448            right,
2449        })
2450    }
2451
2452    /// Block-diagonal-only pullback: writes only time-time, mean-mean, and
2453    /// log_sigma-log_sigma rowwise contributions into per-block targets.
2454    /// Used by `evaluate()` to populate per-block working sets without ever
2455    /// materializing the cross blocks the inner solver does not consume.
2456    fn add_pullback_primary_block_diagonals(
2457        &self,
2458        row: usize,
2459        primary_hessian: &Array2<f64>,
2460        time_target: &mut Array2<f64>,
2461        mean_target: &mut Array2<f64>,
2462        log_sigma_target: Option<&mut Array2<f64>>,
2463    ) -> Result<(), String> {
2464        let h = primary_hessian;
2465        // Time block: 4 squared rows (entry/exit/qdot/right) + 6 symmetric
2466        // crosses. The interval right-boundary functional `q_right` shares the
2467        // time-block coefficients, so it accumulates into the same time target.
2468        dense_outer_accumulate(
2469            time_target,
2470            h[[
2471                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2472                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2473            ]],
2474            self.x_time_entry.row(row),
2475        );
2476        dense_outer_accumulate(
2477            time_target,
2478            h[[
2479                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2480                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2481            ]],
2482            self.x_time_exit.row(row),
2483        );
2484        dense_outer_accumulate(
2485            time_target,
2486            h[[
2487                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2488                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2489            ]],
2490            self.x_time_derivative_exit.row(row),
2491        );
2492        dense_outer_accumulate(
2493            time_target,
2494            h[[
2495                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2496                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2497            ]],
2498            self.x_time_right.row(row),
2499        );
2500        for (a, b, lhs, rhs) in [
2501            (
2502                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2503                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2504                &self.x_time_entry,
2505                &self.x_time_exit,
2506            ),
2507            (
2508                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2509                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2510                &self.x_time_entry,
2511                &self.x_time_derivative_exit,
2512            ),
2513            (
2514                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2515                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2516                &self.x_time_exit,
2517                &self.x_time_derivative_exit,
2518            ),
2519            (
2520                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2521                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2522                &self.x_time_entry,
2523                &self.x_time_right,
2524            ),
2525            (
2526                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2527                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2528                &self.x_time_exit,
2529                &self.x_time_right,
2530            ),
2531            (
2532                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2533                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2534                &self.x_time_derivative_exit,
2535                &self.x_time_right,
2536            ),
2537        ] {
2538            let weight = h[[a, b]];
2539            if weight == 0.0 {
2540                continue;
2541            }
2542            dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2543        }
2544        // Mean block.
2545        let mean_weight = h[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2546        self.x_mean
2547            .syr_row_into_view(row, mean_weight, mean_target.view_mut())
2548            .map_err(|error| {
2549                format!(
2550                    "latent survival mean block-diagonal pullback dimension mismatch: row={row}, mean_target_dim={:?}, x_mean_cols={}, error={error}",
2551                    mean_target.dim(),
2552                    self.x_mean.ncols()
2553                )
2554            })?;
2555        // Log-σ block (scalar).
2556        if let Some(target) = log_sigma_target {
2557            target[[0, 0]] += h[[
2558                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2559                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2560            ]];
2561        }
2562        Ok(())
2563    }
2564
2565    /// Block-diagonal evaluator used by `evaluate()`. Returns the per-row
2566    /// log-likelihood, the joint gradient (sliced into block gradients by
2567    /// the caller), and the three per-block diagonal Hessians without ever
2568    /// materializing the full joint matrix.
2569    fn evaluate_exact_newton_block_diagonals(
2570        &self,
2571        block_states: &[ParameterBlockState],
2572    ) -> Result<
2573        (
2574            f64,
2575            Array1<f64>,
2576            Array2<f64>,
2577            Array2<f64>,
2578            Option<Array2<f64>>,
2579        ),
2580        String,
2581    > {
2582        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2583        let q_right = self.time_q_right(block_states)?;
2584        let sigma = self.latent_sd(block_states)?;
2585        let slices = self.joint_slices();
2586        let include_log_sigma = slices.log_sigma.is_some();
2587        let mut ll = 0.0;
2588        let mut gradient = Array1::<f64>::zeros(slices.total);
2589        let p_time = slices.time.len();
2590        let p_mean = slices.mean.len();
2591        let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
2592        let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
2593        let mut hess_log_sigma = if include_log_sigma {
2594            Some(Array2::<f64>::zeros((1, 1)))
2595        } else {
2596            None
2597        };
2598        for row_idx in 0..self.event_target.len() {
2599            let wi = self.weights[row_idx];
2600            if wi <= MIN_WEIGHT {
2601                continue;
2602            }
2603            let row = self.build_row_at(
2604                row_idx,
2605                q_entry[row_idx],
2606                q_exit[row_idx],
2607                qdot_exit[row_idx],
2608                q_right[row_idx],
2609            )?;
2610            let (row_ll, primary_gradient, primary_hessian) =
2611                latent_survival_row_primary_gradient_hessian(
2612                    &self.quadctx,
2613                    &row,
2614                    q_entry[row_idx],
2615                    q_exit[row_idx],
2616                    qdot_exit[row_idx],
2617                    q_right[row_idx],
2618                    mu[row_idx],
2619                    sigma,
2620                    include_log_sigma,
2621                )?;
2622            ll += wi * row_ll;
2623            self.add_pullback_primary_gradient(
2624                &mut gradient,
2625                row_idx,
2626                &slices,
2627                &primary_gradient,
2628                wi,
2629            )?;
2630            self.add_pullback_primary_block_diagonals(
2631                row_idx,
2632                &(wi * primary_hessian),
2633                &mut hess_time,
2634                &mut hess_mean,
2635                hess_log_sigma.as_mut(),
2636            )?;
2637        }
2638        Ok((ll, gradient, hess_time, hess_mean, hess_log_sigma))
2639    }
2640
2641    fn evaluate_exact_newton_joint_dense(
2642        &self,
2643        block_states: &[ParameterBlockState],
2644    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
2645        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2646        let q_right = self.time_q_right(block_states)?;
2647        let sigma = self.latent_sd(block_states)?;
2648        let slices = self.joint_slices();
2649        let include_log_sigma = slices.log_sigma.is_some();
2650        let total = slices.total;
2651        let acc = deterministic_latent_survival_row_reduction(
2652            self.event_target.len(),
2653            || LatentSurvivalJointDenseAccum {
2654                ll: 0.0,
2655                gradient: Array1::<f64>::zeros(total),
2656                hessian: Array2::<f64>::zeros((total, total)),
2657            },
2658            |row_idx, acc| {
2659                let wi = self.weights[row_idx];
2660                if wi <= MIN_WEIGHT {
2661                    return Ok(());
2662                }
2663                let row = self.build_row_at(
2664                    row_idx,
2665                    q_entry[row_idx],
2666                    q_exit[row_idx],
2667                    qdot_exit[row_idx],
2668                    q_right[row_idx],
2669                )?;
2670                let (row_ll, primary_gradient, primary_hessian) =
2671                    latent_survival_row_primary_gradient_hessian(
2672                        &self.quadctx,
2673                        &row,
2674                        q_entry[row_idx],
2675                        q_exit[row_idx],
2676                        qdot_exit[row_idx],
2677                        q_right[row_idx],
2678                        mu[row_idx],
2679                        sigma,
2680                        include_log_sigma,
2681                    )?;
2682                acc.ll += wi * row_ll;
2683                self.add_pullback_primary_gradient(
2684                    &mut acc.gradient,
2685                    row_idx,
2686                    &slices,
2687                    &primary_gradient,
2688                    wi,
2689                )?;
2690                self.add_pullback_primary_hessian(
2691                    &mut acc.hessian,
2692                    row_idx,
2693                    &slices,
2694                    &(wi * primary_hessian),
2695                )?;
2696                Ok(())
2697            },
2698            |total_acc, chunk_acc| {
2699                total_acc.ll += chunk_acc.ll;
2700                total_acc.gradient += &chunk_acc.gradient;
2701                total_acc.hessian += &chunk_acc.hessian;
2702            },
2703        )?;
2704        Ok((acc.ll, acc.gradient, acc.hessian))
2705    }
2706
2707    fn exact_newton_joint_hessian_directional_derivative_dense(
2708        &self,
2709        block_states: &[ParameterBlockState],
2710        d_beta_flat: &Array1<f64>,
2711    ) -> Result<Array2<f64>, String> {
2712        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2713        let q_right = self.time_q_right(block_states)?;
2714        let sigma = self.latent_sd(block_states)?;
2715        let slices = self.joint_slices();
2716        if d_beta_flat.len() != slices.total {
2717            return Err(format!(
2718                "latent survival joint dH direction length mismatch: got {}, expected {}",
2719                d_beta_flat.len(),
2720                slices.total
2721            ));
2722        }
2723        let include_log_sigma = slices.log_sigma.is_some();
2724        let total = slices.total;
2725        let acc = deterministic_latent_survival_row_reduction(
2726            self.event_target.len(),
2727            || LatentSurvivalDenseHessianAccum {
2728                hessian: Array2::<f64>::zeros((total, total)),
2729            },
2730            |row_idx, acc| {
2731                let wi = self.weights[row_idx];
2732                if wi <= MIN_WEIGHT {
2733                    return Ok(());
2734                }
2735                let row = self.build_row_at(
2736                    row_idx,
2737                    q_entry[row_idx],
2738                    q_exit[row_idx],
2739                    qdot_exit[row_idx],
2740                    q_right[row_idx],
2741                )?;
2742                let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
2743                let third = latent_survival_row_primary_third_contracted(
2744                    &self.quadctx,
2745                    &row,
2746                    q_entry[row_idx],
2747                    q_exit[row_idx],
2748                    qdot_exit[row_idx],
2749                    q_right[row_idx],
2750                    mu[row_idx],
2751                    sigma,
2752                    &direction,
2753                    include_log_sigma,
2754                )?;
2755                self.add_pullback_primary_hessian(
2756                    &mut acc.hessian,
2757                    row_idx,
2758                    &slices,
2759                    &(wi * third),
2760                )?;
2761                Ok(())
2762            },
2763            |total_acc, chunk_acc| {
2764                total_acc.hessian += &chunk_acc.hessian;
2765            },
2766        )?;
2767        Ok(acc.hessian)
2768    }
2769
2770    fn exact_newton_joint_hessian_second_directional_derivative_dense(
2771        &self,
2772        block_states: &[ParameterBlockState],
2773        d_beta_u_flat: &Array1<f64>,
2774        d_beta_v_flat: &Array1<f64>,
2775    ) -> Result<Array2<f64>, String> {
2776        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2777        let q_right = self.time_q_right(block_states)?;
2778        let sigma = self.latent_sd(block_states)?;
2779        let slices = self.joint_slices();
2780        if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
2781            return Err(format!(
2782                "latent survival joint d2H direction length mismatch: got {} and {}, expected {}",
2783                d_beta_u_flat.len(),
2784                d_beta_v_flat.len(),
2785                slices.total
2786            ));
2787        }
2788        let include_log_sigma = slices.log_sigma.is_some();
2789        let total = slices.total;
2790        let acc = deterministic_latent_survival_row_reduction(
2791            self.event_target.len(),
2792            || LatentSurvivalDenseHessianAccum {
2793                hessian: Array2::<f64>::zeros((total, total)),
2794            },
2795            |row_idx, acc| {
2796                let wi = self.weights[row_idx];
2797                if wi <= MIN_WEIGHT {
2798                    return Ok(());
2799                }
2800                let row = self.build_row_at(
2801                    row_idx,
2802                    q_entry[row_idx],
2803                    q_exit[row_idx],
2804                    qdot_exit[row_idx],
2805                    q_right[row_idx],
2806                )?;
2807                let direction_u =
2808                    self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
2809                let direction_v =
2810                    self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
2811                let fourth = latent_survival_row_primary_fourth_contracted(
2812                    &self.quadctx,
2813                    &row,
2814                    q_entry[row_idx],
2815                    q_exit[row_idx],
2816                    qdot_exit[row_idx],
2817                    q_right[row_idx],
2818                    mu[row_idx],
2819                    sigma,
2820                    &direction_u,
2821                    &direction_v,
2822                    include_log_sigma,
2823                )?;
2824                self.add_pullback_primary_hessian(
2825                    &mut acc.hessian,
2826                    row_idx,
2827                    &slices,
2828                    &(wi * fourth),
2829                )?;
2830                Ok(())
2831            },
2832            |total_acc, chunk_acc| {
2833                total_acc.hessian += &chunk_acc.hessian;
2834            },
2835        )?;
2836        Ok(acc.hessian)
2837    }
2838}
2839
2840fn log_kernel_ratio(
2841    bundle: &crate::survival::lognormal_kernel::LogLognormalKernelBundle,
2842    num: usize,
2843    den: usize,
2844) -> f64 {
2845    let delta = bundle.get(num) - bundle.get(den);
2846    if delta.is_finite() {
2847        delta.exp()
2848    } else if delta > 0.0 {
2849        f64::INFINITY
2850    } else {
2851        0.0
2852    }
2853}
2854
2855fn logk_q_derivatives(
2856    quadctx: &QuadratureContext,
2857    k: usize,
2858    mass: f64,
2859    mu: f64,
2860    sigma: f64,
2861) -> Result<(f64, f64, IntegratedExpectationMode), LatentSurvivalError> {
2862    if mass <= 0.0 {
2863        return Ok((0.0, 0.0, IntegratedExpectationMode::ExactClosedForm));
2864    }
2865    let bundle = log_kernel_bundle(quadctx, mass, mu, sigma, k + 2).map_err(|e| {
2866        LatentSurvivalError::NumericalFailure {
2867            reason: format!("latent survival kernel evaluation failed: {e}"),
2868        }
2869    })?;
2870    let r1 = log_kernel_ratio(&bundle, k + 1, k);
2871    let r2 = log_kernel_ratio(&bundle, k + 2, k);
2872    let d1 = -mass * r1;
2873    let d2 = d1 + mass * mass * (r2 - r1 * r1);
2874    Ok((d1, d2, bundle.mode))
2875}
2876
2877fn latent_survival_time_jet(
2878    quadctx: &QuadratureContext,
2879    row: &LatentSurvivalRow,
2880    qdot_exit: f64,
2881    mu: f64,
2882    sigma: f64,
2883) -> Result<LatentSurvivalTimeJet, LatentSurvivalError> {
2884    let (entry_d1, entry_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_entry, mu, sigma)?;
2885    match row.event_type {
2886        LatentSurvivalEventType::RightCensored => {
2887            let (exit_d1, exit_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2888            Ok(LatentSurvivalTimeJet {
2889                grad_entry: -entry_d1,
2890                grad_exit: exit_d1,
2891                neg_hess_entry: entry_d2,
2892                neg_hess_exit: -exit_d2,
2893            })
2894        }
2895        LatentSurvivalEventType::ExactEvent => {
2896            if !(qdot_exit.is_finite() && qdot_exit > 0.0) {
2897                return Err(LatentSurvivalError::NumericalFailure {
2898                    reason: format!(
2899                        "latent survival requires positive finite baseline hazard derivative, got {qdot_exit}"
2900                    ),
2901                });
2902            }
2903            if row.hazard_unloaded > 0.0 {
2904                let bundle =
2905                    log_kernel_bundle(quadctx, row.mass_exit, mu, sigma, 3).map_err(|e| {
2906                        LatentSurvivalError::NumericalFailure {
2907                            reason: format!("latent survival kernel evaluation failed: {e}"),
2908                        }
2909                    })?;
2910                let (unloaded_d1, unloaded_d2, _) =
2911                    logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2912                let (loaded_log_d1, loaded_d2, _) =
2913                    logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2914                let loaded_d1 = 1.0 + loaded_log_d1;
2915                let log_loaded = row.hazard_loaded.ln() + bundle.get(1);
2916                let log_unloaded = row.hazard_unloaded.ln() + bundle.get(0);
2917                let shift = log_loaded.max(log_unloaded);
2918                let loaded_weight = (log_loaded - shift).exp();
2919                let unloaded_weight = (log_unloaded - shift).exp();
2920                let normalizer = loaded_weight + unloaded_weight;
2921                if !(normalizer.is_finite() && normalizer > 0.0) {
2922                    return Err(LatentSurvivalError::NumericalFailure {
2923                        reason: "latent survival exact-event numerator became non-finite under loaded/unloaded hazard decomposition"
2924                            .to_string(),
2925                    });
2926                }
2927                let w_loaded = loaded_weight / normalizer;
2928                let w_unloaded = unloaded_weight / normalizer;
2929                let grad_exit = w_loaded * loaded_d1 + w_unloaded * unloaded_d1;
2930                let d2_exit = w_loaded * (loaded_d2 + loaded_d1 * loaded_d1)
2931                    + w_unloaded * (unloaded_d2 + unloaded_d1 * unloaded_d1)
2932                    - grad_exit * grad_exit;
2933                Ok(LatentSurvivalTimeJet {
2934                    grad_entry: -entry_d1,
2935                    grad_exit,
2936                    neg_hess_entry: entry_d2,
2937                    neg_hess_exit: -d2_exit,
2938                })
2939            } else {
2940                let (exit_d1, exit_d2, _) =
2941                    logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2942                Ok(LatentSurvivalTimeJet {
2943                    grad_entry: -entry_d1,
2944                    grad_exit: 1.0 + exit_d1,
2945                    neg_hess_entry: entry_d2,
2946                    neg_hess_exit: -exit_d2,
2947                })
2948            }
2949        }
2950        LatentSurvivalEventType::IntervalCensored => {
2951            Err(LatentSurvivalError::UnsupportedConfiguration {
2952                reason:
2953                    "latent survival dynamic time derivatives do not implement interval censoring"
2954                        .to_string(),
2955            })
2956        }
2957    }
2958}
2959
2960fn dense_outer_accumulate<S>(
2961    target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2962    weight: f64,
2963    x: ArrayView1<'_, f64>,
2964) where
2965    S: ndarray::DataMut<Elem = f64>,
2966{
2967    for a in 0..x.len() {
2968        let xa = x[a];
2969        if xa == 0.0 {
2970            continue;
2971        }
2972        for b in 0..x.len() {
2973            let xb = x[b];
2974            if xb == 0.0 {
2975                continue;
2976            }
2977            target[[a, b]] += weight * xa * xb;
2978        }
2979    }
2980}
2981
2982fn dense_symmetric_cross_accumulate<S>(
2983    target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2984    weight: f64,
2985    x: ArrayView1<'_, f64>,
2986    y: ArrayView1<'_, f64>,
2987) where
2988    S: ndarray::DataMut<Elem = f64>,
2989{
2990    for a in 0..x.len() {
2991        let xa = x[a];
2992        let ya = y[a];
2993        if xa == 0.0 && ya == 0.0 {
2994            continue;
2995        }
2996        for b in 0..x.len() {
2997            let xb = x[b];
2998            let yb = y[b];
2999            let contribution = xa * yb + ya * xb;
3000            if contribution == 0.0 {
3001                continue;
3002            }
3003            target[[a, b]] += weight * contribution;
3004        }
3005    }
3006}
3007
3008fn build_latent_survival_row(
3009    row_index: usize,
3010    hazard_loading: HazardLoading,
3011    event_type: LatentSurvivalEventType,
3012    q_entry: f64,
3013    q_exit: f64,
3014    qdot_exit: f64,
3015    q_right: f64,
3016    unloaded_mass_entry: f64,
3017    unloaded_mass_exit: f64,
3018    unloaded_mass_right: f64,
3019    unloaded_hazard_exit: f64,
3020) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3021    if !(q_entry.is_finite() && q_exit.is_finite()) {
3022        return Err(LatentSurvivalError::NumericalFailure {
3023            reason: format!(
3024                "latent survival requires finite q_entry and q_exit, got q_entry={q_entry}, q_exit={q_exit}"
3025            ),
3026        });
3027    }
3028    if q_exit < q_entry {
3029        return Err(LatentSurvivalError::NumericalFailure {
3030            reason: format!(
3031                "latent survival requires q_exit >= q_entry so cumulative mass is monotone, got q_entry={q_entry}, q_exit={q_exit}"
3032            ),
3033        });
3034    }
3035    if !(unloaded_mass_entry.is_finite()
3036        && unloaded_mass_exit.is_finite()
3037        && unloaded_hazard_exit.is_finite())
3038    {
3039        return Err(LatentSurvivalError::InvalidDataset {
3040            reason: format!(
3041                "latent survival requires finite unloaded components, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3042            ),
3043        });
3044    }
3045    if unloaded_mass_entry < 0.0
3046        || unloaded_mass_exit < unloaded_mass_entry
3047        || unloaded_hazard_exit < 0.0
3048    {
3049        return Err(LatentSurvivalError::InvalidDataset {
3050            reason: format!(
3051                "latent survival requires unloaded masses/hazard to be non-negative and monotone, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3052            ),
3053        });
3054    }
3055    let mass_entry = q_entry.exp();
3056    let mass_exit = q_exit.exp();
3057    let row = match event_type {
3058        LatentSurvivalEventType::RightCensored => {
3059            validate_unloaded_components_for_loading(
3060                "latent-survival",
3061                row_index,
3062                hazard_loading,
3063                unloaded_mass_entry,
3064                unloaded_mass_exit,
3065                Some(unloaded_hazard_exit),
3066            )?;
3067            LatentSurvivalRow::right_censored(
3068                mass_entry,
3069                mass_exit,
3070                unloaded_mass_entry,
3071                unloaded_mass_exit,
3072            )
3073        }
3074        LatentSurvivalEventType::ExactEvent => {
3075            validate_unloaded_components_for_loading(
3076                "latent-survival",
3077                row_index,
3078                hazard_loading,
3079                unloaded_mass_entry,
3080                unloaded_mass_exit,
3081                Some(unloaded_hazard_exit),
3082            )?;
3083            LatentSurvivalRow::exact_event(
3084                mass_entry,
3085                mass_exit,
3086                unloaded_mass_entry,
3087                unloaded_mass_exit,
3088                mass_exit
3089                    * if qdot_exit.is_finite() && qdot_exit > 0.0 {
3090                        qdot_exit
3091                    } else {
3092                        return Err(LatentSurvivalError::NumericalFailure {
3093                            reason: format!(
3094                                "latent survival exact event requires positive finite baseline hazard derivative, got {qdot_exit}"
3095                            ),
3096                        });
3097                    },
3098                unloaded_hazard_exit,
3099            )
3100        }
3101        LatentSurvivalEventType::IntervalCensored => {
3102            // Interval `(L, R]`: `q_exit` carries the LEFT boundary transform
3103            // `log B(L)` (so `mass_left = exp(q_exit)`) and `q_right` the RIGHT
3104            // boundary `log B(R)`. The likelihood is the survival-mass
3105            // difference `log[S(L) − S(R)]`, requiring `B(L) ≤ B(R)` i.e.
3106            // `q_exit ≤ q_right`. No event hazard participates, so the unloaded
3107            // exit hazard must be the full-loading zero (validated below via the
3108            // interval-specific unloaded check at the left/right boundaries).
3109            if !q_right.is_finite() {
3110                return Err(LatentSurvivalError::NumericalFailure {
3111                    reason: format!(
3112                        "latent survival interval row {} requires a finite q_right, got {q_right}",
3113                        row_index + 1
3114                    ),
3115                });
3116            }
3117            if q_right < q_exit {
3118                return Err(LatentSurvivalError::NumericalFailure {
3119                    reason: format!(
3120                        "latent survival interval row {} requires q_right >= q_exit (R >= L) so the \
3121                         survival-mass difference is non-negative, got q_left={q_exit}, q_right={q_right}",
3122                        row_index + 1
3123                    ),
3124                });
3125            }
3126            if !(unloaded_mass_right.is_finite()) || unloaded_mass_right < unloaded_mass_exit {
3127                return Err(LatentSurvivalError::InvalidDataset {
3128                    reason: format!(
3129                        "latent survival interval row {} requires a finite unloaded right mass >= unloaded left mass, got left={unloaded_mass_exit}, right={unloaded_mass_right}",
3130                        row_index + 1
3131                    ),
3132                });
3133            }
3134            // Interval rows carry no exit-event hazard; the loaded/unloaded
3135            // contract is validated by `LatentSurvivalRow::validate` (entry <=
3136            // left <= right monotonicity on both loaded and unloaded masses).
3137            let mass_right = q_right.exp();
3138            LatentSurvivalRow::interval_censored(
3139                mass_entry,
3140                mass_exit,
3141                mass_right,
3142                unloaded_mass_entry,
3143                unloaded_mass_exit,
3144                unloaded_mass_right,
3145            )
3146        }
3147    };
3148    row.validate()
3149        .map_err(|e| LatentSurvivalError::InvalidDataset {
3150            reason: e.to_string(),
3151        })?;
3152    Ok(row)
3153}
3154
3155#[derive(Clone, Copy)]
3156struct BinaryFromLogSurvival {
3157    log_lik: f64,
3158    /// dℓ/ds where s = log_survival and ℓ = log_lik. For event=1 this is
3159    /// ℓ' = -S/(1-S); for event=0 this is ℓ' = 1 (because ℓ ≡ s).
3160    grad_scale: f64,
3161    /// Coefficient applied to `survival_jet.neg_hessian` (which equals
3162    /// -d²s/dβ²) when assembling the negative Hessian of `wi * log_lik`
3163    /// against β. The Newton accumulator computes
3164    ///     neg_Hess(log_lik) = grad_scale * neg_hessian + outer_scale * score²
3165    /// so by the chain rule this MUST equal `grad_scale` (= ℓ'). Keeping
3166    /// the two fields separate is purely for readability at call sites;
3167    /// the `assert!` in [`binary_log_survival_scales`] enforces the
3168    /// equality.
3169    neg_hess_scale: f64,
3170    /// -ℓ''(s). For event=1 this is +S/(1-S)²; for event=0 it is 0.
3171    outer_scale: f64,
3172    /// ℓ''(s) — derivative of `grad_scale` w.r.t. s.
3173    grad_scale_prime: f64,
3174    /// ℓ'''(s) — second derivative of `grad_scale` w.r.t. s.
3175    grad_scale_second: f64,
3176    /// -ℓ'''(s) — derivative of `outer_scale` w.r.t. s.
3177    outer_scale_prime: f64,
3178    /// -ℓ''''(s) — second derivative of `outer_scale` w.r.t. s.
3179    outer_scale_second: f64,
3180}
3181
3182/// Analytic source of truth for the directional derivatives of
3183/// ℓ(s) = log(1 - exp(s)) at s = `log_survival`. Returns
3184/// `(ℓ, ℓ', ℓ'', ℓ''', ℓ'''')`. All consumer scales (`grad_scale`,
3185/// `neg_hess_scale`, `outer_scale`, and their two derivatives each)
3186/// are derived from this single function so the sign/algebra cannot
3187/// drift between sites.
3188#[inline]
3189fn binary_log_survival_scales(survival: f64, event_prob: f64) -> (f64, f64, f64, f64, f64) {
3190    // ℓ(s)   = log(1 - exp(s)) = log(event_prob)
3191    // dS/ds  = S,    dP/ds = -S        (S=survival, P=event_prob)
3192    // ℓ'(s)  = -S/P
3193    // ℓ''(s) = d/ds[-S/P] = -S/P²        (since P + S = 1)
3194    // ℓ'''(s) = d/ds[-S/P²] = -S(1 + S)/P³
3195    // ℓ''''(s) = d/ds[-S(1+S)/P³]
3196    //          = -S/P³ - 3S²/P³ - 6S²(1+S)/P⁴ - ... ; expanded form below.
3197    let log_lik = event_prob.ln();
3198    let p = event_prob;
3199    let p2 = p * p;
3200    let p3 = p2 * p;
3201    let p4 = p3 * p;
3202    let s = survival;
3203    let s2 = s * s;
3204    let s3 = s2 * s;
3205    let ell_prime = -s / p;
3206    let ell_pp = -s / p2;
3207    let ell_ppp = -s * (1.0 + s) / p3;
3208    // ℓ''''(s) = -S·(1 + 4S + S²) / P⁴ - 3·S²·(1+S)/P⁴? Use the equivalent
3209    // expansion that matches the prior closed form:
3210    //   d/ds[-S(1+S)/P³] = -(S + 2S²)/P³ - 3·S·(1+S)·S/P⁴
3211    //                    = -(S + 2S²)/P³ - 3S²(1+S)/P⁴
3212    // Combining over P⁴: -(S + 2S²)·P/P⁴ - 3S²(1+S)/P⁴
3213    //                  = -[S·P + 2S²·P + 3S² + 3S³] / P⁴
3214    // With P = 1 - S: S·P = S - S²; 2S²·P = 2S² - 2S³.
3215    //   numerator = -[S - S² + 2S² - 2S³ + 3S² + 3S³] = -[S + 4S² + S³].
3216    // So ℓ''''(s) = -(S + 4S² + S³) / P⁴.
3217    let ell_pppp = -(s + 4.0 * s2 + s3) / p4;
3218    (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp)
3219}
3220
3221fn binary_from_log_survival(
3222    log_survival: f64,
3223    event: u8,
3224) -> Result<BinaryFromLogSurvival, LatentSurvivalError> {
3225    if event == 0 {
3226        // ℓ(s) = s ⇒ ℓ' = 1, ℓ'' = ℓ''' = ℓ'''' = 0.
3227        return Ok(BinaryFromLogSurvival {
3228            log_lik: log_survival,
3229            grad_scale: 1.0,
3230            neg_hess_scale: 1.0,
3231            outer_scale: 0.0,
3232            grad_scale_prime: 0.0,
3233            grad_scale_second: 0.0,
3234            outer_scale_prime: 0.0,
3235            outer_scale_second: 0.0,
3236        });
3237    }
3238    if event != 1 {
3239        return Err(LatentSurvivalError::InvalidDataset {
3240            reason: format!("latent-binary requires event targets in {{0,1}}, got {event}"),
3241        });
3242    }
3243    // Cap log S(t) strictly below zero so the event probability
3244    // `1 - exp(log S)` stays strictly positive even when the survival
3245    // probability rounds to exactly 1 (log S == 0): a zero event probability
3246    // would make the binary log-likelihood `log(event_prob)` diverge. The cap
3247    // is at the f64 resolution near 1.0, so it never perturbs a genuinely
3248    // informative survival value.
3249    const MAX_LOG_SURVIVAL: f64 = -1e-15;
3250    let log_survival = log_survival.min(MAX_LOG_SURVIVAL);
3251    let survival = log_survival.exp();
3252    let event_prob = 1.0 - survival;
3253    if !(event_prob.is_finite() && event_prob > 0.0) {
3254        return Err(LatentSurvivalError::NumericalFailure {
3255            reason: format!(
3256                "latent-binary encountered non-positive event probability from log survival {log_survival}"
3257            ),
3258        });
3259    }
3260    let (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp) =
3261        binary_log_survival_scales(survival, event_prob);
3262    let grad_scale = ell_prime;
3263    let neg_hess_scale = ell_prime; // coefficient on (-d²s/dβ²); equals ℓ'.
3264    let outer_scale = -ell_pp;
3265    let grad_scale_prime = ell_pp;
3266    let grad_scale_second = ell_ppp;
3267    let outer_scale_prime = -ell_ppp;
3268    let outer_scale_second = -ell_pppp;
3269    // The Newton accumulator at the call sites computes
3270    //     neg_Hess(log_lik) = neg_hess_scale * (-d²s/dβ²) + outer_scale * (ds/dβ)²
3271    // For this identity to hold by the chain rule, the coefficient on the
3272    // neg_hessian term must equal ℓ' (== grad_scale). Document the invariant.
3273    assert!(
3274        (grad_scale - neg_hess_scale).abs() <= 1e-15 * grad_scale.abs().max(1.0),
3275        "binary_from_log_survival invariant: neg_hess_scale ({neg_hess_scale}) must equal grad_scale ({grad_scale}) so that grad_scale and the coefficient on neg_hessian share sign"
3276    );
3277    assert!(
3278        outer_scale >= 0.0 || !outer_scale.is_finite(),
3279        "binary_from_log_survival invariant: outer_scale (= -ℓ'') must be non-negative for event=1; got {outer_scale}"
3280    );
3281    Ok(BinaryFromLogSurvival {
3282        log_lik,
3283        grad_scale,
3284        neg_hess_scale,
3285        outer_scale,
3286        grad_scale_prime,
3287        grad_scale_second,
3288        outer_scale_prime,
3289        outer_scale_second,
3290    })
3291}
3292
3293impl LatentBinaryFamily {
3294    /// Assemble the per-row [`LatentSurvivalRow`] for a row treated as a pure
3295    /// right-censored survival contribution (exit time is the censoring
3296    /// boundary, unit exit-hazard derivative, no right / post-exit unloaded
3297    /// mass). Shared by every per-row binary-from-survival pullback reduction;
3298    /// behavior is identical to the previously inlined `RightCensored` call.
3299    fn build_right_censored_row_at(
3300        &self,
3301        row_idx: usize,
3302        q_entry: f64,
3303        q_exit: f64,
3304    ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3305        build_latent_survival_row(
3306            row_idx,
3307            self.hazard_loading,
3308            LatentSurvivalEventType::RightCensored,
3309            q_entry,
3310            q_exit,
3311            1.0,
3312            q_exit,
3313            self.unloaded_mass_entry[row_idx],
3314            self.unloaded_mass_exit[row_idx],
3315            0.0,
3316            0.0,
3317        )
3318    }
3319
3320    fn joint_slices(&self) -> LatentSurvivalJointSlices {
3321        let p_time = self.x_time_exit.ncols();
3322        let p_mean = self.x_mean.ncols();
3323        LatentSurvivalJointSlices {
3324            time: 0..p_time,
3325            mean: p_time..p_time + p_mean,
3326            log_sigma: None,
3327            total: p_time + p_mean,
3328        }
3329    }
3330
3331    fn row_primary_direction_from_flat(
3332        &self,
3333        row: usize,
3334        slices: &LatentSurvivalJointSlices,
3335        d_beta_flat: &Array1<f64>,
3336    ) -> Array1<f64> {
3337        let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
3338        let d_time = d_beta_flat.slice(s![slices.time.clone()]);
3339        out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
3340        out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
3341        out[LATENT_SURVIVAL_PRIMARY_MU] = self
3342            .x_mean
3343            .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
3344        out
3345    }
3346
3347    fn add_pullback_primary_gradient(
3348        &self,
3349        target: &mut Array1<f64>,
3350        row: usize,
3351        slices: &LatentSurvivalJointSlices,
3352        primary_gradient: &Array1<f64>,
3353        weight: f64,
3354    ) {
3355        for (primary_idx, time_vec) in [
3356            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3357            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3358        ] {
3359            let scale = weight * primary_gradient[primary_idx];
3360            if scale == 0.0 {
3361                continue;
3362            }
3363            for i in 0..time_vec.len() {
3364                let xi = time_vec[i];
3365                if xi != 0.0 {
3366                    target[slices.time.start + i] += scale * xi;
3367                }
3368            }
3369        }
3370
3371        let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
3372        if mean_scale != 0.0 {
3373            self.x_mean
3374                .axpy_row_into(
3375                    row,
3376                    mean_scale,
3377                    &mut target.slice_mut(s![slices.mean.clone()]),
3378                )
3379                // SAFETY: `slices.mean` sized at construction to match
3380                // `x_mean.ncols()`; an error means caller-side shape drift,
3381                // an invariant violation. A swallowed sentinel would silently
3382                // corrupt the joint gradient, so fail loudly instead.
3383                .unwrap_or_else(|error| {
3384                    panic!(
3385                        "latent binary mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
3386                        slices.mean,
3387                        target.len(),
3388                        self.x_mean.ncols()
3389                    )
3390                });
3391        }
3392    }
3393
3394    fn add_pullback_primary_hessian(
3395        &self,
3396        target: &mut Array2<f64>,
3397        row: usize,
3398        slices: &LatentSurvivalJointSlices,
3399        primary_hessian: &Array2<f64>,
3400    ) {
3401        {
3402            let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
3403            dense_outer_accumulate(
3404                time_target,
3405                primary_hessian[[
3406                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3407                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3408                ]],
3409                self.x_time_entry.row(row),
3410            );
3411            dense_outer_accumulate(
3412                time_target,
3413                primary_hessian[[
3414                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3415                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3416                ]],
3417                self.x_time_exit.row(row),
3418            );
3419            dense_symmetric_cross_accumulate(
3420                time_target,
3421                primary_hessian[[
3422                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3423                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3424                ]],
3425                self.x_time_entry.row(row),
3426                self.x_time_exit.row(row),
3427            );
3428        }
3429
3430        let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
3431        self.x_mean
3432            .syr_row_into_view(
3433                row,
3434                mean_weight,
3435                target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
3436            )
3437            .unwrap_or_else(|error| {
3438                // SAFETY: `slices.mean` × `slices.mean` slab sized at
3439                // construction to `x_mean.ncols()` × `x_mean.ncols()`;
3440                // an error here is caller-side shape drift, an invariant
3441                // violation. A swallowed sentinel would silently corrupt the
3442                // joint Hessian, so fail loudly instead.
3443                panic!(
3444                    "latent binary mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
3445                    slices.mean,
3446                    target.dim(),
3447                    self.x_mean.ncols()
3448                )
3449            });
3450
3451        let mean_row = self
3452            .x_mean
3453            .try_row_chunk(row..row + 1)
3454            .unwrap_or_else(|error| {
3455                // SAFETY: row index comes from the enclosing `0..n` loop
3456                // bound by `self.x_mean.nrows()`, so `row..row+1` is
3457                // always a valid single-row chunk.
3458                panic!(
3459                    "latent binary mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
3460                    self.x_mean.nrows(),
3461                    self.x_mean.ncols()
3462                )
3463            });
3464        let mean_vec = mean_row.row(0);
3465        for (primary_idx, time_vec) in [
3466            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3467            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3468        ] {
3469            let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
3470            if weight == 0.0 {
3471                continue;
3472            }
3473            for i in 0..time_vec.len() {
3474                let xi = time_vec[i];
3475                if xi == 0.0 {
3476                    continue;
3477                }
3478                for j in 0..mean_vec.len() {
3479                    let xj = mean_vec[j];
3480                    if xj == 0.0 {
3481                        continue;
3482                    }
3483                    target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
3484                    target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
3485                }
3486            }
3487        }
3488    }
3489
3490    fn evaluate_exact_newton_joint_dense(
3491        &self,
3492        block_states: &[ParameterBlockState],
3493    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3494        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3495        let slices = self.joint_slices();
3496        let mut ll = 0.0;
3497        let mut gradient = Array1::<f64>::zeros(slices.total);
3498        let mut hessian = Array2::<f64>::zeros((slices.total, slices.total));
3499        for row_idx in 0..self.event_target.len() {
3500            let wi = self.weights[row_idx];
3501            if wi <= MIN_WEIGHT {
3502                continue;
3503            }
3504            let row =
3505                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3506            let (row_log_survival, survival_gradient, survival_hessian) =
3507                latent_survival_row_primary_gradient_hessian(
3508                    &self.quadctx,
3509                    &row,
3510                    q_entry[row_idx],
3511                    q_exit[row_idx],
3512                    1.0,
3513                    q_exit[row_idx],
3514                    mu[row_idx],
3515                    self.latent_sd,
3516                    false,
3517                )?;
3518            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3519            ll += wi * binary.log_lik;
3520            let primary_gradient = binary.grad_scale * &survival_gradient;
3521            let mut primary_hessian = binary.grad_scale * survival_hessian;
3522            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3523                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3524                    primary_hessian[[a, b]] +=
3525                        binary.outer_scale * survival_gradient[a] * survival_gradient[b];
3526                }
3527            }
3528            self.add_pullback_primary_gradient(
3529                &mut gradient,
3530                row_idx,
3531                &slices,
3532                &primary_gradient,
3533                wi,
3534            );
3535            self.add_pullback_primary_hessian(
3536                &mut hessian,
3537                row_idx,
3538                &slices,
3539                &(wi * primary_hessian),
3540            );
3541        }
3542        Ok((ll, gradient, hessian))
3543    }
3544
3545    /// Per-row residuals of the unpenalized NLL with respect to the baseline
3546    /// time-block offsets `(entry, exit)`.
3547    ///
3548    /// The latent-binary deployment likelihood is a monotone scalar transform
3549    /// `ℓ_bin = b(log S_row)` of the latent-survival row log-survival, so by the
3550    /// chain rule `∂ℓ_bin/∂q_ch = b'(log S)·∂(log S)/∂q_ch = grad_scale·g_ch`,
3551    /// where `g_ch` are the `Q_ENTRY`/`Q_EXIT` components of the survival row
3552    /// primary gradient. The baseline θ enters only the additive entry/exit time
3553    /// offsets (`q̇_exit` is held at the constant deployment derivative `1`, so
3554    /// the derivative channel carries no baseline offset and its residual is 0).
3555    /// Sampleweight-scaled to match the [`OffsetChannelResiduals`] contract.
3556    pub fn offset_channel_residuals(
3557        &self,
3558        block_states: &[ParameterBlockState],
3559    ) -> Result<crate::survival::OffsetChannelResiduals, String> {
3560        let n = self.event_target.len();
3561        if block_states.is_empty() {
3562            log::warn!(
3563                "LatentBinaryFamily::offset_channel_residuals: block_states is empty \
3564                 (degraded fit); returning zero offset residuals (n={n})"
3565            );
3566            return Ok(crate::survival::OffsetChannelResiduals {
3567                exit: Array1::<f64>::zeros(n),
3568                entry: Array1::<f64>::zeros(n),
3569                derivative: Array1::<f64>::zeros(n),
3570                right: Array1::<f64>::zeros(n),
3571            });
3572        }
3573        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3574        let mut entry = Array1::<f64>::zeros(n);
3575        let mut exit = Array1::<f64>::zeros(n);
3576        for row_idx in 0..n {
3577            let wi = self.weights[row_idx];
3578            if wi <= MIN_WEIGHT {
3579                continue;
3580            }
3581            let row =
3582                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3583            let (row_log_survival, survival_gradient, _) =
3584                latent_survival_row_primary_gradient_hessian(
3585                    &self.quadctx,
3586                    &row,
3587                    q_entry[row_idx],
3588                    q_exit[row_idx],
3589                    1.0,
3590                    q_exit[row_idx],
3591                    mu[row_idx],
3592                    self.latent_sd,
3593                    false,
3594                )?;
3595            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3596            // ∂NLL/∂o_ch = −w · grad_scale · ∂(log S)/∂q_ch.
3597            entry[row_idx] =
3598                -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
3599            exit[row_idx] =
3600                -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
3601        }
3602        Ok(crate::survival::OffsetChannelResiduals {
3603            exit,
3604            entry,
3605            derivative: Array1::<f64>::zeros(n),
3606            // Latent-binary deployment has no interval upper bound; the `R`
3607            // channel is structurally absent (every row is right-censored).
3608            right: Array1::<f64>::zeros(n),
3609        })
3610    }
3611
3612    fn exact_newton_joint_hessian_directional_derivative_dense(
3613        &self,
3614        block_states: &[ParameterBlockState],
3615        d_beta_flat: &Array1<f64>,
3616    ) -> Result<Array2<f64>, String> {
3617        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3618        let slices = self.joint_slices();
3619        if d_beta_flat.len() != slices.total {
3620            return Err(format!(
3621                "latent binary joint dH direction length mismatch: got {}, expected {}",
3622                d_beta_flat.len(),
3623                slices.total
3624            ));
3625        }
3626        let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3627        for row_idx in 0..self.event_target.len() {
3628            let wi = self.weights[row_idx];
3629            if wi <= MIN_WEIGHT {
3630                continue;
3631            }
3632            let row =
3633                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3634            let (row_log_survival, survival_gradient, survival_hessian) =
3635                latent_survival_row_primary_gradient_hessian(
3636                    &self.quadctx,
3637                    &row,
3638                    q_entry[row_idx],
3639                    q_exit[row_idx],
3640                    1.0,
3641                    q_exit[row_idx],
3642                    mu[row_idx],
3643                    self.latent_sd,
3644                    false,
3645                )?;
3646            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3647            let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
3648            let third = latent_survival_row_primary_third_contracted(
3649                &self.quadctx,
3650                &row,
3651                q_entry[row_idx],
3652                q_exit[row_idx],
3653                1.0,
3654                q_exit[row_idx],
3655                mu[row_idx],
3656                self.latent_sd,
3657                &direction,
3658                false,
3659            )?;
3660            let g_u = -survival_hessian.dot(&direction);
3661            let t_u = survival_gradient.dot(&direction);
3662            let mut primary = binary.grad_scale * third;
3663            primary.scaled_add(binary.grad_scale_prime * t_u, &survival_hessian);
3664            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3665                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3666                    primary[[a, b]] += binary.outer_scale_prime
3667                        * t_u
3668                        * survival_gradient[a]
3669                        * survival_gradient[b]
3670                        + binary.outer_scale
3671                            * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b]);
3672                }
3673            }
3674            self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3675        }
3676        Ok(out)
3677    }
3678
3679    fn exact_newton_joint_hessian_second_directional_derivative_dense(
3680        &self,
3681        block_states: &[ParameterBlockState],
3682        d_beta_u_flat: &Array1<f64>,
3683        d_beta_v_flat: &Array1<f64>,
3684    ) -> Result<Array2<f64>, String> {
3685        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3686        let slices = self.joint_slices();
3687        if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
3688            return Err(format!(
3689                "latent binary joint d2H direction length mismatch: got {} and {}, expected {}",
3690                d_beta_u_flat.len(),
3691                d_beta_v_flat.len(),
3692                slices.total
3693            ));
3694        }
3695        let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3696        for row_idx in 0..self.event_target.len() {
3697            let wi = self.weights[row_idx];
3698            if wi <= MIN_WEIGHT {
3699                continue;
3700            }
3701            let row =
3702                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3703            let (row_log_survival, survival_gradient, survival_hessian) =
3704                latent_survival_row_primary_gradient_hessian(
3705                    &self.quadctx,
3706                    &row,
3707                    q_entry[row_idx],
3708                    q_exit[row_idx],
3709                    1.0,
3710                    q_exit[row_idx],
3711                    mu[row_idx],
3712                    self.latent_sd,
3713                    false,
3714                )?;
3715            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3716            let direction_u = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
3717            let direction_v = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
3718            let third_u = latent_survival_row_primary_third_contracted(
3719                &self.quadctx,
3720                &row,
3721                q_entry[row_idx],
3722                q_exit[row_idx],
3723                1.0,
3724                q_exit[row_idx],
3725                mu[row_idx],
3726                self.latent_sd,
3727                &direction_u,
3728                false,
3729            )?;
3730            let third_v = latent_survival_row_primary_third_contracted(
3731                &self.quadctx,
3732                &row,
3733                q_entry[row_idx],
3734                q_exit[row_idx],
3735                1.0,
3736                q_exit[row_idx],
3737                mu[row_idx],
3738                self.latent_sd,
3739                &direction_v,
3740                false,
3741            )?;
3742            let fourth = latent_survival_row_primary_fourth_contracted(
3743                &self.quadctx,
3744                &row,
3745                q_entry[row_idx],
3746                q_exit[row_idx],
3747                1.0,
3748                q_exit[row_idx],
3749                mu[row_idx],
3750                self.latent_sd,
3751                &direction_u,
3752                &direction_v,
3753                false,
3754            )?;
3755            let g_u = -survival_hessian.dot(&direction_u);
3756            let g_v = -survival_hessian.dot(&direction_v);
3757            let g_uv = -third_v.dot(&direction_u);
3758            let t_u = survival_gradient.dot(&direction_u);
3759            let t_v = survival_gradient.dot(&direction_v);
3760            let l_uv = -direction_u.dot(&survival_hessian.dot(&direction_v));
3761            let c_u = binary.grad_scale_prime * t_u;
3762            let c_v = binary.grad_scale_prime * t_v;
3763            let c_uv = binary.grad_scale_second * t_u * t_v + binary.grad_scale_prime * l_uv;
3764            let o_u = binary.outer_scale_prime * t_u;
3765            let o_v = binary.outer_scale_prime * t_v;
3766            let o_uv = binary.outer_scale_second * t_u * t_v + binary.outer_scale_prime * l_uv;
3767            let mut primary = binary.grad_scale * fourth;
3768            primary.scaled_add(c_u, &third_v);
3769            primary.scaled_add(c_v, &third_u);
3770            primary.scaled_add(c_uv, &survival_hessian);
3771            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3772                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3773                    primary[[a, b]] += o_uv * survival_gradient[a] * survival_gradient[b]
3774                        + o_v * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b])
3775                        + o_u * (g_v[a] * survival_gradient[b] + survival_gradient[a] * g_v[b])
3776                        + binary.outer_scale
3777                            * (g_uv[a] * survival_gradient[b]
3778                                + g_u[a] * g_v[b]
3779                                + g_v[a] * g_u[b]
3780                                + survival_gradient[a] * g_uv[b]);
3781                }
3782            }
3783            self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3784        }
3785        Ok(out)
3786    }
3787}
3788
3789/// Shared interface that both `LatentSurvivalFamily` and `LatentBinaryFamily`
3790/// expose to the joint Hessian workspace.
3791///
3792/// The two families produce the same `ExactNewtonJointHessianWorkspace`
3793/// shape — five of the six workspace methods are pure delegations to a
3794/// matching family method (dense evaluation, directional derivatives, and the
3795/// `slices` cache). The only family-specific piece is the per-row matvec body:
3796/// the survival family iterates over real (entry, exit, ḋ) triples and may
3797/// carry a log-σ block, while the binary family rewrites the same row kernel
3798/// through `binary_from_log_survival(·)` to recover the per-row binary
3799/// gradient/Hessian. That single difference is captured by `ws_matvec_into`;
3800/// every other method is shared by the generic `LatentHessianWorkspace<F>`
3801/// below.
3802trait LatentJointHessianFamily {
3803    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices;
3804
3805    fn ws_evaluate_dense(
3806        &self,
3807        block_states: &[ParameterBlockState],
3808    ) -> Result<(f64, Array1<f64>, Array2<f64>), String>;
3809
3810    fn ws_dh_directional(
3811        &self,
3812        block_states: &[ParameterBlockState],
3813        d_beta_flat: &Array1<f64>,
3814    ) -> Result<Array2<f64>, String>;
3815
3816    fn ws_dh_second_directional(
3817        &self,
3818        block_states: &[ParameterBlockState],
3819        d_beta_u: &Array1<f64>,
3820        d_beta_v: &Array1<f64>,
3821    ) -> Result<Array2<f64>, String>;
3822
3823    /// Family-specific per-row Hessian matvec body, hoisted out of the
3824    /// workspace impl. Writes `out := H · v` (with `out.fill(0.0)` already
3825    /// performed by the caller) using the family's row kernel.
3826    fn ws_matvec_into(
3827        &self,
3828        slices: &LatentSurvivalJointSlices,
3829        block_states: &[ParameterBlockState],
3830        v: &Array1<f64>,
3831        out: &mut Array1<f64>,
3832    ) -> Result<bool, String>;
3833
3834    /// Family-name fragment used in the workspace's dimension-mismatch error
3835    /// message, so callers still see "latent survival …" / "latent binary …"
3836    /// after the workspace impl was unified.
3837    fn ws_label() -> &'static str;
3838}
3839
3840impl LatentJointHessianFamily for LatentSurvivalFamily {
3841    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3842        self.joint_slices()
3843    }
3844
3845    fn ws_evaluate_dense(
3846        &self,
3847        block_states: &[ParameterBlockState],
3848    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3849        self.evaluate_exact_newton_joint_dense(block_states)
3850    }
3851
3852    fn ws_dh_directional(
3853        &self,
3854        block_states: &[ParameterBlockState],
3855        d_beta_flat: &Array1<f64>,
3856    ) -> Result<Array2<f64>, String> {
3857        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3858    }
3859
3860    fn ws_dh_second_directional(
3861        &self,
3862        block_states: &[ParameterBlockState],
3863        d_beta_u: &Array1<f64>,
3864        d_beta_v: &Array1<f64>,
3865    ) -> Result<Array2<f64>, String> {
3866        self.exact_newton_joint_hessian_second_directional_derivative_dense(
3867            block_states,
3868            d_beta_u,
3869            d_beta_v,
3870        )
3871    }
3872
3873    fn ws_matvec_into(
3874        &self,
3875        slices: &LatentSurvivalJointSlices,
3876        block_states: &[ParameterBlockState],
3877        v: &Array1<f64>,
3878        out: &mut Array1<f64>,
3879    ) -> Result<bool, String> {
3880        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
3881        let q_right = self.time_q_right(block_states)?;
3882        let sigma = self.latent_sd(block_states)?;
3883        let include_log_sigma = slices.log_sigma.is_some();
3884        for row_idx in 0..self.event_target.len() {
3885            let wi = self.weights[row_idx];
3886            if wi <= MIN_WEIGHT {
3887                continue;
3888            }
3889            let row = self.build_row_at(
3890                row_idx,
3891                q_entry[row_idx],
3892                q_exit[row_idx],
3893                qdot_exit[row_idx],
3894                q_right[row_idx],
3895            )?;
3896            let (_, _, primary_hessian) = latent_survival_row_primary_gradient_hessian(
3897                &self.quadctx,
3898                &row,
3899                q_entry[row_idx],
3900                q_exit[row_idx],
3901                qdot_exit[row_idx],
3902                q_right[row_idx],
3903                mu[row_idx],
3904                sigma,
3905                include_log_sigma,
3906            )?;
3907            let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3908            let primary_hv = primary_hessian.dot(&primary_dir);
3909            self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi)?;
3910        }
3911        Ok(true)
3912    }
3913
3914    fn ws_label() -> &'static str {
3915        "survival"
3916    }
3917}
3918
3919impl LatentJointHessianFamily for LatentBinaryFamily {
3920    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3921        self.joint_slices()
3922    }
3923
3924    fn ws_evaluate_dense(
3925        &self,
3926        block_states: &[ParameterBlockState],
3927    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3928        self.evaluate_exact_newton_joint_dense(block_states)
3929    }
3930
3931    fn ws_dh_directional(
3932        &self,
3933        block_states: &[ParameterBlockState],
3934        d_beta_flat: &Array1<f64>,
3935    ) -> Result<Array2<f64>, String> {
3936        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3937    }
3938
3939    fn ws_dh_second_directional(
3940        &self,
3941        block_states: &[ParameterBlockState],
3942        d_beta_u: &Array1<f64>,
3943        d_beta_v: &Array1<f64>,
3944    ) -> Result<Array2<f64>, String> {
3945        self.exact_newton_joint_hessian_second_directional_derivative_dense(
3946            block_states,
3947            d_beta_u,
3948            d_beta_v,
3949        )
3950    }
3951
3952    fn ws_matvec_into(
3953        &self,
3954        slices: &LatentSurvivalJointSlices,
3955        block_states: &[ParameterBlockState],
3956        v: &Array1<f64>,
3957        out: &mut Array1<f64>,
3958    ) -> Result<bool, String> {
3959        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3960        for row_idx in 0..self.event_target.len() {
3961            let wi = self.weights[row_idx];
3962            if wi <= MIN_WEIGHT {
3963                continue;
3964            }
3965            let row =
3966                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3967            let (row_log_survival, survival_gradient, survival_hessian) =
3968                latent_survival_row_primary_gradient_hessian(
3969                    &self.quadctx,
3970                    &row,
3971                    q_entry[row_idx],
3972                    q_exit[row_idx],
3973                    1.0,
3974                    q_exit[row_idx],
3975                    mu[row_idx],
3976                    self.latent_sd,
3977                    false,
3978                )?;
3979            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3980            let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3981            let mut primary_hv = binary.grad_scale * survival_hessian.dot(&primary_dir);
3982            let outer_dot = survival_gradient.dot(&primary_dir);
3983            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3984                primary_hv[a] += binary.outer_scale * survival_gradient[a] * outer_dot;
3985            }
3986            self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi);
3987        }
3988        Ok(true)
3989    }
3990
3991    fn ws_label() -> &'static str {
3992        "binary"
3993    }
3994}
3995
3996/// Joint exact-Newton Hessian workspace shared by `LatentSurvivalFamily` and
3997/// `LatentBinaryFamily`. The two families plug into the workspace via
3998/// `LatentJointHessianFamily`; this struct holds the shared bookkeeping
3999/// (block states + cached slices) and routes every trait method either through
4000/// a thin family delegation or through the family's `ws_matvec_into` row
4001/// kernel.
4002struct LatentHessianWorkspace<F: LatentJointHessianFamily> {
4003    family: F,
4004    block_states: Vec<ParameterBlockState>,
4005    slices: LatentSurvivalJointSlices,
4006}
4007
4008impl<F: LatentJointHessianFamily> LatentHessianWorkspace<F> {
4009    fn new(family: F, block_states: Vec<ParameterBlockState>) -> Self {
4010        let slices = family.ws_joint_slices();
4011        Self {
4012            family,
4013            block_states,
4014            slices,
4015        }
4016    }
4017}
4018
4019impl<F> ExactNewtonJointHessianWorkspace for LatentHessianWorkspace<F>
4020where
4021    F: LatentJointHessianFamily + Send + Sync + 'static,
4022{
4023    fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
4024        self.family
4025            .ws_evaluate_dense(&self.block_states)
4026            .map(|(_, _, hessian)| Some(hessian))
4027    }
4028
4029    fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
4030        let mut out = Array1::<f64>::zeros(self.slices.total);
4031        self.hessian_matvec_into(v, &mut out)?;
4032        Ok(Some(out))
4033    }
4034
4035    fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
4036        if v.len() != self.slices.total || out.len() != self.slices.total {
4037            return Err(format!(
4038                "latent {} Hessian matvec dimension mismatch: v={} out={} expected={}",
4039                F::ws_label(),
4040                v.len(),
4041                out.len(),
4042                self.slices.total
4043            ));
4044        }
4045        out.fill(0.0);
4046        self.family
4047            .ws_matvec_into(&self.slices, &self.block_states, v, out)
4048    }
4049
4050    fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
4051        let dense = self.family.ws_evaluate_dense(&self.block_states)?.2;
4052        Ok(Some(dense.diag().to_owned()))
4053    }
4054
4055    fn directional_derivative(
4056        &self,
4057        d_beta_flat: &Array1<f64>,
4058    ) -> Result<Option<Array2<f64>>, String> {
4059        self.family
4060            .ws_dh_directional(&self.block_states, d_beta_flat)
4061            .map(Some)
4062    }
4063
4064    fn second_directional_derivative(
4065        &self,
4066        d_beta_u: &Array1<f64>,
4067        d_beta_v: &Array1<f64>,
4068    ) -> Result<Option<Array2<f64>>, String> {
4069        self.family
4070            .ws_dh_second_directional(&self.block_states, d_beta_u, d_beta_v)
4071            .map(Some)
4072    }
4073}
4074
4075type LatentSurvivalHessianWorkspace = LatentHessianWorkspace<LatentSurvivalFamily>;
4076type LatentBinaryHessianWorkspace = LatentHessianWorkspace<LatentBinaryFamily>;
4077
4078impl CustomFamily for LatentSurvivalFamily {
4079    // Latent survival fits keep the self-limiting Jeffreys/Firth curvature
4080    // active for their under-identification regime. The trait default flipped to
4081    // OFF in gam#1395 (flat-prior exact-Newton objective); opt back in here.
4082    fn joint_jeffreys_term_required(&self) -> bool {
4083        true
4084    }
4085
4086    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4087        true
4088    }
4089
4090    fn has_explicit_joint_hessian(&self) -> bool {
4091        true
4092    }
4093
4094    /// Engage the inner self-vanishing Levenberg–Marquardt μ on a full-rank but
4095    /// indefinite / ill-conditioned penalized joint Hessian, mirroring the
4096    /// sibling [`SurvivalMarginalSlopeFamily`]. Interval-censored rows contribute
4097    /// `ℓ = log[S(L) − S(R)]`, the log of a DIFFERENCE of two survival kernels:
4098    /// unlike the log-concave exact-event / right-censored contributions, its
4099    /// per-row Hessian is legitimately INDEFINITE away from the optimum, so the
4100    /// coupled exact-joint penalized Hessian on the constrained (monotone-cone)
4101    /// time block can be full-rank (`nullity == 0`) yet indefinite or severely
4102    /// ill-conditioned at the cold-start seed. The constrained-QP path already
4103    /// REFLECTS negative-curvature modes to `|λ|` (a convex modified-Newton
4104    /// model), but with this gate OFF it adds NO diagonal floor on a full-rank
4105    /// ill-conditioned reflected model, so the trust-region Newton oscillates on
4106    /// the near-singular mode and stalls out the inner budget before any KKT
4107    /// snapshot is taken ("exited the joint Newton path before convergence — no
4108    /// math snapshot"). Arming the gate adds the SAME self-vanishing μ
4109    /// (∝ the projected KKT residual `‖∇ℓ − Sβ + ∇Φ‖` → 0 at the fixed point) the
4110    /// marginal-slope survival inner relies on, so the step is a well-damped
4111    /// modified-Newton descent that converges, while the converged β̂ is the
4112    /// EXACT unconditioned optimum (μ → 0 there) — zero REML/LAML bias, exact
4113    /// gradient unchanged.
4114    fn levenberg_on_ill_conditioning(&self) -> bool {
4115        true
4116    }
4117
4118    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4119        // `evaluate_exact_newton_joint_dense` builds a fully dense joint
4120        // Hessian over (Σ p_b)² across time, mean, and optional log-σ blocks
4121        // via per-row pullback of the latent-survival primary kernel.
4122        crate::custom_family::joint_coupled_coefficient_hessian_cost(
4123            self.event_target.len() as u64,
4124            specs,
4125        )
4126    }
4127
4128    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4129        let (ll, joint_gradient, hess_time, hess_mean, hess_log_sigma) =
4130            self.evaluate_exact_newton_block_diagonals(block_states)?;
4131        let block_ranges = self.joint_block_ranges();
4132        let mut blockworking_sets = vec![
4133            BlockWorkingSet::ExactNewton {
4134                gradient: joint_gradient.slice(s![block_ranges[0].clone()]).to_owned(),
4135                hessian: SymmetricMatrix::Dense(hess_time),
4136            },
4137            BlockWorkingSet::ExactNewton {
4138                gradient: joint_gradient.slice(s![block_ranges[1].clone()]).to_owned(),
4139                hessian: SymmetricMatrix::Dense(hess_mean),
4140            },
4141        ];
4142        if let (Some(range), Some(hessian)) = (block_ranges.get(2).cloned(), hess_log_sigma) {
4143            blockworking_sets.push(BlockWorkingSet::ExactNewton {
4144                gradient: joint_gradient.slice(s![range]).to_owned(),
4145                hessian: SymmetricMatrix::Dense(hessian),
4146            });
4147        }
4148        Ok(FamilyEvaluation {
4149            log_likelihood: ll,
4150            blockworking_sets,
4151        })
4152    }
4153
4154    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4155        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4156        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
4157        let q_right = self.time_q_right(block_states)?;
4158        let latent_sd = self.latent_sd(block_states)?;
4159        let n = self.event_target.len();
4160        // Per-row latent-survival jet + log-lik contribution. Independent
4161        // across rows; sum via parallel reduce. `?` propagation happens
4162        // through a Result-collecting fold.
4163        let contributions: Result<Vec<f64>, String> = (0..n)
4164            .into_par_iter()
4165            .map(|i| -> Result<f64, String> {
4166                let wi = self.weights[i];
4167                if wi <= MIN_WEIGHT {
4168                    return Ok(0.0);
4169                }
4170                let row = self.build_row_at(i, q_entry[i], q_exit[i], qdot_exit[i], q_right[i])?;
4171                let jet = LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], latent_sd)
4172                    .map_err(|e| format!("LatentSurvivalFamily row {i}: {e}"))?;
4173                Ok(wi * jet.log_lik)
4174            })
4175            .collect();
4176        Ok(contributions?.into_iter().sum())
4177    }
4178
4179    fn block_linear_constraints(
4180        &self,
4181        _: &[ParameterBlockState],
4182        block_idx: usize,
4183        block_spec: &ParameterBlockSpec,
4184    ) -> Result<Option<LinearInequalityConstraints>, String> {
4185        assert!(!block_spec.name.is_empty());
4186        if block_idx == Self::BLOCK_TIME {
4187            Ok(self.time_linear_constraints.clone())
4188        } else {
4189            Ok(None)
4190        }
4191    }
4192
4193    fn exact_newton_joint_hessian(
4194        &self,
4195        block_states: &[ParameterBlockState],
4196    ) -> Result<Option<Array2<f64>>, String> {
4197        self.evaluate_exact_newton_joint_dense(block_states)
4198            .map(|(_, _, hessian)| Some(hessian))
4199    }
4200
4201    fn exact_newton_joint_hessian_workspace(
4202        &self,
4203        block_states: &[ParameterBlockState],
4204        _: &[ParameterBlockSpec],
4205    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4206        Ok(Some(Arc::new(LatentSurvivalHessianWorkspace::new(
4207            self.clone(),
4208            block_states.to_vec(),
4209        ))))
4210    }
4211
4212    fn exact_newton_joint_gradient_evaluation(
4213        &self,
4214        block_states: &[ParameterBlockState],
4215        _: &[ParameterBlockSpec],
4216    ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4217        self.evaluate_exact_newton_joint_gradient_dense(block_states)
4218            .map(|(log_likelihood, gradient)| {
4219                Some(ExactNewtonJointGradientEvaluation {
4220                    log_likelihood,
4221                    gradient,
4222                })
4223            })
4224    }
4225
4226    fn exact_newton_joint_hessian_directional_derivative(
4227        &self,
4228        block_states: &[ParameterBlockState],
4229        d_beta_flat: &Array1<f64>,
4230    ) -> Result<Option<Array2<f64>>, String> {
4231        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4232            .map(Some)
4233    }
4234
4235    fn exact_newton_joint_hessiansecond_directional_derivative(
4236        &self,
4237        block_states: &[ParameterBlockState],
4238        d_beta_u_flat: &Array1<f64>,
4239        d_beta_v_flat: &Array1<f64>,
4240    ) -> Result<Option<Array2<f64>>, String> {
4241        self.exact_newton_joint_hessian_second_directional_derivative_dense(
4242            block_states,
4243            d_beta_u_flat,
4244            d_beta_v_flat,
4245        )
4246        .map(Some)
4247    }
4248
4249    fn requires_joint_outer_hyper_path(&self) -> bool {
4250        true
4251    }
4252}
4253
4254impl CustomFamily for LatentBinaryFamily {
4255    // Latent binary fits have a separation regime; keep the self-limiting
4256    // Jeffreys/Firth curvature active. The trait default flipped to OFF in
4257    // gam#1395 (flat-prior exact-Newton objective); opt back in here.
4258    fn joint_jeffreys_term_required(&self) -> bool {
4259        true
4260    }
4261
4262    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4263        true
4264    }
4265
4266    fn has_explicit_joint_hessian(&self) -> bool {
4267        true
4268    }
4269
4270    /// Same self-vanishing Levenberg–Marquardt gate as
4271    /// [`LatentSurvivalFamily`]: the latent-binary deployment shares the
4272    /// constrained (monotone-cone) coupled time block, so a full-rank but
4273    /// ill-conditioned penalized joint Hessian at the cold-start seed must get
4274    /// the self-vanishing μ floor rather than oscillating the constrained-QP
4275    /// trust region into a snapshot-less stall. μ → 0 at the fixed point, so the
4276    /// converged β̂ is exact (no REML/LAML bias).
4277    fn levenberg_on_ill_conditioning(&self) -> bool {
4278        true
4279    }
4280
4281    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4282        crate::custom_family::joint_coupled_coefficient_hessian_cost(
4283            self.event_target.len() as u64,
4284            specs,
4285        )
4286    }
4287
4288    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4289        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4290        let n = self.event_target.len();
4291        let p_time = self.x_time_exit.ncols();
4292        let p_mean = self.x_mean.ncols();
4293
4294        let mut ll = 0.0;
4295        let mut grad_time = Array1::<f64>::zeros(p_time);
4296        let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
4297        let mut grad_mean = Array1::<f64>::zeros(p_mean);
4298        let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
4299        // Reusable 1-row buffer for x_mean so we avoid allocating a fresh
4300        // Array2<f64> on every iteration via try_row_chunk(i..i+1).
4301        let mut mean_row_buf = Array2::<f64>::zeros((1, p_mean));
4302
4303        for i in 0..n {
4304            let wi = self.weights[i];
4305            if wi <= MIN_WEIGHT {
4306                continue;
4307            }
4308            if !(q_entry[i].is_finite() && q_exit[i].is_finite() && mu[i].is_finite()) {
4309                return Err(format!(
4310                    "latent-binary row {i} contains non-finite predictors: q_entry={}, q_exit={}, mu={}",
4311                    q_entry[i], q_exit[i], mu[i]
4312                ));
4313            }
4314            let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4315            let survival_jet =
4316                LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4317                    .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4318            let binary = binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?;
4319            ll += wi * binary.log_lik;
4320
4321            self.x_mean
4322                .row_chunk_into(i..i + 1, mean_row_buf.view_mut())
4323                .map_err(|e| format!("LatentBinaryFamily row {i} mean row_chunk: {e}"))?;
4324            let mean_vec = mean_row_buf.row(0);
4325            let mean_grad_scale = wi * binary.grad_scale * survival_jet.score;
4326            for j in 0..p_mean {
4327                grad_mean[j] += mean_grad_scale * mean_vec[j];
4328            }
4329            let mean_neg_hess = wi
4330                * (binary.neg_hess_scale * survival_jet.neg_hessian
4331                    + binary.outer_scale * survival_jet.score * survival_jet.score);
4332            dense_outer_accumulate(&mut hess_mean, mean_neg_hess, mean_vec);
4333
4334            let time_jet =
4335                latent_survival_time_jet(&self.quadctx, &row, 0.0, mu[i], self.latent_sd)?;
4336            let t_entry = self.x_time_entry.row(i);
4337            let t_exit = self.x_time_exit.row(i);
4338            for j in 0..p_time {
4339                grad_time[j] += wi
4340                    * binary.grad_scale
4341                    * (time_jet.grad_entry * t_entry[j] + time_jet.grad_exit * t_exit[j]);
4342            }
4343            dense_outer_accumulate(
4344                &mut hess_time,
4345                wi * binary.neg_hess_scale * time_jet.neg_hess_entry,
4346                t_entry,
4347            );
4348            dense_outer_accumulate(
4349                &mut hess_time,
4350                wi * binary.neg_hess_scale * time_jet.neg_hess_exit,
4351                t_exit,
4352            );
4353            if binary.outer_scale != 0.0 {
4354                dense_outer_accumulate(
4355                    &mut hess_time,
4356                    wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_entry,
4357                    t_entry,
4358                );
4359                dense_outer_accumulate(
4360                    &mut hess_time,
4361                    wi * binary.outer_scale * time_jet.grad_exit * time_jet.grad_exit,
4362                    t_exit,
4363                );
4364                dense_symmetric_cross_accumulate(
4365                    &mut hess_time,
4366                    wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_exit,
4367                    t_entry,
4368                    t_exit,
4369                );
4370            }
4371        }
4372
4373        Ok(FamilyEvaluation {
4374            log_likelihood: ll,
4375            blockworking_sets: vec![
4376                BlockWorkingSet::ExactNewton {
4377                    gradient: grad_time,
4378                    hessian: SymmetricMatrix::Dense(hess_time),
4379                },
4380                BlockWorkingSet::ExactNewton {
4381                    gradient: grad_mean,
4382                    hessian: SymmetricMatrix::Dense(hess_mean),
4383                },
4384            ],
4385        })
4386    }
4387
4388    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4389        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4390        let mut ll = 0.0;
4391        for i in 0..self.event_target.len() {
4392            let wi = self.weights[i];
4393            if wi <= MIN_WEIGHT {
4394                continue;
4395            }
4396            let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4397            let survival_jet =
4398                LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4399                    .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4400            ll +=
4401                wi * binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?.log_lik;
4402        }
4403        Ok(ll)
4404    }
4405
4406    fn block_linear_constraints(
4407        &self,
4408        _: &[ParameterBlockState],
4409        block_idx: usize,
4410        block_spec: &ParameterBlockSpec,
4411    ) -> Result<Option<LinearInequalityConstraints>, String> {
4412        assert!(!block_spec.name.is_empty());
4413        if block_idx == Self::BLOCK_TIME {
4414            Ok(self.time_linear_constraints.clone())
4415        } else {
4416            Ok(None)
4417        }
4418    }
4419
4420    fn exact_newton_joint_hessian(
4421        &self,
4422        block_states: &[ParameterBlockState],
4423    ) -> Result<Option<Array2<f64>>, String> {
4424        self.evaluate_exact_newton_joint_dense(block_states)
4425            .map(|(_, _, hessian)| Some(hessian))
4426    }
4427
4428    fn exact_newton_joint_hessian_workspace(
4429        &self,
4430        block_states: &[ParameterBlockState],
4431        _: &[ParameterBlockSpec],
4432    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4433        Ok(Some(Arc::new(LatentBinaryHessianWorkspace::new(
4434            self.clone(),
4435            block_states.to_vec(),
4436        ))))
4437    }
4438
4439    fn exact_newton_joint_gradient_evaluation(
4440        &self,
4441        block_states: &[ParameterBlockState],
4442        _: &[ParameterBlockSpec],
4443    ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4444        self.evaluate_exact_newton_joint_dense(block_states)
4445            .map(|(log_likelihood, gradient, _)| {
4446                Some(ExactNewtonJointGradientEvaluation {
4447                    log_likelihood,
4448                    gradient,
4449                })
4450            })
4451    }
4452
4453    fn exact_newton_joint_hessian_directional_derivative(
4454        &self,
4455        block_states: &[ParameterBlockState],
4456        d_beta_flat: &Array1<f64>,
4457    ) -> Result<Option<Array2<f64>>, String> {
4458        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4459            .map(Some)
4460    }
4461
4462    fn exact_newton_joint_hessiansecond_directional_derivative(
4463        &self,
4464        block_states: &[ParameterBlockState],
4465        d_beta_u_flat: &Array1<f64>,
4466        d_beta_v_flat: &Array1<f64>,
4467    ) -> Result<Option<Array2<f64>>, String> {
4468        self.exact_newton_joint_hessian_second_directional_derivative_dense(
4469            block_states,
4470            d_beta_u_flat,
4471            d_beta_v_flat,
4472        )
4473        .map(Some)
4474    }
4475
4476    fn requires_joint_outer_hyper_path(&self) -> bool {
4477        true
4478    }
4479}
4480
4481#[cfg(test)]
4482mod tests {
4483    use super::*;
4484    use crate::custom_family::BlockWorkingSet;
4485    use gam_linalg::matrix::DenseDesignMatrix;
4486    use ndarray::array;
4487
4488    fn learnable_sigma_test_family() -> LatentSurvivalFamily {
4489        LatentSurvivalFamily {
4490            event_target: array![1u8, 0u8],
4491            weights: array![1.0, 0.7],
4492            latent_sd_fixed: None,
4493            hazard_loading: HazardLoading::LoadedVsUnloaded,
4494            unloaded_mass_entry: array![0.02, 0.03],
4495            unloaded_mass_exit: array![0.05, 0.08],
4496            unloaded_hazard_exit: array![0.04, 0.0],
4497            x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4498            x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4499            x_time_derivative_exit: array![[0.8, 0.4], [0.6, 0.5]],
4500            x_time_right: array![[1.3, 0.1], [0.9, 1.0]],
4501            time_offset_right: Array1::zeros(2),
4502            unloaded_mass_right: Array1::zeros(2),
4503            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4504            time_linear_constraints: None,
4505            quadctx: Arc::new(QuadratureContext::new()),
4506        }
4507    }
4508
4509    fn learnable_sigma_test_joint_beta() -> Array1<f64> {
4510        array![0.15, 0.25, 0.1, -0.15, 0.35_f64.ln()]
4511    }
4512
4513    fn survival_stress_test_family(n: usize) -> LatentSurvivalFamily {
4514        LatentSurvivalFamily {
4515            event_target: Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1u8 } else { 0u8 })),
4516            weights: Array1::from_iter((0..n).map(|i| 0.55 + 0.03 * ((i % 7) as f64))),
4517            latent_sd_fixed: None,
4518            hazard_loading: HazardLoading::LoadedVsUnloaded,
4519            unloaded_mass_entry: Array1::from_iter(
4520                (0..n).map(|i| 0.015 + 0.0015 * ((i % 11) as f64)),
4521            ),
4522            unloaded_mass_exit: Array1::from_iter((0..n).map(|i| 0.06 + 0.002 * ((i % 13) as f64))),
4523            unloaded_hazard_exit: Array1::from_iter((0..n).map(|i| {
4524                if i % 4 == 0 {
4525                    0.018 + 0.001 * ((i % 5) as f64)
4526                } else {
4527                    0.0
4528                }
4529            })),
4530            x_time_entry: Array2::from_shape_fn((n, 4), |(i, j)| {
4531                0.2 + 0.03 * ((i + 2 * j) % 9) as f64 - if j == 1 { 0.12 } else { 0.0 }
4532            }),
4533            x_time_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4534                0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4535            }),
4536            x_time_derivative_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4537                0.45 + 0.015 * ((i + 3 * j) % 8) as f64
4538            }),
4539            x_time_right: Array2::from_shape_fn((n, 4), |(i, j)| {
4540                0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4541            }),
4542            time_offset_right: Array1::zeros(n),
4543            unloaded_mass_right: Array1::zeros(n),
4544            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_shape_fn(
4545                (n, 3),
4546                |(i, j)| 0.1 + 0.04 * ((3 * i + j) % 7) as f64 - if j == 0 { 0.18 } else { 0.0 },
4547            ))),
4548            time_linear_constraints: None,
4549            quadctx: Arc::new(QuadratureContext::new()),
4550        }
4551    }
4552
4553    fn survival_stress_test_joint_beta() -> Array1<f64> {
4554        array![0.18, 0.11, 0.07, 0.13, -0.09, 0.05, 0.12, 0.42_f64.ln()]
4555    }
4556
4557    fn latent_survival_states_from_joint_beta(
4558        family: &LatentSurvivalFamily,
4559        joint_beta: &Array1<f64>,
4560    ) -> Vec<ParameterBlockState> {
4561        let slices = family.joint_slices();
4562        let n = family.event_target.len();
4563        let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4564        let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4565
4566        let mut eta_time = Array1::<f64>::zeros(3 * n);
4567        eta_time
4568            .slice_mut(s![0..n])
4569            .assign(&gam_linalg::faer_ndarray::fast_av(
4570                &family.x_time_entry,
4571                &beta_time,
4572            ));
4573        eta_time
4574            .slice_mut(s![n..2 * n])
4575            .assign(&gam_linalg::faer_ndarray::fast_av(
4576                &family.x_time_exit,
4577                &beta_time,
4578            ));
4579        eta_time
4580            .slice_mut(s![2 * n..3 * n])
4581            .assign(&gam_linalg::faer_ndarray::fast_av(
4582                &family.x_time_derivative_exit,
4583                &beta_time,
4584            ));
4585
4586        let mut states = vec![
4587            ParameterBlockState {
4588                beta: beta_time,
4589                eta: eta_time,
4590            },
4591            ParameterBlockState {
4592                beta: beta_mean.clone(),
4593                eta: family.x_mean.dot(&beta_mean),
4594            },
4595        ];
4596        if let Some(log_sigma) = slices.log_sigma {
4597            let beta_log_sigma = array![joint_beta[log_sigma.start]];
4598            states.push(ParameterBlockState {
4599                beta: beta_log_sigma.clone(),
4600                eta: beta_log_sigma,
4601            });
4602        }
4603        states
4604    }
4605
4606    fn max_relative_array1(left: &Array1<f64>, right: &Array1<f64>) -> f64 {
4607        left.iter()
4608            .zip(right.iter())
4609            .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4610            .fold(0.0_f64, f64::max)
4611    }
4612
4613    fn max_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4614        left.iter()
4615            .zip(right.iter())
4616            .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4617            .fold(0.0_f64, f64::max)
4618    }
4619
4620    fn frobenius_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4621        let mut diff2 = 0.0_f64;
4622        let mut scale2 = 0.0_f64;
4623        for (l, r) in left.iter().zip(right.iter()) {
4624            let d = l - r;
4625            diff2 += d * d;
4626            scale2 += l * l + r * r;
4627        }
4628        diff2.sqrt() / scale2.sqrt().max(1e-12)
4629    }
4630
4631    fn latent_survival_row_loglik_from_primary(
4632        quadctx: &QuadratureContext,
4633        row: &LatentSurvivalRow,
4634        primary: &Array1<f64>,
4635    ) -> f64 {
4636        let q_entry = primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
4637        let q_exit = primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
4638        let qdot_exit = primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
4639        let q_right = primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
4640        let mu = primary[LATENT_SURVIVAL_PRIMARY_MU];
4641        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
4642        latent_survival_row_primary_gradient_hessian(
4643            quadctx, row, q_entry, q_exit, qdot_exit, q_right, mu, sigma, true,
4644        )
4645        .expect("row primary evaluation")
4646        .0
4647    }
4648
4649    fn latent_test_specs(n: usize, block_dims: &[(&str, usize)]) -> Vec<ParameterBlockSpec> {
4650        block_dims
4651            .iter()
4652            .map(|(name, p)| ParameterBlockSpec {
4653                name: (*name).to_string(),
4654                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, *p)))),
4655                offset: Array1::zeros(n),
4656                penalties: Vec::new(),
4657                nullspace_dims: Vec::new(),
4658                initial_log_lambdas: Array1::zeros(0),
4659                initial_beta: None,
4660                gauge_priority: 100,
4661                jacobian_callback: None,
4662                stacked_design: None,
4663                stacked_offset: None,
4664            })
4665            .collect()
4666    }
4667
4668    fn fixed_sigma_binary_test_family() -> LatentBinaryFamily {
4669        LatentBinaryFamily {
4670            event_target: array![1u8, 0u8],
4671            weights: array![1.0, 0.7],
4672            latent_sd: 0.35,
4673            hazard_loading: HazardLoading::LoadedVsUnloaded,
4674            unloaded_mass_entry: array![0.02, 0.03],
4675            unloaded_mass_exit: array![0.05, 0.08],
4676            x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4677            x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4678            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4679            time_linear_constraints: None,
4680            quadctx: Arc::new(QuadratureContext::new()),
4681        }
4682    }
4683
4684    fn latent_binary_states_from_joint_beta(
4685        family: &LatentBinaryFamily,
4686        joint_beta: &Array1<f64>,
4687    ) -> Vec<ParameterBlockState> {
4688        let slices = family.joint_slices();
4689        let n = family.event_target.len();
4690        let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4691        let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4692
4693        let mut eta_time = Array1::<f64>::zeros(3 * n);
4694        eta_time
4695            .slice_mut(s![0..n])
4696            .assign(&gam_linalg::faer_ndarray::fast_av(
4697                &family.x_time_entry,
4698                &beta_time,
4699            ));
4700        eta_time
4701            .slice_mut(s![n..2 * n])
4702            .assign(&gam_linalg::faer_ndarray::fast_av(
4703                &family.x_time_exit,
4704                &beta_time,
4705            ));
4706
4707        vec![
4708            ParameterBlockState {
4709                beta: beta_time,
4710                eta: eta_time,
4711            },
4712            ParameterBlockState {
4713                beta: beta_mean.clone(),
4714                eta: family.x_mean.dot(&beta_mean),
4715            },
4716        ]
4717    }
4718
4719    // --- shared latent-interval validation engine: parity / contract tests ---
4720
4721    use crate::survival::location_scale::{TimeBlockInput, TimeBlockMonotonicity};
4722
4723    /// Minimal, structurally valid `TimeBlockInput` for `n` rows and `p_time`
4724    /// columns, used to exercise the shared validation driver without standing
4725    /// up a full term-collection design.
4726    fn validation_time_block(n: usize, p_time: usize) -> TimeBlockInput {
4727        let design = |fill: f64| {
4728            DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_elem(
4729                (n, p_time),
4730                fill,
4731            )))
4732        };
4733        TimeBlockInput {
4734            design_entry: design(0.1),
4735            design_exit: design(0.2),
4736            design_derivative_exit: design(0.3),
4737            offset_entry: Array1::zeros(n),
4738            offset_exit: Array1::zeros(n),
4739            derivative_offset_exit: Array1::zeros(n),
4740            time_monotonicity: TimeBlockMonotonicity::EnforcedByCoordinateCone,
4741            penalties: Vec::new(),
4742            nullspace_dims: Vec::new(),
4743            initial_log_lambdas: None,
4744            initial_beta: None,
4745        }
4746    }
4747
4748    fn empty_meanspec() -> TermCollectionSpec {
4749        TermCollectionSpec {
4750            linear_terms: Vec::new(),
4751            random_effect_terms: Vec::new(),
4752            smooth_terms: Vec::new(),
4753        }
4754    }
4755
4756    /// A valid two-row latent-survival term spec (one exact event under loaded
4757    /// hazard, one right-censored row).
4758    fn valid_survival_spec(n: usize, p_time: usize) -> LatentSurvivalTermSpec {
4759        LatentSurvivalTermSpec {
4760            age_entry: Array1::zeros(n),
4761            age_exit: Array1::from_elem(n, 1.0),
4762            event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4763            weights: Array1::from_elem(n, 1.0),
4764            derivative_guard: 0.0,
4765            time_block: validation_time_block(n, p_time),
4766            time_design_right: None,
4767            time_offset_right: None,
4768            unloaded_mass_entry: Array1::from_elem(n, 0.01),
4769            unloaded_mass_exit: Array1::from_elem(n, 0.05),
4770            unloaded_mass_right: Array1::zeros(0),
4771            unloaded_hazard_exit: Array1::from_elem(n, 0.02),
4772            meanspec: empty_meanspec(),
4773            mean_offset: Array1::zeros(n),
4774        }
4775    }
4776
4777    /// A valid latent-binary term spec mirroring `valid_survival_spec` but
4778    /// without the per-row unloaded hazard.
4779    fn valid_binary_spec(n: usize, p_time: usize) -> LatentBinaryTermSpec {
4780        LatentBinaryTermSpec {
4781            age_entry: Array1::zeros(n),
4782            age_exit: Array1::from_elem(n, 1.0),
4783            event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4784            weights: Array1::from_elem(n, 1.0),
4785            derivative_guard: 0.0,
4786            time_block: validation_time_block(n, p_time),
4787            unloaded_mass_entry: Array1::from_elem(n, 0.01),
4788            unloaded_mass_exit: Array1::from_elem(n, 0.05),
4789            meanspec: empty_meanspec(),
4790            mean_offset: Array1::zeros(n),
4791        }
4792    }
4793
4794    fn loaded_frailty() -> FrailtySpec {
4795        FrailtySpec::HazardMultiplier {
4796            sigma_fixed: Some(0.3),
4797            loading: HazardLoading::LoadedVsUnloaded,
4798        }
4799    }
4800
4801    /// Both adapters route through the shared `validate_latent_interval_inputs`
4802    /// engine, but each must still emit its own context prefix and (for the
4803    /// size-mismatch / unloaded-decomposition diagnostics) the hazard-aware vs
4804    /// mass-only message variant. This pins the byte-for-byte contract the
4805    /// unification had to preserve, the property the issue's "old vs new
4806    /// validation errors" parity test guards.
4807    #[test]
4808    fn latent_interval_validation_parity_across_models() {
4809        let n = 2;
4810        let p_time = 2;
4811        let data = Array2::<f64>::zeros((n, 3));
4812
4813        // 1. A clean spec validates and round-trips the resolved sigma.
4814        //    Survival keeps the (possibly learnable) Option; binary unwraps to
4815        //    the fixed scalar.
4816        let surv_sigma = validate_latent_survival_inputs(
4817            data.view(),
4818            &valid_survival_spec(n, p_time),
4819            &loaded_frailty(),
4820        )
4821        .expect("valid survival spec must validate");
4822        assert_eq!(surv_sigma, Some(0.3));
4823        let bin_sigma = validate_latent_binary_inputs(
4824            data.view(),
4825            &valid_binary_spec(n, p_time),
4826            &loaded_frailty(),
4827        )
4828        .expect("valid binary spec must validate");
4829        assert_eq!(bin_sigma, 0.3);
4830
4831        // 2. Empty data: shared driver, per-model context prefix.
4832        let empty = Array2::<f64>::zeros((0, 3));
4833        let surv_empty = validate_latent_survival_inputs(
4834            empty.view(),
4835            &valid_survival_spec(n, p_time),
4836            &loaded_frailty(),
4837        )
4838        .expect_err("empty data must be rejected");
4839        assert_eq!(
4840            surv_empty.to_string(),
4841            "latent-survival requires a non-empty dataset"
4842        );
4843        let bin_empty = validate_latent_binary_inputs(
4844            empty.view(),
4845            &valid_binary_spec(n, p_time),
4846            &loaded_frailty(),
4847        )
4848        .expect_err("empty data must be rejected");
4849        assert_eq!(
4850            bin_empty.to_string(),
4851            "latent-binary requires a non-empty dataset"
4852        );
4853
4854        // 3. Size mismatch: survival's message carries `unloaded_hazard=`,
4855        //    binary's does not. This is the one shape that distinguishes the
4856        //    two row views feeding the shared driver.
4857        let mut surv_bad = valid_survival_spec(n, p_time);
4858        surv_bad.weights = Array1::from_elem(n + 1, 1.0);
4859        let surv_size = validate_latent_survival_inputs(data.view(), &surv_bad, &loaded_frailty())
4860            .expect_err("size mismatch must be rejected");
4861        let surv_msg = surv_size.to_string();
4862        assert!(
4863            surv_msg.starts_with("latent-survival size mismatch")
4864                && surv_msg.contains("unloaded_hazard="),
4865            "survival size-mismatch message must include unloaded_hazard: {surv_msg}"
4866        );
4867        let mut bin_bad = valid_binary_spec(n, p_time);
4868        bin_bad.weights = Array1::from_elem(n + 1, 1.0);
4869        let bin_size = validate_latent_binary_inputs(data.view(), &bin_bad, &loaded_frailty())
4870            .expect_err("size mismatch must be rejected");
4871        let bin_msg = bin_size.to_string();
4872        assert!(
4873            bin_msg.starts_with("latent-binary size mismatch")
4874                && !bin_msg.contains("unloaded_hazard"),
4875            "binary size-mismatch message must omit unloaded_hazard: {bin_msg}"
4876        );
4877
4878        // 4. Invalid unloaded decomposition: survival reports `exit_hazard=`,
4879        //    binary reports only the two masses.
4880        let mut surv_neg_hazard = valid_survival_spec(n, p_time);
4881        surv_neg_hazard.unloaded_hazard_exit[0] = -1.0;
4882        let surv_decomp =
4883            validate_latent_survival_inputs(data.view(), &surv_neg_hazard, &loaded_frailty())
4884                .expect_err("negative unloaded hazard must be rejected");
4885        assert_eq!(
4886            surv_decomp.to_string(),
4887            "latent-survival row 1 has invalid unloaded hazard decomposition: entry_mass=0.01, exit_mass=0.05, exit_hazard=-1"
4888        );
4889        let mut bin_bad_mass = valid_binary_spec(n, p_time);
4890        bin_bad_mass.unloaded_mass_exit[0] = 0.0; // exit < entry
4891        let bin_decomp =
4892            validate_latent_binary_inputs(data.view(), &bin_bad_mass, &loaded_frailty())
4893                .expect_err("non-monotone unloaded mass must be rejected");
4894        assert_eq!(
4895            bin_decomp.to_string(),
4896            "latent-binary row 1 has invalid unloaded mass decomposition: entry_mass=0.01, exit_mass=0"
4897        );
4898
4899        // 5. Per-row interval/event/weight diagnostics share one engine, so an
4900        //    identical invalid input yields identical (modulo prefix) text.
4901        let mut surv_event = valid_survival_spec(n, p_time);
4902        surv_event.event_target[1] = 7;
4903        let surv_event_err =
4904            validate_latent_survival_inputs(data.view(), &surv_event, &loaded_frailty())
4905                .expect_err("invalid event target must be rejected");
4906        assert_eq!(
4907            surv_event_err.to_string(),
4908            "latent-survival row 2 has invalid event target 7; expected 0 or 1"
4909        );
4910        let mut bin_event = valid_binary_spec(n, p_time);
4911        bin_event.event_target[1] = 7;
4912        let bin_event_err =
4913            validate_latent_binary_inputs(data.view(), &bin_event, &loaded_frailty())
4914                .expect_err("invalid event target must be rejected");
4915        assert_eq!(
4916            bin_event_err.to_string(),
4917            "latent-binary row 2 has invalid event target 7; expected 0 or 1"
4918        );
4919
4920        // 6. Frailty policy divergence: survival accepts a learnable scale
4921        //    (`sigma_fixed = None` ⇒ `Ok(None)`), binary rejects it.
4922        let learnable = FrailtySpec::HazardMultiplier {
4923            sigma_fixed: None,
4924            loading: HazardLoading::LoadedVsUnloaded,
4925        };
4926        let surv_learnable = validate_latent_survival_inputs(
4927            data.view(),
4928            &valid_survival_spec(n, p_time),
4929            &learnable,
4930        )
4931        .expect("survival accepts a learnable latent scale");
4932        assert_eq!(surv_learnable, None);
4933        let bin_learnable =
4934            validate_latent_binary_inputs(data.view(), &valid_binary_spec(n, p_time), &learnable)
4935                .expect_err("binary requires a fixed latent scale");
4936        assert_eq!(
4937            bin_learnable.to_string(),
4938            "latent-binary currently requires a fixed hazard-multiplier sigma"
4939        );
4940
4941        // 7. The time-block shape check is owned by the shared driver: a
4942        //    column-count mismatch is reported with the per-model prefix.
4943        let mut surv_time_bad = valid_survival_spec(n, p_time);
4944        surv_time_bad.time_block.design_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
4945            Array2::from_elem((n, p_time + 1), 0.1),
4946        ));
4947        let surv_time_err =
4948            validate_latent_survival_inputs(data.view(), &surv_time_bad, &loaded_frailty())
4949                .expect_err("time block column mismatch must be rejected");
4950        assert!(
4951            surv_time_err
4952                .to_string()
4953                .starts_with("latent-survival time block column mismatch"),
4954            "unexpected survival time-block message: {surv_time_err}"
4955        );
4956    }
4957
4958    #[test]
4959    fn latent_survival_coefficient_cost_uses_joint_coupled_formula() {
4960        // `evaluate_exact_newton_joint_dense` builds a fully dense joint
4961        // Hessian over (Σ p_b)² across the time, mean, and log-σ blocks via
4962        // per-row pullback of the latent-survival primary kernel. The override
4963        // must reflect that joint coupling rather than the block-diagonal
4964        // default.
4965        let family = learnable_sigma_test_family();
4966        let n = family.event_target.len() as u64;
4967        let p_time = 2u64;
4968        let p_mean = 2u64;
4969        let p_log_sigma = 1u64;
4970        let specs = vec![
4971            ParameterBlockSpec {
4972                name: "time".to_string(),
4973                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4974                    n as usize,
4975                    p_time as usize,
4976                )))),
4977                offset: Array1::zeros(n as usize),
4978                penalties: Vec::new(),
4979                nullspace_dims: Vec::new(),
4980                initial_log_lambdas: Array1::zeros(0),
4981                initial_beta: None,
4982                gauge_priority: 100,
4983                jacobian_callback: None,
4984                stacked_design: None,
4985                stacked_offset: None,
4986            },
4987            ParameterBlockSpec {
4988                name: "mean".to_string(),
4989                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4990                    n as usize,
4991                    p_mean as usize,
4992                )))),
4993                offset: Array1::zeros(n as usize),
4994                penalties: Vec::new(),
4995                nullspace_dims: Vec::new(),
4996                initial_log_lambdas: Array1::zeros(0),
4997                initial_beta: None,
4998                gauge_priority: 100,
4999                jacobian_callback: None,
5000                stacked_design: None,
5001                stacked_offset: None,
5002            },
5003            ParameterBlockSpec {
5004                name: "log_sigma".to_string(),
5005                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
5006                    n as usize,
5007                    p_log_sigma as usize,
5008                )))),
5009                offset: Array1::zeros(n as usize),
5010                penalties: Vec::new(),
5011                nullspace_dims: Vec::new(),
5012                initial_log_lambdas: Array1::zeros(0),
5013                initial_beta: None,
5014                gauge_priority: 100,
5015                jacobian_callback: None,
5016                stacked_design: None,
5017                stacked_offset: None,
5018            },
5019        ];
5020        let p_total = p_time + p_mean + p_log_sigma;
5021        let expected_joint = n * p_total * p_total;
5022        let expected_block_diag =
5023            n * (p_time * p_time + p_mean * p_mean + p_log_sigma * p_log_sigma);
5024        assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
5025        // Cross-block fill (time–mean, time–log_sigma, mean–log_sigma) makes
5026        // the joint cost strictly larger than the block-diagonal default.
5027        assert!(expected_joint > expected_block_diag);
5028    }
5029
5030    #[test]
5031    fn latent_family_planner_keeps_outer_hessian_at_large_n() {
5032        use crate::custom_family::custom_family_outer_derivatives;
5033        use gam_problem::{DeclaredHessianForm, Derivative};
5034
5035        let options = BlockwiseFitOptions::default();
5036        let large_n = 50_001;
5037
5038        let survival = learnable_sigma_test_family();
5039        let survival_specs =
5040            latent_test_specs(large_n, &[("time", 2), ("mean", 2), ("log_sigma", 1)]);
5041        let (surv_grad, surv_hess) =
5042            custom_family_outer_derivatives(&survival, &survival_specs, &options);
5043        assert_eq!(surv_grad, Derivative::Analytic);
5044        assert_eq!(surv_hess, DeclaredHessianForm::Either);
5045
5046        let binary = fixed_sigma_binary_test_family();
5047        let binary_specs = latent_test_specs(large_n, &[("time", 2), ("mean", 2)]);
5048        let (bin_grad, bin_hess) =
5049            custom_family_outer_derivatives(&binary, &binary_specs, &options);
5050        assert_eq!(bin_grad, Derivative::Analytic);
5051        assert_eq!(bin_hess, DeclaredHessianForm::Either);
5052    }
5053
5054    #[test]
5055    fn latent_families_arm_self_vanishing_levenberg_on_ill_conditioning() {
5056        // Regression guard for #1108. The interval-censored row contribution
5057        // `ℓ = log[S(L) − S(R)]` is the log of a DIFFERENCE of survival kernels and
5058        // is legitimately NON-concave (indefinite per-row Hessian) away from the
5059        // optimum; on the constrained (monotone-cone) coupled time block this can
5060        // make the penalized joint Hessian full-rank yet indefinite / severely
5061        // ill-conditioned at the cold-start seed. The coupled exact-joint inner
5062        // solver only adds the self-vanishing Levenberg–Marquardt diagonal floor
5063        // (the cure for a full-rank ill-conditioned reflected QP that otherwise
5064        // oscillates the trust region into a snapshot-less stall) when the family
5065        // opts in via `levenberg_on_ill_conditioning()`. Both latent families MUST
5066        // keep this armed (the default is `false`, which leaves the interval inner
5067        // solve diverging with "exited the joint Newton path before convergence").
5068        assert!(
5069            learnable_sigma_test_family().levenberg_on_ill_conditioning(),
5070            "LatentSurvivalFamily must arm the self-vanishing Levenberg floor so the \
5071             indefinite interval-censored joint Hessian converges (see #1108)"
5072        );
5073        assert!(
5074            fixed_sigma_binary_test_family().levenberg_on_ill_conditioning(),
5075            "LatentBinaryFamily must arm the self-vanishing Levenberg floor on its \
5076             constrained coupled time block (see #1108)"
5077        );
5078    }
5079
5080    #[test]
5081    fn latent_binary_exact_joint_hessian_and_workspace_matvec_match_fd() {
5082        let family = fixed_sigma_binary_test_family();
5083        let beta = array![0.15, 0.25, 0.1, -0.15];
5084        let states = latent_binary_states_from_joint_beta(&family, &beta);
5085        let h = 1e-6;
5086
5087        let analytic_hessian = family
5088            .exact_newton_joint_hessian(&states)
5089            .expect("analytic latent binary joint hessian evaluation")
5090            .expect("latent binary should expose exact joint hessian");
5091
5092        for j in 0..beta.len() {
5093            let mut beta_plus = beta.clone();
5094            beta_plus[j] += h;
5095            let gradient_plus = family
5096                .exact_newton_joint_gradient_evaluation(
5097                    &latent_binary_states_from_joint_beta(&family, &beta_plus),
5098                    &[],
5099                )
5100                .expect("joint gradient plus")
5101                .expect("joint gradient should exist")
5102                .gradient;
5103
5104            let mut beta_minus = beta.clone();
5105            beta_minus[j] -= h;
5106            let gradient_minus = family
5107                .exact_newton_joint_gradient_evaluation(
5108                    &latent_binary_states_from_joint_beta(&family, &beta_minus),
5109                    &[],
5110                )
5111                .expect("joint gradient minus")
5112                .expect("joint gradient should exist")
5113                .gradient;
5114
5115            let fd_column = -((&gradient_plus - &gradient_minus) / (2.0 * h));
5116            let analytic_column = analytic_hessian.column(j).to_owned();
5117            let rel = max_relative_array1(&analytic_column, &fd_column);
5118            assert!(
5119                rel < 5e-4,
5120                "latent binary joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={fd_column:?}"
5121            );
5122        }
5123
5124        let workspace = family
5125            .exact_newton_joint_hessian_workspace(&states, &[])
5126            .expect("latent binary hessian workspace")
5127            .expect("workspace should exist");
5128        let direction = array![0.4, -0.2, 0.3, 0.1];
5129        let hv = workspace
5130            .hessian_matvec(&direction)
5131            .expect("workspace matvec")
5132            .expect("workspace should support matvec");
5133        let dense_hv = analytic_hessian.dot(&direction);
5134        assert!(
5135            max_relative_array1(&hv, &dense_hv) < 1e-12,
5136            "latent binary workspace HVP mismatch: hv={hv:?}, dense={dense_hv:?}"
5137        );
5138
5139        let dh = workspace
5140            .directional_derivative(&direction)
5141            .expect("workspace dH")
5142            .expect("workspace should support dH");
5143        let fd_step = 1e-5;
5144        let h_plus = family
5145            .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5146                &family,
5147                &(beta.clone() + &(fd_step * &direction)),
5148            ))
5149            .expect("hessian plus")
5150            .expect("hessian plus should exist");
5151        let h_minus = family
5152            .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5153                &family,
5154                &(beta - &(fd_step * &direction)),
5155            ))
5156            .expect("hessian minus")
5157            .expect("hessian minus should exist");
5158        let fd_dh = (&h_plus - &h_minus) / (2.0 * fd_step);
5159        assert!(
5160            max_relative_array2(&dh, &fd_dh) < 2e-4,
5161            "latent binary workspace dH mismatch: dh={dh:?}, fd={fd_dh:?}"
5162        );
5163    }
5164
5165    #[test]
5166    fn latent_survival_learnable_sigma_block_matches_family_fd() {
5167        let family = learnable_sigma_test_family();
5168        let beta = learnable_sigma_test_joint_beta();
5169        let states = latent_survival_states_from_joint_beta(&family, &beta);
5170        let slices = family.joint_slices();
5171        let sigma_idx = slices
5172            .log_sigma
5173            .as_ref()
5174            .expect("learnable sigma test family should expose log_sigma")
5175            .start;
5176        let h = 2e-4;
5177
5178        let eval = family
5179            .evaluate(&states)
5180            .expect("learnable latent survival evaluation");
5181        let joint_gradient = family
5182            .exact_newton_joint_gradient_evaluation(&states, &[])
5183            .expect("joint gradient evaluation")
5184            .expect("joint gradient should exist")
5185            .gradient;
5186        let joint_hessian = family
5187            .exact_newton_joint_hessian(&states)
5188            .expect("joint hessian evaluation")
5189            .expect("joint hessian should exist");
5190        assert_eq!(eval.blockworking_sets.len(), 3);
5191
5192        let (block_grad, block_neg_hess) =
5193            match &eval.blockworking_sets[LatentSurvivalFamily::BLOCK_LOG_SIGMA] {
5194                BlockWorkingSet::ExactNewton { gradient, hessian } => {
5195                    let neg_hess = match hessian {
5196                        SymmetricMatrix::Dense(mat) => mat[[0, 0]],
5197                        _ => panic!("log_sigma block should use a dense exact-Newton Hessian"),
5198                    };
5199                    (gradient[0], neg_hess)
5200                }
5201                _ => panic!("log_sigma block should use ExactNewton"),
5202            };
5203
5204        assert!((block_grad - joint_gradient[sigma_idx]).abs() < 1e-12);
5205        assert!((block_neg_hess - joint_hessian[[sigma_idx, sigma_idx]]).abs() < 1e-12);
5206
5207        let mut beta_plus = beta.clone();
5208        beta_plus[sigma_idx] += h;
5209        let ll_plus = family
5210            .log_likelihood_only(&latent_survival_states_from_joint_beta(&family, &beta_plus))
5211            .expect("ll plus");
5212        let ll_0 = family.log_likelihood_only(&states).expect("ll base");
5213        let mut beta_minus = beta.clone();
5214        beta_minus[sigma_idx] -= h;
5215        let ll_minus = family
5216            .log_likelihood_only(&latent_survival_states_from_joint_beta(
5217                &family,
5218                &beta_minus,
5219            ))
5220            .expect("ll minus");
5221
5222        let fd_grad = (ll_plus - ll_minus) / (2.0 * h);
5223        let fd_neg_hess = -(ll_plus - 2.0 * ll_0 + ll_minus) / (h * h);
5224        assert!(
5225            (joint_gradient[sigma_idx] - fd_grad).abs()
5226                / joint_gradient[sigma_idx]
5227                    .abs()
5228                    .max(fd_grad.abs())
5229                    .max(1e-12)
5230                < 2e-3,
5231            "family log_sigma grad={}, fd={fd_grad}",
5232            joint_gradient[sigma_idx]
5233        );
5234        assert!(
5235            (joint_hessian[[sigma_idx, sigma_idx]] - fd_neg_hess).abs()
5236                / joint_hessian[[sigma_idx, sigma_idx]]
5237                    .abs()
5238                    .max(fd_neg_hess.abs())
5239                    .max(1e-10)
5240                < 2e-2,
5241            "family log_sigma neg_hess={}, fd={fd_neg_hess}",
5242            joint_hessian[[sigma_idx, sigma_idx]]
5243        );
5244    }
5245
5246    #[test]
5247    fn latent_survival_exact_joint_hessian_matches_gradient_fd() {
5248        let family = learnable_sigma_test_family();
5249        let beta = learnable_sigma_test_joint_beta();
5250        let states = latent_survival_states_from_joint_beta(&family, &beta);
5251        let h = 1e-6;
5252
5253        let analytic_hessian = family
5254            .exact_newton_joint_hessian(&states)
5255            .expect("analytic joint hessian evaluation")
5256            .expect("latent survival should expose exact joint hessian");
5257
5258        for j in 0..beta.len() {
5259            let mut beta_plus = beta.clone();
5260            beta_plus[j] += h;
5261            let gradient_plus = family
5262                .exact_newton_joint_gradient_evaluation(
5263                    &latent_survival_states_from_joint_beta(&family, &beta_plus),
5264                    &[],
5265                )
5266                .expect("joint gradient plus")
5267                .expect("joint gradient should exist")
5268                .gradient;
5269
5270            let mut beta_minus = beta.clone();
5271            beta_minus[j] -= h;
5272            let gradient_minus = family
5273                .exact_newton_joint_gradient_evaluation(
5274                    &latent_survival_states_from_joint_beta(&family, &beta_minus),
5275                    &[],
5276                )
5277                .expect("joint gradient minus")
5278                .expect("joint gradient should exist")
5279                .gradient;
5280
5281            let fd_column = (&gradient_plus - &gradient_minus) / (2.0 * h);
5282            let analytic_column = analytic_hessian.column(j).to_owned();
5283            let rel = max_relative_array1(&analytic_column, &(-fd_column));
5284            assert!(
5285                rel < 5e-4,
5286                "joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={:?}",
5287                -((&gradient_plus - &gradient_minus) / (2.0 * h))
5288            );
5289        }
5290    }
5291
5292    /// FD check for `LatentSurvivalFamily::offset_channel_residuals`: each
5293    /// channel residual sums to `∂(−ℓ)/∂o_ch` for a uniform additive offset on
5294    /// that time channel (the baseline-θ enters only through these offsets).
5295    /// `o_ch` shifts `eta_time[ch-slice]` uniformly, so `Σ_i r^ch_i` is exactly
5296    /// the directional derivative of `−ℓ` along a constant offset on channel ch.
5297    /// This validates the envelope-theorem latent baseline-θ gradient primitive.
5298    #[test]
5299    fn latent_survival_offset_channel_residuals_match_finite_difference() {
5300        let family = survival_stress_test_family(24);
5301        let beta = survival_stress_test_joint_beta();
5302        let states = latent_survival_states_from_joint_beta(&family, &beta);
5303        let n = family.event_target.len();
5304
5305        let residuals = family
5306            .offset_channel_residuals(&states)
5307            .expect("offset channel residuals");
5308        let sum_entry: f64 = residuals.entry.sum();
5309        let sum_exit: f64 = residuals.exit.sum();
5310        let sum_deriv: f64 = residuals.derivative.sum();
5311
5312        // `−ℓ` after shifting one time channel's eta by a constant δ.
5313        let neg_ll_with_offset = |channel: usize, delta: f64| -> f64 {
5314            let mut shifted = states.clone();
5315            let slice = match channel {
5316                0 => s![0..n],
5317                1 => s![n..2 * n],
5318                2 => s![2 * n..3 * n],
5319                _ => unreachable!(),
5320            };
5321            shifted[LatentSurvivalFamily::BLOCK_TIME]
5322                .eta
5323                .slice_mut(slice)
5324                .mapv_inplace(|v| v + delta);
5325            let (ll, _) = family
5326                .evaluate_exact_newton_joint_gradient_dense(&shifted)
5327                .expect("shifted joint gradient evaluation");
5328            -ll
5329        };
5330
5331        let h = 1e-6;
5332        let fd_entry = (neg_ll_with_offset(0, h) - neg_ll_with_offset(0, -h)) / (2.0 * h);
5333        let fd_exit = (neg_ll_with_offset(1, h) - neg_ll_with_offset(1, -h)) / (2.0 * h);
5334        let fd_deriv = (neg_ll_with_offset(2, h) - neg_ll_with_offset(2, -h)) / (2.0 * h);
5335
5336        assert!(
5337            (sum_entry - fd_entry).abs() <= 1e-5 * fd_entry.abs().max(1.0),
5338            "entry-channel residual sum mismatch: analytic={sum_entry}, fd={fd_entry}"
5339        );
5340        assert!(
5341            (sum_exit - fd_exit).abs() <= 1e-5 * fd_exit.abs().max(1.0),
5342            "exit-channel residual sum mismatch: analytic={sum_exit}, fd={fd_exit}"
5343        );
5344        assert!(
5345            (sum_deriv - fd_deriv).abs() <= 1e-5 * fd_deriv.abs().max(1.0),
5346            "derivative-channel residual sum mismatch: analytic={sum_deriv}, fd={fd_deriv}"
5347        );
5348    }
5349
5350    #[test]
5351    fn latent_survival_exact_joint_parallel_stress_is_repeatable() {
5352        let family = survival_stress_test_family(96);
5353        let beta = survival_stress_test_joint_beta();
5354        let states = latent_survival_states_from_joint_beta(&family, &beta);
5355        let direction_u = array![0.03, -0.02, 0.01, 0.04, -0.015, 0.025, -0.005, 0.02];
5356        let direction_v = array![-0.01, 0.035, -0.025, 0.015, 0.02, -0.01, 0.03, -0.015];
5357
5358        let (ll_a, grad_a) = family
5359            .evaluate_exact_newton_joint_gradient_dense(&states)
5360            .expect("stress joint gradient evaluation");
5361        let (ll_b, grad_b) = family
5362            .evaluate_exact_newton_joint_gradient_dense(&states)
5363            .expect("repeat stress joint gradient evaluation");
5364        assert_eq!(ll_a.to_bits(), ll_b.to_bits());
5365        assert_eq!(grad_a, grad_b);
5366
5367        let (joint_ll_a, joint_grad_a, hess_a) = family
5368            .evaluate_exact_newton_joint_dense(&states)
5369            .expect("stress joint dense evaluation");
5370        let (joint_ll_b, joint_grad_b, hess_b) = family
5371            .evaluate_exact_newton_joint_dense(&states)
5372            .expect("repeat stress joint dense evaluation");
5373        assert_eq!(joint_ll_a.to_bits(), joint_ll_b.to_bits());
5374        assert_eq!(joint_grad_a, joint_grad_b);
5375        assert_eq!(hess_a, hess_b);
5376        assert!(hess_a.iter().all(|value| value.is_finite()));
5377        assert!(max_relative_array2(&hess_a, &hess_a.t().to_owned()) < 1e-12);
5378
5379        let dh_a = family
5380            .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5381            .expect("stress joint dH evaluation");
5382        let dh_b = family
5383            .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5384            .expect("repeat stress joint dH evaluation");
5385        assert_eq!(dh_a, dh_b);
5386        assert!(dh_a.iter().all(|value| value.is_finite()));
5387        assert!(max_relative_array2(&dh_a, &dh_a.t().to_owned()) < 1e-12);
5388
5389        let d2h_a = family
5390            .exact_newton_joint_hessian_second_directional_derivative_dense(
5391                &states,
5392                &direction_u,
5393                &direction_v,
5394            )
5395            .expect("stress joint d2H evaluation");
5396        let d2h_b = family
5397            .exact_newton_joint_hessian_second_directional_derivative_dense(
5398                &states,
5399                &direction_u,
5400                &direction_v,
5401            )
5402            .expect("repeat stress joint d2H evaluation");
5403        assert_eq!(d2h_a, d2h_b);
5404        assert!(d2h_a.iter().all(|value| value.is_finite()));
5405        assert!(max_relative_array2(&d2h_a, &d2h_a.t().to_owned()) < 1e-12);
5406    }
5407
5408    #[test]
5409    fn latent_survival_exact_joint_dh_matches_hessian_fd() {
5410        let family = learnable_sigma_test_family();
5411        let beta = learnable_sigma_test_joint_beta();
5412        let states = latent_survival_states_from_joint_beta(&family, &beta);
5413        let h = 2e-4;
5414        let direction = array![0.07, -0.03, 0.05, 0.02, -0.04];
5415
5416        let analytic = family
5417            .exact_newton_joint_hessian_directional_derivative(&states, &direction)
5418            .expect("analytic joint dH evaluation")
5419            .expect("latent survival should expose exact joint dH");
5420
5421        let hessian_plus = family
5422            .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5423                &family,
5424                &(beta.clone() + h * &direction),
5425            ))
5426            .expect("joint hessian plus")
5427            .expect("joint hessian should exist");
5428        let hessian_minus = family
5429            .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5430                &family,
5431                &(beta.clone() - h * &direction),
5432            ))
5433            .expect("joint hessian minus")
5434            .expect("joint hessian should exist");
5435
5436        let fd = (&hessian_plus - &hessian_minus) / (2.0 * h);
5437        let rel = frobenius_relative_array2(&analytic, &fd);
5438        assert!(rel < 2e-3, "joint dH mismatch: rel={rel}");
5439    }
5440
5441    #[test]
5442    fn latent_survival_exact_joint_d2h_matches_directional_fd() {
5443        let family = learnable_sigma_test_family();
5444        let beta = learnable_sigma_test_joint_beta();
5445        let states = latent_survival_states_from_joint_beta(&family, &beta);
5446        let h = 5e-4;
5447        let direction_u = array![0.07, -0.03, 0.05, 0.02, -0.04];
5448        let direction_v = array![-0.02, 0.06, -0.01, 0.03, 0.05];
5449
5450        let analytic = family
5451            .exact_newton_joint_hessiansecond_directional_derivative(
5452                &states,
5453                &direction_u,
5454                &direction_v,
5455            )
5456            .expect("analytic joint d2H evaluation")
5457            .expect("latent survival should expose exact joint d2H");
5458        let swapped = family
5459            .exact_newton_joint_hessiansecond_directional_derivative(
5460                &states,
5461                &direction_v,
5462                &direction_u,
5463            )
5464            .expect("swapped analytic joint d2H evaluation")
5465            .expect("latent survival should expose exact joint d2H");
5466        let symmetry_rel = max_relative_array2(&analytic, &swapped);
5467        assert!(
5468            symmetry_rel < 1e-10,
5469            "joint d2H should be symmetric in directions, got rel={symmetry_rel}"
5470        );
5471
5472        let dh_plus = family
5473            .exact_newton_joint_hessian_directional_derivative(
5474                &latent_survival_states_from_joint_beta(
5475                    &family,
5476                    &(beta.clone() + h * &direction_v),
5477                ),
5478                &direction_u,
5479            )
5480            .expect("joint dH plus")
5481            .expect("joint dH should exist");
5482        let dh_minus = family
5483            .exact_newton_joint_hessian_directional_derivative(
5484                &latent_survival_states_from_joint_beta(
5485                    &family,
5486                    &(beta.clone() - h * &direction_v),
5487                ),
5488                &direction_u,
5489            )
5490            .expect("joint dH minus")
5491            .expect("joint dH should exist");
5492
5493        let fd = (&dh_plus - &dh_minus) / (2.0 * h);
5494        let rel = frobenius_relative_array2(&analytic, &fd);
5495        assert!(rel < 2.5e-2, "joint d2H mismatch: rel={rel}");
5496    }
5497
5498    #[test]
5499    fn latent_survival_row_primary_derivatives_match_fd() {
5500        let quadctx = QuadratureContext::new();
5501        let row = LatentSurvivalRow::exact_event(0.35, 1.4, 0.1, 0.45, 0.8, 0.12);
5502        // [q_entry, q_exit, qdot_exit, q_right, mu, log_sigma]. This is an
5503        // exact-event row, so the `q_right` channel is inert (the likelihood
5504        // does not depend on it); the FD loop below confirms its gradient/Hessian
5505        // entries are zero.
5506        let primary = array![
5507            0.35f64.ln(),
5508            1.4f64.ln(),
5509            0.8,
5510            1.6f64.ln(),
5511            -0.2,
5512            0.4f64.ln()
5513        ];
5514        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5515        let h_grad = 1e-6;
5516        let h_hess = 2e-4;
5517
5518        let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5519            &quadctx,
5520            &row,
5521            primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5522            primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5523            primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5524            primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5525            primary[LATENT_SURVIVAL_PRIMARY_MU],
5526            sigma,
5527            true,
5528        )
5529        .expect("analytic row primary gradient/hessian");
5530
5531        for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5532            let mut plus = primary.clone();
5533            plus[j] += h_grad;
5534            let mut minus = primary.clone();
5535            minus[j] -= h_grad;
5536            let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5537                - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5538                / (2.0 * h_grad);
5539            let rel_grad =
5540                (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5541            assert!(
5542                rel_grad < 2e-4,
5543                "row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5544                gradient[j]
5545            );
5546
5547            for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5548                let mut pp = primary.clone();
5549                pp[j] += h_hess;
5550                pp[k] += h_hess;
5551                let mut pm = primary.clone();
5552                pm[j] += h_hess;
5553                pm[k] -= h_hess;
5554                let mut mp = primary.clone();
5555                mp[j] -= h_hess;
5556                mp[k] += h_hess;
5557                let mut mm = primary.clone();
5558                mm[j] -= h_hess;
5559                mm[k] -= h_hess;
5560                let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5561                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5562                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5563                    + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5564                    / (4.0 * h_hess * h_hess);
5565                let analytic = neg_hessian[[j, k]];
5566                let abs_err = (analytic - fd_neg_hess).abs();
5567                let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5568                assert!(
5569                    abs_err < 2e-5 || rel < 2e-3,
5570                    "row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5571                );
5572            }
5573        }
5574    }
5575
5576    #[test]
5577    fn latent_survival_interval_row_primary_derivatives_match_fd() {
5578        // Interval-censored row jet `ℓ = log[S(L) − S(R)] − log S(entry)`. The
5579        // dynamic two-state numerator differentiates BOTH boundary masses
5580        // `M_L = exp(q_exit)` (left, `q_exit`) and `M_R = exp(q_right)` (right,
5581        // `q_right`) independently — channels that the static
5582        // `LatentSurvivalRowJet::interval_censored` (μ-only) never exercises. This
5583        // FD-verifies the gradient AND neg-Hessian of the interval contribution
5584        // w.r.t. ALL six primary coordinates (q_entry, q_exit/L, qdot_exit,
5585        // q_right/R, mu, log_sigma) on a WELL-POSED bracket where `S(L) − S(R)` is
5586        // comfortably positive (M_L = e^{−0.4} ≈ 0.67 well below M_R = e^{0.5} ≈
5587        // 1.65, so the survival-mass difference is large and the log-of-a-
5588        // difference curvature is well-conditioned).
5589        let quadctx = QuadratureContext::new();
5590        // Bracket masses: entry < L < R with comfortable gaps.
5591        let q_entry = -1.2_f64; // M_entry = e^{−1.2} ≈ 0.30
5592        let q_exit = -0.4_f64; // L: M_L = e^{−0.4} ≈ 0.67
5593        let q_right = 0.5_f64; // R: M_R = e^{0.5} ≈ 1.65 (> M_L)
5594        let mu = -0.15_f64;
5595        let log_sigma = 0.3_f64; // σ ≈ 1.35
5596        // Small, monotone unloaded masses (entry ≤ left ≤ right); qdot is inert
5597        // for the interval contribution.
5598        let row = LatentSurvivalRow::interval_censored(
5599            q_entry.exp(), // mass_entry (consistency only; jet reads q's)
5600            q_exit.exp(),  // mass_left
5601            q_right.exp(), // mass_right
5602            0.01,          // mass_unloaded_entry
5603            0.02,          // mass_unloaded_left
5604            0.05,          // mass_unloaded_right
5605        );
5606        assert!(matches!(
5607            row.event_type,
5608            LatentSurvivalEventType::IntervalCensored
5609        ));
5610
5611        // [q_entry, q_exit/L, qdot_exit, q_right/R, mu, log_sigma]. qdot_exit is
5612        // inert for interval rows (no hazard-derivative channel); the FD loop
5613        // confirms its gradient/Hessian entries are 0.
5614        let primary = array![q_entry, q_exit, 0.7, q_right, mu, log_sigma];
5615        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5616        let h_grad = 1e-6;
5617        let h_hess = 2e-4;
5618
5619        let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5620            &quadctx,
5621            &row,
5622            primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5623            primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5624            primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5625            primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5626            primary[LATENT_SURVIVAL_PRIMARY_MU],
5627            sigma,
5628            true,
5629        )
5630        .expect("analytic interval row primary gradient/hessian");
5631
5632        // The interval contribution must be a positive survival-mass difference
5633        // at this bracket, so the value channel is finite.
5634        let value = latent_survival_row_loglik_from_primary(&quadctx, &row, &primary);
5635        assert!(
5636            value.is_finite(),
5637            "interval row log-likelihood must be finite on a well-posed bracket, got {value}"
5638        );
5639
5640        for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5641            let mut plus = primary.clone();
5642            plus[j] += h_grad;
5643            let mut minus = primary.clone();
5644            minus[j] -= h_grad;
5645            let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5646                - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5647                / (2.0 * h_grad);
5648            let rel_grad =
5649                (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5650            assert!(
5651                rel_grad < 2e-4,
5652                "interval row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5653                gradient[j]
5654            );
5655
5656            for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5657                let mut pp = primary.clone();
5658                pp[j] += h_hess;
5659                pp[k] += h_hess;
5660                let mut pm = primary.clone();
5661                pm[j] += h_hess;
5662                pm[k] -= h_hess;
5663                let mut mp = primary.clone();
5664                mp[j] -= h_hess;
5665                mp[k] += h_hess;
5666                let mut mm = primary.clone();
5667                mm[j] -= h_hess;
5668                mm[k] -= h_hess;
5669                let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5670                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5671                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5672                    + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5673                    / (4.0 * h_hess * h_hess);
5674                let analytic = neg_hessian[[j, k]];
5675                let abs_err = (analytic - fd_neg_hess).abs();
5676                let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5677                assert!(
5678                    abs_err < 5e-5 || rel < 3e-3,
5679                    "interval row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5680                );
5681            }
5682        }
5683    }
5684}