Skip to main content

gam_models/survival/
lognormal_kernel.rs

1//! Shared analytic kernel for latent-variable families with lognormal structure.
2//!
3//! The kernel object `K_{k,m}(μ, σ) := E[exp(k·U − m·exp(U))]`, where
4//! `U ~ N(μ, σ²)`, is the only special function required by all latent families.
5//!
6//! It satisfies exact μ-recurrences (see [`kernel_ratio_jet`]) and the
7//! corresponding heat-equation σ-identities, so fixed-σ latent families reduce
8//! to evaluating kernel bundles at shifted arguments.
9//!
10//! Row likelihoods for binary and survival models are small signed sums of
11//! kernel terms; [`LogKernelSumJet`] evaluates their log-derivatives from
12//! log-space kernel bundles and treats non-positive signed sums as invalid rows.
13
14use crate::model_types::EstimationError;
15use crate::probability::signed_log_sum_exp;
16use crate::quadrature::{
17    IntegratedExpectationMode, QuadratureContext, lognormal_laplace_unit_log_term_shared,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21
22// ─── Typed errors ────────────────────────────────────────────────────────────
23
24/// Errors produced by the lognormal-kernel frailty/marginal-slope validators.
25///
26/// Public boundaries that historically returned `Result<_, String>` continue to
27/// do so via `.map_err(|e| e.to_string())`; the `Display` impl reproduces the
28/// original error strings byte-for-byte.
29#[derive(Debug, Clone)]
30pub enum LognormalKernelError {
31    /// The chosen frailty modifier is not finite-state exact with the
32    /// requested marginal-slope family.
33    InvalidSpec { reason: String },
34}
35
36impl_reason_error_boilerplate! {
37    LognormalKernelError {
38        InvalidSpec,
39    }
40}
41
42// ─── Frailty specification ───────────────────────────────────────────────────
43
44/// How the hazard multiplier frailty loads onto the hazard components.
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "kebab-case")]
47pub enum HazardLoading {
48    /// Frailty multiplies the entire hazard: h(t|U) = exp(U) · h_0(t).
49    Full,
50    /// Frailty multiplies only the disease-like component; an exogenous
51    /// background ("Makeham") component is unloaded:
52    ///   h(t|U) = exp(U) · h_loaded(t) + h_unloaded(t).
53    /// This is the faithful model for Gompertz-Makeham.
54    LoadedVsUnloaded,
55}
56
57/// Frailty modifier specification at the family level.
58///
59/// Two structurally different exact modifiers exist:
60///
61/// 1. **GaussianShift**: additive Gaussian on the final transformation index.
62///    Exact for probit families — the existing sextic microcell kernel survives
63///    unchanged (just scale denested cell coefficients by 1/√(1+σ²)).
64///
65/// 2. **HazardMultiplier**: lognormal multiplier on the loaded cumulative hazard.
66///    Exact for PH/cloglog families — row likelihoods are finite sums of
67///    K_{k,m}(μ, σ) kernel terms.
68///
69/// These are mathematically distinct families.  Do not mix them.
70#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71#[serde(tag = "frailty_kind", rename_all = "kebab-case")]
72pub enum FrailtySpec {
73    /// No frailty modifier.
74    None,
75    /// Gaussian shift on the final scalar index: U ~ N(0, σ²) added to η.
76    /// Exact for probit: E[Φ(η + U)] = Φ(η / √(1+σ²)).
77    /// The existing sextic microcell kernel is preserved.
78    GaussianShift {
79        /// Fixed σ, or None if learnable.
80        sigma_fixed: Option<f64>,
81    },
82    /// Lognormal hazard multiplier: conditional hazard h(t|U) involves exp(U).
83    /// Exact for PH/cloglog/survival via K_{k,m} kernel.
84    HazardMultiplier {
85        /// Fixed σ, or None if learnable.
86        sigma_fixed: Option<f64>,
87        /// How the multiplier loads onto hazard components.
88        loading: HazardLoading,
89    },
90}
91
92impl FrailtySpec {
93    /// Validate that this frailty spec is compatible with score_warp/linkwiggle
94    /// cubic marginal-slope families.
95    ///
96    /// - `GaussianShift` is exact: the sextic microcell kernel is preserved
97    ///   (probit scaling by 1/τ, τ = √(1+σ²)).
98    /// - `HazardMultiplier` is exact only for PH/cloglog rowwise families.
99    ///   It is NOT finite-state exact with score_warp/linkwiggle cubic
100    ///   marginal-slope, because the multiplicative frailty breaks the
101    ///   polynomial kernel closure that the cubic cell derivatives require.
102    ///
103    /// Returns an error if the combination is not exactly integrable.
104    pub fn validate_for_marginal_slope(&self) -> Result<(), String> {
105        self.validate_for_marginal_slope_typed()
106            .map_err(|e| e.to_string())
107    }
108
109    /// Typed variant of [`Self::validate_for_marginal_slope`] used internally;
110    /// the `String`-returning entry point above is preserved as a one-line
111    /// shim for external callers.
112    pub fn validate_for_marginal_slope_typed(&self) -> Result<(), LognormalKernelError> {
113        match self {
114            Self::None | Self::GaussianShift { .. } => Ok(()),
115            Self::HazardMultiplier { .. } => Err(LognormalKernelError::InvalidSpec {
116                reason:
117                    "HazardMultiplier frailty is not finite-state exact with score_warp/linkwiggle \
118                     cubic marginal-slope families. Use GaussianShift frailty (exact probit scaling) \
119                     or use the standalone latent-cloglog/latent-survival families instead."
120                        .to_string(),
121            }),
122        }
123    }
124}
125
126// ─── Probit frailty scaling ──────────────────────────────────────────────────
127
128#[inline]
129fn probit_frailty_scale_components(sigma: f64) -> (f64, f64) {
130    let abs_sigma = sigma.abs();
131    if abs_sigma > 1.0 {
132        let inv = 1.0 / abs_sigma;
133        let denom = 1.0 + inv * inv;
134        (inv / denom.sqrt(), 1.0 / denom)
135    } else {
136        let sigma2 = sigma * sigma;
137        let denom = 1.0 + sigma2;
138        (1.0 / denom.sqrt(), sigma2 / denom)
139    }
140}
141
142/// Probit frailty scaling factor **with** t-derivatives (t = log σ).
143///
144/// Provides exact closed-form derivatives of s = 1/√(1+σ²) with respect to
145/// t = log(σ) for learnable Gaussian-shift frailty in the marginal-slope
146/// families.  For Gaussian frailty on the final probit index
147/// E[Φ(η + U)] = Φ(η · s) with s = 1/√(1+σ²); writing α = σ²/(1+σ²) the
148/// derivatives are ∂_t s = −α·s and ∂_{tt} s = α(3α−2)·s.
149#[derive(Clone, Copy, Debug)]
150pub struct ProbitFrailtyScaleJet {
151    /// s = 1/√(1+σ²)
152    pub s: f64,
153    /// α = σ²/(1+σ²)  — shared auxiliary for all derivative levels.
154    pub alpha: f64,
155    /// ∂_t s = -α·s
156    pub ds: f64,
157    /// ∂_{tt} s = α(3α−2)·s
158    pub d2s: f64,
159}
160
161impl ProbitFrailtyScaleJet {
162    /// Build the jet from σ (not from t = log σ).
163    ///
164    /// At σ = 0 the jet degenerates to (s=1, α=0, ds=0, d2s=0), which is
165    /// correct: zero frailty means s ≡ 1 independent of t.
166    pub fn new(sigma: f64) -> Self {
167        let (s, alpha) = probit_frailty_scale_components(sigma);
168        Self {
169            s,
170            alpha,
171            ds: -alpha * s,
172            d2s: alpha * (3.0 * alpha - 2.0) * s,
173        }
174    }
175
176    /// Build the jet from t = log(σ) directly.
177    pub fn from_log_sigma(log_sigma: f64) -> Self {
178        Self::new(log_sigma.exp())
179    }
180}
181
182#[inline]
183fn worst_mode(
184    a: IntegratedExpectationMode,
185    b: IntegratedExpectationMode,
186) -> IntegratedExpectationMode {
187    if a.rank() >= b.rank() { a } else { b }
188}
189
190// ─── Log-space kernel infrastructure ──────────────────────────────────────────
191//
192// The runtime kernel path stays in log-space until the final ratios are formed,
193// avoiding the overflow/underflow and cancellation problems that come from
194// exponentiating individual terms too early.
195
196/// Returns `log K_{k,m}(μ,σ)` directly, without exponentiation.
197///
198/// The value is always finite (or `NEG_INFINITY` when the kernel is zero), so
199/// it cannot overflow or underflow.
200#[inline]
201fn validate_kernel_inputs(m: f64, mu: f64, sigma: f64) -> Result<(), EstimationError> {
202    if !m.is_finite() || m < 0.0 {
203        crate::bail_invalid_estim!("lognormal kernel requires finite m >= 0, got {m}");
204    }
205    if !mu.is_finite() || !sigma.is_finite() || sigma < 0.0 {
206        crate::bail_invalid_estim!(
207            "lognormal kernel requires finite mu and sigma >= 0, got mu={mu}, sigma={sigma}"
208        );
209    }
210    Ok::<(), _>(())
211}
212
213#[inline]
214pub fn log_kernel_term(
215    quadctx: &QuadratureContext,
216    k: usize,
217    m: f64,
218    mu: f64,
219    sigma: f64,
220) -> Result<(f64, IntegratedExpectationMode), EstimationError> {
221    validate_kernel_inputs(m, mu, sigma)?;
222    let kf = k as f64;
223    let sigma2 = sigma * sigma;
224    if !sigma2.is_finite() {
225        crate::bail_invalid_estim!(
226            "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
227        );
228    }
229    let prefix_bound = kf * mu.abs() + 0.5 * kf * kf * sigma2;
230    if !prefix_bound.is_finite() {
231        crate::bail_invalid_estim!(
232            "lognormal kernel prefix is outside the finite exact-derivative range: k={k}, mu={mu}, sigma={sigma}"
233        );
234    }
235    let prefix = kf * mu + 0.5 * kf * kf * sigma2;
236    if m == 0.0 {
237        return Ok((prefix, IntegratedExpectationMode::ExactClosedForm));
238    }
239    let log_m = m.ln();
240    let shifted_bound = mu.abs() + kf * sigma2 + log_m.abs();
241    if !shifted_bound.is_finite() {
242        crate::bail_invalid_estim!(
243            "lognormal kernel shifted location is outside the finite exact-derivative range: k={k}, m={m}, mu={mu}, sigma={sigma}"
244        );
245    }
246    let shifted_mu = mu + kf * sigma2 + log_m;
247    // Survival carried in log space: prefix + ln S(shifted_mu, σ). This keeps the
248    // kernel's true magnitude when S underflows in value space at large σ — the
249    // old `laplace <= 0.0 → −∞` collapse discarded a large-but-finite log-value
250    // (#798) and the value-space asymptotic was biased low at σ ≥ 8 (#799).
251    let (log_laplace, mode) = lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
252    Ok((prefix + log_laplace, mode))
253}
254
255/// Kernel bundle storing `log K_{k,m}` values instead of `K_{k,m}`.
256#[derive(Clone, Debug)]
257pub struct LogLognormalKernelBundle {
258    pub log_values: Vec<f64>,
259    pub mode: IntegratedExpectationMode,
260}
261
262impl LogLognormalKernelBundle {
263    #[inline]
264    pub fn get(&self, k: usize) -> f64 {
265        self.log_values[k]
266    }
267
268    #[inline]
269    pub fn len(&self) -> usize {
270        self.log_values.len()
271    }
272}
273
274/// Builds a log-space kernel bundle for `k = 0, 1, …, max_k` at fixed
275/// `(m, μ, σ)`.
276pub fn log_kernel_bundle(
277    quadctx: &QuadratureContext,
278    m: f64,
279    mu: f64,
280    sigma: f64,
281    max_k: usize,
282) -> Result<LogLognormalKernelBundle, EstimationError> {
283    validate_kernel_inputs(m, mu, sigma)?;
284    let mut log_values = Vec::with_capacity(max_k + 1);
285    let sigma2 = sigma * sigma;
286    if !sigma2.is_finite() {
287        crate::bail_invalid_estim!(
288            "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
289        );
290    }
291    let max_kf = max_k as f64;
292    let prefix_bound = max_kf * mu.abs() + 0.5 * max_kf * max_kf * sigma2;
293    if !prefix_bound.is_finite() {
294        crate::bail_invalid_estim!(
295            "lognormal kernel bundle prefix is outside the finite exact-derivative range: max_k={max_k}, mu={mu}, sigma={sigma}"
296        );
297    }
298    if m == 0.0 {
299        let mut prefix = 0.0;
300        for k in 0..=max_k {
301            log_values.push(prefix);
302            prefix += mu + (k as f64 + 0.5) * sigma2;
303        }
304        return Ok(LogLognormalKernelBundle {
305            log_values,
306            mode: IntegratedExpectationMode::ExactClosedForm,
307        });
308    }
309
310    let log_m = m.ln();
311    let shifted_bound = mu.abs() + max_kf * sigma2 + log_m.abs();
312    if !shifted_bound.is_finite() {
313        crate::bail_invalid_estim!(
314            "lognormal kernel bundle shifted location is outside the finite exact-derivative range: max_k={max_k}, m={m}, mu={mu}, sigma={sigma}"
315        );
316    }
317    let mut shifted_mu = mu + log_m;
318    let mut prefix = 0.0;
319    let mut mode = IntegratedExpectationMode::ExactClosedForm;
320    for k in 0..=max_k {
321        let (log_laplace, val_mode) =
322            lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
323        log_values.push(if log_laplace.is_finite() {
324            prefix + log_laplace
325        } else {
326            f64::NEG_INFINITY
327        });
328        mode = worst_mode(mode, val_mode);
329        prefix += mu + (k as f64 + 0.5) * sigma2;
330        shifted_mu += sigma2;
331    }
332    Ok(LogLognormalKernelBundle { log_values, mode })
333}
334
335/// Computes the value-space derivative ratios `∂ⁿ_μ K_{k,m} / K_{k,m}`
336/// from a log-space bundle.
337///
338/// Returns `[1, K'/K, K''/K, K'''/K, K''''/K]` where only the first
339/// `order + 1` entries are valid.
340///
341/// The recurrences are applied in ratio form, with each `K_{k+r}/K_k`
342/// computed as `exp(log K_{k+r} − log K_k)`, which remains finite even when
343/// the individual kernel values would overflow or underflow.
344pub fn kernel_ratio_jet(
345    log_bundle: &LogLognormalKernelBundle,
346    k: usize,
347    m: f64,
348    order: usize,
349) -> [f64; 5] {
350    let kf = k as f64;
351    let log_k0 = log_bundle.get(k);
352
353    // Precompute ratios K_{k+r}/K_k for r = 1..=order, each from a single
354    // log-difference.  This avoids redundant exp() calls when the same ratio
355    // appears in multiple derivative orders.
356    let mut rk = [0.0f64; 5]; // rk[0] unused; rk[r] = K_{k+r}/K_k
357    for r in 1..=order.min(4) {
358        let delta = log_bundle.get(k + r) - log_k0;
359        rk[r] = if delta.is_finite() {
360            delta.exp()
361        } else if delta > 0.0 {
362            f64::INFINITY
363        } else {
364            0.0
365        };
366    }
367
368    let mut jet = [0.0; 5];
369    jet[0] = 1.0;
370
371    if order >= 1 {
372        jet[1] = kf - m * rk[1];
373    }
374    if order >= 2 {
375        jet[2] = kf * kf - (2.0 * kf + 1.0) * m * rk[1] + m * m * rk[2];
376    }
377    if order >= 3 {
378        jet[3] = kf * kf * kf - (3.0 * kf * kf + 3.0 * kf + 1.0) * m * rk[1]
379            + 3.0 * (kf + 1.0) * m * m * rk[2]
380            - m * m * m * rk[3];
381    }
382    if order >= 4 {
383        let k2 = kf * kf;
384        let k3 = k2 * kf;
385        let k4 = k3 * kf;
386        let m2 = m * m;
387        let m3 = m2 * m;
388        let m4 = m3 * m;
389        jet[4] = k4 - (4.0 * k3 + 6.0 * k2 + 4.0 * kf + 1.0) * m * rk[1]
390            + (6.0 * k2 + 12.0 * kf + 7.0) * m2 * rk[2]
391            - (4.0 * kf + 6.0) * m3 * rk[3]
392            + m4 * rk[4];
393    }
394
395    jet
396}
397
398// `LatentCLogLogJet5` + `latent_cloglog_jet5` / `latent_cloglog_inverse_link_jet`
399// moved DOWN to `crate::quadrature` (#1135), co-located with their analytic
400// backend, so the `solver` link layer names them without importing up into
401// `families::survival`. Re-exported here so the in-family callers (e.g.
402// `family_runtime`) keep resolving.
403pub use crate::quadrature::{
404    LatentCLogLogJet5, latent_cloglog_inverse_link_jet, latent_cloglog_jet5,
405};
406
407// ─── LogKernelSumJet: log-sum derivatives from log-space bundles ─────────────
408
409/// A single signed term in a kernel sum: coefficient × K_{k,m}.
410#[derive(Clone, Copy, Debug)]
411pub struct KernelSumTerm {
412    /// Multiplicative coefficient (can be negative for difference terms).
413    pub coeff: f64,
414    /// Kernel order parameter k.
415    pub k: usize,
416    /// Kernel mass parameter m (≥ 0).
417    pub m: f64,
418}
419
420/// Derivatives of `log(Σ_j a_j · K_{k_j, m_j}(μ, σ))` with respect to μ.
421///
422/// This is the workhorse for row-level log-likelihood derivatives in all
423/// latent families.  The numerator and denominator of a row likelihood are
424/// each a small signed sum of kernel terms.
425///
426/// The value path is assembled from log-space kernel bundles and ratio jets,
427/// so individual kernel terms are never exponentiated before the final signed
428/// sum. That avoids the old overflow/underflow problems from value-space
429/// kernels. When the signed sum is zero or negative, this returns an invalid
430/// row (`value = -∞`) instead of trying to continue with a floored surrogate.
431/// Signed two-term differences (e.g. interval censoring `K_{0,M_L} − K_{0,M_R}`)
432/// are still combined through the shared sign-aware log-sum path.
433#[derive(Clone, Copy, Debug)]
434pub struct LogKernelSumJet {
435    /// log(Σ a_j K_j)
436    pub value: f64,
437    /// d/dμ log(Σ a_j K_j)
438    pub d1: f64,
439    /// d²/dμ² log(Σ a_j K_j)
440    pub d2: f64,
441    /// d³/dμ³ log(Σ a_j K_j)
442    pub d3: f64,
443    /// d⁴/dμ⁴ log(Σ a_j K_j)
444    pub d4: f64,
445    pub mode: IntegratedExpectationMode,
446}
447
448impl LogKernelSumJet {
449    #[inline]
450    fn non_positive(mode: IntegratedExpectationMode) -> Self {
451        Self {
452            value: f64::NEG_INFINITY,
453            d1: 0.0,
454            d2: 0.0,
455            d3: 0.0,
456            d4: 0.0,
457            mode,
458        }
459    }
460
461    #[inline]
462    fn from_log_value_and_ratios(
463        value: f64,
464        ratio: [f64; 5],
465        mode: IntegratedExpectationMode,
466    ) -> Self {
467        let r1 = ratio[1];
468        let r2 = ratio[2];
469        let r3 = ratio[3];
470        let r4 = ratio[4];
471        Self {
472            value,
473            d1: r1,
474            d2: r2 - r1 * r1,
475            d3: r3 - 3.0 * r1 * r2 + 2.0 * r1 * r1 * r1,
476            d4: r4 - 4.0 * r1 * r3 - 3.0 * r2 * r2 + 12.0 * r1 * r1 * r2 - 6.0 * r1.powi(4),
477            mode,
478        }
479    }
480
481    #[inline]
482    fn term_log_mag_and_ratio(
483        bundle: &LogLognormalKernelBundle,
484        term: KernelSumTerm,
485    ) -> (f64, [f64; 5]) {
486        (
487            term.coeff.abs().ln() + bundle.get(term.k),
488            // d4 is used by the exact log-sigma curvature, so this must carry
489            // ratios through order 4 rather than truncating at order 3.
490            kernel_ratio_jet(bundle, term.k, term.m, 4),
491        )
492    }
493
494    fn evaluate_two_terms(
495        quadctx: &QuadratureContext,
496        t0: KernelSumTerm,
497        t1: KernelSumTerm,
498        mu: f64,
499        sigma: f64,
500    ) -> Result<Self, EstimationError> {
501        let max_k_needed = t0.k.max(t1.k) + 4;
502        let bundle0 = log_kernel_bundle(quadctx, t0.m, mu, sigma, max_k_needed)?;
503        let mut overall_mode = bundle0.mode;
504        let bundle1_owned = if (t0.m - t1.m).abs() < 1e-300 {
505            None
506        } else {
507            let bundle1 = log_kernel_bundle(quadctx, t1.m, mu, sigma, max_k_needed)?;
508            overall_mode = worst_mode(overall_mode, bundle1.mode);
509            Some(bundle1)
510        };
511        let bundle1 = bundle1_owned.as_ref().unwrap_or(&bundle0);
512
513        let (log_mag0, ratio0) = Self::term_log_mag_and_ratio(&bundle0, t0);
514        let (log_mag1, ratio1) = Self::term_log_mag_and_ratio(bundle1, t1);
515        let log_mags = [log_mag0, log_mag1];
516        let signs = [t0.coeff.signum(), t1.coeff.signum()];
517        let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
518        if !log_s.is_finite() || sign_s <= 0.0 {
519            return Ok(Self::non_positive(overall_mode));
520        }
521
522        let w0 = sign_s * signs[0] * (log_mag0 - log_s).exp();
523        let w1 = sign_s * signs[1] * (log_mag1 - log_s).exp();
524        let wr1 = w0 * ratio0[1] + w1 * ratio1[1];
525        let wr2 = w0 * ratio0[2] + w1 * ratio1[2];
526        let wr3 = w0 * ratio0[3] + w1 * ratio1[3];
527        let wr4 = w0 * ratio0[4] + w1 * ratio1[4];
528
529        Ok(Self {
530            value: log_s,
531            d1: wr1,
532            d2: wr2 - wr1 * wr1,
533            d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
534            d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
535                - 6.0 * wr1.powi(4),
536            mode: overall_mode,
537        })
538    }
539
540    /// Evaluate for a single positive kernel term (fast path).
541    ///
542    /// Computes `log(K_{k,m})` and its μ-derivatives from exact recurrences,
543    /// entirely in log-space.
544    pub fn single_term(
545        quadctx: &QuadratureContext,
546        k: usize,
547        m: f64,
548        mu: f64,
549        sigma: f64,
550    ) -> Result<Self, EstimationError> {
551        let max_k_needed = k + 4;
552        let lb = log_kernel_bundle(quadctx, m, mu, sigma, max_k_needed)?;
553        Ok(Self::from_log_value_and_ratios(
554            lb.get(k),
555            kernel_ratio_jet(&lb, k, m, 4),
556            lb.mode,
557        ))
558    }
559
560    /// Evaluate `log(Σ a_j K_j)` and its μ-derivatives for a small signed sum.
561    ///
562    /// All terms share the same `(μ, σ)`.  Both the value and derivative
563    /// ratios are computed entirely in log-space.  The runtime latent-survival
564    /// rows in this repo are almost always one-term or two-term sums, so those
565    /// cases stay on dedicated stack paths; the heap-backed logic below is only
566    /// for genuinely longer symbolic sums:
567    ///
568    /// 1. Per-term log-magnitudes `log|a_j| + log K_{k_j,m_j}` and signs.
569    /// 2. Sign-aware log-sum-exp to get `log|S|` and `sign(S)`.
570    /// 3. Importance weights `w_j = a_j K_j / S` formed in log-space.
571    /// 4. Weighted ratio sums `R_n = Σ w_j · (∂ⁿK_j / K_j)` for the
572    ///    final log-derivatives.
573    pub fn evaluate(
574        quadctx: &QuadratureContext,
575        terms: &[KernelSumTerm],
576        mu: f64,
577        sigma: f64,
578    ) -> Result<Self, EstimationError> {
579        if terms.is_empty() {
580            // Empty sums are a caller-contract violation, not a degenerate row.
581            // Return an input error so callers can report the malformed kernel sum.
582            crate::bail_invalid_estim!("KernelSumJet requires at least one term");
583        }
584
585        // Fast path for single term.
586        if terms.len() == 1 {
587            let t = &terms[0];
588            if t.coeff <= 0.0 {
589                // Negative or zero coefficient: the sum is non-positive, so
590                // log(sum) is undefined.  Return −∞ (impossible observation),
591                // matching the general path's sign_s ≤ 0 branch.
592                return Ok(Self::non_positive(
593                    IntegratedExpectationMode::ExactClosedForm,
594                ));
595            }
596            let jet = Self::single_term(quadctx, t.k, t.m, mu, sigma)?;
597            return Ok(Self {
598                value: t.coeff.ln() + jet.value,
599                d1: jet.d1,
600                d2: jet.d2,
601                d3: jet.d3,
602                d4: jet.d4,
603                mode: jet.mode,
604            });
605        }
606        if terms.len() == 2 {
607            return Self::evaluate_two_terms(quadctx, terms[0], terms[1], mu, sigma);
608        }
609
610        let max_k_needed = terms.iter().map(|t| t.k).max().unwrap_or(0) + 4;
611
612        // Build log-bundles for each unique mass.
613        let mut log_bundles: Vec<(f64, LogLognormalKernelBundle)> = Vec::with_capacity(2);
614        let mut overall_mode = IntegratedExpectationMode::ExactClosedForm;
615        for term in terms {
616            if !log_bundles
617                .iter()
618                .any(|(m, _)| (*m - term.m).abs() < 1e-300)
619            {
620                let b = log_kernel_bundle(quadctx, term.m, mu, sigma, max_k_needed)?;
621                overall_mode = worst_mode(overall_mode, b.mode);
622                log_bundles.push((term.m, b));
623            }
624        }
625
626        let get_lb = |m: f64| -> &LogLognormalKernelBundle {
627            &log_bundles
628                .iter()
629                .find(|(bm, _)| (*bm - m).abs() < 1e-300)
630                .unwrap()
631                .1
632        };
633
634        // Per-term: log magnitude, sign, and ratio jet.
635        let mut log_mags: Vec<f64> = Vec::with_capacity(terms.len());
636        let mut signs: Vec<f64> = Vec::with_capacity(terms.len());
637        let mut ratios: Vec<[f64; 5]> = Vec::with_capacity(terms.len());
638        for term in terms {
639            let lb = get_lb(term.m);
640            log_mags.push(term.coeff.abs().ln() + lb.get(term.k));
641            signs.push(term.coeff.signum());
642            ratios.push(kernel_ratio_jet(lb, term.k, term.m, 4));
643        }
644
645        // Sign-aware log-sum-exp: compute log|S| and sign(S).
646        let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
647
648        if !log_s.is_finite() || sign_s <= 0.0 {
649            // Sum is zero or negative — degenerate row.
650            return Ok(Self::non_positive(overall_mode));
651        }
652
653        // Importance weights w_j = sign(S) · sign(a_j) · exp(log|a_j K_j| − log|S|).
654        // When S > 0 and all terms have well-defined kernels, Σ w_j = 1.
655        let mut wr1 = 0.0;
656        let mut wr2 = 0.0;
657        let mut wr3 = 0.0;
658        let mut wr4 = 0.0;
659        for i in 0..terms.len() {
660            let w = sign_s * signs[i] * (log_mags[i] - log_s).exp();
661            wr1 += w * ratios[i][1];
662            wr2 += w * ratios[i][2];
663            wr3 += w * ratios[i][3];
664            wr4 += w * ratios[i][4];
665        }
666
667        Ok(Self {
668            value: log_s,
669            d1: wr1,
670            d2: wr2 - wr1 * wr1,
671            d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
672            d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
673                - 6.0 * wr1.powi(4),
674            mode: overall_mode,
675        })
676    }
677}
678
679// ─── Latent survival sufficient statistics ───────────────────────────────────
680
681/// Event type for compiled survival sufficient statistics.
682#[derive(Clone, Copy, Debug, PartialEq, Eq)]
683pub enum LatentSurvivalEventType {
684    /// Right-censored: observed alive in the observation window.
685    RightCensored,
686    /// Exact event: event observed at a known time.
687    ExactEvent,
688    /// Interval-censored: event known to occur in (t_left, t_right].
689    IntervalCensored,
690}
691
692impl fmt::Display for LatentSurvivalEventType {
693    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
694        match self {
695            Self::RightCensored => write!(f, "right_censored"),
696            Self::ExactEvent => write!(f, "exact_event"),
697            Self::IntervalCensored => write!(f, "interval_censored"),
698        }
699    }
700}
701
702/// Row-level sufficient statistics for one latent survival observation.
703///
704/// This is the canonical row representation used by both fitted-family
705/// evaluation and saved-model prediction.
706///
707/// For the full-loading model (frailty multiplies entire hazard):
708///   mass_loaded = total cumulative hazard mass
709///   mass_unloaded = 0
710///
711/// For the loaded-vs-unloaded model (Gompertz-Makeham):
712///   mass_loaded = integrated disease hazard component
713///   mass_unloaded = integrated background hazard component (not frailty-modified)
714///
715/// The unloaded mass contributes a simple exp(-M_U) prefactor.
716#[derive(Clone, Copy, Debug)]
717pub struct LatentSurvivalRow {
718    pub event_type: LatentSurvivalEventType,
719    /// Cumulative nuisance mass at entry: B(a_in).
720    /// Zero if there is no left truncation.
721    pub mass_entry: f64,
722    /// Cumulative nuisance mass at exit/event: B(a_out) or B(a_event).
723    pub mass_exit: f64,
724    /// For interval censoring: mass at left boundary B(a_L).
725    pub mass_left: f64,
726    /// For interval censoring: mass at right boundary B(a_R).
727    pub mass_right: f64,
728    /// For interval censoring: unloaded mass at left boundary.
729    pub mass_unloaded_left: f64,
730    /// For interval censoring: unloaded mass at right boundary.
731    pub mass_unloaded_right: f64,
732    /// Unloaded (background) cumulative mass at entry (0 for full loading).
733    pub mass_unloaded_entry: f64,
734    /// Unloaded (background) cumulative mass at exit.
735    pub mass_unloaded_exit: f64,
736    /// Loaded instantaneous hazard at event time (for exact events).
737    pub hazard_loaded: f64,
738    /// Unloaded instantaneous hazard at event time (for exact events).
739    pub hazard_unloaded: f64,
740}
741
742impl LatentSurvivalRow {
743    /// Delayed-entry right-censored row with explicit loaded/unloaded masses.
744    ///
745    /// `mass_entry` and `mass_exit` are cumulative loaded masses `B_L(a_in)`
746    /// and `B_L(a_out)` for this row object. They are not an increment over
747    /// `(a_in, a_out]`.
748    pub fn right_censored(
749        mass_entry: f64,
750        mass_exit: f64,
751        mass_unloaded_entry: f64,
752        mass_unloaded_exit: f64,
753    ) -> Self {
754        Self {
755            event_type: LatentSurvivalEventType::RightCensored,
756            mass_entry,
757            mass_exit,
758            mass_left: 0.0,
759            mass_right: 0.0,
760            mass_unloaded_left: 0.0,
761            mass_unloaded_right: 0.0,
762            mass_unloaded_entry,
763            mass_unloaded_exit,
764            hazard_loaded: 0.0,
765            hazard_unloaded: 0.0,
766        }
767    }
768
769    /// Delayed-entry exact-event row with explicit loaded/unloaded hazard parts.
770    pub fn exact_event(
771        mass_entry: f64,
772        mass_exit: f64,
773        mass_unloaded_entry: f64,
774        mass_unloaded_exit: f64,
775        hazard_loaded: f64,
776        hazard_unloaded: f64,
777    ) -> Self {
778        Self {
779            event_type: LatentSurvivalEventType::ExactEvent,
780            mass_entry,
781            mass_exit,
782            mass_left: 0.0,
783            mass_right: 0.0,
784            mass_unloaded_left: 0.0,
785            mass_unloaded_right: 0.0,
786            mass_unloaded_entry,
787            mass_unloaded_exit,
788            hazard_loaded,
789            hazard_unloaded,
790        }
791    }
792
793    /// Delayed-entry interval-censored row with explicit loaded/unloaded masses.
794    pub fn interval_censored(
795        mass_entry: f64,
796        mass_left: f64,
797        mass_right: f64,
798        mass_unloaded_entry: f64,
799        mass_unloaded_left: f64,
800        mass_unloaded_right: f64,
801    ) -> Self {
802        Self {
803            event_type: LatentSurvivalEventType::IntervalCensored,
804            mass_entry,
805            mass_exit: 0.0,
806            mass_left,
807            mass_right,
808            mass_unloaded_left,
809            mass_unloaded_right,
810            mass_unloaded_entry,
811            mass_unloaded_exit: 0.0,
812            hazard_loaded: 0.0,
813            hazard_unloaded: 0.0,
814        }
815    }
816
817    pub fn validate(&self) -> Result<(), EstimationError> {
818        let fields = [
819            ("mass_entry", self.mass_entry),
820            ("mass_exit", self.mass_exit),
821            ("mass_left", self.mass_left),
822            ("mass_right", self.mass_right),
823            ("mass_unloaded_left", self.mass_unloaded_left),
824            ("mass_unloaded_right", self.mass_unloaded_right),
825            ("mass_unloaded_entry", self.mass_unloaded_entry),
826            ("mass_unloaded_exit", self.mass_unloaded_exit),
827            ("hazard_loaded", self.hazard_loaded),
828            ("hazard_unloaded", self.hazard_unloaded),
829        ];
830        for (name, value) in fields {
831            if !value.is_finite() || value < 0.0 {
832                crate::bail_invalid_estim!(
833                    "latent survival row has invalid {name}={value}; expected a finite non-negative value"
834                );
835            }
836        }
837
838        match self.event_type {
839            LatentSurvivalEventType::RightCensored => {
840                if self.mass_exit < self.mass_entry {
841                    crate::bail_invalid_estim!(
842                        "latent survival right-censored row requires mass_exit >= mass_entry, got {} < {}",
843                        self.mass_exit,
844                        self.mass_entry
845                    );
846                }
847                if self.mass_unloaded_exit < self.mass_unloaded_entry {
848                    crate::bail_invalid_estim!(
849                        "latent survival right-censored row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
850                        self.mass_unloaded_exit,
851                        self.mass_unloaded_entry
852                    );
853                }
854                if self.mass_left > 0.0
855                    || self.mass_right > 0.0
856                    || self.mass_unloaded_left > 0.0
857                    || self.mass_unloaded_right > 0.0
858                    || self.hazard_loaded > 0.0
859                    || self.hazard_unloaded > 0.0
860                {
861                    crate::bail_invalid_estim!("latent survival right-censored row cannot carry interval masses or event hazards"
862                            .to_string(),);
863                }
864            }
865            LatentSurvivalEventType::ExactEvent => {
866                if self.mass_exit < self.mass_entry {
867                    crate::bail_invalid_estim!(
868                        "latent survival exact-event row requires mass_exit >= mass_entry, got {} < {}",
869                        self.mass_exit,
870                        self.mass_entry
871                    );
872                }
873                if self.mass_unloaded_exit < self.mass_unloaded_entry {
874                    crate::bail_invalid_estim!(
875                        "latent survival exact-event row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
876                        self.mass_unloaded_exit,
877                        self.mass_unloaded_entry
878                    );
879                }
880                if self.mass_left > 0.0
881                    || self.mass_right > 0.0
882                    || self.mass_unloaded_left > 0.0
883                    || self.mass_unloaded_right > 0.0
884                {
885                    crate::bail_invalid_estim!(
886                        "latent survival exact-event row cannot carry interval masses"
887                    );
888                }
889                if self.hazard_loaded == 0.0 && self.hazard_unloaded == 0.0 {
890                    crate::bail_invalid_estim!("latent survival exact-event row requires a positive loaded or unloaded hazard"
891                            .to_string(),);
892                }
893            }
894            LatentSurvivalEventType::IntervalCensored => {
895                if self.mass_left < self.mass_entry || self.mass_right < self.mass_left {
896                    crate::bail_invalid_estim!(
897                        "latent survival interval row requires mass_entry <= mass_left <= mass_right, got entry={}, left={}, right={}",
898                        self.mass_entry,
899                        self.mass_left,
900                        self.mass_right
901                    );
902                }
903                if self.mass_unloaded_left < self.mass_unloaded_entry
904                    || self.mass_unloaded_right < self.mass_unloaded_left
905                {
906                    crate::bail_invalid_estim!(
907                        "latent survival interval row requires unloaded_entry <= unloaded_left <= unloaded_right, got entry={}, left={}, right={}",
908                        self.mass_unloaded_entry,
909                        self.mass_unloaded_left,
910                        self.mass_unloaded_right
911                    );
912                }
913                if self.mass_exit > 0.0
914                    || self.mass_unloaded_exit > 0.0
915                    || self.hazard_loaded > 0.0
916                    || self.hazard_unloaded > 0.0
917                {
918                    crate::bail_invalid_estim!(
919                        "latent survival interval row cannot carry exit masses or event hazards"
920                            .to_string(),
921                    );
922                }
923            }
924        }
925
926        Ok(())
927    }
928}
929
930fn exact_event_kernel_jet(
931    quadctx: &QuadratureContext,
932    row: &LatentSurvivalRow,
933    mu: f64,
934    sigma: f64,
935) -> Result<LogKernelSumJet, EstimationError> {
936    if row.hazard_loaded < 0.0 || row.hazard_unloaded < 0.0 {
937        crate::bail_invalid_estim!(
938            "latent survival exact-event hazards must be non-negative, got loaded={} unloaded={}",
939            row.hazard_loaded,
940            row.hazard_unloaded
941        );
942    }
943    match (row.hazard_unloaded > 0.0, row.hazard_loaded > 0.0) {
944        (true, true) => {
945            let terms = [
946                KernelSumTerm {
947                    coeff: row.hazard_unloaded,
948                    k: 0,
949                    m: row.mass_exit,
950                },
951                KernelSumTerm {
952                    coeff: row.hazard_loaded,
953                    k: 1,
954                    m: row.mass_exit,
955                },
956            ];
957            LogKernelSumJet::evaluate(quadctx, &terms, mu, sigma)
958        }
959        (true, false) => {
960            let jet = LogKernelSumJet::single_term(quadctx, 0, row.mass_exit, mu, sigma)?;
961            Ok(LogKernelSumJet {
962                value: row.hazard_unloaded.ln() + jet.value,
963                d1: jet.d1,
964                d2: jet.d2,
965                d3: jet.d3,
966                d4: jet.d4,
967                mode: jet.mode,
968            })
969        }
970        (false, true) => {
971            let jet = LogKernelSumJet::single_term(quadctx, 1, row.mass_exit, mu, sigma)?;
972            Ok(LogKernelSumJet {
973                value: row.hazard_loaded.ln() + jet.value,
974                d1: jet.d1,
975                d2: jet.d2,
976                d3: jet.d3,
977                d4: jet.d4,
978                mode: jet.mode,
979            })
980        }
981        (false, false) => Err(EstimationError::InvalidInput(
982            "latent survival exact-event row requires a positive loaded or unloaded hazard"
983                .to_string(),
984        )),
985    }
986}
987
988/// Row-level log-likelihood and μ-derivatives for the latent survival model.
989///
990/// The conditional model is:
991///   `Λ(a | U) = B(a) · exp(U)`,  `U ~ N(μ, σ²)`
992///
993/// All likelihoods reduce to algebra on `K_{k,m}(μ, σ)`.
994#[derive(Clone, Copy, Debug)]
995pub struct LatentSurvivalRowJet {
996    pub log_lik: f64,
997    pub score: f64,
998    pub neg_hessian: f64,
999    pub d3: f64,
1000    pub score_log_sigma: f64,
1001    pub neg_hessian_log_sigma: f64,
1002}
1003
1004#[inline]
1005fn log_sigma_score_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1006    let sigma2 = sigma * sigma;
1007    sigma2 * (jet.d2 + jet.d1 * jet.d1)
1008}
1009
1010#[inline]
1011fn log_sigma_neg_hessian_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1012    let sigma2 = sigma * sigma;
1013    let sigma4 = sigma2 * sigma2;
1014    let d1 = jet.d1;
1015    let d2 = jet.d2;
1016    let d3 = jet.d3;
1017    let d4 = jet.d4;
1018    let s2_over_s = d2 + d1 * d1;
1019    // For S = Σ a_j K_j, D = σ ∂_σ, and D S = σ² S_μμ:
1020    // D² log S = 2σ² (S''/S) + σ⁴ (S''''/S - (S''/S)²).
1021    // Express the final parenthesized term directly in log-derivatives to
1022    // avoid the larger cancellation in `r4 - r2²`.
1023    let s4_over_s_minus_s2_sq = d4 + 4.0 * d1 * d3 + 2.0 * d2 * d2 + 4.0 * d1 * d1 * d2;
1024    -(2.0 * sigma2 * s2_over_s + sigma4 * s4_over_s_minus_s2_sq)
1025}
1026
1027impl LatentSurvivalRowJet {
1028    pub fn evaluate(
1029        quadctx: &QuadratureContext,
1030        row: &LatentSurvivalRow,
1031        mu: f64,
1032        sigma: f64,
1033    ) -> Result<Self, EstimationError> {
1034        row.validate()?;
1035        match row.event_type {
1036            LatentSurvivalEventType::RightCensored => Self::right_censored(quadctx, mu, sigma, row),
1037            LatentSurvivalEventType::ExactEvent => Self::exact_event(quadctx, mu, sigma, row),
1038            LatentSurvivalEventType::IntervalCensored => {
1039                Self::interval_censored(quadctx, mu, sigma, row)
1040            }
1041        }
1042    }
1043
1044    /// Right-censoring with loaded/unloaded mass decomposition.
1045    ///
1046    /// Full formula:
1047    ///   `ℓ = -M_U_exit + log K_{0,M_L_exit} + M_U_entry - log K_{0,M_L_entry}`
1048    ///
1049    /// When `mass_unloaded_exit == 0` and `mass_unloaded_entry == 0`, this
1050    /// falls back to the original formula using `mass_exit` / `mass_entry`.
1051    fn right_censored(
1052        quadctx: &QuadratureContext,
1053        mu: f64,
1054        sigma: f64,
1055        row: &LatentSurvivalRow,
1056    ) -> Result<Self, EstimationError> {
1057        let has_unloaded =
1058            row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300;
1059
1060        // Loaded mass for the kernel terms: when unloaded mass is present,
1061        // mass_exit contains only the loaded component; otherwise it is the
1062        // total mass.
1063        let mass_exit_loaded = row.mass_exit;
1064        let mass_entry_loaded = row.mass_entry;
1065
1066        // Unloaded mass contributes a simple additive constant to log-lik
1067        let unloaded_offset = if has_unloaded {
1068            -row.mass_unloaded_exit + row.mass_unloaded_entry
1069        } else {
1070            0.0
1071        };
1072
1073        let num = LogKernelSumJet::single_term(quadctx, 0, mass_exit_loaded, mu, sigma)?;
1074        if mass_entry_loaded > 1e-300 {
1075            let den = LogKernelSumJet::single_term(quadctx, 0, mass_entry_loaded, mu, sigma)?;
1076            Ok(Self {
1077                log_lik: unloaded_offset + num.value - den.value,
1078                score: num.d1 - den.d1,
1079                neg_hessian: -(num.d2 - den.d2),
1080                d3: num.d3 - den.d3,
1081                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1082                    - log_sigma_score_from_log_sum(&den, sigma),
1083                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1084                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1085            })
1086        } else {
1087            Ok(Self {
1088                log_lik: unloaded_offset + num.value,
1089                score: num.d1,
1090                neg_hessian: -num.d2,
1091                d3: num.d3,
1092                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1093                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1094            })
1095        }
1096    }
1097
1098    /// Exact event with loaded/unloaded hazard decomposition.
1099    ///
1100    /// `ℓ = log(h_U · K_{0,M_L} + h_L · K_{1,M_L}) - M_U_event + M_U_entry - log K_{0,M_L_entry}`
1101    fn exact_event(
1102        quadctx: &QuadratureContext,
1103        mu: f64,
1104        sigma: f64,
1105        row: &LatentSurvivalRow,
1106    ) -> Result<Self, EstimationError> {
1107        let unloaded_offset =
1108            if row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300 {
1109                -row.mass_unloaded_exit + row.mass_unloaded_entry
1110            } else {
1111                0.0
1112            };
1113        let num = exact_event_kernel_jet(quadctx, row, mu, sigma)?;
1114
1115        if row.mass_entry > 1e-300 {
1116            let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1117            Ok(Self {
1118                log_lik: unloaded_offset + num.value - den.value,
1119                score: num.d1 - den.d1,
1120                neg_hessian: -(num.d2 - den.d2),
1121                d3: num.d3 - den.d3,
1122                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1123                    - log_sigma_score_from_log_sum(&den, sigma),
1124                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1125                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1126            })
1127        } else {
1128            Ok(Self {
1129                log_lik: unloaded_offset + num.value,
1130                score: num.d1,
1131                neg_hessian: -num.d2,
1132                d3: num.d3,
1133                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1134                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1135            })
1136        }
1137    }
1138
1139    /// Interval event: `ℓ = log(K_{0,M_L} − K_{0,M_R}) − log K_{0,M_in}`.
1140    fn interval_censored(
1141        quadctx: &QuadratureContext,
1142        mu: f64,
1143        sigma: f64,
1144        row: &LatentSurvivalRow,
1145    ) -> Result<Self, EstimationError> {
1146        let num_terms = [
1147            KernelSumTerm {
1148                coeff: (-row.mass_unloaded_left).exp(),
1149                k: 0,
1150                m: row.mass_left,
1151            },
1152            KernelSumTerm {
1153                coeff: -(-row.mass_unloaded_right).exp(),
1154                k: 0,
1155                m: row.mass_right,
1156            },
1157        ];
1158        let num = LogKernelSumJet::evaluate(quadctx, &num_terms, mu, sigma)?;
1159
1160        if row.mass_entry > 1e-300 {
1161            let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1162            Ok(Self {
1163                log_lik: num.value + row.mass_unloaded_entry - den.value,
1164                score: num.d1 - den.d1,
1165                neg_hessian: -(num.d2 - den.d2),
1166                d3: num.d3 - den.d3,
1167                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1168                    - log_sigma_score_from_log_sum(&den, sigma),
1169                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1170                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1171            })
1172        } else {
1173            Ok(Self {
1174                log_lik: num.value + row.mass_unloaded_entry,
1175                score: num.d1,
1176                neg_hessian: -num.d2,
1177                d3: num.d3,
1178                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1179                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1180            })
1181        }
1182    }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187    use super::*;
1188
1189    fn latent_binomial_row_log_lik(
1190        ctx: &QuadratureContext,
1191        eta: f64,
1192        sigma: f64,
1193        y: f64,
1194        weight: f64,
1195    ) -> f64 {
1196        let mu = latent_cloglog_jet5(ctx, eta, sigma)
1197            .expect("latent jet")
1198            .mean;
1199        let mu = mu.clamp(1e-12, 1.0 - 1e-12);
1200        weight * (y * mu.ln() + (1.0 - y) * (1.0 - mu).ln())
1201    }
1202
1203    #[test]
1204    fn kernel_ratio_jet_d1_fd_check() {
1205        let ctx = QuadratureContext::new();
1206        let mu = 0.3;
1207        let sigma = 0.5;
1208        let m = 1.0;
1209        let k = 0usize;
1210        let h = 1e-5;
1211
1212        let bundle = log_kernel_bundle(&ctx, m, mu, sigma, k + 4).unwrap();
1213        let log_k = bundle.get(k);
1214        let ratios = kernel_ratio_jet(&bundle, k, m, 2);
1215        let kc = log_k.exp();
1216        let d1 = kc * ratios[1];
1217        let d2 = kc * ratios[2];
1218
1219        let kp = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0.exp();
1220        let km = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0.exp();
1221        let fd_d1 = (kp - km) / (2.0 * h);
1222        assert!(
1223            (d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-4,
1224            "d1: jet={d1}, fd={fd_d1}",
1225        );
1226
1227        let fd_d2 = (kp - 2.0 * kc + km) / (h * h);
1228        assert!(
1229            (d2 - fd_d2).abs() / fd_d2.abs().max(1e-15) < 1e-3,
1230            "d2: jet={d2}, fd={fd_d2}",
1231        );
1232    }
1233
1234    #[test]
1235    fn survival_right_censored_score_fd() {
1236        let ctx = QuadratureContext::new();
1237        let mu = -0.5;
1238        let sigma = 0.3;
1239        let h = 1e-6;
1240        let row = LatentSurvivalRow::right_censored(0.0, 2.0, 0.0, 0.0);
1241        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1242            .unwrap()
1243            .log_lik;
1244        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1245            .unwrap()
1246            .log_lik;
1247        let fd_score = (ll_p - ll_m) / (2.0 * h);
1248        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1249        assert!(
1250            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1251            "score={}, fd={fd_score}",
1252            jet.score
1253        );
1254    }
1255
1256    #[test]
1257    fn survival_exact_event_score_fd() {
1258        let ctx = QuadratureContext::new();
1259        let mu = 0.2;
1260        let sigma = 0.5;
1261        let h = 1e-6;
1262        let row = LatentSurvivalRow::exact_event(0.0, 1.5, 0.0, 0.0, (-0.3f64).exp(), 0.0);
1263        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1264            .unwrap()
1265            .log_lik;
1266        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1267            .unwrap()
1268            .log_lik;
1269        let fd_score = (ll_p - ll_m) / (2.0 * h);
1270        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1271        assert!(
1272            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1273            "score={}, fd={fd_score}",
1274            jet.score
1275        );
1276    }
1277
1278    #[test]
1279    fn survival_exact_event_loaded_vs_unloaded_score_fd() {
1280        let ctx = QuadratureContext::new();
1281        let mu = -0.1;
1282        let sigma = 0.4;
1283        let h = 1e-6;
1284        let row = LatentSurvivalRow::exact_event(0.3, 1.2, 0.2, 0.6, 0.9, 0.15);
1285        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1286            .unwrap()
1287            .log_lik;
1288        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1289            .unwrap()
1290            .log_lik;
1291        let fd_score = (ll_p - ll_m) / (2.0 * h);
1292        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1293        assert!(
1294            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1295            "score={}, fd={fd_score}",
1296            jet.score
1297        );
1298    }
1299
1300    #[test]
1301    fn survival_right_censored_loaded_vs_unloaded_score_fd() {
1302        let ctx = QuadratureContext::new();
1303        let mu = 0.15;
1304        let sigma: f64 = 0.35;
1305        let h = 1e-6;
1306        let row = LatentSurvivalRow::right_censored(0.4, 1.7, 0.1, 0.5);
1307        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1308            .unwrap()
1309            .log_lik;
1310        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1311            .unwrap()
1312            .log_lik;
1313        let fd_score = (ll_p - ll_m) / (2.0 * h);
1314        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1315        assert!(
1316            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1317            "score={}, fd={fd_score}",
1318            jet.score
1319        );
1320    }
1321
1322    #[test]
1323    fn survival_interval_censored_score_fd() {
1324        let ctx = QuadratureContext::new();
1325        let mu = 0.0;
1326        let sigma = 0.6;
1327        let h = 1e-6;
1328        let row = LatentSurvivalRow::interval_censored(0.0, 1.0, 2.0, 0.0, 0.0, 0.0);
1329        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1330            .unwrap()
1331            .log_lik;
1332        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1333            .unwrap()
1334            .log_lik;
1335        let fd_score = (ll_p - ll_m) / (2.0 * h);
1336        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1337        assert!(
1338            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1339            "score={}, fd={fd_score}",
1340            jet.score
1341        );
1342    }
1343
1344    #[test]
1345    fn survival_interval_censored_neg_hessian_fd() {
1346        // Second μ-derivative of ℓ = log[S(L) − S(R)] for the interval kernel,
1347        // FD-checked. `neg_hessian` stores −d²ℓ/dμ², so compare against the
1348        // negated central second difference.
1349        let ctx = QuadratureContext::new();
1350        let mu = -0.2;
1351        let sigma = 0.55;
1352        let h = 2e-4;
1353        let row = LatentSurvivalRow::interval_censored(0.0, 0.7, 1.9, 0.0, 0.0, 0.0);
1354        let ll = |m: f64| {
1355            LatentSurvivalRowJet::evaluate(&ctx, &row, m, sigma)
1356                .unwrap()
1357                .log_lik
1358        };
1359        let fd_d2 = (ll(mu + h) - 2.0 * ll(mu) + ll(mu - h)) / (h * h);
1360        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1361        assert!(
1362            (jet.neg_hessian - (-fd_d2)).abs() / fd_d2.abs().max(1e-12) < 1e-2,
1363            "interval neg_hessian={}, fd(-d2)={}",
1364            jet.neg_hessian,
1365            -fd_d2
1366        );
1367    }
1368
1369    #[test]
1370    fn survival_interval_censored_log_sigma_score_fd() {
1371        // σ-recovery for interval data is driven by `score_log_sigma`, the
1372        // derivative of ℓ = log[S(L) − S(R)] w.r.t. log σ. FD-check it directly
1373        // against the row log-likelihood (this is the channel the interval fit's
1374        // latent_sd estimate moves along, the test's primary metric).
1375        let ctx = QuadratureContext::new();
1376        let mu = 0.1;
1377        let sigma: f64 = 0.6;
1378        let h = 1e-5;
1379        let row = LatentSurvivalRow::interval_censored(0.0, 0.8, 2.1, 0.0, 0.0, 0.0);
1380        let ll_at = |s: f64| {
1381            LatentSurvivalRowJet::evaluate(&ctx, &row, mu, s)
1382                .unwrap()
1383                .log_lik
1384        };
1385        // d/d(log σ) = σ · d/dσ, so FD over log σ directly.
1386        let fd_dlogsigma =
1387            (ll_at((sigma.ln() + h).exp()) - ll_at((sigma.ln() - h).exp())) / (2.0 * h);
1388        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1389        assert!(
1390            (jet.score_log_sigma - fd_dlogsigma).abs() / fd_dlogsigma.abs().max(1e-12) < 1e-3,
1391            "interval score_log_sigma={}, fd={fd_dlogsigma}",
1392            jet.score_log_sigma
1393        );
1394    }
1395
1396    #[test]
1397    fn log_kernel_single_term_log_sigma_derivatives_match_ghq_reference() {
1398        let ctx = QuadratureContext::new();
1399        let mu = 0.2;
1400        let sigma = 1.0;
1401        let jet = LogKernelSumJet::single_term(&ctx, 0, 1.0, mu, sigma).unwrap();
1402        let ghq = crate::inference::quadrature::cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
1403        let survival = (1.0 - ghq.l).max(1e-300);
1404        let survival_sigma_over_survival = -ghq.l_sigma / survival;
1405        let ref_score = sigma * survival_sigma_over_survival;
1406        let ref_neg_hessian = -(ref_score
1407            + sigma
1408                * sigma
1409                * (-ghq.l_sigmasigma / survival - survival_sigma_over_survival.powi(2)));
1410
1411        assert!(
1412            (log_sigma_score_from_log_sum(&jet, sigma) - ref_score).abs()
1413                / ref_score.abs().max(1e-12)
1414                < 1e-4,
1415            "log-sigma score={}, ref={ref_score}",
1416            log_sigma_score_from_log_sum(&jet, sigma)
1417        );
1418        assert!(
1419            (log_sigma_neg_hessian_from_log_sum(&jet, sigma) - ref_neg_hessian).abs()
1420                / ref_neg_hessian.abs().max(1e-12)
1421                < 1e-3,
1422            "log-sigma neg_hessian={}, ref={ref_neg_hessian}",
1423            log_sigma_neg_hessian_from_log_sum(&jet, sigma)
1424        );
1425    }
1426
1427    #[test]
1428    fn log_kernel_sum_jet_single_term_d1_fd() {
1429        let ctx = QuadratureContext::new();
1430        let mu = 0.5;
1431        let sigma = 0.4;
1432        let m = 1.0;
1433        let k = 0usize;
1434        let h = 1e-6;
1435
1436        let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1437        let val_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1438        let val_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1439        let fd_d1 = (val_p - val_m) / (2.0 * h);
1440        assert!(
1441            (jet.d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-3,
1442            "d1={}, fd={fd_d1}",
1443            jet.d1
1444        );
1445    }
1446
1447    #[test]
1448    fn log_kernel_sum_jet_single_term_d4_fd() {
1449        let ctx = QuadratureContext::new();
1450        let mu = 0.35;
1451        let sigma = 0.45;
1452        let m = 1.2;
1453        let k = 1usize;
1454        let h = 2e-3;
1455
1456        let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1457        let v_pp = log_kernel_term(&ctx, k, m, mu + 2.0 * h, sigma).unwrap().0;
1458        let v_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1459        let v_0 = log_kernel_term(&ctx, k, m, mu, sigma).unwrap().0;
1460        let v_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1461        let v_mm = log_kernel_term(&ctx, k, m, mu - 2.0 * h, sigma).unwrap().0;
1462        let fd_d4 = (v_mm - 4.0 * v_m + 6.0 * v_0 - 4.0 * v_p + v_pp) / h.powi(4);
1463        assert!(
1464            (jet.d4 - fd_d4).abs() / jet.d4.abs().max(fd_d4.abs()).max(1e-8) < 2e-2,
1465            "d4={}, fd={fd_d4}",
1466            jet.d4
1467        );
1468    }
1469
1470    #[test]
1471    fn latent_cloglog_jet_matches_point_limit_at_zero_sigma() {
1472        let ctx = QuadratureContext::new();
1473        let eta = -0.4;
1474        let jet = latent_cloglog_jet5(&ctx, eta, 0.0).expect("latent jet");
1475        let t = eta.exp();
1476        let d1 = (eta - t).exp();
1477        let d2 = (1.0 - t) * d1;
1478        let d3 = (t * t - 3.0 * t + 1.0) * d1;
1479        let d4 = (-t * t * t + 6.0 * t * t - 7.0 * t + 1.0) * d1;
1480        let d5 = (t.powi(4) - 10.0 * t.powi(3) + 25.0 * t * t - 15.0 * t + 1.0) * d1;
1481        assert!((jet.mean - (1.0 - (-t).exp())).abs() < 1e-12);
1482        assert!((jet.d1 - d1).abs() < 1e-12);
1483        assert!((jet.d2 - d2).abs() < 1e-12);
1484        assert!((jet.d3 - d3).abs() < 1e-12);
1485        assert!((jet.d4 - d4).abs() < 1e-12);
1486        assert!((jet.d5 - d5).abs() < 1e-12);
1487    }
1488
1489    #[test]
1490    fn latent_cloglog_jet_matches_exact_kernel_recurrence() {
1491        let ctx = QuadratureContext::new();
1492        let cases = [(-4.0, 0.15), (-1.2, 0.35), (0.4, 0.6), (1.3, 0.9)];
1493
1494        for (eta, sigma) in cases {
1495            let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1496            let bundle = log_kernel_bundle(&ctx, 1.0, eta, sigma, 5).expect("kernel bundle");
1497            let k0 = bundle.get(0);
1498            let k1 = bundle.get(1).exp();
1499            let k2 = bundle.get(2).exp();
1500            let k3 = bundle.get(3).exp();
1501            let k4 = bundle.get(4).exp();
1502            let k5 = bundle.get(5).exp();
1503
1504            let mean = if k0.is_finite() { -k0.exp_m1() } else { 1.0 };
1505            let d1 = k1;
1506            let d2 = k1 - k2;
1507            let d3 = k1 - 3.0 * k2 + k3;
1508            let d4 = k1 - 7.0 * k2 + 6.0 * k3 - k4;
1509            let d5 = k1 - 15.0 * k2 + 25.0 * k3 - 10.0 * k4 + k5;
1510
1511            assert!((jet.mean - mean).abs() < 1e-12);
1512            assert!((jet.d1 - d1).abs() < 1e-12);
1513            assert!((jet.d2 - d2).abs() < 1e-12);
1514            assert!((jet.d3 - d3).abs() < 1e-12);
1515            assert!((jet.d4 - d4).abs() < 1e-12);
1516            assert!((jet.d5 - d5).abs() < 1e-12);
1517        }
1518    }
1519
1520    #[test]
1521    fn latent_cloglog_binomial_row_neg_hessian_matches_fd() {
1522        let ctx = QuadratureContext::new();
1523        let eta = 0.4;
1524        let sigma = 0.6;
1525        let y = 0.35;
1526        let weight = 2.0;
1527        let h = 1e-4;
1528
1529        let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1530        let mu = jet.mean.clamp(1e-12, 1.0 - 1e-12);
1531        let ellmu = y / mu - (1.0 - y) / (1.0 - mu);
1532        let ellmumu = -y / (mu * mu) - (1.0 - y) / ((1.0 - mu) * (1.0 - mu));
1533        let neg_hessian = -weight * (ellmumu * jet.d1 * jet.d1 + ellmu * jet.d2);
1534
1535        let ll_minus = latent_binomial_row_log_lik(&ctx, eta - h, sigma, y, weight);
1536        let ll0 = latent_binomial_row_log_lik(&ctx, eta, sigma, y, weight);
1537        let ll_plus = latent_binomial_row_log_lik(&ctx, eta + h, sigma, y, weight);
1538        let neg_hessian_fd = -(ll_plus - 2.0 * ll0 + ll_minus) / (h * h);
1539
1540        let err = (neg_hessian - neg_hessian_fd).abs();
1541        let tol = 2e-5_f64.max(3e-3 * neg_hessian_fd.abs());
1542        assert!(
1543            err <= tol,
1544            "latent cloglog Bernoulli row curvature mismatch: analytic={} fd={}",
1545            neg_hessian,
1546            neg_hessian_fd
1547        );
1548    }
1549}