Skip to main content

gam_gpu/
numerics_host.rs

1//! Host-side scalar special functions shared by the CPU parity references of
2//! the GPU backends.
3//!
4//! The CUDA kernels emit their own NVRTC-visible numerics (see
5//! [`crate::numerics_device`]); this module is the matching **host** side
6//! used by the CPU parity oracles (`bms_flex_row`'s test oracle) and the
7//! CPU reference path (`pirls_row`'s probit CDF). Keeping a single definition
8//! here means the host `erfc` cannot drift between backends.
9
10/// Complementary error function `erfc(x) = 1 − erf(x)` evaluated on the host.
11///
12/// Routes to `libm::erfc`, the SunOS msun double-precision implementation
13/// (accurate to within ~1 ulp across the entire real line). The CUDA kernel
14/// side calls device `erfc`, which is itself msun-derived, so the host CPU
15/// reference matches the device path to within a ULP. The previous
16/// branchless Cody 1969 Chebyshev rational here was only ~1.2e-7 accurate
17/// in relative terms; that ate seven digits of every probit `Mills =
18/// φ/Φ = pdf / (½·erfc(-x/√2))` evaluation and made any sufficiently
19/// tight finite-difference probe of `∂neglog/∂e = -w·s·Mills` (which the
20/// analytic side computes from this same `cdf`, while the FD side
21/// differences `log cdf` and cancels the erfc bias) break against itself
22/// at the ~2e-7 floor instead of the genuine 5-point-stencil truncation
23/// floor near 1e-12.
24pub fn erfc(x: f64) -> f64 {
25    libm::erfc(x)
26}
27
28// ── Host oracle for the shared device probit numerics (issue #1175) ──────────
29//
30// The functions below are the CPU-side, device-free mirror of the CUDA source
31// in [`crate::numerics_device::PROBIT_NUMERICS_CU`]. They are written
32// LINE-FOR-LINE against that kernel source — the SAME branch structure, the
33// SAME clamps (`1e-300`, `[0,1]`), the SAME asymptotic `erfcx` polynomial, and
34// the SAME four constants — differing only in that they call the host `libm`
35// transcendentals (`erfc`/`exp`/`log`) where the kernel calls the device
36// `erfc`/`exp`/`log`. Both sides are the SunOS *msun* double-precision
37// implementations, so the host oracle matches the device to within ~1 ULP per
38// transcendental (issue #1175 items 4–5). This mirrors the #1017
39// `emulate_certified_encode_row` pattern: a CPU emulator that is BOTH the
40// fallback and the exactness oracle a device launch is pinned to.
41//
42// Correctness *without a GPU* (CPU-verifiable): the test harness below asserts
43// (a) these constants are bit-identical to the literals in the kernel source
44// (the "constants cannot drift" lock, #1175 item 4), (b) the kernel source uses
45// only msun transcendentals and no fast-math intrinsics (transcendental-parity
46// intent), and (c) the host oracle satisfies the defining probit identities to
47// a stated ULP bound. Confirming a *device launch* reproduces this oracle to
48// round-off still needs CUDA hardware.
49
50/// `1/√(2π)`, matching `INV_SQRT_2PI` in the kernel source bit-for-bit.
51pub const INV_SQRT_2PI: f64 = 0.3989422804014327;
52/// `√2`, matching `SQRT_2` in the kernel source bit-for-bit.
53pub const SQRT_2: f64 = 1.4142135623730951;
54/// `1/√π`, matching `inv_sqrt_pi` in the kernel source bit-for-bit.
55pub const INV_SQRT_PI: f64 = 0.5641895835477563;
56/// `√(2/π)`, matching `sqrt_2_over_pi` in the kernel source bit-for-bit.
57pub const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
58
59/// Scaled complementary error function `erfcx(x) = exp(x²)·erfc(x)` for `x ≥ 0`,
60/// the host oracle for the device `erfcx_nonnegative`. Returns `1.0` for
61/// `x ≤ 0`, `0.0` at `+∞`, `+∞` at `−∞`. For `0 < x < 26` evaluates the direct
62/// `exp(min(x², 700))·erfc(x)` form; beyond that (where `exp(x²)` would
63/// overflow) it switches to the same 4-term asymptotic expansion as the kernel.
64pub fn erfcx_nonnegative(x: f64) -> f64 {
65    if !x.is_finite() {
66        return if x > 0.0 { 0.0 } else { f64::INFINITY };
67    }
68    if x <= 0.0 {
69        return 1.0;
70    }
71    if x < 26.0 {
72        let mut xx = x * x;
73        if xx > 700.0 {
74            xx = 700.0;
75        }
76        return libm::exp(xx) * erfc(x);
77    }
78    let inv = 1.0 / x;
79    let inv2 = inv * inv;
80    let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
81        + 6.5625 * inv2 * inv2 * inv2 * inv2;
82    inv * poly * INV_SQRT_PI
83}
84
85/// `log Φ(x)` for the standard normal CDF, the host oracle for the device
86/// `log_ndtr`. For `x < 0` uses the `erfcx` representation
87/// `log Φ(x) = −u² + log(½·erfcx(u))`, `u = −x/√2`, keeping digits into the
88/// deep left tail; for `x ≥ 0` uses `log(½·erfc(−x/√2))` with the same clamps as
89/// the kernel. Propagates `±∞`/`NaN` exactly as the device path does.
90pub fn log_ndtr(x: f64) -> f64 {
91    if x == f64::INFINITY {
92        return 0.0;
93    }
94    if x == f64::NEG_INFINITY {
95        return f64::NEG_INFINITY;
96    }
97    if x.is_nan() {
98        return x;
99    }
100    if x < 0.0 {
101        let u = -x / SQRT_2;
102        let mut ex = erfcx_nonnegative(u);
103        if ex < 1e-300 {
104            ex = 1e-300;
105        }
106        -u * u + libm::log(0.5 * ex)
107    } else {
108        let mut c = 0.5 * erfc(-x / SQRT_2);
109        if c < 1e-300 {
110            c = 1e-300;
111        }
112        if c > 1.0 {
113            c = 1.0;
114        }
115        libm::log(c)
116    }
117}
118
119/// Joint `(log Φ(x), Mills ratio φ(x)/Φ(x))`, the host oracle for the device
120/// `log_ndtr_and_mills`. The `x < 0` branch computes the Mills ratio as
121/// `√(2/π)/erfcx(u)`, which stays finite even when `Φ(x)` underflows; the
122/// `x ≥ 0` branch forms `pdf/cdf` directly. Boundary values mirror the kernel:
123/// `(+0, +0)` at `+∞`, `(−∞, +∞)` at `−∞`, `(NaN, NaN)` at `NaN`.
124pub fn log_ndtr_and_mills(x: f64) -> (f64, f64) {
125    if x == f64::INFINITY {
126        return (0.0, 0.0);
127    }
128    if x == f64::NEG_INFINITY {
129        return (f64::NEG_INFINITY, f64::INFINITY);
130    }
131    if x.is_nan() {
132        return (x, x);
133    }
134    if x < 0.0 {
135        let u = -x / SQRT_2;
136        let mut ex = erfcx_nonnegative(u);
137        if ex < 1e-300 {
138            ex = 1e-300;
139        }
140        let log_cdf = -u * u + libm::log(0.5 * ex);
141        let lambda = SQRT_2_OVER_PI / ex;
142        (log_cdf, lambda)
143    } else {
144        let mut cdf = 0.5 * erfc(-x / SQRT_2);
145        if cdf < 1e-300 {
146            cdf = 1e-300;
147        }
148        if cdf > 1.0 {
149            cdf = 1.0;
150        }
151        let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
152        let log_cdf = libm::log(cdf);
153        let lambda = pdf / cdf;
154        (log_cdf, lambda)
155    }
156}
157
158#[cfg(test)]
159mod probit_parity_tests {
160    //! CPU-verifiable floating-point-order & transcendental parity harness for
161    //! the shared probit numerics (issue #1175). Everything here runs without a
162    //! GPU: it pins the host oracle constants to the kernel-source literals,
163    //! audits the kernel source for msun-only transcendentals (no fast-math),
164    //! and checks the host oracle against the defining probit identities within
165    //! stated ULP bounds. A *device* reproducing this oracle to round-off still
166    //! requires CUDA hardware and is asserted by the on-device parity gates.
167    use super::*;
168    use crate::numerics_device::PROBIT_NUMERICS_CU;
169
170    const EPS: f64 = f64::EPSILON; // 2.220446049250313e-16
171
172    /// Relative error of `got` vs `want`, expressed in ULP of `want`.
173    fn ulp(got: f64, want: f64) -> f64 {
174        if want == 0.0 {
175            (got - want).abs() / EPS
176        } else {
177            (got - want).abs() / (EPS * want.abs())
178        }
179    }
180
181    /// Extract the first f64 literal appearing after `needle` in `src`.
182    fn literal_after(src: &str, needle: &str) -> f64 {
183        let start = src
184            .find(needle)
185            .unwrap_or_else(|| panic!("kernel source is missing marker {needle:?}"))
186            + needle.len();
187        let tail = &src[start..];
188        // Skip separators between the marker and the number ('=', whitespace).
189        let num_start = tail
190            .find(|c: char| c == '-' || c == '.' || c.is_ascii_digit())
191            .unwrap_or_else(|| panic!("no numeric literal follows {needle:?}"));
192        let rest = &tail[num_start..];
193        let end = rest
194            .find(|c: char| !(c.is_ascii_digit() || matches!(c, '.' | 'e' | 'E' | '+' | '-')))
195            .unwrap_or(rest.len());
196        rest[..end]
197            .parse::<f64>()
198            .unwrap_or_else(|e| panic!("failed to parse literal after {needle:?}: {e}"))
199    }
200
201    /// #1175 item 4 pattern ("constants cannot drift"): every constant the host
202    /// oracle uses is bit-identical to the literal baked into the kernel source.
203    /// A one-bit edit on either side fails this immediately.
204    #[test]
205    fn host_constants_match_kernel_source_bit_for_bit() {
206        for (needle, host) in [
207            ("#define INV_SQRT_2PI", INV_SQRT_2PI),
208            ("#define SQRT_2", SQRT_2),
209            ("inv_sqrt_pi =", INV_SQRT_PI),
210            ("sqrt_2_over_pi =", SQRT_2_OVER_PI),
211        ] {
212            let device = literal_after(PROBIT_NUMERICS_CU, needle);
213            assert_eq!(
214                device.to_bits(),
215                host.to_bits(),
216                "constant {needle:?} drifted: kernel={device:?} host={host:?}"
217            );
218        }
219    }
220
221    /// Transcendental-parity intent: the kernel evaluates its transcendentals
222    /// through the msun `erfc`/`exp`/`log` (which the host `libm` mirrors) and
223    /// contains NO fast-math intrinsic or single-precision variant. FMA
224    /// contraction is separately disabled at compile time via
225    /// `device_cache`'s `--fmad=false`; this guards the source itself.
226    #[test]
227    fn kernel_source_uses_msun_transcendentals_only() {
228        for good in ["erfc(", "exp(", "log("] {
229            assert!(
230                PROBIT_NUMERICS_CU.contains(good),
231                "kernel source should call msun `{good}`"
232            );
233        }
234        for bad in [
235            "__expf",
236            "__logf",
237            "expf(",
238            "logf(",
239            "erfcf(",
240            "__fdividef",
241            "__frcp",
242            "use_fast_math",
243            "ffast-math",
244            "__dmul_",
245            "__dadd_",
246            "__fmaf",
247        ] {
248            assert!(
249                !PROBIT_NUMERICS_CU.contains(bad),
250                "kernel source must not use fast-math / single-precision `{bad}`"
251            );
252        }
253    }
254
255    /// `erfc` boundary + symmetry: `erfc(0)=1` exactly and
256    /// `erfc(-x) = 2 - erfc(x)` to ≤ 2 ULP across a moderate grid.
257    #[test]
258    fn erfc_boundary_and_symmetry() {
259        assert_eq!(erfc(0.0), 1.0);
260        let mut worst = 0.0_f64;
261        for i in 0..300 {
262            let x = i as f64 * 0.01;
263            worst = worst.max(ulp(erfc(-x), 2.0 - erfc(x)));
264        }
265        assert!(worst <= 2.0, "erfc symmetry drift {worst:.3} ULP > 2");
266    }
267
268    /// Defining identity `erfcx(x)·exp(-x²) = erfc(x)` to ≤ 4 ULP for
269    /// `0 < x < 26` (the direct branch of the host oracle).
270    #[test]
271    fn erfcx_matches_definition() {
272        assert_eq!(erfcx_nonnegative(0.0), 1.0);
273        assert_eq!(erfcx_nonnegative(-3.0), 1.0);
274        assert_eq!(erfcx_nonnegative(f64::INFINITY), 0.0);
275        assert_eq!(erfcx_nonnegative(f64::NEG_INFINITY), f64::INFINITY);
276        let mut worst = 0.0_f64;
277        let mut x = 0.1;
278        while x < 25.0 {
279            worst = worst.max(ulp(erfcx_nonnegative(x) * libm::exp(-x * x), erfc(x)));
280            x += 0.1;
281        }
282        assert!(worst <= 4.0, "erfcx definition drift {worst:.3} ULP > 4");
283    }
284
285    /// `log_ndtr` boundary + bulk identity `log Φ(x) = log(½·erfc(-x/√2))` to
286    /// ≤ 2 ULP for `|x| ≤ 3`, and `Φ(x)+Φ(-x)=1` to ≤ 4e-16.
287    #[test]
288    fn log_ndtr_matches_log_cdf_and_reflects() {
289        assert_eq!(log_ndtr(0.0), libm::log(0.5));
290        assert_eq!(log_ndtr(f64::INFINITY), 0.0);
291        assert_eq!(log_ndtr(f64::NEG_INFINITY), f64::NEG_INFINITY);
292        assert!(log_ndtr(f64::NAN).is_nan());
293
294        let mut worst_bulk = 0.0_f64;
295        for i in -30..=30 {
296            let x = i as f64 * 0.1;
297            let cdf = 0.5 * erfc(-x / SQRT_2);
298            worst_bulk = worst_bulk.max(ulp(log_ndtr(x), libm::log(cdf)));
299        }
300        assert!(
301            worst_bulk <= 2.0,
302            "log_ndtr vs log-cdf drift {worst_bulk:.3} ULP > 2"
303        );
304
305        let mut worst_refl = 0.0_f64;
306        for i in 0..60 {
307            let x = i as f64 * 0.1;
308            let s = libm::exp(log_ndtr(x)) + libm::exp(log_ndtr(-x));
309            worst_refl = worst_refl.max((s - 1.0).abs());
310        }
311        assert!(
312            worst_refl <= 4e-16,
313            "Φ(x)+Φ(-x) reflection drift {worst_refl:e} > 4e-16"
314        );
315    }
316
317    /// `log_ndtr_and_mills` agrees with `log_ndtr` on the log-CDF channel and
318    /// satisfies the Mills identity `λ(x)·Φ(x) = φ(x)` to ≤ 32 ULP for
319    /// `|x| ≤ 5`; the deep left tail stays finite (no `-∞`/`NaN`).
320    #[test]
321    fn log_ndtr_and_mills_identity_and_deep_tail() {
322        for i in -50..=50 {
323            let x = i as f64 * 0.1;
324            let (log_cdf, lambda) = log_ndtr_and_mills(x);
325            assert_eq!(
326                log_cdf.to_bits(),
327                log_ndtr(x).to_bits(),
328                "joint log-CDF channel diverged from log_ndtr at x={x}"
329            );
330            let phi = libm::exp(log_cdf);
331            let pdf = INV_SQRT_2PI * libm::exp(-0.5 * x * x);
332            assert!(
333                ulp(lambda * phi, pdf) <= 32.0,
334                "Mills identity drift {:.3} ULP > 32 at x={x}",
335                ulp(lambda * phi, pdf)
336            );
337        }
338        for &x in &[-10.0, -20.0, -30.0, -38.0] {
339            let (log_cdf, lambda) = log_ndtr_and_mills(x);
340            assert!(
341                log_cdf.is_finite() && log_cdf < 0.0,
342                "deep-tail log Φ({x}) not finite-negative: {log_cdf}"
343            );
344            assert!(
345                lambda.is_finite() && lambda > x.abs() * 0.9,
346                "deep-tail Mills({x}) should track |x|: {lambda}"
347            );
348        }
349        assert_eq!(log_ndtr_and_mills(f64::INFINITY), (0.0, 0.0));
350        assert_eq!(
351            log_ndtr_and_mills(f64::NEG_INFINITY),
352            (f64::NEG_INFINITY, f64::INFINITY)
353        );
354    }
355}