gam_gpu/
numerics_device.rs1pub const PROBIT_NUMERICS_CU: &str = r#"
17// -------- shared probit numerics -----------------------------------------
18// All math in double precision. No --use_fast_math.
19//
20// `log_ndtr(x)` = log Φ(x). For x < 0 uses the erfcx representation
21// log Φ(x) = -u² + log(½ · erfcx(u)), u = -x / √2
22// which preserves digits all the way into the deep left tail (matches
23// the CPU `normal_logcdf`). For x ≥ 0 falls back to log1p(-½·erfc(x/√2)).
24//
25// `log_ndtr_and_mills(x, *log_cdf, *lambda)` returns both log Φ(x) and the
26// Mills ratio φ(x)/Φ(x) in a single pass. For x < 0 the erfcx path keeps
27// the ratio stable even when Φ(x) underflows to zero.
28
29#ifndef PROBIT_NUMERICS_INCLUDED
30#define PROBIT_NUMERICS_INCLUDED
31
32#define INV_SQRT_2PI 0.3989422804014327
33#define SQRT_2 1.4142135623730951
34
35extern "C" __device__ __forceinline__ double erfcx_nonnegative(double x) {
36 if (!isfinite(x)) {
37 return (x > 0.0) ? 0.0 : (1.0 / 0.0);
38 }
39 if (x <= 0.0) return 1.0;
40 if (x < 26.0) {
41 double xx = x * x;
42 if (xx > 700.0) xx = 700.0;
43 return exp(xx) * erfc(x);
44 }
45 // 4-term asymptotic expansion of erfcx for large x.
46 double inv = 1.0 / x;
47 double inv2 = inv * inv;
48 double poly = 1.0
49 - 0.5 * inv2
50 + 0.75 * inv2 * inv2
51 - 1.875 * inv2 * inv2 * inv2
52 + 6.5625 * inv2 * inv2 * inv2 * inv2;
53 const double inv_sqrt_pi = 0.5641895835477563; // 1/√π
54 return inv * poly * inv_sqrt_pi;
55}
56
57extern "C" __device__ __forceinline__ double log_ndtr(double x) {
58 if (x == (1.0 / 0.0)) return 0.0;
59 if (x == -(1.0 / 0.0)) return -(1.0 / 0.0);
60 if (isnan(x)) return x;
61 if (x < 0.0) {
62 double u = -x / SQRT_2;
63 double ex = erfcx_nonnegative(u);
64 if (ex < 1e-300) ex = 1e-300;
65 return -u * u + log(0.5 * ex);
66 } else {
67 double c = 0.5 * erfc(-x / SQRT_2);
68 if (c < 1e-300) c = 1e-300;
69 if (c > 1.0) c = 1.0;
70 return log(c);
71 }
72}
73
74// Returns (log Φ(x), φ(x)/Φ(x)).
75extern "C" __device__ __forceinline__ void
76log_ndtr_and_mills(double x, double *log_cdf, double *lambda) {
77 if (x == (1.0 / 0.0)) { *log_cdf = 0.0; *lambda = 0.0; return; }
78 if (x == -(1.0 / 0.0)) { *log_cdf = -(1.0 / 0.0); *lambda = (1.0 / 0.0); return; }
79 if (isnan(x)) { *log_cdf = x; *lambda = x; return; }
80 if (x < 0.0) {
81 double u = -x / SQRT_2;
82 double ex = erfcx_nonnegative(u);
83 if (ex < 1e-300) ex = 1e-300;
84 *log_cdf = -u * u + log(0.5 * ex);
85 const double sqrt_2_over_pi = 0.7978845608028654; // √(2/π)
86 *lambda = sqrt_2_over_pi / ex;
87 } else {
88 double cdf = 0.5 * erfc(-x / SQRT_2);
89 if (cdf < 1e-300) cdf = 1e-300;
90 if (cdf > 1.0) cdf = 1.0;
91 double pdf = INV_SQRT_2PI * exp(-0.5 * x * x);
92 *log_cdf = log(cdf);
93 *lambda = pdf / cdf;
94 }
95}
96
97#endif // PROBIT_NUMERICS_INCLUDED
98"#;