Expand description
GPU Pólya–Gamma sampler primitive — INCOMPATIBLE with shipped probit BMS (different model).
This module implements a stand-alone, device-resident Pólya–Gamma sampler
plus a synthetic logistic Gibbs harness used to validate the sampler
because those are probit families — PG augmentation is exact only for the
Bernoulli logistic likelihood (Polson, Scott & Windle 2013). Probit
paths (bms_flex, bernoulli_marginal_slope) use a different likelihood and
do not call this module.
The block 7 math design splits the device sampler into three regimes (math §7), each kernel laid out to avoid warp divergence inside the launch:
pg1_kernel— exact Devroye (math §8) for shapeb = 1. This covers pure Bernoulli rows. Each row owns acurand-style XORWOW state seeded statelessly from(seed, row_index)so two runs with the same seed produce bit-identical draws regardless of grid layout. The alternating-series accept/reject uses the corrected right-tail coefficientπ · k(notπ / 2) — the math team’s Phase-1 fix.sp_kernel— saddlepoint rejection (math §9) for13 < b ≤ 170. This solvesK'(t) = xvia six Newton iterations ontanh(v)/vortan(v)/vand uses an IG + Gamma envelope for the accept/reject.normal_kernel— Lyapunov-CLT closed-form approximation (math §10) forb > 170. Mean and variance use the analytic PG(b, c) limit, no rejection loop, no warp divergence.
The host dispatcher partitions an input vector of (b_i, c_i) rows
into three contiguous index lists (one per regime) and launches one
kernel per regime. The 8 ≤ b ≤ 13 band is handled on host via the
sum-of-PG(1, c) convolution identity — at small b the sum cost is
negligible and keeping it off-device avoids a fourth kernel that would
see almost no traffic in practice.
§What this primitive intentionally does NOT do
- It does not plug into BMS marginal slope (probit model) — the PG augmentation identity is logit-only; doing so silently would change numerical results for shipped fits.
- It does not define a public production family. The
Gibbs harness in
logistic_gibbs_stepis a validation oracle for the sampler primitive, not a fit method. The CPU referencesrc/inference/polya_gamma.rsand the NUTS/HMC infrastructure remain the supported posterior-inference paths.
§Stateless XORWOW seeding
Each row’s XORWOW state (s0, s1, s2, s3, s4, counter) is materialised
by feeding splitmix64( seed ⊕ row · ZETA ⊕ word · GAMMA ) for word
indices 0..5 — five 32-bit lanes plus a 32-bit counter. The host
reference RNG (xorwow_state_from) reproduces the same byte sequence
the kernel emits, so CPU/GPU parity tests can compare draw-by-draw at
the same (seed, row).
Structs§
- PgSeed
- Stateless seed for the per-row XORWOW PRNG. The same
seedreproduces bit-identical draws across runs and across CPU/GPU implementations. - Polya
Gamma Batch Input - Inputs for the dispatched batched sampler.
- Xorwow
State - Compact per-row XORWOW state. Layout matches
curand_kernel.h’scurandStateXORWOW_tfor the five state lanes plus the addition counter; we omit the boxmuller cache (PG sampler doesn’t use it).
Constants§
- NORMAL_
MIN_ B - PG1_
MAX_ B - Regime split thresholds (math §7).
- SADDLE_
MAX_ B - SADDLE_
MIN_ B
Functions§
- draw_
batch - Top-level entry point: dispatches to GPU when available, otherwise CPU.
Both paths use the same per-row XORWOW seeding so the GPU result is a
bit-equivalent of the CPU result up to IEEE rounding of
exp/log/tan/tanh/sqrt(which the device evaluators round to within 1 ULP of the CPUlibm). - draw_
batch_ cpu - Per-row CPU draw using the appropriate regime. Used by the harness when the GPU runtime is unavailable, and as the per-row oracle for the dispatched device path’s parity tests.
- logistic_
gibbs_ step - Single Gibbs step for the synthetic Bernoulli-logistic model
y_i | β ~ Bernoulli(σ(x_iᵀ β))with priorβ ~ N(0, Q_0⁻¹). - pg1_
draw_ cpu_ oracle - CPU oracle for one PG(1, c) draw using a
XorwowStatedirectly. The device kernel performs the same arithmetic byte-for-byte (modulo IEEE rounding of transcendentals, which agree to <1 ULP for the inputs we touch). - pg_
convolution_ cpu_ oracle - Higher-shape draw on host via convolution: PG(b, c) =d Σ{j=1..b} PG(1, c).
Used by host for the
2 ≤ b ≤ 13band and as the parity oracle for the saddlepoint kernel at modestb. - pg_mean
- Mean of
PG(b, c)(PSW 2013, eq. 4):E = b · tanh(c/2)/(2c), limitb/4. - pg_
normal_ cpu_ oracle - Lyapunov-CLT closed-form draw for
b > NORMAL_MIN_B. Truncated at zero because PG support is(0, +∞). - pg_
saddlepoint_ cpu_ oracle - Saddlepoint host draw for PG(b, c) with
13 < b ≤ 170. This is the reference the device sp_kernel matches in distribution; both fall back to the convolution oracle whenbis small enough that the saddlepoint approximation has noticeable bias (validated by §12.4 test). - pg_
variance - Variance of
PG(b, c):Var = b · (sinh c − c)/(2 c³ (1 + cosh c)), limitb/24. - saddlepoint_
kpp - Saddlepoint approximation K’‘(t), the variance of the tilted distribution from K’(t) = x. K’’(t) is variance, so positive on both branches.
- saddlepoint_
solve - Solve K’(t) = x for the saddlepoint t given x in (0, 1). K’(t) is a continuous strictly increasing function of t on the appropriate branch; the math team’s parameterisation eliminates v = sqrt(|2t|) so the Newton iteration is on a monotone bounded variable.
- splitmix64_
mix - SplitMix64 finalizer (matches
reml_trace::splitmix64_mix). Thin wrapper over the canonical implementation ingam_linalg::utils::splitmix64_hash.