1use statrs::function::erf::erfc;
2
3#[inline]
5pub fn normal_pdf(x: f64) -> f64 {
6 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
7 INV_SQRT_2PI * (-0.5 * x * x).exp()
8}
9
10#[inline]
19pub fn normal_cdf(x: f64) -> f64 {
20 0.5 * statrs::function::erf::erfc(-x / std::f64::consts::SQRT_2)
21}
22
23#[inline]
31pub fn erfcx_nonnegative(x: f64) -> f64 {
32 if !x.is_finite() {
33 return if x.is_sign_positive() {
34 0.0
35 } else {
36 f64::INFINITY
37 };
38 }
39 if x <= 0.0 {
40 return 1.0;
41 }
42 if x < 26.0 {
43 ((x * x).min(700.0)).exp() * erfc(x)
44 } else {
45 let inv = 1.0 / x;
46 let inv2 = inv * inv;
47 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
48 + 6.5625 * inv2 * inv2 * inv2 * inv2;
49 inv * poly / std::f64::consts::PI.sqrt()
50 }
51}
52
53#[inline]
55pub fn log1mexp_positive(a: f64) -> f64 {
56 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
57 if a > core::f64::consts::LN_2 {
58 (-(-a).exp()).ln_1p()
59 } else if a > 0.0 {
60 (-(-a).exp_m1()).ln()
61 } else {
62 f64::NEG_INFINITY
63 }
64}
65
66pub fn signed_log_sum_exp(log_mags: &[f64], signs: &[f64]) -> (f64, f64) {
84 let mut has_pos_inf = false;
88 let mut has_neg_inf = false;
89 for (idx, &lm) in log_mags.iter().enumerate() {
90 if lm == f64::INFINITY {
91 if signs[idx] > 0.0 {
92 has_pos_inf = true;
93 } else if signs[idx] < 0.0 {
94 has_neg_inf = true;
95 }
96 }
97 }
98 match (has_pos_inf, has_neg_inf) {
99 (true, true) => return (f64::NAN, 0.0),
101 (true, false) => return (f64::INFINITY, 1.0),
103 (false, true) => return (f64::INFINITY, -1.0),
105 (false, false) => {}
106 }
107
108 let mut pos_max = f64::NEG_INFINITY;
109 let mut neg_max = f64::NEG_INFINITY;
110 for (idx, &lm) in log_mags.iter().enumerate() {
111 if signs[idx] > 0.0 {
112 pos_max = pos_max.max(lm);
113 } else if signs[idx] < 0.0 {
114 neg_max = neg_max.max(lm);
115 }
116 }
117
118 let mut pos_sum = 0.0_f64;
119 let mut neg_sum = 0.0_f64;
120 for (idx, &lm) in log_mags.iter().enumerate() {
121 if !lm.is_finite() {
122 continue;
123 }
124 if signs[idx] > 0.0 {
125 pos_sum += (lm - pos_max).exp();
126 } else if signs[idx] < 0.0 {
127 neg_sum += (lm - neg_max).exp();
128 }
129 }
130
131 let log_pos = if pos_sum > 0.0 {
132 pos_max + pos_sum.ln()
133 } else {
134 f64::NEG_INFINITY
135 };
136 let log_neg = if neg_sum > 0.0 {
137 neg_max + neg_sum.ln()
138 } else {
139 f64::NEG_INFINITY
140 };
141
142 if log_pos == f64::NEG_INFINITY && log_neg == f64::NEG_INFINITY {
143 return (f64::NEG_INFINITY, 0.0);
149 }
150 if log_neg == f64::NEG_INFINITY {
151 return (log_pos, 1.0);
152 }
153 if log_pos == f64::NEG_INFINITY {
154 return (log_neg, -1.0);
155 }
156 if log_pos > log_neg {
157 let gap = log_pos - log_neg;
158 (log_pos + log1mexp_positive(gap), 1.0)
159 } else if log_neg > log_pos {
160 let gap = log_neg - log_pos;
161 (log_neg + log1mexp_positive(gap), -1.0)
162 } else {
163 (f64::NEG_INFINITY, 0.0)
164 }
165}
166
167#[inline]
174pub fn normal_logcdf(x: f64) -> f64 {
175 if x == f64::INFINITY {
176 return 0.0;
177 }
178 if x == f64::NEG_INFINITY {
179 return f64::NEG_INFINITY;
180 }
181 if x.is_nan() {
182 return f64::NAN;
183 }
184 if x < 0.0 {
185 let u = -x / std::f64::consts::SQRT_2;
186 -u * u + (0.5 * erfcx_nonnegative(u).max(1e-300)).ln()
187 } else {
188 normal_cdf(x).clamp(1e-300, 1.0).ln()
189 }
190}
191
192#[inline]
196pub fn normal_logsf(x: f64) -> f64 {
197 normal_logcdf(-x)
198}
199
200#[inline]
207pub fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
208 if x == f64::INFINITY {
209 return (0.0, 0.0);
210 }
211 if x == f64::NEG_INFINITY {
212 return (f64::NEG_INFINITY, f64::INFINITY);
213 }
214 if x.is_nan() {
215 return (f64::NAN, f64::NAN);
216 }
217 if x < 0.0 {
218 let u = -x / std::f64::consts::SQRT_2;
219 let ex = erfcx_nonnegative(u).max(1e-300);
220 let log_cdf = -u * u + (0.5 * ex).ln();
221 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
222 (log_cdf, lambda)
223 } else {
224 let cdf = normal_cdf(x).clamp(1e-300, 1.0);
225 let lambda = normal_pdf(x) / cdf;
226 (cdf.ln(), lambda)
227 }
228}
229
230#[inline]
232pub fn standard_normal_quantile(p: f64) -> Result<f64, String> {
233 if !(p.is_finite() && p > 0.0 && p < 1.0) {
234 return Err(format!("normal quantile requires p in (0,1), got {p}"));
235 }
236
237 const A: [f64; 6] = [
238 -3.969_683_028_665_376e1,
239 2.209_460_984_245_205e2,
240 -2.759_285_104_469_687e2,
241 1.383_577_518_672_69e2,
242 -3.066_479_806_614_716e1,
243 2.506_628_277_459_239,
244 ];
245 const B: [f64; 5] = [
246 -5.447_609_879_822_406e1,
247 1.615_858_368_580_409e2,
248 -1.556_989_798_598_866e2,
249 6.680_131_188_771_972e1,
250 -1.328_068_155_288_572e1,
251 ];
252 const C: [f64; 6] = [
253 -7.784_894_002_430_293e-3,
254 -3.223_964_580_411_365e-1,
255 -2.400_758_277_161_838,
256 -2.549_732_539_343_734,
257 4.374_664_141_464_968,
258 2.938_163_982_698_783,
259 ];
260 const D: [f64; 4] = [
261 7.784_695_709_041_462e-3,
262 3.224_671_290_700_398e-1,
263 2.445_134_137_142_996,
264 3.754_408_661_907_416,
265 ];
266 const P_LOW: f64 = 0.02425;
267 const P_HIGH: f64 = 1.0 - P_LOW;
268
269 let mut x = if p < P_LOW {
270 let q = (-2.0 * p.ln()).sqrt();
271 (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
272 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
273 } else if p <= P_HIGH {
274 let q = p - 0.5;
275 let r = q * q;
276 (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
277 / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
278 } else {
279 let q = (-2.0 * (1.0 - p).ln()).sqrt();
280 -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
281 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
282 };
283 for _ in 0..2 {
284 let density = normal_pdf(x);
285 if !(density.is_finite() && density > 0.0) {
286 break;
287 }
288 let residual = if x > 0.0 {
298 (1.0 - p) - 0.5 * erfc(x / std::f64::consts::SQRT_2)
299 } else {
300 normal_cdf(x) - p
301 };
302 let correction = residual / density;
303 let denominator = 1.0 + 0.5 * x * correction;
304 if !(correction.is_finite() && denominator.is_finite() && denominator != 0.0) {
305 break;
306 }
307 let step = correction / denominator;
308 if !step.is_finite() {
309 break;
310 }
311 x -= step;
312 if step.abs() <= 2.0 * f64::EPSILON * x.abs().max(1.0) {
313 break;
314 }
315 }
316 Ok(x)
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 const TOL: f64 = 1e-12;
324
325 fn rel_err(got: f64, expected: f64) -> f64 {
326 (got - expected).abs() / expected.abs().max(1e-300)
327 }
328
329 #[test]
332 fn normal_pdf_at_zero() {
333 let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
334 assert!((normal_pdf(0.0) - expected).abs() < TOL);
335 }
336
337 #[test]
338 fn normal_pdf_symmetry() {
339 for &x in &[0.5, 1.0, 2.0, 3.0, 5.0] {
340 assert_eq!(normal_pdf(x), normal_pdf(-x), "symmetry failed at x={x}");
341 }
342 }
343
344 #[test]
345 fn normal_pdf_positive() {
346 for &x in &[-5.0, -1.0, 0.0, 1.0, 5.0] {
347 assert!(normal_pdf(x) > 0.0, "pdf should be positive at x={x}");
348 }
349 }
350
351 #[test]
354 fn normal_cdf_at_zero_is_half() {
355 assert!((normal_cdf(0.0) - 0.5).abs() < TOL);
356 }
357
358 #[test]
359 fn normal_cdf_symmetry() {
360 for &x in &[0.5, 1.0, 2.0, 3.0] {
361 let sum = normal_cdf(x) + normal_cdf(-x);
362 assert!(
363 (sum - 1.0).abs() < TOL,
364 "cdf symmetry failed at x={x}: sum={sum}"
365 );
366 }
367 }
368
369 #[test]
370 fn normal_cdf_bounds() {
371 assert!(normal_cdf(10.0) > 0.9999);
372 assert!(normal_cdf(-10.0) < 1e-22);
373 assert!(normal_cdf(0.0) > 0.0);
374 assert!(normal_cdf(0.0) < 1.0);
375 }
376
377 #[test]
378 fn normal_cdf_at_1_96_near_0975() {
379 let p = normal_cdf(1.959_963_985);
381 assert!((p - 0.975).abs() < 1e-8, "p={p}");
382 }
383
384 #[test]
387 fn erfcx_at_nonpositive_returns_one() {
388 assert_eq!(erfcx_nonnegative(0.0), 1.0);
389 assert_eq!(erfcx_nonnegative(-1.0), 1.0);
390 assert_eq!(erfcx_nonnegative(-100.0), 1.0);
391 }
392
393 #[test]
394 fn erfcx_positive_inf_returns_zero() {
395 assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
396 }
397
398 #[test]
399 fn erfcx_negative_inf_returns_inf() {
400 assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
401 }
402
403 #[test]
404 fn erfcx_small_positive_matches_direct() {
405 use statrs::function::erf::erfc;
406 for &x in &[0.1_f64, 0.5, 1.0, 5.0, 10.0, 25.0] {
407 let got = erfcx_nonnegative(x);
408 let expected = (x * x).exp() * erfc(x);
409 let err = rel_err(got, expected);
410 assert!(
411 err < 1e-10,
412 "x={x}: got={got} expected={expected} rel={err}"
413 );
414 }
415 }
416
417 #[test]
418 fn erfcx_large_x_positive_and_finite() {
419 let got = erfcx_nonnegative(50.0);
421 assert!(got.is_finite() && got > 0.0, "erfcx(50)={got}");
422 let asymptotic = 1.0 / (50.0 * std::f64::consts::PI.sqrt());
424 assert!(
425 rel_err(got, asymptotic) < 1e-3,
426 "got={got} asymptotic={asymptotic}"
427 );
428 }
429
430 #[test]
433 fn log1mexp_at_zero_is_neg_inf() {
434 assert_eq!(log1mexp_positive(0.0), f64::NEG_INFINITY);
435 }
436
437 #[test]
438 fn log1mexp_recovers_log_one_minus_exp() {
439 for &a in &[0.001_f64, 0.5, std::f64::consts::LN_2, 1.0, 5.0, 20.0] {
444 let lm = log1mexp_positive(a);
445 let roundtrip = lm.exp() + (-a).exp();
446 assert!(
447 (roundtrip - 1.0).abs() < 1e-14,
448 "a={a}: exp(log1mexp(a)) + exp(-a) = {roundtrip}, expected 1.0"
449 );
450 }
451 }
452
453 #[test]
454 fn log1mexp_at_ln2_is_neg_ln2() {
455 let ln2 = std::f64::consts::LN_2;
456 let got = log1mexp_positive(ln2);
457 assert!((got - (-ln2)).abs() < TOL, "got={got}");
458 }
459
460 #[test]
463 fn slse_all_positive_single() {
464 let (lm, sg) = signed_log_sum_exp(&[2.0], &[1.0]);
465 assert!((lm - 2.0).abs() < TOL);
466 assert!((sg - 1.0).abs() < TOL);
467 }
468
469 #[test]
470 fn slse_difference_recovers_log2() {
471 let log3 = 3.0_f64.ln();
473 let log1 = 0.0_f64; let (lm, sg) = signed_log_sum_exp(&[log3, log1], &[1.0, -1.0]);
475 assert!((lm - 2.0_f64.ln()).abs() < TOL, "lm={lm}");
476 assert!((sg - 1.0).abs() < TOL, "sg={sg}");
477 }
478
479 #[test]
480 fn slse_cancellation_gives_neg_inf() {
481 let ln2 = 2.0_f64.ln();
483 let (lm, sg) = signed_log_sum_exp(&[ln2, ln2], &[1.0, -1.0]);
484 assert_eq!(lm, f64::NEG_INFINITY);
485 assert_eq!(sg, 0.0);
486 }
487
488 #[test]
489 fn slse_empty_returns_neg_inf_with_zero_sign() {
490 let (lm, sg) = signed_log_sum_exp(&[], &[]);
495 assert_eq!(lm, f64::NEG_INFINITY);
496 assert_eq!(sg, 0.0);
497 }
498
499 #[test]
500 fn slse_all_zero_signs_return_zero_sign() {
501 let (lm, sg) = signed_log_sum_exp(&[0.0], &[0.0]);
503 assert_eq!(lm, f64::NEG_INFINITY);
504 assert_eq!(sg, 0.0);
505 }
506
507 #[test]
508 fn slse_all_neg_inf_magnitudes_return_zero_sign() {
509 let (lm, sg) = signed_log_sum_exp(&[f64::NEG_INFINITY, f64::NEG_INFINITY], &[1.0, -1.0]);
512 assert_eq!(lm, f64::NEG_INFINITY);
513 assert_eq!(sg, 0.0);
514 }
515
516 #[test]
517 fn slse_pos_inf_dominates() {
518 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[1.0, -1.0]);
519 assert_eq!(lm, f64::INFINITY);
520 assert_eq!(sg, 1.0);
521 }
522
523 #[test]
524 fn slse_neg_inf_dominates() {
525 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, 1.0], &[-1.0, 1.0]);
526 assert_eq!(lm, f64::INFINITY);
527 assert_eq!(sg, -1.0);
528 }
529
530 #[test]
531 fn slse_both_inf_signs_gives_nan() {
532 let (lm, sg) = signed_log_sum_exp(&[f64::INFINITY, f64::INFINITY], &[1.0, -1.0]);
533 assert!(lm.is_nan());
534 assert_eq!(sg, 0.0);
535 }
536
537 #[test]
540 fn logcdf_at_zero_is_log_half() {
541 let got = normal_logcdf(0.0);
542 let expected = 0.5_f64.ln();
543 assert!((got - expected).abs() < TOL, "got={got}");
544 }
545
546 #[test]
547 fn logcdf_pos_inf_is_zero() {
548 assert_eq!(normal_logcdf(f64::INFINITY), 0.0);
549 }
550
551 #[test]
552 fn logcdf_neg_inf_is_neg_inf() {
553 assert_eq!(normal_logcdf(f64::NEG_INFINITY), f64::NEG_INFINITY);
554 }
555
556 #[test]
557 fn logcdf_nan_is_nan() {
558 assert!(normal_logcdf(f64::NAN).is_nan());
559 }
560
561 #[test]
562 fn logcdf_matches_log_cdf_for_moderate_x() {
563 for &x in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0] {
564 let got = normal_logcdf(x);
565 let expected = normal_cdf(x).ln();
566 assert!(
567 (got - expected).abs() < 1e-10,
568 "x={x}: got={got} expected={expected}"
569 );
570 }
571 }
572
573 #[test]
574 fn logcdf_deep_left_tail_stays_finite() {
575 let got = normal_logcdf(-20.0);
578 assert!(got.is_finite() && got < -100.0, "logcdf(-20)={got}");
579 }
580
581 #[test]
584 fn logsf_at_zero_is_log_half() {
585 let got = normal_logsf(0.0);
586 let expected = 0.5_f64.ln();
587 assert!((got - expected).abs() < TOL, "got={got}");
588 }
589
590 #[test]
591 fn logsf_mirrors_logcdf() {
592 for &x in &[-3.0_f64, -1.0, 0.0, 1.0, 3.0] {
594 assert_eq!(normal_logsf(x), normal_logcdf(-x));
595 }
596 }
597
598 #[test]
601 fn probit_at_pos_inf() {
602 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::INFINITY);
603 assert_eq!(lc, 0.0);
604 assert_eq!(mr, 0.0);
605 }
606
607 #[test]
608 fn probit_at_neg_inf() {
609 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NEG_INFINITY);
610 assert_eq!(lc, f64::NEG_INFINITY);
611 assert_eq!(mr, f64::INFINITY);
612 }
613
614 #[test]
615 fn probit_nan_propagates() {
616 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(f64::NAN);
617 assert!(lc.is_nan() && mr.is_nan());
618 }
619
620 #[test]
621 fn probit_at_zero_logcdf_and_mills() {
622 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(0.0);
623 assert!((lc - 0.5_f64.ln()).abs() < TOL, "lc={lc}");
624 assert!((mr - 0.797_884_560_802_865).abs() < 1e-10, "mr={mr}");
626 }
627
628 #[test]
629 fn probit_positive_branch_matches_logcdf() {
630 for &x in &[0.5_f64, 1.0, 2.0, 3.0] {
631 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
632 let lc_ref = normal_logcdf(x);
633 let mr_ref = normal_pdf(x) / normal_cdf(x);
634 assert!(
635 (lc - lc_ref).abs() < 1e-10,
636 "x={x}: lc={lc} lc_ref={lc_ref}"
637 );
638 assert!(
639 (mr - mr_ref).abs() < 1e-10,
640 "x={x}: mr={mr} mr_ref={mr_ref}"
641 );
642 }
643 }
644
645 #[test]
646 fn probit_negative_branch_matches_logcdf() {
647 for &x in &[-0.5_f64, -1.0, -2.0, -5.0] {
648 let (lc, mr) = signed_probit_logcdf_and_mills_ratio(x);
649 let lc_ref = normal_logcdf(x);
650 assert!(
651 (lc - lc_ref).abs() < 1e-10,
652 "x={x}: lc={lc} lc_ref={lc_ref}"
653 );
654 assert!(mr.is_finite() && mr > 0.0, "x={x}: mr={mr}");
655 }
656 }
657
658 #[test]
661 fn quantile_rejects_out_of_range() {
662 assert!(standard_normal_quantile(0.0).is_err());
663 assert!(standard_normal_quantile(1.0).is_err());
664 assert!(standard_normal_quantile(-0.1).is_err());
665 assert!(standard_normal_quantile(1.1).is_err());
666 assert!(standard_normal_quantile(f64::NAN).is_err());
667 }
668
669 #[test]
670 fn quantile_at_half_is_near_zero() {
671 let q = standard_normal_quantile(0.5).unwrap();
672 assert!(q.abs() < 1e-10, "quantile(0.5)={q}");
673 }
674
675 #[test]
676 fn quantile_at_0975_is_near_196() {
677 let q = standard_normal_quantile(0.975).unwrap();
678 assert!((q - 1.959_963_985).abs() < 1e-7, "q={q}");
679 }
680
681 #[test]
682 fn quantile_antisymmetry() {
683 let q_lo = standard_normal_quantile(0.1).unwrap();
684 let q_hi = standard_normal_quantile(0.9).unwrap();
685 assert!((q_lo + q_hi).abs() < 1e-10, "q_lo={q_lo} q_hi={q_hi}");
686 }
687
688 #[test]
689 fn quantile_roundtrip_cdf() {
690 for &p in &[
691 0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999,
692 ] {
693 let q = standard_normal_quantile(p).unwrap();
694 let p_back = normal_cdf(q);
695 assert!(
696 (p_back - p).abs() < 1e-10,
697 "roundtrip failed at p={p}: q={q} p_back={p_back}"
698 );
699 }
700 }
701}