Skip to main content

gam_gpu/
numerics_device.rs

1//! Shared device-side probit numerics for NVRTC kernels.
2//!
3//! [`PROBIT_NUMERICS_CU`] is prepended to every NVRTC kernel source that needs
4//! stable probit/normal-CDF arithmetic.  Keeping one copy here means a
5//! numerics fix is a one-line change instead of a coordination problem across
6//! multiple kernel source strings.
7//!
8//! Covered device functions (all `__device__ __forceinline__`, double precision):
9//!   - `erfcx_nonnegative(x)`   — scaled complementary error function for x ≥ 0
10//!   - `log_ndtr(x)`            — log Φ(x), numerically stable in the deep left tail
11//!   - `log_ndtr_and_mills(x, *log_cdf, *lambda)` — joint (log Φ(x), φ(x)/Φ(x))
12
13/// Device-side probit numerics injected at the top of every NVRTC kernel that
14/// needs them.  Prepend this string to a kernel-specific body before passing to
15/// `cudarc::nvrtc::compile_ptx` or `PtxModuleCache::get_or_compile`.
16pub const PROBIT_NUMERICS_CU: &str = r#"
17// -------- shared probit numerics -----------------------------------------
18// All math in double precision; fast-math is disabled at compile time
19// (see `device_cache`'s `--fmad=false`) and the source is kept free of any
20// fast-math / single-precision intrinsic, guarded by the numerics_host tests.
21//
22// `log_ndtr(x)` = log Φ(x).  For x < 0 uses the erfcx representation
23//   log Φ(x) = -u² + log(½ · erfcx(u)),   u = -x / √2
24// which preserves digits all the way into the deep left tail (matches
25// the CPU `normal_logcdf`).  For x ≥ 0 falls back to log1p(-½·erfc(x/√2)).
26//
27// `log_ndtr_and_mills(x, *log_cdf, *lambda)` returns both log Φ(x) and the
28// Mills ratio φ(x)/Φ(x) in a single pass.  For x < 0 the erfcx path keeps
29// the ratio stable even when Φ(x) underflows to zero.
30
31#ifndef PROBIT_NUMERICS_INCLUDED
32#define PROBIT_NUMERICS_INCLUDED
33
34#define INV_SQRT_2PI 0.3989422804014327
35#define SQRT_2       1.4142135623730951
36
37extern "C" __device__ __forceinline__ double erfcx_nonnegative(double x) {
38    if (!isfinite(x)) {
39        return (x > 0.0) ? 0.0 : (1.0 / 0.0);
40    }
41    if (x <= 0.0) return 1.0;
42    if (x < 26.0) {
43        double xx = x * x;
44        if (xx > 700.0) xx = 700.0;
45        return exp(xx) * erfc(x);
46    }
47    // 4-term asymptotic expansion of erfcx for large x.
48    double inv  = 1.0 / x;
49    double inv2 = inv * inv;
50    double poly = 1.0
51                - 0.5      * inv2
52                + 0.75     * inv2 * inv2
53                - 1.875    * inv2 * inv2 * inv2
54                + 6.5625   * inv2 * inv2 * inv2 * inv2;
55    const double inv_sqrt_pi = 0.5641895835477563; // 1/√π
56    return inv * poly * inv_sqrt_pi;
57}
58
59extern "C" __device__ __forceinline__ double log_ndtr(double x) {
60    if (x ==  (1.0 / 0.0)) return 0.0;
61    if (x == -(1.0 / 0.0)) return -(1.0 / 0.0);
62    if (isnan(x)) return x;
63    if (x < 0.0) {
64        double u   = -x / SQRT_2;
65        double ex  = erfcx_nonnegative(u);
66        if (ex < 1e-300) ex = 1e-300;
67        return -u * u + log(0.5 * ex);
68    } else {
69        double c = 0.5 * erfc(-x / SQRT_2);
70        if (c < 1e-300) c = 1e-300;
71        if (c > 1.0)    c = 1.0;
72        return log(c);
73    }
74}
75
76// Returns (log Φ(x), φ(x)/Φ(x)).
77extern "C" __device__ __forceinline__ void
78log_ndtr_and_mills(double x, double *log_cdf, double *lambda) {
79    if (x ==  (1.0 / 0.0)) { *log_cdf = 0.0;            *lambda = 0.0;            return; }
80    if (x == -(1.0 / 0.0)) { *log_cdf = -(1.0 / 0.0);   *lambda = (1.0 / 0.0);    return; }
81    if (isnan(x))          { *log_cdf = x;              *lambda = x;              return; }
82    if (x < 0.0) {
83        double u   = -x / SQRT_2;
84        double ex  = erfcx_nonnegative(u);
85        if (ex < 1e-300) ex = 1e-300;
86        *log_cdf = -u * u + log(0.5 * ex);
87        const double sqrt_2_over_pi = 0.7978845608028654; // √(2/π)
88        *lambda  = sqrt_2_over_pi / ex;
89    } else {
90        double cdf = 0.5 * erfc(-x / SQRT_2);
91        if (cdf < 1e-300) cdf = 1e-300;
92        if (cdf > 1.0)    cdf = 1.0;
93        double pdf = INV_SQRT_2PI * exp(-0.5 * x * x);
94        *log_cdf = log(cdf);
95        *lambda  = pdf / cdf;
96    }
97}
98
99#endif // PROBIT_NUMERICS_INCLUDED
100"#;