Skip to main content

gam_math/
probability.rs

1use statrs::function::erf::erfc;
2
3/// Standard normal PDF phi(x).
4#[inline]
5pub fn normal_pdf(x: f64) -> f64 {
6    const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
7    INV_SQRT_2PI * (-0.5 * x * x).exp()
8}
9
10/// Standard normal CDF Phi(x) evaluated via the exact special-function identity
11///
12///   Phi(x) = 0.5 * erfc(-x / sqrt(2)).
13///
14/// This is the exact Gaussian CDF semantics used throughout the codebase. The
15/// numerical `erfc` implementation may use internal approximations, but the
16/// returned function is the standard normal CDF itself rather than a separate
17/// polynomial surrogate surface.
18#[inline]
19pub fn normal_cdf(x: f64) -> f64 {
20    0.5 * statrs::function::erf::erfc(-x / std::f64::consts::SQRT_2)
21}
22
23/// Scaled complementary error function `erfcx(x) = exp(x²) · erfc(x)`,
24/// specialized to `x ≥ 0`.  Returns `1.0` for `x ≤ 0` and `0.0` for
25/// `x = +∞`.  For `0 < x < 26` uses the direct `exp(x²)·erfc(x)` form;
26/// beyond that the (otherwise overflowing) `exp(x²)` is replaced by a
27/// 4-term asymptotic expansion `(1/(x√π))·(1 − 1/(2x²) + 3/(4x⁴) − …)`,
28/// keeping relative accuracy near machine epsilon. The non-negative
29/// restriction lets the caller skip the reflection identity.
30#[inline]
31pub fn erfcx_nonnegative(x: f64) -> f64 {
32    if !x.is_finite() {
33        return if x.is_sign_positive() {
34            0.0
35        } else {
36            f64::INFINITY
37        };
38    }
39    if x <= 0.0 {
40        return 1.0;
41    }
42    if x < 26.0 {
43        ((x * x).min(700.0)).exp() * erfc(x)
44    } else {
45        let inv = 1.0 / x;
46        let inv2 = inv * inv;
47        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
48            + 6.5625 * inv2 * inv2 * inv2 * inv2;
49        inv * poly / std::f64::consts::PI.sqrt()
50    }
51}
52
53/// Computes `log(1 - exp(-a))` for `a >= 0` without cancellation.
54#[inline]
55pub fn log1mexp_positive(a: f64) -> f64 {
56    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
57    if a > core::f64::consts::LN_2 {
58        (-(-a).exp()).ln_1p()
59    } else if a > 0.0 {
60        (-(-a).exp_m1()).ln()
61    } else {
62        f64::NEG_INFINITY
63    }
64}
65
66/// Numerically stable signed log-sum-exp.  Given pairs
67/// `(log|aⱼ|, sign(aⱼ))` (with `signs[j] ∈ {−1, 0, +1}`), returns
68/// `(log|S|, sign(S))` for `S = Σⱼ signs[j]·exp(log_mags[j])`.  Positive
69/// and negative magnitudes are reduced separately with the standard
70/// log-sum-exp trick (subtract the max, sum, log, add back); the two
71/// partial sums are then combined via `log(|p − n|) =
72/// max(log p, log n) + log1mexp(|log p − log n|)`, preserving accuracy
73/// even when `p ≈ n` (catastrophic cancellation regime).  When all
74/// signs are zero or all magnitudes are `−∞`, returns
75/// `(NEG_INFINITY, 0.0)`.
76///
77/// A `+∞` log-magnitude denotes an infinite-magnitude term (`exp(+∞) = +∞`)
78/// and dominates the sum: if it appears only with positive sign the result
79/// is `(+∞, +1)`; only with negative sign, `(+∞, −1)` (a log-magnitude of
80/// `+∞` with sign `−1` encodes the value `−∞`); with both signs the sum is
81/// the indeterminate `+∞ − ∞`, returned as `(NaN, 0.0)`.  A `−∞`
82/// log-magnitude is `exp(−∞) = 0` and is correctly dropped.
83pub fn signed_log_sum_exp(log_mags: &[f64], signs: &[f64]) -> (f64, f64) {
84    // Infinite-magnitude terms dominate any finite contribution, so resolve
85    // them before the finite log-sum-exp reduction below. `−∞` log-magnitudes
86    // are `exp(−∞) = 0` and need no special handling.
87    let mut has_pos_inf = false;
88    let mut has_neg_inf = false;
89    for (idx, &lm) in log_mags.iter().enumerate() {
90        if lm == f64::INFINITY {
91            if signs[idx] > 0.0 {
92                has_pos_inf = true;
93            } else if signs[idx] < 0.0 {
94                has_neg_inf = true;
95            }
96        }
97    }
98    match (has_pos_inf, has_neg_inf) {
99        // P = +∞, N = +∞ ⇒ indeterminate +∞ − ∞.
100        (true, true) => return (f64::NAN, 0.0),
101        // P = +∞, N < ∞ ⇒ S = +∞.
102        (true, false) => return (f64::INFINITY, 1.0),
103        // N = +∞, P < ∞ ⇒ S = −∞, encoded as log-magnitude +∞ with sign −1.
104        (false, true) => return (f64::INFINITY, -1.0),
105        (false, false) => {}
106    }
107
108    let mut pos_max = f64::NEG_INFINITY;
109    let mut neg_max = f64::NEG_INFINITY;
110    for (idx, &lm) in log_mags.iter().enumerate() {
111        if signs[idx] > 0.0 {
112            pos_max = pos_max.max(lm);
113        } else if signs[idx] < 0.0 {
114            neg_max = neg_max.max(lm);
115        }
116    }
117
118    let mut pos_sum = 0.0_f64;
119    let mut neg_sum = 0.0_f64;
120    for (idx, &lm) in log_mags.iter().enumerate() {
121        if !lm.is_finite() {
122            continue;
123        }
124        if signs[idx] > 0.0 {
125            pos_sum += (lm - pos_max).exp();
126        } else if signs[idx] < 0.0 {
127            neg_sum += (lm - neg_max).exp();
128        }
129    }
130
131    let log_pos = if pos_sum > 0.0 {
132        pos_max + pos_sum.ln()
133    } else {
134        f64::NEG_INFINITY
135    };
136    let log_neg = if neg_sum > 0.0 {
137        neg_max + neg_sum.ln()
138    } else {
139        f64::NEG_INFINITY
140    };
141
142    if log_neg == f64::NEG_INFINITY {
143        return (log_pos, 1.0);
144    }
145    if log_pos == f64::NEG_INFINITY {
146        return (log_neg, -1.0);
147    }
148    if log_pos > log_neg {
149        let gap = log_pos - log_neg;
150        (log_pos + log1mexp_positive(gap), 1.0)
151    } else if log_neg > log_pos {
152        let gap = log_neg - log_pos;
153        (log_neg + log1mexp_positive(gap), -1.0)
154    } else {
155        (f64::NEG_INFINITY, 0.0)
156    }
157}
158
159/// Numerically stable `ln Φ(x)` for the standard normal CDF.  For `x ≥ 0`
160/// computes `ln(Φ(x))` directly with a small floor against underflow; for
161/// `x < 0` rewrites
162/// `ln Φ(x) = −u² + ln(½·erfcx(u))`, `u = −x/√2`,
163/// which preserves digits all the way into the deep left tail (no
164/// `ln(0)`).  Returns `±∞` and `NaN` at the corresponding inputs.
165#[inline]
166pub fn normal_logcdf(x: f64) -> f64 {
167    if x == f64::INFINITY {
168        return 0.0;
169    }
170    if x == f64::NEG_INFINITY {
171        return f64::NEG_INFINITY;
172    }
173    if x.is_nan() {
174        return f64::NAN;
175    }
176    if x < 0.0 {
177        let u = -x / std::f64::consts::SQRT_2;
178        -u * u + (0.5 * erfcx_nonnegative(u).max(1e-300)).ln()
179    } else {
180        normal_cdf(x).clamp(1e-300, 1.0).ln()
181    }
182}
183
184/// Numerically stable `ln(1 − Φ(x)) = ln Φ(−x)` for the standard normal
185/// survival function.  Delegates to `normal_logcdf(-x)` so the deep-right
186/// tail benefits from the same `erfcx`-based representation.
187#[inline]
188pub fn normal_logsf(x: f64) -> f64 {
189    normal_logcdf(-x)
190}
191
192/// Joint evaluation of `ln Φ(x)` and the Mills-ratio analogue
193/// `φ(x) / Φ(x)`, signed for the symmetric branch.  Used by the latent
194/// probit families where the inverse-link gradient needs the ratio and
195/// the likelihood needs the log-CDF on the same `x`; computing both in
196/// one call shares the `erfcx` evaluation that dominates the cost in the
197/// deep tail.
198#[inline]
199pub fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
200    if x == f64::INFINITY {
201        return (0.0, 0.0);
202    }
203    if x == f64::NEG_INFINITY {
204        return (f64::NEG_INFINITY, f64::INFINITY);
205    }
206    if x.is_nan() {
207        return (f64::NAN, f64::NAN);
208    }
209    if x < 0.0 {
210        let u = -x / std::f64::consts::SQRT_2;
211        let ex = erfcx_nonnegative(u).max(1e-300);
212        let log_cdf = -u * u + (0.5 * ex).ln();
213        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
214        (log_cdf, lambda)
215    } else {
216        let cdf = normal_cdf(x).clamp(1e-300, 1.0);
217        let lambda = normal_pdf(x) / cdf;
218        (cdf.ln(), lambda)
219    }
220}
221
222/// Standard normal quantile Φ⁻¹(p) using Acklam's rational approximation.
223#[inline]
224pub fn standard_normal_quantile(p: f64) -> Result<f64, String> {
225    if !(p.is_finite() && p > 0.0 && p < 1.0) {
226        return Err(format!("normal quantile requires p in (0,1), got {p}"));
227    }
228
229    const A: [f64; 6] = [
230        -3.969_683_028_665_376e1,
231        2.209_460_984_245_205e2,
232        -2.759_285_104_469_687e2,
233        1.383_577_518_672_69e2,
234        -3.066_479_806_614_716e1,
235        2.506_628_277_459_239,
236    ];
237    const B: [f64; 5] = [
238        -5.447_609_879_822_406e1,
239        1.615_858_368_580_409e2,
240        -1.556_989_798_598_866e2,
241        6.680_131_188_771_972e1,
242        -1.328_068_155_288_572e1,
243    ];
244    const C: [f64; 6] = [
245        -7.784_894_002_430_293e-3,
246        -3.223_964_580_411_365e-1,
247        -2.400_758_277_161_838,
248        -2.549_732_539_343_734,
249        4.374_664_141_464_968,
250        2.938_163_982_698_783,
251    ];
252    const D: [f64; 4] = [
253        7.784_695_709_041_462e-3,
254        3.224_671_290_700_398e-1,
255        2.445_134_137_142_996,
256        3.754_408_661_907_416,
257    ];
258    const P_LOW: f64 = 0.02425;
259    const P_HIGH: f64 = 1.0 - P_LOW;
260
261    let mut x = if p < P_LOW {
262        let q = (-2.0 * p.ln()).sqrt();
263        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
264            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
265    } else if p <= P_HIGH {
266        let q = p - 0.5;
267        let r = q * q;
268        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
269            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
270    } else {
271        let q = (-2.0 * (1.0 - p).ln()).sqrt();
272        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
273            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
274    };
275    for _ in 0..2 {
276        let density = normal_pdf(x);
277        if !(density.is_finite() && density > 0.0) {
278            break;
279        }
280        // Residual F(x) − p, formed without catastrophic cancellation in
281        // either tail. For an upper-tail iterate `x > 0`, `normal_cdf(x)`
282        // saturates to ~1, so the direct `normal_cdf(x) − p` annihilates the
283        // tiny residual the polish must act on; instead use the upper-tail
284        // complement `F(x) − p = (1 − p) − 0.5·erfc(x/√2)`, where both terms
285        // are the small upper-tail quantities (`1 − p` is exact by Sterbenz
286        // for `p ∈ [½,1)`). For `x ≤ 0`, `normal_cdf(x) = 0.5·erfc(|x|/√2)` is
287        // itself the faithfully carried small lower-tail value, so the direct
288        // form is already cancellation-free.
289        let residual = if x > 0.0 {
290            (1.0 - p) - 0.5 * erfc(x / std::f64::consts::SQRT_2)
291        } else {
292            normal_cdf(x) - p
293        };
294        let correction = residual / density;
295        let denominator = 1.0 + 0.5 * x * correction;
296        if !(correction.is_finite() && denominator.is_finite() && denominator != 0.0) {
297            break;
298        }
299        let step = correction / denominator;
300        if !step.is_finite() {
301            break;
302        }
303        x -= step;
304        if step.abs() <= 2.0 * f64::EPSILON * x.abs().max(1.0) {
305            break;
306        }
307    }
308    Ok(x)
309}