Skip to main content

gam_models/survival/latent/
survival.rs

1//! Jointly learned latent-frailty survival and binary deployment families with
2//! a live time/baseline block.
3//!
4//! Model:
5//!   H_0(a) = exp(q(a)),
6//!   h_0(a) = dq(a)/da,
7//!   H(a | U) = H_0(a) * exp(U),
8//!   U ~ N(mu, sigma^2),
9//!   mu = X beta + offset.
10//!
11//! Unlike the old compiled-row path, the cumulative masses and baseline hazard
12//! are rebuilt inside the optimizer from the current time-basis coefficients.
13//! The family-level fit surface supports exact events, right censoring, and
14//! interval censoring `T ∈ (L, R]` (contribution `log[S(L) − S(R)]`). Interval
15//! rows carry the reserved [`LATENT_SURVIVAL_EVENT_INTERVAL`] event code and a
16//! dedicated upper-bound time channel (`time_design_right` / `q_right`); the
17//! 3-way event dispatch is [`latent_survival_event_type_for`]. Reached from the
18//! formula DSL via `SurvInterval(L, R, event) ~ ...`.
19
20use crate::custom_family::{
21    BlockWorkingSet, BlockwiseFitOptions, CustomFamily, ExactNewtonJointGradientEvaluation,
22    ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
23    PenaltyMatrix, fit_custom_family, fit_custom_family_fixed_log_lambdas,
24};
25use crate::gamlss::{FamilyMetadata, ParameterLink};
26use crate::sigma_link::{exp_sigma_eta_for_sigma_scalar, exp_sigma_from_eta_scalar};
27use crate::survival::latent::interval::{
28    LatentFrailtyResolution, LatentIntervalModel, LatentIntervalRowView,
29    validate_latent_interval_inputs,
30};
31use crate::survival::location_scale::{
32    TimeBlockInput, project_onto_linear_constraints, structural_time_coefficient_constraints,
33};
34use crate::survival::lognormal_kernel::{
35    FrailtySpec, HazardLoading, LatentSurvivalEventType, LatentSurvivalRow, LatentSurvivalRowJet,
36    log_kernel_bundle,
37};
38use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
39use crate::model_types::UnifiedFitResult;
40use gam_solve::pirls::LinearInequalityConstraints;
41use crate::probability::signed_log_sum_exp;
42use crate::quadrature::{IntegratedExpectationMode, QuadratureContext};
43use gam_terms::smooth::{
44    TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
45};
46use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
47use gam_problem::MIN_WEIGHT;
48use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
49use std::collections::BTreeMap;
50use std::sync::Arc;
51
52/// Typed error for the latent-survival / latent-binary family kernels and
53/// their fit-time and per-row validation helpers. Variants pick the semantic
54/// bucket while the inner `reason` carries the original byte-equivalent
55/// message so external callers that previously consumed `String` errors keep
56/// the same diagnostic text via `Display`.
57#[derive(Debug, Clone)]
58pub enum LatentSurvivalError {
59    /// The frailty spec supplied to a latent-survival or latent-binary
60    /// helper is incompatible (wrong variant, missing fixed sigma, non-finite
61    /// or negative fixed sigma).
62    InvalidFrailty { reason: String },
63    /// Per-row dataset validation failed: empty input, size mismatch across
64    /// the spec vectors, or invalid age / event / weight / unloaded-mass
65    /// values for an individual row.
66    InvalidDataset { reason: String },
67    /// A parameter-block state, eta vector, or directional-derivative
68    /// argument supplied to a family entry point has the wrong length.
69    BlockMismatch { reason: String },
70    /// A runtime numerical value (sigma, baseline hazard derivative, kernel
71    /// sum, event probability) became non-finite or out-of-domain.
72    NumericalFailure { reason: String },
73    /// The requested combination of time-block structure or event type is
74    /// not implemented (non-structural monotonicity, interval-censored rows
75    /// on the dynamic-derivative path).
76    UnsupportedConfiguration { reason: String },
77}
78
79impl_reason_error_boilerplate! {
80    LatentSurvivalError {
81        InvalidFrailty,
82        InvalidDataset,
83        BlockMismatch,
84        NumericalFailure,
85        UnsupportedConfiguration,
86    }
87}
88
89impl From<crate::block_layout::block_count::BlockCountMismatch> for LatentSurvivalError {
90    fn from(
91        err: crate::block_layout::block_count::BlockCountMismatch,
92    ) -> LatentSurvivalError {
93        LatentSurvivalError::BlockMismatch {
94            reason: err.message(),
95        }
96    }
97}
98
99impl From<String> for LatentSurvivalError {
100    /// Inbound conversion for the many `Result<_, String>` helpers this
101    /// module still calls into (term-collection design assembly, dense
102    /// chunk conversion, sparse linear constraints). The text is preserved
103    /// verbatim; we only pick a category so external messages flow through
104    /// `?` without per-callsite `.map_err`.
105    fn from(reason: String) -> LatentSurvivalError {
106        LatentSurvivalError::InvalidDataset { reason }
107    }
108}
109
110/// Reserved [`LatentSurvivalTermSpec::event_target`] code marking an
111/// interval-censored row `(L, R]`. Exact-event codes are `>= 1` and right
112/// censoring is `0`; the interval code is the sentinel `u8::MAX` so it never
113/// collides with an exact-event count and the dispatch is an explicit 3-way map
114/// `{0 → RightCensored, INTERVAL → IntervalCensored, k ≥ 1 → ExactEvent}`.
115pub const LATENT_SURVIVAL_EVENT_INTERVAL: u8 = u8::MAX;
116
117#[inline]
118fn latent_survival_event_type_for(code: u8) -> LatentSurvivalEventType {
119    match code {
120        0 => LatentSurvivalEventType::RightCensored,
121        LATENT_SURVIVAL_EVENT_INTERVAL => LatentSurvivalEventType::IntervalCensored,
122        _ => LatentSurvivalEventType::ExactEvent,
123    }
124}
125
126#[derive(Clone)]
127pub struct LatentSurvivalTermSpec {
128    pub age_entry: Array1<f64>,
129    pub age_exit: Array1<f64>,
130    pub event_target: Array1<u8>,
131    pub weights: Array1<f64>,
132    pub derivative_guard: f64,
133    pub time_block: TimeBlockInput,
134    /// Time-basis design evaluated at the interval upper bound `R` (so
135    /// `q_right = design_right · β_time + offset_right`). `None` when the data
136    /// carries no interval-censored rows; the family then reuses the exit design
137    /// for the unused `q_right` channel. When `Some`, rows whose
138    /// `event_target == LATENT_SURVIVAL_EVENT_INTERVAL` contribute the interval
139    /// likelihood `log[S(L) − S(R)]`.
140    pub time_design_right: Option<DesignMatrix>,
141    pub time_offset_right: Option<Array1<f64>>,
142    pub unloaded_mass_entry: Array1<f64>,
143    pub unloaded_mass_exit: Array1<f64>,
144    /// Unloaded (background) cumulative mass at the interval upper bound `R`.
145    /// Length-`n`; entries for non-interval rows are ignored. Empty/`None`
146    /// folds to zero (full-loading interval rows).
147    pub unloaded_mass_right: Array1<f64>,
148    pub unloaded_hazard_exit: Array1<f64>,
149    pub meanspec: TermCollectionSpec,
150    pub mean_offset: Array1<f64>,
151}
152
153pub struct LatentSurvivalTermFitResult {
154    pub fit: UnifiedFitResult,
155    pub design: TermCollectionDesign,
156    pub resolvedspec: TermCollectionSpec,
157    pub latent_sd: f64,
158    /// Per-row residuals of the unpenalized NLL w.r.t. the additive baseline
159    /// time-block offsets `(entry, exit, derivative)` at the converged β̂.
160    /// Contracted against `baseline_offset_theta_partials` by
161    /// `baseline_chain_rule_gradient` to give the exact θ-gradient of the
162    /// profile penalized NLL for the outer baseline-config optimizer.
163    pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
164}
165
166#[derive(Clone)]
167pub struct LatentBinaryTermSpec {
168    pub age_entry: Array1<f64>,
169    pub age_exit: Array1<f64>,
170    pub event_target: Array1<u8>,
171    pub weights: Array1<f64>,
172    pub derivative_guard: f64,
173    pub time_block: TimeBlockInput,
174    pub unloaded_mass_entry: Array1<f64>,
175    pub unloaded_mass_exit: Array1<f64>,
176    pub meanspec: TermCollectionSpec,
177    pub mean_offset: Array1<f64>,
178}
179
180pub struct LatentBinaryTermFitResult {
181    pub fit: UnifiedFitResult,
182    pub design: TermCollectionDesign,
183    pub resolvedspec: TermCollectionSpec,
184    /// Per-row residuals of the unpenalized NLL w.r.t. the additive baseline
185    /// time-block offsets `(entry, exit)` at the converged β̂ (the derivative
186    /// channel is identically zero for the binary deployment likelihood).
187    pub baseline_offset_residuals: crate::survival::OffsetChannelResiduals,
188}
189
190#[derive(Clone)]
191struct PreparedLatentTimeBlock {
192    design_entry: Array2<f64>,
193    design_exit: Array2<f64>,
194    design_derivative_exit: Array2<f64>,
195    /// Dense time-basis design at the interval upper bound `R`. Falls back to a
196    /// clone of `design_exit` when the spec supplies no interval design, so the
197    /// `q_right` channel is always well-defined (and unused for non-interval
198    /// rows).
199    design_right: Array2<f64>,
200    linear_constraints: Option<LinearInequalityConstraints>,
201    penalties: Vec<Array2<f64>>,
202    initial_beta: Option<Array1<f64>>,
203}
204
205#[derive(Clone)]
206pub struct LatentSurvivalFamily {
207    pub event_target: Array1<u8>,
208    pub weights: Array1<f64>,
209    pub latent_sd_fixed: Option<f64>,
210    pub hazard_loading: HazardLoading,
211    pub unloaded_mass_entry: Array1<f64>,
212    pub unloaded_mass_exit: Array1<f64>,
213    pub unloaded_hazard_exit: Array1<f64>,
214    pub x_time_entry: Array2<f64>,
215    pub x_time_exit: Array2<f64>,
216    pub x_time_derivative_exit: Array2<f64>,
217    /// Time-basis design evaluated at the interval upper bound `R` (so
218    /// `q_right = x_time_right · β_time + time_offset_right`). For non-interval
219    /// rows this row equals `x_time_exit`'s row (`q_right` is then unused by the
220    /// likelihood), so the matrix always has `n` rows and the same column count
221    /// as the other time designs.
222    pub x_time_right: Array2<f64>,
223    /// Time-block offset at the interval upper bound `R` (length `n`).
224    pub time_offset_right: Array1<f64>,
225    /// Unloaded (background) cumulative mass at the interval upper bound `R`
226    /// (length `n`). Ignored for non-interval rows.
227    pub unloaded_mass_right: Array1<f64>,
228    pub x_mean: DesignMatrix,
229    pub time_linear_constraints: Option<LinearInequalityConstraints>,
230    pub quadctx: Arc<QuadratureContext>,
231}
232
233#[derive(Clone)]
234pub struct LatentBinaryFamily {
235    pub event_target: Array1<u8>,
236    pub weights: Array1<f64>,
237    pub latent_sd: f64,
238    pub hazard_loading: HazardLoading,
239    pub unloaded_mass_entry: Array1<f64>,
240    pub unloaded_mass_exit: Array1<f64>,
241    pub x_time_entry: Array2<f64>,
242    pub x_time_exit: Array2<f64>,
243    pub x_mean: DesignMatrix,
244    pub time_linear_constraints: Option<LinearInequalityConstraints>,
245    pub quadctx: Arc<QuadratureContext>,
246}
247
248impl LatentSurvivalFamily {
249    pub const BLOCK_TIME: usize = 0;
250    pub const BLOCK_MEAN: usize = 1;
251    pub const BLOCK_LOG_SIGMA: usize = 2;
252
253    pub fn parameter_names() -> &'static [&'static str] {
254        &["time_transform", "mean"]
255    }
256
257    pub fn parameter_links() -> &'static [ParameterLink] {
258        &[ParameterLink::Identity, ParameterLink::Identity]
259    }
260
261    pub fn metadata() -> FamilyMetadata {
262        FamilyMetadata {
263            name: "latent_survival",
264            parameternames: Self::parameter_names(),
265            parameter_links: Self::parameter_links(),
266        }
267    }
268
269    fn split_time_eta<'a>(
270        &self,
271        block_states: &'a [ParameterBlockState],
272    ) -> Result<
273        (
274            ArrayView1<'a, f64>,
275            ArrayView1<'a, f64>,
276            ArrayView1<'a, f64>,
277            &'a Array1<f64>,
278        ),
279        LatentSurvivalError,
280    > {
281        let expected_blocks = if self.latent_sd_fixed.is_some() { 2 } else { 3 };
282        crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
283            "LatentSurvivalFamily",
284            expected_blocks,
285            block_states.len(),
286        )?;
287        let n = self.event_target.len();
288        let eta_time = &block_states[Self::BLOCK_TIME].eta;
289        let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
290        if eta_time.len() != 3 * n {
291            return Err(LatentSurvivalError::BlockMismatch {
292                reason: format!(
293                    "latent survival time eta length mismatch: got {}, expected {}",
294                    eta_time.len(),
295                    3 * n
296                ),
297            });
298        }
299        if eta_mean.len() != n || self.weights.len() != n {
300            return Err(LatentSurvivalError::BlockMismatch {
301                reason: "latent survival mean eta dimension mismatch".to_string(),
302            });
303        }
304        Ok((
305            eta_time.slice(s![0..n]),
306            eta_time.slice(s![n..2 * n]),
307            eta_time.slice(s![2 * n..3 * n]),
308            eta_mean,
309        ))
310    }
311
312    /// Per-row interval upper-bound time transform `q_right = x_time_right · β_time
313    /// + time_offset_right`. Shares the time-block coefficients with `q_exit`
314    /// (same monotone basis, evaluated at `R`), so it is read off the time
315    /// block's `beta` rather than carried as an extra eta channel. For
316    /// non-interval rows `x_time_right` equals `x_time_exit`, so the (unused)
317    /// value is simply `q_exit`.
318    fn time_q_right(
319        &self,
320        block_states: &[ParameterBlockState],
321    ) -> Result<Array1<f64>, LatentSurvivalError> {
322        let n = self.event_target.len();
323        let beta_time = &block_states[Self::BLOCK_TIME].beta;
324        if self.x_time_right.ncols() != beta_time.len() {
325            return Err(LatentSurvivalError::BlockMismatch {
326                reason: format!(
327                    "latent survival interval right design has {} columns but time beta has {}",
328                    self.x_time_right.ncols(),
329                    beta_time.len()
330                ),
331            });
332        }
333        if self.x_time_right.nrows() != n || self.time_offset_right.len() != n {
334            return Err(LatentSurvivalError::BlockMismatch {
335                reason: "latent survival interval right design/offset row count mismatch"
336                    .to_string(),
337            });
338        }
339        let mut q_right = self.x_time_right.dot(beta_time);
340        q_right += &self.time_offset_right;
341        Ok(q_right)
342    }
343
344    fn latent_sd(&self, block_states: &[ParameterBlockState]) -> Result<f64, LatentSurvivalError> {
345        if let Some(sigma) = self.latent_sd_fixed {
346            return Ok(sigma);
347        }
348        let eta = *block_states
349            .get(Self::BLOCK_LOG_SIGMA)
350            .and_then(|state| state.eta.get(0))
351            .ok_or_else(|| LatentSurvivalError::BlockMismatch {
352                reason: "latent survival learnable log_sigma block is missing".to_string(),
353            })?;
354        let sigma = exp_sigma_from_eta_scalar(eta);
355        if !(sigma.is_finite() && sigma > 0.0) {
356            return Err(LatentSurvivalError::NumericalFailure {
357                reason: format!(
358                    "latent survival learnable sigma became invalid: log_sigma={eta}, sigma={sigma}"
359                ),
360            });
361        }
362        Ok(sigma)
363    }
364}
365
366impl LatentBinaryFamily {
367    pub const BLOCK_TIME: usize = 0;
368    pub const BLOCK_MEAN: usize = 1;
369
370    fn split_time_eta<'a>(
371        &self,
372        block_states: &'a [ParameterBlockState],
373    ) -> Result<(ArrayView1<'a, f64>, ArrayView1<'a, f64>, &'a Array1<f64>), LatentSurvivalError>
374    {
375        crate::block_layout::block_count::validate_block_count::<LatentSurvivalError>(
376            "LatentBinaryFamily",
377            2,
378            block_states.len(),
379        )?;
380        let n = self.event_target.len();
381        let eta_time = &block_states[Self::BLOCK_TIME].eta;
382        let eta_mean = &block_states[Self::BLOCK_MEAN].eta;
383        if eta_time.len() != 3 * n {
384            return Err(LatentSurvivalError::BlockMismatch {
385                reason: format!(
386                    "latent binary time eta length mismatch: got {}, expected {}",
387                    eta_time.len(),
388                    3 * n
389                ),
390            });
391        }
392        if eta_mean.len() != n || self.weights.len() != n {
393            return Err(LatentSurvivalError::BlockMismatch {
394                reason: "latent binary mean eta dimension mismatch".to_string(),
395            });
396        }
397        Ok((
398            eta_time.slice(s![0..n]),
399            eta_time.slice(s![n..2 * n]),
400            eta_mean,
401        ))
402    }
403}
404
405pub fn fixed_latent_hazard_frailty(
406    frailty: &FrailtySpec,
407    context: &str,
408) -> Result<(f64, HazardLoading), String> {
409    fixed_latent_hazard_frailty_typed(frailty, context).map_err(Into::into)
410}
411
412fn fixed_latent_hazard_frailty_typed(
413    frailty: &FrailtySpec,
414    context: &str,
415) -> Result<(f64, HazardLoading), LatentSurvivalError> {
416    match frailty {
417        FrailtySpec::HazardMultiplier {
418            sigma_fixed: Some(sigma),
419            loading,
420        } if sigma.is_finite() && *sigma >= 0.0 => Ok((*sigma, *loading)),
421        FrailtySpec::HazardMultiplier {
422            sigma_fixed: Some(sigma),
423            ..
424        } => Err(LatentSurvivalError::InvalidFrailty {
425            reason: format!(
426                "{context} requires a finite fixed hazard-multiplier sigma >= 0, got {sigma}"
427            ),
428        }),
429        FrailtySpec::HazardMultiplier {
430            sigma_fixed: None, ..
431        } => Err(LatentSurvivalError::InvalidFrailty {
432            reason: format!("{context} currently requires a fixed hazard-multiplier sigma"),
433        }),
434        FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
435            reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
436        }),
437        FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
438            reason: format!("{context} requires a fixed HazardMultiplier frailty specification"),
439        }),
440    }
441}
442
443pub fn latent_hazard_loading(
444    frailty: &FrailtySpec,
445    context: &str,
446) -> Result<HazardLoading, String> {
447    latent_hazard_loading_typed(frailty, context).map_err(Into::into)
448}
449
450fn latent_hazard_loading_typed(
451    frailty: &FrailtySpec,
452    context: &str,
453) -> Result<HazardLoading, LatentSurvivalError> {
454    match frailty {
455        FrailtySpec::HazardMultiplier { loading, .. } => Ok(*loading),
456        FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
457            reason: format!("{context} requires HazardMultiplier frailty, not GaussianShift"),
458        }),
459        FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
460            reason: format!("{context} requires a HazardMultiplier frailty specification"),
461        }),
462    }
463}
464
465#[derive(Clone, Copy)]
466struct LatentSurvivalTimeJet {
467    grad_entry: f64,
468    grad_exit: f64,
469    neg_hess_entry: f64,
470    neg_hess_exit: f64,
471}
472
473pub fn fit_latent_survival_terms(
474    data: ArrayView2<'_, f64>,
475    spec: LatentSurvivalTermSpec,
476    frailty: FrailtySpec,
477    options: &BlockwiseFitOptions,
478) -> Result<LatentSurvivalTermFitResult, String> {
479    let latent_sd = validate_latent_survival_inputs(data, &spec, &frailty)?;
480    let hazard_loading = latent_hazard_loading(&frailty, "latent-survival")?;
481    let mean_design =
482        build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
483    let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
484        .map_err(|e| e.to_string())?;
485    let time_prepared = prepare_latent_time_block(
486        &spec.time_block,
487        spec.time_design_right.as_ref(),
488        spec.derivative_guard,
489    )?;
490
491    let n = spec.event_target.len();
492    let time_offset_right = match spec.time_offset_right.as_ref() {
493        Some(offset) => {
494            if offset.len() != n {
495                return Err(format!(
496                    "latent survival interval right time offset must have length {n}, got {}",
497                    offset.len()
498                ));
499            }
500            offset.clone()
501        }
502        None => Array1::zeros(n),
503    };
504    let unloaded_mass_right = if spec.unloaded_mass_right.is_empty() {
505        Array1::zeros(n)
506    } else {
507        if spec.unloaded_mass_right.len() != n {
508            return Err(format!(
509                "latent survival interval right unloaded mass must have length {n}, got {}",
510                spec.unloaded_mass_right.len()
511            ));
512        }
513        spec.unloaded_mass_right.clone()
514    };
515
516    let family = LatentSurvivalFamily {
517        event_target: spec.event_target.clone(),
518        weights: spec.weights.clone(),
519        latent_sd_fixed: latent_sd,
520        hazard_loading,
521        unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
522        unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
523        unloaded_hazard_exit: spec.unloaded_hazard_exit.clone(),
524        x_time_entry: time_prepared.design_entry.clone(),
525        x_time_exit: time_prepared.design_exit.clone(),
526        x_time_derivative_exit: time_prepared.design_derivative_exit.clone(),
527        x_time_right: time_prepared.design_right.clone(),
528        time_offset_right,
529        unloaded_mass_right,
530        x_mean: mean_design.design.clone(),
531        time_linear_constraints: time_prepared.linear_constraints.clone(),
532        quadctx: Arc::new(QuadratureContext::new()),
533    };
534
535    let mut blocks = vec![
536        build_time_blockspec(&time_prepared, &spec.time_block),
537        build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
538    ];
539    if latent_sd.is_none() {
540        blocks.push(build_log_sigma_blockspec(
541            LEARNABLE_LATENT_SD_SEED,
542            mean_design.design.nrows(),
543        ));
544    }
545    // Interval warm start (issue #1108). Interval-censored rows contribute the
546    // NON-concave `ℓ = log[S(L) − S(R)]`; the coupled exact-joint inner Newton
547    // diverges from the cold seed (β_time = 1e-4, σ = 0.5) — the failure surfaces
548    // first as `fit_custom_family`'s outer ρ-seed startup validation rejecting
549    // every seed (`solver_started = 0`). We warm-start from a LOG-CONCAVE
550    // surrogate whose β/σ land in the interval basin, threaded via `initial_beta`
551    // (consumed by every inner solve, including each ρ-seed validation fit).
552    //
553    // Surrogate = right-censored at the bracket LOWER bound `L`. Its survival
554    // mass `S(L) = K_{0,B(L)}` is log-concave (PD Hessian) and — crucially —
555    // its time-block design is the SAME fixed-knot I-spline basis the interval
556    // fit uses, which is FULL RANK regardless of how heavily the inspection-grid
557    // `L` values are TIED (the basis columns are functions of the frozen knots,
558    // not of the observed time multiplicities). Unlike an exact-event surrogate
559    // it imposes NO per-row `q̇(L) > 0` hazard-derivative feasibility condition
560    // (which the tied/degenerate cold-start derivative design can violate), so it
561    // is robust where exact-event-at-L is not. The warm σ then refines from the
562    // bracket-width spread inside the (now in-basin) interval fit.
563    //
564    // Failure is NON-SILENT (#1108): a surrogate that errors or returns a
565    // non-finite / all-zero degenerate β is surfaced as a hard error rather than
566    // silently reverting to the diverging cold start (which masked the real
567    // failure across several attempts). Only `initial_beta` is seeded; the EXACT
568    // interval objective/gradient/Hessian are unchanged, so σ̂ is the true MLE.
569    let has_interval_rows = spec
570        .event_target
571        .iter()
572        .any(|&code| code == LATENT_SURVIVAL_EVENT_INTERVAL);
573    if has_interval_rows {
574        let warm_event_target = spec.event_target.mapv(|code| {
575            if code == LATENT_SURVIVAL_EVENT_INTERVAL {
576                0u8
577            } else {
578                code
579            }
580        });
581        let mut warm_family = family.clone();
582        warm_family.event_target = warm_event_target;
583        // Right-censored-at-L ignores the interval upper bound `R`, so the
584        // (unused) `q_right` channel cannot drift the fit; leaving the right
585        // design/mass in place is harmless (no interval row remains to read it).
586        let warm_fit = fit_custom_family_fixed_log_lambdas(
587            &warm_family,
588            &blocks,
589            options,
590            None,
591            0,
592            None,
593            false,
594        )
595        .map_err(|e| {
596            format!(
597                "latent interval warm start: right-censored-at-L surrogate fit failed \
598                 (so the interval fit cannot be safely warm-started; this surrogate is \
599                 log-concave and should converge — investigate the surrogate, not the \
600                 interval kernel): {e}"
601            )
602        })?;
603        let warm_beta_usable = warm_fit
604            .block_states
605            .iter()
606            .any(|s| s.beta.iter().all(|v| v.is_finite()) && s.beta.iter().any(|&v| v != 0.0));
607        if !warm_beta_usable {
608            return Err(
609                "latent interval warm start: right-censored-at-L surrogate returned a \
610                 degenerate (non-finite or all-zero) β across every block; the warm start \
611                 cannot seed the interval fit. This indicates the surrogate's time-block \
612                 design is rank-deficient or the inner solve stalled at the seed — \
613                 investigate the surrogate before retrying the interval fit."
614                    .to_string(),
615            );
616        }
617        for (block, state) in blocks.iter_mut().zip(warm_fit.block_states.iter()) {
618            if state.beta.iter().all(|v| v.is_finite()) {
619                block.initial_beta = Some(state.beta.clone());
620            }
621        }
622    }
623    let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
624    let latent_sd = family.latent_sd(&fit.block_states)?;
625    let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
626    Ok(LatentSurvivalTermFitResult {
627        fit,
628        design: mean_design,
629        resolvedspec,
630        latent_sd,
631        baseline_offset_residuals,
632    })
633}
634
635pub fn fit_latent_binary_terms(
636    data: ArrayView2<'_, f64>,
637    spec: LatentBinaryTermSpec,
638    frailty: FrailtySpec,
639    options: &BlockwiseFitOptions,
640) -> Result<LatentBinaryTermFitResult, String> {
641    let latent_sd = validate_latent_binary_inputs(data, &spec, &frailty)?;
642    let (_, hazard_loading) = fixed_latent_hazard_frailty(&frailty, "latent-binary")?;
643    let mean_design =
644        build_term_collection_design(data, &spec.meanspec).map_err(|e| e.to_string())?;
645    let resolvedspec = freeze_term_collection_from_design(&spec.meanspec, &mean_design)
646        .map_err(|e| e.to_string())?;
647    let time_prepared = prepare_latent_time_block(&spec.time_block, None, spec.derivative_guard)?;
648
649    let family = LatentBinaryFamily {
650        event_target: spec.event_target.clone(),
651        weights: spec.weights.clone(),
652        latent_sd,
653        hazard_loading,
654        unloaded_mass_entry: spec.unloaded_mass_entry.clone(),
655        unloaded_mass_exit: spec.unloaded_mass_exit.clone(),
656        x_time_entry: time_prepared.design_entry.clone(),
657        x_time_exit: time_prepared.design_exit.clone(),
658        x_mean: mean_design.design.clone(),
659        time_linear_constraints: time_prepared.linear_constraints.clone(),
660        quadctx: Arc::new(QuadratureContext::new()),
661    };
662
663    let blocks = vec![
664        build_time_blockspec(&time_prepared, &spec.time_block),
665        build_mean_blockspec(&mean_design, spec.mean_offset.clone()),
666    ];
667    let fit = fit_custom_family(&family, &blocks, options).map_err(|e| e.to_string())?;
668    let baseline_offset_residuals = family.offset_channel_residuals(&fit.block_states)?;
669    Ok(LatentBinaryTermFitResult {
670        fit,
671        design: mean_design,
672        resolvedspec,
673        baseline_offset_residuals,
674    })
675}
676
677/// Latent-survival adapter for the shared [`LatentIntervalModel`] driver.
678///
679/// Survival permits a learnable sigma (`sigma_fixed == None`) and carries the
680/// per-row unloaded baseline hazard at exit (which feeds the exact-event
681/// loaded/unloaded split); everything else is validated by the shared engine.
682struct LatentSurvivalModel;
683
684impl LatentIntervalModel for LatentSurvivalModel {
685    fn context() -> &'static str {
686        "latent-survival"
687    }
688
689    fn allows_interval() -> bool {
690        true
691    }
692
693    fn frailty_policy(
694        frailty: &FrailtySpec,
695    ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
696        match frailty {
697            FrailtySpec::HazardMultiplier {
698                sigma_fixed,
699                loading,
700            } => {
701                if let Some(sigma) = sigma_fixed
702                    && (!sigma.is_finite() || *sigma < 0.0)
703                {
704                    return Err(LatentSurvivalError::InvalidFrailty {
705                        reason: format!(
706                            "latent-survival requires a finite hazard-multiplier sigma >= 0, got {sigma}"
707                        ),
708                    });
709                }
710                Ok(LatentFrailtyResolution {
711                    sigma: *sigma_fixed,
712                    loading: *loading,
713                })
714            }
715            FrailtySpec::GaussianShift { .. } => Err(LatentSurvivalError::InvalidFrailty {
716                reason: "latent-survival requires HazardMultiplier frailty, not GaussianShift"
717                    .to_string(),
718            }),
719            FrailtySpec::None => Err(LatentSurvivalError::InvalidFrailty {
720                reason: "latent-survival requires a HazardMultiplier frailty specification"
721                    .to_string(),
722            }),
723        }
724    }
725}
726
727fn validate_latent_survival_inputs(
728    data: ArrayView2<'_, f64>,
729    spec: &LatentSurvivalTermSpec,
730    frailty: &FrailtySpec,
731) -> Result<Option<f64>, LatentSurvivalError> {
732    let row = LatentIntervalRowView {
733        frailty,
734        age_entry: &spec.age_entry,
735        age_exit: &spec.age_exit,
736        event_target: &spec.event_target,
737        weights: &spec.weights,
738        unloaded_mass_entry: &spec.unloaded_mass_entry,
739        unloaded_mass_exit: &spec.unloaded_mass_exit,
740        unloaded_hazard_exit: Some(&spec.unloaded_hazard_exit),
741        mean_offset: &spec.mean_offset,
742        derivative_guard: spec.derivative_guard,
743        time_block: &spec.time_block,
744    };
745    validate_latent_interval_inputs::<LatentSurvivalModel>(data, &row)
746}
747
748pub(crate) fn validate_unloaded_components_for_loading(
749    context: &str,
750    row_index: usize,
751    loading: HazardLoading,
752    unloaded_entry: f64,
753    unloaded_exit: f64,
754    unloaded_hazard: Option<f64>,
755) -> Result<(), LatentSurvivalError> {
756    match loading {
757        HazardLoading::Full => {
758            if unloaded_entry != 0.0
759                || unloaded_exit != 0.0
760                || unloaded_hazard.is_some_and(|hazard| hazard != 0.0)
761            {
762                return Err(LatentSurvivalError::InvalidDataset {
763                    reason: format!(
764                        "{context} row {} uses full hazard loading, so unloaded components must be exactly zero; got entry_mass={}, exit_mass={}, exit_hazard={}",
765                        row_index + 1,
766                        unloaded_entry,
767                        unloaded_exit,
768                        unloaded_hazard.unwrap_or(0.0)
769                    ),
770                });
771            }
772        }
773        HazardLoading::LoadedVsUnloaded => {}
774    }
775    Ok(())
776}
777
778/// Latent-binary adapter for the shared [`LatentIntervalModel`] driver.
779///
780/// Binary never evaluates an exact event, so it requires a finite *fixed*
781/// latent sigma (via [`fixed_latent_hazard_frailty_typed`]) and carries no
782/// per-row unloaded hazard; every other invariant is validated by the shared
783/// engine.
784struct LatentBinaryModel;
785
786impl LatentIntervalModel for LatentBinaryModel {
787    fn context() -> &'static str {
788        "latent-binary"
789    }
790
791    fn frailty_policy(
792        frailty: &FrailtySpec,
793    ) -> Result<LatentFrailtyResolution, LatentSurvivalError> {
794        let (sigma, loading) = fixed_latent_hazard_frailty_typed(frailty, "latent-binary")?;
795        Ok(LatentFrailtyResolution {
796            sigma: Some(sigma),
797            loading,
798        })
799    }
800}
801
802fn validate_latent_binary_inputs(
803    data: ArrayView2<'_, f64>,
804    spec: &LatentBinaryTermSpec,
805    frailty: &FrailtySpec,
806) -> Result<f64, LatentSurvivalError> {
807    let row = LatentIntervalRowView {
808        frailty,
809        age_entry: &spec.age_entry,
810        age_exit: &spec.age_exit,
811        event_target: &spec.event_target,
812        weights: &spec.weights,
813        unloaded_mass_entry: &spec.unloaded_mass_entry,
814        unloaded_mass_exit: &spec.unloaded_mass_exit,
815        unloaded_hazard_exit: None,
816        mean_offset: &spec.mean_offset,
817        derivative_guard: spec.derivative_guard,
818        time_block: &spec.time_block,
819    };
820    // The binary `frailty_policy` always yields `Some(sigma)` (it rejects the
821    // learnable-scale case), so the shared driver's `Option<f64>` is `Some`
822    // here; surface a structured error rather than unwrapping if that ever
823    // changes.
824    validate_latent_interval_inputs::<LatentBinaryModel>(data, &row)?.ok_or_else(|| {
825        LatentSurvivalError::InvalidFrailty {
826            reason: "latent-binary requires a fixed latent sigma".to_string(),
827        }
828    })
829}
830
831fn prepare_latent_time_block(
832    input: &TimeBlockInput,
833    design_right: Option<&DesignMatrix>,
834    derivative_guard: f64,
835) -> Result<PreparedLatentTimeBlock, LatentSurvivalError> {
836    if !input.time_monotonicity.is_coordinate_cone() {
837        return Err(LatentSurvivalError::UnsupportedConfiguration {
838            reason: format!(
839                "latent survival requires a coordinate-cone monotonicity strategy; got {:?}",
840                input.time_monotonicity
841            ),
842        });
843    }
844    let design_entry = input
845        .design_entry
846        .try_to_dense_by_chunks("latent survival entry time design")?;
847    let design_exit = input
848        .design_exit
849        .try_to_dense_by_chunks("latent survival exit time design")?;
850    let design_derivative_exit = input
851        .design_derivative_exit
852        .try_to_dense_by_chunks("latent survival derivative time design")?;
853    // The interval upper-bound design shares the time-block coefficients with
854    // the exit design; when the data has no interval rows we reuse the exit
855    // design so `q_right` stays well-defined (its likelihood contribution is
856    // gated off for non-interval rows). When present it must match the exit
857    // design's shape (same basis, evaluated at R).
858    let design_right = match design_right {
859        Some(matrix) => {
860            let dense =
861                matrix.try_to_dense_by_chunks("latent survival interval right time design")?;
862            if dense.nrows() != design_exit.nrows() || dense.ncols() != design_exit.ncols() {
863                return Err(LatentSurvivalError::InvalidDataset {
864                    reason: format!(
865                        "latent survival interval right time design must match exit design shape \
866                         {:?}, got {:?}",
867                        design_exit.dim(),
868                        dense.dim()
869                    ),
870                });
871            }
872            dense
873        }
874        None => design_exit.clone(),
875    };
876    let linear_constraints = structural_time_coefficient_constraints(
877        &input.design_derivative_exit,
878        &input.derivative_offset_exit,
879        derivative_guard,
880    )?;
881    let initial_beta = match linear_constraints.as_ref() {
882        // `project_onto_linear_constraints` validates that any supplied
883        // `initial_beta` matches `design_exit.ncols()`; surface a mismatch as a
884        // structured error rather than letting an ndarray broadcast panic
885        // (issue #374).
886        Some(constraints) => Some(project_onto_linear_constraints(
887            design_exit.ncols(),
888            constraints,
889            input.initial_beta.as_ref(),
890        )?),
891        None => None,
892    };
893    Ok(PreparedLatentTimeBlock {
894        design_entry,
895        design_exit,
896        design_derivative_exit,
897        design_right,
898        linear_constraints,
899        penalties: input.penalties.clone(),
900        initial_beta,
901    })
902}
903
904fn stack_rows(blocks: &[&Array2<f64>]) -> Array2<f64> {
905    let ncols = blocks.first().map_or(0, |m| m.ncols());
906    let nrows = blocks.iter().map(|m| m.nrows()).sum();
907    let mut out = Array2::<f64>::zeros((nrows, ncols));
908    let mut row = 0usize;
909    for block in blocks {
910        let end = row + block.nrows();
911        out.slice_mut(s![row..end, ..]).assign(block);
912        row = end;
913    }
914    out
915}
916
917fn build_time_blockspec(
918    prepared: &PreparedLatentTimeBlock,
919    input: &TimeBlockInput,
920) -> ParameterBlockSpec {
921    // The solver produces a `3·n`-long time `eta` (the `[entry; exit; deriv]`
922    // channel stack that `split_time_eta` slices). That stacked operator is
923    // the eta-producing matrix and so belongs in `stacked_design`, paired with
924    // the matching `3·n`-row stacked offset. The audit / shape-policy invariant
925    // `design.nrows() == n_obs` is satisfied by exposing the single-channel
926    // n-row exit design as `design`; the audit never inspects `stacked_design`.
927    //
928    // This mirrors the survival location-scale fix for the same #326 class
929    // (`survival_location_scale.rs`): the previous code put the `3·n`-row
930    // stack in `design`, which tripped the flat identifiability audit's
931    // row-equality invariant (`block 1 (mean) has n rows, expected 3n`).
932    let stacked_design = stack_rows(&[
933        &prepared.design_entry,
934        &prepared.design_exit,
935        &prepared.design_derivative_exit,
936    ]);
937    let stacked_offset = gam_linalg::utils::stack_offsets(&[
938        &input.offset_entry,
939        &input.offset_exit,
940        &input.derivative_offset_exit,
941    ]);
942    ParameterBlockSpec {
943        name: "time_transform".to_string(),
944        design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
945            prepared.design_exit.clone(),
946        ))),
947        offset: input.offset_exit.clone(),
948        penalties: prepared
949            .penalties
950            .iter()
951            .cloned()
952            .map(PenaltyMatrix::Dense)
953            .collect(),
954        nullspace_dims: input.nullspace_dims.clone(),
955        initial_log_lambdas: input
956            .initial_log_lambdas
957            .clone()
958            .unwrap_or_else(|| Array1::zeros(prepared.penalties.len())),
959        initial_beta: prepared.initial_beta.clone(),
960        // Canonical-gauge ownership for the latent-survival joint design: the
961        // time-transform block carries the structural monotone baseline that
962        // anchors the parameterisation, so it owns any shared constant
963        // direction (strictly higher than `mean`/`log_sigma` at 100). This
964        // matches the survival location-scale gauge contract (time highest).
965        gauge_priority: 200,
966        jacobian_callback: None,
967        stacked_design: Some(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
968            stacked_design,
969        )))),
970        stacked_offset: Some(stacked_offset),
971    }
972}
973
974fn build_mean_blockspec(design: &TermCollectionDesign, offset: Array1<f64>) -> ParameterBlockSpec {
975    ParameterBlockSpec {
976        name: "mean".to_string(),
977        design: design.design.clone(),
978        offset,
979        penalties: design.penalties_as_penalty_matrix(),
980        nullspace_dims: design.nullspace_dims.clone(),
981        initial_log_lambdas: Array1::zeros(design.penalties.len()),
982        initial_beta: None,
983        // Strictly below `time_transform` (200) so any constant direction
984        // shared between the monotone time baseline and the mean intercept is
985        // deterministically attributable to the lower-priority `mean` block by
986        // the canonical-gauge RRQR (the descending-priority contract used by
987        // survival location-scale; #366/#556 gauge story).
988        gauge_priority: 150,
989        jacobian_callback: None,
990        stacked_design: None,
991        stacked_offset: None,
992    }
993}
994
995/// Starting latent-frailty standard deviation when `sigma` is learnable
996/// (`sigma_fixed == None`). The log-sigma block is seeded at `log(0.5)` so the
997/// optimizer begins from a moderate, well-conditioned dispersion (σ = 0.5,
998/// neither a near-degenerate σ → 0 that flattens the frailty integral nor a
999/// large σ that makes the Gauss-Hermite quadrature heavy-tailed) and then
1000/// learns the data's actual scale. Only an initial value, not a constraint.
1001const LEARNABLE_LATENT_SD_SEED: f64 = 0.5;
1002
1003fn build_log_sigma_blockspec(initial_sigma: f64, n_obs: usize) -> ParameterBlockSpec {
1004    ParameterBlockSpec {
1005        name: "log_sigma".to_string(),
1006        // The frailty/dispersion scale is a single GLOBAL hyperparameter (one free
1007        // coefficient), but the identifiability audit — and the canonical-row
1008        // architecture generally — require every block's effective Jacobian to carry
1009        // `n_obs` rows. A global scalar is realised the same way the survival
1010        // location-scale `log_sigma` block is (see `BinomialLocationScaleFamily`): an
1011        // `n_obs × 1` constant column of ones, so `eta = design · β` is the same scalar
1012        // broadcast to every observation. This keeps it a single free parameter while
1013        // exposing the `n_obs`-row shape the audit checks, and `latent_sd` reads
1014        // `eta[0]` — identical across rows by construction.
1015        design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(Array2::from_elem(
1016            (n_obs, 1),
1017            1.0,
1018        )))),
1019        offset: Array1::zeros(n_obs),
1020        penalties: vec![],
1021        nullspace_dims: vec![],
1022        initial_log_lambdas: Array1::zeros(0),
1023        initial_beta: Some(Array1::from_elem(
1024            1,
1025            exp_sigma_eta_for_sigma_scalar(initial_sigma),
1026        )),
1027        // Lowest of the three (time=200, mean=150): the learnable-scale channel
1028        // yields any shared constant to the location blocks.
1029        gauge_priority: 120,
1030        jacobian_callback: None,
1031        stacked_design: None,
1032        stacked_offset: None,
1033    }
1034}
1035
1036const LATENT_SURVIVAL_PRIMARY_Q_ENTRY: usize = 0;
1037const LATENT_SURVIVAL_PRIMARY_Q_EXIT: usize = 1;
1038const LATENT_SURVIVAL_PRIMARY_QDOT_EXIT: usize = 2;
1039// Interval-censored right boundary R: q_right = log B(R) shares the time-block
1040// coefficients with q_exit (same monotone transform, different time point), so
1041// it is a fourth linear functional of the time block, NOT an independent eta
1042// channel. It sits before `mu`/`log_sigma` so the "trailing optional log_sigma"
1043// invariant used by `active_primary` (= `LATENT_SURVIVAL_PRIMARY_LOG_SIGMA`)
1044// keeps q_right always active.
1045const LATENT_SURVIVAL_PRIMARY_Q_RIGHT: usize = 3;
1046const LATENT_SURVIVAL_PRIMARY_MU: usize = 4;
1047const LATENT_SURVIVAL_PRIMARY_LOG_SIGMA: usize = 5;
1048const LATENT_SURVIVAL_PRIMARY_DIM: usize = 6;
1049
1050use gam_math::jet_partitions::MultiDirJet as LatentMultiDirJet;
1051
1052/// Derivatives of `log(x)` through 4th order.
1053///
1054/// # Contract
1055///
1056/// `x` must be strictly positive. This function does NOT clamp: a previous
1057/// version replaced `x` by `x.max(1e-300)`, which fabricated enormous finite
1058/// derivatives (`1/1e-300` etc.) that are the derivatives of neither `log(x)`
1059/// nor `log(max(x, floor))` and would silently mask an upstream domain
1060/// failure. Both callers guarantee `x > 0`: one composes at the literal `1.0`
1061/// (the normalised log-sum base); the other passes `base`, which is gated by
1062/// an explicit `base.is_finite() && base > 0.0` check immediately upstream. A
1063/// non-positive `x` therefore never reaches here on any supported path; were
1064/// one to, the function returns the honest IEEE result (`-inf`/`NaN`) —
1065/// identical in debug and release — rather than a finite fabrication. For all
1066/// valid `x > 0` the output is bit-identical to the previous clamped version.
1067#[inline]
1068fn latent_unary_derivatives_log(x: f64) -> [f64; 5] {
1069    let x2 = x * x;
1070    let x3 = x2 * x;
1071    let x4 = x3 * x;
1072    [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
1073}
1074
1075#[derive(Clone, Copy, Debug)]
1076struct LatentKernelPrimaryTerm {
1077    coeff: f64,
1078    q_exp: usize,
1079    qdot_power: usize,
1080    tau_exp: usize,
1081    k: usize,
1082}
1083
1084#[derive(Clone, Copy, Debug)]
1085struct LatentKernelPrimaryDirection {
1086    dq: f64,
1087    dqd: f64,
1088    dmu: f64,
1089    dtau: f64,
1090}
1091
1092#[derive(Clone, Copy, Debug)]
1093struct LatentSurvivalPrimaryDirection {
1094    dq_entry: f64,
1095    dq_exit: f64,
1096    dqdot_exit: f64,
1097    dq_right: f64,
1098    dmu: f64,
1099    dlog_sigma: f64,
1100}
1101
1102#[derive(Clone, Copy, Debug)]
1103struct LatentKernelPrimaryState {
1104    q: f64,
1105    qdot: f64,
1106    mu: f64,
1107    sigma: f64,
1108    log_sigma_factor: f64,
1109}
1110
1111fn latent_kernel_accumulate_term(
1112    terms: &mut BTreeMap<(usize, usize, usize, usize), f64>,
1113    term: LatentKernelPrimaryTerm,
1114    scale: f64,
1115) {
1116    if scale == 0.0 || term.coeff == 0.0 {
1117        return;
1118    }
1119    *terms
1120        .entry((term.q_exp, term.qdot_power, term.tau_exp, term.k))
1121        .or_insert(0.0) += scale * term.coeff;
1122}
1123
1124fn latent_kernel_differentiate_terms(
1125    terms: &[LatentKernelPrimaryTerm],
1126    dir: LatentKernelPrimaryDirection,
1127) -> Vec<LatentKernelPrimaryTerm> {
1128    let mut out = BTreeMap::<(usize, usize, usize, usize), f64>::new();
1129    for term in terms {
1130        if dir.dq != 0.0 {
1131            if term.q_exp > 0 {
1132                latent_kernel_accumulate_term(&mut out, *term, dir.dq * term.q_exp as f64);
1133            }
1134            latent_kernel_accumulate_term(
1135                &mut out,
1136                LatentKernelPrimaryTerm {
1137                    q_exp: term.q_exp + 1,
1138                    k: term.k + 1,
1139                    ..*term
1140                },
1141                -dir.dq,
1142            );
1143        }
1144        if dir.dmu != 0.0 {
1145            if term.k > 0 {
1146                latent_kernel_accumulate_term(&mut out, *term, dir.dmu * term.k as f64);
1147            }
1148            latent_kernel_accumulate_term(
1149                &mut out,
1150                LatentKernelPrimaryTerm {
1151                    q_exp: term.q_exp + 1,
1152                    k: term.k + 1,
1153                    ..*term
1154                },
1155                -dir.dmu,
1156            );
1157        }
1158        if dir.dtau != 0.0 {
1159            if term.tau_exp > 0 {
1160                latent_kernel_accumulate_term(&mut out, *term, dir.dtau * term.tau_exp as f64);
1161            }
1162            let kf = term.k as f64;
1163            latent_kernel_accumulate_term(
1164                &mut out,
1165                LatentKernelPrimaryTerm {
1166                    tau_exp: term.tau_exp + 2,
1167                    ..*term
1168                },
1169                dir.dtau * kf * kf,
1170            );
1171            latent_kernel_accumulate_term(
1172                &mut out,
1173                LatentKernelPrimaryTerm {
1174                    q_exp: term.q_exp + 1,
1175                    tau_exp: term.tau_exp + 2,
1176                    k: term.k + 1,
1177                    ..*term
1178                },
1179                -dir.dtau * (2.0 * kf + 1.0),
1180            );
1181            latent_kernel_accumulate_term(
1182                &mut out,
1183                LatentKernelPrimaryTerm {
1184                    q_exp: term.q_exp + 2,
1185                    tau_exp: term.tau_exp + 2,
1186                    k: term.k + 2,
1187                    ..*term
1188                },
1189                dir.dtau,
1190            );
1191        }
1192        if dir.dqd != 0.0 && term.qdot_power > 0 {
1193            latent_kernel_accumulate_term(
1194                &mut out,
1195                LatentKernelPrimaryTerm {
1196                    qdot_power: term.qdot_power - 1,
1197                    ..*term
1198                },
1199                dir.dqd * term.qdot_power as f64,
1200            );
1201        }
1202    }
1203    out.into_iter()
1204        .filter_map(|((q_exp, qdot_power, tau_exp, k), coeff)| {
1205            (coeff != 0.0).then_some(LatentKernelPrimaryTerm {
1206                coeff,
1207                q_exp,
1208                qdot_power,
1209                tau_exp,
1210                k,
1211            })
1212        })
1213        .collect()
1214}
1215
1216fn latent_kernel_term_lists_for_directions(
1217    base_terms: &[LatentKernelPrimaryTerm],
1218    directions: &[LatentKernelPrimaryDirection],
1219) -> Vec<Vec<LatentKernelPrimaryTerm>> {
1220    fn build_mask(
1221        mask: usize,
1222        base_terms: &[LatentKernelPrimaryTerm],
1223        directions: &[LatentKernelPrimaryDirection],
1224        cache: &mut [Option<Vec<LatentKernelPrimaryTerm>>],
1225    ) -> Vec<LatentKernelPrimaryTerm> {
1226        if let Some(existing) = &cache[mask] {
1227            return existing.clone();
1228        }
1229        let built = if mask == 0 {
1230            base_terms.to_vec()
1231        } else {
1232            let bit = 1usize << mask.trailing_zeros();
1233            let prev = build_mask(mask ^ bit, base_terms, directions, cache);
1234            latent_kernel_differentiate_terms(&prev, directions[bit.trailing_zeros() as usize])
1235        };
1236        cache[mask] = Some(built.clone());
1237        built
1238    }
1239
1240    let mut cache = vec![None; 1usize << directions.len()];
1241    (0..cache.len())
1242        .map(|mask| build_mask(mask, base_terms, directions, &mut cache))
1243        .collect()
1244}
1245
1246fn latent_kernel_sum_log_jet(
1247    quadctx: &QuadratureContext,
1248    base_terms: &[LatentKernelPrimaryTerm],
1249    state: LatentKernelPrimaryState,
1250    directions: &[LatentKernelPrimaryDirection],
1251    context: &str,
1252) -> Result<LatentMultiDirJet, LatentSurvivalError> {
1253    let term_lists = latent_kernel_term_lists_for_directions(base_terms, directions);
1254    let max_k = term_lists
1255        .iter()
1256        .flat_map(|terms| terms.iter().map(|term| term.k))
1257        .max()
1258        .unwrap_or(0);
1259    let bundle =
1260        log_kernel_bundle(quadctx, state.q.exp(), state.mu, state.sigma, max_k).map_err(|e| {
1261            LatentSurvivalError::NumericalFailure {
1262                reason: format!("{context} kernel evaluation failed: {e}"),
1263            }
1264        })?;
1265
1266    let evaluate_terms =
1267        |terms: &[LatentKernelPrimaryTerm]| -> Result<(f64, f64), LatentSurvivalError> {
1268            let mut log_mags = Vec::new();
1269            let mut signs = Vec::new();
1270            for term in terms {
1271                if term.coeff == 0.0 {
1272                    continue;
1273                }
1274                if term.qdot_power > 0 && !(state.qdot.is_finite() && state.qdot > 0.0) {
1275                    return Err(LatentSurvivalError::NumericalFailure {
1276                        reason: format!(
1277                            "{context} requires positive finite qdot for exact-event directional terms, got {}",
1278                            state.qdot
1279                        ),
1280                    });
1281                }
1282                let log_qdot = if term.qdot_power > 0 {
1283                    state.qdot.ln()
1284                } else {
1285                    0.0
1286                };
1287                let log_mag = term.coeff.abs().ln()
1288                    + term.q_exp as f64 * state.q
1289                    + term.tau_exp as f64 * state.log_sigma_factor
1290                    + term.qdot_power as f64 * log_qdot
1291                    + bundle.get(term.k);
1292                log_mags.push(log_mag);
1293                signs.push(term.coeff.signum());
1294            }
1295            if log_mags.is_empty() {
1296                return Ok((f64::NEG_INFINITY, 0.0));
1297            }
1298            Ok(signed_log_sum_exp(&log_mags, &signs))
1299        };
1300
1301    let (base_log_sum, base_sign) = evaluate_terms(&term_lists[0])?;
1302    if !(base_log_sum.is_finite() && base_sign > 0.0) {
1303        return Err(LatentSurvivalError::NumericalFailure {
1304            reason: format!("{context} produced a non-positive signed kernel sum"),
1305        });
1306    }
1307
1308    let mut normalized = LatentMultiDirJet::constant(directions.len(), 1.0);
1309    for mask in 1..term_lists.len() {
1310        let (log_abs, sign) = evaluate_terms(&term_lists[mask])?;
1311        normalized.coeffs[mask] = if !log_abs.is_finite() || sign == 0.0 {
1312            0.0
1313        } else {
1314            sign * (log_abs - base_log_sum).exp()
1315        };
1316    }
1317
1318    let mut out = normalized.compose_unary(latent_unary_derivatives_log(1.0));
1319    out.coeffs[0] += base_log_sum;
1320    Ok(out)
1321}
1322
1323fn latent_survival_basis_direction(primary_idx: usize) -> LatentSurvivalPrimaryDirection {
1324    match primary_idx {
1325        LATENT_SURVIVAL_PRIMARY_Q_ENTRY => LatentSurvivalPrimaryDirection {
1326            dq_entry: 1.0,
1327            dq_exit: 0.0,
1328            dqdot_exit: 0.0,
1329            dq_right: 0.0,
1330            dmu: 0.0,
1331            dlog_sigma: 0.0,
1332        },
1333        LATENT_SURVIVAL_PRIMARY_Q_EXIT => LatentSurvivalPrimaryDirection {
1334            dq_entry: 0.0,
1335            dq_exit: 1.0,
1336            dqdot_exit: 0.0,
1337            dq_right: 0.0,
1338            dmu: 0.0,
1339            dlog_sigma: 0.0,
1340        },
1341        LATENT_SURVIVAL_PRIMARY_QDOT_EXIT => LatentSurvivalPrimaryDirection {
1342            dq_entry: 0.0,
1343            dq_exit: 0.0,
1344            dqdot_exit: 1.0,
1345            dq_right: 0.0,
1346            dmu: 0.0,
1347            dlog_sigma: 0.0,
1348        },
1349        LATENT_SURVIVAL_PRIMARY_Q_RIGHT => LatentSurvivalPrimaryDirection {
1350            dq_entry: 0.0,
1351            dq_exit: 0.0,
1352            dqdot_exit: 0.0,
1353            dq_right: 1.0,
1354            dmu: 0.0,
1355            dlog_sigma: 0.0,
1356        },
1357        LATENT_SURVIVAL_PRIMARY_MU => LatentSurvivalPrimaryDirection {
1358            dq_entry: 0.0,
1359            dq_exit: 0.0,
1360            dqdot_exit: 0.0,
1361            dq_right: 0.0,
1362            dmu: 1.0,
1363            dlog_sigma: 0.0,
1364        },
1365        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA => LatentSurvivalPrimaryDirection {
1366            dq_entry: 0.0,
1367            dq_exit: 0.0,
1368            dqdot_exit: 0.0,
1369            dq_right: 0.0,
1370            dmu: 0.0,
1371            dlog_sigma: 1.0,
1372        },
1373        // SAFETY: latent survival has exactly `LATENT_SURVIVAL_PRIMARY_DIM`
1374        // (= 5) primary directions, indexed 0..=4 via the module-private
1375        // `LATENT_SURVIVAL_PRIMARY_*` constants. All five are matched
1376        // above, so this wildcard fires only on an out-of-range index,
1377        // which the internal iteration bounds (`0..LATENT_SURVIVAL_PRIMARY_DIM`)
1378        // make unreachable.
1379        // SAFETY: primary_idx is bounded by LATENT_SURVIVAL_PRIMARY_DIM at every internal call site.
1380        _ => std::panic::panic_any(format!(
1381            "latent survival primary index out of bounds: primary_idx={primary_idx}, primary_dim={LATENT_SURVIVAL_PRIMARY_DIM}"
1382        )),
1383    }
1384}
1385
1386fn latent_survival_map_entry_direction(
1387    direction: LatentSurvivalPrimaryDirection,
1388) -> LatentKernelPrimaryDirection {
1389    LatentKernelPrimaryDirection {
1390        dq: direction.dq_entry,
1391        dqd: 0.0,
1392        dmu: direction.dmu,
1393        dtau: direction.dlog_sigma,
1394    }
1395}
1396
1397fn latent_survival_map_exit_direction(
1398    direction: LatentSurvivalPrimaryDirection,
1399    event_type: LatentSurvivalEventType,
1400) -> LatentKernelPrimaryDirection {
1401    LatentKernelPrimaryDirection {
1402        dq: direction.dq_exit,
1403        dqd: if matches!(event_type, LatentSurvivalEventType::ExactEvent) {
1404            direction.dqdot_exit
1405        } else {
1406            0.0
1407        },
1408        dmu: direction.dmu,
1409        dtau: direction.dlog_sigma,
1410    }
1411}
1412
1413/// Direction map for the interval-censored LEFT boundary state (mass `M_L =
1414/// exp(q_exit)`). The left boundary tracks the same `q_exit` time functional as
1415/// right-censoring (no hazard-derivative channel), plus the shared `mu`/`sigma`.
1416fn latent_survival_map_left_direction(
1417    direction: LatentSurvivalPrimaryDirection,
1418) -> LatentKernelPrimaryDirection {
1419    LatentKernelPrimaryDirection {
1420        dq: direction.dq_exit,
1421        dqd: 0.0,
1422        dmu: direction.dmu,
1423        dtau: direction.dlog_sigma,
1424    }
1425}
1426
1427/// Direction map for the interval-censored RIGHT boundary state (mass `M_R =
1428/// exp(q_right)`). The right boundary tracks the dedicated `q_right` functional
1429/// (which shares the time-block coefficients with `q_exit` but is evaluated at
1430/// the interval upper bound `R`), plus the shared `mu`/`sigma`.
1431fn latent_survival_map_right_direction(
1432    direction: LatentSurvivalPrimaryDirection,
1433) -> LatentKernelPrimaryDirection {
1434    LatentKernelPrimaryDirection {
1435        dq: direction.dq_right,
1436        dqd: 0.0,
1437        dmu: direction.dmu,
1438        dtau: direction.dlog_sigma,
1439    }
1440}
1441
1442fn latent_survival_row_primary_log_jet(
1443    quadctx: &QuadratureContext,
1444    row: &LatentSurvivalRow,
1445    q_entry: f64,
1446    q_exit: f64,
1447    qdot_exit: f64,
1448    q_right: f64,
1449    mu: f64,
1450    sigma: f64,
1451    log_sigma_factor: f64,
1452    directions: &[LatentSurvivalPrimaryDirection],
1453) -> Result<LatentMultiDirJet, String> {
1454    let entry_state = LatentKernelPrimaryState {
1455        q: q_entry,
1456        qdot: 1.0,
1457        mu,
1458        sigma,
1459        log_sigma_factor,
1460    };
1461    let entry_directions = directions
1462        .iter()
1463        .copied()
1464        .map(latent_survival_map_entry_direction)
1465        .collect::<Vec<_>>();
1466
1467    let denominator = latent_kernel_sum_log_jet(
1468        quadctx,
1469        &[LatentKernelPrimaryTerm {
1470            coeff: 1.0,
1471            q_exp: 0,
1472            qdot_power: 0,
1473            tau_exp: 0,
1474            k: 0,
1475        }],
1476        entry_state,
1477        &entry_directions,
1478        "latent survival denominator",
1479    )?;
1480
1481    // The numerator for right-censoring / exact events is a single-state log-sum
1482    // kernel at the exit mass. Interval censoring is the difference of two
1483    // single-state kernels at DIFFERENT masses (L at `q_exit`, R at `q_right`),
1484    // so it is assembled by `latent_survival_interval_numerator_log_jet` below.
1485    let numerator = match row.event_type {
1486        LatentSurvivalEventType::RightCensored | LatentSurvivalEventType::ExactEvent => {
1487            let exit_state = LatentKernelPrimaryState {
1488                q: q_exit,
1489                qdot: qdot_exit,
1490                mu,
1491                sigma,
1492                log_sigma_factor,
1493            };
1494            let exit_directions = directions
1495                .iter()
1496                .copied()
1497                .map(|dir| latent_survival_map_exit_direction(dir, row.event_type))
1498                .collect::<Vec<_>>();
1499            let numerator_terms = match row.event_type {
1500                LatentSurvivalEventType::RightCensored => vec![LatentKernelPrimaryTerm {
1501                    coeff: 1.0,
1502                    q_exp: 0,
1503                    qdot_power: 0,
1504                    tau_exp: 0,
1505                    k: 0,
1506                }],
1507                LatentSurvivalEventType::ExactEvent => {
1508                    let mut terms = Vec::new();
1509                    if row.hazard_unloaded > 0.0 {
1510                        terms.push(LatentKernelPrimaryTerm {
1511                            coeff: row.hazard_unloaded,
1512                            q_exp: 0,
1513                            qdot_power: 0,
1514                            tau_exp: 0,
1515                            k: 0,
1516                        });
1517                    }
1518                    terms.push(LatentKernelPrimaryTerm {
1519                        coeff: 1.0,
1520                        q_exp: 1,
1521                        qdot_power: 1,
1522                        tau_exp: 0,
1523                        k: 1,
1524                    });
1525                    terms
1526                }
1527                LatentSurvivalEventType::IntervalCensored => {
1528                    // Interval-censored rows are routed to the dedicated two-state
1529                    // numerator branch (the outer match arm below), so this inner
1530                    // arm is not reached; a clean error rather than a panic guards
1531                    // against a future routing change.
1532                    return Err(
1533                        "interval-censored row reached the single-state numerator branch; \
1534                         it must take the dedicated two-state branch"
1535                            .to_string(),
1536                    );
1537                }
1538            };
1539            latent_kernel_sum_log_jet(
1540                quadctx,
1541                &numerator_terms,
1542                exit_state,
1543                &exit_directions,
1544                "latent survival numerator",
1545            )?
1546        }
1547        LatentSurvivalEventType::IntervalCensored => latent_survival_interval_numerator_log_jet(
1548            quadctx,
1549            row,
1550            q_exit,
1551            q_right,
1552            mu,
1553            sigma,
1554            log_sigma_factor,
1555            directions,
1556        )?,
1557    };
1558
1559    let mut total = numerator.add(&denominator.scale(-1.0));
1560    // For interval rows the unloaded exit mass is folded into the per-boundary
1561    // coefficients `exp(-mass_unloaded_{left,right})` inside the two-state
1562    // numerator, so only the (constant) unloaded-entry term remains here; for
1563    // right-censoring / exact events the exit/entry unloaded masses are an
1564    // additive constant on the log-likelihood.
1565    match row.event_type {
1566        LatentSurvivalEventType::IntervalCensored => {
1567            total.coeffs[0] += row.mass_unloaded_entry;
1568        }
1569        _ => {
1570            total.coeffs[0] += -row.mass_unloaded_exit + row.mass_unloaded_entry;
1571        }
1572    }
1573    Ok(total)
1574}
1575
1576/// Interval-censored numerator jet `log[ c_L·K_{0,M_L} − c_R·K_{0,M_R} ]` where
1577/// `M_L = exp(q_exit)`, `M_R = exp(q_right)`, `c_L = exp(-mass_unloaded_left)`
1578/// and `c_R = exp(-mass_unloaded_right)`.
1579///
1580/// This is the dynamic-time analogue of the static
1581/// [`LatentSurvivalRowJet::interval_censored`] kernel: the interval likelihood
1582/// is the difference of two BOUNDARY survival masses, each a single-state
1583/// order-0 kernel, but at two DISTINCT cumulative masses. Because the two
1584/// boundaries respond to different time functionals (`q_exit` vs `q_right`) we
1585/// cannot fold them into one `latent_kernel_sum_log_jet` state. Instead we:
1586///   1. build each boundary's `log K_{0,M}` jet at its own state, with its own
1587///      direction map (left → `dq_exit`, right → `dq_right`; both share
1588///      `mu`/`sigma`),
1589///   2. lift each to the LINEAR domain via `exp` (a unary composition whose five
1590///      derivatives at value `v` are all `exp(v)`), scaled by its coefficient
1591///      `c_L` (resp. `−c_R`),
1592///   3. add the two linear-domain jets, and
1593///   4. drop back to the log domain via the same `log` unary composition the
1594///      single-state path uses.
1595/// Every multi-direction coefficient (value, score, neg-Hessian, 3rd, 4th)
1596/// follows by the Faà-di-Bruno composition already implemented in
1597/// `MultiDirJet::compose_unary`, so the derivative reductions are consistent
1598/// with the exact-event/right-censored branches by construction.
1599fn latent_survival_interval_numerator_log_jet(
1600    quadctx: &QuadratureContext,
1601    row: &LatentSurvivalRow,
1602    q_exit: f64,
1603    q_right: f64,
1604    mu: f64,
1605    sigma: f64,
1606    log_sigma_factor: f64,
1607    directions: &[LatentSurvivalPrimaryDirection],
1608) -> Result<LatentMultiDirJet, String> {
1609    let single_k0 = [LatentKernelPrimaryTerm {
1610        coeff: 1.0,
1611        q_exp: 0,
1612        qdot_power: 0,
1613        tau_exp: 0,
1614        k: 0,
1615    }];
1616
1617    let left_state = LatentKernelPrimaryState {
1618        q: q_exit,
1619        qdot: 1.0,
1620        mu,
1621        sigma,
1622        log_sigma_factor,
1623    };
1624    let right_state = LatentKernelPrimaryState {
1625        q: q_right,
1626        qdot: 1.0,
1627        mu,
1628        sigma,
1629        log_sigma_factor,
1630    };
1631    let left_directions = directions
1632        .iter()
1633        .copied()
1634        .map(latent_survival_map_left_direction)
1635        .collect::<Vec<_>>();
1636    let right_directions = directions
1637        .iter()
1638        .copied()
1639        .map(latent_survival_map_right_direction)
1640        .collect::<Vec<_>>();
1641
1642    let log_left = latent_kernel_sum_log_jet(
1643        quadctx,
1644        &single_k0,
1645        left_state,
1646        &left_directions,
1647        "latent survival interval left boundary",
1648    )?;
1649    let log_right = latent_kernel_sum_log_jet(
1650        quadctx,
1651        &single_k0,
1652        right_state,
1653        &right_directions,
1654        "latent survival interval right boundary",
1655    )?;
1656
1657    // Lift each boundary's log-kernel jet to the linear domain and scale by the
1658    // unloaded-mass prefactor. exp''''(v) = exp(v) for all orders, so the unary
1659    // derivative tower is `[exp(v); exp(v); exp(v); exp(v); exp(v)]`.
1660    let c_left = (-row.mass_unloaded_left).exp();
1661    let c_right = (-row.mass_unloaded_right).exp();
1662    let exp_left_value = log_left.coeff(0).exp();
1663    let exp_right_value = log_right.coeff(0).exp();
1664    let linear_left = log_left.compose_unary([exp_left_value; 5]).scale(c_left);
1665    let linear_right = log_right.compose_unary([exp_right_value; 5]).scale(c_right);
1666
1667    let linear_numerator = linear_left.add(&linear_right.scale(-1.0));
1668    let base = linear_numerator.coeff(0);
1669    if !(base.is_finite() && base > 0.0) {
1670        return Err(LatentSurvivalError::NumericalFailure {
1671            reason: format!(
1672                "latent survival interval numerator must be a positive survival-mass difference, \
1673                 got c_L*K0(M_L) - c_R*K0(M_R) = {base}; require M_L < M_R (i.e. L < R)"
1674            ),
1675        }
1676        .into());
1677    }
1678    // Drop back to the log domain. `latent_unary_derivatives_log(base)` is the
1679    // unary derivative tower of `ln` at the positive base value, so the composed
1680    // value channel is `ln(base)` and the higher coefficients are the
1681    // log-of-a-difference score / curvature, consistent with the single-state
1682    // log-sum path (which composes `ln` at its normalised base of 1).
1683    Ok(linear_numerator.compose_unary(latent_unary_derivatives_log(base)))
1684}
1685
1686fn latent_survival_row_primary_gradient_hessian(
1687    quadctx: &QuadratureContext,
1688    row: &LatentSurvivalRow,
1689    q_entry: f64,
1690    q_exit: f64,
1691    qdot_exit: f64,
1692    q_right: f64,
1693    mu: f64,
1694    sigma: f64,
1695    include_log_sigma: bool,
1696) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
1697    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1698    let mut gradient = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1699    let mut neg_hessian =
1700        Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1701    let active_primary = if include_log_sigma {
1702        LATENT_SURVIVAL_PRIMARY_DIM
1703    } else {
1704        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1705    };
1706    let log_lik = latent_survival_row_primary_log_jet(
1707        quadctx,
1708        row,
1709        q_entry,
1710        q_exit,
1711        qdot_exit,
1712        q_right,
1713        mu,
1714        sigma,
1715        log_sigma_factor,
1716        &[],
1717    )?
1718    .coeff(0);
1719    for a in 0..active_primary {
1720        let dir_a = latent_survival_basis_direction(a);
1721        gradient[a] = latent_survival_row_primary_log_jet(
1722            quadctx,
1723            row,
1724            q_entry,
1725            q_exit,
1726            qdot_exit,
1727            q_right,
1728            mu,
1729            sigma,
1730            log_sigma_factor,
1731            &[dir_a],
1732        )?
1733        .coeff(1);
1734        for b in a..active_primary {
1735            let coeff = latent_survival_row_primary_log_jet(
1736                quadctx,
1737                row,
1738                q_entry,
1739                q_exit,
1740                qdot_exit,
1741                q_right,
1742                mu,
1743                sigma,
1744                log_sigma_factor,
1745                &[dir_a, latent_survival_basis_direction(b)],
1746            )?
1747            .coeff(3);
1748            neg_hessian[[a, b]] = -coeff;
1749            neg_hessian[[b, a]] = -coeff;
1750        }
1751    }
1752    Ok((log_lik, gradient, neg_hessian))
1753}
1754
1755fn latent_survival_row_primary_third_contracted(
1756    quadctx: &QuadratureContext,
1757    row: &LatentSurvivalRow,
1758    q_entry: f64,
1759    q_exit: f64,
1760    qdot_exit: f64,
1761    q_right: f64,
1762    mu: f64,
1763    sigma: f64,
1764    direction: &Array1<f64>,
1765    include_log_sigma: bool,
1766) -> Result<Array2<f64>, String> {
1767    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1768    let active_primary = if include_log_sigma {
1769        LATENT_SURVIVAL_PRIMARY_DIM
1770    } else {
1771        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1772    };
1773    let dir = LatentSurvivalPrimaryDirection {
1774        dq_entry: direction[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1775        dq_exit: direction[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1776        dqdot_exit: direction[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1777        dq_right: direction[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1778        dmu: direction[LATENT_SURVIVAL_PRIMARY_MU],
1779        dlog_sigma: direction[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1780    };
1781    let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1782    for a in 0..active_primary {
1783        let dir_a = latent_survival_basis_direction(a);
1784        for b in a..active_primary {
1785            let coeff = latent_survival_row_primary_log_jet(
1786                quadctx,
1787                row,
1788                q_entry,
1789                q_exit,
1790                qdot_exit,
1791                q_right,
1792                mu,
1793                sigma,
1794                log_sigma_factor,
1795                &[dir_a, latent_survival_basis_direction(b), dir],
1796            )?
1797            .coeff(7);
1798            out[[a, b]] = -coeff;
1799            out[[b, a]] = -coeff;
1800        }
1801    }
1802    Ok(out)
1803}
1804
1805fn latent_survival_row_primary_fourth_contracted(
1806    quadctx: &QuadratureContext,
1807    row: &LatentSurvivalRow,
1808    q_entry: f64,
1809    q_exit: f64,
1810    qdot_exit: f64,
1811    q_right: f64,
1812    mu: f64,
1813    sigma: f64,
1814    direction_u: &Array1<f64>,
1815    direction_v: &Array1<f64>,
1816    include_log_sigma: bool,
1817) -> Result<Array2<f64>, String> {
1818    let log_sigma_factor = if sigma > 0.0 { sigma.ln() } else { 0.0 };
1819    let active_primary = if include_log_sigma {
1820        LATENT_SURVIVAL_PRIMARY_DIM
1821    } else {
1822        LATENT_SURVIVAL_PRIMARY_LOG_SIGMA
1823    };
1824    let dir_u = LatentSurvivalPrimaryDirection {
1825        dq_entry: direction_u[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1826        dq_exit: direction_u[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1827        dqdot_exit: direction_u[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1828        dq_right: direction_u[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1829        dmu: direction_u[LATENT_SURVIVAL_PRIMARY_MU],
1830        dlog_sigma: direction_u[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1831    };
1832    let dir_v = LatentSurvivalPrimaryDirection {
1833        dq_entry: direction_v[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
1834        dq_exit: direction_v[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
1835        dqdot_exit: direction_v[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
1836        dq_right: direction_v[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
1837        dmu: direction_v[LATENT_SURVIVAL_PRIMARY_MU],
1838        dlog_sigma: direction_v[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA],
1839    };
1840    let mut out = Array2::<f64>::zeros((LATENT_SURVIVAL_PRIMARY_DIM, LATENT_SURVIVAL_PRIMARY_DIM));
1841    for a in 0..active_primary {
1842        let dir_a = latent_survival_basis_direction(a);
1843        for b in a..active_primary {
1844            let coeff = latent_survival_row_primary_log_jet(
1845                quadctx,
1846                row,
1847                q_entry,
1848                q_exit,
1849                qdot_exit,
1850                q_right,
1851                mu,
1852                sigma,
1853                log_sigma_factor,
1854                &[dir_a, latent_survival_basis_direction(b), dir_u, dir_v],
1855            )?
1856            .coeff(15);
1857            out[[a, b]] = -coeff;
1858            out[[b, a]] = -coeff;
1859        }
1860    }
1861    Ok(out)
1862}
1863
1864#[derive(Clone)]
1865struct LatentSurvivalJointSlices {
1866    time: std::ops::Range<usize>,
1867    mean: std::ops::Range<usize>,
1868    log_sigma: Option<std::ops::Range<usize>>,
1869    total: usize,
1870}
1871
1872#[derive(Clone)]
1873struct LatentSurvivalJointGradientAccum {
1874    ll: f64,
1875    gradient: Array1<f64>,
1876}
1877
1878#[derive(Clone)]
1879struct LatentSurvivalJointDenseAccum {
1880    ll: f64,
1881    gradient: Array1<f64>,
1882    hessian: Array2<f64>,
1883}
1884
1885#[derive(Clone)]
1886struct LatentSurvivalDenseHessianAccum {
1887    hessian: Array2<f64>,
1888}
1889
1890/// Process latent-survival rows in fixed contiguous chunks, using one
1891/// accumulator per rayon task and reducing those accumulators in chunk-index
1892/// order so gradient/Hessian assembly stays deterministic across runs.
1893fn deterministic_latent_survival_row_reduction<Acc, Init, Process, Combine>(
1894    n_rows: usize,
1895    init: Init,
1896    process_row: Process,
1897    mut combine: Combine,
1898) -> Result<Acc, String>
1899where
1900    Acc: Send,
1901    Init: Fn() -> Acc + Sync,
1902    Process: Fn(usize, &mut Acc) -> Result<(), String> + Sync,
1903    Combine: FnMut(&mut Acc, Acc),
1904{
1905    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1906
1907    const TARGET_CHUNK_COUNT: usize = 32;
1908    if n_rows == 0 {
1909        return Ok(init());
1910    }
1911    let chunk_size = n_rows.div_ceil(TARGET_CHUNK_COUNT).max(1);
1912    let n_chunks = n_rows.div_ceil(chunk_size);
1913    let chunk_accumulators: Vec<Acc> = (0..n_chunks)
1914        .into_par_iter()
1915        .map(|chunk_idx| -> Result<Acc, String> {
1916            let start = chunk_idx * chunk_size;
1917            let end = (start + chunk_size).min(n_rows);
1918            let mut acc = init();
1919            for row_idx in start..end {
1920                process_row(row_idx, &mut acc)?;
1921            }
1922            Ok(acc)
1923        })
1924        .collect::<Result<Vec<_>, String>>()?;
1925
1926    let mut total = init();
1927    for acc in chunk_accumulators {
1928        combine(&mut total, acc);
1929    }
1930    Ok(total)
1931}
1932
1933impl LatentSurvivalFamily {
1934    /// Assemble the per-row [`LatentSurvivalRow`] for `row_idx` from the family's
1935    /// unloaded-mass/hazard fields and the supplied per-row time quantiles.
1936    ///
1937    /// Shared by every per-row reduction (log-likelihood, gradient, Hessian,
1938    /// directional third derivatives): each previously inlined an identical
1939    /// `event_type` lookup followed by the same 12-argument
1940    /// `build_latent_survival_row` call. Behavior is unchanged.
1941    fn build_row_at(
1942        &self,
1943        row_idx: usize,
1944        q_entry: f64,
1945        q_exit: f64,
1946        qdot_exit: f64,
1947        q_right: f64,
1948    ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
1949        let event_type = latent_survival_event_type_for(self.event_target[row_idx]);
1950        build_latent_survival_row(
1951            row_idx,
1952            self.hazard_loading,
1953            event_type,
1954            q_entry,
1955            q_exit,
1956            qdot_exit,
1957            q_right,
1958            self.unloaded_mass_entry[row_idx],
1959            self.unloaded_mass_exit[row_idx],
1960            self.unloaded_mass_right[row_idx],
1961            self.unloaded_hazard_exit[row_idx],
1962        )
1963    }
1964
1965    fn joint_slices(&self) -> LatentSurvivalJointSlices {
1966        let p_time = self.x_time_exit.ncols();
1967        let p_mean = self.x_mean.ncols();
1968        let time = 0..p_time;
1969        let mean = p_time..p_time + p_mean;
1970        let log_sigma = self
1971            .latent_sd_fixed
1972            .is_none()
1973            .then_some((p_time + p_mean)..(p_time + p_mean + 1));
1974        LatentSurvivalJointSlices {
1975            total: log_sigma
1976                .as_ref()
1977                .map_or(p_time + p_mean, |range| range.end),
1978            time,
1979            mean,
1980            log_sigma,
1981        }
1982    }
1983
1984    fn row_primary_direction_from_flat(
1985        &self,
1986        row: usize,
1987        slices: &LatentSurvivalJointSlices,
1988        d_beta_flat: &Array1<f64>,
1989    ) -> Array1<f64> {
1990        let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
1991        let d_time = d_beta_flat.slice(s![slices.time.clone()]);
1992        out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
1993        out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
1994        out[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT] = self.x_time_derivative_exit.row(row).dot(&d_time);
1995        out[LATENT_SURVIVAL_PRIMARY_Q_RIGHT] = self.x_time_right.row(row).dot(&d_time);
1996        out[LATENT_SURVIVAL_PRIMARY_MU] = self
1997            .x_mean
1998            .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
1999        if let Some(range) = &slices.log_sigma {
2000            out[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA] = d_beta_flat[range.start];
2001        }
2002        out
2003    }
2004
2005    fn joint_block_ranges(&self) -> Vec<std::ops::Range<usize>> {
2006        let slices = self.joint_slices();
2007        let mut ranges = vec![slices.time.clone(), slices.mean.clone()];
2008        if let Some(log_sigma) = slices.log_sigma {
2009            ranges.push(log_sigma);
2010        }
2011        ranges
2012    }
2013
2014    fn add_pullback_primary_gradient(
2015        &self,
2016        target: &mut Array1<f64>,
2017        row: usize,
2018        slices: &LatentSurvivalJointSlices,
2019        primary_gradient: &Array1<f64>,
2020        weight: f64,
2021    ) -> Result<(), String> {
2022        for (primary_idx, time_vec) in [
2023            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2024            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2025            (
2026                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2027                self.x_time_derivative_exit.row(row),
2028            ),
2029            (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2030        ] {
2031            let scale = weight * primary_gradient[primary_idx];
2032            if scale == 0.0 {
2033                continue;
2034            }
2035            for i in 0..time_vec.len() {
2036                let xi = time_vec[i];
2037                if xi != 0.0 {
2038                    target[slices.time.start + i] += scale * xi;
2039                }
2040            }
2041        }
2042
2043        let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
2044        if mean_scale != 0.0 {
2045            self.x_mean
2046                .axpy_row_into(
2047                    row,
2048                    mean_scale,
2049                    &mut target.slice_mut(s![slices.mean.clone()]),
2050                )
2051                .map_err(|error| {
2052                    format!(
2053                        "latent survival mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
2054                        slices.mean,
2055                        target.len(),
2056                        self.x_mean.ncols()
2057                    )
2058                })?;
2059        }
2060
2061        if let Some(log_sigma) = &slices.log_sigma {
2062            target[log_sigma.start] += weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA];
2063        }
2064        Ok(())
2065    }
2066
2067    fn add_pullback_primary_hessian(
2068        &self,
2069        target: &mut Array2<f64>,
2070        row: usize,
2071        slices: &LatentSurvivalJointSlices,
2072        primary_hessian: &Array2<f64>,
2073    ) -> Result<(), String> {
2074        let time_weights = [
2075            primary_hessian[[
2076                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2077                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2078            ]],
2079            primary_hessian[[
2080                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2081                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2082            ]],
2083            primary_hessian[[
2084                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2085                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2086            ]],
2087            primary_hessian[[
2088                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2089                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2090            ]],
2091        ];
2092        let time_cross_weights = [
2093            (
2094                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2095                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2096                &self.x_time_entry,
2097                &self.x_time_exit,
2098            ),
2099            (
2100                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2101                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2102                &self.x_time_entry,
2103                &self.x_time_derivative_exit,
2104            ),
2105            (
2106                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2107                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2108                &self.x_time_exit,
2109                &self.x_time_derivative_exit,
2110            ),
2111            (
2112                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2113                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2114                &self.x_time_entry,
2115                &self.x_time_right,
2116            ),
2117            (
2118                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2119                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2120                &self.x_time_exit,
2121                &self.x_time_right,
2122            ),
2123            (
2124                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2125                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2126                &self.x_time_derivative_exit,
2127                &self.x_time_right,
2128            ),
2129        ];
2130        {
2131            let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
2132            dense_outer_accumulate(time_target, time_weights[0], self.x_time_entry.row(row));
2133            dense_outer_accumulate(time_target, time_weights[1], self.x_time_exit.row(row));
2134            dense_outer_accumulate(
2135                time_target,
2136                time_weights[2],
2137                self.x_time_derivative_exit.row(row),
2138            );
2139            dense_outer_accumulate(time_target, time_weights[3], self.x_time_right.row(row));
2140            for (a, b, lhs, rhs) in time_cross_weights {
2141                let weight = primary_hessian[[a, b]];
2142                if weight == 0.0 {
2143                    continue;
2144                }
2145                dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2146            }
2147        }
2148
2149        let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2150        self.x_mean
2151            .syr_row_into_view(
2152                row,
2153                mean_weight,
2154                target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
2155            )
2156            .map_err(|error| {
2157                format!(
2158                    "latent survival mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
2159                    slices.mean,
2160                    target.dim(),
2161                    self.x_mean.ncols()
2162                )
2163            })?;
2164
2165        let mean_row = self
2166            .x_mean
2167            .try_row_chunk(row..row + 1)
2168            .map_err(|error| {
2169                format!(
2170                    "latent survival mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
2171                    self.x_mean.nrows(),
2172                    self.x_mean.ncols()
2173                )
2174            })?;
2175        let mean_vec = mean_row.row(0);
2176        let time_mean_weights = [
2177            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2178            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2179            (
2180                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2181                self.x_time_derivative_exit.row(row),
2182            ),
2183            (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2184        ];
2185        for (primary_idx, time_vec) in time_mean_weights {
2186            let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
2187            if weight == 0.0 {
2188                continue;
2189            }
2190            for i in 0..time_vec.len() {
2191                let xi = time_vec[i];
2192                if xi == 0.0 {
2193                    continue;
2194                }
2195                for j in 0..mean_vec.len() {
2196                    let xj = mean_vec[j];
2197                    if xj == 0.0 {
2198                        continue;
2199                    }
2200                    target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
2201                    target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
2202                }
2203            }
2204        }
2205
2206        if let Some(log_sigma) = &slices.log_sigma {
2207            let sigma_idx = log_sigma.start;
2208            target[[sigma_idx, sigma_idx]] += primary_hessian[[
2209                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2210                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2211            ]];
2212
2213            for (primary_idx, time_vec) in [
2214                (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
2215                (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
2216                (
2217                    LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2218                    self.x_time_derivative_exit.row(row),
2219                ),
2220                (LATENT_SURVIVAL_PRIMARY_Q_RIGHT, self.x_time_right.row(row)),
2221            ] {
2222                let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_LOG_SIGMA]];
2223                if weight == 0.0 {
2224                    continue;
2225                }
2226                for i in 0..time_vec.len() {
2227                    let xi = time_vec[i];
2228                    if xi == 0.0 {
2229                        continue;
2230                    }
2231                    target[[slices.time.start + i, sigma_idx]] += weight * xi;
2232                    target[[sigma_idx, slices.time.start + i]] += weight * xi;
2233                }
2234            }
2235
2236            let mean_sigma_weight = primary_hessian[[
2237                LATENT_SURVIVAL_PRIMARY_MU,
2238                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2239            ]];
2240            if mean_sigma_weight != 0.0 {
2241                for j in 0..mean_vec.len() {
2242                    let xj = mean_vec[j];
2243                    if xj == 0.0 {
2244                        continue;
2245                    }
2246                    target[[slices.mean.start + j, sigma_idx]] += mean_sigma_weight * xj;
2247                    target[[sigma_idx, slices.mean.start + j]] += mean_sigma_weight * xj;
2248                }
2249            }
2250        }
2251        Ok(())
2252    }
2253
2254    fn evaluate_exact_newton_joint_gradient_dense(
2255        &self,
2256        block_states: &[ParameterBlockState],
2257    ) -> Result<(f64, Array1<f64>), String> {
2258        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2259        let q_right = self.time_q_right(block_states)?;
2260        let sigma = self.latent_sd(block_states)?;
2261        let slices = self.joint_slices();
2262        let include_log_sigma = slices.log_sigma.is_some();
2263        let total = slices.total;
2264        let acc = deterministic_latent_survival_row_reduction(
2265            self.event_target.len(),
2266            || LatentSurvivalJointGradientAccum {
2267                ll: 0.0,
2268                gradient: Array1::<f64>::zeros(total),
2269            },
2270            |row_idx, acc| {
2271                let wi = self.weights[row_idx];
2272                if wi <= MIN_WEIGHT {
2273                    return Ok(());
2274                }
2275                let row = self.build_row_at(
2276                    row_idx,
2277                    q_entry[row_idx],
2278                    q_exit[row_idx],
2279                    qdot_exit[row_idx],
2280                    q_right[row_idx],
2281                )?;
2282                let (row_ll, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2283                    &self.quadctx,
2284                    &row,
2285                    q_entry[row_idx],
2286                    q_exit[row_idx],
2287                    qdot_exit[row_idx],
2288                    q_right[row_idx],
2289                    mu[row_idx],
2290                    sigma,
2291                    include_log_sigma,
2292                )?;
2293                acc.ll += wi * row_ll;
2294                self.add_pullback_primary_gradient(
2295                    &mut acc.gradient,
2296                    row_idx,
2297                    &slices,
2298                    &primary_gradient,
2299                    wi,
2300                )?;
2301                Ok(())
2302            },
2303            |total_acc, chunk_acc| {
2304                total_acc.ll += chunk_acc.ll;
2305                total_acc.gradient += &chunk_acc.gradient;
2306            },
2307        )?;
2308        Ok((acc.ll, acc.gradient))
2309    }
2310
2311    /// Per-row residuals of the unpenalized NLL with respect to the three
2312    /// additive baseline time-block offsets `(entry, exit, derivative)`.
2313    ///
2314    /// The baseline configuration θ enters the latent-survival working model
2315    /// only through the additive offsets on the three time channels
2316    ///   q_entry = x_time_entry·β_time + o_E(θ),
2317    ///   q_exit  = x_time_exit·β_time  + o_X(θ),
2318    ///   q̇_exit = x_time_deriv·β_time + o_D(θ),
2319    /// exactly the offset channel the transformation path carries through
2320    /// [`WorkingModelSurvival::offset_channel_residuals`]. Because
2321    /// `∂q_ch/∂o_ch = 1`, the residual `∂NLL/∂o_ch_i` equals
2322    /// `−∂(log-likelihood)/∂q_ch_i`, and the per-row primary log-likelihood
2323    /// gradient over `(q_entry, q_exit, q̇_exit)` is precisely the
2324    /// `Q_ENTRY`/`Q_EXIT`/`QDOT_EXIT` components returned by
2325    /// [`latent_survival_row_primary_gradient_hessian`]. Sampleweight-scaled to
2326    /// match the [`OffsetChannelResiduals`] contract consumed by
2327    /// `baseline_chain_rule_gradient`.
2328    ///
2329    /// At the converged (constrained) β̂ the envelope theorem makes this the
2330    /// exact θ-gradient of the profile penalized NLL `0.5·deviance + 0.5·βᵀSβ`.
2331    /// The interval upper-bound `q_right = x_time_right·β_time + o_R(θ)` channel
2332    /// DOES carry its own baseline-θ offset `o_R(θ)` (the time basis evaluated at
2333    /// the bracket upper bound `R`), distinct from the exit offset at `L`, so its
2334    /// residual `−∂(log-likelihood)/∂q_right` is returned in the dedicated
2335    /// [`OffsetChannelResiduals::right`] channel; it is exactly 0 on every
2336    /// non-interval row (the `Q_RIGHT` primary channel is inert there) and the
2337    /// baseline-θ chain rule contracts it against the `age_right`-evaluated
2338    /// η-partial.
2339    pub fn offset_channel_residuals(
2340        &self,
2341        block_states: &[ParameterBlockState],
2342    ) -> Result<crate::survival::OffsetChannelResiduals, String> {
2343        let n = self.event_target.len();
2344        if block_states.is_empty() {
2345            // Degraded-fit fallback mirroring the location-scale family: an
2346            // empty block-state slate (ARC deterministic-replay stall) yields
2347            // zero residuals so the outer baseline-θ BFGS sees ‖g‖ = 0 and
2348            // terminates cleanly at the current θ̂ rather than panicking.
2349            log::warn!(
2350                "LatentSurvivalFamily::offset_channel_residuals: block_states is empty \
2351                 (degraded fit); returning zero offset residuals (n={n})"
2352            );
2353            return Ok(crate::survival::OffsetChannelResiduals {
2354                exit: Array1::<f64>::zeros(n),
2355                entry: Array1::<f64>::zeros(n),
2356                derivative: Array1::<f64>::zeros(n),
2357                right: Array1::<f64>::zeros(n),
2358            });
2359        }
2360        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2361        let q_right = self.time_q_right(block_states)?;
2362        let sigma = self.latent_sd(block_states)?;
2363        let include_log_sigma = self.joint_slices().log_sigma.is_some();
2364        let mut entry = Array1::<f64>::zeros(n);
2365        let mut exit = Array1::<f64>::zeros(n);
2366        let mut derivative = Array1::<f64>::zeros(n);
2367        let mut right = Array1::<f64>::zeros(n);
2368        for row_idx in 0..n {
2369            let wi = self.weights[row_idx];
2370            if wi <= MIN_WEIGHT {
2371                continue;
2372            }
2373            let row = self.build_row_at(
2374                row_idx,
2375                q_entry[row_idx],
2376                q_exit[row_idx],
2377                qdot_exit[row_idx],
2378                q_right[row_idx],
2379            )?;
2380            let (_, primary_gradient, _) = latent_survival_row_primary_gradient_hessian(
2381                &self.quadctx,
2382                &row,
2383                q_entry[row_idx],
2384                q_exit[row_idx],
2385                qdot_exit[row_idx],
2386                q_right[row_idx],
2387                mu[row_idx],
2388                sigma,
2389                include_log_sigma,
2390            )?;
2391            // ∂NLL/∂o_ch = −w · ∂(log-likelihood)/∂q_ch.
2392            entry[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
2393            exit[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
2394            derivative[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
2395            // Interval upper-bound (`R`) channel. `q_right` shares the time-block
2396            // coefficients but carries its OWN baseline-θ η-offset evaluated at
2397            // `R` (`o_R(θ)`), so the profile-NLL θ-gradient must include it.
2398            // `∂(log-likelihood)/∂q_right` is exactly 0 for non-interval rows
2399            // (the `Q_RIGHT` channel is inert there), so this is 0 except on
2400            // interval-censored rows.
2401            right[row_idx] = -wi * primary_gradient[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
2402        }
2403        Ok(crate::survival::OffsetChannelResiduals {
2404            exit,
2405            entry,
2406            derivative,
2407            right,
2408        })
2409    }
2410
2411    /// Block-diagonal-only pullback: writes only time-time, mean-mean, and
2412    /// log_sigma-log_sigma rowwise contributions into per-block targets.
2413    /// Used by `evaluate()` to populate per-block working sets without ever
2414    /// materializing the cross blocks the inner solver does not consume.
2415    fn add_pullback_primary_block_diagonals(
2416        &self,
2417        row: usize,
2418        primary_hessian: &Array2<f64>,
2419        time_target: &mut Array2<f64>,
2420        mean_target: &mut Array2<f64>,
2421        log_sigma_target: Option<&mut Array2<f64>>,
2422    ) -> Result<(), String> {
2423        let h = primary_hessian;
2424        // Time block: 4 squared rows (entry/exit/qdot/right) + 6 symmetric
2425        // crosses. The interval right-boundary functional `q_right` shares the
2426        // time-block coefficients, so it accumulates into the same time target.
2427        dense_outer_accumulate(
2428            time_target,
2429            h[[
2430                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2431                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2432            ]],
2433            self.x_time_entry.row(row),
2434        );
2435        dense_outer_accumulate(
2436            time_target,
2437            h[[
2438                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2439                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2440            ]],
2441            self.x_time_exit.row(row),
2442        );
2443        dense_outer_accumulate(
2444            time_target,
2445            h[[
2446                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2447                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2448            ]],
2449            self.x_time_derivative_exit.row(row),
2450        );
2451        dense_outer_accumulate(
2452            time_target,
2453            h[[
2454                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2455                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2456            ]],
2457            self.x_time_right.row(row),
2458        );
2459        for (a, b, lhs, rhs) in [
2460            (
2461                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2462                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2463                &self.x_time_entry,
2464                &self.x_time_exit,
2465            ),
2466            (
2467                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2468                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2469                &self.x_time_entry,
2470                &self.x_time_derivative_exit,
2471            ),
2472            (
2473                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2474                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2475                &self.x_time_exit,
2476                &self.x_time_derivative_exit,
2477            ),
2478            (
2479                LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
2480                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2481                &self.x_time_entry,
2482                &self.x_time_right,
2483            ),
2484            (
2485                LATENT_SURVIVAL_PRIMARY_Q_EXIT,
2486                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2487                &self.x_time_exit,
2488                &self.x_time_right,
2489            ),
2490            (
2491                LATENT_SURVIVAL_PRIMARY_QDOT_EXIT,
2492                LATENT_SURVIVAL_PRIMARY_Q_RIGHT,
2493                &self.x_time_derivative_exit,
2494                &self.x_time_right,
2495            ),
2496        ] {
2497            let weight = h[[a, b]];
2498            if weight == 0.0 {
2499                continue;
2500            }
2501            dense_symmetric_cross_accumulate(time_target, weight, lhs.row(row), rhs.row(row));
2502        }
2503        // Mean block.
2504        let mean_weight = h[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
2505        self.x_mean
2506            .syr_row_into_view(row, mean_weight, mean_target.view_mut())
2507            .map_err(|error| {
2508                format!(
2509                    "latent survival mean block-diagonal pullback dimension mismatch: row={row}, mean_target_dim={:?}, x_mean_cols={}, error={error}",
2510                    mean_target.dim(),
2511                    self.x_mean.ncols()
2512                )
2513            })?;
2514        // Log-σ block (scalar).
2515        if let Some(target) = log_sigma_target {
2516            target[[0, 0]] += h[[
2517                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2518                LATENT_SURVIVAL_PRIMARY_LOG_SIGMA,
2519            ]];
2520        }
2521        Ok(())
2522    }
2523
2524    /// Block-diagonal evaluator used by `evaluate()`. Returns the per-row
2525    /// log-likelihood, the joint gradient (sliced into block gradients by
2526    /// the caller), and the three per-block diagonal Hessians without ever
2527    /// materializing the full joint matrix.
2528    fn evaluate_exact_newton_block_diagonals(
2529        &self,
2530        block_states: &[ParameterBlockState],
2531    ) -> Result<
2532        (
2533            f64,
2534            Array1<f64>,
2535            Array2<f64>,
2536            Array2<f64>,
2537            Option<Array2<f64>>,
2538        ),
2539        String,
2540    > {
2541        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2542        let q_right = self.time_q_right(block_states)?;
2543        let sigma = self.latent_sd(block_states)?;
2544        let slices = self.joint_slices();
2545        let include_log_sigma = slices.log_sigma.is_some();
2546        let mut ll = 0.0;
2547        let mut gradient = Array1::<f64>::zeros(slices.total);
2548        let p_time = slices.time.len();
2549        let p_mean = slices.mean.len();
2550        let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
2551        let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
2552        let mut hess_log_sigma = if include_log_sigma {
2553            Some(Array2::<f64>::zeros((1, 1)))
2554        } else {
2555            None
2556        };
2557        for row_idx in 0..self.event_target.len() {
2558            let wi = self.weights[row_idx];
2559            if wi <= MIN_WEIGHT {
2560                continue;
2561            }
2562            let row = self.build_row_at(
2563                row_idx,
2564                q_entry[row_idx],
2565                q_exit[row_idx],
2566                qdot_exit[row_idx],
2567                q_right[row_idx],
2568            )?;
2569            let (row_ll, primary_gradient, primary_hessian) =
2570                latent_survival_row_primary_gradient_hessian(
2571                    &self.quadctx,
2572                    &row,
2573                    q_entry[row_idx],
2574                    q_exit[row_idx],
2575                    qdot_exit[row_idx],
2576                    q_right[row_idx],
2577                    mu[row_idx],
2578                    sigma,
2579                    include_log_sigma,
2580                )?;
2581            ll += wi * row_ll;
2582            self.add_pullback_primary_gradient(
2583                &mut gradient,
2584                row_idx,
2585                &slices,
2586                &primary_gradient,
2587                wi,
2588            )?;
2589            self.add_pullback_primary_block_diagonals(
2590                row_idx,
2591                &(wi * primary_hessian),
2592                &mut hess_time,
2593                &mut hess_mean,
2594                hess_log_sigma.as_mut(),
2595            )?;
2596        }
2597        Ok((ll, gradient, hess_time, hess_mean, hess_log_sigma))
2598    }
2599
2600    fn evaluate_exact_newton_joint_dense(
2601        &self,
2602        block_states: &[ParameterBlockState],
2603    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
2604        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2605        let q_right = self.time_q_right(block_states)?;
2606        let sigma = self.latent_sd(block_states)?;
2607        let slices = self.joint_slices();
2608        let include_log_sigma = slices.log_sigma.is_some();
2609        let total = slices.total;
2610        let acc = deterministic_latent_survival_row_reduction(
2611            self.event_target.len(),
2612            || LatentSurvivalJointDenseAccum {
2613                ll: 0.0,
2614                gradient: Array1::<f64>::zeros(total),
2615                hessian: Array2::<f64>::zeros((total, total)),
2616            },
2617            |row_idx, acc| {
2618                let wi = self.weights[row_idx];
2619                if wi <= MIN_WEIGHT {
2620                    return Ok(());
2621                }
2622                let row = self.build_row_at(
2623                    row_idx,
2624                    q_entry[row_idx],
2625                    q_exit[row_idx],
2626                    qdot_exit[row_idx],
2627                    q_right[row_idx],
2628                )?;
2629                let (row_ll, primary_gradient, primary_hessian) =
2630                    latent_survival_row_primary_gradient_hessian(
2631                        &self.quadctx,
2632                        &row,
2633                        q_entry[row_idx],
2634                        q_exit[row_idx],
2635                        qdot_exit[row_idx],
2636                        q_right[row_idx],
2637                        mu[row_idx],
2638                        sigma,
2639                        include_log_sigma,
2640                    )?;
2641                acc.ll += wi * row_ll;
2642                self.add_pullback_primary_gradient(
2643                    &mut acc.gradient,
2644                    row_idx,
2645                    &slices,
2646                    &primary_gradient,
2647                    wi,
2648                )?;
2649                self.add_pullback_primary_hessian(
2650                    &mut acc.hessian,
2651                    row_idx,
2652                    &slices,
2653                    &(wi * primary_hessian),
2654                )?;
2655                Ok(())
2656            },
2657            |total_acc, chunk_acc| {
2658                total_acc.ll += chunk_acc.ll;
2659                total_acc.gradient += &chunk_acc.gradient;
2660                total_acc.hessian += &chunk_acc.hessian;
2661            },
2662        )?;
2663        Ok((acc.ll, acc.gradient, acc.hessian))
2664    }
2665
2666    fn exact_newton_joint_hessian_directional_derivative_dense(
2667        &self,
2668        block_states: &[ParameterBlockState],
2669        d_beta_flat: &Array1<f64>,
2670    ) -> Result<Array2<f64>, String> {
2671        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2672        let q_right = self.time_q_right(block_states)?;
2673        let sigma = self.latent_sd(block_states)?;
2674        let slices = self.joint_slices();
2675        if d_beta_flat.len() != slices.total {
2676            return Err(format!(
2677                "latent survival joint dH direction length mismatch: got {}, expected {}",
2678                d_beta_flat.len(),
2679                slices.total
2680            ));
2681        }
2682        let include_log_sigma = slices.log_sigma.is_some();
2683        let total = slices.total;
2684        let acc = deterministic_latent_survival_row_reduction(
2685            self.event_target.len(),
2686            || LatentSurvivalDenseHessianAccum {
2687                hessian: Array2::<f64>::zeros((total, total)),
2688            },
2689            |row_idx, acc| {
2690                let wi = self.weights[row_idx];
2691                if wi <= MIN_WEIGHT {
2692                    return Ok(());
2693                }
2694                let row = self.build_row_at(
2695                    row_idx,
2696                    q_entry[row_idx],
2697                    q_exit[row_idx],
2698                    qdot_exit[row_idx],
2699                    q_right[row_idx],
2700                )?;
2701                let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
2702                let third = latent_survival_row_primary_third_contracted(
2703                    &self.quadctx,
2704                    &row,
2705                    q_entry[row_idx],
2706                    q_exit[row_idx],
2707                    qdot_exit[row_idx],
2708                    q_right[row_idx],
2709                    mu[row_idx],
2710                    sigma,
2711                    &direction,
2712                    include_log_sigma,
2713                )?;
2714                self.add_pullback_primary_hessian(
2715                    &mut acc.hessian,
2716                    row_idx,
2717                    &slices,
2718                    &(wi * third),
2719                )?;
2720                Ok(())
2721            },
2722            |total_acc, chunk_acc| {
2723                total_acc.hessian += &chunk_acc.hessian;
2724            },
2725        )?;
2726        Ok(acc.hessian)
2727    }
2728
2729    fn exact_newton_joint_hessian_second_directional_derivative_dense(
2730        &self,
2731        block_states: &[ParameterBlockState],
2732        d_beta_u_flat: &Array1<f64>,
2733        d_beta_v_flat: &Array1<f64>,
2734    ) -> Result<Array2<f64>, String> {
2735        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
2736        let q_right = self.time_q_right(block_states)?;
2737        let sigma = self.latent_sd(block_states)?;
2738        let slices = self.joint_slices();
2739        if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
2740            return Err(format!(
2741                "latent survival joint d2H direction length mismatch: got {} and {}, expected {}",
2742                d_beta_u_flat.len(),
2743                d_beta_v_flat.len(),
2744                slices.total
2745            ));
2746        }
2747        let include_log_sigma = slices.log_sigma.is_some();
2748        let total = slices.total;
2749        let acc = deterministic_latent_survival_row_reduction(
2750            self.event_target.len(),
2751            || LatentSurvivalDenseHessianAccum {
2752                hessian: Array2::<f64>::zeros((total, total)),
2753            },
2754            |row_idx, acc| {
2755                let wi = self.weights[row_idx];
2756                if wi <= MIN_WEIGHT {
2757                    return Ok(());
2758                }
2759                let row = self.build_row_at(
2760                    row_idx,
2761                    q_entry[row_idx],
2762                    q_exit[row_idx],
2763                    qdot_exit[row_idx],
2764                    q_right[row_idx],
2765                )?;
2766                let direction_u =
2767                    self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
2768                let direction_v =
2769                    self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
2770                let fourth = latent_survival_row_primary_fourth_contracted(
2771                    &self.quadctx,
2772                    &row,
2773                    q_entry[row_idx],
2774                    q_exit[row_idx],
2775                    qdot_exit[row_idx],
2776                    q_right[row_idx],
2777                    mu[row_idx],
2778                    sigma,
2779                    &direction_u,
2780                    &direction_v,
2781                    include_log_sigma,
2782                )?;
2783                self.add_pullback_primary_hessian(
2784                    &mut acc.hessian,
2785                    row_idx,
2786                    &slices,
2787                    &(wi * fourth),
2788                )?;
2789                Ok(())
2790            },
2791            |total_acc, chunk_acc| {
2792                total_acc.hessian += &chunk_acc.hessian;
2793            },
2794        )?;
2795        Ok(acc.hessian)
2796    }
2797}
2798
2799fn log_kernel_ratio(
2800    bundle: &crate::survival::lognormal_kernel::LogLognormalKernelBundle,
2801    num: usize,
2802    den: usize,
2803) -> f64 {
2804    let delta = bundle.get(num) - bundle.get(den);
2805    if delta.is_finite() {
2806        delta.exp()
2807    } else if delta > 0.0 {
2808        f64::INFINITY
2809    } else {
2810        0.0
2811    }
2812}
2813
2814fn logk_q_derivatives(
2815    quadctx: &QuadratureContext,
2816    k: usize,
2817    mass: f64,
2818    mu: f64,
2819    sigma: f64,
2820) -> Result<(f64, f64, IntegratedExpectationMode), LatentSurvivalError> {
2821    if mass <= 0.0 {
2822        return Ok((0.0, 0.0, IntegratedExpectationMode::ExactClosedForm));
2823    }
2824    let bundle = log_kernel_bundle(quadctx, mass, mu, sigma, k + 2).map_err(|e| {
2825        LatentSurvivalError::NumericalFailure {
2826            reason: format!("latent survival kernel evaluation failed: {e}"),
2827        }
2828    })?;
2829    let r1 = log_kernel_ratio(&bundle, k + 1, k);
2830    let r2 = log_kernel_ratio(&bundle, k + 2, k);
2831    let d1 = -mass * r1;
2832    let d2 = d1 + mass * mass * (r2 - r1 * r1);
2833    Ok((d1, d2, bundle.mode))
2834}
2835
2836fn latent_survival_time_jet(
2837    quadctx: &QuadratureContext,
2838    row: &LatentSurvivalRow,
2839    qdot_exit: f64,
2840    mu: f64,
2841    sigma: f64,
2842) -> Result<LatentSurvivalTimeJet, LatentSurvivalError> {
2843    let (entry_d1, entry_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_entry, mu, sigma)?;
2844    match row.event_type {
2845        LatentSurvivalEventType::RightCensored => {
2846            let (exit_d1, exit_d2, _) = logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2847            Ok(LatentSurvivalTimeJet {
2848                grad_entry: -entry_d1,
2849                grad_exit: exit_d1,
2850                neg_hess_entry: entry_d2,
2851                neg_hess_exit: -exit_d2,
2852            })
2853        }
2854        LatentSurvivalEventType::ExactEvent => {
2855            if !(qdot_exit.is_finite() && qdot_exit > 0.0) {
2856                return Err(LatentSurvivalError::NumericalFailure {
2857                    reason: format!(
2858                        "latent survival requires positive finite baseline hazard derivative, got {qdot_exit}"
2859                    ),
2860                });
2861            }
2862            if row.hazard_unloaded > 0.0 {
2863                let bundle =
2864                    log_kernel_bundle(quadctx, row.mass_exit, mu, sigma, 3).map_err(|e| {
2865                        LatentSurvivalError::NumericalFailure {
2866                            reason: format!("latent survival kernel evaluation failed: {e}"),
2867                        }
2868                    })?;
2869                let (unloaded_d1, unloaded_d2, _) =
2870                    logk_q_derivatives(quadctx, 0, row.mass_exit, mu, sigma)?;
2871                let (loaded_log_d1, loaded_d2, _) =
2872                    logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2873                let loaded_d1 = 1.0 + loaded_log_d1;
2874                let log_loaded = row.hazard_loaded.ln() + bundle.get(1);
2875                let log_unloaded = row.hazard_unloaded.ln() + bundle.get(0);
2876                let shift = log_loaded.max(log_unloaded);
2877                let loaded_weight = (log_loaded - shift).exp();
2878                let unloaded_weight = (log_unloaded - shift).exp();
2879                let normalizer = loaded_weight + unloaded_weight;
2880                if !(normalizer.is_finite() && normalizer > 0.0) {
2881                    return Err(LatentSurvivalError::NumericalFailure {
2882                        reason: "latent survival exact-event numerator became non-finite under loaded/unloaded hazard decomposition"
2883                            .to_string(),
2884                    });
2885                }
2886                let w_loaded = loaded_weight / normalizer;
2887                let w_unloaded = unloaded_weight / normalizer;
2888                let grad_exit = w_loaded * loaded_d1 + w_unloaded * unloaded_d1;
2889                let d2_exit = w_loaded * (loaded_d2 + loaded_d1 * loaded_d1)
2890                    + w_unloaded * (unloaded_d2 + unloaded_d1 * unloaded_d1)
2891                    - grad_exit * grad_exit;
2892                Ok(LatentSurvivalTimeJet {
2893                    grad_entry: -entry_d1,
2894                    grad_exit,
2895                    neg_hess_entry: entry_d2,
2896                    neg_hess_exit: -d2_exit,
2897                })
2898            } else {
2899                let (exit_d1, exit_d2, _) =
2900                    logk_q_derivatives(quadctx, 1, row.mass_exit, mu, sigma)?;
2901                Ok(LatentSurvivalTimeJet {
2902                    grad_entry: -entry_d1,
2903                    grad_exit: 1.0 + exit_d1,
2904                    neg_hess_entry: entry_d2,
2905                    neg_hess_exit: -exit_d2,
2906                })
2907            }
2908        }
2909        LatentSurvivalEventType::IntervalCensored => {
2910            Err(LatentSurvivalError::UnsupportedConfiguration {
2911                reason:
2912                    "latent survival dynamic time derivatives do not implement interval censoring"
2913                        .to_string(),
2914            })
2915        }
2916    }
2917}
2918
2919fn dense_outer_accumulate<S>(
2920    target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2921    weight: f64,
2922    x: ArrayView1<'_, f64>,
2923) where
2924    S: ndarray::DataMut<Elem = f64>,
2925{
2926    for a in 0..x.len() {
2927        let xa = x[a];
2928        if xa == 0.0 {
2929            continue;
2930        }
2931        for b in 0..x.len() {
2932            let xb = x[b];
2933            if xb == 0.0 {
2934                continue;
2935            }
2936            target[[a, b]] += weight * xa * xb;
2937        }
2938    }
2939}
2940
2941fn dense_symmetric_cross_accumulate<S>(
2942    target: &mut ndarray::ArrayBase<S, ndarray::Ix2>,
2943    weight: f64,
2944    x: ArrayView1<'_, f64>,
2945    y: ArrayView1<'_, f64>,
2946) where
2947    S: ndarray::DataMut<Elem = f64>,
2948{
2949    for a in 0..x.len() {
2950        let xa = x[a];
2951        let ya = y[a];
2952        if xa == 0.0 && ya == 0.0 {
2953            continue;
2954        }
2955        for b in 0..x.len() {
2956            let xb = x[b];
2957            let yb = y[b];
2958            let contribution = xa * yb + ya * xb;
2959            if contribution == 0.0 {
2960                continue;
2961            }
2962            target[[a, b]] += weight * contribution;
2963        }
2964    }
2965}
2966
2967fn build_latent_survival_row(
2968    row_index: usize,
2969    hazard_loading: HazardLoading,
2970    event_type: LatentSurvivalEventType,
2971    q_entry: f64,
2972    q_exit: f64,
2973    qdot_exit: f64,
2974    q_right: f64,
2975    unloaded_mass_entry: f64,
2976    unloaded_mass_exit: f64,
2977    unloaded_mass_right: f64,
2978    unloaded_hazard_exit: f64,
2979) -> Result<LatentSurvivalRow, LatentSurvivalError> {
2980    if !(q_entry.is_finite() && q_exit.is_finite()) {
2981        return Err(LatentSurvivalError::NumericalFailure {
2982            reason: format!(
2983                "latent survival requires finite q_entry and q_exit, got q_entry={q_entry}, q_exit={q_exit}"
2984            ),
2985        });
2986    }
2987    if q_exit < q_entry {
2988        return Err(LatentSurvivalError::NumericalFailure {
2989            reason: format!(
2990                "latent survival requires q_exit >= q_entry so cumulative mass is monotone, got q_entry={q_entry}, q_exit={q_exit}"
2991            ),
2992        });
2993    }
2994    if !(unloaded_mass_entry.is_finite()
2995        && unloaded_mass_exit.is_finite()
2996        && unloaded_hazard_exit.is_finite())
2997    {
2998        return Err(LatentSurvivalError::InvalidDataset {
2999            reason: format!(
3000                "latent survival requires finite unloaded components, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3001            ),
3002        });
3003    }
3004    if unloaded_mass_entry < 0.0
3005        || unloaded_mass_exit < unloaded_mass_entry
3006        || unloaded_hazard_exit < 0.0
3007    {
3008        return Err(LatentSurvivalError::InvalidDataset {
3009            reason: format!(
3010                "latent survival requires unloaded masses/hazard to be non-negative and monotone, got entry_mass={unloaded_mass_entry}, exit_mass={unloaded_mass_exit}, exit_hazard={unloaded_hazard_exit}"
3011            ),
3012        });
3013    }
3014    let mass_entry = q_entry.exp();
3015    let mass_exit = q_exit.exp();
3016    let row = match event_type {
3017        LatentSurvivalEventType::RightCensored => {
3018            validate_unloaded_components_for_loading(
3019                "latent-survival",
3020                row_index,
3021                hazard_loading,
3022                unloaded_mass_entry,
3023                unloaded_mass_exit,
3024                Some(unloaded_hazard_exit),
3025            )?;
3026            LatentSurvivalRow::right_censored(
3027                mass_entry,
3028                mass_exit,
3029                unloaded_mass_entry,
3030                unloaded_mass_exit,
3031            )
3032        }
3033        LatentSurvivalEventType::ExactEvent => {
3034            validate_unloaded_components_for_loading(
3035                "latent-survival",
3036                row_index,
3037                hazard_loading,
3038                unloaded_mass_entry,
3039                unloaded_mass_exit,
3040                Some(unloaded_hazard_exit),
3041            )?;
3042            LatentSurvivalRow::exact_event(
3043                mass_entry,
3044                mass_exit,
3045                unloaded_mass_entry,
3046                unloaded_mass_exit,
3047                mass_exit
3048                    * if qdot_exit.is_finite() && qdot_exit > 0.0 {
3049                        qdot_exit
3050                    } else {
3051                        return Err(LatentSurvivalError::NumericalFailure {
3052                            reason: format!(
3053                                "latent survival exact event requires positive finite baseline hazard derivative, got {qdot_exit}"
3054                            ),
3055                        });
3056                    },
3057                unloaded_hazard_exit,
3058            )
3059        }
3060        LatentSurvivalEventType::IntervalCensored => {
3061            // Interval `(L, R]`: `q_exit` carries the LEFT boundary transform
3062            // `log B(L)` (so `mass_left = exp(q_exit)`) and `q_right` the RIGHT
3063            // boundary `log B(R)`. The likelihood is the survival-mass
3064            // difference `log[S(L) − S(R)]`, requiring `B(L) ≤ B(R)` i.e.
3065            // `q_exit ≤ q_right`. No event hazard participates, so the unloaded
3066            // exit hazard must be the full-loading zero (validated below via the
3067            // interval-specific unloaded check at the left/right boundaries).
3068            if !q_right.is_finite() {
3069                return Err(LatentSurvivalError::NumericalFailure {
3070                    reason: format!(
3071                        "latent survival interval row {} requires a finite q_right, got {q_right}",
3072                        row_index + 1
3073                    ),
3074                });
3075            }
3076            if q_right < q_exit {
3077                return Err(LatentSurvivalError::NumericalFailure {
3078                    reason: format!(
3079                        "latent survival interval row {} requires q_right >= q_exit (R >= L) so the \
3080                         survival-mass difference is non-negative, got q_left={q_exit}, q_right={q_right}",
3081                        row_index + 1
3082                    ),
3083                });
3084            }
3085            if !(unloaded_mass_right.is_finite()) || unloaded_mass_right < unloaded_mass_exit {
3086                return Err(LatentSurvivalError::InvalidDataset {
3087                    reason: format!(
3088                        "latent survival interval row {} requires a finite unloaded right mass >= unloaded left mass, got left={unloaded_mass_exit}, right={unloaded_mass_right}",
3089                        row_index + 1
3090                    ),
3091                });
3092            }
3093            // Interval rows carry no exit-event hazard; the loaded/unloaded
3094            // contract is validated by `LatentSurvivalRow::validate` (entry <=
3095            // left <= right monotonicity on both loaded and unloaded masses).
3096            let mass_right = q_right.exp();
3097            LatentSurvivalRow::interval_censored(
3098                mass_entry,
3099                mass_exit,
3100                mass_right,
3101                unloaded_mass_entry,
3102                unloaded_mass_exit,
3103                unloaded_mass_right,
3104            )
3105        }
3106    };
3107    row.validate()
3108        .map_err(|e| LatentSurvivalError::InvalidDataset {
3109            reason: e.to_string(),
3110        })?;
3111    Ok(row)
3112}
3113
3114#[derive(Clone, Copy)]
3115struct BinaryFromLogSurvival {
3116    log_lik: f64,
3117    /// dℓ/ds where s = log_survival and ℓ = log_lik. For event=1 this is
3118    /// ℓ' = -S/(1-S); for event=0 this is ℓ' = 1 (because ℓ ≡ s).
3119    grad_scale: f64,
3120    /// Coefficient applied to `survival_jet.neg_hessian` (which equals
3121    /// -d²s/dβ²) when assembling the negative Hessian of `wi * log_lik`
3122    /// against β. The Newton accumulator computes
3123    ///     neg_Hess(log_lik) = grad_scale * neg_hessian + outer_scale * score²
3124    /// so by the chain rule this MUST equal `grad_scale` (= ℓ'). Keeping
3125    /// the two fields separate is purely for readability at call sites;
3126    /// the `assert!` in [`binary_log_survival_scales`] enforces the
3127    /// equality.
3128    neg_hess_scale: f64,
3129    /// -ℓ''(s). For event=1 this is +S/(1-S)²; for event=0 it is 0.
3130    outer_scale: f64,
3131    /// ℓ''(s) — derivative of `grad_scale` w.r.t. s.
3132    grad_scale_prime: f64,
3133    /// ℓ'''(s) — second derivative of `grad_scale` w.r.t. s.
3134    grad_scale_second: f64,
3135    /// -ℓ'''(s) — derivative of `outer_scale` w.r.t. s.
3136    outer_scale_prime: f64,
3137    /// -ℓ''''(s) — second derivative of `outer_scale` w.r.t. s.
3138    outer_scale_second: f64,
3139}
3140
3141/// Analytic source of truth for the directional derivatives of
3142/// ℓ(s) = log(1 - exp(s)) at s = `log_survival`. Returns
3143/// `(ℓ, ℓ', ℓ'', ℓ''', ℓ'''')`. All consumer scales (`grad_scale`,
3144/// `neg_hess_scale`, `outer_scale`, and their two derivatives each)
3145/// are derived from this single function so the sign/algebra cannot
3146/// drift between sites.
3147#[inline]
3148fn binary_log_survival_scales(survival: f64, event_prob: f64) -> (f64, f64, f64, f64, f64) {
3149    // ℓ(s)   = log(1 - exp(s)) = log(event_prob)
3150    // dS/ds  = S,    dP/ds = -S        (S=survival, P=event_prob)
3151    // ℓ'(s)  = -S/P
3152    // ℓ''(s) = d/ds[-S/P] = -S/P²        (since P + S = 1)
3153    // ℓ'''(s) = d/ds[-S/P²] = -S(1 + S)/P³
3154    // ℓ''''(s) = d/ds[-S(1+S)/P³]
3155    //          = -S/P³ - 3S²/P³ - 6S²(1+S)/P⁴ - ... ; expanded form below.
3156    let log_lik = event_prob.ln();
3157    let p = event_prob;
3158    let p2 = p * p;
3159    let p3 = p2 * p;
3160    let p4 = p3 * p;
3161    let s = survival;
3162    let s2 = s * s;
3163    let s3 = s2 * s;
3164    let ell_prime = -s / p;
3165    let ell_pp = -s / p2;
3166    let ell_ppp = -s * (1.0 + s) / p3;
3167    // ℓ''''(s) = -S·(1 + 4S + S²) / P⁴ - 3·S²·(1+S)/P⁴? Use the equivalent
3168    // expansion that matches the prior closed form:
3169    //   d/ds[-S(1+S)/P³] = -(S + 2S²)/P³ - 3·S·(1+S)·S/P⁴
3170    //                    = -(S + 2S²)/P³ - 3S²(1+S)/P⁴
3171    // Combining over P⁴: -(S + 2S²)·P/P⁴ - 3S²(1+S)/P⁴
3172    //                  = -[S·P + 2S²·P + 3S² + 3S³] / P⁴
3173    // With P = 1 - S: S·P = S - S²; 2S²·P = 2S² - 2S³.
3174    //   numerator = -[S - S² + 2S² - 2S³ + 3S² + 3S³] = -[S + 4S² + S³].
3175    // So ℓ''''(s) = -(S + 4S² + S³) / P⁴.
3176    let ell_pppp = -(s + 4.0 * s2 + s3) / p4;
3177    (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp)
3178}
3179
3180fn binary_from_log_survival(
3181    log_survival: f64,
3182    event: u8,
3183) -> Result<BinaryFromLogSurvival, LatentSurvivalError> {
3184    if event == 0 {
3185        // ℓ(s) = s ⇒ ℓ' = 1, ℓ'' = ℓ''' = ℓ'''' = 0.
3186        return Ok(BinaryFromLogSurvival {
3187            log_lik: log_survival,
3188            grad_scale: 1.0,
3189            neg_hess_scale: 1.0,
3190            outer_scale: 0.0,
3191            grad_scale_prime: 0.0,
3192            grad_scale_second: 0.0,
3193            outer_scale_prime: 0.0,
3194            outer_scale_second: 0.0,
3195        });
3196    }
3197    if event != 1 {
3198        return Err(LatentSurvivalError::InvalidDataset {
3199            reason: format!("latent-binary requires event targets in {{0,1}}, got {event}"),
3200        });
3201    }
3202    // Cap log S(t) strictly below zero so the event probability
3203    // `1 - exp(log S)` stays strictly positive even when the survival
3204    // probability rounds to exactly 1 (log S == 0): a zero event probability
3205    // would make the binary log-likelihood `log(event_prob)` diverge. The cap
3206    // is at the f64 resolution near 1.0, so it never perturbs a genuinely
3207    // informative survival value.
3208    const MAX_LOG_SURVIVAL: f64 = -1e-15;
3209    let log_survival = log_survival.min(MAX_LOG_SURVIVAL);
3210    let survival = log_survival.exp();
3211    let event_prob = 1.0 - survival;
3212    if !(event_prob.is_finite() && event_prob > 0.0) {
3213        return Err(LatentSurvivalError::NumericalFailure {
3214            reason: format!(
3215                "latent-binary encountered non-positive event probability from log survival {log_survival}"
3216            ),
3217        });
3218    }
3219    let (log_lik, ell_prime, ell_pp, ell_ppp, ell_pppp) =
3220        binary_log_survival_scales(survival, event_prob);
3221    let grad_scale = ell_prime;
3222    let neg_hess_scale = ell_prime; // coefficient on (-d²s/dβ²); equals ℓ'.
3223    let outer_scale = -ell_pp;
3224    let grad_scale_prime = ell_pp;
3225    let grad_scale_second = ell_ppp;
3226    let outer_scale_prime = -ell_ppp;
3227    let outer_scale_second = -ell_pppp;
3228    // The Newton accumulator at the call sites computes
3229    //     neg_Hess(log_lik) = neg_hess_scale * (-d²s/dβ²) + outer_scale * (ds/dβ)²
3230    // For this identity to hold by the chain rule, the coefficient on the
3231    // neg_hessian term must equal ℓ' (== grad_scale). Document the invariant.
3232    assert!(
3233        (grad_scale - neg_hess_scale).abs() <= 1e-15 * grad_scale.abs().max(1.0),
3234        "binary_from_log_survival invariant: neg_hess_scale ({neg_hess_scale}) must equal grad_scale ({grad_scale}) so that grad_scale and the coefficient on neg_hessian share sign"
3235    );
3236    assert!(
3237        outer_scale >= 0.0 || !outer_scale.is_finite(),
3238        "binary_from_log_survival invariant: outer_scale (= -ℓ'') must be non-negative for event=1; got {outer_scale}"
3239    );
3240    Ok(BinaryFromLogSurvival {
3241        log_lik,
3242        grad_scale,
3243        neg_hess_scale,
3244        outer_scale,
3245        grad_scale_prime,
3246        grad_scale_second,
3247        outer_scale_prime,
3248        outer_scale_second,
3249    })
3250}
3251
3252impl LatentBinaryFamily {
3253    /// Assemble the per-row [`LatentSurvivalRow`] for a row treated as a pure
3254    /// right-censored survival contribution (exit time is the censoring
3255    /// boundary, unit exit-hazard derivative, no right / post-exit unloaded
3256    /// mass). Shared by every per-row binary-from-survival pullback reduction;
3257    /// behavior is identical to the previously inlined `RightCensored` call.
3258    fn build_right_censored_row_at(
3259        &self,
3260        row_idx: usize,
3261        q_entry: f64,
3262        q_exit: f64,
3263    ) -> Result<LatentSurvivalRow, LatentSurvivalError> {
3264        build_latent_survival_row(
3265            row_idx,
3266            self.hazard_loading,
3267            LatentSurvivalEventType::RightCensored,
3268            q_entry,
3269            q_exit,
3270            1.0,
3271            q_exit,
3272            self.unloaded_mass_entry[row_idx],
3273            self.unloaded_mass_exit[row_idx],
3274            0.0,
3275            0.0,
3276        )
3277    }
3278
3279    fn joint_slices(&self) -> LatentSurvivalJointSlices {
3280        let p_time = self.x_time_exit.ncols();
3281        let p_mean = self.x_mean.ncols();
3282        LatentSurvivalJointSlices {
3283            time: 0..p_time,
3284            mean: p_time..p_time + p_mean,
3285            log_sigma: None,
3286            total: p_time + p_mean,
3287        }
3288    }
3289
3290    fn row_primary_direction_from_flat(
3291        &self,
3292        row: usize,
3293        slices: &LatentSurvivalJointSlices,
3294        d_beta_flat: &Array1<f64>,
3295    ) -> Array1<f64> {
3296        let mut out = Array1::<f64>::zeros(LATENT_SURVIVAL_PRIMARY_DIM);
3297        let d_time = d_beta_flat.slice(s![slices.time.clone()]);
3298        out[LATENT_SURVIVAL_PRIMARY_Q_ENTRY] = self.x_time_entry.row(row).dot(&d_time);
3299        out[LATENT_SURVIVAL_PRIMARY_Q_EXIT] = self.x_time_exit.row(row).dot(&d_time);
3300        out[LATENT_SURVIVAL_PRIMARY_MU] = self
3301            .x_mean
3302            .dot_row_view(row, d_beta_flat.slice(s![slices.mean.clone()]));
3303        out
3304    }
3305
3306    fn add_pullback_primary_gradient(
3307        &self,
3308        target: &mut Array1<f64>,
3309        row: usize,
3310        slices: &LatentSurvivalJointSlices,
3311        primary_gradient: &Array1<f64>,
3312        weight: f64,
3313    ) {
3314        for (primary_idx, time_vec) in [
3315            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3316            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3317        ] {
3318            let scale = weight * primary_gradient[primary_idx];
3319            if scale == 0.0 {
3320                continue;
3321            }
3322            for i in 0..time_vec.len() {
3323                let xi = time_vec[i];
3324                if xi != 0.0 {
3325                    target[slices.time.start + i] += scale * xi;
3326                }
3327            }
3328        }
3329
3330        let mean_scale = weight * primary_gradient[LATENT_SURVIVAL_PRIMARY_MU];
3331        if mean_scale != 0.0 {
3332            self.x_mean
3333                .axpy_row_into(
3334                    row,
3335                    mean_scale,
3336                    &mut target.slice_mut(s![slices.mean.clone()]),
3337                )
3338                // SAFETY: `slices.mean` sized at construction to match
3339                // `x_mean.ncols()`; an error means caller-side shape drift,
3340                // an invariant violation. A swallowed sentinel would silently
3341                // corrupt the joint gradient, so fail loudly instead.
3342                .unwrap_or_else(|error| {
3343                    panic!(
3344                        "latent binary mean gradient pullback dimension mismatch: row={row}, mean_slice={:?}, target_len={}, x_mean_cols={}, error={error}",
3345                        slices.mean,
3346                        target.len(),
3347                        self.x_mean.ncols()
3348                    )
3349                });
3350        }
3351    }
3352
3353    fn add_pullback_primary_hessian(
3354        &self,
3355        target: &mut Array2<f64>,
3356        row: usize,
3357        slices: &LatentSurvivalJointSlices,
3358        primary_hessian: &Array2<f64>,
3359    ) {
3360        {
3361            let time_target = &mut target.slice_mut(s![slices.time.clone(), slices.time.clone()]);
3362            dense_outer_accumulate(
3363                time_target,
3364                primary_hessian[[
3365                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3366                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3367                ]],
3368                self.x_time_entry.row(row),
3369            );
3370            dense_outer_accumulate(
3371                time_target,
3372                primary_hessian[[
3373                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3374                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3375                ]],
3376                self.x_time_exit.row(row),
3377            );
3378            dense_symmetric_cross_accumulate(
3379                time_target,
3380                primary_hessian[[
3381                    LATENT_SURVIVAL_PRIMARY_Q_ENTRY,
3382                    LATENT_SURVIVAL_PRIMARY_Q_EXIT,
3383                ]],
3384                self.x_time_entry.row(row),
3385                self.x_time_exit.row(row),
3386            );
3387        }
3388
3389        let mean_weight = primary_hessian[[LATENT_SURVIVAL_PRIMARY_MU, LATENT_SURVIVAL_PRIMARY_MU]];
3390        self.x_mean
3391            .syr_row_into_view(
3392                row,
3393                mean_weight,
3394                target.slice_mut(s![slices.mean.clone(), slices.mean.clone()]),
3395            )
3396            .unwrap_or_else(|error| {
3397                // SAFETY: `slices.mean` × `slices.mean` slab sized at
3398                // construction to `x_mean.ncols()` × `x_mean.ncols()`;
3399                // an error here is caller-side shape drift, an invariant
3400                // violation. A swallowed sentinel would silently corrupt the
3401                // joint Hessian, so fail loudly instead.
3402                panic!(
3403                    "latent binary mean Hessian pullback dimension mismatch: row={row}, mean_slice={:?}, target_dim={:?}, x_mean_cols={}, error={error}",
3404                    slices.mean,
3405                    target.dim(),
3406                    self.x_mean.ncols()
3407                )
3408            });
3409
3410        let mean_row = self
3411            .x_mean
3412            .try_row_chunk(row..row + 1)
3413            .unwrap_or_else(|error| {
3414                // SAFETY: row index comes from the enclosing `0..n` loop
3415                // bound by `self.x_mean.nrows()`, so `row..row+1` is
3416                // always a valid single-row chunk.
3417                panic!(
3418                    "latent binary mean pullback row chunk failed: row={row}, x_mean_rows={}, x_mean_cols={}, error={error}",
3419                    self.x_mean.nrows(),
3420                    self.x_mean.ncols()
3421                )
3422            });
3423        let mean_vec = mean_row.row(0);
3424        for (primary_idx, time_vec) in [
3425            (LATENT_SURVIVAL_PRIMARY_Q_ENTRY, self.x_time_entry.row(row)),
3426            (LATENT_SURVIVAL_PRIMARY_Q_EXIT, self.x_time_exit.row(row)),
3427        ] {
3428            let weight = primary_hessian[[primary_idx, LATENT_SURVIVAL_PRIMARY_MU]];
3429            if weight == 0.0 {
3430                continue;
3431            }
3432            for i in 0..time_vec.len() {
3433                let xi = time_vec[i];
3434                if xi == 0.0 {
3435                    continue;
3436                }
3437                for j in 0..mean_vec.len() {
3438                    let xj = mean_vec[j];
3439                    if xj == 0.0 {
3440                        continue;
3441                    }
3442                    target[[slices.time.start + i, slices.mean.start + j]] += weight * xi * xj;
3443                    target[[slices.mean.start + j, slices.time.start + i]] += weight * xj * xi;
3444                }
3445            }
3446        }
3447    }
3448
3449    fn evaluate_exact_newton_joint_dense(
3450        &self,
3451        block_states: &[ParameterBlockState],
3452    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3453        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3454        let slices = self.joint_slices();
3455        let mut ll = 0.0;
3456        let mut gradient = Array1::<f64>::zeros(slices.total);
3457        let mut hessian = Array2::<f64>::zeros((slices.total, slices.total));
3458        for row_idx in 0..self.event_target.len() {
3459            let wi = self.weights[row_idx];
3460            if wi <= MIN_WEIGHT {
3461                continue;
3462            }
3463            let row =
3464                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3465            let (row_log_survival, survival_gradient, survival_hessian) =
3466                latent_survival_row_primary_gradient_hessian(
3467                    &self.quadctx,
3468                    &row,
3469                    q_entry[row_idx],
3470                    q_exit[row_idx],
3471                    1.0,
3472                    q_exit[row_idx],
3473                    mu[row_idx],
3474                    self.latent_sd,
3475                    false,
3476                )?;
3477            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3478            ll += wi * binary.log_lik;
3479            let primary_gradient = binary.grad_scale * &survival_gradient;
3480            let mut primary_hessian = binary.grad_scale * survival_hessian;
3481            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3482                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3483                    primary_hessian[[a, b]] +=
3484                        binary.outer_scale * survival_gradient[a] * survival_gradient[b];
3485                }
3486            }
3487            self.add_pullback_primary_gradient(
3488                &mut gradient,
3489                row_idx,
3490                &slices,
3491                &primary_gradient,
3492                wi,
3493            );
3494            self.add_pullback_primary_hessian(
3495                &mut hessian,
3496                row_idx,
3497                &slices,
3498                &(wi * primary_hessian),
3499            );
3500        }
3501        Ok((ll, gradient, hessian))
3502    }
3503
3504    /// Per-row residuals of the unpenalized NLL with respect to the baseline
3505    /// time-block offsets `(entry, exit)`.
3506    ///
3507    /// The latent-binary deployment likelihood is a monotone scalar transform
3508    /// `ℓ_bin = b(log S_row)` of the latent-survival row log-survival, so by the
3509    /// chain rule `∂ℓ_bin/∂q_ch = b'(log S)·∂(log S)/∂q_ch = grad_scale·g_ch`,
3510    /// where `g_ch` are the `Q_ENTRY`/`Q_EXIT` components of the survival row
3511    /// primary gradient. The baseline θ enters only the additive entry/exit time
3512    /// offsets (`q̇_exit` is held at the constant deployment derivative `1`, so
3513    /// the derivative channel carries no baseline offset and its residual is 0).
3514    /// Sampleweight-scaled to match the [`OffsetChannelResiduals`] contract.
3515    pub fn offset_channel_residuals(
3516        &self,
3517        block_states: &[ParameterBlockState],
3518    ) -> Result<crate::survival::OffsetChannelResiduals, String> {
3519        let n = self.event_target.len();
3520        if block_states.is_empty() {
3521            log::warn!(
3522                "LatentBinaryFamily::offset_channel_residuals: block_states is empty \
3523                 (degraded fit); returning zero offset residuals (n={n})"
3524            );
3525            return Ok(crate::survival::OffsetChannelResiduals {
3526                exit: Array1::<f64>::zeros(n),
3527                entry: Array1::<f64>::zeros(n),
3528                derivative: Array1::<f64>::zeros(n),
3529                right: Array1::<f64>::zeros(n),
3530            });
3531        }
3532        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3533        let mut entry = Array1::<f64>::zeros(n);
3534        let mut exit = Array1::<f64>::zeros(n);
3535        for row_idx in 0..n {
3536            let wi = self.weights[row_idx];
3537            if wi <= MIN_WEIGHT {
3538                continue;
3539            }
3540            let row =
3541                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3542            let (row_log_survival, survival_gradient, _) =
3543                latent_survival_row_primary_gradient_hessian(
3544                    &self.quadctx,
3545                    &row,
3546                    q_entry[row_idx],
3547                    q_exit[row_idx],
3548                    1.0,
3549                    q_exit[row_idx],
3550                    mu[row_idx],
3551                    self.latent_sd,
3552                    false,
3553                )?;
3554            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3555            // ∂NLL/∂o_ch = −w · grad_scale · ∂(log S)/∂q_ch.
3556            entry[row_idx] =
3557                -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
3558            exit[row_idx] =
3559                -wi * binary.grad_scale * survival_gradient[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
3560        }
3561        Ok(crate::survival::OffsetChannelResiduals {
3562            exit,
3563            entry,
3564            derivative: Array1::<f64>::zeros(n),
3565            // Latent-binary deployment has no interval upper bound; the `R`
3566            // channel is structurally absent (every row is right-censored).
3567            right: Array1::<f64>::zeros(n),
3568        })
3569    }
3570
3571    fn exact_newton_joint_hessian_directional_derivative_dense(
3572        &self,
3573        block_states: &[ParameterBlockState],
3574        d_beta_flat: &Array1<f64>,
3575    ) -> Result<Array2<f64>, String> {
3576        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3577        let slices = self.joint_slices();
3578        if d_beta_flat.len() != slices.total {
3579            return Err(format!(
3580                "latent binary joint dH direction length mismatch: got {}, expected {}",
3581                d_beta_flat.len(),
3582                slices.total
3583            ));
3584        }
3585        let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3586        for row_idx in 0..self.event_target.len() {
3587            let wi = self.weights[row_idx];
3588            if wi <= MIN_WEIGHT {
3589                continue;
3590            }
3591            let row =
3592                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3593            let (row_log_survival, survival_gradient, survival_hessian) =
3594                latent_survival_row_primary_gradient_hessian(
3595                    &self.quadctx,
3596                    &row,
3597                    q_entry[row_idx],
3598                    q_exit[row_idx],
3599                    1.0,
3600                    q_exit[row_idx],
3601                    mu[row_idx],
3602                    self.latent_sd,
3603                    false,
3604                )?;
3605            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3606            let direction = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_flat);
3607            let third = latent_survival_row_primary_third_contracted(
3608                &self.quadctx,
3609                &row,
3610                q_entry[row_idx],
3611                q_exit[row_idx],
3612                1.0,
3613                q_exit[row_idx],
3614                mu[row_idx],
3615                self.latent_sd,
3616                &direction,
3617                false,
3618            )?;
3619            let g_u = -survival_hessian.dot(&direction);
3620            let t_u = survival_gradient.dot(&direction);
3621            let mut primary = binary.grad_scale * third;
3622            primary.scaled_add(binary.grad_scale_prime * t_u, &survival_hessian);
3623            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3624                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3625                    primary[[a, b]] += binary.outer_scale_prime
3626                        * t_u
3627                        * survival_gradient[a]
3628                        * survival_gradient[b]
3629                        + binary.outer_scale
3630                            * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b]);
3631                }
3632            }
3633            self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3634        }
3635        Ok(out)
3636    }
3637
3638    fn exact_newton_joint_hessian_second_directional_derivative_dense(
3639        &self,
3640        block_states: &[ParameterBlockState],
3641        d_beta_u_flat: &Array1<f64>,
3642        d_beta_v_flat: &Array1<f64>,
3643    ) -> Result<Array2<f64>, String> {
3644        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3645        let slices = self.joint_slices();
3646        if d_beta_u_flat.len() != slices.total || d_beta_v_flat.len() != slices.total {
3647            return Err(format!(
3648                "latent binary joint d2H direction length mismatch: got {} and {}, expected {}",
3649                d_beta_u_flat.len(),
3650                d_beta_v_flat.len(),
3651                slices.total
3652            ));
3653        }
3654        let mut out = Array2::<f64>::zeros((slices.total, slices.total));
3655        for row_idx in 0..self.event_target.len() {
3656            let wi = self.weights[row_idx];
3657            if wi <= MIN_WEIGHT {
3658                continue;
3659            }
3660            let row =
3661                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3662            let (row_log_survival, survival_gradient, survival_hessian) =
3663                latent_survival_row_primary_gradient_hessian(
3664                    &self.quadctx,
3665                    &row,
3666                    q_entry[row_idx],
3667                    q_exit[row_idx],
3668                    1.0,
3669                    q_exit[row_idx],
3670                    mu[row_idx],
3671                    self.latent_sd,
3672                    false,
3673                )?;
3674            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3675            let direction_u = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_u_flat);
3676            let direction_v = self.row_primary_direction_from_flat(row_idx, &slices, d_beta_v_flat);
3677            let third_u = latent_survival_row_primary_third_contracted(
3678                &self.quadctx,
3679                &row,
3680                q_entry[row_idx],
3681                q_exit[row_idx],
3682                1.0,
3683                q_exit[row_idx],
3684                mu[row_idx],
3685                self.latent_sd,
3686                &direction_u,
3687                false,
3688            )?;
3689            let third_v = latent_survival_row_primary_third_contracted(
3690                &self.quadctx,
3691                &row,
3692                q_entry[row_idx],
3693                q_exit[row_idx],
3694                1.0,
3695                q_exit[row_idx],
3696                mu[row_idx],
3697                self.latent_sd,
3698                &direction_v,
3699                false,
3700            )?;
3701            let fourth = latent_survival_row_primary_fourth_contracted(
3702                &self.quadctx,
3703                &row,
3704                q_entry[row_idx],
3705                q_exit[row_idx],
3706                1.0,
3707                q_exit[row_idx],
3708                mu[row_idx],
3709                self.latent_sd,
3710                &direction_u,
3711                &direction_v,
3712                false,
3713            )?;
3714            let g_u = -survival_hessian.dot(&direction_u);
3715            let g_v = -survival_hessian.dot(&direction_v);
3716            let g_uv = -third_v.dot(&direction_u);
3717            let t_u = survival_gradient.dot(&direction_u);
3718            let t_v = survival_gradient.dot(&direction_v);
3719            let l_uv = -direction_u.dot(&survival_hessian.dot(&direction_v));
3720            let c_u = binary.grad_scale_prime * t_u;
3721            let c_v = binary.grad_scale_prime * t_v;
3722            let c_uv = binary.grad_scale_second * t_u * t_v + binary.grad_scale_prime * l_uv;
3723            let o_u = binary.outer_scale_prime * t_u;
3724            let o_v = binary.outer_scale_prime * t_v;
3725            let o_uv = binary.outer_scale_second * t_u * t_v + binary.outer_scale_prime * l_uv;
3726            let mut primary = binary.grad_scale * fourth;
3727            primary.scaled_add(c_u, &third_v);
3728            primary.scaled_add(c_v, &third_u);
3729            primary.scaled_add(c_uv, &survival_hessian);
3730            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3731                for b in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3732                    primary[[a, b]] += o_uv * survival_gradient[a] * survival_gradient[b]
3733                        + o_v * (g_u[a] * survival_gradient[b] + survival_gradient[a] * g_u[b])
3734                        + o_u * (g_v[a] * survival_gradient[b] + survival_gradient[a] * g_v[b])
3735                        + binary.outer_scale
3736                            * (g_uv[a] * survival_gradient[b]
3737                                + g_u[a] * g_v[b]
3738                                + g_v[a] * g_u[b]
3739                                + survival_gradient[a] * g_uv[b]);
3740                }
3741            }
3742            self.add_pullback_primary_hessian(&mut out, row_idx, &slices, &(wi * primary));
3743        }
3744        Ok(out)
3745    }
3746}
3747
3748/// Shared interface that both `LatentSurvivalFamily` and `LatentBinaryFamily`
3749/// expose to the joint Hessian workspace.
3750///
3751/// The two families produce the same `ExactNewtonJointHessianWorkspace`
3752/// shape — five of the six workspace methods are pure delegations to a
3753/// matching family method (dense evaluation, directional derivatives, and the
3754/// `slices` cache). The only family-specific piece is the per-row matvec body:
3755/// the survival family iterates over real (entry, exit, ḋ) triples and may
3756/// carry a log-σ block, while the binary family rewrites the same row kernel
3757/// through `binary_from_log_survival(·)` to recover the per-row binary
3758/// gradient/Hessian. That single difference is captured by `ws_matvec_into`;
3759/// every other method is shared by the generic `LatentHessianWorkspace<F>`
3760/// below.
3761trait LatentJointHessianFamily {
3762    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices;
3763
3764    fn ws_evaluate_dense(
3765        &self,
3766        block_states: &[ParameterBlockState],
3767    ) -> Result<(f64, Array1<f64>, Array2<f64>), String>;
3768
3769    fn ws_dh_directional(
3770        &self,
3771        block_states: &[ParameterBlockState],
3772        d_beta_flat: &Array1<f64>,
3773    ) -> Result<Array2<f64>, String>;
3774
3775    fn ws_dh_second_directional(
3776        &self,
3777        block_states: &[ParameterBlockState],
3778        d_beta_u: &Array1<f64>,
3779        d_beta_v: &Array1<f64>,
3780    ) -> Result<Array2<f64>, String>;
3781
3782    /// Family-specific per-row Hessian matvec body, hoisted out of the
3783    /// workspace impl. Writes `out := H · v` (with `out.fill(0.0)` already
3784    /// performed by the caller) using the family's row kernel.
3785    fn ws_matvec_into(
3786        &self,
3787        slices: &LatentSurvivalJointSlices,
3788        block_states: &[ParameterBlockState],
3789        v: &Array1<f64>,
3790        out: &mut Array1<f64>,
3791    ) -> Result<bool, String>;
3792
3793    /// Family-name fragment used in the workspace's dimension-mismatch error
3794    /// message, so callers still see "latent survival …" / "latent binary …"
3795    /// after the workspace impl was unified.
3796    fn ws_label() -> &'static str;
3797}
3798
3799impl LatentJointHessianFamily for LatentSurvivalFamily {
3800    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3801        self.joint_slices()
3802    }
3803
3804    fn ws_evaluate_dense(
3805        &self,
3806        block_states: &[ParameterBlockState],
3807    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3808        self.evaluate_exact_newton_joint_dense(block_states)
3809    }
3810
3811    fn ws_dh_directional(
3812        &self,
3813        block_states: &[ParameterBlockState],
3814        d_beta_flat: &Array1<f64>,
3815    ) -> Result<Array2<f64>, String> {
3816        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3817    }
3818
3819    fn ws_dh_second_directional(
3820        &self,
3821        block_states: &[ParameterBlockState],
3822        d_beta_u: &Array1<f64>,
3823        d_beta_v: &Array1<f64>,
3824    ) -> Result<Array2<f64>, String> {
3825        self.exact_newton_joint_hessian_second_directional_derivative_dense(
3826            block_states,
3827            d_beta_u,
3828            d_beta_v,
3829        )
3830    }
3831
3832    fn ws_matvec_into(
3833        &self,
3834        slices: &LatentSurvivalJointSlices,
3835        block_states: &[ParameterBlockState],
3836        v: &Array1<f64>,
3837        out: &mut Array1<f64>,
3838    ) -> Result<bool, String> {
3839        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
3840        let q_right = self.time_q_right(block_states)?;
3841        let sigma = self.latent_sd(block_states)?;
3842        let include_log_sigma = slices.log_sigma.is_some();
3843        for row_idx in 0..self.event_target.len() {
3844            let wi = self.weights[row_idx];
3845            if wi <= MIN_WEIGHT {
3846                continue;
3847            }
3848            let row = self.build_row_at(
3849                row_idx,
3850                q_entry[row_idx],
3851                q_exit[row_idx],
3852                qdot_exit[row_idx],
3853                q_right[row_idx],
3854            )?;
3855            let (_, _, primary_hessian) = latent_survival_row_primary_gradient_hessian(
3856                &self.quadctx,
3857                &row,
3858                q_entry[row_idx],
3859                q_exit[row_idx],
3860                qdot_exit[row_idx],
3861                q_right[row_idx],
3862                mu[row_idx],
3863                sigma,
3864                include_log_sigma,
3865            )?;
3866            let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3867            let primary_hv = primary_hessian.dot(&primary_dir);
3868            self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi)?;
3869        }
3870        Ok(true)
3871    }
3872
3873    fn ws_label() -> &'static str {
3874        "survival"
3875    }
3876}
3877
3878impl LatentJointHessianFamily for LatentBinaryFamily {
3879    fn ws_joint_slices(&self) -> LatentSurvivalJointSlices {
3880        self.joint_slices()
3881    }
3882
3883    fn ws_evaluate_dense(
3884        &self,
3885        block_states: &[ParameterBlockState],
3886    ) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
3887        self.evaluate_exact_newton_joint_dense(block_states)
3888    }
3889
3890    fn ws_dh_directional(
3891        &self,
3892        block_states: &[ParameterBlockState],
3893        d_beta_flat: &Array1<f64>,
3894    ) -> Result<Array2<f64>, String> {
3895        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
3896    }
3897
3898    fn ws_dh_second_directional(
3899        &self,
3900        block_states: &[ParameterBlockState],
3901        d_beta_u: &Array1<f64>,
3902        d_beta_v: &Array1<f64>,
3903    ) -> Result<Array2<f64>, String> {
3904        self.exact_newton_joint_hessian_second_directional_derivative_dense(
3905            block_states,
3906            d_beta_u,
3907            d_beta_v,
3908        )
3909    }
3910
3911    fn ws_matvec_into(
3912        &self,
3913        slices: &LatentSurvivalJointSlices,
3914        block_states: &[ParameterBlockState],
3915        v: &Array1<f64>,
3916        out: &mut Array1<f64>,
3917    ) -> Result<bool, String> {
3918        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
3919        for row_idx in 0..self.event_target.len() {
3920            let wi = self.weights[row_idx];
3921            if wi <= MIN_WEIGHT {
3922                continue;
3923            }
3924            let row =
3925                self.build_right_censored_row_at(row_idx, q_entry[row_idx], q_exit[row_idx])?;
3926            let (row_log_survival, survival_gradient, survival_hessian) =
3927                latent_survival_row_primary_gradient_hessian(
3928                    &self.quadctx,
3929                    &row,
3930                    q_entry[row_idx],
3931                    q_exit[row_idx],
3932                    1.0,
3933                    q_exit[row_idx],
3934                    mu[row_idx],
3935                    self.latent_sd,
3936                    false,
3937                )?;
3938            let binary = binary_from_log_survival(row_log_survival, self.event_target[row_idx])?;
3939            let primary_dir = self.row_primary_direction_from_flat(row_idx, slices, v);
3940            let mut primary_hv = binary.grad_scale * survival_hessian.dot(&primary_dir);
3941            let outer_dot = survival_gradient.dot(&primary_dir);
3942            for a in 0..LATENT_SURVIVAL_PRIMARY_DIM {
3943                primary_hv[a] += binary.outer_scale * survival_gradient[a] * outer_dot;
3944            }
3945            self.add_pullback_primary_gradient(out, row_idx, slices, &primary_hv, wi);
3946        }
3947        Ok(true)
3948    }
3949
3950    fn ws_label() -> &'static str {
3951        "binary"
3952    }
3953}
3954
3955/// Joint exact-Newton Hessian workspace shared by `LatentSurvivalFamily` and
3956/// `LatentBinaryFamily`. The two families plug into the workspace via
3957/// `LatentJointHessianFamily`; this struct holds the shared bookkeeping
3958/// (block states + cached slices) and routes every trait method either through
3959/// a thin family delegation or through the family's `ws_matvec_into` row
3960/// kernel.
3961struct LatentHessianWorkspace<F: LatentJointHessianFamily> {
3962    family: F,
3963    block_states: Vec<ParameterBlockState>,
3964    slices: LatentSurvivalJointSlices,
3965}
3966
3967impl<F: LatentJointHessianFamily> LatentHessianWorkspace<F> {
3968    fn new(family: F, block_states: Vec<ParameterBlockState>) -> Self {
3969        let slices = family.ws_joint_slices();
3970        Self {
3971            family,
3972            block_states,
3973            slices,
3974        }
3975    }
3976}
3977
3978impl<F> ExactNewtonJointHessianWorkspace for LatentHessianWorkspace<F>
3979where
3980    F: LatentJointHessianFamily + Send + Sync + 'static,
3981{
3982    fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
3983        self.family
3984            .ws_evaluate_dense(&self.block_states)
3985            .map(|(_, _, hessian)| Some(hessian))
3986    }
3987
3988    fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
3989        let mut out = Array1::<f64>::zeros(self.slices.total);
3990        self.hessian_matvec_into(v, &mut out)?;
3991        Ok(Some(out))
3992    }
3993
3994    fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
3995        if v.len() != self.slices.total || out.len() != self.slices.total {
3996            return Err(format!(
3997                "latent {} Hessian matvec dimension mismatch: v={} out={} expected={}",
3998                F::ws_label(),
3999                v.len(),
4000                out.len(),
4001                self.slices.total
4002            ));
4003        }
4004        out.fill(0.0);
4005        self.family
4006            .ws_matvec_into(&self.slices, &self.block_states, v, out)
4007    }
4008
4009    fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
4010        let dense = self.family.ws_evaluate_dense(&self.block_states)?.2;
4011        Ok(Some(dense.diag().to_owned()))
4012    }
4013
4014    fn directional_derivative(
4015        &self,
4016        d_beta_flat: &Array1<f64>,
4017    ) -> Result<Option<Array2<f64>>, String> {
4018        self.family
4019            .ws_dh_directional(&self.block_states, d_beta_flat)
4020            .map(Some)
4021    }
4022
4023    fn second_directional_derivative(
4024        &self,
4025        d_beta_u: &Array1<f64>,
4026        d_beta_v: &Array1<f64>,
4027    ) -> Result<Option<Array2<f64>>, String> {
4028        self.family
4029            .ws_dh_second_directional(&self.block_states, d_beta_u, d_beta_v)
4030            .map(Some)
4031    }
4032}
4033
4034type LatentSurvivalHessianWorkspace = LatentHessianWorkspace<LatentSurvivalFamily>;
4035type LatentBinaryHessianWorkspace = LatentHessianWorkspace<LatentBinaryFamily>;
4036
4037impl CustomFamily for LatentSurvivalFamily {
4038    // Latent survival fits keep the self-limiting Jeffreys/Firth curvature
4039    // active for their under-identification regime. The trait default flipped to
4040    // OFF in gam#1395 (flat-prior exact-Newton objective); opt back in here.
4041    fn joint_jeffreys_term_required(&self) -> bool {
4042        true
4043    }
4044
4045    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4046        true
4047    }
4048
4049    fn has_explicit_joint_hessian(&self) -> bool {
4050        true
4051    }
4052
4053    /// Engage the inner self-vanishing Levenberg–Marquardt μ on a full-rank but
4054    /// indefinite / ill-conditioned penalized joint Hessian, mirroring the
4055    /// sibling [`SurvivalMarginalSlopeFamily`]. Interval-censored rows contribute
4056    /// `ℓ = log[S(L) − S(R)]`, the log of a DIFFERENCE of two survival kernels:
4057    /// unlike the log-concave exact-event / right-censored contributions, its
4058    /// per-row Hessian is legitimately INDEFINITE away from the optimum, so the
4059    /// coupled exact-joint penalized Hessian on the constrained (monotone-cone)
4060    /// time block can be full-rank (`nullity == 0`) yet indefinite or severely
4061    /// ill-conditioned at the cold-start seed. The constrained-QP path already
4062    /// REFLECTS negative-curvature modes to `|λ|` (a convex modified-Newton
4063    /// model), but with this gate OFF it adds NO diagonal floor on a full-rank
4064    /// ill-conditioned reflected model, so the trust-region Newton oscillates on
4065    /// the near-singular mode and stalls out the inner budget before any KKT
4066    /// snapshot is taken ("exited the joint Newton path before convergence — no
4067    /// math snapshot"). Arming the gate adds the SAME self-vanishing μ
4068    /// (∝ the projected KKT residual `‖∇ℓ − Sβ + ∇Φ‖` → 0 at the fixed point) the
4069    /// marginal-slope survival inner relies on, so the step is a well-damped
4070    /// modified-Newton descent that converges, while the converged β̂ is the
4071    /// EXACT unconditioned optimum (μ → 0 there) — zero REML/LAML bias, exact
4072    /// gradient unchanged.
4073    fn levenberg_on_ill_conditioning(&self) -> bool {
4074        true
4075    }
4076
4077    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4078        // `evaluate_exact_newton_joint_dense` builds a fully dense joint
4079        // Hessian over (Σ p_b)² across time, mean, and optional log-σ blocks
4080        // via per-row pullback of the latent-survival primary kernel.
4081        crate::custom_family::joint_coupled_coefficient_hessian_cost(
4082            self.event_target.len() as u64,
4083            specs,
4084        )
4085    }
4086
4087    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4088        let (ll, joint_gradient, hess_time, hess_mean, hess_log_sigma) =
4089            self.evaluate_exact_newton_block_diagonals(block_states)?;
4090        let block_ranges = self.joint_block_ranges();
4091        let mut blockworking_sets = vec![
4092            BlockWorkingSet::ExactNewton {
4093                gradient: joint_gradient.slice(s![block_ranges[0].clone()]).to_owned(),
4094                hessian: SymmetricMatrix::Dense(hess_time),
4095            },
4096            BlockWorkingSet::ExactNewton {
4097                gradient: joint_gradient.slice(s![block_ranges[1].clone()]).to_owned(),
4098                hessian: SymmetricMatrix::Dense(hess_mean),
4099            },
4100        ];
4101        if let (Some(range), Some(hessian)) = (block_ranges.get(2).cloned(), hess_log_sigma) {
4102            blockworking_sets.push(BlockWorkingSet::ExactNewton {
4103                gradient: joint_gradient.slice(s![range]).to_owned(),
4104                hessian: SymmetricMatrix::Dense(hessian),
4105            });
4106        }
4107        Ok(FamilyEvaluation {
4108            log_likelihood: ll,
4109            blockworking_sets,
4110        })
4111    }
4112
4113    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4114        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4115        let (q_entry, q_exit, qdot_exit, mu) = self.split_time_eta(block_states)?;
4116        let q_right = self.time_q_right(block_states)?;
4117        let latent_sd = self.latent_sd(block_states)?;
4118        let n = self.event_target.len();
4119        // Per-row latent-survival jet + log-lik contribution. Independent
4120        // across rows; sum via parallel reduce. `?` propagation happens
4121        // through a Result-collecting fold.
4122        let contributions: Result<Vec<f64>, String> = (0..n)
4123            .into_par_iter()
4124            .map(|i| -> Result<f64, String> {
4125                let wi = self.weights[i];
4126                if wi <= MIN_WEIGHT {
4127                    return Ok(0.0);
4128                }
4129                let row = self.build_row_at(i, q_entry[i], q_exit[i], qdot_exit[i], q_right[i])?;
4130                let jet = LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], latent_sd)
4131                    .map_err(|e| format!("LatentSurvivalFamily row {i}: {e}"))?;
4132                Ok(wi * jet.log_lik)
4133            })
4134            .collect();
4135        Ok(contributions?.into_iter().sum())
4136    }
4137
4138    fn block_linear_constraints(
4139        &self,
4140        _: &[ParameterBlockState],
4141        block_idx: usize,
4142        block_spec: &ParameterBlockSpec,
4143    ) -> Result<Option<LinearInequalityConstraints>, String> {
4144        assert!(!block_spec.name.is_empty());
4145        if block_idx == Self::BLOCK_TIME {
4146            Ok(self.time_linear_constraints.clone())
4147        } else {
4148            Ok(None)
4149        }
4150    }
4151
4152    fn exact_newton_joint_hessian(
4153        &self,
4154        block_states: &[ParameterBlockState],
4155    ) -> Result<Option<Array2<f64>>, String> {
4156        self.evaluate_exact_newton_joint_dense(block_states)
4157            .map(|(_, _, hessian)| Some(hessian))
4158    }
4159
4160    fn exact_newton_joint_hessian_workspace(
4161        &self,
4162        block_states: &[ParameterBlockState],
4163        _: &[ParameterBlockSpec],
4164    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4165        Ok(Some(Arc::new(LatentSurvivalHessianWorkspace::new(
4166            self.clone(),
4167            block_states.to_vec(),
4168        ))))
4169    }
4170
4171    fn exact_newton_joint_gradient_evaluation(
4172        &self,
4173        block_states: &[ParameterBlockState],
4174        _: &[ParameterBlockSpec],
4175    ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4176        self.evaluate_exact_newton_joint_gradient_dense(block_states)
4177            .map(|(log_likelihood, gradient)| {
4178                Some(ExactNewtonJointGradientEvaluation {
4179                    log_likelihood,
4180                    gradient,
4181                })
4182            })
4183    }
4184
4185    fn exact_newton_joint_hessian_directional_derivative(
4186        &self,
4187        block_states: &[ParameterBlockState],
4188        d_beta_flat: &Array1<f64>,
4189    ) -> Result<Option<Array2<f64>>, String> {
4190        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4191            .map(Some)
4192    }
4193
4194    fn exact_newton_joint_hessiansecond_directional_derivative(
4195        &self,
4196        block_states: &[ParameterBlockState],
4197        d_beta_u_flat: &Array1<f64>,
4198        d_beta_v_flat: &Array1<f64>,
4199    ) -> Result<Option<Array2<f64>>, String> {
4200        self.exact_newton_joint_hessian_second_directional_derivative_dense(
4201            block_states,
4202            d_beta_u_flat,
4203            d_beta_v_flat,
4204        )
4205        .map(Some)
4206    }
4207
4208    fn requires_joint_outer_hyper_path(&self) -> bool {
4209        true
4210    }
4211}
4212
4213impl CustomFamily for LatentBinaryFamily {
4214    // Latent binary fits have a separation regime; keep the self-limiting
4215    // Jeffreys/Firth curvature active. The trait default flipped to OFF in
4216    // gam#1395 (flat-prior exact-Newton objective); opt back in here.
4217    fn joint_jeffreys_term_required(&self) -> bool {
4218        true
4219    }
4220
4221    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
4222        true
4223    }
4224
4225    fn has_explicit_joint_hessian(&self) -> bool {
4226        true
4227    }
4228
4229    /// Same self-vanishing Levenberg–Marquardt gate as
4230    /// [`LatentSurvivalFamily`]: the latent-binary deployment shares the
4231    /// constrained (monotone-cone) coupled time block, so a full-rank but
4232    /// ill-conditioned penalized joint Hessian at the cold-start seed must get
4233    /// the self-vanishing μ floor rather than oscillating the constrained-QP
4234    /// trust region into a snapshot-less stall. μ → 0 at the fixed point, so the
4235    /// converged β̂ is exact (no REML/LAML bias).
4236    fn levenberg_on_ill_conditioning(&self) -> bool {
4237        true
4238    }
4239
4240    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
4241        crate::custom_family::joint_coupled_coefficient_hessian_cost(
4242            self.event_target.len() as u64,
4243            specs,
4244        )
4245    }
4246
4247    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
4248        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4249        let n = self.event_target.len();
4250        let p_time = self.x_time_exit.ncols();
4251        let p_mean = self.x_mean.ncols();
4252
4253        let mut ll = 0.0;
4254        let mut grad_time = Array1::<f64>::zeros(p_time);
4255        let mut hess_time = Array2::<f64>::zeros((p_time, p_time));
4256        let mut grad_mean = Array1::<f64>::zeros(p_mean);
4257        let mut hess_mean = Array2::<f64>::zeros((p_mean, p_mean));
4258        // Reusable 1-row buffer for x_mean so we avoid allocating a fresh
4259        // Array2<f64> on every iteration via try_row_chunk(i..i+1).
4260        let mut mean_row_buf = Array2::<f64>::zeros((1, p_mean));
4261
4262        for i in 0..n {
4263            let wi = self.weights[i];
4264            if wi <= MIN_WEIGHT {
4265                continue;
4266            }
4267            if !(q_entry[i].is_finite() && q_exit[i].is_finite() && mu[i].is_finite()) {
4268                return Err(format!(
4269                    "latent-binary row {i} contains non-finite predictors: q_entry={}, q_exit={}, mu={}",
4270                    q_entry[i], q_exit[i], mu[i]
4271                ));
4272            }
4273            let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4274            let survival_jet =
4275                LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4276                    .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4277            let binary = binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?;
4278            ll += wi * binary.log_lik;
4279
4280            self.x_mean
4281                .row_chunk_into(i..i + 1, mean_row_buf.view_mut())
4282                .map_err(|e| format!("LatentBinaryFamily row {i} mean row_chunk: {e}"))?;
4283            let mean_vec = mean_row_buf.row(0);
4284            let mean_grad_scale = wi * binary.grad_scale * survival_jet.score;
4285            for j in 0..p_mean {
4286                grad_mean[j] += mean_grad_scale * mean_vec[j];
4287            }
4288            let mean_neg_hess = wi
4289                * (binary.neg_hess_scale * survival_jet.neg_hessian
4290                    + binary.outer_scale * survival_jet.score * survival_jet.score);
4291            dense_outer_accumulate(&mut hess_mean, mean_neg_hess, mean_vec);
4292
4293            let time_jet =
4294                latent_survival_time_jet(&self.quadctx, &row, 0.0, mu[i], self.latent_sd)?;
4295            let t_entry = self.x_time_entry.row(i);
4296            let t_exit = self.x_time_exit.row(i);
4297            for j in 0..p_time {
4298                grad_time[j] += wi
4299                    * binary.grad_scale
4300                    * (time_jet.grad_entry * t_entry[j] + time_jet.grad_exit * t_exit[j]);
4301            }
4302            dense_outer_accumulate(
4303                &mut hess_time,
4304                wi * binary.neg_hess_scale * time_jet.neg_hess_entry,
4305                t_entry,
4306            );
4307            dense_outer_accumulate(
4308                &mut hess_time,
4309                wi * binary.neg_hess_scale * time_jet.neg_hess_exit,
4310                t_exit,
4311            );
4312            if binary.outer_scale != 0.0 {
4313                dense_outer_accumulate(
4314                    &mut hess_time,
4315                    wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_entry,
4316                    t_entry,
4317                );
4318                dense_outer_accumulate(
4319                    &mut hess_time,
4320                    wi * binary.outer_scale * time_jet.grad_exit * time_jet.grad_exit,
4321                    t_exit,
4322                );
4323                dense_symmetric_cross_accumulate(
4324                    &mut hess_time,
4325                    wi * binary.outer_scale * time_jet.grad_entry * time_jet.grad_exit,
4326                    t_entry,
4327                    t_exit,
4328                );
4329            }
4330        }
4331
4332        Ok(FamilyEvaluation {
4333            log_likelihood: ll,
4334            blockworking_sets: vec![
4335                BlockWorkingSet::ExactNewton {
4336                    gradient: grad_time,
4337                    hessian: SymmetricMatrix::Dense(hess_time),
4338                },
4339                BlockWorkingSet::ExactNewton {
4340                    gradient: grad_mean,
4341                    hessian: SymmetricMatrix::Dense(hess_mean),
4342                },
4343            ],
4344        })
4345    }
4346
4347    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
4348        let (q_entry, q_exit, mu) = self.split_time_eta(block_states)?;
4349        let mut ll = 0.0;
4350        for i in 0..self.event_target.len() {
4351            let wi = self.weights[i];
4352            if wi <= MIN_WEIGHT {
4353                continue;
4354            }
4355            let row = self.build_right_censored_row_at(i, q_entry[i], q_exit[i])?;
4356            let survival_jet =
4357                LatentSurvivalRowJet::evaluate(&self.quadctx, &row, mu[i], self.latent_sd)
4358                    .map_err(|e| format!("LatentBinaryFamily row {i}: {e}"))?;
4359            ll +=
4360                wi * binary_from_log_survival(survival_jet.log_lik, self.event_target[i])?.log_lik;
4361        }
4362        Ok(ll)
4363    }
4364
4365    fn block_linear_constraints(
4366        &self,
4367        _: &[ParameterBlockState],
4368        block_idx: usize,
4369        block_spec: &ParameterBlockSpec,
4370    ) -> Result<Option<LinearInequalityConstraints>, String> {
4371        assert!(!block_spec.name.is_empty());
4372        if block_idx == Self::BLOCK_TIME {
4373            Ok(self.time_linear_constraints.clone())
4374        } else {
4375            Ok(None)
4376        }
4377    }
4378
4379    fn exact_newton_joint_hessian(
4380        &self,
4381        block_states: &[ParameterBlockState],
4382    ) -> Result<Option<Array2<f64>>, String> {
4383        self.evaluate_exact_newton_joint_dense(block_states)
4384            .map(|(_, _, hessian)| Some(hessian))
4385    }
4386
4387    fn exact_newton_joint_hessian_workspace(
4388        &self,
4389        block_states: &[ParameterBlockState],
4390        _: &[ParameterBlockSpec],
4391    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
4392        Ok(Some(Arc::new(LatentBinaryHessianWorkspace::new(
4393            self.clone(),
4394            block_states.to_vec(),
4395        ))))
4396    }
4397
4398    fn exact_newton_joint_gradient_evaluation(
4399        &self,
4400        block_states: &[ParameterBlockState],
4401        _: &[ParameterBlockSpec],
4402    ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
4403        self.evaluate_exact_newton_joint_dense(block_states)
4404            .map(|(log_likelihood, gradient, _)| {
4405                Some(ExactNewtonJointGradientEvaluation {
4406                    log_likelihood,
4407                    gradient,
4408                })
4409            })
4410    }
4411
4412    fn exact_newton_joint_hessian_directional_derivative(
4413        &self,
4414        block_states: &[ParameterBlockState],
4415        d_beta_flat: &Array1<f64>,
4416    ) -> Result<Option<Array2<f64>>, String> {
4417        self.exact_newton_joint_hessian_directional_derivative_dense(block_states, d_beta_flat)
4418            .map(Some)
4419    }
4420
4421    fn exact_newton_joint_hessiansecond_directional_derivative(
4422        &self,
4423        block_states: &[ParameterBlockState],
4424        d_beta_u_flat: &Array1<f64>,
4425        d_beta_v_flat: &Array1<f64>,
4426    ) -> Result<Option<Array2<f64>>, String> {
4427        self.exact_newton_joint_hessian_second_directional_derivative_dense(
4428            block_states,
4429            d_beta_u_flat,
4430            d_beta_v_flat,
4431        )
4432        .map(Some)
4433    }
4434
4435    fn requires_joint_outer_hyper_path(&self) -> bool {
4436        true
4437    }
4438}
4439
4440#[cfg(test)]
4441mod tests {
4442    use super::*;
4443    use crate::custom_family::BlockWorkingSet;
4444    use gam_linalg::matrix::DenseDesignMatrix;
4445    use ndarray::array;
4446
4447    fn learnable_sigma_test_family() -> LatentSurvivalFamily {
4448        LatentSurvivalFamily {
4449            event_target: array![1u8, 0u8],
4450            weights: array![1.0, 0.7],
4451            latent_sd_fixed: None,
4452            hazard_loading: HazardLoading::LoadedVsUnloaded,
4453            unloaded_mass_entry: array![0.02, 0.03],
4454            unloaded_mass_exit: array![0.05, 0.08],
4455            unloaded_hazard_exit: array![0.04, 0.0],
4456            x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4457            x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4458            x_time_derivative_exit: array![[0.8, 0.4], [0.6, 0.5]],
4459            x_time_right: array![[1.3, 0.1], [0.9, 1.0]],
4460            time_offset_right: Array1::zeros(2),
4461            unloaded_mass_right: Array1::zeros(2),
4462            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4463            time_linear_constraints: None,
4464            quadctx: Arc::new(QuadratureContext::new()),
4465        }
4466    }
4467
4468    fn learnable_sigma_test_joint_beta() -> Array1<f64> {
4469        array![0.15, 0.25, 0.1, -0.15, 0.35_f64.ln()]
4470    }
4471
4472    fn survival_stress_test_family(n: usize) -> LatentSurvivalFamily {
4473        LatentSurvivalFamily {
4474            event_target: Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1u8 } else { 0u8 })),
4475            weights: Array1::from_iter((0..n).map(|i| 0.55 + 0.03 * ((i % 7) as f64))),
4476            latent_sd_fixed: None,
4477            hazard_loading: HazardLoading::LoadedVsUnloaded,
4478            unloaded_mass_entry: Array1::from_iter(
4479                (0..n).map(|i| 0.015 + 0.0015 * ((i % 11) as f64)),
4480            ),
4481            unloaded_mass_exit: Array1::from_iter((0..n).map(|i| 0.06 + 0.002 * ((i % 13) as f64))),
4482            unloaded_hazard_exit: Array1::from_iter((0..n).map(|i| {
4483                if i % 4 == 0 {
4484                    0.018 + 0.001 * ((i % 5) as f64)
4485                } else {
4486                    0.0
4487                }
4488            })),
4489            x_time_entry: Array2::from_shape_fn((n, 4), |(i, j)| {
4490                0.2 + 0.03 * ((i + 2 * j) % 9) as f64 - if j == 1 { 0.12 } else { 0.0 }
4491            }),
4492            x_time_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4493                0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4494            }),
4495            x_time_derivative_exit: Array2::from_shape_fn((n, 4), |(i, j)| {
4496                0.45 + 0.015 * ((i + 3 * j) % 8) as f64
4497            }),
4498            x_time_right: Array2::from_shape_fn((n, 4), |(i, j)| {
4499                0.35 + 0.025 * ((2 * i + j) % 10) as f64 - if j == 2 { 0.08 } else { 0.0 }
4500            }),
4501            time_offset_right: Array1::zeros(n),
4502            unloaded_mass_right: Array1::zeros(n),
4503            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_shape_fn(
4504                (n, 3),
4505                |(i, j)| 0.1 + 0.04 * ((3 * i + j) % 7) as f64 - if j == 0 { 0.18 } else { 0.0 },
4506            ))),
4507            time_linear_constraints: None,
4508            quadctx: Arc::new(QuadratureContext::new()),
4509        }
4510    }
4511
4512    fn survival_stress_test_joint_beta() -> Array1<f64> {
4513        array![0.18, 0.11, 0.07, 0.13, -0.09, 0.05, 0.12, 0.42_f64.ln()]
4514    }
4515
4516    fn latent_survival_states_from_joint_beta(
4517        family: &LatentSurvivalFamily,
4518        joint_beta: &Array1<f64>,
4519    ) -> Vec<ParameterBlockState> {
4520        let slices = family.joint_slices();
4521        let n = family.event_target.len();
4522        let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4523        let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4524
4525        let mut eta_time = Array1::<f64>::zeros(3 * n);
4526        eta_time
4527            .slice_mut(s![0..n])
4528            .assign(&gam_linalg::faer_ndarray::fast_av(
4529                &family.x_time_entry,
4530                &beta_time,
4531            ));
4532        eta_time
4533            .slice_mut(s![n..2 * n])
4534            .assign(&gam_linalg::faer_ndarray::fast_av(
4535                &family.x_time_exit,
4536                &beta_time,
4537            ));
4538        eta_time
4539            .slice_mut(s![2 * n..3 * n])
4540            .assign(&gam_linalg::faer_ndarray::fast_av(
4541                &family.x_time_derivative_exit,
4542                &beta_time,
4543            ));
4544
4545        let mut states = vec![
4546            ParameterBlockState {
4547                beta: beta_time,
4548                eta: eta_time,
4549            },
4550            ParameterBlockState {
4551                beta: beta_mean.clone(),
4552                eta: family.x_mean.dot(&beta_mean),
4553            },
4554        ];
4555        if let Some(log_sigma) = slices.log_sigma {
4556            let beta_log_sigma = array![joint_beta[log_sigma.start]];
4557            states.push(ParameterBlockState {
4558                beta: beta_log_sigma.clone(),
4559                eta: beta_log_sigma,
4560            });
4561        }
4562        states
4563    }
4564
4565    fn max_relative_array1(left: &Array1<f64>, right: &Array1<f64>) -> f64 {
4566        left.iter()
4567            .zip(right.iter())
4568            .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4569            .fold(0.0_f64, f64::max)
4570    }
4571
4572    fn max_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4573        left.iter()
4574            .zip(right.iter())
4575            .map(|(l, r)| (l - r).abs() / l.abs().max(r.abs()).max(1e-12))
4576            .fold(0.0_f64, f64::max)
4577    }
4578
4579    fn frobenius_relative_array2(left: &Array2<f64>, right: &Array2<f64>) -> f64 {
4580        let mut diff2 = 0.0_f64;
4581        let mut scale2 = 0.0_f64;
4582        for (l, r) in left.iter().zip(right.iter()) {
4583            let d = l - r;
4584            diff2 += d * d;
4585            scale2 += l * l + r * r;
4586        }
4587        diff2.sqrt() / scale2.sqrt().max(1e-12)
4588    }
4589
4590    fn latent_survival_row_loglik_from_primary(
4591        quadctx: &QuadratureContext,
4592        row: &LatentSurvivalRow,
4593        primary: &Array1<f64>,
4594    ) -> f64 {
4595        let q_entry = primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY];
4596        let q_exit = primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT];
4597        let qdot_exit = primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT];
4598        let q_right = primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT];
4599        let mu = primary[LATENT_SURVIVAL_PRIMARY_MU];
4600        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
4601        latent_survival_row_primary_gradient_hessian(
4602            quadctx, row, q_entry, q_exit, qdot_exit, q_right, mu, sigma, true,
4603        )
4604        .expect("row primary evaluation")
4605        .0
4606    }
4607
4608    fn latent_test_specs(n: usize, block_dims: &[(&str, usize)]) -> Vec<ParameterBlockSpec> {
4609        block_dims
4610            .iter()
4611            .map(|(name, p)| ParameterBlockSpec {
4612                name: (*name).to_string(),
4613                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((n, *p)))),
4614                offset: Array1::zeros(n),
4615                penalties: Vec::new(),
4616                nullspace_dims: Vec::new(),
4617                initial_log_lambdas: Array1::zeros(0),
4618                initial_beta: None,
4619                gauge_priority: 100,
4620                jacobian_callback: None,
4621                stacked_design: None,
4622                stacked_offset: None,
4623            })
4624            .collect()
4625    }
4626
4627    fn fixed_sigma_binary_test_family() -> LatentBinaryFamily {
4628        LatentBinaryFamily {
4629            event_target: array![1u8, 0u8],
4630            weights: array![1.0, 0.7],
4631            latent_sd: 0.35,
4632            hazard_loading: HazardLoading::LoadedVsUnloaded,
4633            unloaded_mass_entry: array![0.02, 0.03],
4634            unloaded_mass_exit: array![0.05, 0.08],
4635            x_time_entry: array![[1.0, -0.2], [0.4, 0.7]],
4636            x_time_exit: array![[1.3, 0.1], [0.9, 1.0]],
4637            x_mean: DesignMatrix::Dense(DenseDesignMatrix::from(array![[1.0, -0.3], [0.2, 0.9]])),
4638            time_linear_constraints: None,
4639            quadctx: Arc::new(QuadratureContext::new()),
4640        }
4641    }
4642
4643    fn latent_binary_states_from_joint_beta(
4644        family: &LatentBinaryFamily,
4645        joint_beta: &Array1<f64>,
4646    ) -> Vec<ParameterBlockState> {
4647        let slices = family.joint_slices();
4648        let n = family.event_target.len();
4649        let beta_time = joint_beta.slice(s![slices.time.clone()]).to_owned();
4650        let beta_mean = joint_beta.slice(s![slices.mean.clone()]).to_owned();
4651
4652        let mut eta_time = Array1::<f64>::zeros(3 * n);
4653        eta_time
4654            .slice_mut(s![0..n])
4655            .assign(&gam_linalg::faer_ndarray::fast_av(
4656                &family.x_time_entry,
4657                &beta_time,
4658            ));
4659        eta_time
4660            .slice_mut(s![n..2 * n])
4661            .assign(&gam_linalg::faer_ndarray::fast_av(
4662                &family.x_time_exit,
4663                &beta_time,
4664            ));
4665
4666        vec![
4667            ParameterBlockState {
4668                beta: beta_time,
4669                eta: eta_time,
4670            },
4671            ParameterBlockState {
4672                beta: beta_mean.clone(),
4673                eta: family.x_mean.dot(&beta_mean),
4674            },
4675        ]
4676    }
4677
4678    // --- shared latent-interval validation engine: parity / contract tests ---
4679
4680    use crate::survival::location_scale::{TimeBlockInput, TimeBlockMonotonicity};
4681
4682    /// Minimal, structurally valid `TimeBlockInput` for `n` rows and `p_time`
4683    /// columns, used to exercise the shared validation driver without standing
4684    /// up a full term-collection design.
4685    fn validation_time_block(n: usize, p_time: usize) -> TimeBlockInput {
4686        let design = |fill: f64| {
4687            DesignMatrix::Dense(DenseDesignMatrix::from(Array2::from_elem(
4688                (n, p_time),
4689                fill,
4690            )))
4691        };
4692        TimeBlockInput {
4693            design_entry: design(0.1),
4694            design_exit: design(0.2),
4695            design_derivative_exit: design(0.3),
4696            offset_entry: Array1::zeros(n),
4697            offset_exit: Array1::zeros(n),
4698            derivative_offset_exit: Array1::zeros(n),
4699            time_monotonicity: TimeBlockMonotonicity::EnforcedByCoordinateCone,
4700            penalties: Vec::new(),
4701            nullspace_dims: Vec::new(),
4702            initial_log_lambdas: None,
4703            initial_beta: None,
4704        }
4705    }
4706
4707    fn empty_meanspec() -> TermCollectionSpec {
4708        TermCollectionSpec {
4709            linear_terms: Vec::new(),
4710            random_effect_terms: Vec::new(),
4711            smooth_terms: Vec::new(),
4712        }
4713    }
4714
4715    /// A valid two-row latent-survival term spec (one exact event under loaded
4716    /// hazard, one right-censored row).
4717    fn valid_survival_spec(n: usize, p_time: usize) -> LatentSurvivalTermSpec {
4718        LatentSurvivalTermSpec {
4719            age_entry: Array1::zeros(n),
4720            age_exit: Array1::from_elem(n, 1.0),
4721            event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4722            weights: Array1::from_elem(n, 1.0),
4723            derivative_guard: 0.0,
4724            time_block: validation_time_block(n, p_time),
4725            time_design_right: None,
4726            time_offset_right: None,
4727            unloaded_mass_entry: Array1::from_elem(n, 0.01),
4728            unloaded_mass_exit: Array1::from_elem(n, 0.05),
4729            unloaded_mass_right: Array1::zeros(0),
4730            unloaded_hazard_exit: Array1::from_elem(n, 0.02),
4731            meanspec: empty_meanspec(),
4732            mean_offset: Array1::zeros(n),
4733        }
4734    }
4735
4736    /// A valid latent-binary term spec mirroring `valid_survival_spec` but
4737    /// without the per-row unloaded hazard.
4738    fn valid_binary_spec(n: usize, p_time: usize) -> LatentBinaryTermSpec {
4739        LatentBinaryTermSpec {
4740            age_entry: Array1::zeros(n),
4741            age_exit: Array1::from_elem(n, 1.0),
4742            event_target: Array1::from_shape_fn(n, |i| (i % 2) as u8),
4743            weights: Array1::from_elem(n, 1.0),
4744            derivative_guard: 0.0,
4745            time_block: validation_time_block(n, p_time),
4746            unloaded_mass_entry: Array1::from_elem(n, 0.01),
4747            unloaded_mass_exit: Array1::from_elem(n, 0.05),
4748            meanspec: empty_meanspec(),
4749            mean_offset: Array1::zeros(n),
4750        }
4751    }
4752
4753    fn loaded_frailty() -> FrailtySpec {
4754        FrailtySpec::HazardMultiplier {
4755            sigma_fixed: Some(0.3),
4756            loading: HazardLoading::LoadedVsUnloaded,
4757        }
4758    }
4759
4760    /// Both adapters route through the shared `validate_latent_interval_inputs`
4761    /// engine, but each must still emit its own context prefix and (for the
4762    /// size-mismatch / unloaded-decomposition diagnostics) the hazard-aware vs
4763    /// mass-only message variant. This pins the byte-for-byte contract the
4764    /// unification had to preserve, the property the issue's "old vs new
4765    /// validation errors" parity test guards.
4766    #[test]
4767    fn latent_interval_validation_parity_across_models() {
4768        let n = 2;
4769        let p_time = 2;
4770        let data = Array2::<f64>::zeros((n, 3));
4771
4772        // 1. A clean spec validates and round-trips the resolved sigma.
4773        //    Survival keeps the (possibly learnable) Option; binary unwraps to
4774        //    the fixed scalar.
4775        let surv_sigma = validate_latent_survival_inputs(
4776            data.view(),
4777            &valid_survival_spec(n, p_time),
4778            &loaded_frailty(),
4779        )
4780        .expect("valid survival spec must validate");
4781        assert_eq!(surv_sigma, Some(0.3));
4782        let bin_sigma = validate_latent_binary_inputs(
4783            data.view(),
4784            &valid_binary_spec(n, p_time),
4785            &loaded_frailty(),
4786        )
4787        .expect("valid binary spec must validate");
4788        assert_eq!(bin_sigma, 0.3);
4789
4790        // 2. Empty data: shared driver, per-model context prefix.
4791        let empty = Array2::<f64>::zeros((0, 3));
4792        let surv_empty = validate_latent_survival_inputs(
4793            empty.view(),
4794            &valid_survival_spec(n, p_time),
4795            &loaded_frailty(),
4796        )
4797        .expect_err("empty data must be rejected");
4798        assert_eq!(
4799            surv_empty.to_string(),
4800            "latent-survival requires a non-empty dataset"
4801        );
4802        let bin_empty = validate_latent_binary_inputs(
4803            empty.view(),
4804            &valid_binary_spec(n, p_time),
4805            &loaded_frailty(),
4806        )
4807        .expect_err("empty data must be rejected");
4808        assert_eq!(
4809            bin_empty.to_string(),
4810            "latent-binary requires a non-empty dataset"
4811        );
4812
4813        // 3. Size mismatch: survival's message carries `unloaded_hazard=`,
4814        //    binary's does not. This is the one shape that distinguishes the
4815        //    two row views feeding the shared driver.
4816        let mut surv_bad = valid_survival_spec(n, p_time);
4817        surv_bad.weights = Array1::from_elem(n + 1, 1.0);
4818        let surv_size = validate_latent_survival_inputs(data.view(), &surv_bad, &loaded_frailty())
4819            .expect_err("size mismatch must be rejected");
4820        let surv_msg = surv_size.to_string();
4821        assert!(
4822            surv_msg.starts_with("latent-survival size mismatch")
4823                && surv_msg.contains("unloaded_hazard="),
4824            "survival size-mismatch message must include unloaded_hazard: {surv_msg}"
4825        );
4826        let mut bin_bad = valid_binary_spec(n, p_time);
4827        bin_bad.weights = Array1::from_elem(n + 1, 1.0);
4828        let bin_size = validate_latent_binary_inputs(data.view(), &bin_bad, &loaded_frailty())
4829            .expect_err("size mismatch must be rejected");
4830        let bin_msg = bin_size.to_string();
4831        assert!(
4832            bin_msg.starts_with("latent-binary size mismatch")
4833                && !bin_msg.contains("unloaded_hazard"),
4834            "binary size-mismatch message must omit unloaded_hazard: {bin_msg}"
4835        );
4836
4837        // 4. Invalid unloaded decomposition: survival reports `exit_hazard=`,
4838        //    binary reports only the two masses.
4839        let mut surv_neg_hazard = valid_survival_spec(n, p_time);
4840        surv_neg_hazard.unloaded_hazard_exit[0] = -1.0;
4841        let surv_decomp =
4842            validate_latent_survival_inputs(data.view(), &surv_neg_hazard, &loaded_frailty())
4843                .expect_err("negative unloaded hazard must be rejected");
4844        assert_eq!(
4845            surv_decomp.to_string(),
4846            "latent-survival row 1 has invalid unloaded hazard decomposition: entry_mass=0.01, exit_mass=0.05, exit_hazard=-1"
4847        );
4848        let mut bin_bad_mass = valid_binary_spec(n, p_time);
4849        bin_bad_mass.unloaded_mass_exit[0] = 0.0; // exit < entry
4850        let bin_decomp =
4851            validate_latent_binary_inputs(data.view(), &bin_bad_mass, &loaded_frailty())
4852                .expect_err("non-monotone unloaded mass must be rejected");
4853        assert_eq!(
4854            bin_decomp.to_string(),
4855            "latent-binary row 1 has invalid unloaded mass decomposition: entry_mass=0.01, exit_mass=0"
4856        );
4857
4858        // 5. Per-row interval/event/weight diagnostics share one engine, so an
4859        //    identical invalid input yields identical (modulo prefix) text.
4860        let mut surv_event = valid_survival_spec(n, p_time);
4861        surv_event.event_target[1] = 7;
4862        let surv_event_err =
4863            validate_latent_survival_inputs(data.view(), &surv_event, &loaded_frailty())
4864                .expect_err("invalid event target must be rejected");
4865        assert_eq!(
4866            surv_event_err.to_string(),
4867            "latent-survival row 2 has invalid event target 7; expected 0 or 1"
4868        );
4869        let mut bin_event = valid_binary_spec(n, p_time);
4870        bin_event.event_target[1] = 7;
4871        let bin_event_err =
4872            validate_latent_binary_inputs(data.view(), &bin_event, &loaded_frailty())
4873                .expect_err("invalid event target must be rejected");
4874        assert_eq!(
4875            bin_event_err.to_string(),
4876            "latent-binary row 2 has invalid event target 7; expected 0 or 1"
4877        );
4878
4879        // 6. Frailty policy divergence: survival accepts a learnable scale
4880        //    (`sigma_fixed = None` ⇒ `Ok(None)`), binary rejects it.
4881        let learnable = FrailtySpec::HazardMultiplier {
4882            sigma_fixed: None,
4883            loading: HazardLoading::LoadedVsUnloaded,
4884        };
4885        let surv_learnable = validate_latent_survival_inputs(
4886            data.view(),
4887            &valid_survival_spec(n, p_time),
4888            &learnable,
4889        )
4890        .expect("survival accepts a learnable latent scale");
4891        assert_eq!(surv_learnable, None);
4892        let bin_learnable =
4893            validate_latent_binary_inputs(data.view(), &valid_binary_spec(n, p_time), &learnable)
4894                .expect_err("binary requires a fixed latent scale");
4895        assert_eq!(
4896            bin_learnable.to_string(),
4897            "latent-binary currently requires a fixed hazard-multiplier sigma"
4898        );
4899
4900        // 7. The time-block shape check is owned by the shared driver: a
4901        //    column-count mismatch is reported with the per-model prefix.
4902        let mut surv_time_bad = valid_survival_spec(n, p_time);
4903        surv_time_bad.time_block.design_entry = DesignMatrix::Dense(DenseDesignMatrix::from(
4904            Array2::from_elem((n, p_time + 1), 0.1),
4905        ));
4906        let surv_time_err =
4907            validate_latent_survival_inputs(data.view(), &surv_time_bad, &loaded_frailty())
4908                .expect_err("time block column mismatch must be rejected");
4909        assert!(
4910            surv_time_err
4911                .to_string()
4912                .starts_with("latent-survival time block column mismatch"),
4913            "unexpected survival time-block message: {surv_time_err}"
4914        );
4915    }
4916
4917    #[test]
4918    fn latent_survival_coefficient_cost_uses_joint_coupled_formula() {
4919        // `evaluate_exact_newton_joint_dense` builds a fully dense joint
4920        // Hessian over (Σ p_b)² across the time, mean, and log-σ blocks via
4921        // per-row pullback of the latent-survival primary kernel. The override
4922        // must reflect that joint coupling rather than the block-diagonal
4923        // default.
4924        let family = learnable_sigma_test_family();
4925        let n = family.event_target.len() as u64;
4926        let p_time = 2u64;
4927        let p_mean = 2u64;
4928        let p_log_sigma = 1u64;
4929        let specs = vec![
4930            ParameterBlockSpec {
4931                name: "time".to_string(),
4932                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4933                    n as usize,
4934                    p_time as usize,
4935                )))),
4936                offset: Array1::zeros(n as usize),
4937                penalties: Vec::new(),
4938                nullspace_dims: Vec::new(),
4939                initial_log_lambdas: Array1::zeros(0),
4940                initial_beta: None,
4941                gauge_priority: 100,
4942                jacobian_callback: None,
4943                stacked_design: None,
4944                stacked_offset: None,
4945            },
4946            ParameterBlockSpec {
4947                name: "mean".to_string(),
4948                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4949                    n as usize,
4950                    p_mean as usize,
4951                )))),
4952                offset: Array1::zeros(n as usize),
4953                penalties: Vec::new(),
4954                nullspace_dims: Vec::new(),
4955                initial_log_lambdas: Array1::zeros(0),
4956                initial_beta: None,
4957                gauge_priority: 100,
4958                jacobian_callback: None,
4959                stacked_design: None,
4960                stacked_offset: None,
4961            },
4962            ParameterBlockSpec {
4963                name: "log_sigma".to_string(),
4964                design: DesignMatrix::Dense(DenseDesignMatrix::from(Array2::zeros((
4965                    n as usize,
4966                    p_log_sigma as usize,
4967                )))),
4968                offset: Array1::zeros(n as usize),
4969                penalties: Vec::new(),
4970                nullspace_dims: Vec::new(),
4971                initial_log_lambdas: Array1::zeros(0),
4972                initial_beta: None,
4973                gauge_priority: 100,
4974                jacobian_callback: None,
4975                stacked_design: None,
4976                stacked_offset: None,
4977            },
4978        ];
4979        let p_total = p_time + p_mean + p_log_sigma;
4980        let expected_joint = n * p_total * p_total;
4981        let expected_block_diag =
4982            n * (p_time * p_time + p_mean * p_mean + p_log_sigma * p_log_sigma);
4983        assert_eq!(family.coefficient_hessian_cost(&specs), expected_joint);
4984        // Cross-block fill (time–mean, time–log_sigma, mean–log_sigma) makes
4985        // the joint cost strictly larger than the block-diagonal default.
4986        assert!(expected_joint > expected_block_diag);
4987    }
4988
4989    #[test]
4990    fn latent_family_planner_keeps_outer_hessian_at_large_n() {
4991        use crate::custom_family::custom_family_outer_derivatives;
4992        use gam_problem::{DeclaredHessianForm, Derivative};
4993
4994        let options = BlockwiseFitOptions::default();
4995        let large_n = 50_001;
4996
4997        let survival = learnable_sigma_test_family();
4998        let survival_specs =
4999            latent_test_specs(large_n, &[("time", 2), ("mean", 2), ("log_sigma", 1)]);
5000        let (surv_grad, surv_hess) =
5001            custom_family_outer_derivatives(&survival, &survival_specs, &options);
5002        assert_eq!(surv_grad, Derivative::Analytic);
5003        assert_eq!(surv_hess, DeclaredHessianForm::Either);
5004
5005        let binary = fixed_sigma_binary_test_family();
5006        let binary_specs = latent_test_specs(large_n, &[("time", 2), ("mean", 2)]);
5007        let (bin_grad, bin_hess) =
5008            custom_family_outer_derivatives(&binary, &binary_specs, &options);
5009        assert_eq!(bin_grad, Derivative::Analytic);
5010        assert_eq!(bin_hess, DeclaredHessianForm::Either);
5011    }
5012
5013    #[test]
5014    fn latent_families_arm_self_vanishing_levenberg_on_ill_conditioning() {
5015        // Regression guard for #1108. The interval-censored row contribution
5016        // `ℓ = log[S(L) − S(R)]` is the log of a DIFFERENCE of survival kernels and
5017        // is legitimately NON-concave (indefinite per-row Hessian) away from the
5018        // optimum; on the constrained (monotone-cone) coupled time block this can
5019        // make the penalized joint Hessian full-rank yet indefinite / severely
5020        // ill-conditioned at the cold-start seed. The coupled exact-joint inner
5021        // solver only adds the self-vanishing Levenberg–Marquardt diagonal floor
5022        // (the cure for a full-rank ill-conditioned reflected QP that otherwise
5023        // oscillates the trust region into a snapshot-less stall) when the family
5024        // opts in via `levenberg_on_ill_conditioning()`. Both latent families MUST
5025        // keep this armed (the default is `false`, which leaves the interval inner
5026        // solve diverging with "exited the joint Newton path before convergence").
5027        assert!(
5028            learnable_sigma_test_family().levenberg_on_ill_conditioning(),
5029            "LatentSurvivalFamily must arm the self-vanishing Levenberg floor so the \
5030             indefinite interval-censored joint Hessian converges (see #1108)"
5031        );
5032        assert!(
5033            fixed_sigma_binary_test_family().levenberg_on_ill_conditioning(),
5034            "LatentBinaryFamily must arm the self-vanishing Levenberg floor on its \
5035             constrained coupled time block (see #1108)"
5036        );
5037    }
5038
5039    #[test]
5040    fn latent_binary_exact_joint_hessian_and_workspace_matvec_match_fd() {
5041        let family = fixed_sigma_binary_test_family();
5042        let beta = array![0.15, 0.25, 0.1, -0.15];
5043        let states = latent_binary_states_from_joint_beta(&family, &beta);
5044        let h = 1e-6;
5045
5046        let analytic_hessian = family
5047            .exact_newton_joint_hessian(&states)
5048            .expect("analytic latent binary joint hessian evaluation")
5049            .expect("latent binary should expose exact joint hessian");
5050
5051        for j in 0..beta.len() {
5052            let mut beta_plus = beta.clone();
5053            beta_plus[j] += h;
5054            let gradient_plus = family
5055                .exact_newton_joint_gradient_evaluation(
5056                    &latent_binary_states_from_joint_beta(&family, &beta_plus),
5057                    &[],
5058                )
5059                .expect("joint gradient plus")
5060                .expect("joint gradient should exist")
5061                .gradient;
5062
5063            let mut beta_minus = beta.clone();
5064            beta_minus[j] -= h;
5065            let gradient_minus = family
5066                .exact_newton_joint_gradient_evaluation(
5067                    &latent_binary_states_from_joint_beta(&family, &beta_minus),
5068                    &[],
5069                )
5070                .expect("joint gradient minus")
5071                .expect("joint gradient should exist")
5072                .gradient;
5073
5074            let fd_column = -((&gradient_plus - &gradient_minus) / (2.0 * h));
5075            let analytic_column = analytic_hessian.column(j).to_owned();
5076            let rel = max_relative_array1(&analytic_column, &fd_column);
5077            assert!(
5078                rel < 5e-4,
5079                "latent binary joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={fd_column:?}"
5080            );
5081        }
5082
5083        let workspace = family
5084            .exact_newton_joint_hessian_workspace(&states, &[])
5085            .expect("latent binary hessian workspace")
5086            .expect("workspace should exist");
5087        let direction = array![0.4, -0.2, 0.3, 0.1];
5088        let hv = workspace
5089            .hessian_matvec(&direction)
5090            .expect("workspace matvec")
5091            .expect("workspace should support matvec");
5092        let dense_hv = analytic_hessian.dot(&direction);
5093        assert!(
5094            max_relative_array1(&hv, &dense_hv) < 1e-12,
5095            "latent binary workspace HVP mismatch: hv={hv:?}, dense={dense_hv:?}"
5096        );
5097
5098        let dh = workspace
5099            .directional_derivative(&direction)
5100            .expect("workspace dH")
5101            .expect("workspace should support dH");
5102        let fd_step = 1e-5;
5103        let h_plus = family
5104            .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5105                &family,
5106                &(beta.clone() + &(fd_step * &direction)),
5107            ))
5108            .expect("hessian plus")
5109            .expect("hessian plus should exist");
5110        let h_minus = family
5111            .exact_newton_joint_hessian(&latent_binary_states_from_joint_beta(
5112                &family,
5113                &(beta - &(fd_step * &direction)),
5114            ))
5115            .expect("hessian minus")
5116            .expect("hessian minus should exist");
5117        let fd_dh = (&h_plus - &h_minus) / (2.0 * fd_step);
5118        assert!(
5119            max_relative_array2(&dh, &fd_dh) < 2e-4,
5120            "latent binary workspace dH mismatch: dh={dh:?}, fd={fd_dh:?}"
5121        );
5122    }
5123
5124    #[test]
5125    fn latent_survival_learnable_sigma_block_matches_family_fd() {
5126        let family = learnable_sigma_test_family();
5127        let beta = learnable_sigma_test_joint_beta();
5128        let states = latent_survival_states_from_joint_beta(&family, &beta);
5129        let slices = family.joint_slices();
5130        let sigma_idx = slices
5131            .log_sigma
5132            .as_ref()
5133            .expect("learnable sigma test family should expose log_sigma")
5134            .start;
5135        let h = 2e-4;
5136
5137        let eval = family
5138            .evaluate(&states)
5139            .expect("learnable latent survival evaluation");
5140        let joint_gradient = family
5141            .exact_newton_joint_gradient_evaluation(&states, &[])
5142            .expect("joint gradient evaluation")
5143            .expect("joint gradient should exist")
5144            .gradient;
5145        let joint_hessian = family
5146            .exact_newton_joint_hessian(&states)
5147            .expect("joint hessian evaluation")
5148            .expect("joint hessian should exist");
5149        assert_eq!(eval.blockworking_sets.len(), 3);
5150
5151        let (block_grad, block_neg_hess) =
5152            match &eval.blockworking_sets[LatentSurvivalFamily::BLOCK_LOG_SIGMA] {
5153                BlockWorkingSet::ExactNewton { gradient, hessian } => {
5154                    let neg_hess = match hessian {
5155                        SymmetricMatrix::Dense(mat) => mat[[0, 0]],
5156                        _ => panic!("log_sigma block should use a dense exact-Newton Hessian"),
5157                    };
5158                    (gradient[0], neg_hess)
5159                }
5160                _ => panic!("log_sigma block should use ExactNewton"),
5161            };
5162
5163        assert!((block_grad - joint_gradient[sigma_idx]).abs() < 1e-12);
5164        assert!((block_neg_hess - joint_hessian[[sigma_idx, sigma_idx]]).abs() < 1e-12);
5165
5166        let mut beta_plus = beta.clone();
5167        beta_plus[sigma_idx] += h;
5168        let ll_plus = family
5169            .log_likelihood_only(&latent_survival_states_from_joint_beta(&family, &beta_plus))
5170            .expect("ll plus");
5171        let ll_0 = family.log_likelihood_only(&states).expect("ll base");
5172        let mut beta_minus = beta.clone();
5173        beta_minus[sigma_idx] -= h;
5174        let ll_minus = family
5175            .log_likelihood_only(&latent_survival_states_from_joint_beta(
5176                &family,
5177                &beta_minus,
5178            ))
5179            .expect("ll minus");
5180
5181        let fd_grad = (ll_plus - ll_minus) / (2.0 * h);
5182        let fd_neg_hess = -(ll_plus - 2.0 * ll_0 + ll_minus) / (h * h);
5183        assert!(
5184            (joint_gradient[sigma_idx] - fd_grad).abs()
5185                / joint_gradient[sigma_idx]
5186                    .abs()
5187                    .max(fd_grad.abs())
5188                    .max(1e-12)
5189                < 2e-3,
5190            "family log_sigma grad={}, fd={fd_grad}",
5191            joint_gradient[sigma_idx]
5192        );
5193        assert!(
5194            (joint_hessian[[sigma_idx, sigma_idx]] - fd_neg_hess).abs()
5195                / joint_hessian[[sigma_idx, sigma_idx]]
5196                    .abs()
5197                    .max(fd_neg_hess.abs())
5198                    .max(1e-10)
5199                < 2e-2,
5200            "family log_sigma neg_hess={}, fd={fd_neg_hess}",
5201            joint_hessian[[sigma_idx, sigma_idx]]
5202        );
5203    }
5204
5205    #[test]
5206    fn latent_survival_exact_joint_hessian_matches_gradient_fd() {
5207        let family = learnable_sigma_test_family();
5208        let beta = learnable_sigma_test_joint_beta();
5209        let states = latent_survival_states_from_joint_beta(&family, &beta);
5210        let h = 1e-6;
5211
5212        let analytic_hessian = family
5213            .exact_newton_joint_hessian(&states)
5214            .expect("analytic joint hessian evaluation")
5215            .expect("latent survival should expose exact joint hessian");
5216
5217        for j in 0..beta.len() {
5218            let mut beta_plus = beta.clone();
5219            beta_plus[j] += h;
5220            let gradient_plus = family
5221                .exact_newton_joint_gradient_evaluation(
5222                    &latent_survival_states_from_joint_beta(&family, &beta_plus),
5223                    &[],
5224                )
5225                .expect("joint gradient plus")
5226                .expect("joint gradient should exist")
5227                .gradient;
5228
5229            let mut beta_minus = beta.clone();
5230            beta_minus[j] -= h;
5231            let gradient_minus = family
5232                .exact_newton_joint_gradient_evaluation(
5233                    &latent_survival_states_from_joint_beta(&family, &beta_minus),
5234                    &[],
5235                )
5236                .expect("joint gradient minus")
5237                .expect("joint gradient should exist")
5238                .gradient;
5239
5240            let fd_column = (&gradient_plus - &gradient_minus) / (2.0 * h);
5241            let analytic_column = analytic_hessian.column(j).to_owned();
5242            let rel = max_relative_array1(&analytic_column, &(-fd_column));
5243            assert!(
5244                rel < 5e-4,
5245                "joint Hessian column {j} mismatch: rel={rel}, analytic={analytic_column:?}, fd={:?}",
5246                -((&gradient_plus - &gradient_minus) / (2.0 * h))
5247            );
5248        }
5249    }
5250
5251    /// FD check for `LatentSurvivalFamily::offset_channel_residuals`: each
5252    /// channel residual sums to `∂(−ℓ)/∂o_ch` for a uniform additive offset on
5253    /// that time channel (the baseline-θ enters only through these offsets).
5254    /// `o_ch` shifts `eta_time[ch-slice]` uniformly, so `Σ_i r^ch_i` is exactly
5255    /// the directional derivative of `−ℓ` along a constant offset on channel ch.
5256    /// This validates the envelope-theorem latent baseline-θ gradient primitive.
5257    #[test]
5258    fn latent_survival_offset_channel_residuals_match_finite_difference() {
5259        let family = survival_stress_test_family(24);
5260        let beta = survival_stress_test_joint_beta();
5261        let states = latent_survival_states_from_joint_beta(&family, &beta);
5262        let n = family.event_target.len();
5263
5264        let residuals = family
5265            .offset_channel_residuals(&states)
5266            .expect("offset channel residuals");
5267        let sum_entry: f64 = residuals.entry.sum();
5268        let sum_exit: f64 = residuals.exit.sum();
5269        let sum_deriv: f64 = residuals.derivative.sum();
5270
5271        // `−ℓ` after shifting one time channel's eta by a constant δ.
5272        let neg_ll_with_offset = |channel: usize, delta: f64| -> f64 {
5273            let mut shifted = states.clone();
5274            let slice = match channel {
5275                0 => s![0..n],
5276                1 => s![n..2 * n],
5277                2 => s![2 * n..3 * n],
5278                _ => unreachable!(),
5279            };
5280            shifted[LatentSurvivalFamily::BLOCK_TIME]
5281                .eta
5282                .slice_mut(slice)
5283                .mapv_inplace(|v| v + delta);
5284            let (ll, _) = family
5285                .evaluate_exact_newton_joint_gradient_dense(&shifted)
5286                .expect("shifted joint gradient evaluation");
5287            -ll
5288        };
5289
5290        let h = 1e-6;
5291        let fd_entry = (neg_ll_with_offset(0, h) - neg_ll_with_offset(0, -h)) / (2.0 * h);
5292        let fd_exit = (neg_ll_with_offset(1, h) - neg_ll_with_offset(1, -h)) / (2.0 * h);
5293        let fd_deriv = (neg_ll_with_offset(2, h) - neg_ll_with_offset(2, -h)) / (2.0 * h);
5294
5295        assert!(
5296            (sum_entry - fd_entry).abs() <= 1e-5 * fd_entry.abs().max(1.0),
5297            "entry-channel residual sum mismatch: analytic={sum_entry}, fd={fd_entry}"
5298        );
5299        assert!(
5300            (sum_exit - fd_exit).abs() <= 1e-5 * fd_exit.abs().max(1.0),
5301            "exit-channel residual sum mismatch: analytic={sum_exit}, fd={fd_exit}"
5302        );
5303        assert!(
5304            (sum_deriv - fd_deriv).abs() <= 1e-5 * fd_deriv.abs().max(1.0),
5305            "derivative-channel residual sum mismatch: analytic={sum_deriv}, fd={fd_deriv}"
5306        );
5307    }
5308
5309    #[test]
5310    fn latent_survival_exact_joint_parallel_stress_is_repeatable() {
5311        let family = survival_stress_test_family(96);
5312        let beta = survival_stress_test_joint_beta();
5313        let states = latent_survival_states_from_joint_beta(&family, &beta);
5314        let direction_u = array![0.03, -0.02, 0.01, 0.04, -0.015, 0.025, -0.005, 0.02];
5315        let direction_v = array![-0.01, 0.035, -0.025, 0.015, 0.02, -0.01, 0.03, -0.015];
5316
5317        let (ll_a, grad_a) = family
5318            .evaluate_exact_newton_joint_gradient_dense(&states)
5319            .expect("stress joint gradient evaluation");
5320        let (ll_b, grad_b) = family
5321            .evaluate_exact_newton_joint_gradient_dense(&states)
5322            .expect("repeat stress joint gradient evaluation");
5323        assert_eq!(ll_a.to_bits(), ll_b.to_bits());
5324        assert_eq!(grad_a, grad_b);
5325
5326        let (joint_ll_a, joint_grad_a, hess_a) = family
5327            .evaluate_exact_newton_joint_dense(&states)
5328            .expect("stress joint dense evaluation");
5329        let (joint_ll_b, joint_grad_b, hess_b) = family
5330            .evaluate_exact_newton_joint_dense(&states)
5331            .expect("repeat stress joint dense evaluation");
5332        assert_eq!(joint_ll_a.to_bits(), joint_ll_b.to_bits());
5333        assert_eq!(joint_grad_a, joint_grad_b);
5334        assert_eq!(hess_a, hess_b);
5335        assert!(hess_a.iter().all(|value| value.is_finite()));
5336        assert!(max_relative_array2(&hess_a, &hess_a.t().to_owned()) < 1e-12);
5337
5338        let dh_a = family
5339            .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5340            .expect("stress joint dH evaluation");
5341        let dh_b = family
5342            .exact_newton_joint_hessian_directional_derivative_dense(&states, &direction_u)
5343            .expect("repeat stress joint dH evaluation");
5344        assert_eq!(dh_a, dh_b);
5345        assert!(dh_a.iter().all(|value| value.is_finite()));
5346        assert!(max_relative_array2(&dh_a, &dh_a.t().to_owned()) < 1e-12);
5347
5348        let d2h_a = family
5349            .exact_newton_joint_hessian_second_directional_derivative_dense(
5350                &states,
5351                &direction_u,
5352                &direction_v,
5353            )
5354            .expect("stress joint d2H evaluation");
5355        let d2h_b = family
5356            .exact_newton_joint_hessian_second_directional_derivative_dense(
5357                &states,
5358                &direction_u,
5359                &direction_v,
5360            )
5361            .expect("repeat stress joint d2H evaluation");
5362        assert_eq!(d2h_a, d2h_b);
5363        assert!(d2h_a.iter().all(|value| value.is_finite()));
5364        assert!(max_relative_array2(&d2h_a, &d2h_a.t().to_owned()) < 1e-12);
5365    }
5366
5367    #[test]
5368    fn latent_survival_exact_joint_dh_matches_hessian_fd() {
5369        let family = learnable_sigma_test_family();
5370        let beta = learnable_sigma_test_joint_beta();
5371        let states = latent_survival_states_from_joint_beta(&family, &beta);
5372        let h = 2e-4;
5373        let direction = array![0.07, -0.03, 0.05, 0.02, -0.04];
5374
5375        let analytic = family
5376            .exact_newton_joint_hessian_directional_derivative(&states, &direction)
5377            .expect("analytic joint dH evaluation")
5378            .expect("latent survival should expose exact joint dH");
5379
5380        let hessian_plus = family
5381            .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5382                &family,
5383                &(beta.clone() + h * &direction),
5384            ))
5385            .expect("joint hessian plus")
5386            .expect("joint hessian should exist");
5387        let hessian_minus = family
5388            .exact_newton_joint_hessian(&latent_survival_states_from_joint_beta(
5389                &family,
5390                &(beta.clone() - h * &direction),
5391            ))
5392            .expect("joint hessian minus")
5393            .expect("joint hessian should exist");
5394
5395        let fd = (&hessian_plus - &hessian_minus) / (2.0 * h);
5396        let rel = frobenius_relative_array2(&analytic, &fd);
5397        assert!(rel < 2e-3, "joint dH mismatch: rel={rel}");
5398    }
5399
5400    #[test]
5401    fn latent_survival_exact_joint_d2h_matches_directional_fd() {
5402        let family = learnable_sigma_test_family();
5403        let beta = learnable_sigma_test_joint_beta();
5404        let states = latent_survival_states_from_joint_beta(&family, &beta);
5405        let h = 5e-4;
5406        let direction_u = array![0.07, -0.03, 0.05, 0.02, -0.04];
5407        let direction_v = array![-0.02, 0.06, -0.01, 0.03, 0.05];
5408
5409        let analytic = family
5410            .exact_newton_joint_hessiansecond_directional_derivative(
5411                &states,
5412                &direction_u,
5413                &direction_v,
5414            )
5415            .expect("analytic joint d2H evaluation")
5416            .expect("latent survival should expose exact joint d2H");
5417        let swapped = family
5418            .exact_newton_joint_hessiansecond_directional_derivative(
5419                &states,
5420                &direction_v,
5421                &direction_u,
5422            )
5423            .expect("swapped analytic joint d2H evaluation")
5424            .expect("latent survival should expose exact joint d2H");
5425        let symmetry_rel = max_relative_array2(&analytic, &swapped);
5426        assert!(
5427            symmetry_rel < 1e-10,
5428            "joint d2H should be symmetric in directions, got rel={symmetry_rel}"
5429        );
5430
5431        let dh_plus = family
5432            .exact_newton_joint_hessian_directional_derivative(
5433                &latent_survival_states_from_joint_beta(
5434                    &family,
5435                    &(beta.clone() + h * &direction_v),
5436                ),
5437                &direction_u,
5438            )
5439            .expect("joint dH plus")
5440            .expect("joint dH should exist");
5441        let dh_minus = family
5442            .exact_newton_joint_hessian_directional_derivative(
5443                &latent_survival_states_from_joint_beta(
5444                    &family,
5445                    &(beta.clone() - h * &direction_v),
5446                ),
5447                &direction_u,
5448            )
5449            .expect("joint dH minus")
5450            .expect("joint dH should exist");
5451
5452        let fd = (&dh_plus - &dh_minus) / (2.0 * h);
5453        let rel = frobenius_relative_array2(&analytic, &fd);
5454        assert!(rel < 2.5e-2, "joint d2H mismatch: rel={rel}");
5455    }
5456
5457    #[test]
5458    fn latent_survival_row_primary_derivatives_match_fd() {
5459        let quadctx = QuadratureContext::new();
5460        let row = LatentSurvivalRow::exact_event(0.35, 1.4, 0.1, 0.45, 0.8, 0.12);
5461        // [q_entry, q_exit, qdot_exit, q_right, mu, log_sigma]. This is an
5462        // exact-event row, so the `q_right` channel is inert (the likelihood
5463        // does not depend on it); the FD loop below confirms its gradient/Hessian
5464        // entries are zero.
5465        let primary = array![
5466            0.35f64.ln(),
5467            1.4f64.ln(),
5468            0.8,
5469            1.6f64.ln(),
5470            -0.2,
5471            0.4f64.ln()
5472        ];
5473        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5474        let h_grad = 1e-6;
5475        let h_hess = 2e-4;
5476
5477        let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5478            &quadctx,
5479            &row,
5480            primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5481            primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5482            primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5483            primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5484            primary[LATENT_SURVIVAL_PRIMARY_MU],
5485            sigma,
5486            true,
5487        )
5488        .expect("analytic row primary gradient/hessian");
5489
5490        for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5491            let mut plus = primary.clone();
5492            plus[j] += h_grad;
5493            let mut minus = primary.clone();
5494            minus[j] -= h_grad;
5495            let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5496                - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5497                / (2.0 * h_grad);
5498            let rel_grad =
5499                (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5500            assert!(
5501                rel_grad < 2e-4,
5502                "row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5503                gradient[j]
5504            );
5505
5506            for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5507                let mut pp = primary.clone();
5508                pp[j] += h_hess;
5509                pp[k] += h_hess;
5510                let mut pm = primary.clone();
5511                pm[j] += h_hess;
5512                pm[k] -= h_hess;
5513                let mut mp = primary.clone();
5514                mp[j] -= h_hess;
5515                mp[k] += h_hess;
5516                let mut mm = primary.clone();
5517                mm[j] -= h_hess;
5518                mm[k] -= h_hess;
5519                let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5520                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5521                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5522                    + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5523                    / (4.0 * h_hess * h_hess);
5524                let analytic = neg_hessian[[j, k]];
5525                let abs_err = (analytic - fd_neg_hess).abs();
5526                let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5527                assert!(
5528                    abs_err < 2e-5 || rel < 2e-3,
5529                    "row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5530                );
5531            }
5532        }
5533    }
5534
5535    #[test]
5536    fn latent_survival_interval_row_primary_derivatives_match_fd() {
5537        // Interval-censored row jet `ℓ = log[S(L) − S(R)] − log S(entry)`. The
5538        // dynamic two-state numerator differentiates BOTH boundary masses
5539        // `M_L = exp(q_exit)` (left, `q_exit`) and `M_R = exp(q_right)` (right,
5540        // `q_right`) independently — channels that the static
5541        // `LatentSurvivalRowJet::interval_censored` (μ-only) never exercises. This
5542        // FD-verifies the gradient AND neg-Hessian of the interval contribution
5543        // w.r.t. ALL six primary coordinates (q_entry, q_exit/L, qdot_exit,
5544        // q_right/R, mu, log_sigma) on a WELL-POSED bracket where `S(L) − S(R)` is
5545        // comfortably positive (M_L = e^{−0.4} ≈ 0.67 well below M_R = e^{0.5} ≈
5546        // 1.65, so the survival-mass difference is large and the log-of-a-
5547        // difference curvature is well-conditioned).
5548        let quadctx = QuadratureContext::new();
5549        // Bracket masses: entry < L < R with comfortable gaps.
5550        let q_entry = -1.2_f64; // M_entry = e^{−1.2} ≈ 0.30
5551        let q_exit = -0.4_f64; // L: M_L = e^{−0.4} ≈ 0.67
5552        let q_right = 0.5_f64; // R: M_R = e^{0.5} ≈ 1.65 (> M_L)
5553        let mu = -0.15_f64;
5554        let log_sigma = 0.3_f64; // σ ≈ 1.35
5555        // Small, monotone unloaded masses (entry ≤ left ≤ right); qdot is inert
5556        // for the interval contribution.
5557        let row = LatentSurvivalRow::interval_censored(
5558            q_entry.exp(), // mass_entry (consistency only; jet reads q's)
5559            q_exit.exp(),  // mass_left
5560            q_right.exp(), // mass_right
5561            0.01,          // mass_unloaded_entry
5562            0.02,          // mass_unloaded_left
5563            0.05,          // mass_unloaded_right
5564        );
5565        assert!(matches!(
5566            row.event_type,
5567            LatentSurvivalEventType::IntervalCensored
5568        ));
5569
5570        // [q_entry, q_exit/L, qdot_exit, q_right/R, mu, log_sigma]. qdot_exit is
5571        // inert for interval rows (no hazard-derivative channel); the FD loop
5572        // confirms its gradient/Hessian entries are 0.
5573        let primary = array![q_entry, q_exit, 0.7, q_right, mu, log_sigma];
5574        let sigma = primary[LATENT_SURVIVAL_PRIMARY_LOG_SIGMA].exp();
5575        let h_grad = 1e-6;
5576        let h_hess = 2e-4;
5577
5578        let (_, gradient, neg_hessian) = latent_survival_row_primary_gradient_hessian(
5579            &quadctx,
5580            &row,
5581            primary[LATENT_SURVIVAL_PRIMARY_Q_ENTRY],
5582            primary[LATENT_SURVIVAL_PRIMARY_Q_EXIT],
5583            primary[LATENT_SURVIVAL_PRIMARY_QDOT_EXIT],
5584            primary[LATENT_SURVIVAL_PRIMARY_Q_RIGHT],
5585            primary[LATENT_SURVIVAL_PRIMARY_MU],
5586            sigma,
5587            true,
5588        )
5589        .expect("analytic interval row primary gradient/hessian");
5590
5591        // The interval contribution must be a positive survival-mass difference
5592        // at this bracket, so the value channel is finite.
5593        let value = latent_survival_row_loglik_from_primary(&quadctx, &row, &primary);
5594        assert!(
5595            value.is_finite(),
5596            "interval row log-likelihood must be finite on a well-posed bracket, got {value}"
5597        );
5598
5599        for j in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5600            let mut plus = primary.clone();
5601            plus[j] += h_grad;
5602            let mut minus = primary.clone();
5603            minus[j] -= h_grad;
5604            let fd_grad = (latent_survival_row_loglik_from_primary(&quadctx, &row, &plus)
5605                - latent_survival_row_loglik_from_primary(&quadctx, &row, &minus))
5606                / (2.0 * h_grad);
5607            let rel_grad =
5608                (gradient[j] - fd_grad).abs() / gradient[j].abs().max(fd_grad.abs()).max(1e-12);
5609            assert!(
5610                rel_grad < 2e-4,
5611                "interval row primary grad[{j}] mismatch: analytic={}, fd={fd_grad}, rel={rel_grad}",
5612                gradient[j]
5613            );
5614
5615            for k in 0..LATENT_SURVIVAL_PRIMARY_DIM {
5616                let mut pp = primary.clone();
5617                pp[j] += h_hess;
5618                pp[k] += h_hess;
5619                let mut pm = primary.clone();
5620                pm[j] += h_hess;
5621                pm[k] -= h_hess;
5622                let mut mp = primary.clone();
5623                mp[j] -= h_hess;
5624                mp[k] += h_hess;
5625                let mut mm = primary.clone();
5626                mm[j] -= h_hess;
5627                mm[k] -= h_hess;
5628                let fd_neg_hess = -(latent_survival_row_loglik_from_primary(&quadctx, &row, &pp)
5629                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &pm)
5630                    - latent_survival_row_loglik_from_primary(&quadctx, &row, &mp)
5631                    + latent_survival_row_loglik_from_primary(&quadctx, &row, &mm))
5632                    / (4.0 * h_hess * h_hess);
5633                let analytic = neg_hessian[[j, k]];
5634                let abs_err = (analytic - fd_neg_hess).abs();
5635                let rel = abs_err / analytic.abs().max(fd_neg_hess.abs()).max(1e-10);
5636                assert!(
5637                    abs_err < 5e-5 || rel < 3e-3,
5638                    "interval row primary neg_hess[{j},{k}] mismatch: analytic={analytic}, fd={fd_neg_hess}, abs_err={abs_err}, rel={rel}"
5639                );
5640            }
5641        }
5642    }
5643}