Skip to main content

gam_inference/
gpu_polya_gamma.rs

1//! GPU Pólya–Gamma sampler primitive — INCOMPATIBLE with shipped probit BMS
2//! (different model).
3//!
4//! This module implements a stand-alone, device-resident Pólya–Gamma sampler
5//! plus a synthetic *logistic* Gibbs harness used to validate the sampler
6//! because those are probit families — PG augmentation is exact only for the
7//! Bernoulli **logistic** likelihood (Polson, Scott & Windle 2013). Probit
8//! paths (`bms_flex`, `bernoulli_marginal_slope`) use a different likelihood and
9//! do not call this module.
10//!
11//! The block 7 math design splits the device sampler into three regimes
12//! (math §7), each kernel laid out to avoid warp divergence inside the
13//! launch:
14//!
15//! * **`pg1_kernel`** — exact Devroye (math §8) for shape `b = 1`. This
16//!   covers pure Bernoulli rows. Each row owns a `curand`-style XORWOW
17//!   state seeded statelessly from `(seed, row_index)` so two runs with
18//!   the same seed produce bit-identical draws regardless of grid layout.
19//!   The alternating-series accept/reject uses the corrected right-tail
20//!   coefficient `π · k` (not `π / 2`) — the math team’s Phase-1 fix.
21//! * **`sp_kernel`** — saddlepoint rejection (math §9) for `13 < b ≤ 170`.
22//!   This solves `K'(t) = x` via six Newton iterations on `tanh(v)/v` or
23//!   `tan(v)/v` and uses an IG + Gamma envelope for the accept/reject.
24//! * **`normal_kernel`** — Lyapunov-CLT closed-form approximation
25//!   (math §10) for `b > 170`. Mean and variance use the analytic
26//!   PG(b, c) limit, no rejection loop, no warp divergence.
27//!
28//! The host dispatcher partitions an input vector of `(b_i, c_i)` rows
29//! into three contiguous index lists (one per regime) and launches one
30//! kernel per regime. The `8 ≤ b ≤ 13` band is handled on host via the
31//! sum-of-PG(1, c) convolution identity — at small `b` the sum cost is
32//! negligible and keeping it off-device avoids a fourth kernel that would
33//! see almost no traffic in practice.
34//!
35//! ## What this primitive intentionally does NOT do
36//!
37//! * It does **not** plug into BMS marginal slope (probit model) — the PG
38//!   augmentation identity is logit-only; doing so silently would change
39//!   numerical results for shipped fits.
40//! * It does **not** define a public production family. The
41//!   Gibbs harness in [`logistic_gibbs_step`] is a *validation oracle* for
42//!   the sampler primitive, not a fit method. The CPU reference
43//!   `src/inference/polya_gamma.rs` and the NUTS/HMC infrastructure remain
44//!   the supported posterior-inference paths.
45//!
46//! ## Stateless XORWOW seeding
47//!
48//! Each row’s XORWOW state `(s0, s1, s2, s3, s4, counter)` is materialised
49//! by feeding `splitmix64( seed ⊕ row · ZETA ⊕ word · GAMMA )` for word
50//! indices `0..5` — five 32-bit lanes plus a 32-bit counter. The host
51//! reference RNG (`xorwow_state_from`) reproduces the same byte sequence
52//! the kernel emits, so CPU/GPU parity tests can compare draw-by-draw at
53//! the same `(seed, row)`.
54
55use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
56
57use gam_linalg::triangular::{back_substitution_lower_transpose, cholesky_solve_vector};
58
59#[cfg(target_os = "linux")]
60use gam_gpu::gpu_error::GpuError;
61
62// ────────────────────────────────────────────────────────────────────────
63// Public types
64// ────────────────────────────────────────────────────────────────────────
65
66/// Stateless seed for the per-row XORWOW PRNG. The same `seed` reproduces
67/// bit-identical draws across runs and across CPU/GPU implementations.
68#[derive(Clone, Copy, Debug)]
69pub struct PgSeed(pub u64);
70
71impl Default for PgSeed {
72    fn default() -> Self {
73        Self(0x50_4F_4C_59_47_41_4D_41) // "POLYGAMA" big-endian ascii
74    }
75}
76
77/// Regime split thresholds (math §7).
78///
79/// * `PG1_MAX_B = 1` — exact-Devroye regime.
80/// * `(PG1_MAX_B, SADDLE_MIN_B)` — host convolution-of-PG(1) regime.
81/// * `[SADDLE_MIN_B, SADDLE_MAX_B]` — saddlepoint-rejection regime.
82/// * `b > NORMAL_MIN_B` — normal-approximation regime.
83pub const PG1_MAX_B: u32 = 1;
84pub const SADDLE_MIN_B: u32 = 14;
85pub const SADDLE_MAX_B: u32 = 170;
86pub const NORMAL_MIN_B: u32 = 171;
87
88/// Inputs for the dispatched batched sampler.
89#[derive(Clone, Debug)]
90pub struct PolyaGammaBatchInput<'a> {
91    /// Shape parameters `b_i`. Must be ≥ 1.
92    pub shapes: ArrayView1<'a, u32>,
93    /// Tilt parameters `c_i = ψ_i`. Sign is irrelevant (sampler uses |c|).
94    pub tilts: ArrayView1<'a, f64>,
95    /// Stateless RNG seed.
96    pub seed: PgSeed,
97}
98
99impl<'a> PolyaGammaBatchInput<'a> {
100    pub fn rows(&self) -> usize {
101        self.shapes.len()
102    }
103
104    pub fn validate(&self) -> Result<(), String> {
105        if self.shapes.len() != self.tilts.len() {
106            return Err(format!(
107                "polya_gamma: shapes.len()={} != tilts.len()={}",
108                self.shapes.len(),
109                self.tilts.len()
110            ));
111        }
112        if self.shapes.iter().any(|b| *b == 0) {
113            return Err("polya_gamma: b=0 is invalid (PG(0,c) is a point mass at 0)".to_string());
114        }
115        Ok(())
116    }
117}
118
119// ────────────────────────────────────────────────────────────────────────
120// SplitMix64 finalizer + per-row XORWOW seeding
121// ────────────────────────────────────────────────────────────────────────
122
123/// SplitMix64 finalizer (matches `reml_trace::splitmix64_mix`). Thin wrapper
124/// over the canonical implementation in [`gam_linalg::utils::splitmix64_hash`].
125#[inline]
126pub fn splitmix64_mix(z: u64) -> u64 {
127    gam_linalg::utils::splitmix64_hash(z)
128}
129
130/// Two large odd constants used to mix `(seed, row, word)` into the
131/// SplitMix input. Disjoint from the `reml_trace` constants so different
132/// kernels with the same seed don’t share probe sequences.
133const ROW_ZETA: u64 = 0xA1B2_C3D4_E5F6_7890;
134const WORD_GAMMA: u64 = 0x0F1E_2D3C_4B5A_6978;
135
136/// Compact per-row XORWOW state. Layout matches `curand_kernel.h`’s
137/// `curandStateXORWOW_t` for the five state lanes plus the addition
138/// counter; we omit the boxmuller cache (PG sampler doesn’t use it).
139#[derive(Clone, Copy, Debug)]
140pub struct XorwowState {
141    pub s: [u32; 5],
142    pub d: u32,
143}
144
145impl XorwowState {
146    /// Stateless seeding from `(seed, row)`. Each of the six state words
147    /// is the high or low half of a SplitMix64 hash of
148    /// `splitmix64(seed ⊕ row·ROW_ZETA ⊕ word·WORD_GAMMA)`. The first
149    /// non-zero state word is enforced so we never enter the all-zero
150    /// XORWOW absorbing fixed point.
151    pub fn new(seed: u64, row: u64) -> Self {
152        let mut words = [0u32; 6];
153        for (word_idx, slot) in words.iter_mut().enumerate() {
154            let composite =
155                seed ^ row.wrapping_mul(ROW_ZETA) ^ (word_idx as u64).wrapping_mul(WORD_GAMMA);
156            let h = splitmix64_mix(composite);
157            *slot = (h >> 32) as u32;
158        }
159        // XORWOW absorbs at all-zeros; flip the low bit of s[0] if it ever
160        // happens (probability 2⁻³² but cheap to guard).
161        if words[0] == 0 && words[1] == 0 && words[2] == 0 && words[3] == 0 && words[4] == 0 {
162            words[0] = 1;
163        }
164        Self {
165            s: [words[0], words[1], words[2], words[3], words[4]],
166            d: words[5],
167        }
168    }
169
170    /// Single XORWOW advance. Returns the next 32-bit output and mutates
171    /// the state. Matches Marsaglia’s 2003 XORWOW formulation, which is
172    /// also what `curand_kernel.h::xorwow` computes.
173    #[inline]
174    pub fn next_u32(&mut self) -> u32 {
175        let mut t = self.s[4];
176        let s = self.s[0];
177        self.s[4] = self.s[3];
178        self.s[3] = self.s[2];
179        self.s[2] = self.s[1];
180        self.s[1] = s;
181        t ^= t >> 2;
182        t ^= t << 1;
183        t ^= s ^ (s << 4);
184        self.s[0] = t;
185        self.d = self.d.wrapping_add(362_437);
186        t.wrapping_add(self.d)
187    }
188
189    /// Uniform double in (0, 1] — same `(u32 + 1) / 2^32` convention the
190    /// kernel uses (matches `curand_uniform_double` upper-open interval
191    /// convention; we use the upper-closed variant so a zero u32 never
192    /// produces exactly zero, which would crash `log(u)` in the Exp draw).
193    #[inline]
194    pub fn next_unit(&mut self) -> f64 {
195        let raw = self.next_u32();
196        ((raw as f64) + 1.0) * (1.0 / 4_294_967_296.0)
197    }
198
199    /// Standard exponential via inverse CDF: `-ln(U)`. `U` is on (0, 1]
200    /// so `-ln(U)` is in `[0, +inf)`, never `+inf` from a zero argument.
201    #[inline]
202    pub fn next_exp(&mut self) -> f64 {
203        -self.next_unit().ln()
204    }
205
206    /// Standard normal via Marsaglia polar method. Discards the second
207    /// variate the polar pair produces (cleaner than caching it across
208    /// calls — we’d need a per-row scratch slot, which the device kernel
209    /// can’t afford to spill).
210    #[inline]
211    pub fn next_norm(&mut self) -> f64 {
212        loop {
213            let u = 2.0 * self.next_unit() - 1.0;
214            let v = 2.0 * self.next_unit() - 1.0;
215            let s = u * u + v * v;
216            if s > 0.0 && s < 1.0 {
217                let factor = (-2.0 * s.ln() / s).sqrt();
218                return u * factor;
219            }
220        }
221    }
222}
223
224// ────────────────────────────────────────────────────────────────────────
225// CPU host reference — Devroye PG(1, c) via the shared sampler core
226// ────────────────────────────────────────────────────────────────────────
227//
228// The CPU oracle for parity tests has to *use the same RNG bytes as the
229// device kernel*, so it drives the shared Devroye core
230// (`crate::polya_gamma_core`) through the bit-exact `XorwowState`
231// rather than through the production `rand::Rng` adapter. The math (Devroye
232// 1986; PSW 2013) is the single shared implementation — there is no second
233// copy of the tail mass / series / inverse-Gaussian helpers to drift.
234
235use crate::polya_gamma_core::{PgRng, draw_pg1};
236use std::f64::consts::{FRAC_PI_2, PI};
237
238/// `XorwowState` is the randomness source for the bit-exact GPU oracle. Wiring
239/// it through [`PgRng`] lets the shared Devroye core run against the same RNG
240/// byte stream the device kernel consumes.
241impl PgRng for XorwowState {
242    #[inline]
243    fn next_unit(&mut self) -> f64 {
244        XorwowState::next_unit(self)
245    }
246
247    #[inline]
248    fn next_exp(&mut self) -> f64 {
249        XorwowState::next_exp(self)
250    }
251
252    #[inline]
253    fn next_norm(&mut self) -> f64 {
254        XorwowState::next_norm(self)
255    }
256}
257
258/// CPU oracle for one PG(1, c) draw using a `XorwowState` directly. The
259/// device kernel performs the same arithmetic byte-for-byte (modulo IEEE
260/// rounding of transcendentals, which agree to <1 ULP for the inputs we
261/// touch).
262pub fn pg1_draw_cpu_oracle(state: &mut XorwowState, tilt: f64) -> f64 {
263    draw_pg1(state, tilt)
264}
265
266/// Higher-shape draw on host via convolution: PG(b, c) =_d Σ_{j=1..b} PG(1, c).
267/// Used by host for the `2 ≤ b ≤ 13` band and as the parity oracle for the
268/// saddlepoint kernel at modest `b`.
269pub fn pg_convolution_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
270    (0..b).map(|_| pg1_draw_cpu_oracle(state, tilt)).sum()
271}
272
273// ────────────────────────────────────────────────────────────────────────
274// Saddlepoint regime (math §9, 13 < b ≤ 170) — host oracle
275// ────────────────────────────────────────────────────────────────────────
276//
277// We sample a tilted-J*(b, z) variate via saddlepoint rejection. The
278// envelope is an IG / Gamma mixture; the saddlepoint approximation to the
279// log density gives a tight acceptance ratio across the full b range. The
280// host implementation here is also the *oracle* used to validate the
281// device sp_kernel.
282
283/// Solve K'(t) = x for the saddlepoint t given x in (0, 1). K'(t) is a
284/// continuous strictly increasing function of t on the appropriate
285/// branch; the math team’s parameterisation eliminates v = sqrt(|2t|) so
286/// the Newton iteration is on a monotone bounded variable.
287///
288/// Branch:
289/// * `x < 1`  → `K'(t) = tanh(v)/v` with `v = sqrt(-2t)`, t ≤ 0.
290/// * `x ≥ 1`  → `K'(t) = tan(v)/v`  with `v = sqrt( 2t)`, t > 0.
291pub fn saddlepoint_solve(x: f64) -> f64 {
292    // Six iterations is the math team’s target (§9). The function is
293    // analytic; Newton on tanh(v)/v or tan(v)/v converges quadratically
294    // from the closed-form initial guess `v0 = sqrt(3(1 - x))` (Taylor of
295    // `tanh(v)/v = 1 - v²/3 + 2v⁴/15 - ...`).
296    if (x - 1.0).abs() < 1e-9 {
297        return 0.0;
298    }
299    if x < 1.0 {
300        // Negative-t branch, work in v = sqrt(-2t). `tanh(v)/v` is monotone
301        // decreasing in v on (0, ∞), with two well-separated asymptotic
302        // regimes:
303        //
304        //   * x ≈ 1 (v small): Taylor expansion tanh(v)/v ≈ 1 - v²/3 gives
305        //     `v₀ = sqrt(3(1 - x))`, which is the ~quadratic starting point
306        //     used historically.
307        //   * x ≈ 0 (v large): `tanh(v) → 1`, so `tanh(v)/v ≈ 1/v` and the
308        //     root sits near `v ≈ 1/x`. The Taylor seed `sqrt(3(1-x)) ≤ √3`
309        //     is bounded above by ~1.73, which leaves Newton walking the
310        //     plateau at ~`tanh(v)/v ≈ 0.55` and converging linearly to the
311        //     true root (≈ 20 at x = 0.05); six Newton steps are not enough
312        //     to drive the relative error to 1e-6 from there.
313        //
314        // Take the maximum of the two seeds so each regime gets a starting
315        // point in its quadratic-convergence basin; the function is monotone
316        // so overshooting the root from above just trades a couple of
317        // descending Newton steps for the missing factor-of-ten distance.
318        // 16 Newton iterations is comfortable even when the initial seed
319        // overshoots and Newton has to recover via several linear steps
320        // before settling into the quadratic regime.
321        let v_taylor = (3.0 * (1.0 - x)).sqrt();
322        let v_asym = 1.0 / x.max(1e-12);
323        let mut v = v_taylor.max(v_asym).max(1e-6);
324        for _ in 0..16 {
325            let tanh_v = v.tanh();
326            let f = tanh_v / v - x;
327            // d/dv [tanh(v)/v] = (1 - tanh²v)/v - tanh(v)/v²
328            //                  = ((1 - tanh²v) - tanh(v)/v) / v.
329            let sech_sq = 1.0 - tanh_v * tanh_v;
330            let df = (sech_sq - tanh_v / v) / v;
331            v -= f / df;
332            if v.abs() < 1e-12 {
333                break;
334            }
335        }
336        -0.5 * v * v
337    } else {
338        // Positive-t branch, work in v = sqrt(2t). The pole of tan is at
339        // v = π/2; the relevant root sits in (0, π/2). Two regimes:
340        //
341        //   * x ≈ 1 (v small): Taylor tan(v)/v ≈ 1 + v²/3 gives
342        //     `v₀ = sqrt(3(x - 1))` — the historical seed.
343        //   * x large (v near π/2): tan(v) ≈ 1/(π/2 - v), so the root sits
344        //     near `v ≈ π/2 - 2/(x π)`. Seeding from the 0.49 π cap leaves
345        //     Newton inside the very steep tail of the pole, where each
346        //     Newton step descends by a fraction of the remaining distance;
347        //     six steps left x = 3 stuck at rel ≈ 1.5e-4 above 1e-6.
348        //
349        // The cap stays at 0.499 π to keep `tan(v)` finite; the analytic
350        // pole-tail seed is honoured when it sits below that cap. Bumping
351        // the iteration cap mirrors the negative branch.
352        let v_taylor = (3.0 * (x - 1.0)).sqrt();
353        let v_pole = FRAC_PI_2 - 2.0 / (x.max(1e-12) * PI);
354        let mut v = v_taylor.max(v_pole).min(0.499 * PI).max(1e-6);
355        for _ in 0..16 {
356            let tan_v = v.tan();
357            let f = tan_v / v - x;
358            // d/dv [tan(v)/v] = (1 + tan²v)/v - tan(v)/v².
359            let sec_sq = 1.0 + tan_v * tan_v;
360            let df = (sec_sq - tan_v / v) / v;
361            v = (v - f / df).max(1e-6).min(0.499_999 * PI);
362            if !v.is_finite() {
363                v = (3.0 * (x - 1.0)).sqrt().min(0.49 * PI);
364                break;
365            }
366        }
367        0.5 * v * v
368    }
369}
370
371/// Saddlepoint approximation K''(t), the variance of the tilted distribution
372/// from K'(t) = x. K''(t) is variance, so positive on both branches.
373///
374/// Derivation. From `saddlepoint_solve` the saddlepoint parameterisation is
375///   negative branch (t ≤ 0): K'(t) = tanh(v)/v with v = sqrt(-2t),
376///   positive branch (t > 0): K'(t) = tan(v)/v  with v = sqrt( 2t).
377/// Chain rule with dv/dt = ±1/v (sign matches the branch) yields
378///   negative branch:  K''(t) = tanh(v)/v³ - sech²(v)/v²
379///   positive branch:  K''(t) = sec²(v)/v²  - tan(v)/v³
380/// As v → 0 both branches reduce to the same Taylor limit 2/3, which is the
381/// continuous value of K''(0).
382///
383/// The previous form returned `sech²(v)/v² - tanh(v)/v³` on the negative
384/// branch — the algebraic negative of the chain-rule derivative — and a
385/// hardcoded `1/3` at t = 0 that did not match either one-sided limit. The
386/// negative-branch sign error produced K''(-2) ≈ -0.103, which the test
387/// `saddlepoint_kpp_is_positive` correctly flagged (variance must be > 0).
388pub fn saddlepoint_kpp(t: f64) -> f64 {
389    if t.abs() < 1e-14 {
390        return 2.0 / 3.0;
391    }
392    if t < 0.0 {
393        let v = (-2.0 * t).sqrt();
394        let tanh_v = v.tanh();
395        let sech_sq = 1.0 - tanh_v * tanh_v;
396        (tanh_v / (v * v * v)) - (sech_sq / (v * v))
397    } else {
398        let v = (2.0 * t).sqrt();
399        let tan_v = v.tan();
400        let sec_sq = 1.0 + tan_v * tan_v;
401        (sec_sq / (v * v)) - (tan_v / (v * v * v))
402    }
403}
404
405/// Saddlepoint host draw for PG(b, c) with `13 < b ≤ 170`. This is the
406/// reference the device sp_kernel matches in distribution; both fall
407/// back to the convolution oracle when `b` is small enough that the
408/// saddlepoint approximation has noticeable bias (validated by §12.4 test).
409pub fn pg_saddlepoint_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
410    // For now, use the convolution identity as the oracle. The saddlepoint
411    // *kernel* is what we ship on device; the host oracle just needs to
412    // produce the correct distribution for parity tests, and PG(b, c) =
413    // sum_{j=1..b} PG(1, c) is exact for integer b. Device-side we use
414    // the saddlepoint to *avoid* paying b times the PG(1) cost.
415    pg_convolution_cpu_oracle(state, b, tilt)
416}
417
418// ────────────────────────────────────────────────────────────────────────
419// Normal-approximation regime (math §10, b > 170) — host oracle
420// ────────────────────────────────────────────────────────────────────────
421
422// The closed-form `PG(b, c)` moments live once on the inference side
423// (`crate::pg_moments`) so the deterministic evidence path can use
424// them without depending on this GPU module; re-export keeps the device oracle
425// and the host evidence code on a single source of truth.
426pub use crate::pg_moments::{pg_mean, pg_variance};
427
428/// Lyapunov-CLT closed-form draw for `b > NORMAL_MIN_B`. Truncated at
429/// zero because PG support is `(0, +∞)`.
430pub fn pg_normal_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
431    let mean = pg_mean(b as f64, tilt);
432    let var = pg_variance(b as f64, tilt);
433    let sd = var.sqrt();
434    let mut draw = mean + sd * state.next_norm();
435    // Reflect into the positive half-line. At b > 170 the probability mass
436    // below zero is ~Φ(-mean/sd) ≈ 0 for any reasonable c; reflection is a
437    // negligibly biased truncation.
438    if draw <= 0.0 {
439        draw = -draw + 1e-300;
440    }
441    draw
442}
443
444// ────────────────────────────────────────────────────────────────────────
445// Host dispatcher — CPU reference for the regime split (math §7)
446// ────────────────────────────────────────────────────────────────────────
447
448/// Per-row CPU draw using the appropriate regime. Used by the harness
449/// when the GPU runtime is unavailable, and as the per-row oracle for
450/// the dispatched device path’s parity tests.
451pub fn draw_batch_cpu(input: &PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
452    input.validate()?;
453    let n = input.rows();
454    let mut out = Array1::<f64>::zeros(n);
455    for i in 0..n {
456        let mut state = XorwowState::new(input.seed.0, i as u64);
457        let b = input.shapes[i];
458        let c = input.tilts[i];
459        let v = if b <= PG1_MAX_B {
460            pg1_draw_cpu_oracle(&mut state, c)
461        } else if b < SADDLE_MIN_B {
462            pg_convolution_cpu_oracle(&mut state, b, c)
463        } else if b <= SADDLE_MAX_B {
464            pg_saddlepoint_cpu_oracle(&mut state, b, c)
465        } else {
466            pg_normal_cpu_oracle(&mut state, b, c)
467        };
468        out[i] = v;
469    }
470    Ok(out)
471}
472
473/// Top-level entry point: dispatches to GPU when available, otherwise CPU.
474/// Both paths use the same per-row XORWOW seeding so the GPU result is a
475/// bit-equivalent of the CPU result up to IEEE rounding of `exp`/`log`/
476/// `tan`/`tanh`/`sqrt` (which the device evaluators round to within 1 ULP
477/// of the CPU `libm`).
478pub fn draw_batch(input: PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
479    input.validate()?;
480
481    #[cfg(target_os = "linux")]
482    {
483        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
484            match linux_cuda::draw_batch_gpu(&input) {
485                Ok(v) => return Ok(v),
486                Err(GpuError::NoDeviceKernel { .. }) => {
487                    // No device kernel for this path on this build: fall
488                    // through to the CPU reference.
489                }
490                Err(other) => return Err(String::from(other)),
491            }
492        }
493    }
494
495    draw_batch_cpu(&input)
496}
497
498// ────────────────────────────────────────────────────────────────────────
499// Phase 5: synthetic logistic Gibbs harness (validation oracle only)
500// ────────────────────────────────────────────────────────────────────────
501
502/// Single Gibbs step for the synthetic Bernoulli-logistic model
503/// `y_i | β ~ Bernoulli(σ(x_iᵀ β))` with prior `β ~ N(0, Q_0⁻¹)`.
504///
505/// Steps (math block 7 §11):
506///
507/// 1. `ψ = X β` (length n).
508/// 2. `ω_i ~ PG(1, ψ_i)` for all i (uses [`draw_batch`]).
509/// 3. `z_i = (y_i − 1/2) / ω_i` (working response).
510/// 4. `Q_ω = Xᵀ Ω X + Q_0`, `m_ω = Xᵀ Ω z`.
511/// 5. Cholesky `Q_ω = L Lᵀ`, mean `μ = (Q_ω)⁻¹ m_ω = L⁻ᵀ L⁻¹ m_ω`.
512/// 6. `β ← μ + L⁻ᵀ η` with `η ~ N(0, I_p)`.
513///
514/// This is a *primitive validation harness*; it deliberately runs entirely
515/// on host except for the PG draws, which are the thing under test. The
516/// posterior-inference path that ships with `gam` is NUTS, not this Gibbs
517/// loop, and this module does not export the Gibbs sampler as a fit method.
518pub fn logistic_gibbs_step(
519    design: ArrayView2<'_, f64>,
520    targets: ArrayView1<'_, u8>,
521    prior_precision: ArrayView2<'_, f64>,
522    beta: ArrayView1<'_, f64>,
523    seed: PgSeed,
524    norm_seed: u64,
525) -> Result<Array1<f64>, String> {
526    let (n, p) = design.dim();
527    if targets.len() != n {
528        return Err(format!(
529            "logistic_gibbs_step: y.len()={} != n={n}",
530            targets.len()
531        ));
532    }
533    if prior_precision.dim() != (p, p) {
534        return Err(format!(
535            "logistic_gibbs_step: Q_0 shape {:?} != ({p}, {p})",
536            prior_precision.dim()
537        ));
538    }
539    if beta.len() != p {
540        return Err(format!(
541            "logistic_gibbs_step: beta.len()={} != p={p}",
542            beta.len()
543        ));
544    }
545
546    // Step 1: ψ = X β  (host matvec — n×p × p).
547    let mut psi = Array1::<f64>::zeros(n);
548    for i in 0..n {
549        let mut acc = 0.0;
550        for j in 0..p {
551            acc += design[[i, j]] * beta[j];
552        }
553        psi[i] = acc;
554    }
555
556    // Step 2: ω_i ~ PG(1, ψ_i).
557    let shapes = Array1::<u32>::from_elem(n, 1);
558    let omega = draw_batch(PolyaGammaBatchInput {
559        shapes: shapes.view(),
560        tilts: psi.view(),
561        seed,
562    })?;
563
564    // Step 3: z_i = (y_i − 1/2) / ω_i  — but we never form z explicitly;
565    //   m_ω = Xᵀ (y − 1/2)  (the ω cancels) is the standard PSW shortcut.
566    let mut m = Array1::<f64>::zeros(p);
567    for i in 0..n {
568        let r = targets[i] as f64 - 0.5;
569        for j in 0..p {
570            m[j] += design[[i, j]] * r;
571        }
572    }
573
574    // Step 4: Q_ω = Xᵀ Ω X + Q_0  (symmetric p × p; O(n p²)).
575    let mut q = prior_precision.to_owned();
576    for i in 0..n {
577        let w = omega[i];
578        for a in 0..p {
579            let xa = design[[i, a]];
580            for b in 0..p {
581                q[[a, b]] += w * xa * design[[i, b]];
582            }
583        }
584    }
585
586    // Step 5: Cholesky L Lᵀ = Q_ω.
587    let l = cholesky_lower_inplace(q.clone())
588        .map_err(|e| format!("logistic_gibbs_step Cholesky: {e}"))?;
589    // μ = (Q_ω)⁻¹ m via L y = m, Lᵀ μ = y.
590    let mean = cholesky_solve_vector(&l, &m);
591
592    // Step 6: β ← μ + L⁻ᵀ η.
593    let mut norm_state = XorwowState::new(norm_seed, 0);
594    let mut eta = Array1::<f64>::zeros(p);
595    for j in 0..p {
596        eta[j] = norm_state.next_norm();
597    }
598    let perturb = back_substitution_lower_transpose(&l, &eta);
599    let mut beta_new = Array1::<f64>::zeros(p);
600    for j in 0..p {
601        beta_new[j] = mean[j] + perturb[j];
602    }
603    Ok(beta_new)
604}
605
606fn cholesky_lower_inplace(mut a: Array2<f64>) -> Result<Array2<f64>, String> {
607    let n = a.nrows();
608    for i in 0..n {
609        for j in 0..=i {
610            let mut sum = a[[i, j]];
611            for k in 0..j {
612                sum -= a[[i, k]] * a[[j, k]];
613            }
614            if i == j {
615                if sum <= 0.0 {
616                    return Err(format!("non-SPD diagonal {sum} at row {i}"));
617                }
618                a[[i, j]] = sum.sqrt();
619            } else {
620                a[[i, j]] = sum / a[[j, j]];
621            }
622        }
623        for j in (i + 1)..n {
624            a[[i, j]] = 0.0;
625        }
626    }
627    Ok(a)
628}
629
630// ────────────────────────────────────────────────────────────────────────
631// Linux/CUDA implementation — Phases 2, 3, 4, 6
632// ────────────────────────────────────────────────────────────────────────
633
634#[cfg(target_os = "linux")]
635mod linux_cuda {
636    use super::{
637        PG1_MAX_B, PgSeed, PolyaGammaBatchInput, SADDLE_MAX_B, SADDLE_MIN_B, XorwowState,
638        pg_convolution_cpu_oracle, pg_normal_cpu_oracle,
639    };
640    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
641    use gam_gpu::solver::context_and_stream;
642    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
643    use ndarray::Array1;
644    use std::sync::Arc;
645
646    /// NVRTC source prelude: SplitMix64 seeding, the per-row XORWOW state
647    /// advance, and the unit/exp/normal draw helpers. The Devroye constants
648    /// and the sampler body that follow are appended at compile time by
649    /// [`ptx_source`], with the numeric constants rendered from the shared
650    /// Rust [`crate::polya_gamma_core::constants`] so no device
651    /// literal is hand-typed.
652    ///
653    /// All arithmetic is in `double`; the device transcendentals (`exp`,
654    /// `log`, `tanh`, `tan`, `sqrt`, `erfc`) are the high-accuracy intrinsics
655    /// — we do NOT use `__expf` / `__tanhf`, which would diverge from the CPU
656    /// oracle past a few ULPs.
657    ///
658    /// Layout of inputs/outputs:
659    ///
660    /// * `shapes` — u32, length `n`.
661    /// * `tilts`  — f64, length `n`.
662    /// * `out`    — f64, length `n`.
663    /// * Each thread owns one row index `i`; it constructs its own XORWOW
664    ///   state from `(seed, i)` via SplitMix64, draws once, and writes
665    ///   `out[i]`. No shared state → no warp divergence beyond what the
666    ///   algorithm itself dictates.
667    const PTX_SOURCE_PRELUDE: &str = r#"
668extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
669    z += 0x9E3779B97F4A7C15ULL;
670    unsigned long long x = z;
671    x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
672    x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
673    return x ^ (x >> 31);
674}
675
676// Per-row XORWOW state. Layout mirrors curand_kernel.h::curandStateXORWOW_t
677// for the five 32-bit state lanes plus the addition counter. We omit the
678// boxmuller_extra/boxmuller_flag cache since our normal draws use the
679// polar method (which discards the second variate).
680struct XorwowState {
681    unsigned int s0, s1, s2, s3, s4, d;
682};
683
684extern "C" __device__ void xorwow_seed(struct XorwowState* st, unsigned long long seed, unsigned long long row) {
685    const unsigned long long ROW_ZETA  = 0xA1B2C3D4E5F67890ULL;
686    const unsigned long long WORD_GAMMA = 0x0F1E2D3C4B5A6978ULL;
687    unsigned int words[6];
688    for (int w = 0; w < 6; ++w) {
689        unsigned long long composite = seed ^ (row * ROW_ZETA) ^ ((unsigned long long)w * WORD_GAMMA);
690        unsigned long long h = splitmix64_mix(composite);
691        words[w] = (unsigned int)(h >> 32);
692    }
693    if ((words[0] | words[1] | words[2] | words[3] | words[4]) == 0u) {
694        words[0] = 1u;
695    }
696    st->s0 = words[0]; st->s1 = words[1]; st->s2 = words[2];
697    st->s3 = words[3]; st->s4 = words[4]; st->d  = words[5];
698}
699
700extern "C" __device__ unsigned int xorwow_next(struct XorwowState* st) {
701    unsigned int t = st->s4;
702    unsigned int s = st->s0;
703    st->s4 = st->s3;
704    st->s3 = st->s2;
705    st->s2 = st->s1;
706    st->s1 = s;
707    t ^= (t >> 2);
708    t ^= (t << 1);
709    t ^= s ^ (s << 4);
710    st->s0 = t;
711    st->d += 362437u;
712    return t + st->d;
713}
714
715extern "C" __device__ double xorwow_unit(struct XorwowState* st) {
716    unsigned int raw = xorwow_next(st);
717    return ((double)raw + 1.0) * (1.0 / 4294967296.0);
718}
719
720extern "C" __device__ double xorwow_exp(struct XorwowState* st) {
721    return -log(xorwow_unit(st));
722}
723
724extern "C" __device__ double xorwow_norm(struct XorwowState* st) {
725    // Marsaglia polar — discard the partner variate, matches host oracle
726    // byte-for-byte (host also discards).
727    for (;;) {
728        double u = 2.0 * xorwow_unit(st) - 1.0;
729        double v = 2.0 * xorwow_unit(st) - 1.0;
730        double s = u * u + v * v;
731        if (s > 0.0 && s < 1.0) {
732            double factor = sqrt(-2.0 * log(s) / s);
733            return u * factor;
734        }
735    }
736}
737"#;
738
739    /// NVRTC source body: the Devroye / saddlepoint device helpers and the
740    /// three regime kernels. Appended by [`ptx_source`] after the prelude and
741    /// the rendered `#define` constants. The `// ── Devroye PG(1, c)` helpers
742    /// here consume `PG_FRAC_2_PI`, `PG_PI`, `PG_PI_SQ`, `PG_SQRT_2_OVER_PI`,
743    /// and `PG_SQRT_PI_OVER_2`, all defined by the rendered constant block.
744    const PTX_SOURCE_BODY: &str = r#"
745extern "C" __device__ double std_normal_cdf(double x) {
746    // 0.5 · erfc(-x / sqrt(2)).
747    return 0.5 * erfc(-x * 0.7071067811865475);
748}
749
750extern "C" __device__ double pg_series(int n, double x) {
751    if (x <= 0.0) return 0.0;
752    double k = (double)n + 0.5;
753    double k_sq = k * k;
754    if (x <= PG_FRAC_2_PI) {
755        double inv_x = 1.0 / x;
756        return (2.0 * k * PG_SQRT_2_OVER_PI) * inv_x * sqrt(inv_x) * exp(-2.0 * k_sq * inv_x);
757    } else {
758        // Right branch — corrected coefficient PI · k (not PI / 2).
759        return PG_PI * k * exp(-0.5 * k_sq * PG_PI_SQ * x);
760    }
761}
762
763extern "C" __device__ double pg_exp_tail_mass(double tilt) {
764    double base = 0.125 * PG_PI_SQ + 0.5 * tilt * tilt;
765    double upper = PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt - 1.0);
766    double lower = -(PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt + 1.0));
767    double base_factor = base * exp(base * PG_FRAC_2_PI);
768    double p_upper = base_factor * exp(-tilt) * std_normal_cdf(upper);
769    double p_lower = base_factor * exp( tilt) * std_normal_cdf(lower);
770    double exp_terms = (4.0 / PG_PI) * (p_upper + p_lower);
771    return 1.0 / (1.0 + exp_terms);
772}
773
774extern "C" __device__ double sample_small_z(struct XorwowState* st, double z, double trunc) {
775    double accept = 0.0;
776    double sample = 0.0;
777    while (accept < xorwow_unit(st)) {
778        double exp_sample;
779        for (;;) {
780            double e1 = xorwow_exp(st);
781            double e2 = xorwow_exp(st);
782            if (e1 * e1 <= 2.0 * e2 / trunc) { exp_sample = e1; break; }
783        }
784        sample = 1.0 + exp_sample * trunc;
785        sample = trunc / (sample * sample);
786        accept = exp(-0.5 * z * z * sample);
787    }
788    return sample;
789}
790
791extern "C" __device__ double sample_large_z(struct XorwowState* st, double mean, double trunc) {
792    double sample = 1.0e300;
793    while (sample > trunc) {
794        double n = xorwow_norm(st);
795        double n_sq = n * n;
796        double half_mean = 0.5 * mean;
797        double mn_sq = mean * n_sq;
798        double disc = sqrt(4.0 * mn_sq + mn_sq * mn_sq);
799        sample = mean + half_mean * mn_sq - half_mean * disc;
800        if (xorwow_unit(st) > mean / (mean + sample)) {
801            sample = mean * mean / sample;
802        }
803    }
804    return sample;
805}
806
807extern "C" __device__ double sample_trunc_inv_gauss(struct XorwowState* st, double z, double trunc) {
808    double az = fabs(z);
809    if (PG_FRAC_2_PI > az) {
810        return sample_small_z(st, az, trunc);
811    } else {
812        return sample_large_z(st, 1.0 / az, trunc);
813    }
814}
815
816extern "C" __device__ double pg1_draw(struct XorwowState* st, double tilt) {
817    double half_tilt = fabs(tilt) * 0.5;
818    double scale = 0.125 * PG_PI_SQ + 0.5 * half_tilt * half_tilt;
819    double exp_mass = pg_exp_tail_mass(half_tilt);
820
821    for (;;) {
822        double u = xorwow_unit(st);
823        double proposal;
824        if (u < exp_mass) {
825            proposal = PG_FRAC_2_PI + xorwow_exp(st) / scale;
826        } else {
827            proposal = sample_trunc_inv_gauss(st, half_tilt, PG_FRAC_2_PI);
828        }
829        double sum = pg_series(0, proposal);
830        double threshold = xorwow_unit(st) * sum;
831        int idx = 0;
832        // The alternating-series tail. Bounded iteration cap (64) is
833        // overwhelmingly safe: PSW 2013 show termination in <10 iters
834        // with probability >1 - 1e-30 for any tilt; the cap exists only
835        // to guarantee forward progress under hardware fault.
836        for (int outer = 0; outer < 64; ++outer) {
837            idx += 1;
838            double term = pg_series(idx, proposal);
839            if (idx & 1) {
840                sum -= term;
841                if (threshold <= sum) {
842                    return 0.25 * proposal;
843                }
844            } else {
845                sum += term;
846                if (threshold >= sum) {
847                    break;
848                }
849            }
850        }
851    }
852}
853
854// ── Saddlepoint helpers (math §9) ────────────────────────────────────────
855
856extern "C" __device__ double saddlepoint_t(double x) {
857    if (fabs(x - 1.0) < 1.0e-9) return 0.0;
858    if (x < 1.0) {
859        double v = sqrt(3.0 * (1.0 - x)); if (v < 1.0e-6) v = 1.0e-6;
860        for (int it = 0; it < 6; ++it) {
861            double tanh_v = tanh(v);
862            double f  = tanh_v / v - x;
863            double sech_sq = 1.0 - tanh_v * tanh_v;
864            double df = (sech_sq - tanh_v / v) / v;
865            v -= f / df;
866            if (fabs(v) < 1.0e-12) break;
867        }
868        return -0.5 * v * v;
869    } else {
870        double v = sqrt(3.0 * (x - 1.0));
871        if (v > 0.49 * PG_PI) v = 0.49 * PG_PI;
872        if (v < 1.0e-6) v = 1.0e-6;
873        for (int it = 0; it < 6; ++it) {
874            double tan_v = tan(v);
875            double f  = tan_v / v - x;
876            double sec_sq = 1.0 + tan_v * tan_v;
877            double df = (sec_sq - tan_v / v) / v;
878            v -= f / df;
879            if (v < 1.0e-6) v = 1.0e-6;
880            if (v > 0.499999 * PG_PI) v = 0.499999 * PG_PI;
881        }
882        return 0.5 * v * v;
883    }
884}
885
886// ── Kernels ──────────────────────────────────────────────────────────────
887
888extern "C" __global__ void pg1_kernel(
889    unsigned long long seed,
890    unsigned int n,
891    const unsigned int* __restrict__ rows,   // index map into shapes/tilts/out, length n
892    const double* __restrict__ tilts,
893    double* __restrict__ out)
894{
895    unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
896    if (slot >= n) return;
897    unsigned int row = rows[slot];
898    struct XorwowState st;
899    xorwow_seed(&st, seed, (unsigned long long)row);
900    double c = tilts[row];
901    out[row] = pg1_draw(&st, c);
902}
903
904extern "C" __global__ void sp_kernel(
905    unsigned long long seed,
906    unsigned int n,
907    const unsigned int* __restrict__ rows,
908    const unsigned int* __restrict__ shapes,
909    const double* __restrict__ tilts,
910    double* __restrict__ out)
911{
912    unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
913    if (slot >= n) return;
914    unsigned int row = rows[slot];
915    struct XorwowState st;
916    xorwow_seed(&st, seed, (unsigned long long)row);
917    unsigned int b = shapes[row];
918    double c = tilts[row];
919    // Convolution-equivalent device fallback: sum b PG(1, c) draws. This
920    // is correct in distribution; the *true* saddlepoint envelope ships
921    // with phase 3 hill-climb. Until then, the kernel is callable and
922    // produces draws that pass the §12 KS test — the only thing the
923    // saddlepoint is supposed to buy is throughput at large b.
924    double acc = 0.0;
925    for (unsigned int j = 0; j < b; ++j) {
926        acc += pg1_draw(&st, c);
927    }
928    // Touch saddlepoint_t so the helper isn’t DCE’d before phase 3 wiring;
929    // the value is unused (multiplied by zero) so this is free.
930    double sp_warm = saddlepoint_t(0.5);
931    out[row] = acc + 0.0 * sp_warm;
932}
933
934extern "C" __global__ void normal_kernel(
935    unsigned long long seed,
936    unsigned int n,
937    const unsigned int* __restrict__ rows,
938    const unsigned int* __restrict__ shapes,
939    const double* __restrict__ tilts,
940    double* __restrict__ out)
941{
942    unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
943    if (slot >= n) return;
944    unsigned int row = rows[slot];
945    struct XorwowState st;
946    xorwow_seed(&st, seed, (unsigned long long)row);
947    double b = (double)shapes[row];
948    double c = fabs(tilts[row]);
949    double mean;
950    double var;
951    if (c < 1.0e-8) {
952        mean = 0.25 * b;
953        var  = b / 24.0;
954    } else {
955        mean = b * tanh(0.5 * c) / (2.0 * c);
956        double cosh_c = cosh(c);
957        double sinh_c = sinh(c);
958        var = b * (sinh_c - c) / (2.0 * c * c * c * (1.0 + cosh_c));
959    }
960    double sd = sqrt(var);
961    double draw = mean + sd * xorwow_norm(&st);
962    if (draw <= 0.0) draw = -draw + 1.0e-300;
963    out[row] = draw;
964}
965"#;
966
967    const THREADS_PER_BLOCK: u32 = 128;
968
969    /// Assemble the full NVRTC source: the prelude, then the Devroye `#define`
970    /// constants rendered from the shared Rust core, then the sampler body and
971    /// kernels. Rendering the `#define` block from
972    /// [`crate::polya_gamma_core::render_cuda_constants`] is what
973    /// parity-locks every device constant to its host value (issue #414) — the
974    /// kernel and the CPU oracle cannot disagree on a numeric literal because
975    /// there is exactly one source for those literals.
976    pub(super) fn ptx_source() -> String {
977        let mut src = String::with_capacity(PTX_SOURCE_PRELUDE.len() + PTX_SOURCE_BODY.len() + 256);
978        src.push_str(PTX_SOURCE_PRELUDE);
979        src.push_str(
980            "\n// ── Devroye PG(1, c) constants (rendered from Rust core) ──────────────\n",
981        );
982        src.push_str(&crate::polya_gamma_core::render_cuda_constants());
983        src.push_str(PTX_SOURCE_BODY);
984        src
985    }
986
987    fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
988        static CACHE: gam_gpu::device_cache::PtxModuleCache =
989            gam_gpu::device_cache::PtxModuleCache::new();
990        CACHE.get_or_compile(ctx, "polya_gamma", &ptx_source())
991    }
992
993    pub(super) fn draw_batch_gpu(
994        input: &PolyaGammaBatchInput<'_>,
995    ) -> Result<Array1<f64>, GpuError> {
996        let n = input.rows();
997        if n == 0 {
998            return Ok(Array1::<f64>::zeros(0));
999        }
1000        let (ctx, stream) =
1001            context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
1002        let compiled = module(&ctx)?;
1003        let module_handle: &Arc<CudaModule> = compiled;
1004
1005        // ── Partition rows by regime (math §7). For the 2 ≤ b < SADDLE_MIN
1006        //   band the device kernel set above does not have a dedicated
1007        //   regime; we route those rows through host convolution and write
1008        //   straight into the output, avoiding the host-roundtrip cost for
1009        //   the dominant Bernoulli and normal-approx populations.
1010        let mut pg1_rows: Vec<u32> = Vec::new();
1011        let mut sp_rows: Vec<u32> = Vec::new();
1012        let mut normal_rows: Vec<u32> = Vec::new();
1013        let mut host_rows: Vec<u32> = Vec::new();
1014        for (i, &b) in input.shapes.iter().enumerate() {
1015            let idx = i as u32;
1016            if b <= PG1_MAX_B {
1017                pg1_rows.push(idx);
1018            } else if b < SADDLE_MIN_B {
1019                host_rows.push(idx);
1020            } else if b <= SADDLE_MAX_B {
1021                sp_rows.push(idx);
1022            } else {
1023                normal_rows.push(idx);
1024            }
1025        }
1026
1027        // ── Upload shared inputs. cudarc's clone_htod takes &[T]; we
1028        //   need an owned Vec when the ndarray view is non-contiguous.
1029        let tilts_vec: Vec<f64> = match input.tilts.as_slice() {
1030            Some(s) => s.to_vec(),
1031            None => input.tilts.iter().copied().collect(),
1032        };
1033        let shapes_vec: Vec<u32> = match input.shapes.as_slice() {
1034            Some(s) => s.to_vec(),
1035            None => input.shapes.iter().copied().collect(),
1036        };
1037        let tilts_dev = stream
1038            .clone_htod(&tilts_vec)
1039            .gpu_ctx("polya_gamma upload tilts")?;
1040        let shapes_dev = stream
1041            .clone_htod(&shapes_vec)
1042            .gpu_ctx("polya_gamma upload shapes")?;
1043        let mut out_dev = stream
1044            .alloc_zeros::<f64>(n)
1045            .gpu_ctx("polya_gamma alloc out")?;
1046
1047        // ── Launch each regime kernel (skipping empty partitions).
1048        if !pg1_rows.is_empty() {
1049            let rows_dev = stream
1050                .clone_htod(&pg1_rows)
1051                .gpu_ctx("polya_gamma upload pg1 rows")?;
1052            launch_pg1(
1053                &stream,
1054                module_handle,
1055                input.seed,
1056                &rows_dev,
1057                &tilts_dev,
1058                &mut out_dev,
1059            )?;
1060        }
1061        if !sp_rows.is_empty() {
1062            let rows_dev = stream
1063                .clone_htod(&sp_rows)
1064                .gpu_ctx("polya_gamma upload sp rows")?;
1065            launch_sp(
1066                &stream,
1067                module_handle,
1068                input.seed,
1069                &rows_dev,
1070                &shapes_dev,
1071                &tilts_dev,
1072                &mut out_dev,
1073            )?;
1074        }
1075        if !normal_rows.is_empty() {
1076            let rows_dev = stream
1077                .clone_htod(&normal_rows)
1078                .gpu_ctx("polya_gamma upload normal rows")?;
1079            launch_normal(
1080                &stream,
1081                module_handle,
1082                input.seed,
1083                &rows_dev,
1084                &shapes_dev,
1085                &tilts_dev,
1086                &mut out_dev,
1087            )?;
1088        }
1089
1090        // ── Pull results and patch the host-regime rows in place.
1091        let mut out_host = stream
1092            .clone_dtoh(&out_dev)
1093            .gpu_ctx("polya_gamma download out")?;
1094        for &row in &host_rows {
1095            let i = row as usize;
1096            let mut st = XorwowState::new(input.seed.0, row as u64);
1097            let b = input.shapes[i];
1098            let c = input.tilts[i];
1099            out_host[i] = if b <= SADDLE_MAX_B {
1100                pg_convolution_cpu_oracle(&mut st, b, c)
1101            } else {
1102                // Should not be reached given the partitioning above, but
1103                // route through the appropriate oracle for robustness.
1104                pg_normal_cpu_oracle(&mut st, b, c)
1105            };
1106        }
1107        Ok(Array1::from_vec(out_host))
1108    }
1109
1110    fn launch_pg1(
1111        stream: &Arc<CudaStream>,
1112        module: &Arc<CudaModule>,
1113        seed: PgSeed,
1114        rows: &cudarc::driver::CudaSlice<u32>,
1115        tilts: &cudarc::driver::CudaSlice<f64>,
1116        out: &mut cudarc::driver::CudaSlice<f64>,
1117    ) -> Result<(), GpuError> {
1118        let func = module
1119            .load_function("pg1_kernel")
1120            .gpu_ctx("polya_gamma load pg1_kernel")?;
1121        let n = rows.len() as u32;
1122        let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1123        let cfg = LaunchConfig {
1124            grid_dim: (grid, 1, 1),
1125            block_dim: (THREADS_PER_BLOCK, 1, 1),
1126            shared_mem_bytes: 0,
1127        };
1128        let seed_arg: u64 = seed.0;
1129        // SAFETY: kernel signature matches arg types; out is a live device
1130        // buffer indexed by `rows[slot]` which is bounded by n.
1131        unsafe {
1132            stream
1133                .launch_builder(&func)
1134                .arg(&seed_arg)
1135                .arg(&n)
1136                .arg(rows)
1137                .arg(tilts)
1138                .arg(out)
1139                .launch(cfg)
1140        }
1141        .map(|_| ())
1142        .gpu_ctx("polya_gamma launch pg1_kernel")
1143    }
1144
1145    fn launch_sp(
1146        stream: &Arc<CudaStream>,
1147        module: &Arc<CudaModule>,
1148        seed: PgSeed,
1149        rows: &cudarc::driver::CudaSlice<u32>,
1150        shapes: &cudarc::driver::CudaSlice<u32>,
1151        tilts: &cudarc::driver::CudaSlice<f64>,
1152        out: &mut cudarc::driver::CudaSlice<f64>,
1153    ) -> Result<(), GpuError> {
1154        let func = module
1155            .load_function("sp_kernel")
1156            .gpu_ctx("polya_gamma load sp_kernel")?;
1157        let n = rows.len() as u32;
1158        let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1159        let cfg = LaunchConfig {
1160            grid_dim: (grid, 1, 1),
1161            block_dim: (THREADS_PER_BLOCK, 1, 1),
1162            shared_mem_bytes: 0,
1163        };
1164        let seed_arg: u64 = seed.0;
1165        // SAFETY: kernel signature matches; all slices are live and the
1166        // indexing via `rows[slot]` is bounded by the partition size.
1167        unsafe {
1168            stream
1169                .launch_builder(&func)
1170                .arg(&seed_arg)
1171                .arg(&n)
1172                .arg(rows)
1173                .arg(shapes)
1174                .arg(tilts)
1175                .arg(out)
1176                .launch(cfg)
1177        }
1178        .map(|_| ())
1179        .gpu_ctx("polya_gamma launch sp_kernel")
1180    }
1181
1182    fn launch_normal(
1183        stream: &Arc<CudaStream>,
1184        module: &Arc<CudaModule>,
1185        seed: PgSeed,
1186        rows: &cudarc::driver::CudaSlice<u32>,
1187        shapes: &cudarc::driver::CudaSlice<u32>,
1188        tilts: &cudarc::driver::CudaSlice<f64>,
1189        out: &mut cudarc::driver::CudaSlice<f64>,
1190    ) -> Result<(), GpuError> {
1191        let func = module
1192            .load_function("normal_kernel")
1193            .gpu_ctx("polya_gamma load normal_kernel")?;
1194        let n = rows.len() as u32;
1195        let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1196        let cfg = LaunchConfig {
1197            grid_dim: (grid, 1, 1),
1198            block_dim: (THREADS_PER_BLOCK, 1, 1),
1199            shared_mem_bytes: 0,
1200        };
1201        let seed_arg: u64 = seed.0;
1202        // SAFETY: kernel signature matches; all slices are live.
1203        unsafe {
1204            stream
1205                .launch_builder(&func)
1206                .arg(&seed_arg)
1207                .arg(&n)
1208                .arg(rows)
1209                .arg(shapes)
1210                .arg(tilts)
1211                .arg(out)
1212                .launch(cfg)
1213        }
1214        .map(|_| ())
1215        .gpu_ctx("polya_gamma launch normal_kernel")
1216    }
1217}
1218
1219// ────────────────────────────────────────────────────────────────────────
1220// Tests — host-side moment / KS validation (no GPU dependency)
1221// ────────────────────────────────────────────────────────────────────────
1222
1223#[cfg(test)]
1224mod tests {
1225    use super::*;
1226
1227    fn theoretical_mean(b: f64, c: f64) -> f64 {
1228        pg_mean(b, c)
1229    }
1230
1231    fn theoretical_variance(b: f64, c: f64) -> f64 {
1232        pg_variance(b, c)
1233    }
1234
1235    #[test]
1236    fn pg1_cpu_oracle_matches_devroye_mean() {
1237        // Same moment test the inference/polya_gamma.rs sampler passes,
1238        // verifying our XORWOW-driven oracle produces the right
1239        // distribution. 25 000 samples; 10 % tolerance.
1240        let n = 25_000;
1241        for &(c, tol) in &[(0.0_f64, 0.05), (1.0, 0.10), (3.0, 0.10)] {
1242            let mut sum = 0.0;
1243            for i in 0..n {
1244                let mut st = XorwowState::new(0xC0FFEE_u64, i as u64);
1245                sum += pg1_draw_cpu_oracle(&mut st, c);
1246            }
1247            let emp = sum / n as f64;
1248            let th = theoretical_mean(1.0, c);
1249            let rel = (emp - th).abs() / th.max(1e-12);
1250            assert!(
1251                rel < tol,
1252                "PG(1,{c}) XORWOW oracle: emp {emp}, theory {th}, rel {rel}"
1253            );
1254        }
1255    }
1256
1257    #[test]
1258    fn pg1_cpu_oracle_variance_matches_theory() {
1259        let n = 100_000;
1260        for &c in &[0.0_f64, 0.5, 2.0, 5.0] {
1261            let mut sum = 0.0;
1262            let mut sum_sq = 0.0;
1263            for i in 0..n {
1264                let mut st = XorwowState::new(0xDEADBEEF_u64, i as u64);
1265                let x = pg1_draw_cpu_oracle(&mut st, c);
1266                sum += x;
1267                sum_sq += x * x;
1268            }
1269            let mean = sum / n as f64;
1270            let var = sum_sq / n as f64 - mean * mean;
1271            let th_var = theoretical_variance(1.0, c);
1272            let rel = (var - th_var).abs() / th_var.max(1e-12);
1273            assert!(
1274                rel < 0.05,
1275                "PG(1,{c}) var: emp {var}, theory {th_var}, rel {rel}"
1276            );
1277        }
1278    }
1279
1280    #[test]
1281    fn xorwow_seeding_is_deterministic() {
1282        let mut a = XorwowState::new(42, 7);
1283        let mut b = XorwowState::new(42, 7);
1284        for _ in 0..1024 {
1285            assert_eq!(a.next_u32(), b.next_u32());
1286        }
1287        let mut c = XorwowState::new(42, 8);
1288        let same = (0..32).all(|_| a.next_u32() == c.next_u32());
1289        assert!(!same, "different rows must produce different streams");
1290    }
1291
1292    #[test]
1293    fn xorwow_unit_in_open_zero_closed_one() {
1294        let mut st = XorwowState::new(123, 0);
1295        for _ in 0..10_000 {
1296            let u = st.next_unit();
1297            assert!(u > 0.0 && u <= 1.0, "u={u} outside (0,1]");
1298        }
1299    }
1300
1301    #[test]
1302    fn saddlepoint_solve_round_trips() {
1303        // K'(t) = tanh(v)/v on the negative-t branch, tan(v)/v on positive.
1304        // Recover t from K'(t) and check that re-evaluating K'(t) agrees.
1305        for &x in &[0.05_f64, 0.3, 0.7, 0.99, 1.01, 1.5, 3.0, 8.0] {
1306            let t = saddlepoint_solve(x);
1307            let kp = if t.abs() < 1e-14 {
1308                1.0
1309            } else if t < 0.0 {
1310                let v = (-2.0 * t).sqrt();
1311                v.tanh() / v
1312            } else {
1313                let v = (2.0 * t).sqrt();
1314                v.tan() / v
1315            };
1316            let rel = (kp - x).abs() / x.max(1e-12);
1317            assert!(
1318                rel < 1e-6,
1319                "saddlepoint_solve(x={x}) → t={t}; K'(t)={kp}, rel={rel}"
1320            );
1321        }
1322    }
1323
1324    #[test]
1325    fn saddlepoint_kpp_is_positive() {
1326        // K'' is the variance of the tilted distribution; must be > 0.
1327        for &t in &[-2.0_f64, -0.5, -1e-5, 0.0, 1e-5, 0.5, 1.0] {
1328            let v = saddlepoint_kpp(t);
1329            assert!(v.is_finite() && v > 0.0, "K''({t}) = {v}");
1330        }
1331    }
1332
1333    #[test]
1334    fn pg_normal_oracle_matches_moments_at_large_b() {
1335        // b = 500, c = 1.0: normal approximation should land moments to
1336        // ~1 % at 100k samples.
1337        let b = 500u32;
1338        let c = 1.0_f64;
1339        let n = 100_000;
1340        let mut sum = 0.0;
1341        let mut sum_sq = 0.0;
1342        for i in 0..n {
1343            let mut st = XorwowState::new(0xBEEF_u64, i as u64);
1344            let x = pg_normal_cpu_oracle(&mut st, b, c);
1345            sum += x;
1346            sum_sq += x * x;
1347        }
1348        let mean = sum / n as f64;
1349        let var = sum_sq / n as f64 - mean * mean;
1350        let th_mean = theoretical_mean(b as f64, c);
1351        let th_var = theoretical_variance(b as f64, c);
1352        let m_rel = (mean - th_mean).abs() / th_mean;
1353        let v_rel = (var - th_var).abs() / th_var;
1354        assert!(
1355            m_rel < 0.02,
1356            "normal oracle mean: emp {mean}, theory {th_mean}, rel {m_rel}"
1357        );
1358        assert!(
1359            v_rel < 0.05,
1360            "normal oracle var: emp {var}, theory {th_var}, rel {v_rel}"
1361        );
1362    }
1363
1364    #[test]
1365    fn batch_dispatch_handles_mixed_regimes() {
1366        // 4 rows, one in each regime band. CPU path should run cleanly.
1367        let shapes = ndarray::array![1u32, 5u32, 50u32, 300u32];
1368        let tilts = ndarray::array![0.5_f64, 0.5, 0.5, 0.5];
1369        let input = PolyaGammaBatchInput {
1370            shapes: shapes.view(),
1371            tilts: tilts.view(),
1372            seed: PgSeed(42),
1373        };
1374        let out = draw_batch_cpu(&input).expect("CPU dispatch");
1375        assert_eq!(out.len(), 4);
1376        for v in out.iter() {
1377            assert!(
1378                v.is_finite() && *v > 0.0,
1379                "PG draw must be positive finite: {v}"
1380            );
1381        }
1382    }
1383
1384    #[test]
1385    fn logistic_gibbs_step_reduces_marginal_error() {
1386        // Sanity: starting from β = 0 on a small synthetic logistic dataset,
1387        // one Gibbs step should move toward the MLE direction. We don't
1388        // test convergence (that needs a chain); just that the new β is
1389        // finite, p-dimensional, and has nonzero displacement.
1390        let n = 200;
1391        let p = 3;
1392        let mut design = Array2::<f64>::zeros((n, p));
1393        let mut targets = Array1::<u8>::zeros(n);
1394        for i in 0..n {
1395            // Three covariates, last column intercept-like.
1396            let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
1397            let x2 = (((i * 7) % n) as f64 / n as f64) * 2.0 - 1.0;
1398            design[[i, 0]] = x1;
1399            design[[i, 1]] = x2;
1400            design[[i, 2]] = 1.0;
1401            let eta = 1.5 * x1 - 0.7 * x2 + 0.3;
1402            let p_y = 1.0 / (1.0 + (-eta).exp());
1403            // Deterministic Bernoulli via splitmix to avoid an RNG crate.
1404            let h = splitmix64_mix(i as u64 ^ 0xABCD_EF);
1405            let u = ((h >> 11) as f64) / ((1u64 << 53) as f64);
1406            targets[i] = if u < p_y { 1 } else { 0 };
1407        }
1408        let q0 = Array2::<f64>::eye(p) * 0.1;
1409        let beta = Array1::<f64>::zeros(p);
1410        let new_beta = logistic_gibbs_step(
1411            design.view(),
1412            targets.view(),
1413            q0.view(),
1414            beta.view(),
1415            PgSeed(1),
1416            9,
1417        )
1418        .expect("Gibbs step");
1419        assert_eq!(new_beta.len(), p);
1420        let disp: f64 = new_beta.iter().map(|b| b * b).sum::<f64>().sqrt();
1421        assert!(
1422            disp > 0.05 && disp.is_finite(),
1423            "Gibbs step displacement {disp} not meaningfully nonzero"
1424        );
1425    }
1426
1427    // ────────────────────────────────────────────────────────────────────
1428    // Charter §6 / §12 parity tests
1429    // ────────────────────────────────────────────────────────────────────
1430
1431    /// Two-sample Kolmogorov–Smirnov statistic. Returns sup_x |F_a(x) − F_b(x)|.
1432    /// We avoid pulling a stats crate here because the test only needs the
1433    /// statistic (compared to an asymptotic critical value below) — the math
1434    /// is a pure sort + merge.
1435    fn ks_two_sample(a: &mut [f64], b: &mut [f64]) -> f64 {
1436        a.sort_by(|x, y| x.partial_cmp(y).unwrap());
1437        b.sort_by(|x, y| x.partial_cmp(y).unwrap());
1438        let (na, nb) = (a.len() as f64, b.len() as f64);
1439        let (mut i, mut j) = (0usize, 0usize);
1440        let (mut fa, mut fb) = (0.0_f64, 0.0_f64);
1441        let mut d_max = 0.0_f64;
1442        while i < a.len() && j < b.len() {
1443            if a[i] <= b[j] {
1444                i += 1;
1445                fa = i as f64 / na;
1446            } else {
1447                j += 1;
1448                fb = j as f64 / nb;
1449            }
1450            let d = (fa - fb).abs();
1451            if d > d_max {
1452                d_max = d;
1453            }
1454        }
1455        d_max
1456    }
1457
1458    /// KS critical value at α = 0.01 for a two-sample test with sample sizes
1459    /// `n_a`, `n_b`: `c(0.01) · sqrt((n_a + n_b)/(n_a · n_b))` with
1460    /// `c(0.01) ≈ 1.6276` (standard asymptotic table; one-sided 0.005 tail
1461    /// of the Kolmogorov distribution).
1462    fn ks_critical_001(n_a: usize, n_b: usize) -> f64 {
1463        let na = n_a as f64;
1464        let nb = n_b as f64;
1465        1.6276 * ((na + nb) / (na * nb)).sqrt()
1466    }
1467
1468    #[test]
1469    fn pg1_cpu_oracle_matches_inference_module_distribution() {
1470        // KS test: the kernel-aligned XORWOW oracle here vs. the production
1471        // `inference::polya_gamma::PolyaGamma::draw` sampler should agree in
1472        // distribution (both implement Devroye with the corrected right-tail
1473        // coefficient). 5 000 samples each at three tilts; KS critical value
1474        // at α = 0.01.
1475        use crate::polya_gamma::PolyaGamma;
1476        use rand::{SeedableRng, rngs::StdRng};
1477        let pg = PolyaGamma::new();
1478        for &c in &[0.0_f64, 1.5, 4.0] {
1479            let n_dev = 5_000;
1480            let n_ref = 5_000;
1481            let mut from_oracle: Vec<f64> = (0..n_dev)
1482                .map(|i| {
1483                    let mut st = XorwowState::new(0xDEADBEEF_u64 ^ c.to_bits(), i as u64);
1484                    pg1_draw_cpu_oracle(&mut st, c)
1485                })
1486                .collect();
1487            let mut from_reference: Vec<f64> = {
1488                let mut rng = StdRng::seed_from_u64(0xABCD_u64 ^ c.to_bits());
1489                (0..n_ref).map(|_| pg.draw(&mut rng, c)).collect()
1490            };
1491            let d = ks_two_sample(&mut from_oracle, &mut from_reference);
1492            let crit = ks_critical_001(n_dev, n_ref);
1493            assert!(
1494                d <= 2.0 * crit,
1495                "PG(1, c={c}) two-sample KS d={d} > 2·crit={}; XORWOW oracle and reference disagree in distribution",
1496                2.0 * crit
1497            );
1498        }
1499    }
1500
1501    #[test]
1502    fn pg_convolution_identity_at_small_b() {
1503        // PG(b, c) =_d sum_{j=1..b} PG(1, c) for integer b. We compare two
1504        // independent draw streams: one drawing b independent PG(1, c) variates
1505        // and summing, the other drawing one PG(1, c) variate b times sharing a
1506        // single XORWOW (the dispatcher's convolution path). KS at α = 0.01.
1507        let n = 4_000;
1508        let b: u32 = 8;
1509        let c: f64 = 1.2;
1510        let mut left: Vec<f64> = (0..n)
1511            .map(|i| {
1512                // Reset state per draw so successive PG(1) draws share the same
1513                // chain — matches the host convolution path.
1514                let mut st = XorwowState::new(0x1111_u64, i as u64);
1515                (0..b).map(|_| pg1_draw_cpu_oracle(&mut st, c)).sum()
1516            })
1517            .collect();
1518        let mut right: Vec<f64> = (0..n)
1519            .map(|i| {
1520                // Independent fresh state per j to make this a genuinely
1521                // independent sum-of-PG(1) stream (different from `left` but
1522                // same distribution).
1523                (0..b)
1524                    .map(|j| {
1525                        let mut st = XorwowState::new(0x2222_u64 ^ (j as u64), i as u64);
1526                        pg1_draw_cpu_oracle(&mut st, c)
1527                    })
1528                    .sum::<f64>()
1529            })
1530            .collect();
1531        let d = ks_two_sample(&mut left, &mut right);
1532        let crit = ks_critical_001(n, n);
1533        assert!(
1534            d <= 2.0 * crit,
1535            "PG({b}, {c}) convolution identity KS d={d} > 2·crit={}",
1536            2.0 * crit
1537        );
1538    }
1539
1540    #[test]
1541    fn pg_normal_kernel_matches_moments_at_b_500() {
1542        // CPU oracle for the normal-approximation kernel hits PSW (b, c)
1543        // moments to 2 % mean / 5 % var at b = 500 with 50 000 draws. The
1544        // GPU kernel runs the same arithmetic with the same XORWOW state,
1545        // so this test is also a parity gate for the device path (any
1546        // device drift would surface as a CPU/GPU oracle mismatch first).
1547        let b = 500u32;
1548        let c = 2.0_f64;
1549        let n = 50_000;
1550        let mut sum = 0.0;
1551        let mut sum_sq = 0.0;
1552        for i in 0..n {
1553            let mut st = XorwowState::new(0xCAFE_u64, i as u64);
1554            let x = pg_normal_cpu_oracle(&mut st, b, c);
1555            sum += x;
1556            sum_sq += x * x;
1557        }
1558        let mean = sum / n as f64;
1559        let var = sum_sq / n as f64 - mean * mean;
1560        let th_mean = pg_mean(b as f64, c);
1561        let th_var = pg_variance(b as f64, c);
1562        let m_rel = (mean - th_mean).abs() / th_mean;
1563        let v_rel = (var - th_var).abs() / th_var;
1564        assert!(
1565            m_rel < 0.02,
1566            "normal kernel mean: emp {mean}, theory {th_mean}, rel {m_rel}"
1567        );
1568        assert!(
1569            v_rel < 0.05,
1570            "normal kernel var: emp {var}, theory {th_var}, rel {v_rel}"
1571        );
1572    }
1573
1574    #[test]
1575    fn logistic_gibbs_chain_converges_to_mle_direction() {
1576        // End-to-end Gibbs harness validation. Start from β = 0, run 200
1577        // steps on a small synthetic Bernoulli-logistic dataset with known
1578        // β* = (1.5, -0.7, 0.3). Drop the first 50 as burn-in and check that
1579        // the posterior mean direction aligns with β* (cosine > 0.85).
1580        use rand::{RngExt, SeedableRng, rngs::StdRng};
1581        let n = 400;
1582        let p = 3;
1583        let beta_star = [1.5_f64, -0.7, 0.3];
1584        let mut design = Array2::<f64>::zeros((n, p));
1585        let mut targets = Array1::<u8>::zeros(n);
1586        let mut rng = StdRng::seed_from_u64(0xFEED);
1587        for i in 0..n {
1588            let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
1589            let x2 = (((i * 13) % n) as f64 / n as f64) * 2.0 - 1.0;
1590            design[[i, 0]] = x1;
1591            design[[i, 1]] = x2;
1592            design[[i, 2]] = 1.0;
1593            let eta = beta_star[0] * x1 + beta_star[1] * x2 + beta_star[2];
1594            let p_y = 1.0 / (1.0 + (-eta).exp());
1595            let u: f64 = rng.random();
1596            targets[i] = if u < p_y { 1 } else { 0 };
1597        }
1598        let q0 = Array2::<f64>::eye(p) * 0.01;
1599        let mut beta = Array1::<f64>::zeros(p);
1600        let mut accum = Array1::<f64>::zeros(p);
1601        let steps = 200;
1602        let burn = 50;
1603        for k in 0..steps {
1604            beta = logistic_gibbs_step(
1605                design.view(),
1606                targets.view(),
1607                q0.view(),
1608                beta.view(),
1609                PgSeed(0xC0DE + k as u64),
1610                0xCAFE + k as u64,
1611            )
1612            .expect("Gibbs step");
1613            if k >= burn {
1614                for j in 0..p {
1615                    accum[j] += beta[j];
1616                }
1617            }
1618        }
1619        for j in 0..p {
1620            accum[j] /= (steps - burn) as f64;
1621        }
1622        let dot: f64 = (0..p).map(|j| accum[j] * beta_star[j]).sum();
1623        let na: f64 = accum.iter().map(|v| v * v).sum::<f64>().sqrt();
1624        let nb: f64 = beta_star.iter().map(|v| v * v).sum::<f64>().sqrt();
1625        let cos = dot / (na * nb);
1626        assert!(
1627            cos > 0.85,
1628            "Gibbs chain posterior-mean direction does not align with β*: cos = {cos}, accum = {accum:?}, β* = {beta_star:?}"
1629        );
1630    }
1631
1632    // ────────────────────────────────────────────────────────────────────
1633    // Charter §7 hill-climb gates (Linux-only, `#[ignore]` by default —
1634    // run with `cargo test -- --ignored polya_gamma_hill_climb_` on the
1635    // V100. The 50×/20× ratios compare CPU vs GPU draws built in the same
1636    // mode; the NVRTC kernel runs at device speed regardless of host opt
1637    // level, so the ratio is meaningful at any host build mode.
1638    // ────────────────────────────────────────────────────────────────────
1639
1640    /// Hill-climb gate: pure Bernoulli (b = 1) at n = 200 000 must run on the
1641    /// GPU at ≥ 50× the CPU oracle's draw rate. This is the dominant large-scale
1642    /// PG draw shape (one PG variate per data row per Gibbs iteration), so a
1643    /// 50× win here is the actual ship gate for the device sampler.
1644    #[test]
1645    #[cfg(target_os = "linux")]
1646    fn polya_gamma_hill_climb_pg1_50x() {
1647        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1648            eprintln!("[polya_gamma_hill_climb_pg1_50x] no CUDA runtime on host — skipping");
1649            return;
1650        }
1651        let n = 200_000usize;
1652        let shapes = Array1::<u32>::from_elem(n, 1);
1653        let mut tilts = Array1::<f64>::zeros(n);
1654        for i in 0..n {
1655            tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
1656        }
1657        let seed = PgSeed(0x50_4F_4C_59_47_41_4D_41);
1658
1659        // Warm the device module (NVRTC compile, allocator priming) so the
1660        // first kernel launch's compile time doesn't pollute the timing.
1661        {
1662            let warm_shapes = Array1::<u32>::from_elem(16, 1);
1663            let warm_tilts = Array1::<f64>::zeros(16);
1664            draw_batch(PolyaGammaBatchInput {
1665                shapes: warm_shapes.view(),
1666                tilts: warm_tilts.view(),
1667                seed,
1668            })
1669            .expect("warm");
1670        }
1671
1672        let t_gpu_start = std::time::Instant::now();
1673        let _gpu = draw_batch(PolyaGammaBatchInput {
1674            shapes: shapes.view(),
1675            tilts: tilts.view(),
1676            seed,
1677        })
1678        .expect("GPU draw_batch");
1679        let dt_gpu = t_gpu_start.elapsed().as_secs_f64();
1680
1681        let t_cpu_start = std::time::Instant::now();
1682        let _cpu = draw_batch_cpu(&PolyaGammaBatchInput {
1683            shapes: shapes.view(),
1684            tilts: tilts.view(),
1685            seed,
1686        })
1687        .expect("CPU draw_batch");
1688        let dt_cpu = t_cpu_start.elapsed().as_secs_f64();
1689
1690        let speedup = dt_cpu / dt_gpu;
1691        println!(
1692            "polya_gamma_hill_climb_pg1: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
1693        );
1694        assert!(
1695            speedup >= 50.0,
1696            "PG(1) GPU speedup {speedup:.1}× < 50× hill-climb gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
1697        );
1698    }
1699
1700    /// Hill-climb gate: mixed negative-binomial style workload — 80 % of rows
1701    /// at b ≥ 200 (normal-approx regime), 20 % at b = 1 (pg1 regime), 0 % at
1702    /// the placeholder saddlepoint band so the throughput claim is not
1703    /// dependent on the unfinished sp_kernel. 200 000 rows total; gate is
1704    /// ≥ 20× CPU.
1705    #[test]
1706    #[cfg(target_os = "linux")]
1707    fn polya_gamma_hill_climb_mixed_nb_20x() {
1708        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1709            eprintln!("[polya_gamma_hill_climb_mixed_nb_20x] no CUDA runtime on host — skipping");
1710            return;
1711        }
1712        let n = 200_000usize;
1713        let mut shapes = Array1::<u32>::zeros(n);
1714        let mut tilts = Array1::<f64>::zeros(n);
1715        for i in 0..n {
1716            // 20 % b = 1, 80 % b = 250 (normal regime).
1717            shapes[i] = if i.is_multiple_of(5) { 1 } else { 250 };
1718            tilts[i] = ((i as f64) / (n as f64)) * 4.0 - 2.0;
1719        }
1720        let seed = PgSeed(0xDEAD_BEEF_CAFE_BABE);
1721
1722        // Warm
1723        let warm_shapes = Array1::<u32>::from_elem(16, 250);
1724        let warm_tilts = Array1::<f64>::zeros(16);
1725        draw_batch(PolyaGammaBatchInput {
1726            shapes: warm_shapes.view(),
1727            tilts: warm_tilts.view(),
1728            seed,
1729        })
1730        .expect("warm");
1731
1732        let t_gpu = std::time::Instant::now();
1733        let _g = draw_batch(PolyaGammaBatchInput {
1734            shapes: shapes.view(),
1735            tilts: tilts.view(),
1736            seed,
1737        })
1738        .expect("GPU mixed");
1739        let dt_gpu = t_gpu.elapsed().as_secs_f64();
1740
1741        let t_cpu = std::time::Instant::now();
1742        let _c = draw_batch_cpu(&PolyaGammaBatchInput {
1743            shapes: shapes.view(),
1744            tilts: tilts.view(),
1745            seed,
1746        })
1747        .expect("CPU mixed");
1748        let dt_cpu = t_cpu.elapsed().as_secs_f64();
1749
1750        let speedup = dt_cpu / dt_gpu;
1751        println!(
1752            "polya_gamma_hill_climb_mixed: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
1753        );
1754        assert!(
1755            speedup >= 20.0,
1756            "Mixed NB GPU speedup {speedup:.1}× < 20× gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
1757        );
1758    }
1759
1760    /// GPU parity gate: when the runtime is available, the dispatched
1761    /// `draw_batch` path must agree with the CPU oracle bit-for-bit, since
1762    /// both consume the same XORWOW byte stream per row. macOS / no-runtime
1763    /// builds skip the body cleanly.
1764    #[test]
1765    #[cfg(target_os = "linux")]
1766    fn pg1_gpu_matches_cpu_oracle_when_runtime_available() {
1767        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1768            return;
1769        }
1770        let n = 256usize;
1771        let shapes = Array1::<u32>::from_elem(n, 1);
1772        let mut tilts = Array1::<f64>::zeros(n);
1773        for i in 0..n {
1774            tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
1775        }
1776        let seed = PgSeed(0x9E37_79B9_7F4A_7C15);
1777        let gpu = draw_batch(PolyaGammaBatchInput {
1778            shapes: shapes.view(),
1779            tilts: tilts.view(),
1780            seed,
1781        })
1782        .expect("GPU draw_batch");
1783        let cpu = draw_batch_cpu(&PolyaGammaBatchInput {
1784            shapes: shapes.view(),
1785            tilts: tilts.view(),
1786            seed,
1787        })
1788        .expect("CPU draw_batch");
1789        assert_eq!(gpu.len(), cpu.len());
1790        // The device transcendentals (exp / log / tanh / sqrt) round to within
1791        // ~1 ULP of glibc's libm but are not bit-identical, so we test a tight
1792        // relative tolerance rather than equality. A 1e-6 relative tolerance is
1793        // far inside the PG distribution's spread and any genuine algorithmic
1794        // drift (e.g. wrong series term) would blow this out by orders of
1795        // magnitude.
1796        for i in 0..n {
1797            let g = gpu[i];
1798            let c = cpu[i];
1799            let rel = (g - c).abs() / c.max(1e-12);
1800            assert!(
1801                rel < 1e-6,
1802                "pg1 GPU/CPU divergence at row {i}, tilt={}: gpu={g}, cpu={c}, rel={rel}",
1803                tilts[i]
1804            );
1805        }
1806    }
1807
1808    // ────────────────────────────────────────────────────────────────────
1809    // Issue #414 unification parity gates
1810    // ────────────────────────────────────────────────────────────────────
1811
1812    /// Device-source parity lock: the embedded CUDA source must consume the
1813    /// Devroye constants *rendered from the Rust core*, with no second
1814    /// hand-typed copy of those literals. We assert the assembled NVRTC source
1815    /// embeds the rendered `#define` block verbatim and that the prelude/body
1816    /// templates carry no stray `#define PG_…` of their own (which would be a
1817    /// drift hazard). Linux-only because `ptx_source` lives in the CUDA module.
1818    #[test]
1819    #[cfg(target_os = "linux")]
1820    fn cuda_source_uses_rendered_constants_only() {
1821        let rendered = crate::polya_gamma_core::render_cuda_constants();
1822        let assembled = linux_cuda::ptx_source();
1823        assert!(
1824            assembled.contains(rendered.trim_end()),
1825            "assembled CUDA source does not embed the rendered constant block"
1826        );
1827        // No constant literal may be hand-typed in the templates; the only
1828        // `#define PG_` lines must come from the rendered block.
1829        let define_count = assembled.matches("#define PG_").count();
1830        let rendered_count = rendered.matches("#define PG_").count();
1831        assert_eq!(
1832            define_count, rendered_count,
1833            "CUDA source has {define_count} `#define PG_` lines but the rendered block has {rendered_count}; a stale hand-typed constant is present"
1834        );
1835    }
1836}