use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[cfg(target_os = "linux")]
use super::error::GpuError;
#[derive(Clone, Copy, Debug)]
pub struct PgSeed(pub u64);
impl Default for PgSeed {
fn default() -> Self {
Self(0x50_4F_4C_59_47_41_4D_41) }
}
pub const PG1_MAX_B: u32 = 1;
pub const SADDLE_MIN_B: u32 = 14;
pub const SADDLE_MAX_B: u32 = 170;
pub const NORMAL_MIN_B: u32 = 171;
#[derive(Clone, Debug)]
pub struct PolyaGammaBatchInput<'a> {
pub shapes: ArrayView1<'a, u32>,
pub tilts: ArrayView1<'a, f64>,
pub seed: PgSeed,
}
impl<'a> PolyaGammaBatchInput<'a> {
pub fn rows(&self) -> usize {
self.shapes.len()
}
pub fn validate(&self) -> Result<(), String> {
if self.shapes.len() != self.tilts.len() {
return Err(format!(
"polya_gamma: shapes.len()={} != tilts.len()={}",
self.shapes.len(),
self.tilts.len()
));
}
if self.shapes.iter().any(|b| *b == 0) {
return Err("polya_gamma: b=0 is invalid (PG(0,c) is a point mass at 0)".to_string());
}
Ok(())
}
}
#[inline]
pub fn splitmix64_mix(mut z: u64) -> u64 {
z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut x = z;
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
x ^ (x >> 31)
}
const ROW_ZETA: u64 = 0xA1B2_C3D4_E5F6_7890;
const WORD_GAMMA: u64 = 0x0F1E_2D3C_4B5A_6978;
#[derive(Clone, Copy, Debug)]
pub struct XorwowState {
pub s: [u32; 5],
pub d: u32,
}
impl XorwowState {
pub fn new(seed: u64, row: u64) -> Self {
let mut words = [0u32; 6];
for (word_idx, slot) in words.iter_mut().enumerate() {
let composite =
seed ^ row.wrapping_mul(ROW_ZETA) ^ (word_idx as u64).wrapping_mul(WORD_GAMMA);
let h = splitmix64_mix(composite);
*slot = (h >> 32) as u32;
}
if words[0] == 0 && words[1] == 0 && words[2] == 0 && words[3] == 0 && words[4] == 0 {
words[0] = 1;
}
Self {
s: [words[0], words[1], words[2], words[3], words[4]],
d: words[5],
}
}
#[inline]
pub fn next_u32(&mut self) -> u32 {
let mut t = self.s[4];
let s = self.s[0];
self.s[4] = self.s[3];
self.s[3] = self.s[2];
self.s[2] = self.s[1];
self.s[1] = s;
t ^= t >> 2;
t ^= t << 1;
t ^= s ^ (s << 4);
self.s[0] = t;
self.d = self.d.wrapping_add(362_437);
t.wrapping_add(self.d)
}
#[inline]
pub fn next_unit(&mut self) -> f64 {
let raw = self.next_u32();
((raw as f64) + 1.0) * (1.0 / 4_294_967_296.0)
}
#[inline]
pub fn next_exp(&mut self) -> f64 {
-self.next_unit().ln()
}
#[inline]
pub fn next_norm(&mut self) -> f64 {
loop {
let u = 2.0 * self.next_unit() - 1.0;
let v = 2.0 * self.next_unit() - 1.0;
let s = u * u + v * v;
if s > 0.0 && s < 1.0 {
let factor = (-2.0 * s.ln() / s).sqrt();
return u * factor;
}
}
}
}
use std::f64::consts::{FRAC_2_PI, FRAC_PI_2, PI};
const PI_SQ: f64 = PI * PI;
const SQRT_2_OVER_SQRT_PI: f64 = 0.797_884_560_802_865_4;
const SQRT_PI_OVER_2: f64 = 1.253_314_137_315_500_1;
pub fn pg1_draw_cpu_oracle(state: &mut XorwowState, tilt: f64) -> f64 {
let half_tilt = tilt.abs() * 0.5;
let half_tilt_sq = half_tilt * half_tilt;
let scale = 0.125 * PI_SQ + 0.5 * half_tilt_sq;
let exp_mass = exponential_tail_mass(half_tilt);
loop {
let u = state.next_unit();
let proposal = if u < exp_mass {
FRAC_2_PI + state.next_exp() / scale
} else {
sample_trunc_inv_gauss(state, half_tilt, FRAC_2_PI)
};
let mut sum = series_coefficient(0, proposal);
let threshold = state.next_unit() * sum;
let mut idx = 0usize;
loop {
idx += 1;
let term = series_coefficient(idx, proposal);
if idx % 2 == 1 {
sum -= term;
if threshold <= sum {
return 0.25 * proposal;
}
} else {
sum += term;
if threshold >= sum {
break;
}
}
}
}
}
pub fn pg_convolution_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
(0..b).map(|_| pg1_draw_cpu_oracle(state, tilt)).sum()
}
fn exponential_tail_mass(tilt: f64) -> f64 {
let base = 0.125 * PI_SQ + 0.5 * tilt * tilt;
let upper = SQRT_PI_OVER_2 * (FRAC_2_PI * tilt - 1.0);
let lower = -(SQRT_PI_OVER_2 * (FRAC_2_PI * tilt + 1.0));
let base_factor = base * (base * FRAC_2_PI).exp();
let p_upper = base_factor * (-tilt).exp() * std_normal_cdf(upper);
let p_lower = base_factor * tilt.exp() * std_normal_cdf(lower);
let exp_terms = (4.0 / PI) * (p_upper + p_lower);
1.0 / (1.0 + exp_terms)
}
#[inline]
fn std_normal_cdf(x: f64) -> f64 {
use statrs::distribution::{ContinuousCDF, Normal};
Normal::standard().cdf(x)
}
#[inline]
fn series_coefficient(n: usize, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let k = n as f64 + 0.5;
let k_sq = k * k;
if x <= FRAC_2_PI {
let coeff = 2.0 * k * SQRT_2_OVER_SQRT_PI;
let inv_x = 1.0 / x;
coeff * inv_x * inv_x.sqrt() * (-2.0 * k_sq * inv_x).exp()
} else {
PI * k * (-0.5 * k_sq * PI_SQ * x).exp()
}
}
fn sample_trunc_inv_gauss(state: &mut XorwowState, z: f64, trunc: f64) -> f64 {
let z = z.abs();
if FRAC_2_PI > z {
sample_small_z(state, z, trunc)
} else {
sample_large_z(state, 1.0 / z, trunc)
}
}
fn sample_small_z(state: &mut XorwowState, z: f64, trunc: f64) -> f64 {
let mut accept = 0.0;
let mut sample = 0.0;
while accept < state.next_unit() {
let exp_sample = loop {
let e1 = state.next_exp();
let e2 = state.next_exp();
if e1 * e1 <= 2.0 * e2 / trunc {
break e1;
}
};
sample = 1.0 + exp_sample * trunc;
sample = trunc / (sample * sample);
accept = (-0.5 * z * z * sample).exp();
}
sample
}
fn sample_large_z(state: &mut XorwowState, mean: f64, trunc: f64) -> f64 {
let mut sample = f64::INFINITY;
while sample > trunc {
let n = state.next_norm();
let n_sq = n * n;
let half_mean = 0.5 * mean;
let mn_sq = mean * n_sq;
let disc = (4.0 * mn_sq + mn_sq * mn_sq).sqrt();
sample = mean + half_mean * mn_sq - half_mean * disc;
if state.next_unit() > mean / (mean + sample) {
sample = mean * mean / sample;
}
}
sample
}
pub fn saddlepoint_solve(x: f64) -> f64 {
if (x - 1.0).abs() < 1e-9 {
return 0.0;
}
if x < 1.0 {
let v_taylor = (3.0 * (1.0 - x)).sqrt();
let v_asym = 1.0 / x.max(1e-12);
let mut v = v_taylor.max(v_asym).max(1e-6);
for _ in 0..16 {
let tanh_v = v.tanh();
let f = tanh_v / v - x;
let sech_sq = 1.0 - tanh_v * tanh_v;
let df = (sech_sq - tanh_v / v) / v;
v -= f / df;
if v.abs() < 1e-12 {
break;
}
}
-0.5 * v * v
} else {
let v_taylor = (3.0 * (x - 1.0)).sqrt();
let v_pole = FRAC_PI_2 - 2.0 / (x.max(1e-12) * PI);
let mut v = v_taylor.max(v_pole).min(0.499 * PI).max(1e-6);
for _ in 0..16 {
let tan_v = v.tan();
let f = tan_v / v - x;
let sec_sq = 1.0 + tan_v * tan_v;
let df = (sec_sq - tan_v / v) / v;
v = (v - f / df).max(1e-6).min(0.499_999 * PI);
if !v.is_finite() {
v = (3.0 * (x - 1.0)).sqrt().min(0.49 * PI);
break;
}
}
0.5 * v * v
}
}
pub fn saddlepoint_kpp(t: f64) -> f64 {
if t.abs() < 1e-14 {
return 2.0 / 3.0;
}
if t < 0.0 {
let v = (-2.0 * t).sqrt();
let tanh_v = v.tanh();
let sech_sq = 1.0 - tanh_v * tanh_v;
(tanh_v / (v * v * v)) - (sech_sq / (v * v))
} else {
let v = (2.0 * t).sqrt();
let tan_v = v.tan();
let sec_sq = 1.0 + tan_v * tan_v;
(sec_sq / (v * v)) - (tan_v / (v * v * v))
}
}
pub fn pg_saddlepoint_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
pg_convolution_cpu_oracle(state, b, tilt)
}
#[inline]
fn pg_third_cumulant(b: f64, c: f64) -> f64 {
let cb = c.abs().max(1e-8);
let eps = 1e-3 * cb * cb;
let kpp = |t: f64| -> f64 {
let u = (cb * cb - 2.0 * t).max(1e-12).sqrt() * 0.5;
let thu = u.tanh();
let sech2u = 1.0 - thu * thu;
let inner = sech2u / (4.0 * u) - thu / (4.0 * u * u);
b * (-1.0 / (4.0 * u)) * inner
};
(kpp(eps) - kpp(-eps)) / (2.0 * eps)
}
pub fn pg_saddlepoint_normal_skew_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
let bf = b as f64;
let mean = pg_mean(bf, tilt);
let var = pg_variance(bf, tilt);
let sd = var.sqrt();
let k3 = pg_third_cumulant(bf, tilt);
let gamma1 = k3 / (sd * sd * sd);
let z = state.next_norm();
let mut draw = mean + sd * (z + gamma1 / 6.0 * (z * z - 1.0));
if draw <= 0.0 {
draw = -draw + 1e-300;
}
draw
}
#[inline]
pub fn pg_mean(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-8 {
0.25 * b
} else {
b * (0.5 * c_abs).tanh() / (2.0 * c_abs)
}
}
#[inline]
pub fn pg_variance(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-6 {
b / 24.0
} else {
let cosh_c = c_abs.cosh();
let sinh_c = c_abs.sinh();
b * (sinh_c - c_abs) / (2.0 * c_abs * c_abs * c_abs * (1.0 + cosh_c))
}
}
pub fn pg_normal_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
let mean = pg_mean(b as f64, tilt);
let var = pg_variance(b as f64, tilt);
let sd = var.sqrt();
let mut draw = mean + sd * state.next_norm();
if draw <= 0.0 {
draw = -draw + 1e-300;
}
draw
}
pub fn draw_batch_cpu(input: &PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
input.validate()?;
let n = input.rows();
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
let mut state = XorwowState::new(input.seed.0, i as u64);
let b = input.shapes[i];
let c = input.tilts[i];
let v = if b <= PG1_MAX_B {
pg1_draw_cpu_oracle(&mut state, c)
} else if b < SADDLE_MIN_B {
pg_convolution_cpu_oracle(&mut state, b, c)
} else if b <= SADDLE_MAX_B {
pg_saddlepoint_cpu_oracle(&mut state, b, c)
} else {
pg_normal_cpu_oracle(&mut state, b, c)
};
out[i] = v;
}
Ok(out)
}
pub fn draw_batch(input: PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
input.validate()?;
#[cfg(target_os = "linux")]
{
if super::runtime::GpuRuntime::global().is_some() {
match linux_cuda::draw_batch_gpu(&input) {
Ok(v) => return Ok(v),
Err(GpuError::NotYetImplemented { .. }) => {
}
Err(other) => return Err(String::from(other)),
}
}
}
draw_batch_cpu(&input)
}
pub fn logistic_gibbs_step(
design: ArrayView2<'_, f64>,
targets: ArrayView1<'_, u8>,
prior_precision: ArrayView2<'_, f64>,
beta: ArrayView1<'_, f64>,
seed: PgSeed,
norm_seed: u64,
) -> Result<Array1<f64>, String> {
let (n, p) = design.dim();
if targets.len() != n {
return Err(format!(
"logistic_gibbs_step: y.len()={} != n={n}",
targets.len()
));
}
if prior_precision.dim() != (p, p) {
return Err(format!(
"logistic_gibbs_step: Q_0 shape {:?} != ({p}, {p})",
prior_precision.dim()
));
}
if beta.len() != p {
return Err(format!(
"logistic_gibbs_step: beta.len()={} != p={p}",
beta.len()
));
}
let mut psi = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0;
for j in 0..p {
acc += design[[i, j]] * beta[j];
}
psi[i] = acc;
}
let shapes = Array1::<u32>::from_elem(n, 1);
let omega = draw_batch(PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: psi.view(),
seed,
})?;
let mut m = Array1::<f64>::zeros(p);
for i in 0..n {
let r = targets[i] as f64 - 0.5;
for j in 0..p {
m[j] += design[[i, j]] * r;
}
}
let mut q = prior_precision.to_owned();
for i in 0..n {
let w = omega[i];
for a in 0..p {
let xa = design[[i, a]];
for b in 0..p {
q[[a, b]] += w * xa * design[[i, b]];
}
}
}
let l = cholesky_lower_inplace(q.clone())
.map_err(|e| format!("logistic_gibbs_step Cholesky: {e}"))?;
let mean = solve_lower_then_upper(&l, &m);
let mut norm_state = XorwowState::new(norm_seed, 0);
let mut eta = Array1::<f64>::zeros(p);
for j in 0..p {
eta[j] = norm_state.next_norm();
}
let perturb = solve_upper_transpose(&l, &eta);
let mut beta_new = Array1::<f64>::zeros(p);
for j in 0..p {
beta_new[j] = mean[j] + perturb[j];
}
Ok(beta_new)
}
fn cholesky_lower_inplace(mut a: Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for k in 0..j {
sum -= a[[i, k]] * a[[j, k]];
}
if i == j {
if sum <= 0.0 {
return Err(format!("non-SPD diagonal {sum} at row {i}"));
}
a[[i, j]] = sum.sqrt();
} else {
a[[i, j]] = sum / a[[j, j]];
}
}
for j in (i + 1)..n {
a[[i, j]] = 0.0;
}
}
Ok(a)
}
fn solve_lower_then_upper(l: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut s = rhs[i];
for k in 0..i {
s -= l[[i, k]] * y[k];
}
y[i] = s / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut s = y[i];
for k in (i + 1)..n {
s -= l[[k, i]] * x[k];
}
x[i] = s / l[[i, i]];
}
x
}
fn solve_upper_transpose(l: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut s = rhs[i];
for k in (i + 1)..n {
s -= l[[k, i]] * x[k];
}
x[i] = s / l[[i, i]];
}
x
}
#[cfg(target_os = "linux")]
mod linux_cuda {
use super::{
PG1_MAX_B, PgSeed, PolyaGammaBatchInput, SADDLE_MAX_B, SADDLE_MIN_B, XorwowState,
pg_convolution_cpu_oracle, pg_normal_cpu_oracle,
};
use crate::gpu::error::{GpuError, GpuResultExt};
use crate::gpu::solver::context_and_stream;
use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
use ndarray::Array1;
use std::sync::Arc;
pub(super) const PTX_SOURCE: &str = r#"
extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
z += 0x9E3779B97F4A7C15ULL;
unsigned long long x = z;
x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
return x ^ (x >> 31);
}
// Per-row XORWOW state. Layout mirrors curand_kernel.h::curandStateXORWOW_t
// for the five 32-bit state lanes plus the addition counter. We omit the
// boxmuller_extra/boxmuller_flag cache since our normal draws use the
// polar method (which discards the second variate).
struct XorwowState {
unsigned int s0, s1, s2, s3, s4, d;
};
extern "C" __device__ void xorwow_seed(struct XorwowState* st, unsigned long long seed, unsigned long long row) {
const unsigned long long ROW_ZETA = 0xA1B2C3D4E5F67890ULL;
const unsigned long long WORD_GAMMA = 0x0F1E2D3C4B5A6978ULL;
unsigned int words[6];
for (int w = 0; w < 6; ++w) {
unsigned long long composite = seed ^ (row * ROW_ZETA) ^ ((unsigned long long)w * WORD_GAMMA);
unsigned long long h = splitmix64_mix(composite);
words[w] = (unsigned int)(h >> 32);
}
if ((words[0] | words[1] | words[2] | words[3] | words[4]) == 0u) {
words[0] = 1u;
}
st->s0 = words[0]; st->s1 = words[1]; st->s2 = words[2];
st->s3 = words[3]; st->s4 = words[4]; st->d = words[5];
}
extern "C" __device__ unsigned int xorwow_next(struct XorwowState* st) {
unsigned int t = st->s4;
unsigned int s = st->s0;
st->s4 = st->s3;
st->s3 = st->s2;
st->s2 = st->s1;
st->s1 = s;
t ^= (t >> 2);
t ^= (t << 1);
t ^= s ^ (s << 4);
st->s0 = t;
st->d += 362437u;
return t + st->d;
}
extern "C" __device__ double xorwow_unit(struct XorwowState* st) {
unsigned int raw = xorwow_next(st);
return ((double)raw + 1.0) * (1.0 / 4294967296.0);
}
extern "C" __device__ double xorwow_exp(struct XorwowState* st) {
return -log(xorwow_unit(st));
}
extern "C" __device__ double xorwow_norm(struct XorwowState* st) {
// Marsaglia polar — discard the partner variate, matches host oracle
// byte-for-byte (host also discards).
for (;;) {
double u = 2.0 * xorwow_unit(st) - 1.0;
double v = 2.0 * xorwow_unit(st) - 1.0;
double s = u * u + v * v;
if (s > 0.0 && s < 1.0) {
double factor = sqrt(-2.0 * log(s) / s);
return u * factor;
}
}
}
// ── Devroye PG(1, c) helpers ─────────────────────────────────────────────
#define PG_FRAC_2_PI (0.63661977236758134307553505349006)
#define PG_PI (3.14159265358979323846)
#define PG_PI_SQ (9.86960440108935861883)
#define PG_SQRT_2_OVER_PI (0.79788456080286535588)
#define PG_SQRT_PI_OVER_2 (1.25331413731550025121)
extern "C" __device__ double std_normal_cdf(double x) {
// 0.5 · erfc(-x / sqrt(2)).
return 0.5 * erfc(-x * 0.7071067811865475);
}
extern "C" __device__ double pg_series(int n, double x) {
if (x <= 0.0) return 0.0;
double k = (double)n + 0.5;
double k_sq = k * k;
if (x <= PG_FRAC_2_PI) {
double inv_x = 1.0 / x;
return (2.0 * k * PG_SQRT_2_OVER_PI) * inv_x * sqrt(inv_x) * exp(-2.0 * k_sq * inv_x);
} else {
// Right branch — corrected coefficient PI · k (not PI / 2).
return PG_PI * k * exp(-0.5 * k_sq * PG_PI_SQ * x);
}
}
extern "C" __device__ double pg_exp_tail_mass(double tilt) {
double base = 0.125 * PG_PI_SQ + 0.5 * tilt * tilt;
double upper = PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt - 1.0);
double lower = -(PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt + 1.0));
double base_factor = base * exp(base * PG_FRAC_2_PI);
double p_upper = base_factor * exp(-tilt) * std_normal_cdf(upper);
double p_lower = base_factor * exp( tilt) * std_normal_cdf(lower);
double exp_terms = (4.0 / PG_PI) * (p_upper + p_lower);
return 1.0 / (1.0 + exp_terms);
}
extern "C" __device__ double sample_small_z(struct XorwowState* st, double z, double trunc) {
double accept = 0.0;
double sample = 0.0;
while (accept < xorwow_unit(st)) {
double exp_sample;
for (;;) {
double e1 = xorwow_exp(st);
double e2 = xorwow_exp(st);
if (e1 * e1 <= 2.0 * e2 / trunc) { exp_sample = e1; break; }
}
sample = 1.0 + exp_sample * trunc;
sample = trunc / (sample * sample);
accept = exp(-0.5 * z * z * sample);
}
return sample;
}
extern "C" __device__ double sample_large_z(struct XorwowState* st, double mean, double trunc) {
double sample = 1.0e300;
while (sample > trunc) {
double n = xorwow_norm(st);
double n_sq = n * n;
double half_mean = 0.5 * mean;
double mn_sq = mean * n_sq;
double disc = sqrt(4.0 * mn_sq + mn_sq * mn_sq);
sample = mean + half_mean * mn_sq - half_mean * disc;
if (xorwow_unit(st) > mean / (mean + sample)) {
sample = mean * mean / sample;
}
}
return sample;
}
extern "C" __device__ double sample_trunc_inv_gauss(struct XorwowState* st, double z, double trunc) {
double az = fabs(z);
if (PG_FRAC_2_PI > az) {
return sample_small_z(st, az, trunc);
} else {
return sample_large_z(st, 1.0 / az, trunc);
}
}
extern "C" __device__ double pg1_draw(struct XorwowState* st, double tilt) {
double half_tilt = fabs(tilt) * 0.5;
double scale = 0.125 * PG_PI_SQ + 0.5 * half_tilt * half_tilt;
double exp_mass = pg_exp_tail_mass(half_tilt);
for (;;) {
double u = xorwow_unit(st);
double proposal;
if (u < exp_mass) {
proposal = PG_FRAC_2_PI + xorwow_exp(st) / scale;
} else {
proposal = sample_trunc_inv_gauss(st, half_tilt, PG_FRAC_2_PI);
}
double sum = pg_series(0, proposal);
double threshold = xorwow_unit(st) * sum;
int idx = 0;
// The alternating-series tail. Bounded iteration cap (64) is
// overwhelmingly safe: PSW 2013 show termination in <10 iters
// with probability >1 - 1e-30 for any tilt; the cap exists only
// to guarantee forward progress under hardware fault.
for (int outer = 0; outer < 64; ++outer) {
idx += 1;
double term = pg_series(idx, proposal);
if (idx & 1) {
sum -= term;
if (threshold <= sum) {
return 0.25 * proposal;
}
} else {
sum += term;
if (threshold >= sum) {
break;
}
}
}
}
}
// ── Saddlepoint helpers (math §9) ────────────────────────────────────────
extern "C" __device__ double saddlepoint_t(double x) {
if (fabs(x - 1.0) < 1.0e-9) return 0.0;
if (x < 1.0) {
double v = sqrt(3.0 * (1.0 - x)); if (v < 1.0e-6) v = 1.0e-6;
for (int it = 0; it < 6; ++it) {
double tanh_v = tanh(v);
double f = tanh_v / v - x;
double sech_sq = 1.0 - tanh_v * tanh_v;
double df = (sech_sq - tanh_v / v) / v;
v -= f / df;
if (fabs(v) < 1.0e-12) break;
}
return -0.5 * v * v;
} else {
double v = sqrt(3.0 * (x - 1.0));
if (v > 0.49 * PG_PI) v = 0.49 * PG_PI;
if (v < 1.0e-6) v = 1.0e-6;
for (int it = 0; it < 6; ++it) {
double tan_v = tan(v);
double f = tan_v / v - x;
double sec_sq = 1.0 + tan_v * tan_v;
double df = (sec_sq - tan_v / v) / v;
v -= f / df;
if (v < 1.0e-6) v = 1.0e-6;
if (v > 0.499999 * PG_PI) v = 0.499999 * PG_PI;
}
return 0.5 * v * v;
}
}
// ── Kernels ──────────────────────────────────────────────────────────────
extern "C" __global__ void pg1_kernel(
unsigned long long seed,
unsigned int n,
const unsigned int* __restrict__ rows, // index map into shapes/tilts/out, length n
const double* __restrict__ tilts,
double* __restrict__ out)
{
unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
if (slot >= n) return;
unsigned int row = rows[slot];
struct XorwowState st;
xorwow_seed(&st, seed, (unsigned long long)row);
double c = tilts[row];
out[row] = pg1_draw(&st, c);
}
extern "C" __global__ void sp_kernel(
unsigned long long seed,
unsigned int n,
const unsigned int* __restrict__ rows,
const unsigned int* __restrict__ shapes,
const double* __restrict__ tilts,
double* __restrict__ out)
{
unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
if (slot >= n) return;
unsigned int row = rows[slot];
struct XorwowState st;
xorwow_seed(&st, seed, (unsigned long long)row);
unsigned int b = shapes[row];
double c = tilts[row];
// Convolution-equivalent device fallback: sum b PG(1, c) draws. This
// is correct in distribution; the *true* saddlepoint envelope ships
// with phase 3 hill-climb. Until then, the kernel is callable and
// produces draws that pass the §12 KS test — the only thing the
// saddlepoint is supposed to buy is throughput at large b.
double acc = 0.0;
for (unsigned int j = 0; j < b; ++j) {
acc += pg1_draw(&st, c);
}
// Touch saddlepoint_t so the helper isn’t DCE’d before phase 3 wiring;
// the value is unused (multiplied by zero) so this is free.
double sp_warm = saddlepoint_t(0.5);
out[row] = acc + 0.0 * sp_warm;
}
extern "C" __global__ void normal_kernel(
unsigned long long seed,
unsigned int n,
const unsigned int* __restrict__ rows,
const unsigned int* __restrict__ shapes,
const double* __restrict__ tilts,
double* __restrict__ out)
{
unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
if (slot >= n) return;
unsigned int row = rows[slot];
struct XorwowState st;
xorwow_seed(&st, seed, (unsigned long long)row);
double b = (double)shapes[row];
double c = fabs(tilts[row]);
double mean;
double var;
if (c < 1.0e-8) {
mean = 0.25 * b;
var = b / 24.0;
} else {
mean = b * tanh(0.5 * c) / (2.0 * c);
double cosh_c = cosh(c);
double sinh_c = sinh(c);
var = b * (sinh_c - c) / (2.0 * c * c * c * (1.0 + cosh_c));
}
double sd = sqrt(var);
double draw = mean + sd * xorwow_norm(&st);
if (draw <= 0.0) draw = -draw + 1.0e-300;
out[row] = draw;
}
"#;
const THREADS_PER_BLOCK: u32 = 128;
fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
static CACHE: crate::gpu::common::PtxModuleCache =
crate::gpu::common::PtxModuleCache::new();
CACHE.get_or_compile(ctx, "polya_gamma", PTX_SOURCE)
}
pub(super) fn draw_batch_gpu(
input: &PolyaGammaBatchInput<'_>,
) -> Result<Array1<f64>, GpuError> {
let n = input.rows();
if n == 0 {
return Ok(Array1::<f64>::zeros(0));
}
let (ctx, stream) =
context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
let compiled = module(&ctx)?;
let module_handle: &Arc<CudaModule> = compiled;
let mut pg1_rows: Vec<u32> = Vec::new();
let mut sp_rows: Vec<u32> = Vec::new();
let mut normal_rows: Vec<u32> = Vec::new();
let mut host_rows: Vec<u32> = Vec::new();
for (i, &b) in input.shapes.iter().enumerate() {
let idx = i as u32;
if b <= PG1_MAX_B {
pg1_rows.push(idx);
} else if b < SADDLE_MIN_B {
host_rows.push(idx);
} else if b <= SADDLE_MAX_B {
sp_rows.push(idx);
} else {
normal_rows.push(idx);
}
}
let tilts_vec: Vec<f64> = match input.tilts.as_slice() {
Some(s) => s.to_vec(),
None => input.tilts.iter().copied().collect(),
};
let shapes_vec: Vec<u32> = match input.shapes.as_slice() {
Some(s) => s.to_vec(),
None => input.shapes.iter().copied().collect(),
};
let tilts_dev = stream
.clone_htod(&tilts_vec)
.gpu_ctx("polya_gamma upload tilts")?;
let shapes_dev = stream
.clone_htod(&shapes_vec)
.gpu_ctx("polya_gamma upload shapes")?;
let mut out_dev = stream
.alloc_zeros::<f64>(n)
.gpu_ctx("polya_gamma alloc out")?;
if !pg1_rows.is_empty() {
let rows_dev = stream
.clone_htod(&pg1_rows)
.gpu_ctx("polya_gamma upload pg1 rows")?;
launch_pg1(
&stream,
module_handle,
input.seed,
&rows_dev,
&tilts_dev,
&mut out_dev,
)?;
}
if !sp_rows.is_empty() {
let rows_dev = stream
.clone_htod(&sp_rows)
.gpu_ctx("polya_gamma upload sp rows")?;
launch_sp(
&stream,
module_handle,
input.seed,
&rows_dev,
&shapes_dev,
&tilts_dev,
&mut out_dev,
)?;
}
if !normal_rows.is_empty() {
let rows_dev = stream
.clone_htod(&normal_rows)
.gpu_ctx("polya_gamma upload normal rows")?;
launch_normal(
&stream,
module_handle,
input.seed,
&rows_dev,
&shapes_dev,
&tilts_dev,
&mut out_dev,
)?;
}
let mut out_host = stream
.clone_dtoh(&out_dev)
.gpu_ctx("polya_gamma download out")?;
for &row in &host_rows {
let i = row as usize;
let mut st = XorwowState::new(input.seed.0, row as u64);
let b = input.shapes[i];
let c = input.tilts[i];
out_host[i] = if b <= SADDLE_MAX_B {
pg_convolution_cpu_oracle(&mut st, b, c)
} else {
pg_normal_cpu_oracle(&mut st, b, c)
};
}
Ok(Array1::from_vec(out_host))
}
fn launch_pg1(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
seed: PgSeed,
rows: &cudarc::driver::CudaSlice<u32>,
tilts: &cudarc::driver::CudaSlice<f64>,
out: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("pg1_kernel")
.gpu_ctx("polya_gamma load pg1_kernel")?;
let n = rows.len() as u32;
let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let seed_arg: u64 = seed.0;
unsafe {
stream
.launch_builder(&func)
.arg(&seed_arg)
.arg(&n)
.arg(rows)
.arg(tilts)
.arg(out)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("polya_gamma launch pg1_kernel")
}
fn launch_sp(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
seed: PgSeed,
rows: &cudarc::driver::CudaSlice<u32>,
shapes: &cudarc::driver::CudaSlice<u32>,
tilts: &cudarc::driver::CudaSlice<f64>,
out: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("sp_kernel")
.gpu_ctx("polya_gamma load sp_kernel")?;
let n = rows.len() as u32;
let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let seed_arg: u64 = seed.0;
unsafe {
stream
.launch_builder(&func)
.arg(&seed_arg)
.arg(&n)
.arg(rows)
.arg(shapes)
.arg(tilts)
.arg(out)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("polya_gamma launch sp_kernel")
}
fn launch_normal(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
seed: PgSeed,
rows: &cudarc::driver::CudaSlice<u32>,
shapes: &cudarc::driver::CudaSlice<u32>,
tilts: &cudarc::driver::CudaSlice<f64>,
out: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("normal_kernel")
.gpu_ctx("polya_gamma load normal_kernel")?;
let n = rows.len() as u32;
let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let seed_arg: u64 = seed.0;
unsafe {
stream
.launch_builder(&func)
.arg(&seed_arg)
.arg(&n)
.arg(rows)
.arg(shapes)
.arg(tilts)
.arg(out)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("polya_gamma launch normal_kernel")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn theoretical_mean(b: f64, c: f64) -> f64 {
pg_mean(b, c)
}
fn theoretical_variance(b: f64, c: f64) -> f64 {
pg_variance(b, c)
}
#[test]
fn pg1_cpu_oracle_matches_devroye_mean() {
let n = 25_000;
for &(c, tol) in &[(0.0_f64, 0.05), (1.0, 0.10), (3.0, 0.10)] {
let mut sum = 0.0;
for i in 0..n {
let mut st = XorwowState::new(0xC0FFEE_u64, i as u64);
sum += pg1_draw_cpu_oracle(&mut st, c);
}
let emp = sum / n as f64;
let th = theoretical_mean(1.0, c);
let rel = (emp - th).abs() / th.max(1e-12);
assert!(
rel < tol,
"PG(1,{c}) XORWOW oracle: emp {emp}, theory {th}, rel {rel}"
);
}
}
#[test]
fn pg1_cpu_oracle_variance_matches_theory() {
let n = 100_000;
for &c in &[0.0_f64, 0.5, 2.0, 5.0] {
let mut sum = 0.0;
let mut sum_sq = 0.0;
for i in 0..n {
let mut st = XorwowState::new(0xDEADBEEF_u64, i as u64);
let x = pg1_draw_cpu_oracle(&mut st, c);
sum += x;
sum_sq += x * x;
}
let mean = sum / n as f64;
let var = sum_sq / n as f64 - mean * mean;
let th_var = theoretical_variance(1.0, c);
let rel = (var - th_var).abs() / th_var.max(1e-12);
assert!(
rel < 0.05,
"PG(1,{c}) var: emp {var}, theory {th_var}, rel {rel}"
);
}
}
#[test]
fn xorwow_seeding_is_deterministic() {
let mut a = XorwowState::new(42, 7);
let mut b = XorwowState::new(42, 7);
for _ in 0..1024 {
assert_eq!(a.next_u32(), b.next_u32());
}
let mut c = XorwowState::new(42, 8);
let same = (0..32).all(|_| a.next_u32() == c.next_u32());
assert!(!same, "different rows must produce different streams");
}
#[test]
fn xorwow_unit_in_open_zero_closed_one() {
let mut st = XorwowState::new(123, 0);
for _ in 0..10_000 {
let u = st.next_unit();
assert!(u > 0.0 && u <= 1.0, "u={u} outside (0,1]");
}
}
#[test]
fn saddlepoint_solve_round_trips() {
for &x in &[0.05_f64, 0.3, 0.7, 0.99, 1.01, 1.5, 3.0, 8.0] {
let t = saddlepoint_solve(x);
let kp = if t.abs() < 1e-14 {
1.0
} else if t < 0.0 {
let v = (-2.0 * t).sqrt();
v.tanh() / v
} else {
let v = (2.0 * t).sqrt();
v.tan() / v
};
let rel = (kp - x).abs() / x.max(1e-12);
assert!(
rel < 1e-6,
"saddlepoint_solve(x={x}) → t={t}; K'(t)={kp}, rel={rel}"
);
}
}
#[test]
fn saddlepoint_kpp_is_positive() {
for &t in &[-2.0_f64, -0.5, -1e-5, 0.0, 1e-5, 0.5, 1.0] {
let v = saddlepoint_kpp(t);
assert!(v.is_finite() && v > 0.0, "K''({t}) = {v}");
}
}
#[test]
fn pg_normal_oracle_matches_moments_at_large_b() {
let b = 500u32;
let c = 1.0_f64;
let n = 100_000;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for i in 0..n {
let mut st = XorwowState::new(0xBEEF_u64, i as u64);
let x = pg_normal_cpu_oracle(&mut st, b, c);
sum += x;
sum_sq += x * x;
}
let mean = sum / n as f64;
let var = sum_sq / n as f64 - mean * mean;
let th_mean = theoretical_mean(b as f64, c);
let th_var = theoretical_variance(b as f64, c);
let m_rel = (mean - th_mean).abs() / th_mean;
let v_rel = (var - th_var).abs() / th_var;
assert!(
m_rel < 0.02,
"normal oracle mean: emp {mean}, theory {th_mean}, rel {m_rel}"
);
assert!(
v_rel < 0.05,
"normal oracle var: emp {var}, theory {th_var}, rel {v_rel}"
);
}
#[test]
fn batch_dispatch_handles_mixed_regimes() {
let shapes = ndarray::array![1u32, 5u32, 50u32, 300u32];
let tilts = ndarray::array![0.5_f64, 0.5, 0.5, 0.5];
let input = PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed: PgSeed(42),
};
let out = draw_batch_cpu(&input).expect("CPU dispatch");
assert_eq!(out.len(), 4);
for v in out.iter() {
assert!(
v.is_finite() && *v > 0.0,
"PG draw must be positive finite: {v}"
);
}
}
#[test]
fn logistic_gibbs_step_reduces_marginal_error() {
let n = 200;
let p = 3;
let mut design = Array2::<f64>::zeros((n, p));
let mut targets = Array1::<u8>::zeros(n);
for i in 0..n {
let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
let x2 = (((i * 7) % n) as f64 / n as f64) * 2.0 - 1.0;
design[[i, 0]] = x1;
design[[i, 1]] = x2;
design[[i, 2]] = 1.0;
let eta = 1.5 * x1 - 0.7 * x2 + 0.3;
let p_y = 1.0 / (1.0 + (-eta).exp());
let h = splitmix64_mix(i as u64 ^ 0xABCD_EF);
let u = ((h >> 11) as f64) / ((1u64 << 53) as f64);
targets[i] = if u < p_y { 1 } else { 0 };
}
let q0 = Array2::<f64>::eye(p) * 0.1;
let beta = Array1::<f64>::zeros(p);
let new_beta = logistic_gibbs_step(
design.view(),
targets.view(),
q0.view(),
beta.view(),
PgSeed(1),
9,
)
.expect("Gibbs step");
assert_eq!(new_beta.len(), p);
let disp: f64 = new_beta.iter().map(|b| b * b).sum::<f64>().sqrt();
assert!(
disp > 0.05 && disp.is_finite(),
"Gibbs step displacement {disp} not meaningfully nonzero"
);
}
fn ks_two_sample(a: &mut [f64], b: &mut [f64]) -> f64 {
a.sort_by(|x, y| x.partial_cmp(y).unwrap());
b.sort_by(|x, y| x.partial_cmp(y).unwrap());
let (na, nb) = (a.len() as f64, b.len() as f64);
let (mut i, mut j) = (0usize, 0usize);
let (mut fa, mut fb) = (0.0_f64, 0.0_f64);
let mut d_max = 0.0_f64;
while i < a.len() && j < b.len() {
if a[i] <= b[j] {
i += 1;
fa = i as f64 / na;
} else {
j += 1;
fb = j as f64 / nb;
}
let d = (fa - fb).abs();
if d > d_max {
d_max = d;
}
}
d_max
}
fn ks_critical_001(n_a: usize, n_b: usize) -> f64 {
let na = n_a as f64;
let nb = n_b as f64;
1.6276 * ((na + nb) / (na * nb)).sqrt()
}
#[test]
fn pg1_cpu_oracle_matches_inference_module_distribution() {
use crate::inference::polya_gamma::PolyaGamma;
use rand::{SeedableRng, rngs::StdRng};
let pg = PolyaGamma::new();
for &c in &[0.0_f64, 1.5, 4.0] {
let n_dev = 5_000;
let n_ref = 5_000;
let mut from_oracle: Vec<f64> = (0..n_dev)
.map(|i| {
let mut st = XorwowState::new(0xDEADBEEF_u64 ^ c.to_bits(), i as u64);
pg1_draw_cpu_oracle(&mut st, c)
})
.collect();
let mut from_reference: Vec<f64> = {
let mut rng = StdRng::seed_from_u64(0xABCD_u64 ^ c.to_bits());
(0..n_ref).map(|_| pg.draw(&mut rng, c)).collect()
};
let d = ks_two_sample(&mut from_oracle, &mut from_reference);
let crit = ks_critical_001(n_dev, n_ref);
assert!(
d <= 2.0 * crit,
"PG(1, c={c}) two-sample KS d={d} > 2·crit={}; XORWOW oracle and reference disagree in distribution",
2.0 * crit
);
}
}
#[test]
fn pg_convolution_identity_at_small_b() {
let n = 4_000;
let b: u32 = 8;
let c: f64 = 1.2;
let mut left: Vec<f64> = (0..n)
.map(|i| {
let mut st = XorwowState::new(0x1111_u64, i as u64);
(0..b).map(|_| pg1_draw_cpu_oracle(&mut st, c)).sum()
})
.collect();
let mut right: Vec<f64> = (0..n)
.map(|i| {
(0..b)
.map(|j| {
let mut st = XorwowState::new(0x2222_u64 ^ (j as u64), i as u64);
pg1_draw_cpu_oracle(&mut st, c)
})
.sum::<f64>()
})
.collect();
let d = ks_two_sample(&mut left, &mut right);
let crit = ks_critical_001(n, n);
assert!(
d <= 2.0 * crit,
"PG({b}, {c}) convolution identity KS d={d} > 2·crit={}",
2.0 * crit
);
}
#[test]
fn pg_normal_kernel_matches_moments_at_b_500() {
let b = 500u32;
let c = 2.0_f64;
let n = 50_000;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for i in 0..n {
let mut st = XorwowState::new(0xCAFE_u64, i as u64);
let x = pg_normal_cpu_oracle(&mut st, b, c);
sum += x;
sum_sq += x * x;
}
let mean = sum / n as f64;
let var = sum_sq / n as f64 - mean * mean;
let th_mean = pg_mean(b as f64, c);
let th_var = pg_variance(b as f64, c);
let m_rel = (mean - th_mean).abs() / th_mean;
let v_rel = (var - th_var).abs() / th_var;
assert!(
m_rel < 0.02,
"normal kernel mean: emp {mean}, theory {th_mean}, rel {m_rel}"
);
assert!(
v_rel < 0.05,
"normal kernel var: emp {var}, theory {th_var}, rel {v_rel}"
);
}
#[test]
fn logistic_gibbs_chain_converges_to_mle_direction() {
use rand::{RngExt, SeedableRng, rngs::StdRng};
let n = 400;
let p = 3;
let beta_star = [1.5_f64, -0.7, 0.3];
let mut design = Array2::<f64>::zeros((n, p));
let mut targets = Array1::<u8>::zeros(n);
let mut rng = StdRng::seed_from_u64(0xFEED);
for i in 0..n {
let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
let x2 = (((i * 13) % n) as f64 / n as f64) * 2.0 - 1.0;
design[[i, 0]] = x1;
design[[i, 1]] = x2;
design[[i, 2]] = 1.0;
let eta = beta_star[0] * x1 + beta_star[1] * x2 + beta_star[2];
let p_y = 1.0 / (1.0 + (-eta).exp());
let u: f64 = rng.random();
targets[i] = if u < p_y { 1 } else { 0 };
}
let q0 = Array2::<f64>::eye(p) * 0.01;
let mut beta = Array1::<f64>::zeros(p);
let mut accum = Array1::<f64>::zeros(p);
let steps = 200;
let burn = 50;
for k in 0..steps {
beta = logistic_gibbs_step(
design.view(),
targets.view(),
q0.view(),
beta.view(),
PgSeed(0xC0DE + k as u64),
0xCAFE + k as u64,
)
.expect("Gibbs step");
if k >= burn {
for j in 0..p {
accum[j] += beta[j];
}
}
}
for j in 0..p {
accum[j] /= (steps - burn) as f64;
}
let dot: f64 = (0..p).map(|j| accum[j] * beta_star[j]).sum();
let na: f64 = accum.iter().map(|v| v * v).sum::<f64>().sqrt();
let nb: f64 = beta_star.iter().map(|v| v * v).sum::<f64>().sqrt();
let cos = dot / (na * nb);
assert!(
cos > 0.85,
"Gibbs chain posterior-mean direction does not align with β*: cos = {cos}, accum = {accum:?}, β* = {beta_star:?}"
);
}
#[test]
#[cfg(target_os = "linux")]
fn polya_gamma_hill_climb_pg1_50x() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[polya_gamma_hill_climb_pg1_50x] no CUDA runtime on host — skipping");
return;
}
let n = 200_000usize;
let shapes = Array1::<u32>::from_elem(n, 1);
let mut tilts = Array1::<f64>::zeros(n);
for i in 0..n {
tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
}
let seed = PgSeed(0x50_4F_4C_59_47_41_4D_41);
{
let warm_shapes = Array1::<u32>::from_elem(16, 1);
let warm_tilts = Array1::<f64>::zeros(16);
draw_batch(PolyaGammaBatchInput {
shapes: warm_shapes.view(),
tilts: warm_tilts.view(),
seed,
})
.expect("warm");
}
let t_gpu_start = std::time::Instant::now();
let _gpu = draw_batch(PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("GPU draw_batch");
let dt_gpu = t_gpu_start.elapsed().as_secs_f64();
let t_cpu_start = std::time::Instant::now();
let _cpu = draw_batch_cpu(&PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("CPU draw_batch");
let dt_cpu = t_cpu_start.elapsed().as_secs_f64();
let speedup = dt_cpu / dt_gpu;
println!(
"polya_gamma_hill_climb_pg1: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
);
assert!(
speedup >= 50.0,
"PG(1) GPU speedup {speedup:.1}× < 50× hill-climb gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
);
}
#[test]
#[cfg(target_os = "linux")]
fn polya_gamma_hill_climb_mixed_nb_20x() {
if super::super::runtime::GpuRuntime::global().is_none() {
eprintln!("[polya_gamma_hill_climb_mixed_nb_20x] no CUDA runtime on host — skipping");
return;
}
let n = 200_000usize;
let mut shapes = Array1::<u32>::zeros(n);
let mut tilts = Array1::<f64>::zeros(n);
for i in 0..n {
shapes[i] = if i.is_multiple_of(5) { 1 } else { 250 };
tilts[i] = ((i as f64) / (n as f64)) * 4.0 - 2.0;
}
let seed = PgSeed(0xDEAD_BEEF_CAFE_BABE);
let warm_shapes = Array1::<u32>::from_elem(16, 250);
let warm_tilts = Array1::<f64>::zeros(16);
draw_batch(PolyaGammaBatchInput {
shapes: warm_shapes.view(),
tilts: warm_tilts.view(),
seed,
})
.expect("warm");
let t_gpu = std::time::Instant::now();
let _g = draw_batch(PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("GPU mixed");
let dt_gpu = t_gpu.elapsed().as_secs_f64();
let t_cpu = std::time::Instant::now();
let _c = draw_batch_cpu(&PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("CPU mixed");
let dt_cpu = t_cpu.elapsed().as_secs_f64();
let speedup = dt_cpu / dt_gpu;
println!(
"polya_gamma_hill_climb_mixed: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
);
assert!(
speedup >= 20.0,
"Mixed NB GPU speedup {speedup:.1}× < 20× gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
);
}
#[test]
fn pg_saddlepoint_skew_normal_passes_ks_vs_convolution() {
let n = 100_000;
let cases = [
(14u32, 0.5_f64),
(14, 2.0),
(30, 0.5),
(30, 2.0),
(80, 0.5),
(80, 2.0),
(170, 0.5),
(170, 2.0),
];
for &(b, c) in &cases {
let mut from_skew: Vec<f64> = (0..n)
.map(|i| {
let mut st = XorwowState::new(0xA5A5_A5A5_u64 ^ b as u64, i as u64);
pg_saddlepoint_normal_skew_oracle(&mut st, b, c)
})
.collect();
let mut from_conv: Vec<f64> = (0..n)
.map(|i| {
let mut st = XorwowState::new(0x5A5A_5A5A_u64 ^ b as u64, i as u64);
pg_convolution_cpu_oracle(&mut st, b, c)
})
.collect();
let d = ks_two_sample(&mut from_skew, &mut from_conv);
let crit = ks_critical_001(n, n);
assert!(
d <= crit,
"Phase 3 saddlepoint KS FAIL at (b={b}, c={c}): d={d} > crit={crit}. Revert to convolution-only sp_kernel."
);
}
}
#[test]
#[cfg(target_os = "linux")]
fn pg1_gpu_matches_cpu_oracle_when_runtime_available() {
if super::super::runtime::GpuRuntime::global().is_none() {
return;
}
let n = 256usize;
let shapes = Array1::<u32>::from_elem(n, 1);
let mut tilts = Array1::<f64>::zeros(n);
for i in 0..n {
tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
}
let seed = PgSeed(0x9E37_79B9_7F4A_7C15);
let gpu = draw_batch(PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("GPU draw_batch");
let cpu = draw_batch_cpu(&PolyaGammaBatchInput {
shapes: shapes.view(),
tilts: tilts.view(),
seed,
})
.expect("CPU draw_batch");
assert_eq!(gpu.len(), cpu.len());
for i in 0..n {
let g = gpu[i];
let c = cpu[i];
let rel = (g - c).abs() / c.max(1e-12);
assert!(
rel < 1e-6,
"pg1 GPU/CPU divergence at row {i}, tilt={}: gpu={g}, cpu={c}, rel={rel}",
tilts[i]
);
}
}
}