Skip to main content

Module gpu_polya_gamma

Module gpu_polya_gamma 

Source
Expand description

GPU Pólya–Gamma sampler primitive — INCOMPATIBLE with shipped probit BMS (different model).

This module implements a stand-alone, device-resident Pólya–Gamma sampler plus a synthetic logistic Gibbs harness used to validate the sampler because those are probit families — PG augmentation is exact only for the Bernoulli logistic likelihood (Polson, Scott & Windle 2013). Probit paths (bms_flex, bernoulli_marginal_slope) use a different likelihood and do not call this module.

The block 7 math design splits the device sampler into three regimes (math §7), each kernel laid out to avoid warp divergence inside the launch:

  • pg1_kernel — exact Devroye (math §8) for shape b = 1. This covers pure Bernoulli rows. Each row owns a curand-style XORWOW state seeded statelessly from (seed, row_index) so two runs with the same seed produce bit-identical draws regardless of grid layout. The alternating-series accept/reject uses the corrected right-tail coefficient π · k (not π / 2) — the math team’s Phase-1 fix.
  • sp_kernel — saddlepoint rejection (math §9) for 13 < b ≤ 170. This solves K'(t) = x via six Newton iterations on tanh(v)/v or tan(v)/v and uses an IG + Gamma envelope for the accept/reject.
  • normal_kernel — Lyapunov-CLT closed-form approximation (math §10) for b > 170. Mean and variance use the analytic PG(b, c) limit, no rejection loop, no warp divergence.

The host dispatcher partitions an input vector of (b_i, c_i) rows into three contiguous index lists (one per regime) and launches one kernel per regime. The 8 ≤ b ≤ 13 band is handled on host via the sum-of-PG(1, c) convolution identity — at small b the sum cost is negligible and keeping it off-device avoids a fourth kernel that would see almost no traffic in practice.

§What this primitive intentionally does NOT do

  • It does not plug into BMS marginal slope (probit model) — the PG augmentation identity is logit-only; doing so silently would change numerical results for shipped fits.
  • It does not define a public production family. The Gibbs harness in logistic_gibbs_step is a validation oracle for the sampler primitive, not a fit method. The CPU reference src/inference/polya_gamma.rs and the NUTS/HMC infrastructure remain the supported posterior-inference paths.

§Stateless XORWOW seeding

Each row’s XORWOW state (s0, s1, s2, s3, s4, counter) is materialised by feeding splitmix64( seed ⊕ row · ZETA ⊕ word · GAMMA ) for word indices 0..5 — five 32-bit lanes plus a 32-bit counter. The host reference RNG (xorwow_state_from) reproduces the same byte sequence the kernel emits, so CPU/GPU parity tests can compare draw-by-draw at the same (seed, row).

Structs§

PgSeed
Stateless seed for the per-row XORWOW PRNG. The same seed reproduces bit-identical draws across runs and across CPU/GPU implementations.
PolyaGammaBatchInput
Inputs for the dispatched batched sampler.
XorwowState
Compact per-row XORWOW state. Layout matches curand_kernel.h’s curandStateXORWOW_t for the five state lanes plus the addition counter; we omit the boxmuller cache (PG sampler doesn’t use it).

Constants§

NORMAL_MIN_B
PG1_MAX_B
Regime split thresholds (math §7).
SADDLE_MAX_B
SADDLE_MIN_B

Functions§

draw_batch
Top-level entry point: dispatches to GPU when available, otherwise CPU. Both paths use the same per-row XORWOW seeding so the GPU result is a bit-equivalent of the CPU result up to IEEE rounding of exp/log/ tan/tanh/sqrt (which the device evaluators round to within 1 ULP of the CPU libm).
draw_batch_cpu
Per-row CPU draw using the appropriate regime. Used by the harness when the GPU runtime is unavailable, and as the per-row oracle for the dispatched device path’s parity tests.
logistic_gibbs_step
Single Gibbs step for the synthetic Bernoulli-logistic model y_i | β ~ Bernoulli(σ(x_iᵀ β)) with prior β ~ N(0, Q_0⁻¹).
pg1_draw_cpu_oracle
CPU oracle for one PG(1, c) draw using a XorwowState directly. The device kernel performs the same arithmetic byte-for-byte (modulo IEEE rounding of transcendentals, which agree to <1 ULP for the inputs we touch).
pg_convolution_cpu_oracle
Higher-shape draw on host via convolution: PG(b, c) =d Σ{j=1..b} PG(1, c). Used by host for the 2 ≤ b ≤ 13 band and as the parity oracle for the saddlepoint kernel at modest b.
pg_mean
Mean of PG(b, c) (PSW 2013, eq. 4): E = b · tanh(c/2)/(2c), limit b/4.
pg_normal_cpu_oracle
Lyapunov-CLT closed-form draw for b > NORMAL_MIN_B. Truncated at zero because PG support is (0, +∞).
pg_saddlepoint_cpu_oracle
Saddlepoint host draw for PG(b, c) with 13 < b ≤ 170. This is the reference the device sp_kernel matches in distribution; both fall back to the convolution oracle when b is small enough that the saddlepoint approximation has noticeable bias (validated by §12.4 test).
pg_variance
Variance of PG(b, c): Var = b · (sinh c − c)/(2 c³ (1 + cosh c)), limit b/24.
saddlepoint_kpp
Saddlepoint approximation K’‘(t), the variance of the tilted distribution from K’(t) = x. K’’(t) is variance, so positive on both branches.
saddlepoint_solve
Solve K’(t) = x for the saddlepoint t given x in (0, 1). K’(t) is a continuous strictly increasing function of t on the appropriate branch; the math team’s parameterisation eliminates v = sqrt(|2t|) so the Newton iteration is on a monotone bounded variable.
splitmix64_mix
SplitMix64 finalizer (matches reml_trace::splitmix64_mix). Thin wrapper over the canonical implementation in gam_linalg::utils::splitmix64_hash.