Skip to main content

gam_models/survival/location_scale/
residual_dist.rs

1use super::*;
2
3/// Layer 2 defense: compute q0 = -eta_t * exp(-eta_ls) with log-space
4/// overflow detection.  When log|q0| = ln|eta_t| + (-eta_ls) exceeds the
5/// clamp ceiling, the product would overflow; we saturate to ±MAX instead.
6#[inline]
7pub(crate) fn survival_q0_from_eta(eta_t: f64, eta_ls: f64) -> f64 {
8    if eta_t == 0.0 {
9        return 0.0;
10    }
11    let log_abs = eta_t.abs().ln() + (-eta_ls).min(EXP_NEG_STABLE_MAX_ARG);
12    if log_abs > EXP_NEG_STABLE_MAX_ARG {
13        if eta_t > 0.0 { -f64::MAX } else { f64::MAX }
14    } else {
15        -eta_t * exp_sigma_inverse_from_eta_scalar(eta_ls)
16    }
17}
18
19#[inline]
20pub(crate) fn probit_survival_value(eta: f64) -> f64 {
21    if eta.is_nan() {
22        f64::NAN
23    } else if eta == f64::INFINITY {
24        0.0
25    } else if eta == f64::NEG_INFINITY {
26        1.0
27    } else {
28        0.5 * erfc(eta / std::f64::consts::SQRT_2)
29    }
30}
31
32#[inline]
33pub(crate) fn probit_log_survival_and_ratio_derivatives(eta: f64) -> (f64, f64, f64, f64, f64) {
34    if eta.is_nan() {
35        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
36    }
37    if eta == f64::NEG_INFINITY {
38        return (0.0, 0.0, 0.0, 0.0, 0.0);
39    }
40    let x = eta / std::f64::consts::SQRT_2;
41    let (log_survival, ratio) = if eta >= 0.0 {
42        // erfcx(x) = exp(x²)·erfc(x); compute once and reuse for both
43        // log-survival and the hazard ratio.
44        let erfcx_val = erfcx_nonnegative(x);
45        let log_surv = -0.5 * eta * eta + (0.5 * erfcx_val).ln();
46        let r = std::f64::consts::FRAC_2_SQRT_PI / (std::f64::consts::SQRT_2 * erfcx_val);
47        (log_surv, r)
48    } else {
49        let survival = probit_survival_value(eta);
50        (survival.ln(), normal_pdf(eta) / survival)
51    };
52    let dr = ratio * (ratio - eta);
53    let ddr = 2.0 * ratio.powi(3) - 3.0 * eta * ratio.powi(2) + (eta * eta - 1.0) * ratio;
54    let dddr = 6.0 * ratio.powi(4) - 12.0 * eta * ratio.powi(3)
55        + (7.0 * eta * eta - 4.0) * ratio.powi(2)
56        + (-eta * eta * eta + 3.0 * eta) * ratio;
57    (log_survival, ratio, dr, ddr, dddr)
58}
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
61pub enum ResidualDistribution {
62    Gaussian,
63    Gumbel,
64    Logistic,
65}
66
67pub trait ResidualDistributionOps {
68    fn cdf(&self, z: f64) -> f64;
69    fn pdf(&self, z: f64) -> f64;
70    fn pdf_derivative(&self, z: f64) -> f64;
71    fn pdfsecond_derivative(&self, z: f64) -> f64;
72    fn pdfthird_derivative(&self, z: f64) -> f64;
73
74    /// Fourth derivative of the residual-distribution PDF, f''''(z).
75    ///
76    /// This is the m4 ingredient for the outer REML Hessian's Q[v_k, v_l] term.
77    /// The second directional derivative of the inner Hessian (used by the outer
78    /// Hessian drift) requires the 4th derivative of the composed likelihood
79    /// F_αβγδ via the Arbogast chain rule. That chain rule's leading term
80    /// m4·u_α·u_β·u_γ·u_δ needs this quantity.
81    ///
82    /// See response.md Section 6 for the mathematical derivation.
83    fn pdffourth_derivative(&self, z: f64) -> f64;
84}
85
86impl ResidualDistributionOps for ResidualDistribution {
87    fn cdf(&self, z: f64) -> f64 {
88        match self {
89            ResidualDistribution::Gaussian => normal_cdf(z),
90            ResidualDistribution::Gumbel => {
91                component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).mu
92            }
93            ResidualDistribution::Logistic => {
94                component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).mu
95            }
96        }
97    }
98
99    fn pdf(&self, z: f64) -> f64 {
100        match self {
101            ResidualDistribution::Gaussian => normal_pdf(z),
102            ResidualDistribution::Gumbel => {
103                component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d1
104            }
105            ResidualDistribution::Logistic => {
106                component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d1
107            }
108        }
109    }
110
111    fn pdf_derivative(&self, z: f64) -> f64 {
112        match self {
113            ResidualDistribution::Gaussian => -z * normal_pdf(z),
114            ResidualDistribution::Gumbel => {
115                component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d2
116            }
117            ResidualDistribution::Logistic => {
118                component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d2
119            }
120        }
121    }
122
123    fn pdfsecond_derivative(&self, z: f64) -> f64 {
124        match self {
125            ResidualDistribution::Gaussian => {
126                let f = normal_pdf(z);
127                (z * z - 1.0) * f
128            }
129            ResidualDistribution::Gumbel => {
130                component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d3
131            }
132            ResidualDistribution::Logistic => {
133                component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d3
134            }
135        }
136    }
137
138    fn pdfthird_derivative(&self, z: f64) -> f64 {
139        match self {
140            ResidualDistribution::Gaussian => {
141                let f = normal_pdf(z);
142                -(z * z * z - 3.0 * z) * f
143            }
144            ResidualDistribution::Gumbel => inverse_link_pdfthird_derivative_for_inverse_link(
145                &InverseLink::Standard(StandardLink::CLogLog),
146                z,
147            )
148            .expect("standard cloglog inverse-link third derivative should evaluate"),
149            ResidualDistribution::Logistic => inverse_link_pdfthird_derivative_for_inverse_link(
150                &InverseLink::Standard(StandardLink::Logit),
151                z,
152            )
153            .expect("standard logit inverse-link third derivative should evaluate"),
154        }
155    }
156
157    /// Fourth derivative of the residual-distribution PDF.
158    ///
159    /// # Derivations
160    ///
161    /// **Gaussian**: f(z) = φ(z). The n-th derivative of the Gaussian PDF is
162    /// (-1)^n He_n(z) φ(z) where He_n is the probabilist's Hermite polynomial.
163    /// He_4(z) = z⁴ - 6z² + 3, so f''''(z) = (z⁴ - 6z² + 3) φ(z).
164    ///
165    /// **Logistic**: f(z) = s(1-s) with s = σ(z). The k-th derivative of f is
166    /// f · P_k(s) where P_k satisfies the Euler-polynomial recurrence
167    /// P_{k+1}(s) = (1-2s) P_k(s) + s(1-s) P_k'(s).
168    /// P_4(s) = 1 - 30s + 150s² - 240s³ + 120s⁴.
169    ///
170    /// **Gumbel**: f(z) = exp(z - e^z). Let e = e^z. The k-th derivative of f
171    /// is f · Q_k(e) where Q_k satisfies Q_{k+1}(e) = (1-e) Q_k(e) + e Q_k'(e).
172    /// Q_4(e) = 1 - 15e + 25e² - 10e³ + e⁴.
173    fn pdffourth_derivative(&self, z: f64) -> f64 {
174        match self {
175            ResidualDistribution::Gaussian => {
176                let f = normal_pdf(z);
177                let z2 = z * z;
178                // He_4(z) = z^4 - 6z^2 + 3
179                (z2 * z2 - 6.0 * z2 + 3.0) * f
180            }
181            ResidualDistribution::Gumbel => inverse_link_pdffourth_derivative_for_inverse_link(
182                &InverseLink::Standard(StandardLink::CLogLog),
183                z,
184            )
185            .expect("standard cloglog inverse-link fourth derivative should evaluate"),
186            ResidualDistribution::Logistic => inverse_link_pdffourth_derivative_for_inverse_link(
187                &InverseLink::Standard(StandardLink::Logit),
188                z,
189            )
190            .expect("standard logit inverse-link fourth derivative should evaluate"),
191        }
192    }
193}
194
195#[inline]
196pub(crate) fn residual_distribution_link(distribution: ResidualDistribution) -> StandardLink {
197    match distribution {
198        ResidualDistribution::Gaussian => StandardLink::Probit,
199        ResidualDistribution::Gumbel => StandardLink::CLogLog,
200        ResidualDistribution::Logistic => StandardLink::Logit,
201    }
202}
203
204#[inline]
205pub fn residual_distribution_inverse_link(distribution: ResidualDistribution) -> InverseLink {
206    InverseLink::Standard(residual_distribution_link(distribution))
207}
208
209/// Maps an `InverseLink` to its `ResidualDistribution` counterpart when the
210/// link is one of the three standard survival residual-distribution links
211/// (Probit/Logit/CLogLog). Returns `None` for stateful / mixture links (Sas,
212/// BetaLogistic, Mixture, LatentCLogLog) and for non-residual-distribution
213/// standard links — those carry their full state via `payload.link` and have
214/// no `ResidualDistribution` representation.
215#[inline]
216pub fn residual_distribution_from_inverse_link(link: &InverseLink) -> Option<ResidualDistribution> {
217    match link {
218        InverseLink::Standard(StandardLink::Probit) => Some(ResidualDistribution::Gaussian),
219        InverseLink::Standard(StandardLink::CLogLog) => Some(ResidualDistribution::Gumbel),
220        InverseLink::Standard(StandardLink::Logit) => Some(ResidualDistribution::Logistic),
221        _ => None,
222    }
223}
224
225/// Fourth derivative of the inverse-link PDF (= 5th derivative of the CDF).
226///
227/// This is the f'''' quantity used in the 4th derivative of log f(u), which
228/// in turn enters the m4 ingredient of the Arbogast chain rule for
229/// the outer REML Hessian Q[v_k, v_l] term.
230///
231/// For the three standard survival residual distributions (Probit, Logit,
232/// CLogLog), uses the closed-form ResidualDistribution implementations.
233/// For all other inverse links (SAS, BetaLogistic, Mixture), delegates
234/// to the generic `inverse_link_pdffourth_derivative_for_inverse_link`
235/// dispatcher in mixture_link.rs.
236pub(crate) fn inverse_link_pdffourth_derivative(
237    inverse_link: &InverseLink,
238    eta: f64,
239) -> Result<f64, SurvivalLocationScaleError> {
240    match inverse_link {
241        InverseLink::Standard(StandardLink::Probit) => {
242            Ok(ResidualDistribution::Gaussian.pdffourth_derivative(eta))
243        }
244        InverseLink::Standard(StandardLink::Logit) => {
245            Ok(ResidualDistribution::Logistic.pdffourth_derivative(eta))
246        }
247        InverseLink::Standard(StandardLink::CLogLog) => {
248            Ok(ResidualDistribution::Gumbel.pdffourth_derivative(eta))
249        }
250        _ => gam_solve::mixture_link::inverse_link_pdffourth_derivative_for_inverse_link(
251            inverse_link,
252            eta,
253        )
254        .map_err(|e| SurvivalLocationScaleError::NumericalFailure {
255            reason: format!("inverse link fourth-derivative evaluation failed at eta={eta}: {e}"),
256        }),
257    }
258}