gam_gpu/numerics_host.rs
1//! Host-side scalar special functions shared by the CPU parity references of
2//! the GPU backends.
3//!
4//! The CUDA kernels emit their own NVRTC-visible numerics (see
5//! [`crate::numerics_device`]); this module is the matching **host** side
6//! used by the CPU parity oracles (`bms_flex_row`'s test oracle) and the
7//! CPU reference path (`pirls_row`'s probit CDF). Keeping a single definition
8//! here means the host `erfc` cannot drift between backends.
9
10/// Complementary error function `erfc(x) = 1 − erf(x)` evaluated on the host.
11///
12/// Routes to `libm::erfc`, the SunOS msun double-precision implementation
13/// (accurate to within ~1 ulp across the entire real line). The CUDA kernel
14/// side calls device `erfc`, which is itself msun-derived, so the host CPU
15/// reference matches the device path to within a ULP. The previous
16/// branchless Cody 1969 Chebyshev rational here was only ~1.2e-7 accurate
17/// in relative terms; that ate seven digits of every probit `Mills =
18/// φ/Φ = pdf / (½·erfc(-x/√2))` evaluation and made any sufficiently
19/// tight finite-difference probe of `∂neglog/∂e = -w·s·Mills` (which the
20/// analytic side computes from this same `cdf`, while the FD side
21/// differences `log cdf` and cancels the erfc bias) break against itself
22/// at the ~2e-7 floor instead of the genuine 5-point-stencil truncation
23/// floor near 1e-12.
24pub fn erfc(x: f64) -> f64 {
25 libm::erfc(x)
26}
27
28// ── Host oracle for the shared device probit numerics (issue #1175) ──────────
29//
30// The functions below are the CPU-side, device-free mirror of the CUDA source
31// in [`crate::numerics_device::PROBIT_NUMERICS_CU`]. They are written
32// LINE-FOR-LINE against that kernel source — the SAME branch structure, the
33// SAME clamps (`1e-300`, `[0,1]`), the SAME asymptotic `erfcx` polynomial, and
34// the SAME four constants — differing only in that they call the host `libm`
35// transcendentals (`erfc`/`exp`/`log`) where the kernel calls the device
36// `erfc`/`exp`/`log`. Both sides are the SunOS *msun* double-precision
37// implementations, so the host oracle matches the device to within ~1 ULP per
38// transcendental (issue #1175 items 4–5). This mirrors the #1017
39// `emulate_certified_encode_row` pattern: a CPU emulator that is BOTH the
40// fallback and the exactness oracle a device launch is pinned to.
41//
42// Correctness *without a GPU* (CPU-verifiable): the test harness below asserts
43// (a) these constants are bit-identical to the literals in the kernel source
44// (the "constants cannot drift" lock, #1175 item 4), (b) the kernel source uses
45// only msun transcendentals and no fast-math intrinsics (transcendental-parity
46// intent), and (c) the host oracle satisfies the defining probit identities to
47// a stated ULP bound. Confirming a *device launch* reproduces this oracle to
48// round-off still needs CUDA hardware.
49
50/// `1/√(2π)`, matching `INV_SQRT_2PI` in the kernel source bit-for-bit.
51pub const INV_SQRT_2PI: f64 = 0.3989422804014327;
52/// `√2`, matching `SQRT_2` in the kernel source bit-for-bit.
53pub const SQRT_2: f64 = 1.4142135623730951;
54/// `1/√π`, matching `inv_sqrt_pi` in the kernel source bit-for-bit.
55pub const INV_SQRT_PI: f64 = 0.5641895835477563;
56/// `√(2/π)`, matching `sqrt_2_over_pi` in the kernel source bit-for-bit.
57pub const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
58
59/// Scaled complementary error function `erfcx(x) = exp(x²)·erfc(x)` for `x ≥ 0`,
60/// the host oracle for the device `erfcx_nonnegative`. Returns `1.0` for
61/// `x ≤ 0`, `0.0` at `+∞`, `+∞` at `−∞`. For `0 < x < 26` evaluates the direct
62/// `exp(min(x², 700))·erfc(x)` form; beyond that (where `exp(x²)` would
63/// overflow) it switches to the same 4-term asymptotic expansion as the kernel.
64pub fn erfcx_nonnegative(x: f64) -> f64 {
65 if !x.is_finite() {
66 return if x > 0.0 { 0.0 } else { f64::INFINITY };
67 }
68 if x <= 0.0 {
69 return 1.0;
70 }
71 if x < 26.0 {
72 let mut xx = x * x;
73 if xx > 700.0 {
74 xx = 700.0;
75 }
76 return libm::exp(xx) * erfc(x);
77 }
78 let inv = 1.0 / x;
79 let inv2 = inv * inv;
80 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
81 + 6.5625 * inv2 * inv2 * inv2 * inv2;
82 inv * poly * INV_SQRT_PI
83}
84
85/// `log Φ(x)` for the standard normal CDF, the host oracle for the device
86/// `log_ndtr`. For `x < 0` uses the `erfcx` representation
87/// `log Φ(x) = −u² + log(½·erfcx(u))`, `u = −x/√2`, keeping digits into the
88/// deep left tail; for `x ≥ 0` uses `log(½·erfc(−x/√2))` with the same clamps as
89/// the kernel. Propagates `±∞`/`NaN` exactly as the device path does.
90pub fn log_ndtr(x: f64) -> f64 {
91 if x == f64::INFINITY {
92 return 0.0;
93 }
94 if x == f64::NEG_INFINITY {
95 return f64::NEG_INFINITY;
96 }
97 if x.is_nan() {
98 return x;
99 }
100 if x < 0.0 {
101 let u = -x / SQRT_2;
102 let mut ex = erfcx_nonnegative(u);
103 if ex < 1e-300 {
104 ex = 1e-300;
105 }
106 -u * u + libm::log(0.5 * ex)
107 } else {
108 let mut c = 0.5 * erfc(-x / SQRT_2);
109 if c < 1e-300 {
110 c = 1e-300;
111 }
112 if c > 1.0 {
113 c = 1.0;
114 }
115 libm::log(c)
116 }
117}
118
119/// Joint `(log Φ(x), Mills ratio φ(x)/Φ(x))`, the host oracle for the device
120/// `log_ndtr_and_mills`. The `x < 0` branch computes the Mills ratio as
121/// `√(2/π)/erfcx(u)`, which stays finite even when `Φ(x)` underflows; the
122/// `x ≥ 0` branch forms `pdf/cdf` directly. Boundary values mirror the kernel:
123/// `(+0, +0)` at `+∞`, `(−∞, +∞)` at `−∞`, `(NaN, NaN)` at `NaN`.
124pub fn log_ndtr_and_mills(x: f64) -> (f64, f64) {
125 if x == f64::INFINITY {
126 return (0.0, 0.0);
127 }
128 if x == f64::NEG_INFINITY {
129 return (f64::NEG_INFINITY, f64::INFINITY);
130 }
131 if x.is_nan() {
132 return (x, x);
133 }
134 if x < 0.0 {
135 let u = -x / SQRT_2;
136 let mut ex = erfcx_nonnegative(u);
137 if ex < 1e-300 {
138 ex = 1e-300;
139 }
140 let log_cdf = -u * u + libm::log(0.5 * ex);
141 let lambda = SQRT_2_OVER_PI / ex;
142 (log_cdf, lambda)
143 } else {
144 let mut cdf = 0.5 * erfc(-x / SQRT_2);
145 if cdf < 1e-300 {
146 cdf = 1e-300;
147 }
148 if cdf > 1.0 {
149 cdf = 1.0;
150 }
151 let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
152 let log_cdf = libm::log(cdf);
153 let lambda = pdf / cdf;
154 (log_cdf, lambda)
155 }
156}
157
158#[cfg(test)]
159mod probit_parity_tests {
160 //! CPU-verifiable floating-point-order & transcendental parity harness for
161 //! the shared probit numerics (issue #1175). Everything here runs without a
162 //! GPU: it pins the host oracle constants to the kernel-source literals,
163 //! audits the kernel source for msun-only transcendentals (no fast-math),
164 //! and checks the host oracle against the defining probit identities within
165 //! stated ULP bounds. A *device* reproducing this oracle to round-off still
166 //! requires CUDA hardware and is asserted by the on-device parity gates.
167 use super::*;
168 use crate::numerics_device::PROBIT_NUMERICS_CU;
169
170 const EPS: f64 = f64::EPSILON; // 2.220446049250313e-16
171
172 /// Relative error of `got` vs `want`, expressed in ULP of `want`.
173 fn ulp(got: f64, want: f64) -> f64 {
174 if want == 0.0 {
175 (got - want).abs() / EPS
176 } else {
177 (got - want).abs() / (EPS * want.abs())
178 }
179 }
180
181 /// Extract the first f64 literal appearing after `needle` in `src`.
182 fn literal_after(src: &str, needle: &str) -> f64 {
183 let start = src
184 .find(needle)
185 .unwrap_or_else(|| panic!("kernel source is missing marker {needle:?}"))
186 + needle.len();
187 let tail = &src[start..];
188 // Skip separators between the marker and the number ('=', whitespace).
189 let num_start = tail
190 .find(|c: char| c == '-' || c == '.' || c.is_ascii_digit())
191 .unwrap_or_else(|| panic!("no numeric literal follows {needle:?}"));
192 let rest = &tail[num_start..];
193 let end = rest
194 .find(|c: char| !(c.is_ascii_digit() || matches!(c, '.' | 'e' | 'E' | '+' | '-')))
195 .unwrap_or(rest.len());
196 rest[..end]
197 .parse::<f64>()
198 .unwrap_or_else(|e| panic!("failed to parse literal after {needle:?}: {e}"))
199 }
200
201 /// #1175 item 4 pattern ("constants cannot drift"): every constant the host
202 /// oracle uses is bit-identical to the literal baked into the kernel source.
203 /// A one-bit edit on either side fails this immediately.
204 #[test]
205 fn host_constants_match_kernel_source_bit_for_bit() {
206 for (needle, host) in [
207 ("#define INV_SQRT_2PI", INV_SQRT_2PI),
208 ("#define SQRT_2", SQRT_2),
209 ("inv_sqrt_pi =", INV_SQRT_PI),
210 ("sqrt_2_over_pi =", SQRT_2_OVER_PI),
211 ] {
212 let device = literal_after(PROBIT_NUMERICS_CU, needle);
213 assert_eq!(
214 device.to_bits(),
215 host.to_bits(),
216 "constant {needle:?} drifted: kernel={device:?} host={host:?}"
217 );
218 }
219 }
220
221 /// Transcendental-parity intent: the kernel evaluates its transcendentals
222 /// through the msun `erfc`/`exp`/`log` (which the host `libm` mirrors) and
223 /// contains NO fast-math intrinsic or single-precision variant. FMA
224 /// contraction is separately disabled at compile time via
225 /// `device_cache`'s `--fmad=false`; this guards the source itself.
226 #[test]
227 fn kernel_source_uses_msun_transcendentals_only() {
228 for good in ["erfc(", "exp(", "log("] {
229 assert!(
230 PROBIT_NUMERICS_CU.contains(good),
231 "kernel source should call msun `{good}`"
232 );
233 }
234 for bad in [
235 "__expf",
236 "__logf",
237 "expf(",
238 "logf(",
239 "erfcf(",
240 "__fdividef",
241 "__frcp",
242 "use_fast_math",
243 "ffast-math",
244 "__dmul_",
245 "__dadd_",
246 "__fmaf",
247 ] {
248 assert!(
249 !PROBIT_NUMERICS_CU.contains(bad),
250 "kernel source must not use fast-math / single-precision `{bad}`"
251 );
252 }
253 }
254
255 /// `erfc` boundary + symmetry: `erfc(0)=1` exactly and
256 /// `erfc(-x) = 2 - erfc(x)` to ≤ 2 ULP across a moderate grid.
257 #[test]
258 fn erfc_boundary_and_symmetry() {
259 assert_eq!(erfc(0.0), 1.0);
260 let mut worst = 0.0_f64;
261 for i in 0..300 {
262 let x = i as f64 * 0.01;
263 worst = worst.max(ulp(erfc(-x), 2.0 - erfc(x)));
264 }
265 assert!(worst <= 2.0, "erfc symmetry drift {worst:.3} ULP > 2");
266 }
267
268 /// Defining identity `erfcx(x)·exp(-x²) = erfc(x)` to ≤ 4 ULP for
269 /// `0 < x < 26` (the direct branch of the host oracle).
270 #[test]
271 fn erfcx_matches_definition() {
272 assert_eq!(erfcx_nonnegative(0.0), 1.0);
273 assert_eq!(erfcx_nonnegative(-3.0), 1.0);
274 assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
275 assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
276 let mut worst = 0.0_f64;
277 let mut x = 0.1;
278 while x < 25.0 {
279 worst = worst.max(ulp(erfcx_nonnegative(x) * libm::exp(-x * x), erfc(x)));
280 x += 0.1;
281 }
282 assert!(worst <= 4.0, "erfcx definition drift {worst:.3} ULP > 4");
283 }
284
285 /// `log_ndtr` boundary + bulk identity `log Φ(x) = log(½·erfc(-x/√2))` to
286 /// ≤ 2 ULP for `|x| ≤ 3`, and `Φ(x)+Φ(-x)=1` to ≤ 4e-16.
287 #[test]
288 fn log_ndtr_matches_log_cdf_and_reflects() {
289 assert_eq!(log_ndtr(0.0), libm::log(0.5));
290 assert_eq!(log_ndtr(f64::INFINITY), 0.0);
291 assert_eq!(log_ndtr(f64::NEG_INFINITY), f64::NEG_INFINITY);
292 assert!(log_ndtr(f64::NAN).is_nan());
293
294 let mut worst_bulk = 0.0_f64;
295 for i in -30..=30 {
296 let x = i as f64 * 0.1;
297 let cdf = 0.5 * erfc(-x / SQRT_2);
298 worst_bulk = worst_bulk.max(ulp(log_ndtr(x), libm::log(cdf)));
299 }
300 assert!(
301 worst_bulk <= 2.0,
302 "log_ndtr vs log-cdf drift {worst_bulk:.3} ULP > 2"
303 );
304
305 let mut worst_refl = 0.0_f64;
306 for i in 0..60 {
307 let x = i as f64 * 0.1;
308 let s = libm::exp(log_ndtr(x)) + libm::exp(log_ndtr(-x));
309 worst_refl = worst_refl.max((s - 1.0).abs());
310 }
311 assert!(
312 worst_refl <= 4e-16,
313 "Φ(x)+Φ(-x) reflection drift {worst_refl:e} > 4e-16"
314 );
315 }
316
317 /// `log_ndtr_and_mills` agrees with `log_ndtr` on the log-CDF channel and
318 /// satisfies the Mills identity `λ(x)·Φ(x) = φ(x)` to ≤ 32 ULP for
319 /// `|x| ≤ 5`; the deep left tail stays finite (no `-∞`/`NaN`).
320 #[test]
321 fn log_ndtr_and_mills_identity_and_deep_tail() {
322 for i in -50..=50 {
323 let x = i as f64 * 0.1;
324 let (log_cdf, lambda) = log_ndtr_and_mills(x);
325 assert_eq!(
326 log_cdf.to_bits(),
327 log_ndtr(x).to_bits(),
328 "joint log-CDF channel diverged from log_ndtr at x={x}"
329 );
330 let phi = libm::exp(log_cdf);
331 let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
332 assert!(
333 ulp(lambda * phi, pdf) <= 32.0,
334 "Mills identity drift {:.3} ULP > 32 at x={x}",
335 ulp(lambda * phi, pdf)
336 );
337 }
338 for &x in &[-10.0, -20.0, -30.0, -38.0] {
339 let (log_cdf, lambda) = log_ndtr_and_mills(x);
340 assert!(
341 log_cdf.is_finite() && log_cdf < 0.0,
342 "deep-tail log Φ({x}) not finite-negative: {log_cdf}"
343 );
344 assert!(
345 lambda.is_finite() && lambda > x.abs() * 0.9,
346 "deep-tail Mills({x}) should track |x|: {lambda}"
347 );
348 }
349 assert_eq!(log_ndtr_and_mills(f64::INFINITY), (0.0, 0.0));
350 assert_eq!(
351 log_ndtr_and_mills(f64::NEG_INFINITY),
352 (f64::NEG_INFINITY, f64::INFINITY)
353 );
354 }
355}