1use statrs::function::erf::erfc;
2
3#[inline]
5pub fn normal_pdf(x: f64) -> f64 {
6 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
7 INV_SQRT_2PI * (-0.5 * x * x).exp()
8}
9
10#[inline]
19pub fn normal_cdf(x: f64) -> f64 {
20 0.5 * statrs::function::erf::erfc(-x / std::f64::consts::SQRT_2)
21}
22
23#[inline]
31pub fn erfcx_nonnegative(x: f64) -> f64 {
32 if !x.is_finite() {
33 return if x.is_sign_positive() {
34 0.0
35 } else {
36 f64::INFINITY
37 };
38 }
39 if x <= 0.0 {
40 return 1.0;
41 }
42 if x < 26.0 {
43 ((x * x).min(700.0)).exp() * erfc(x)
44 } else {
45 let inv = 1.0 / x;
46 let inv2 = inv * inv;
47 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
48 + 6.5625 * inv2 * inv2 * inv2 * inv2;
49 inv * poly / std::f64::consts::PI.sqrt()
50 }
51}
52
53#[inline]
55pub fn log1mexp_positive(a: f64) -> f64 {
56 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
57 if a > core::f64::consts::LN_2 {
58 (-(-a).exp()).ln_1p()
59 } else if a > 0.0 {
60 (-(-a).exp_m1()).ln()
61 } else {
62 f64::NEG_INFINITY
63 }
64}
65
66pub fn signed_log_sum_exp(log_mags: &[f64], signs: &[f64]) -> (f64, f64) {
84 let mut has_pos_inf = false;
88 let mut has_neg_inf = false;
89 for (idx, &lm) in log_mags.iter().enumerate() {
90 if lm == f64::INFINITY {
91 if signs[idx] > 0.0 {
92 has_pos_inf = true;
93 } else if signs[idx] < 0.0 {
94 has_neg_inf = true;
95 }
96 }
97 }
98 match (has_pos_inf, has_neg_inf) {
99 (true, true) => return (f64::NAN, 0.0),
101 (true, false) => return (f64::INFINITY, 1.0),
103 (false, true) => return (f64::INFINITY, -1.0),
105 (false, false) => {}
106 }
107
108 let mut pos_max = f64::NEG_INFINITY;
109 let mut neg_max = f64::NEG_INFINITY;
110 for (idx, &lm) in log_mags.iter().enumerate() {
111 if signs[idx] > 0.0 {
112 pos_max = pos_max.max(lm);
113 } else if signs[idx] < 0.0 {
114 neg_max = neg_max.max(lm);
115 }
116 }
117
118 let mut pos_sum = 0.0_f64;
119 let mut neg_sum = 0.0_f64;
120 for (idx, &lm) in log_mags.iter().enumerate() {
121 if !lm.is_finite() {
122 continue;
123 }
124 if signs[idx] > 0.0 {
125 pos_sum += (lm - pos_max).exp();
126 } else if signs[idx] < 0.0 {
127 neg_sum += (lm - neg_max).exp();
128 }
129 }
130
131 let log_pos = if pos_sum > 0.0 {
132 pos_max + pos_sum.ln()
133 } else {
134 f64::NEG_INFINITY
135 };
136 let log_neg = if neg_sum > 0.0 {
137 neg_max + neg_sum.ln()
138 } else {
139 f64::NEG_INFINITY
140 };
141
142 if log_neg == f64::NEG_INFINITY {
143 return (log_pos, 1.0);
144 }
145 if log_pos == f64::NEG_INFINITY {
146 return (log_neg, -1.0);
147 }
148 if log_pos > log_neg {
149 let gap = log_pos - log_neg;
150 (log_pos + log1mexp_positive(gap), 1.0)
151 } else if log_neg > log_pos {
152 let gap = log_neg - log_pos;
153 (log_neg + log1mexp_positive(gap), -1.0)
154 } else {
155 (f64::NEG_INFINITY, 0.0)
156 }
157}
158
159#[inline]
166pub fn normal_logcdf(x: f64) -> f64 {
167 if x == f64::INFINITY {
168 return 0.0;
169 }
170 if x == f64::NEG_INFINITY {
171 return f64::NEG_INFINITY;
172 }
173 if x.is_nan() {
174 return f64::NAN;
175 }
176 if x < 0.0 {
177 let u = -x / std::f64::consts::SQRT_2;
178 -u * u + (0.5 * erfcx_nonnegative(u).max(1e-300)).ln()
179 } else {
180 normal_cdf(x).clamp(1e-300, 1.0).ln()
181 }
182}
183
184#[inline]
188pub fn normal_logsf(x: f64) -> f64 {
189 normal_logcdf(-x)
190}
191
192#[inline]
199pub fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
200 if x == f64::INFINITY {
201 return (0.0, 0.0);
202 }
203 if x == f64::NEG_INFINITY {
204 return (f64::NEG_INFINITY, f64::INFINITY);
205 }
206 if x.is_nan() {
207 return (f64::NAN, f64::NAN);
208 }
209 if x < 0.0 {
210 let u = -x / std::f64::consts::SQRT_2;
211 let ex = erfcx_nonnegative(u).max(1e-300);
212 let log_cdf = -u * u + (0.5 * ex).ln();
213 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
214 (log_cdf, lambda)
215 } else {
216 let cdf = normal_cdf(x).clamp(1e-300, 1.0);
217 let lambda = normal_pdf(x) / cdf;
218 (cdf.ln(), lambda)
219 }
220}
221
222#[inline]
224pub fn standard_normal_quantile(p: f64) -> Result<f64, String> {
225 if !(p.is_finite() && p > 0.0 && p < 1.0) {
226 return Err(format!("normal quantile requires p in (0,1), got {p}"));
227 }
228
229 const A: [f64; 6] = [
230 -3.969_683_028_665_376e1,
231 2.209_460_984_245_205e2,
232 -2.759_285_104_469_687e2,
233 1.383_577_518_672_69e2,
234 -3.066_479_806_614_716e1,
235 2.506_628_277_459_239,
236 ];
237 const B: [f64; 5] = [
238 -5.447_609_879_822_406e1,
239 1.615_858_368_580_409e2,
240 -1.556_989_798_598_866e2,
241 6.680_131_188_771_972e1,
242 -1.328_068_155_288_572e1,
243 ];
244 const C: [f64; 6] = [
245 -7.784_894_002_430_293e-3,
246 -3.223_964_580_411_365e-1,
247 -2.400_758_277_161_838,
248 -2.549_732_539_343_734,
249 4.374_664_141_464_968,
250 2.938_163_982_698_783,
251 ];
252 const D: [f64; 4] = [
253 7.784_695_709_041_462e-3,
254 3.224_671_290_700_398e-1,
255 2.445_134_137_142_996,
256 3.754_408_661_907_416,
257 ];
258 const P_LOW: f64 = 0.02425;
259 const P_HIGH: f64 = 1.0 - P_LOW;
260
261 let mut x = if p < P_LOW {
262 let q = (-2.0 * p.ln()).sqrt();
263 (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
264 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
265 } else if p <= P_HIGH {
266 let q = p - 0.5;
267 let r = q * q;
268 (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
269 / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
270 } else {
271 let q = (-2.0 * (1.0 - p).ln()).sqrt();
272 -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
273 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
274 };
275 for _ in 0..2 {
276 let density = normal_pdf(x);
277 if !(density.is_finite() && density > 0.0) {
278 break;
279 }
280 let residual = if x > 0.0 {
290 (1.0 - p) - 0.5 * erfc(x / std::f64::consts::SQRT_2)
291 } else {
292 normal_cdf(x) - p
293 };
294 let correction = residual / density;
295 let denominator = 1.0 + 0.5 * x * correction;
296 if !(correction.is_finite() && denominator.is_finite() && denominator != 0.0) {
297 break;
298 }
299 let step = correction / denominator;
300 if !step.is_finite() {
301 break;
302 }
303 x -= step;
304 if step.abs() <= 2.0 * f64::EPSILON * x.abs().max(1.0) {
305 break;
306 }
307 }
308 Ok(x)
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 const TOL: f64 = 1e-12;
316
317 fn rel_err(got: f64, expected: f64) -> f64 {
318 (got - expected).abs() / expected.abs().max(1e-300)
319 }
320
321 #[test]
324 fn normal_pdf_at_zero() {
325 let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
326 assert!((normal_pdf(0.0) - expected).abs() < TOL);
327 }
328
329 #[test]
330 fn normal_pdf_symmetry() {
331 for &x in &[0.5, 1.0, 2.0, 3.0, 5.0] {
332 assert_eq!(normal_pdf(x), normal_pdf(-x), "symmetry failed at x={x}");
333 }
334 }
335
336 #[test]
337 fn normal_pdf_positive() {
338 for &x in &[-5.0, -1.0, 0.0, 1.0, 5.0] {
339 assert!(normal_pdf(x) > 0.0, "pdf should be positive at x={x}");
340 }
341 }
342
343 #[test]
346 fn normal_cdf_at_zero_is_half() {
347 assert!((normal_cdf(0.0) - 0.5).abs() < TOL);
348 }
349
350 #[test]
351 fn normal_cdf_symmetry() {
352 for &x in &[0.5, 1.0, 2.0, 3.0] {
353 let sum = normal_cdf(x) + normal_cdf(-x);
354 assert!((sum - 1.0).abs() < TOL, "cdf symmetry failed at x={x}: sum={sum}");
355 }
356 }
357
358 #[test]
359 fn normal_cdf_bounds() {
360 assert!(normal_cdf(10.0) > 0.9999);
361 assert!(normal_cdf(-10.0) < 1e-22);
362 assert!(normal_cdf(0.0) > 0.0);
363 assert!(normal_cdf(0.0) < 1.0);
364 }
365
366 #[test]
367 fn normal_cdf_at_1_96_near_0975() {
368 let p = normal_cdf(1.959_963_985);
370 assert!((p - 0.975).abs() < 1e-8, "p={p}");
371 }
372
373 #[test]
376 fn erfcx_at_nonpositive_returns_one() {
377 assert_eq!(erfcx_nonnegative(0.0), 1.0);
378 assert_eq!(erfcx_nonnegative(-1.0), 1.0);
379 assert_eq!(erfcx_nonnegative(-100.0), 1.0);
380 }
381
382 #[test]
383 fn erfcx_positive_inf_returns_zero() {
384 assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
385 }
386
387 #[test]
388 fn erfcx_negative_inf_returns_inf() {
389 assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
390 }
391
392 #[test]
393 fn erfcx_small_positive_matches_direct() {
394 use statrs::function::erf::erfc;
395 for &x in &[0.1_f64, 0.5, 1.0, 5.0, 10.0, 25.0] {
396 let got = erfcx_nonnegative(x);
397 let expected = (x * x).exp() * erfc(x);
398 let err = rel_err(got, expected);
399 assert!(err < 1e-10, "x={x}: got={got} expected={expected} rel={err}");
400 }
401 }
402
403 #[test]
404 fn erfcx_large_x_positive_and_finite() {
405 let got = erfcx_nonnegative(50.0);
407 assert!(got.is_finite() && got > 0.0, "erfcx(50)={got}");
408 let asymptotic = 1.0 / (50.0 * std::f64::consts::PI.sqrt());
410 assert!(rel_err(got, asymptotic) < 1e-3, "got={got} asymptotic={asymptotic}");
411 }
412
413 #[test]
416 fn log1mexp_at_zero_is_neg_inf() {
417 assert_eq!(log1mexp_positive(0.0), f64::NEG_INFINITY);
418 }
419
420 #[test]
421 fn log1mexp_recovers_log_one_minus_exp() {
422 for &a in &[0.001_f64, 0.5, std::f64::consts::LN_2, 1.0, 5.0, 20.0] {
427 let lm = log1mexp_positive(a);
428 let roundtrip = lm.exp() + (-a).exp();
429 assert!(
430 (roundtrip - 1.0).abs() < 1e-14,
431 "a={a}: exp(log1mexp(a)) + exp(-a) = {roundtrip}, expected 1.0"
432 );
433 }
434 }
435
436 #[test]
437 fn log1mexp_at_ln2_is_neg_ln2() {
438 let ln2 = std::f64::consts::LN_2;
439 let got = log1mexp_positive(ln2);
440 assert!((got - (-ln2)).abs() < TOL, "got={got}");
441 }
442
443 #[test]
446 fn slse_all_positive_single() {
447 let (lm, sg) = signed_log_sum_exp(&[2.0], &[1.0]);
448 assert!((lm - 2.0).abs() < TOL);
449 assert!((sg - 1.0).abs() < TOL);
450 }
451
452 #[test]
453 fn slse_difference_recovers_log2() {
454 let log3 = 3.0_f64.ln();
456 let log1 = 0.0_f64; let (lm, sg) = signed_log_sum_exp(&[log3, log1], &[1.0, -1.0]);
458 assert!((lm - 2.0_f64.ln()).abs() < TOL, "lm={lm}");
459 assert!((sg - 1.0).abs() < TOL, "sg={sg}");
460 }
461
462 #[test]
463 fn slse_cancellation_gives_neg_inf() {
464 let ln2 = 2.0_f64.ln();
466 let (lm, sg) = signed_log_sum_exp(&[ln2, ln2], &[1.0, -1.0]);
467 assert_eq!(lm, f64::NEG_INFINITY);
468 assert_eq!(sg, 0.0);
469 }
470
471 #[test]
472 fn slse_empty_returns_neg_inf() {
473 let (lm, sg) = signed_log_sum_exp(&[], &[]);
477 assert_eq!(lm, f64::NEG_INFINITY);
478 assert_eq!(sg, 1.0);
479 }
480
481 #[test]
482 fn slse_pos_inf_dominates() {
483 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[1.0, -1.0]);
484 assert_eq!(lm, f64::INFINITY);
485 assert_eq!(sg, 1.0);
486 }
487
488 #[test]
489 fn slse_neg_inf_dominates() {
490 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[-1.0, 1.0]);
491 assert_eq!(lm, f64::INFINITY);
492 assert_eq!(sg, -1.0);
493 }
494
495 #[test]
496 fn slse_both_inf_signs_gives_nan() {
497 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, f64::INFINITY], &[1.0, -1.0]);
498 assert!(lm.is_nan());
499 assert_eq!(sg, 0.0);
500 }
501
502 #[test]
505 fn logcdf_at_zero_is_log_half() {
506 let got = normal_logcdf(0.0);
507 let expected = 0.5_f64.ln();
508 assert!((got - expected).abs() < TOL, "got={got}");
509 }
510
511 #[test]
512 fn logcdf_pos_inf_is_zero() {
513 assert_eq!(normal_logcdf(f64::INFINITY), 0.0);
514 }
515
516 #[test]
517 fn logcdf_neg_inf_is_neg_inf() {
518 assert_eq!(normal_logcdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
519 }
520
521 #[test]
522 fn logcdf_nan_is_nan() {
523 assert!(normal_logcdf(f64::NAN).is_nan());
524 }
525
526 #[test]
527 fn logcdf_matches_log_cdf_for_moderate_x() {
528 for &x in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0] {
529 let got = normal_logcdf(x);
530 let expected = normal_cdf(x).ln();
531 assert!((got - expected).abs() < 1e-10, "x={x}: got={got} expected={expected}");
532 }
533 }
534
535 #[test]
536 fn logcdf_deep_left_tail_stays_finite() {
537 let got = normal_logcdf(-20.0);
540 assert!(got.is_finite() && got < -100.0, "logcdf(-20)={got}");
541 }
542
543 #[test]
546 fn logsf_at_zero_is_log_half() {
547 let got = normal_logsf(0.0);
548 let expected = 0.5_f64.ln();
549 assert!((got - expected).abs() < TOL, "got={got}");
550 }
551
552 #[test]
553 fn logsf_mirrors_logcdf() {
554 for &x in &[-3.0_f64, -1.0, 0.0, 1.0, 3.0] {
556 assert_eq!(normal_logsf(x), normal_logcdf(-x));
557 }
558 }
559
560 #[test]
563 fn probit_at_pos_inf() {
564 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::INFINITY);
565 assert_eq!(lc, 0.0);
566 assert_eq!(mr, 0.0);
567 }
568
569 #[test]
570 fn probit_at_neg_inf() {
571 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NEG_INFINITY);
572 assert_eq!(lc, f64::NEG_INFINITY);
573 assert_eq!(mr, f64::INFINITY);
574 }
575
576 #[test]
577 fn probit_nan_propagates() {
578 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NAN);
579 assert!(lc.is_nan() && mr.is_nan());
580 }
581
582 #[test]
583 fn probit_at_zero_logcdf_and_mills() {
584 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(0.0);
585 assert!((lc - 0.5_f64.ln()).abs() < TOL, "lc={lc}");
586 assert!((mr - 0.797_884_560_802_865).abs() < 1e-10, "mr={mr}");
588 }
589
590 #[test]
591 fn probit_positive_branch_matches_logcdf() {
592 for &x in &[0.5_f64, 1.0, 2.0, 3.0] {
593 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
594 let lc_ref = normal_logcdf(x);
595 let mr_ref = normal_pdf(x) / normal_cdf(x);
596 assert!((lc - lc_ref).abs() < 1e-10, "x={x}: lc={lc} lc_ref={lc_ref}");
597 assert!((mr - mr_ref).abs() < 1e-10, "x={x}: mr={mr} mr_ref={mr_ref}");
598 }
599 }
600
601 #[test]
602 fn probit_negative_branch_matches_logcdf() {
603 for &x in &[-0.5_f64, -1.0, -2.0, -5.0] {
604 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
605 let lc_ref = normal_logcdf(x);
606 assert!((lc - lc_ref).abs() < 1e-10, "x={x}: lc={lc} lc_ref={lc_ref}");
607 assert!(mr.is_finite() && mr > 0.0, "x={x}: mr={mr}");
608 }
609 }
610
611 #[test]
614 fn quantile_rejects_out_of_range() {
615 assert!(standard_normal_quantile(0.0).is_err());
616 assert!(standard_normal_quantile(1.0).is_err());
617 assert!(standard_normal_quantile(-0.1).is_err());
618 assert!(standard_normal_quantile(1.1).is_err());
619 assert!(standard_normal_quantile(f64::NAN).is_err());
620 }
621
622 #[test]
623 fn quantile_at_half_is_near_zero() {
624 let q = standard_normal_quantile(0.5).unwrap();
625 assert!(q.abs() < 1e-10, "quantile(0.5)={q}");
626 }
627
628 #[test]
629 fn quantile_at_0975_is_near_196() {
630 let q = standard_normal_quantile(0.975).unwrap();
631 assert!((q - 1.959_963_985).abs() < 1e-7, "q={q}");
632 }
633
634 #[test]
635 fn quantile_antisymmetry() {
636 let q_lo = standard_normal_quantile(0.1).unwrap();
637 let q_hi = standard_normal_quantile(0.9).unwrap();
638 assert!((q_lo + q_hi).abs() < 1e-10, "q_lo={q_lo} q_hi={q_hi}");
639 }
640
641 #[test]
642 fn quantile_roundtrip_cdf() {
643 for &p in &[0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999] {
644 let q = standard_normal_quantile(p).unwrap();
645 let p_back = normal_cdf(q);
646 assert!(
647 (p_back - p).abs() < 1e-10,
648 "roundtrip failed at p={p}: q={q} p_back={p_back}"
649 );
650 }
651 }
652}