Skip to main content

gam_models/survival/
lognormal_kernel.rs

1//! Shared analytic kernel for latent-variable families with lognormal structure.
2//!
3//! The kernel object `K_{k,m}(μ, σ) := E[exp(k·U − m·exp(U))]`, where
4//! `U ~ N(μ, σ²)`, is the only special function required by all latent families.
5//!
6//! It satisfies exact μ-recurrences (see [`kernel_ratio_jet`]) and the
7//! corresponding heat-equation σ-identities, so fixed-σ latent families reduce
8//! to evaluating kernel bundles at shifted arguments.
9//!
10//! Row likelihoods for binary and survival models are small signed sums of
11//! kernel terms; [`LogKernelSumJet`] evaluates their log-derivatives from
12//! log-space kernel bundles and treats non-positive signed sums as invalid rows.
13
14use crate::model_types::EstimationError;
15use crate::probability::signed_log_sum_exp;
16use crate::quadrature::{
17    IntegratedExpectationMode, QuadratureContext, lognormal_laplace_unit_log_term_shared,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21
22// ─── Typed errors ────────────────────────────────────────────────────────────
23
24/// Errors produced by the lognormal-kernel frailty/marginal-slope validators.
25///
26/// Public boundaries that historically returned `Result<_, String>` continue to
27/// do so via `.map_err(|e| e.to_string())`; the `Display` impl reproduces the
28/// original error strings byte-for-byte.
29#[derive(Debug, Clone)]
30pub enum LognormalKernelError {
31    /// The chosen frailty modifier is not finite-state exact with the
32    /// requested marginal-slope family.
33    InvalidSpec { reason: String },
34}
35
36impl_reason_error_boilerplate! {
37    LognormalKernelError {
38        InvalidSpec,
39    }
40}
41
42// ─── Frailty specification ───────────────────────────────────────────────────
43
44/// How the hazard multiplier frailty loads onto the hazard components.
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "kebab-case")]
47pub enum HazardLoading {
48    /// Frailty multiplies the entire hazard: h(t|U) = exp(U) · h_0(t).
49    Full,
50    /// Frailty multiplies only the disease-like component; an exogenous
51    /// background ("Makeham") component is unloaded:
52    ///   h(t|U) = exp(U) · h_loaded(t) + h_unloaded(t).
53    /// This is the faithful model for Gompertz-Makeham.
54    LoadedVsUnloaded,
55}
56
57/// Frailty modifier specification at the family level.
58///
59/// Two structurally different exact modifiers exist:
60///
61/// 1. **GaussianShift**: additive Gaussian on the final transformation index.
62///    Exact for probit families — the existing sextic microcell kernel survives
63///    unchanged (just scale denested cell coefficients by 1/√(1+σ²)).
64///
65/// 2. **HazardMultiplier**: lognormal multiplier on the loaded cumulative hazard.
66///    Exact for PH/cloglog families — row likelihoods are finite sums of
67///    K_{k,m}(μ, σ) kernel terms.
68///
69/// These are mathematically distinct families.  Do not mix them.
70#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71#[serde(tag = "frailty_kind", rename_all = "kebab-case")]
72pub enum FrailtySpec {
73    /// No frailty modifier.
74    None,
75    /// Gaussian shift on the final scalar index: U ~ N(0, σ²) added to η.
76    /// Exact for probit: E[Φ(η + U)] = Φ(η / √(1+σ²)).
77    /// The existing sextic microcell kernel is preserved.
78    GaussianShift {
79        /// Fixed σ, or None if learnable.
80        sigma_fixed: Option<f64>,
81    },
82    /// Lognormal hazard multiplier: conditional hazard h(t|U) involves exp(U).
83    /// Exact for PH/cloglog/survival via K_{k,m} kernel.
84    HazardMultiplier {
85        /// Fixed σ, or None if learnable.
86        sigma_fixed: Option<f64>,
87        /// How the multiplier loads onto hazard components.
88        loading: HazardLoading,
89    },
90}
91
92impl FrailtySpec {
93    /// Whether this spec requests an actual frailty modifier.
94    ///
95    /// [`FrailtySpec::None`] is the canonical "no frailty" value. The CLI/config
96    /// layer normalizes a missing `--frailty-kind` to `Some(FrailtySpec::None)`
97    /// (see `resolve_cli_frailty_spec`), so `Option::is_some` on
98    /// `config.frailty` is NOT a valid test for "the user requested frailty" — it
99    /// fires on the null `Some(FrailtySpec::None)`. Family/mode guards that only
100    /// support the no-frailty case must reject on this predicate instead, or they
101    /// misclassify every ordinary CLI fit as a frailty request.
102    #[inline]
103    pub fn is_active(&self) -> bool {
104        !matches!(self, Self::None)
105    }
106
107    /// Validate that this frailty spec is compatible with score_warp/linkwiggle
108    /// cubic marginal-slope families.
109    ///
110    /// - `GaussianShift` is exact: the sextic microcell kernel is preserved
111    ///   (probit scaling by 1/τ, τ = √(1+σ²)).
112    /// - `HazardMultiplier` is exact only for PH/cloglog rowwise families.
113    ///   It is NOT finite-state exact with score_warp/linkwiggle cubic
114    ///   marginal-slope, because the multiplicative frailty breaks the
115    ///   polynomial kernel closure that the cubic cell derivatives require.
116    ///
117    /// Returns an error if the combination is not exactly integrable.
118    pub fn validate_for_marginal_slope(&self) -> Result<(), String> {
119        self.validate_for_marginal_slope_typed()
120            .map_err(|e| e.to_string())
121    }
122
123    /// Typed variant of [`Self::validate_for_marginal_slope`] used internally;
124    /// the `String`-returning entry point above is preserved as a one-line
125    /// shim for external callers.
126    pub fn validate_for_marginal_slope_typed(&self) -> Result<(), LognormalKernelError> {
127        match self {
128            Self::None | Self::GaussianShift { .. } => Ok(()),
129            Self::HazardMultiplier { .. } => Err(LognormalKernelError::InvalidSpec {
130                reason:
131                    "HazardMultiplier frailty is not finite-state exact with score_warp/linkwiggle \
132                     cubic marginal-slope families. Use GaussianShift frailty (exact probit scaling) \
133                     or use the standalone latent-cloglog/latent-survival families instead."
134                        .to_string(),
135            }),
136        }
137    }
138}
139
140// ─── Probit frailty scaling ──────────────────────────────────────────────────
141
142#[inline]
143fn probit_frailty_scale_components(sigma: f64) -> (f64, f64) {
144    let abs_sigma = sigma.abs();
145    if abs_sigma > 1.0 {
146        let inv = 1.0 / abs_sigma;
147        let denom = 1.0 + inv * inv;
148        (inv / denom.sqrt(), 1.0 / denom)
149    } else {
150        let sigma2 = sigma * sigma;
151        let denom = 1.0 + sigma2;
152        (1.0 / denom.sqrt(), sigma2 / denom)
153    }
154}
155
156/// Probit frailty scaling factor **with** t-derivatives (t = log σ).
157///
158/// Provides exact closed-form derivatives of s = 1/√(1+σ²) with respect to
159/// t = log(σ) for learnable Gaussian-shift frailty in the marginal-slope
160/// families.  For Gaussian frailty on the final probit index
161/// E[Φ(η + U)] = Φ(η · s) with s = 1/√(1+σ²); writing α = σ²/(1+σ²) the
162/// derivatives are ∂_t s = −α·s and ∂_{tt} s = α(3α−2)·s.
163#[derive(Clone, Copy, Debug)]
164pub struct ProbitFrailtyScaleJet {
165    /// s = 1/√(1+σ²)
166    pub s: f64,
167    /// α = σ²/(1+σ²)  — shared auxiliary for all derivative levels.
168    pub alpha: f64,
169    /// ∂_t s = -α·s
170    pub ds: f64,
171    /// ∂_{tt} s = α(3α−2)·s
172    pub d2s: f64,
173}
174
175impl ProbitFrailtyScaleJet {
176    /// Build the jet from σ (not from t = log σ).
177    ///
178    /// At σ = 0 the jet degenerates to (s=1, α=0, ds=0, d2s=0), which is
179    /// correct: zero frailty means s ≡ 1 independent of t.
180    pub fn new(sigma: f64) -> Self {
181        let (s, alpha) = probit_frailty_scale_components(sigma);
182        Self {
183            s,
184            alpha,
185            ds: -alpha * s,
186            d2s: alpha * (3.0 * alpha - 2.0) * s,
187        }
188    }
189
190    /// Build the jet from t = log(σ) directly.
191    pub fn from_log_sigma(log_sigma: f64) -> Self {
192        Self::new(log_sigma.exp())
193    }
194}
195
196#[inline]
197fn worst_mode(
198    a: IntegratedExpectationMode,
199    b: IntegratedExpectationMode,
200) -> IntegratedExpectationMode {
201    if a.rank() >= b.rank() { a } else { b }
202}
203
204// ─── Log-space kernel infrastructure ──────────────────────────────────────────
205//
206// The runtime kernel path stays in log-space until the final ratios are formed,
207// avoiding the overflow/underflow and cancellation problems that come from
208// exponentiating individual terms too early.
209
210/// Returns `log K_{k,m}(μ,σ)` directly, without exponentiation.
211///
212/// The value is always finite (or `NEG_INFINITY` when the kernel is zero), so
213/// it cannot overflow or underflow.
214#[inline]
215fn validate_kernel_inputs(m: f64, mu: f64, sigma: f64) -> Result<(), EstimationError> {
216    if !m.is_finite() || m < 0.0 {
217        crate::bail_invalid_estim!("lognormal kernel requires finite m >= 0, got {m}");
218    }
219    if !mu.is_finite() || !sigma.is_finite() || sigma < 0.0 {
220        crate::bail_invalid_estim!(
221            "lognormal kernel requires finite mu and sigma >= 0, got mu={mu}, sigma={sigma}"
222        );
223    }
224    Ok::<(), _>(())
225}
226
227#[inline]
228pub fn log_kernel_term(
229    quadctx: &QuadratureContext,
230    k: usize,
231    m: f64,
232    mu: f64,
233    sigma: f64,
234) -> Result<(f64, IntegratedExpectationMode), EstimationError> {
235    validate_kernel_inputs(m, mu, sigma)?;
236    let kf = k as f64;
237    let sigma2 = sigma * sigma;
238    if !sigma2.is_finite() {
239        crate::bail_invalid_estim!(
240            "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
241        );
242    }
243    let prefix_bound = kf * mu.abs() + 0.5 * kf * kf * sigma2;
244    if !prefix_bound.is_finite() {
245        crate::bail_invalid_estim!(
246            "lognormal kernel prefix is outside the finite exact-derivative range: k={k}, mu={mu}, sigma={sigma}"
247        );
248    }
249    let prefix = kf * mu + 0.5 * kf * kf * sigma2;
250    if m == 0.0 {
251        return Ok((prefix, IntegratedExpectationMode::ExactClosedForm));
252    }
253    let log_m = m.ln();
254    let shifted_bound = mu.abs() + kf * sigma2 + log_m.abs();
255    if !shifted_bound.is_finite() {
256        crate::bail_invalid_estim!(
257            "lognormal kernel shifted location is outside the finite exact-derivative range: k={k}, m={m}, mu={mu}, sigma={sigma}"
258        );
259    }
260    let shifted_mu = mu + kf * sigma2 + log_m;
261    // Survival carried in log space: prefix + ln S(shifted_mu, σ). This keeps the
262    // kernel's true magnitude when S underflows in value space at large σ — the
263    // old `laplace <= 0.0 → −∞` collapse discarded a large-but-finite log-value
264    // (#798) and the value-space asymptotic was biased low at σ ≥ 8 (#799).
265    let (log_laplace, mode) = lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
266    Ok((prefix + log_laplace, mode))
267}
268
269/// Kernel bundle storing `log K_{k,m}` values instead of `K_{k,m}`.
270#[derive(Clone, Debug)]
271pub struct LogLognormalKernelBundle {
272    pub log_values: Vec<f64>,
273    pub mode: IntegratedExpectationMode,
274}
275
276impl LogLognormalKernelBundle {
277    #[inline]
278    pub fn get(&self, k: usize) -> f64 {
279        self.log_values[k]
280    }
281
282    #[inline]
283    pub fn len(&self) -> usize {
284        self.log_values.len()
285    }
286}
287
288/// Builds a log-space kernel bundle for `k = 0, 1, …, max_k` at fixed
289/// `(m, μ, σ)`.
290pub fn log_kernel_bundle(
291    quadctx: &QuadratureContext,
292    m: f64,
293    mu: f64,
294    sigma: f64,
295    max_k: usize,
296) -> Result<LogLognormalKernelBundle, EstimationError> {
297    validate_kernel_inputs(m, mu, sigma)?;
298    let mut log_values = Vec::with_capacity(max_k + 1);
299    let sigma2 = sigma * sigma;
300    if !sigma2.is_finite() {
301        crate::bail_invalid_estim!(
302            "lognormal kernel sigma is outside the finite exact-derivative range: sigma={sigma}"
303        );
304    }
305    let max_kf = max_k as f64;
306    let prefix_bound = max_kf * mu.abs() + 0.5 * max_kf * max_kf * sigma2;
307    if !prefix_bound.is_finite() {
308        crate::bail_invalid_estim!(
309            "lognormal kernel bundle prefix is outside the finite exact-derivative range: max_k={max_k}, mu={mu}, sigma={sigma}"
310        );
311    }
312    if m == 0.0 {
313        let mut prefix = 0.0;
314        for k in 0..=max_k {
315            log_values.push(prefix);
316            prefix += mu + (k as f64 + 0.5) * sigma2;
317        }
318        return Ok(LogLognormalKernelBundle {
319            log_values,
320            mode: IntegratedExpectationMode::ExactClosedForm,
321        });
322    }
323
324    let log_m = m.ln();
325    let shifted_bound = mu.abs() + max_kf * sigma2 + log_m.abs();
326    if !shifted_bound.is_finite() {
327        crate::bail_invalid_estim!(
328            "lognormal kernel bundle shifted location is outside the finite exact-derivative range: max_k={max_k}, m={m}, mu={mu}, sigma={sigma}"
329        );
330    }
331    let mut shifted_mu = mu + log_m;
332    let mut prefix = 0.0;
333    let mut mode = IntegratedExpectationMode::ExactClosedForm;
334    for k in 0..=max_k {
335        let (log_laplace, val_mode) =
336            lognormal_laplace_unit_log_term_shared(quadctx, shifted_mu, sigma);
337        log_values.push(if log_laplace.is_finite() {
338            prefix + log_laplace
339        } else {
340            f64::NEG_INFINITY
341        });
342        mode = worst_mode(mode, val_mode);
343        prefix += mu + (k as f64 + 0.5) * sigma2;
344        shifted_mu += sigma2;
345    }
346    Ok(LogLognormalKernelBundle { log_values, mode })
347}
348
349/// Computes the value-space derivative ratios `∂ⁿ_μ K_{k,m} / K_{k,m}`
350/// from a log-space bundle.
351///
352/// Returns `[1, K'/K, K''/K, K'''/K, K''''/K]` where only the first
353/// `order + 1` entries are valid.
354///
355/// The recurrences are applied in ratio form, with each `K_{k+r}/K_k`
356/// computed as `exp(log K_{k+r} − log K_k)`, which remains finite even when
357/// the individual kernel values would overflow or underflow.
358pub fn kernel_ratio_jet(
359    log_bundle: &LogLognormalKernelBundle,
360    k: usize,
361    m: f64,
362    order: usize,
363) -> [f64; 5] {
364    let kf = k as f64;
365    let log_k0 = log_bundle.get(k);
366
367    // Precompute ratios K_{k+r}/K_k for r = 1..=order, each from a single
368    // log-difference.  This avoids redundant exp() calls when the same ratio
369    // appears in multiple derivative orders.
370    let mut rk = [0.0f64; 5]; // rk[0] unused; rk[r] = K_{k+r}/K_k
371    for r in 1..=order.min(4) {
372        let delta = log_bundle.get(k + r) - log_k0;
373        rk[r] = if delta.is_finite() {
374            delta.exp()
375        } else if delta > 0.0 {
376            f64::INFINITY
377        } else {
378            0.0
379        };
380    }
381
382    let mut jet = [0.0; 5];
383    jet[0] = 1.0;
384
385    if order >= 1 {
386        jet[1] = kf - m * rk[1];
387    }
388    if order >= 2 {
389        jet[2] = kf * kf - (2.0 * kf + 1.0) * m * rk[1] + m * m * rk[2];
390    }
391    if order >= 3 {
392        jet[3] = kf * kf * kf - (3.0 * kf * kf + 3.0 * kf + 1.0) * m * rk[1]
393            + 3.0 * (kf + 1.0) * m * m * rk[2]
394            - m * m * m * rk[3];
395    }
396    if order >= 4 {
397        let k2 = kf * kf;
398        let k3 = k2 * kf;
399        let k4 = k3 * kf;
400        let m2 = m * m;
401        let m3 = m2 * m;
402        let m4 = m3 * m;
403        jet[4] = k4 - (4.0 * k3 + 6.0 * k2 + 4.0 * kf + 1.0) * m * rk[1]
404            + (6.0 * k2 + 12.0 * kf + 7.0) * m2 * rk[2]
405            - (4.0 * kf + 6.0) * m3 * rk[3]
406            + m4 * rk[4];
407    }
408
409    jet
410}
411
412// `LatentCLogLogJet5` + `latent_cloglog_jet5` / `latent_cloglog_inverse_link_jet`
413// moved DOWN to `crate::quadrature` (#1135), co-located with their analytic
414// backend, so the `solver` link layer names them without importing up into
415// `families::survival`. Re-exported here so the in-family callers (e.g.
416// `family_runtime`) keep resolving.
417pub use crate::quadrature::{
418    LatentCLogLogJet5, latent_cloglog_inverse_link_jet, latent_cloglog_jet5,
419};
420
421// ─── LogKernelSumJet: log-sum derivatives from log-space bundles ─────────────
422
423/// A single signed term in a kernel sum: coefficient × K_{k,m}.
424#[derive(Clone, Copy, Debug)]
425pub struct KernelSumTerm {
426    /// Multiplicative coefficient (can be negative for difference terms).
427    pub coeff: f64,
428    /// Kernel order parameter k.
429    pub k: usize,
430    /// Kernel mass parameter m (≥ 0).
431    pub m: f64,
432}
433
434/// Derivatives of `log(Σ_j a_j · K_{k_j, m_j}(μ, σ))` with respect to μ.
435///
436/// This is the workhorse for row-level log-likelihood derivatives in all
437/// latent families.  The numerator and denominator of a row likelihood are
438/// each a small signed sum of kernel terms.
439///
440/// The value path is assembled from log-space kernel bundles and ratio jets,
441/// so individual kernel terms are never exponentiated before the final signed
442/// sum. That avoids the old overflow/underflow problems from value-space
443/// kernels. When the signed sum is zero or negative, this returns an invalid
444/// row (`value = -∞`) instead of trying to continue with a floored surrogate.
445/// Signed two-term differences (e.g. interval censoring `K_{0,M_L} − K_{0,M_R}`)
446/// are still combined through the shared sign-aware log-sum path.
447#[derive(Clone, Copy, Debug)]
448pub struct LogKernelSumJet {
449    /// log(Σ a_j K_j)
450    pub value: f64,
451    /// d/dμ log(Σ a_j K_j)
452    pub d1: f64,
453    /// d²/dμ² log(Σ a_j K_j)
454    pub d2: f64,
455    /// d³/dμ³ log(Σ a_j K_j)
456    pub d3: f64,
457    /// d⁴/dμ⁴ log(Σ a_j K_j)
458    pub d4: f64,
459    pub mode: IntegratedExpectationMode,
460}
461
462impl LogKernelSumJet {
463    #[inline]
464    fn non_positive(mode: IntegratedExpectationMode) -> Self {
465        Self {
466            value: f64::NEG_INFINITY,
467            d1: 0.0,
468            d2: 0.0,
469            d3: 0.0,
470            d4: 0.0,
471            mode,
472        }
473    }
474
475    #[inline]
476    fn from_log_value_and_ratios(
477        value: f64,
478        ratio: [f64; 5],
479        mode: IntegratedExpectationMode,
480    ) -> Self {
481        let r1 = ratio[1];
482        let r2 = ratio[2];
483        let r3 = ratio[3];
484        let r4 = ratio[4];
485        Self {
486            value,
487            d1: r1,
488            d2: r2 - r1 * r1,
489            d3: r3 - 3.0 * r1 * r2 + 2.0 * r1 * r1 * r1,
490            d4: r4 - 4.0 * r1 * r3 - 3.0 * r2 * r2 + 12.0 * r1 * r1 * r2 - 6.0 * r1.powi(4),
491            mode,
492        }
493    }
494
495    #[inline]
496    fn term_log_mag_and_ratio(
497        bundle: &LogLognormalKernelBundle,
498        term: KernelSumTerm,
499    ) -> (f64, [f64; 5]) {
500        (
501            term.coeff.abs().ln() + bundle.get(term.k),
502            // d4 is used by the exact log-sigma curvature, so this must carry
503            // ratios through order 4 rather than truncating at order 3.
504            kernel_ratio_jet(bundle, term.k, term.m, 4),
505        )
506    }
507
508    fn evaluate_two_terms(
509        quadctx: &QuadratureContext,
510        t0: KernelSumTerm,
511        t1: KernelSumTerm,
512        mu: f64,
513        sigma: f64,
514    ) -> Result<Self, EstimationError> {
515        let max_k_needed = t0.k.max(t1.k) + 4;
516        let bundle0 = log_kernel_bundle(quadctx, t0.m, mu, sigma, max_k_needed)?;
517        let mut overall_mode = bundle0.mode;
518        let bundle1_owned = if (t0.m - t1.m).abs() < 1e-300 {
519            None
520        } else {
521            let bundle1 = log_kernel_bundle(quadctx, t1.m, mu, sigma, max_k_needed)?;
522            overall_mode = worst_mode(overall_mode, bundle1.mode);
523            Some(bundle1)
524        };
525        let bundle1 = bundle1_owned.as_ref().unwrap_or(&bundle0);
526
527        let (log_mag0, ratio0) = Self::term_log_mag_and_ratio(&bundle0, t0);
528        let (log_mag1, ratio1) = Self::term_log_mag_and_ratio(bundle1, t1);
529        let log_mags = [log_mag0, log_mag1];
530        let signs = [t0.coeff.signum(), t1.coeff.signum()];
531        let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
532        if !log_s.is_finite() || sign_s <= 0.0 {
533            return Ok(Self::non_positive(overall_mode));
534        }
535
536        let w0 = sign_s * signs[0] * (log_mag0 - log_s).exp();
537        let w1 = sign_s * signs[1] * (log_mag1 - log_s).exp();
538        let wr1 = w0 * ratio0[1] + w1 * ratio1[1];
539        let wr2 = w0 * ratio0[2] + w1 * ratio1[2];
540        let wr3 = w0 * ratio0[3] + w1 * ratio1[3];
541        let wr4 = w0 * ratio0[4] + w1 * ratio1[4];
542
543        Ok(Self {
544            value: log_s,
545            d1: wr1,
546            d2: wr2 - wr1 * wr1,
547            d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
548            d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
549                - 6.0 * wr1.powi(4),
550            mode: overall_mode,
551        })
552    }
553
554    /// Evaluate for a single positive kernel term (fast path).
555    ///
556    /// Computes `log(K_{k,m})` and its μ-derivatives from exact recurrences,
557    /// entirely in log-space.
558    pub fn single_term(
559        quadctx: &QuadratureContext,
560        k: usize,
561        m: f64,
562        mu: f64,
563        sigma: f64,
564    ) -> Result<Self, EstimationError> {
565        let max_k_needed = k + 4;
566        let lb = log_kernel_bundle(quadctx, m, mu, sigma, max_k_needed)?;
567        Ok(Self::from_log_value_and_ratios(
568            lb.get(k),
569            kernel_ratio_jet(&lb, k, m, 4),
570            lb.mode,
571        ))
572    }
573
574    /// Evaluate `log(Σ a_j K_j)` and its μ-derivatives for a small signed sum.
575    ///
576    /// All terms share the same `(μ, σ)`.  Both the value and derivative
577    /// ratios are computed entirely in log-space.  The runtime latent-survival
578    /// rows in this repo are almost always one-term or two-term sums, so those
579    /// cases stay on dedicated stack paths; the heap-backed logic below is only
580    /// for genuinely longer symbolic sums:
581    ///
582    /// 1. Per-term log-magnitudes `log|a_j| + log K_{k_j,m_j}` and signs.
583    /// 2. Sign-aware log-sum-exp to get `log|S|` and `sign(S)`.
584    /// 3. Importance weights `w_j = a_j K_j / S` formed in log-space.
585    /// 4. Weighted ratio sums `R_n = Σ w_j · (∂ⁿK_j / K_j)` for the
586    ///    final log-derivatives.
587    pub fn evaluate(
588        quadctx: &QuadratureContext,
589        terms: &[KernelSumTerm],
590        mu: f64,
591        sigma: f64,
592    ) -> Result<Self, EstimationError> {
593        if terms.is_empty() {
594            // Empty sums are a caller-contract violation, not a degenerate row.
595            // Return an input error so callers can report the malformed kernel sum.
596            crate::bail_invalid_estim!("KernelSumJet requires at least one term");
597        }
598
599        // Fast path for single term.
600        if terms.len() == 1 {
601            let t = &terms[0];
602            if t.coeff <= 0.0 {
603                // Negative or zero coefficient: the sum is non-positive, so
604                // log(sum) is undefined.  Return −∞ (impossible observation),
605                // matching the general path's sign_s ≤ 0 branch.
606                return Ok(Self::non_positive(
607                    IntegratedExpectationMode::ExactClosedForm,
608                ));
609            }
610            let jet = Self::single_term(quadctx, t.k, t.m, mu, sigma)?;
611            return Ok(Self {
612                value: t.coeff.ln() + jet.value,
613                d1: jet.d1,
614                d2: jet.d2,
615                d3: jet.d3,
616                d4: jet.d4,
617                mode: jet.mode,
618            });
619        }
620        if terms.len() == 2 {
621            return Self::evaluate_two_terms(quadctx, terms[0], terms[1], mu, sigma);
622        }
623
624        let max_k_needed = terms.iter().map(|t| t.k).max().unwrap_or(0) + 4;
625
626        // Build log-bundles for each unique mass.
627        let mut log_bundles: Vec<(f64, LogLognormalKernelBundle)> = Vec::with_capacity(2);
628        let mut overall_mode = IntegratedExpectationMode::ExactClosedForm;
629        for term in terms {
630            if !log_bundles
631                .iter()
632                .any(|(m, _)| (*m - term.m).abs() < 1e-300)
633            {
634                let b = log_kernel_bundle(quadctx, term.m, mu, sigma, max_k_needed)?;
635                overall_mode = worst_mode(overall_mode, b.mode);
636                log_bundles.push((term.m, b));
637            }
638        }
639
640        let get_lb = |m: f64| -> &LogLognormalKernelBundle {
641            &log_bundles
642                .iter()
643                .find(|(bm, _)| (*bm - m).abs() < 1e-300)
644                .unwrap()
645                .1
646        };
647
648        // Per-term: log magnitude, sign, and ratio jet.
649        let mut log_mags: Vec<f64> = Vec::with_capacity(terms.len());
650        let mut signs: Vec<f64> = Vec::with_capacity(terms.len());
651        let mut ratios: Vec<[f64; 5]> = Vec::with_capacity(terms.len());
652        for term in terms {
653            let lb = get_lb(term.m);
654            log_mags.push(term.coeff.abs().ln() + lb.get(term.k));
655            signs.push(term.coeff.signum());
656            ratios.push(kernel_ratio_jet(lb, term.k, term.m, 4));
657        }
658
659        // Sign-aware log-sum-exp: compute log|S| and sign(S).
660        let (log_s, sign_s) = signed_log_sum_exp(&log_mags, &signs);
661
662        if !log_s.is_finite() || sign_s <= 0.0 {
663            // Sum is zero or negative — degenerate row.
664            return Ok(Self::non_positive(overall_mode));
665        }
666
667        // Importance weights w_j = sign(S) · sign(a_j) · exp(log|a_j K_j| − log|S|).
668        // When S > 0 and all terms have well-defined kernels, Σ w_j = 1.
669        let mut wr1 = 0.0;
670        let mut wr2 = 0.0;
671        let mut wr3 = 0.0;
672        let mut wr4 = 0.0;
673        for i in 0..terms.len() {
674            let w = sign_s * signs[i] * (log_mags[i] - log_s).exp();
675            wr1 += w * ratios[i][1];
676            wr2 += w * ratios[i][2];
677            wr3 += w * ratios[i][3];
678            wr4 += w * ratios[i][4];
679        }
680
681        Ok(Self {
682            value: log_s,
683            d1: wr1,
684            d2: wr2 - wr1 * wr1,
685            d3: wr3 - 3.0 * wr1 * wr2 + 2.0 * wr1 * wr1 * wr1,
686            d4: wr4 - 4.0 * wr1 * wr3 - 3.0 * wr2 * wr2 + 12.0 * wr1 * wr1 * wr2
687                - 6.0 * wr1.powi(4),
688            mode: overall_mode,
689        })
690    }
691}
692
693// ─── Latent survival sufficient statistics ───────────────────────────────────
694
695/// Event type for compiled survival sufficient statistics.
696#[derive(Clone, Copy, Debug, PartialEq, Eq)]
697pub enum LatentSurvivalEventType {
698    /// Right-censored: observed alive in the observation window.
699    RightCensored,
700    /// Exact event: event observed at a known time.
701    ExactEvent,
702    /// Interval-censored: event known to occur in (t_left, t_right].
703    IntervalCensored,
704}
705
706impl fmt::Display for LatentSurvivalEventType {
707    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708        match self {
709            Self::RightCensored => write!(f, "right_censored"),
710            Self::ExactEvent => write!(f, "exact_event"),
711            Self::IntervalCensored => write!(f, "interval_censored"),
712        }
713    }
714}
715
716/// Row-level sufficient statistics for one latent survival observation.
717///
718/// This is the canonical row representation used by both fitted-family
719/// evaluation and saved-model prediction.
720///
721/// For the full-loading model (frailty multiplies entire hazard):
722///   mass_loaded = total cumulative hazard mass
723///   mass_unloaded = 0
724///
725/// For the loaded-vs-unloaded model (Gompertz-Makeham):
726///   mass_loaded = integrated disease hazard component
727///   mass_unloaded = integrated background hazard component (not frailty-modified)
728///
729/// The unloaded mass contributes a simple exp(-M_U) prefactor.
730#[derive(Clone, Copy, Debug)]
731pub struct LatentSurvivalRow {
732    pub event_type: LatentSurvivalEventType,
733    /// Cumulative nuisance mass at entry: B(a_in).
734    /// Zero if there is no left truncation.
735    pub mass_entry: f64,
736    /// Cumulative nuisance mass at exit/event: B(a_out) or B(a_event).
737    pub mass_exit: f64,
738    /// For interval censoring: mass at left boundary B(a_L).
739    pub mass_left: f64,
740    /// For interval censoring: mass at right boundary B(a_R).
741    pub mass_right: f64,
742    /// For interval censoring: unloaded mass at left boundary.
743    pub mass_unloaded_left: f64,
744    /// For interval censoring: unloaded mass at right boundary.
745    pub mass_unloaded_right: f64,
746    /// Unloaded (background) cumulative mass at entry (0 for full loading).
747    pub mass_unloaded_entry: f64,
748    /// Unloaded (background) cumulative mass at exit.
749    pub mass_unloaded_exit: f64,
750    /// Loaded instantaneous hazard at event time (for exact events).
751    pub hazard_loaded: f64,
752    /// Unloaded instantaneous hazard at event time (for exact events).
753    pub hazard_unloaded: f64,
754}
755
756impl LatentSurvivalRow {
757    /// Delayed-entry right-censored row with explicit loaded/unloaded masses.
758    ///
759    /// `mass_entry` and `mass_exit` are cumulative loaded masses `B_L(a_in)`
760    /// and `B_L(a_out)` for this row object. They are not an increment over
761    /// `(a_in, a_out]`.
762    pub fn right_censored(
763        mass_entry: f64,
764        mass_exit: f64,
765        mass_unloaded_entry: f64,
766        mass_unloaded_exit: f64,
767    ) -> Self {
768        Self {
769            event_type: LatentSurvivalEventType::RightCensored,
770            mass_entry,
771            mass_exit,
772            mass_left: 0.0,
773            mass_right: 0.0,
774            mass_unloaded_left: 0.0,
775            mass_unloaded_right: 0.0,
776            mass_unloaded_entry,
777            mass_unloaded_exit,
778            hazard_loaded: 0.0,
779            hazard_unloaded: 0.0,
780        }
781    }
782
783    /// Delayed-entry exact-event row with explicit loaded/unloaded hazard parts.
784    pub fn exact_event(
785        mass_entry: f64,
786        mass_exit: f64,
787        mass_unloaded_entry: f64,
788        mass_unloaded_exit: f64,
789        hazard_loaded: f64,
790        hazard_unloaded: f64,
791    ) -> Self {
792        Self {
793            event_type: LatentSurvivalEventType::ExactEvent,
794            mass_entry,
795            mass_exit,
796            mass_left: 0.0,
797            mass_right: 0.0,
798            mass_unloaded_left: 0.0,
799            mass_unloaded_right: 0.0,
800            mass_unloaded_entry,
801            mass_unloaded_exit,
802            hazard_loaded,
803            hazard_unloaded,
804        }
805    }
806
807    /// Delayed-entry interval-censored row with explicit loaded/unloaded masses.
808    pub fn interval_censored(
809        mass_entry: f64,
810        mass_left: f64,
811        mass_right: f64,
812        mass_unloaded_entry: f64,
813        mass_unloaded_left: f64,
814        mass_unloaded_right: f64,
815    ) -> Self {
816        Self {
817            event_type: LatentSurvivalEventType::IntervalCensored,
818            mass_entry,
819            mass_exit: 0.0,
820            mass_left,
821            mass_right,
822            mass_unloaded_left,
823            mass_unloaded_right,
824            mass_unloaded_entry,
825            mass_unloaded_exit: 0.0,
826            hazard_loaded: 0.0,
827            hazard_unloaded: 0.0,
828        }
829    }
830
831    pub fn validate(&self) -> Result<(), EstimationError> {
832        let fields = [
833            ("mass_entry", self.mass_entry),
834            ("mass_exit", self.mass_exit),
835            ("mass_left", self.mass_left),
836            ("mass_right", self.mass_right),
837            ("mass_unloaded_left", self.mass_unloaded_left),
838            ("mass_unloaded_right", self.mass_unloaded_right),
839            ("mass_unloaded_entry", self.mass_unloaded_entry),
840            ("mass_unloaded_exit", self.mass_unloaded_exit),
841            ("hazard_loaded", self.hazard_loaded),
842            ("hazard_unloaded", self.hazard_unloaded),
843        ];
844        for (name, value) in fields {
845            if !value.is_finite() || value < 0.0 {
846                crate::bail_invalid_estim!(
847                    "latent survival row has invalid {name}={value}; expected a finite non-negative value"
848                );
849            }
850        }
851
852        match self.event_type {
853            LatentSurvivalEventType::RightCensored => {
854                if self.mass_exit < self.mass_entry {
855                    crate::bail_invalid_estim!(
856                        "latent survival right-censored row requires mass_exit >= mass_entry, got {} < {}",
857                        self.mass_exit,
858                        self.mass_entry
859                    );
860                }
861                if self.mass_unloaded_exit < self.mass_unloaded_entry {
862                    crate::bail_invalid_estim!(
863                        "latent survival right-censored row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
864                        self.mass_unloaded_exit,
865                        self.mass_unloaded_entry
866                    );
867                }
868                if self.mass_left > 0.0
869                    || self.mass_right > 0.0
870                    || self.mass_unloaded_left > 0.0
871                    || self.mass_unloaded_right > 0.0
872                    || self.hazard_loaded > 0.0
873                    || self.hazard_unloaded > 0.0
874                {
875                    crate::bail_invalid_estim!("latent survival right-censored row cannot carry interval masses or event hazards"
876                            .to_string(),);
877                }
878            }
879            LatentSurvivalEventType::ExactEvent => {
880                if self.mass_exit < self.mass_entry {
881                    crate::bail_invalid_estim!(
882                        "latent survival exact-event row requires mass_exit >= mass_entry, got {} < {}",
883                        self.mass_exit,
884                        self.mass_entry
885                    );
886                }
887                if self.mass_unloaded_exit < self.mass_unloaded_entry {
888                    crate::bail_invalid_estim!(
889                        "latent survival exact-event row requires unloaded exit mass >= unloaded entry mass, got {} < {}",
890                        self.mass_unloaded_exit,
891                        self.mass_unloaded_entry
892                    );
893                }
894                if self.mass_left > 0.0
895                    || self.mass_right > 0.0
896                    || self.mass_unloaded_left > 0.0
897                    || self.mass_unloaded_right > 0.0
898                {
899                    crate::bail_invalid_estim!(
900                        "latent survival exact-event row cannot carry interval masses"
901                    );
902                }
903                if self.hazard_loaded == 0.0 && self.hazard_unloaded == 0.0 {
904                    crate::bail_invalid_estim!("latent survival exact-event row requires a positive loaded or unloaded hazard"
905                            .to_string(),);
906                }
907            }
908            LatentSurvivalEventType::IntervalCensored => {
909                if self.mass_left < self.mass_entry || self.mass_right < self.mass_left {
910                    crate::bail_invalid_estim!(
911                        "latent survival interval row requires mass_entry <= mass_left <= mass_right, got entry={}, left={}, right={}",
912                        self.mass_entry,
913                        self.mass_left,
914                        self.mass_right
915                    );
916                }
917                if self.mass_unloaded_left < self.mass_unloaded_entry
918                    || self.mass_unloaded_right < self.mass_unloaded_left
919                {
920                    crate::bail_invalid_estim!(
921                        "latent survival interval row requires unloaded_entry <= unloaded_left <= unloaded_right, got entry={}, left={}, right={}",
922                        self.mass_unloaded_entry,
923                        self.mass_unloaded_left,
924                        self.mass_unloaded_right
925                    );
926                }
927                if self.mass_exit > 0.0
928                    || self.mass_unloaded_exit > 0.0
929                    || self.hazard_loaded > 0.0
930                    || self.hazard_unloaded > 0.0
931                {
932                    crate::bail_invalid_estim!(
933                        "latent survival interval row cannot carry exit masses or event hazards"
934                            .to_string(),
935                    );
936                }
937            }
938        }
939
940        Ok(())
941    }
942}
943
944fn exact_event_kernel_jet(
945    quadctx: &QuadratureContext,
946    row: &LatentSurvivalRow,
947    mu: f64,
948    sigma: f64,
949) -> Result<LogKernelSumJet, EstimationError> {
950    if row.hazard_loaded < 0.0 || row.hazard_unloaded < 0.0 {
951        crate::bail_invalid_estim!(
952            "latent survival exact-event hazards must be non-negative, got loaded={} unloaded={}",
953            row.hazard_loaded,
954            row.hazard_unloaded
955        );
956    }
957    match (row.hazard_unloaded > 0.0, row.hazard_loaded > 0.0) {
958        (true, true) => {
959            let terms = [
960                KernelSumTerm {
961                    coeff: row.hazard_unloaded,
962                    k: 0,
963                    m: row.mass_exit,
964                },
965                KernelSumTerm {
966                    coeff: row.hazard_loaded,
967                    k: 1,
968                    m: row.mass_exit,
969                },
970            ];
971            LogKernelSumJet::evaluate(quadctx, &terms, mu, sigma)
972        }
973        (true, false) => {
974            let jet = LogKernelSumJet::single_term(quadctx, 0, row.mass_exit, mu, sigma)?;
975            Ok(LogKernelSumJet {
976                value: row.hazard_unloaded.ln() + jet.value,
977                d1: jet.d1,
978                d2: jet.d2,
979                d3: jet.d3,
980                d4: jet.d4,
981                mode: jet.mode,
982            })
983        }
984        (false, true) => {
985            let jet = LogKernelSumJet::single_term(quadctx, 1, row.mass_exit, mu, sigma)?;
986            Ok(LogKernelSumJet {
987                value: row.hazard_loaded.ln() + jet.value,
988                d1: jet.d1,
989                d2: jet.d2,
990                d3: jet.d3,
991                d4: jet.d4,
992                mode: jet.mode,
993            })
994        }
995        (false, false) => Err(EstimationError::InvalidInput(
996            "latent survival exact-event row requires a positive loaded or unloaded hazard"
997                .to_string(),
998        )),
999    }
1000}
1001
1002/// Row-level log-likelihood and μ-derivatives for the latent survival model.
1003///
1004/// The conditional model is:
1005///   `Λ(a | U) = B(a) · exp(U)`,  `U ~ N(μ, σ²)`
1006///
1007/// All likelihoods reduce to algebra on `K_{k,m}(μ, σ)`.
1008#[derive(Clone, Copy, Debug)]
1009pub struct LatentSurvivalRowJet {
1010    pub log_lik: f64,
1011    pub score: f64,
1012    pub neg_hessian: f64,
1013    pub d3: f64,
1014    pub score_log_sigma: f64,
1015    pub neg_hessian_log_sigma: f64,
1016}
1017
1018#[inline]
1019fn log_sigma_score_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1020    let sigma2 = sigma * sigma;
1021    sigma2 * (jet.d2 + jet.d1 * jet.d1)
1022}
1023
1024#[inline]
1025fn log_sigma_neg_hessian_from_log_sum(jet: &LogKernelSumJet, sigma: f64) -> f64 {
1026    let sigma2 = sigma * sigma;
1027    let sigma4 = sigma2 * sigma2;
1028    let d1 = jet.d1;
1029    let d2 = jet.d2;
1030    let d3 = jet.d3;
1031    let d4 = jet.d4;
1032    let s2_over_s = d2 + d1 * d1;
1033    // For S = Σ a_j K_j, D = σ ∂_σ, and D S = σ² S_μμ:
1034    // D² log S = 2σ² (S''/S) + σ⁴ (S''''/S - (S''/S)²).
1035    // Express the final parenthesized term directly in log-derivatives to
1036    // avoid the larger cancellation in `r4 - r2²`.
1037    let s4_over_s_minus_s2_sq = d4 + 4.0 * d1 * d3 + 2.0 * d2 * d2 + 4.0 * d1 * d1 * d2;
1038    -(2.0 * sigma2 * s2_over_s + sigma4 * s4_over_s_minus_s2_sq)
1039}
1040
1041impl LatentSurvivalRowJet {
1042    pub fn evaluate(
1043        quadctx: &QuadratureContext,
1044        row: &LatentSurvivalRow,
1045        mu: f64,
1046        sigma: f64,
1047    ) -> Result<Self, EstimationError> {
1048        row.validate()?;
1049        match row.event_type {
1050            LatentSurvivalEventType::RightCensored => Self::right_censored(quadctx, mu, sigma, row),
1051            LatentSurvivalEventType::ExactEvent => Self::exact_event(quadctx, mu, sigma, row),
1052            LatentSurvivalEventType::IntervalCensored => {
1053                Self::interval_censored(quadctx, mu, sigma, row)
1054            }
1055        }
1056    }
1057
1058    /// Right-censoring with loaded/unloaded mass decomposition.
1059    ///
1060    /// Full formula:
1061    ///   `ℓ = -M_U_exit + log K_{0,M_L_exit} + M_U_entry - log K_{0,M_L_entry}`
1062    ///
1063    /// When `mass_unloaded_exit == 0` and `mass_unloaded_entry == 0`, this
1064    /// falls back to the original formula using `mass_exit` / `mass_entry`.
1065    fn right_censored(
1066        quadctx: &QuadratureContext,
1067        mu: f64,
1068        sigma: f64,
1069        row: &LatentSurvivalRow,
1070    ) -> Result<Self, EstimationError> {
1071        let has_unloaded =
1072            row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300;
1073
1074        // Loaded mass for the kernel terms: when unloaded mass is present,
1075        // mass_exit contains only the loaded component; otherwise it is the
1076        // total mass.
1077        let mass_exit_loaded = row.mass_exit;
1078        let mass_entry_loaded = row.mass_entry;
1079
1080        // Unloaded mass contributes a simple additive constant to log-lik
1081        let unloaded_offset = if has_unloaded {
1082            -row.mass_unloaded_exit + row.mass_unloaded_entry
1083        } else {
1084            0.0
1085        };
1086
1087        let num = LogKernelSumJet::single_term(quadctx, 0, mass_exit_loaded, mu, sigma)?;
1088        if mass_entry_loaded > 1e-300 {
1089            let den = LogKernelSumJet::single_term(quadctx, 0, mass_entry_loaded, mu, sigma)?;
1090            Ok(Self {
1091                log_lik: unloaded_offset + num.value - den.value,
1092                score: num.d1 - den.d1,
1093                neg_hessian: -(num.d2 - den.d2),
1094                d3: num.d3 - den.d3,
1095                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1096                    - log_sigma_score_from_log_sum(&den, sigma),
1097                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1098                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1099            })
1100        } else {
1101            Ok(Self {
1102                log_lik: unloaded_offset + num.value,
1103                score: num.d1,
1104                neg_hessian: -num.d2,
1105                d3: num.d3,
1106                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1107                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1108            })
1109        }
1110    }
1111
1112    /// Exact event with loaded/unloaded hazard decomposition.
1113    ///
1114    /// `ℓ = log(h_U · K_{0,M_L} + h_L · K_{1,M_L}) - M_U_event + M_U_entry - log K_{0,M_L_entry}`
1115    fn exact_event(
1116        quadctx: &QuadratureContext,
1117        mu: f64,
1118        sigma: f64,
1119        row: &LatentSurvivalRow,
1120    ) -> Result<Self, EstimationError> {
1121        let unloaded_offset =
1122            if row.mass_unloaded_exit.abs() > 1e-300 || row.mass_unloaded_entry.abs() > 1e-300 {
1123                -row.mass_unloaded_exit + row.mass_unloaded_entry
1124            } else {
1125                0.0
1126            };
1127        let num = exact_event_kernel_jet(quadctx, row, mu, sigma)?;
1128
1129        if row.mass_entry > 1e-300 {
1130            let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1131            Ok(Self {
1132                log_lik: unloaded_offset + num.value - den.value,
1133                score: num.d1 - den.d1,
1134                neg_hessian: -(num.d2 - den.d2),
1135                d3: num.d3 - den.d3,
1136                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1137                    - log_sigma_score_from_log_sum(&den, sigma),
1138                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1139                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1140            })
1141        } else {
1142            Ok(Self {
1143                log_lik: unloaded_offset + num.value,
1144                score: num.d1,
1145                neg_hessian: -num.d2,
1146                d3: num.d3,
1147                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1148                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1149            })
1150        }
1151    }
1152
1153    /// Interval event: `ℓ = log(K_{0,M_L} − K_{0,M_R}) − log K_{0,M_in}`.
1154    fn interval_censored(
1155        quadctx: &QuadratureContext,
1156        mu: f64,
1157        sigma: f64,
1158        row: &LatentSurvivalRow,
1159    ) -> Result<Self, EstimationError> {
1160        let num_terms = [
1161            KernelSumTerm {
1162                coeff: (-row.mass_unloaded_left).exp(),
1163                k: 0,
1164                m: row.mass_left,
1165            },
1166            KernelSumTerm {
1167                coeff: -(-row.mass_unloaded_right).exp(),
1168                k: 0,
1169                m: row.mass_right,
1170            },
1171        ];
1172        let num = LogKernelSumJet::evaluate(quadctx, &num_terms, mu, sigma)?;
1173
1174        if row.mass_entry > 1e-300 {
1175            let den = LogKernelSumJet::single_term(quadctx, 0, row.mass_entry, mu, sigma)?;
1176            Ok(Self {
1177                log_lik: num.value + row.mass_unloaded_entry - den.value,
1178                score: num.d1 - den.d1,
1179                neg_hessian: -(num.d2 - den.d2),
1180                d3: num.d3 - den.d3,
1181                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma)
1182                    - log_sigma_score_from_log_sum(&den, sigma),
1183                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma)
1184                    - log_sigma_neg_hessian_from_log_sum(&den, sigma),
1185            })
1186        } else {
1187            Ok(Self {
1188                log_lik: num.value + row.mass_unloaded_entry,
1189                score: num.d1,
1190                neg_hessian: -num.d2,
1191                d3: num.d3,
1192                score_log_sigma: log_sigma_score_from_log_sum(&num, sigma),
1193                neg_hessian_log_sigma: log_sigma_neg_hessian_from_log_sum(&num, sigma),
1194            })
1195        }
1196    }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use super::*;
1202
1203    fn latent_binomial_row_log_lik(
1204        ctx: &QuadratureContext,
1205        eta: f64,
1206        sigma: f64,
1207        y: f64,
1208        weight: f64,
1209    ) -> f64 {
1210        let mu = latent_cloglog_jet5(ctx, eta, sigma)
1211            .expect("latent jet")
1212            .mean;
1213        let mu = mu.clamp(1e-12, 1.0 - 1e-12);
1214        weight * (y * mu.ln() + (1.0 - y) * (1.0 - mu).ln())
1215    }
1216
1217    #[test]
1218    fn kernel_ratio_jet_d1_fd_check() {
1219        let ctx = QuadratureContext::new();
1220        let mu = 0.3;
1221        let sigma = 0.5;
1222        let m = 1.0;
1223        let k = 0usize;
1224        let h = 1e-5;
1225
1226        let bundle = log_kernel_bundle(&ctx, m, mu, sigma, k + 4).unwrap();
1227        let log_k = bundle.get(k);
1228        let ratios = kernel_ratio_jet(&bundle, k, m, 2);
1229        let kc = log_k.exp();
1230        let d1 = kc * ratios[1];
1231        let d2 = kc * ratios[2];
1232
1233        let kp = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0.exp();
1234        let km = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0.exp();
1235        let fd_d1 = (kp - km) / (2.0 * h);
1236        assert!(
1237            (d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-4,
1238            "d1: jet={d1}, fd={fd_d1}",
1239        );
1240
1241        let fd_d2 = (kp - 2.0 * kc + km) / (h * h);
1242        assert!(
1243            (d2 - fd_d2).abs() / fd_d2.abs().max(1e-15) < 1e-3,
1244            "d2: jet={d2}, fd={fd_d2}",
1245        );
1246    }
1247
1248    #[test]
1249    fn survival_right_censored_score_fd() {
1250        let ctx = QuadratureContext::new();
1251        let mu = -0.5;
1252        let sigma = 0.3;
1253        let h = 1e-6;
1254        let row = LatentSurvivalRow::right_censored(0.0, 2.0, 0.0, 0.0);
1255        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1256            .unwrap()
1257            .log_lik;
1258        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1259            .unwrap()
1260            .log_lik;
1261        let fd_score = (ll_p - ll_m) / (2.0 * h);
1262        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1263        assert!(
1264            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1265            "score={}, fd={fd_score}",
1266            jet.score
1267        );
1268    }
1269
1270    #[test]
1271    fn survival_exact_event_score_fd() {
1272        let ctx = QuadratureContext::new();
1273        let mu = 0.2;
1274        let sigma = 0.5;
1275        let h = 1e-6;
1276        let row = LatentSurvivalRow::exact_event(0.0, 1.5, 0.0, 0.0, (-0.3f64).exp(), 0.0);
1277        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1278            .unwrap()
1279            .log_lik;
1280        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1281            .unwrap()
1282            .log_lik;
1283        let fd_score = (ll_p - ll_m) / (2.0 * h);
1284        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1285        assert!(
1286            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1287            "score={}, fd={fd_score}",
1288            jet.score
1289        );
1290    }
1291
1292    #[test]
1293    fn survival_exact_event_loaded_vs_unloaded_score_fd() {
1294        let ctx = QuadratureContext::new();
1295        let mu = -0.1;
1296        let sigma = 0.4;
1297        let h = 1e-6;
1298        let row = LatentSurvivalRow::exact_event(0.3, 1.2, 0.2, 0.6, 0.9, 0.15);
1299        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1300            .unwrap()
1301            .log_lik;
1302        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1303            .unwrap()
1304            .log_lik;
1305        let fd_score = (ll_p - ll_m) / (2.0 * h);
1306        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1307        assert!(
1308            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1309            "score={}, fd={fd_score}",
1310            jet.score
1311        );
1312    }
1313
1314    #[test]
1315    fn survival_right_censored_loaded_vs_unloaded_score_fd() {
1316        let ctx = QuadratureContext::new();
1317        let mu = 0.15;
1318        let sigma: f64 = 0.35;
1319        let h = 1e-6;
1320        let row = LatentSurvivalRow::right_censored(0.4, 1.7, 0.1, 0.5);
1321        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1322            .unwrap()
1323            .log_lik;
1324        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1325            .unwrap()
1326            .log_lik;
1327        let fd_score = (ll_p - ll_m) / (2.0 * h);
1328        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1329        assert!(
1330            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1331            "score={}, fd={fd_score}",
1332            jet.score
1333        );
1334    }
1335
1336    #[test]
1337    fn survival_interval_censored_score_fd() {
1338        let ctx = QuadratureContext::new();
1339        let mu = 0.0;
1340        let sigma = 0.6;
1341        let h = 1e-6;
1342        let row = LatentSurvivalRow::interval_censored(0.0, 1.0, 2.0, 0.0, 0.0, 0.0);
1343        let ll_p = LatentSurvivalRowJet::evaluate(&ctx, &row, mu + h, sigma)
1344            .unwrap()
1345            .log_lik;
1346        let ll_m = LatentSurvivalRowJet::evaluate(&ctx, &row, mu - h, sigma)
1347            .unwrap()
1348            .log_lik;
1349        let fd_score = (ll_p - ll_m) / (2.0 * h);
1350        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1351        assert!(
1352            (jet.score - fd_score).abs() / fd_score.abs().max(1e-15) < 1e-3,
1353            "score={}, fd={fd_score}",
1354            jet.score
1355        );
1356    }
1357
1358    #[test]
1359    fn survival_interval_censored_neg_hessian_fd() {
1360        // Second μ-derivative of ℓ = log[S(L) − S(R)] for the interval kernel,
1361        // FD-checked. `neg_hessian` stores −d²ℓ/dμ², so compare against the
1362        // negated central second difference.
1363        let ctx = QuadratureContext::new();
1364        let mu = -0.2;
1365        let sigma = 0.55;
1366        let h = 2e-4;
1367        let row = LatentSurvivalRow::interval_censored(0.0, 0.7, 1.9, 0.0, 0.0, 0.0);
1368        let ll = |m: f64| {
1369            LatentSurvivalRowJet::evaluate(&ctx, &row, m, sigma)
1370                .unwrap()
1371                .log_lik
1372        };
1373        let fd_d2 = (ll(mu + h) - 2.0 * ll(mu) + ll(mu - h)) / (h * h);
1374        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1375        assert!(
1376            (jet.neg_hessian - (-fd_d2)).abs() / fd_d2.abs().max(1e-12) < 1e-2,
1377            "interval neg_hessian={}, fd(-d2)={}",
1378            jet.neg_hessian,
1379            -fd_d2
1380        );
1381    }
1382
1383    #[test]
1384    fn survival_interval_censored_log_sigma_score_fd() {
1385        // σ-recovery for interval data is driven by `score_log_sigma`, the
1386        // derivative of ℓ = log[S(L) − S(R)] w.r.t. log σ. FD-check it directly
1387        // against the row log-likelihood (this is the channel the interval fit's
1388        // latent_sd estimate moves along, the test's primary metric).
1389        let ctx = QuadratureContext::new();
1390        let mu = 0.1;
1391        let sigma: f64 = 0.6;
1392        let h = 1e-5;
1393        let row = LatentSurvivalRow::interval_censored(0.0, 0.8, 2.1, 0.0, 0.0, 0.0);
1394        let ll_at = |s: f64| {
1395            LatentSurvivalRowJet::evaluate(&ctx, &row, mu, s)
1396                .unwrap()
1397                .log_lik
1398        };
1399        // d/d(log σ) = σ · d/dσ, so FD over log σ directly.
1400        let fd_dlogsigma =
1401            (ll_at((sigma.ln() + h).exp()) - ll_at((sigma.ln() - h).exp())) / (2.0 * h);
1402        let jet = LatentSurvivalRowJet::evaluate(&ctx, &row, mu, sigma).unwrap();
1403        assert!(
1404            (jet.score_log_sigma - fd_dlogsigma).abs() / fd_dlogsigma.abs().max(1e-12) < 1e-3,
1405            "interval score_log_sigma={}, fd={fd_dlogsigma}",
1406            jet.score_log_sigma
1407        );
1408    }
1409
1410    #[test]
1411    fn log_kernel_single_term_log_sigma_derivatives_match_ghq_reference() {
1412        let ctx = QuadratureContext::new();
1413        let mu = 0.2;
1414        let sigma = 1.0;
1415        let jet = LogKernelSumJet::single_term(&ctx, 0, 1.0, mu, sigma).unwrap();
1416        let ghq = crate::inference::quadrature::cloglog_ghq_derivatives_adaptive(&ctx, mu, sigma);
1417        let survival = (1.0 - ghq.l).max(1e-300);
1418        let survival_sigma_over_survival = -ghq.l_sigma / survival;
1419        let ref_score = sigma * survival_sigma_over_survival;
1420        let ref_neg_hessian = -(ref_score
1421            + sigma
1422                * sigma
1423                * (-ghq.l_sigmasigma / survival - survival_sigma_over_survival.powi(2)));
1424
1425        assert!(
1426            (log_sigma_score_from_log_sum(&jet, sigma) - ref_score).abs()
1427                / ref_score.abs().max(1e-12)
1428                < 1e-4,
1429            "log-sigma score={}, ref={ref_score}",
1430            log_sigma_score_from_log_sum(&jet, sigma)
1431        );
1432        assert!(
1433            (log_sigma_neg_hessian_from_log_sum(&jet, sigma) - ref_neg_hessian).abs()
1434                / ref_neg_hessian.abs().max(1e-12)
1435                < 1e-3,
1436            "log-sigma neg_hessian={}, ref={ref_neg_hessian}",
1437            log_sigma_neg_hessian_from_log_sum(&jet, sigma)
1438        );
1439    }
1440
1441    #[test]
1442    fn log_kernel_sum_jet_single_term_d1_fd() {
1443        let ctx = QuadratureContext::new();
1444        let mu = 0.5;
1445        let sigma = 0.4;
1446        let m = 1.0;
1447        let k = 0usize;
1448        let h = 1e-6;
1449
1450        let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1451        let val_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1452        let val_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1453        let fd_d1 = (val_p - val_m) / (2.0 * h);
1454        assert!(
1455            (jet.d1 - fd_d1).abs() / fd_d1.abs().max(1e-15) < 1e-3,
1456            "d1={}, fd={fd_d1}",
1457            jet.d1
1458        );
1459    }
1460
1461    #[test]
1462    fn log_kernel_sum_jet_single_term_d4_fd() {
1463        let ctx = QuadratureContext::new();
1464        let mu = 0.35;
1465        let sigma = 0.45;
1466        let m = 1.2;
1467        let k = 1usize;
1468        let h = 2e-3;
1469
1470        let jet = LogKernelSumJet::single_term(&ctx, k, m, mu, sigma).unwrap();
1471        let v_pp = log_kernel_term(&ctx, k, m, mu + 2.0 * h, sigma).unwrap().0;
1472        let v_p = log_kernel_term(&ctx, k, m, mu + h, sigma).unwrap().0;
1473        let v_0 = log_kernel_term(&ctx, k, m, mu, sigma).unwrap().0;
1474        let v_m = log_kernel_term(&ctx, k, m, mu - h, sigma).unwrap().0;
1475        let v_mm = log_kernel_term(&ctx, k, m, mu - 2.0 * h, sigma).unwrap().0;
1476        let fd_d4 = (v_mm - 4.0 * v_m + 6.0 * v_0 - 4.0 * v_p + v_pp) / h.powi(4);
1477        assert!(
1478            (jet.d4 - fd_d4).abs() / jet.d4.abs().max(fd_d4.abs()).max(1e-8) < 2e-2,
1479            "d4={}, fd={fd_d4}",
1480            jet.d4
1481        );
1482    }
1483
1484    #[test]
1485    fn latent_cloglog_jet_matches_point_limit_at_zero_sigma() {
1486        let ctx = QuadratureContext::new();
1487        let eta = -0.4;
1488        let jet = latent_cloglog_jet5(&ctx, eta, 0.0).expect("latent jet");
1489        let t = eta.exp();
1490        let d1 = (eta - t).exp();
1491        let d2 = (1.0 - t) * d1;
1492        let d3 = (t * t - 3.0 * t + 1.0) * d1;
1493        let d4 = (-t * t * t + 6.0 * t * t - 7.0 * t + 1.0) * d1;
1494        let d5 = (t.powi(4) - 10.0 * t.powi(3) + 25.0 * t * t - 15.0 * t + 1.0) * d1;
1495        assert!((jet.mean - (1.0 - (-t).exp())).abs() < 1e-12);
1496        assert!((jet.d1 - d1).abs() < 1e-12);
1497        assert!((jet.d2 - d2).abs() < 1e-12);
1498        assert!((jet.d3 - d3).abs() < 1e-12);
1499        assert!((jet.d4 - d4).abs() < 1e-12);
1500        assert!((jet.d5 - d5).abs() < 1e-12);
1501    }
1502
1503    #[test]
1504    fn latent_cloglog_jet_matches_exact_kernel_recurrence() {
1505        let ctx = QuadratureContext::new();
1506        let cases = [(-4.0, 0.15), (-1.2, 0.35), (0.4, 0.6), (1.3, 0.9)];
1507
1508        for (eta, sigma) in cases {
1509            let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1510            let bundle = log_kernel_bundle(&ctx, 1.0, eta, sigma, 5).expect("kernel bundle");
1511            let k0 = bundle.get(0);
1512            let k1 = bundle.get(1).exp();
1513            let k2 = bundle.get(2).exp();
1514            let k3 = bundle.get(3).exp();
1515            let k4 = bundle.get(4).exp();
1516            let k5 = bundle.get(5).exp();
1517
1518            let mean = if k0.is_finite() { -k0.exp_m1() } else { 1.0 };
1519            let d1 = k1;
1520            let d2 = k1 - k2;
1521            let d3 = k1 - 3.0 * k2 + k3;
1522            let d4 = k1 - 7.0 * k2 + 6.0 * k3 - k4;
1523            let d5 = k1 - 15.0 * k2 + 25.0 * k3 - 10.0 * k4 + k5;
1524
1525            assert!((jet.mean - mean).abs() < 1e-12);
1526            assert!((jet.d1 - d1).abs() < 1e-12);
1527            assert!((jet.d2 - d2).abs() < 1e-12);
1528            assert!((jet.d3 - d3).abs() < 1e-12);
1529            assert!((jet.d4 - d4).abs() < 1e-12);
1530            assert!((jet.d5 - d5).abs() < 1e-12);
1531        }
1532    }
1533
1534    #[test]
1535    fn latent_cloglog_binomial_row_neg_hessian_matches_fd() {
1536        let ctx = QuadratureContext::new();
1537        let eta = 0.4;
1538        let sigma = 0.6;
1539        let y = 0.35;
1540        let weight = 2.0;
1541        let h = 1e-4;
1542
1543        let jet = latent_cloglog_jet5(&ctx, eta, sigma).expect("latent jet");
1544        let mu = jet.mean.clamp(1e-12, 1.0 - 1e-12);
1545        let ellmu = y / mu - (1.0 - y) / (1.0 - mu);
1546        let ellmumu = -y / (mu * mu) - (1.0 - y) / ((1.0 - mu) * (1.0 - mu));
1547        let neg_hessian = -weight * (ellmumu * jet.d1 * jet.d1 + ellmu * jet.d2);
1548
1549        let ll_minus = latent_binomial_row_log_lik(&ctx, eta - h, sigma, y, weight);
1550        let ll0 = latent_binomial_row_log_lik(&ctx, eta, sigma, y, weight);
1551        let ll_plus = latent_binomial_row_log_lik(&ctx, eta + h, sigma, y, weight);
1552        let neg_hessian_fd = -(ll_plus - 2.0 * ll0 + ll_minus) / (h * h);
1553
1554        let err = (neg_hessian - neg_hessian_fd).abs();
1555        let tol = 2e-5_f64.max(3e-3 * neg_hessian_fd.abs());
1556        assert!(
1557            err <= tol,
1558            "latent cloglog Bernoulli row curvature mismatch: analytic={} fd={}",
1559            neg_hessian,
1560            neg_hessian_fd
1561        );
1562    }
1563}