pub fn erfc(x: f64) -> f64 {
libm::erfc(x)
}
pub const INV_SQRT_2PI: f64 = 0.3989422804014327;
pub const SQRT_2: f64 = 1.4142135623730951;
pub const INV_SQRT_PI: f64 = 0.5641895835477563;
pub const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
pub fn erfcx_nonnegative(x: f64) -> f64 {
if !x.is_finite() {
return if x > 0.0 { 0.0 } else { f64::INFINITY };
}
if x <= 0.0 {
return 1.0;
}
if x < 26.0 {
let mut xx = x * x;
if xx > 700.0 {
xx = 700.0;
}
return libm::exp(xx) * erfc(x);
}
let inv = 1.0 / x;
let inv2 = inv * inv;
let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
+ 6.5625 * inv2 * inv2 * inv2 * inv2;
inv * poly * INV_SQRT_PI
}
pub fn log_ndtr(x: f64) -> f64 {
if x == f64::INFINITY {
return 0.0;
}
if x == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
if x.is_nan() {
return x;
}
if x < 0.0 {
let u = -x / SQRT_2;
let mut ex = erfcx_nonnegative(u);
if ex < 1e-300 {
ex = 1e-300;
}
-u * u + libm::log(0.5 * ex)
} else {
let mut c = 0.5 * erfc(-x / SQRT_2);
if c < 1e-300 {
c = 1e-300;
}
if c > 1.0 {
c = 1.0;
}
libm::log(c)
}
}
pub fn log_ndtr_and_mills(x: f64) -> (f64, f64) {
if x == f64::INFINITY {
return (0.0, 0.0);
}
if x == f64::NEG_INFINITY {
return (f64::NEG_INFINITY, f64::INFINITY);
}
if x.is_nan() {
return (x, x);
}
if x < 0.0 {
let u = -x / SQRT_2;
let mut ex = erfcx_nonnegative(u);
if ex < 1e-300 {
ex = 1e-300;
}
let log_cdf = -u * u + libm::log(0.5 * ex);
let lambda = SQRT_2_OVER_PI / ex;
(log_cdf, lambda)
} else {
let mut cdf = 0.5 * erfc(-x / SQRT_2);
if cdf < 1e-300 {
cdf = 1e-300;
}
if cdf > 1.0 {
cdf = 1.0;
}
let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
let log_cdf = libm::log(cdf);
let lambda = pdf / cdf;
(log_cdf, lambda)
}
}
#[cfg(test)]
mod probit_parity_tests {
use super::*;
use crate::numerics_device::PROBIT_NUMERICS_CU;
const EPS: f64 = f64::EPSILON;
fn ulp(got: f64, want: f64) -> f64 {
if want == 0.0 {
(got - want).abs() / EPS
} else {
(got - want).abs() / (EPS * want.abs())
}
}
fn literal_after(src: &str, needle: &str) -> f64 {
let start = src
.find(needle)
.unwrap_or_else(|| panic!("kernel source is missing marker {needle:?}"))
+ needle.len();
let tail = &src[start..];
let num_start = tail
.find(|c: char| c == '-' || c == '.' || c.is_ascii_digit())
.unwrap_or_else(|| panic!("no numeric literal follows {needle:?}"));
let rest = &tail[num_start..];
let end = rest
.find(|c: char| !(c.is_ascii_digit() || matches!(c, '.' | 'e' | 'E' | '+' | '-')))
.unwrap_or(rest.len());
rest[..end]
.parse::<f64>()
.unwrap_or_else(|e| panic!("failed to parse literal after {needle:?}: {e}"))
}
#[test]
fn host_constants_match_kernel_source_bit_for_bit() {
for (needle, host) in [
("#define INV_SQRT_2PI", INV_SQRT_2PI),
("#define SQRT_2", SQRT_2),
("inv_sqrt_pi =", INV_SQRT_PI),
("sqrt_2_over_pi =", SQRT_2_OVER_PI),
] {
let device = literal_after(PROBIT_NUMERICS_CU, needle);
assert_eq!(
device.to_bits(),
host.to_bits(),
"constant {needle:?} drifted: kernel={device:?} host={host:?}"
);
}
}
#[test]
fn kernel_source_uses_msun_transcendentals_only() {
for good in ["erfc(", "exp(", "log("] {
assert!(
PROBIT_NUMERICS_CU.contains(good),
"kernel source should call msun `{good}`"
);
}
for bad in [
"__expf",
"__logf",
"expf(",
"logf(",
"erfcf(",
"__fdividef",
"__frcp",
"use_fast_math",
"ffast-math",
"__dmul_",
"__dadd_",
"__fmaf",
] {
assert!(
!PROBIT_NUMERICS_CU.contains(bad),
"kernel source must not use fast-math / single-precision `{bad}`"
);
}
}
#[test]
fn erfc_boundary_and_symmetry() {
assert_eq!(erfc(0.0), 1.0);
let mut worst = 0.0_f64;
for i in 0..300 {
let x = i as f64 * 0.01;
worst = worst.max(ulp(erfc(-x), 2.0 - erfc(x)));
}
assert!(worst <= 2.0, "erfc symmetry drift {worst:.3} ULP > 2");
}
#[test]
fn erfcx_matches_definition() {
assert_eq!(erfcx_nonnegative(0.0), 1.0);
assert_eq!(erfcx_nonnegative(-3.0), 1.0);
assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
let mut worst = 0.0_f64;
let mut x = 0.1;
while x < 25.0 {
worst = worst.max(ulp(erfcx_nonnegative(x) * libm::exp(-x * x), erfc(x)));
x += 0.1;
}
assert!(worst <= 4.0, "erfcx definition drift {worst:.3} ULP > 4");
}
#[test]
fn log_ndtr_matches_log_cdf_and_reflects() {
assert_eq!(log_ndtr(0.0), libm::log(0.5));
assert_eq!(log_ndtr(f64::INFINITY), 0.0);
assert_eq!(log_ndtr(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert!(log_ndtr(f64::NAN).is_nan());
let mut worst_bulk = 0.0_f64;
for i in -30..=30 {
let x = i as f64 * 0.1;
let cdf = 0.5 * erfc(-x / SQRT_2);
worst_bulk = worst_bulk.max(ulp(log_ndtr(x), libm::log(cdf)));
}
assert!(
worst_bulk <= 2.0,
"log_ndtr vs log-cdf drift {worst_bulk:.3} ULP > 2"
);
let mut worst_refl = 0.0_f64;
for i in 0..60 {
let x = i as f64 * 0.1;
let s = libm::exp(log_ndtr(x)) + libm::exp(log_ndtr(-x));
worst_refl = worst_refl.max((s - 1.0).abs());
}
assert!(
worst_refl <= 4e-16,
"Φ(x)+Φ(-x) reflection drift {worst_refl:e} > 4e-16"
);
}
#[test]
fn log_ndtr_and_mills_identity_and_deep_tail() {
for i in -50..=50 {
let x = i as f64 * 0.1;
let (log_cdf, lambda) = log_ndtr_and_mills(x);
assert_eq!(
log_cdf.to_bits(),
log_ndtr(x).to_bits(),
"joint log-CDF channel diverged from log_ndtr at x={x}"
);
let phi = libm::exp(log_cdf);
let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
assert!(
ulp(lambda * phi, pdf) <= 32.0,
"Mills identity drift {:.3} ULP > 32 at x={x}",
ulp(lambda * phi, pdf)
);
}
for &x in &[-10.0, -20.0, -30.0, -38.0] {
let (log_cdf, lambda) = log_ndtr_and_mills(x);
assert!(
log_cdf.is_finite() && log_cdf < 0.0,
"deep-tail log Φ({x}) not finite-negative: {log_cdf}"
);
assert!(
lambda.is_finite() && lambda > x.abs() * 0.9,
"deep-tail Mills({x}) should track |x|: {lambda}"
);
}
assert_eq!(log_ndtr_and_mills(f64::INFINITY), (0.0, 0.0));
assert_eq!(
log_ndtr_and_mills(f64::NEG_INFINITY),
(f64::NEG_INFINITY, f64::INFINITY)
);
}
}