Skip to main content

gam_math/
probability.rs

1use statrs::function::erf::erfc;
2
3/// Standard normal PDF phi(x).
4#[inline]
5pub fn normal_pdf(x: f64) -> f64 {
6    const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
7    INV_SQRT_2PI * (-0.5 * x * x).exp()
8}
9
10/// Standard normal CDF Phi(x) evaluated via the exact special-function identity
11///
12///   Phi(x) = 0.5 * erfc(-x / sqrt(2)).
13///
14/// This is the exact Gaussian CDF semantics used throughout the codebase. The
15/// numerical `erfc` implementation may use internal approximations, but the
16/// returned function is the standard normal CDF itself rather than a separate
17/// polynomial surrogate surface.
18#[inline]
19pub fn normal_cdf(x: f64) -> f64 {
20    0.5 * statrs::function::erf::erfc(-x / std::f64::consts::SQRT_2)
21}
22
23/// Scaled complementary error function `erfcx(x) = exp(x²) · erfc(x)`,
24/// specialized to `x ≥ 0`.  Returns `1.0` for `x ≤ 0` and `0.0` for
25/// `x = +∞`.  For `0 < x < 26` uses the direct `exp(x²)·erfc(x)` form;
26/// beyond that the (otherwise overflowing) `exp(x²)` is replaced by a
27/// 4-term asymptotic expansion `(1/(x√π))·(1 − 1/(2x²) + 3/(4x⁴) − …)`,
28/// keeping relative accuracy near machine epsilon. The non-negative
29/// restriction lets the caller skip the reflection identity.
30#[inline]
31pub fn erfcx_nonnegative(x: f64) -> f64 {
32    if !x.is_finite() {
33        return if x.is_sign_positive() {
34            0.0
35        } else {
36            f64::INFINITY
37        };
38    }
39    if x <= 0.0 {
40        return 1.0;
41    }
42    if x < 26.0 {
43        ((x * x).min(700.0)).exp() * erfc(x)
44    } else {
45        let inv = 1.0 / x;
46        let inv2 = inv * inv;
47        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
48            + 6.5625 * inv2 * inv2 * inv2 * inv2;
49        inv * poly / std::f64::consts::PI.sqrt()
50    }
51}
52
53/// Computes `log(1 - exp(-a))` for `a >= 0` without cancellation.
54#[inline]
55pub fn log1mexp_positive(a: f64) -> f64 {
56    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
57    if a > core::f64::consts::LN_2 {
58        (-(-a).exp()).ln_1p()
59    } else if a > 0.0 {
60        (-(-a).exp_m1()).ln()
61    } else {
62        f64::NEG_INFINITY
63    }
64}
65
66/// Numerically stable signed log-sum-exp.  Given pairs
67/// `(log|aⱼ|, sign(aⱼ))` (with `signs[j] ∈ {−1, 0, +1}`), returns
68/// `(log|S|, sign(S))` for `S = Σⱼ signs[j]·exp(log_mags[j])`.  Positive
69/// and negative magnitudes are reduced separately with the standard
70/// log-sum-exp trick (subtract the max, sum, log, add back); the two
71/// partial sums are then combined via `log(|p − n|) =
72/// max(log p, log n) + log1mexp(|log p − log n|)`, preserving accuracy
73/// even when `p ≈ n` (catastrophic cancellation regime).  When all
74/// signs are zero or all magnitudes are `−∞`, returns
75/// `(NEG_INFINITY, 0.0)`.
76///
77/// A `+∞` log-magnitude denotes an infinite-magnitude term (`exp(+∞) = +∞`)
78/// and dominates the sum: if it appears only with positive sign the result
79/// is `(+∞, +1)`; only with negative sign, `(+∞, −1)` (a log-magnitude of
80/// `+∞` with sign `−1` encodes the value `−∞`); with both signs the sum is
81/// the indeterminate `+∞ − ∞`, returned as `(NaN, 0.0)`.  A `−∞`
82/// log-magnitude is `exp(−∞) = 0` and is correctly dropped.
83pub fn signed_log_sum_exp(log_mags: &[f64], signs: &[f64]) -> (f64, f64) {
84    // Infinite-magnitude terms dominate any finite contribution, so resolve
85    // them before the finite log-sum-exp reduction below. `−∞` log-magnitudes
86    // are `exp(−∞) = 0` and need no special handling.
87    let mut has_pos_inf = false;
88    let mut has_neg_inf = false;
89    for (idx, &lm) in log_mags.iter().enumerate() {
90        if lm == f64::INFINITY {
91            if signs[idx] > 0.0 {
92                has_pos_inf = true;
93            } else if signs[idx] < 0.0 {
94                has_neg_inf = true;
95            }
96        }
97    }
98    match (has_pos_inf, has_neg_inf) {
99        // P = +∞, N = +∞ ⇒ indeterminate +∞ − ∞.
100        (true, true) => return (f64::NAN, 0.0),
101        // P = +∞, N < ∞ ⇒ S = +∞.
102        (true, false) => return (f64::INFINITY, 1.0),
103        // N = +∞, P < ∞ ⇒ S = −∞, encoded as log-magnitude +∞ with sign −1.
104        (false, true) => return (f64::INFINITY, -1.0),
105        (false, false) => {}
106    }
107
108    let mut pos_max = f64::NEG_INFINITY;
109    let mut neg_max = f64::NEG_INFINITY;
110    for (idx, &lm) in log_mags.iter().enumerate() {
111        if signs[idx] > 0.0 {
112            pos_max = pos_max.max(lm);
113        } else if signs[idx] < 0.0 {
114            neg_max = neg_max.max(lm);
115        }
116    }
117
118    let mut pos_sum = 0.0_f64;
119    let mut neg_sum = 0.0_f64;
120    for (idx, &lm) in log_mags.iter().enumerate() {
121        if !lm.is_finite() {
122            continue;
123        }
124        if signs[idx] > 0.0 {
125            pos_sum += (lm - pos_max).exp();
126        } else if signs[idx] < 0.0 {
127            neg_sum += (lm - neg_max).exp();
128        }
129    }
130
131    let log_pos = if pos_sum > 0.0 {
132        pos_max + pos_sum.ln()
133    } else {
134        f64::NEG_INFINITY
135    };
136    let log_neg = if neg_sum > 0.0 {
137        neg_max + neg_sum.ln()
138    } else {
139        f64::NEG_INFINITY
140    };
141
142    if log_pos == f64::NEG_INFINITY && log_neg == f64::NEG_INFINITY {
143        // Both partial sums are empty: no terms at all, all signs zero, or every
144        // magnitude `−∞` (each `exp(−∞) = 0`). The signed sum is exactly `0`, so
145        // the contract requires `(−∞, 0.0)` — NOT the positive-sum convention,
146        // which would mislabel a zero as `+1` and corrupt any downstream cascade
147        // that reads back the sign.
148        return (f64::NEG_INFINITY, 0.0);
149    }
150    if log_neg == f64::NEG_INFINITY {
151        return (log_pos, 1.0);
152    }
153    if log_pos == f64::NEG_INFINITY {
154        return (log_neg, -1.0);
155    }
156    if log_pos > log_neg {
157        let gap = log_pos - log_neg;
158        (log_pos + log1mexp_positive(gap), 1.0)
159    } else if log_neg > log_pos {
160        let gap = log_neg - log_pos;
161        (log_neg + log1mexp_positive(gap), -1.0)
162    } else {
163        (f64::NEG_INFINITY, 0.0)
164    }
165}
166
167/// Numerically stable `ln Φ(x)` for the standard normal CDF.  For `x ≥ 0`
168/// computes `ln(Φ(x))` directly with a small floor against underflow; for
169/// `x < 0` rewrites
170/// `ln Φ(x) = −u² + ln(½·erfcx(u))`, `u = −x/√2`,
171/// which preserves digits all the way into the deep left tail (no
172/// `ln(0)`).  Returns `±∞` and `NaN` at the corresponding inputs.
173#[inline]
174pub fn normal_logcdf(x: f64) -> f64 {
175    if x == f64::INFINITY {
176        return 0.0;
177    }
178    if x == f64::NEG_INFINITY {
179        return f64::NEG_INFINITY;
180    }
181    if x.is_nan() {
182        return f64::NAN;
183    }
184    if x < 0.0 {
185        let u = -x / std::f64::consts::SQRT_2;
186        -u * u + (0.5 * erfcx_nonnegative(u).max(1e-300)).ln()
187    } else {
188        normal_cdf(x).clamp(1e-300, 1.0).ln()
189    }
190}
191
192/// Numerically stable `ln(1 − Φ(x)) = ln Φ(−x)` for the standard normal
193/// survival function.  Delegates to `normal_logcdf(-x)` so the deep-right
194/// tail benefits from the same `erfcx`-based representation.
195#[inline]
196pub fn normal_logsf(x: f64) -> f64 {
197    normal_logcdf(-x)
198}
199
200/// Joint evaluation of `ln Φ(x)` and the Mills-ratio analogue
201/// `φ(x) / Φ(x)`, signed for the symmetric branch.  Used by the latent
202/// probit families where the inverse-link gradient needs the ratio and
203/// the likelihood needs the log-CDF on the same `x`; computing both in
204/// one call shares the `erfcx` evaluation that dominates the cost in the
205/// deep tail.
206#[inline]
207pub fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
208    if x == f64::INFINITY {
209        return (0.0, 0.0);
210    }
211    if x == f64::NEG_INFINITY {
212        return (f64::NEG_INFINITY, f64::INFINITY);
213    }
214    if x.is_nan() {
215        return (f64::NAN, f64::NAN);
216    }
217    if x < 0.0 {
218        let u = -x / std::f64::consts::SQRT_2;
219        let ex = erfcx_nonnegative(u).max(1e-300);
220        let log_cdf = -u * u + (0.5 * ex).ln();
221        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
222        (log_cdf, lambda)
223    } else {
224        let cdf = normal_cdf(x).clamp(1e-300, 1.0);
225        let lambda = normal_pdf(x) / cdf;
226        (cdf.ln(), lambda)
227    }
228}
229
230/// Standard normal quantile Φ⁻¹(p) using Acklam's rational approximation.
231#[inline]
232pub fn standard_normal_quantile(p: f64) -> Result<f64, String> {
233    if !(p.is_finite() && p > 0.0 && p < 1.0) {
234        return Err(format!("normal quantile requires p in (0,1), got {p}"));
235    }
236
237    const A: [f64; 6] = [
238        -3.969_683_028_665_376e1,
239        2.209_460_984_245_205e2,
240        -2.759_285_104_469_687e2,
241        1.383_577_518_672_69e2,
242        -3.066_479_806_614_716e1,
243        2.506_628_277_459_239,
244    ];
245    const B: [f64; 5] = [
246        -5.447_609_879_822_406e1,
247        1.615_858_368_580_409e2,
248        -1.556_989_798_598_866e2,
249        6.680_131_188_771_972e1,
250        -1.328_068_155_288_572e1,
251    ];
252    const C: [f64; 6] = [
253        -7.784_894_002_430_293e-3,
254        -3.223_964_580_411_365e-1,
255        -2.400_758_277_161_838,
256        -2.549_732_539_343_734,
257        4.374_664_141_464_968,
258        2.938_163_982_698_783,
259    ];
260    const D: [f64; 4] = [
261        7.784_695_709_041_462e-3,
262        3.224_671_290_700_398e-1,
263        2.445_134_137_142_996,
264        3.754_408_661_907_416,
265    ];
266    const P_LOW: f64 = 0.02425;
267    const P_HIGH: f64 = 1.0 - P_LOW;
268
269    let mut x = if p < P_LOW {
270        let q = (-2.0 * p.ln()).sqrt();
271        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
272            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
273    } else if p <= P_HIGH {
274        let q = p - 0.5;
275        let r = q * q;
276        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
277            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
278    } else {
279        let q = (-2.0 * (1.0 - p).ln()).sqrt();
280        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
281            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
282    };
283    for _ in 0..2 {
284        let density = normal_pdf(x);
285        if !(density.is_finite() && density > 0.0) {
286            break;
287        }
288        // Residual F(x) − p, formed without catastrophic cancellation in
289        // either tail. For an upper-tail iterate `x > 0`, `normal_cdf(x)`
290        // saturates to ~1, so the direct `normal_cdf(x) − p` annihilates the
291        // tiny residual the polish must act on; instead use the upper-tail
292        // complement `F(x) − p = (1 − p) − 0.5·erfc(x/√2)`, where both terms
293        // are the small upper-tail quantities (`1 − p` is exact by Sterbenz
294        // for `p ∈ [½,1)`). For `x ≤ 0`, `normal_cdf(x) = 0.5·erfc(|x|/√2)` is
295        // itself the faithfully carried small lower-tail value, so the direct
296        // form is already cancellation-free.
297        let residual = if x > 0.0 {
298            (1.0 - p) - 0.5 * erfc(x / std::f64::consts::SQRT_2)
299        } else {
300            normal_cdf(x) - p
301        };
302        let correction = residual / density;
303        let denominator = 1.0 + 0.5 * x * correction;
304        if !(correction.is_finite() && denominator.is_finite() && denominator != 0.0) {
305            break;
306        }
307        let step = correction / denominator;
308        if !step.is_finite() {
309            break;
310        }
311        x -= step;
312        if step.abs() <= 2.0 * f64::EPSILON * x.abs().max(1.0) {
313            break;
314        }
315    }
316    Ok(x)
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    const TOL: f64 = 1e-12;
324
325    fn rel_err(got: f64, expected: f64) -> f64 {
326        (got - expected).abs() / expected.abs().max(1e-300)
327    }
328
329    // ── normal_pdf ────────────────────────────────────────────────────────────
330
331    #[test]
332    fn normal_pdf_at_zero() {
333        let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
334        assert!((normal_pdf(0.0) - expected).abs() < TOL);
335    }
336
337    #[test]
338    fn normal_pdf_symmetry() {
339        for &x in &[0.5, 1.0, 2.0, 3.0, 5.0] {
340            assert_eq!(normal_pdf(x), normal_pdf(-x), "symmetry failed at x={x}");
341        }
342    }
343
344    #[test]
345    fn normal_pdf_positive() {
346        for &x in &[-5.0, -1.0, 0.0, 1.0, 5.0] {
347            assert!(normal_pdf(x) > 0.0, "pdf should be positive at x={x}");
348        }
349    }
350
351    // ── normal_cdf ────────────────────────────────────────────────────────────
352
353    #[test]
354    fn normal_cdf_at_zero_is_half() {
355        assert!((normal_cdf(0.0) - 0.5).abs() < TOL);
356    }
357
358    #[test]
359    fn normal_cdf_symmetry() {
360        for &x in &[0.5, 1.0, 2.0, 3.0] {
361            let sum = normal_cdf(x) + normal_cdf(-x);
362            assert!(
363                (sum - 1.0).abs() < TOL,
364                "cdf symmetry failed at x={x}: sum={sum}"
365            );
366        }
367    }
368
369    #[test]
370    fn normal_cdf_bounds() {
371        assert!(normal_cdf(10.0) > 0.9999);
372        assert!(normal_cdf(-10.0) < 1e-22);
373        assert!(normal_cdf(0.0) > 0.0);
374        assert!(normal_cdf(0.0) < 1.0);
375    }
376
377    #[test]
378    fn normal_cdf_at_1_96_near_0975() {
379        // Phi(1.96) ≈ 0.975 — canonical two-sided 5% critical value.
380        let p = normal_cdf(1.959_963_985);
381        assert!((p - 0.975).abs() < 1e-8, "p={p}");
382    }
383
384    // ── erfcx_nonnegative ─────────────────────────────────────────────────────
385
386    #[test]
387    fn erfcx_at_nonpositive_returns_one() {
388        assert_eq!(erfcx_nonnegative(0.0), 1.0);
389        assert_eq!(erfcx_nonnegative(-1.0), 1.0);
390        assert_eq!(erfcx_nonnegative(-100.0), 1.0);
391    }
392
393    #[test]
394    fn erfcx_positive_inf_returns_zero() {
395        assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
396    }
397
398    #[test]
399    fn erfcx_negative_inf_returns_inf() {
400        assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
401    }
402
403    #[test]
404    fn erfcx_small_positive_matches_direct() {
405        use statrs::function::erf::erfc;
406        for &x in &[0.1_f64, 0.5, 1.0, 5.0, 10.0, 25.0] {
407            let got = erfcx_nonnegative(x);
408            let expected = (x * x).exp() * erfc(x);
409            let err = rel_err(got, expected);
410            assert!(
411                err < 1e-10,
412                "x={x}: got={got} expected={expected} rel={err}"
413            );
414        }
415    }
416
417    #[test]
418    fn erfcx_large_x_positive_and_finite() {
419        // For x >= 26 the asymptotic branch must remain positive and finite.
420        let got = erfcx_nonnegative(50.0);
421        assert!(got.is_finite() && got > 0.0, "erfcx(50)={got}");
422        // Leading asymptotic term: 1/(x*sqrt(pi)).
423        let asymptotic = 1.0 / (50.0 * std::f64::consts::PI.sqrt());
424        assert!(
425            rel_err(got, asymptotic) < 1e-3,
426            "got={got} asymptotic={asymptotic}"
427        );
428    }
429
430    // ── log1mexp_positive ─────────────────────────────────────────────────────
431
432    #[test]
433    fn log1mexp_at_zero_is_neg_inf() {
434        assert_eq!(log1mexp_positive(0.0), f64::NEG_INFINITY);
435    }
436
437    #[test]
438    fn log1mexp_recovers_log_one_minus_exp() {
439        // Verify exp(log1mexp(a)) + exp(-a) ≈ 1 for several a > 0. This
440        // roundtrip avoids computing `(1 - exp(-a)).ln()` directly, which
441        // suffers catastrophic cancellation for large a (e.g. a=20 where
442        // `1.0 - exp(-20)` loses 9 decimal digits from the subtraction).
443        for &a in &[0.001_f64, 0.5, std::f64::consts::LN_2, 1.0, 5.0, 20.0] {
444            let lm = log1mexp_positive(a);
445            let roundtrip = lm.exp() + (-a).exp();
446            assert!(
447                (roundtrip - 1.0).abs() < 1e-14,
448                "a={a}: exp(log1mexp(a)) + exp(-a) = {roundtrip}, expected 1.0"
449            );
450        }
451    }
452
453    #[test]
454    fn log1mexp_at_ln2_is_neg_ln2() {
455        let ln2 = std::f64::consts::LN_2;
456        let got = log1mexp_positive(ln2);
457        assert!((got - (-ln2)).abs() < TOL, "got={got}");
458    }
459
460    // ── signed_log_sum_exp ────────────────────────────────────────────────────
461
462    #[test]
463    fn slse_all_positive_single() {
464        let (lm, sg) = signed_log_sum_exp(&[2.0], &[1.0]);
465        assert!((lm - 2.0).abs() < TOL);
466        assert!((sg - 1.0).abs() < TOL);
467    }
468
469    #[test]
470    fn slse_difference_recovers_log2() {
471        // 3 - 1 = 2 → log|2| = ln(2), sign = +1.
472        let log3 = 3.0_f64.ln();
473        let log1 = 0.0_f64; // ln(1)
474        let (lm, sg) = signed_log_sum_exp(&[log3, log1], &[1.0, -1.0]);
475        assert!((lm - 2.0_f64.ln()).abs() < TOL, "lm={lm}");
476        assert!((sg - 1.0).abs() < TOL, "sg={sg}");
477    }
478
479    #[test]
480    fn slse_cancellation_gives_neg_inf() {
481        // a - a = 0 → log|0| = -∞.
482        let ln2 = 2.0_f64.ln();
483        let (lm, sg) = signed_log_sum_exp(&[ln2, ln2], &[1.0, -1.0]);
484        assert_eq!(lm, f64::NEG_INFINITY);
485        assert_eq!(sg, 0.0);
486    }
487
488    #[test]
489    fn slse_empty_returns_neg_inf_with_zero_sign() {
490        // With no terms the sum is exactly 0, so the docstring contract is
491        // `(−∞, 0.0)`. (This test previously encoded the buggy `+1.0` positive-sum
492        // convention, which contradicted both the docstring and the cancellation
493        // test below; rewritten to the correct zero sign.)
494        let (lm, sg) = signed_log_sum_exp(&[], &[]);
495        assert_eq!(lm, f64::NEG_INFINITY);
496        assert_eq!(sg, 0.0);
497    }
498
499    #[test]
500    fn slse_all_zero_signs_return_zero_sign() {
501        // A single term whose sign is 0 contributes nothing; S = 0 ⇒ (−∞, 0.0).
502        let (lm, sg) = signed_log_sum_exp(&[0.0], &[0.0]);
503        assert_eq!(lm, f64::NEG_INFINITY);
504        assert_eq!(sg, 0.0);
505    }
506
507    #[test]
508    fn slse_all_neg_inf_magnitudes_return_zero_sign() {
509        // Every magnitude is exp(−∞) = 0 regardless of sign, so the sum is 0 and
510        // the reported sign must be 0.0, not +1.0.
511        let (lm, sg) = signed_log_sum_exp(&[f64::NEG_INFINITY, f64::NEG_INFINITY], &[1.0, -1.0]);
512        assert_eq!(lm, f64::NEG_INFINITY);
513        assert_eq!(sg, 0.0);
514    }
515
516    #[test]
517    fn slse_pos_inf_dominates() {
518        let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[1.0, -1.0]);
519        assert_eq!(lm, f64::INFINITY);
520        assert_eq!(sg, 1.0);
521    }
522
523    #[test]
524    fn slse_neg_inf_dominates() {
525        let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[-1.0, 1.0]);
526        assert_eq!(lm, f64::INFINITY);
527        assert_eq!(sg, -1.0);
528    }
529
530    #[test]
531    fn slse_both_inf_signs_gives_nan() {
532        let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, f64::INFINITY], &[1.0, -1.0]);
533        assert!(lm.is_nan());
534        assert_eq!(sg, 0.0);
535    }
536
537    // ── normal_logcdf ─────────────────────────────────────────────────────────
538
539    #[test]
540    fn logcdf_at_zero_is_log_half() {
541        let got = normal_logcdf(0.0);
542        let expected = 0.5_f64.ln();
543        assert!((got - expected).abs() < TOL, "got={got}");
544    }
545
546    #[test]
547    fn logcdf_pos_inf_is_zero() {
548        assert_eq!(normal_logcdf(f64::INFINITY), 0.0);
549    }
550
551    #[test]
552    fn logcdf_neg_inf_is_neg_inf() {
553        assert_eq!(normal_logcdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
554    }
555
556    #[test]
557    fn logcdf_nan_is_nan() {
558        assert!(normal_logcdf(f64::NAN).is_nan());
559    }
560
561    #[test]
562    fn logcdf_matches_log_cdf_for_moderate_x() {
563        for &x in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0] {
564            let got = normal_logcdf(x);
565            let expected = normal_cdf(x).ln();
566            assert!(
567                (got - expected).abs() < 1e-10,
568                "x={x}: got={got} expected={expected}"
569            );
570        }
571    }
572
573    #[test]
574    fn logcdf_deep_left_tail_stays_finite() {
575        // For very negative x, normal_cdf(x) underflows to 0, but logcdf should
576        // remain finite and large-negative.
577        let got = normal_logcdf(-20.0);
578        assert!(got.is_finite() && got < -100.0, "logcdf(-20)={got}");
579    }
580
581    // ── normal_logsf ─────────────────────────────────────────────────────────
582
583    #[test]
584    fn logsf_at_zero_is_log_half() {
585        let got = normal_logsf(0.0);
586        let expected = 0.5_f64.ln();
587        assert!((got - expected).abs() < TOL, "got={got}");
588    }
589
590    #[test]
591    fn logsf_mirrors_logcdf() {
592        // logsf(x) = logcdf(-x) by definition.
593        for &x in &[-3.0_f64, -1.0, 0.0, 1.0, 3.0] {
594            assert_eq!(normal_logsf(x), normal_logcdf(-x));
595        }
596    }
597
598    // ── signed_probit_logcdf_and_mills_ratio ──────────────────────────────────
599
600    #[test]
601    fn probit_at_pos_inf() {
602        let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::INFINITY);
603        assert_eq!(lc, 0.0);
604        assert_eq!(mr, 0.0);
605    }
606
607    #[test]
608    fn probit_at_neg_inf() {
609        let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NEG_INFINITY);
610        assert_eq!(lc, f64::NEG_INFINITY);
611        assert_eq!(mr, f64::INFINITY);
612    }
613
614    #[test]
615    fn probit_nan_propagates() {
616        let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NAN);
617        assert!(lc.is_nan() && mr.is_nan());
618    }
619
620    #[test]
621    fn probit_at_zero_logcdf_and_mills() {
622        let (lc, mr) = signed_probit_logcdf_and_mills_ratio(0.0);
623        assert!((lc - 0.5_f64.ln()).abs() < TOL, "lc={lc}");
624        // phi(0)/Phi(0) = 0.3989.../0.5 ≈ 0.7979.
625        assert!((mr - 0.797_884_560_802_865).abs() < 1e-10, "mr={mr}");
626    }
627
628    #[test]
629    fn probit_positive_branch_matches_logcdf() {
630        for &x in &[0.5_f64, 1.0, 2.0, 3.0] {
631            let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
632            let lc_ref = normal_logcdf(x);
633            let mr_ref = normal_pdf(x) / normal_cdf(x);
634            assert!(
635                (lc - lc_ref).abs() < 1e-10,
636                "x={x}: lc={lc} lc_ref={lc_ref}"
637            );
638            assert!(
639                (mr - mr_ref).abs() < 1e-10,
640                "x={x}: mr={mr} mr_ref={mr_ref}"
641            );
642        }
643    }
644
645    #[test]
646    fn probit_negative_branch_matches_logcdf() {
647        for &x in &[-0.5_f64, -1.0, -2.0, -5.0] {
648            let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
649            let lc_ref = normal_logcdf(x);
650            assert!(
651                (lc - lc_ref).abs() < 1e-10,
652                "x={x}: lc={lc} lc_ref={lc_ref}"
653            );
654            assert!(mr.is_finite() && mr > 0.0, "x={x}: mr={mr}");
655        }
656    }
657
658    // ── standard_normal_quantile ──────────────────────────────────────────────
659
660    #[test]
661    fn quantile_rejects_out_of_range() {
662        assert!(standard_normal_quantile(0.0).is_err());
663        assert!(standard_normal_quantile(1.0).is_err());
664        assert!(standard_normal_quantile(-0.1).is_err());
665        assert!(standard_normal_quantile(1.1).is_err());
666        assert!(standard_normal_quantile(f64::NAN).is_err());
667    }
668
669    #[test]
670    fn quantile_at_half_is_near_zero() {
671        let q = standard_normal_quantile(0.5).unwrap();
672        assert!(q.abs() < 1e-10, "quantile(0.5)={q}");
673    }
674
675    #[test]
676    fn quantile_at_0975_is_near_196() {
677        let q = standard_normal_quantile(0.975).unwrap();
678        assert!((q - 1.959_963_985).abs() < 1e-7, "q={q}");
679    }
680
681    #[test]
682    fn quantile_antisymmetry() {
683        let q_lo = standard_normal_quantile(0.1).unwrap();
684        let q_hi = standard_normal_quantile(0.9).unwrap();
685        assert!((q_lo + q_hi).abs() < 1e-10, "q_lo={q_lo} q_hi={q_hi}");
686    }
687
688    #[test]
689    fn quantile_roundtrip_cdf() {
690        for &p in &[
691            0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999,
692        ] {
693            let q = standard_normal_quantile(p).unwrap();
694            let p_back = normal_cdf(q);
695            assert!(
696                (p_back - p).abs() < 1e-10,
697                "roundtrip failed at p={p}: q={q} p_back={p_back}"
698            );
699        }
700    }
701}