1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
56
57use gam_linalg::triangular::{back_substitution_lower_transpose, cholesky_solve_vector};
58
59#[cfg(target_os = "linux")]
60use gam_gpu::gpu_error::GpuError;
61
62#[derive(Clone, Copy, Debug)]
69pub struct PgSeed(pub u64);
70
71impl Default for PgSeed {
72 fn default() -> Self {
73 Self(0x50_4F_4C_59_47_41_4D_41) }
75}
76
77pub const PG1_MAX_B: u32 = 1;
84pub const SADDLE_MIN_B: u32 = 14;
85pub const SADDLE_MAX_B: u32 = 170;
86pub const NORMAL_MIN_B: u32 = 171;
87
88#[derive(Clone, Debug)]
90pub struct PolyaGammaBatchInput<'a> {
91 pub shapes: ArrayView1<'a, u32>,
93 pub tilts: ArrayView1<'a, f64>,
95 pub seed: PgSeed,
97}
98
99impl<'a> PolyaGammaBatchInput<'a> {
100 pub fn rows(&self) -> usize {
101 self.shapes.len()
102 }
103
104 pub fn validate(&self) -> Result<(), String> {
105 if self.shapes.len() != self.tilts.len() {
106 return Err(format!(
107 "polya_gamma: shapes.len()={} != tilts.len()={}",
108 self.shapes.len(),
109 self.tilts.len()
110 ));
111 }
112 if self.shapes.iter().any(|b| *b == 0) {
113 return Err("polya_gamma: b=0 is invalid (PG(0,c) is a point mass at 0)".to_string());
114 }
115 Ok(())
116 }
117}
118
119#[inline]
126pub fn splitmix64_mix(z: u64) -> u64 {
127 gam_linalg::utils::splitmix64_hash(z)
128}
129
130const ROW_ZETA: u64 = 0xA1B2_C3D4_E5F6_7890;
134const WORD_GAMMA: u64 = 0x0F1E_2D3C_4B5A_6978;
135
136#[derive(Clone, Copy, Debug)]
140pub struct XorwowState {
141 pub s: [u32; 5],
142 pub d: u32,
143}
144
145impl XorwowState {
146 pub fn new(seed: u64, row: u64) -> Self {
152 let mut words = [0u32; 6];
153 for (word_idx, slot) in words.iter_mut().enumerate() {
154 let composite =
155 seed ^ row.wrapping_mul(ROW_ZETA) ^ (word_idx as u64).wrapping_mul(WORD_GAMMA);
156 let h = splitmix64_mix(composite);
157 *slot = (h >> 32) as u32;
158 }
159 if words[0] == 0 && words[1] == 0 && words[2] == 0 && words[3] == 0 && words[4] == 0 {
162 words[0] = 1;
163 }
164 Self {
165 s: [words[0], words[1], words[2], words[3], words[4]],
166 d: words[5],
167 }
168 }
169
170 #[inline]
174 pub fn next_u32(&mut self) -> u32 {
175 let mut t = self.s[4];
176 let s = self.s[0];
177 self.s[4] = self.s[3];
178 self.s[3] = self.s[2];
179 self.s[2] = self.s[1];
180 self.s[1] = s;
181 t ^= t >> 2;
182 t ^= t << 1;
183 t ^= s ^ (s << 4);
184 self.s[0] = t;
185 self.d = self.d.wrapping_add(362_437);
186 t.wrapping_add(self.d)
187 }
188
189 #[inline]
194 pub fn next_unit(&mut self) -> f64 {
195 let raw = self.next_u32();
196 ((raw as f64) + 1.0) * (1.0 / 4_294_967_296.0)
197 }
198
199 #[inline]
202 pub fn next_exp(&mut self) -> f64 {
203 -self.next_unit().ln()
204 }
205
206 #[inline]
211 pub fn next_norm(&mut self) -> f64 {
212 loop {
213 let u = 2.0 * self.next_unit() - 1.0;
214 let v = 2.0 * self.next_unit() - 1.0;
215 let s = u * u + v * v;
216 if s > 0.0 && s < 1.0 {
217 let factor = (-2.0 * s.ln() / s).sqrt();
218 return u * factor;
219 }
220 }
221 }
222}
223
224use crate::polya_gamma_core::{PgRng, draw_pg1};
236use std::f64::consts::{FRAC_PI_2, PI};
237
238impl PgRng for XorwowState {
242 #[inline]
243 fn next_unit(&mut self) -> f64 {
244 XorwowState::next_unit(self)
245 }
246
247 #[inline]
248 fn next_exp(&mut self) -> f64 {
249 XorwowState::next_exp(self)
250 }
251
252 #[inline]
253 fn next_norm(&mut self) -> f64 {
254 XorwowState::next_norm(self)
255 }
256}
257
258pub fn pg1_draw_cpu_oracle(state: &mut XorwowState, tilt: f64) -> f64 {
263 draw_pg1(state, tilt)
264}
265
266pub fn pg_convolution_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
270 (0..b).map(|_| pg1_draw_cpu_oracle(state, tilt)).sum()
271}
272
273pub fn saddlepoint_solve(x: f64) -> f64 {
292 if (x - 1.0).abs() < 1e-9 {
297 return 0.0;
298 }
299 if x < 1.0 {
300 let v_taylor = (3.0 * (1.0 - x)).sqrt();
322 let v_asym = 1.0 / x.max(1e-12);
323 let mut v = v_taylor.max(v_asym).max(1e-6);
324 for _ in 0..16 {
325 let tanh_v = v.tanh();
326 let f = tanh_v / v - x;
327 let sech_sq = 1.0 - tanh_v * tanh_v;
330 let df = (sech_sq - tanh_v / v) / v;
331 v -= f / df;
332 if v.abs() < 1e-12 {
333 break;
334 }
335 }
336 -0.5 * v * v
337 } else {
338 let v_taylor = (3.0 * (x - 1.0)).sqrt();
353 let v_pole = FRAC_PI_2 - 2.0 / (x.max(1e-12) * PI);
354 let mut v = v_taylor.max(v_pole).min(0.499 * PI).max(1e-6);
355 for _ in 0..16 {
356 let tan_v = v.tan();
357 let f = tan_v / v - x;
358 let sec_sq = 1.0 + tan_v * tan_v;
360 let df = (sec_sq - tan_v / v) / v;
361 v = (v - f / df).max(1e-6).min(0.499_999 * PI);
362 if !v.is_finite() {
363 v = (3.0 * (x - 1.0)).sqrt().min(0.49 * PI);
364 break;
365 }
366 }
367 0.5 * v * v
368 }
369}
370
371pub fn saddlepoint_kpp(t: f64) -> f64 {
389 if t.abs() < 1e-14 {
390 return 2.0 / 3.0;
391 }
392 if t < 0.0 {
393 let v = (-2.0 * t).sqrt();
394 let tanh_v = v.tanh();
395 let sech_sq = 1.0 - tanh_v * tanh_v;
396 (tanh_v / (v * v * v)) - (sech_sq / (v * v))
397 } else {
398 let v = (2.0 * t).sqrt();
399 let tan_v = v.tan();
400 let sec_sq = 1.0 + tan_v * tan_v;
401 (sec_sq / (v * v)) - (tan_v / (v * v * v))
402 }
403}
404
405pub fn pg_saddlepoint_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
410 pg_convolution_cpu_oracle(state, b, tilt)
416}
417
418pub use crate::pg_moments::{pg_mean, pg_variance};
427
428pub fn pg_normal_cpu_oracle(state: &mut XorwowState, b: u32, tilt: f64) -> f64 {
431 let mean = pg_mean(b as f64, tilt);
432 let var = pg_variance(b as f64, tilt);
433 let sd = var.sqrt();
434 let mut draw = mean + sd * state.next_norm();
435 if draw <= 0.0 {
439 draw = -draw + 1e-300;
440 }
441 draw
442}
443
444pub fn draw_batch_cpu(input: &PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
452 input.validate()?;
453 let n = input.rows();
454 let mut out = Array1::<f64>::zeros(n);
455 for i in 0..n {
456 let mut state = XorwowState::new(input.seed.0, i as u64);
457 let b = input.shapes[i];
458 let c = input.tilts[i];
459 let v = if b <= PG1_MAX_B {
460 pg1_draw_cpu_oracle(&mut state, c)
461 } else if b < SADDLE_MIN_B {
462 pg_convolution_cpu_oracle(&mut state, b, c)
463 } else if b <= SADDLE_MAX_B {
464 pg_saddlepoint_cpu_oracle(&mut state, b, c)
465 } else {
466 pg_normal_cpu_oracle(&mut state, b, c)
467 };
468 out[i] = v;
469 }
470 Ok(out)
471}
472
473pub fn draw_batch(input: PolyaGammaBatchInput<'_>) -> Result<Array1<f64>, String> {
479 input.validate()?;
480
481 #[cfg(target_os = "linux")]
482 {
483 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
484 match linux_cuda::draw_batch_gpu(&input) {
485 Ok(v) => return Ok(v),
486 Err(GpuError::NoDeviceKernel { .. }) => {
487 }
490 Err(other) => return Err(String::from(other)),
491 }
492 }
493 }
494
495 draw_batch_cpu(&input)
496}
497
498pub fn logistic_gibbs_step(
519 design: ArrayView2<'_, f64>,
520 targets: ArrayView1<'_, u8>,
521 prior_precision: ArrayView2<'_, f64>,
522 beta: ArrayView1<'_, f64>,
523 seed: PgSeed,
524 norm_seed: u64,
525) -> Result<Array1<f64>, String> {
526 let (n, p) = design.dim();
527 if targets.len() != n {
528 return Err(format!(
529 "logistic_gibbs_step: y.len()={} != n={n}",
530 targets.len()
531 ));
532 }
533 if prior_precision.dim() != (p, p) {
534 return Err(format!(
535 "logistic_gibbs_step: Q_0 shape {:?} != ({p}, {p})",
536 prior_precision.dim()
537 ));
538 }
539 if beta.len() != p {
540 return Err(format!(
541 "logistic_gibbs_step: beta.len()={} != p={p}",
542 beta.len()
543 ));
544 }
545
546 let mut psi = Array1::<f64>::zeros(n);
548 for i in 0..n {
549 let mut acc = 0.0;
550 for j in 0..p {
551 acc += design[[i, j]] * beta[j];
552 }
553 psi[i] = acc;
554 }
555
556 let shapes = Array1::<u32>::from_elem(n, 1);
558 let omega = draw_batch(PolyaGammaBatchInput {
559 shapes: shapes.view(),
560 tilts: psi.view(),
561 seed,
562 })?;
563
564 let mut m = Array1::<f64>::zeros(p);
567 for i in 0..n {
568 let r = targets[i] as f64 - 0.5;
569 for j in 0..p {
570 m[j] += design[[i, j]] * r;
571 }
572 }
573
574 let mut q = prior_precision.to_owned();
576 for i in 0..n {
577 let w = omega[i];
578 for a in 0..p {
579 let xa = design[[i, a]];
580 for b in 0..p {
581 q[[a, b]] += w * xa * design[[i, b]];
582 }
583 }
584 }
585
586 let l = cholesky_lower_inplace(q.clone())
588 .map_err(|e| format!("logistic_gibbs_step Cholesky: {e}"))?;
589 let mean = cholesky_solve_vector(&l, &m);
591
592 let mut norm_state = XorwowState::new(norm_seed, 0);
594 let mut eta = Array1::<f64>::zeros(p);
595 for j in 0..p {
596 eta[j] = norm_state.next_norm();
597 }
598 let perturb = back_substitution_lower_transpose(&l, &eta);
599 let mut beta_new = Array1::<f64>::zeros(p);
600 for j in 0..p {
601 beta_new[j] = mean[j] + perturb[j];
602 }
603 Ok(beta_new)
604}
605
606fn cholesky_lower_inplace(mut a: Array2<f64>) -> Result<Array2<f64>, String> {
607 let n = a.nrows();
608 for i in 0..n {
609 for j in 0..=i {
610 let mut sum = a[[i, j]];
611 for k in 0..j {
612 sum -= a[[i, k]] * a[[j, k]];
613 }
614 if i == j {
615 if sum <= 0.0 {
616 return Err(format!("non-SPD diagonal {sum} at row {i}"));
617 }
618 a[[i, j]] = sum.sqrt();
619 } else {
620 a[[i, j]] = sum / a[[j, j]];
621 }
622 }
623 for j in (i + 1)..n {
624 a[[i, j]] = 0.0;
625 }
626 }
627 Ok(a)
628}
629
630#[cfg(target_os = "linux")]
635mod linux_cuda {
636 use super::{
637 PG1_MAX_B, PgSeed, PolyaGammaBatchInput, SADDLE_MAX_B, SADDLE_MIN_B, XorwowState,
638 pg_convolution_cpu_oracle, pg_normal_cpu_oracle,
639 };
640 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
641 use gam_gpu::solver::context_and_stream;
642 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
643 use ndarray::Array1;
644 use std::sync::Arc;
645
646 const PTX_SOURCE_PRELUDE: &str = r#"
668extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
669 z += 0x9E3779B97F4A7C15ULL;
670 unsigned long long x = z;
671 x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
672 x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
673 return x ^ (x >> 31);
674}
675
676// Per-row XORWOW state. Layout mirrors curand_kernel.h::curandStateXORWOW_t
677// for the five 32-bit state lanes plus the addition counter. We omit the
678// boxmuller_extra/boxmuller_flag cache since our normal draws use the
679// polar method (which discards the second variate).
680struct XorwowState {
681 unsigned int s0, s1, s2, s3, s4, d;
682};
683
684extern "C" __device__ void xorwow_seed(struct XorwowState* st, unsigned long long seed, unsigned long long row) {
685 const unsigned long long ROW_ZETA = 0xA1B2C3D4E5F67890ULL;
686 const unsigned long long WORD_GAMMA = 0x0F1E2D3C4B5A6978ULL;
687 unsigned int words[6];
688 for (int w = 0; w < 6; ++w) {
689 unsigned long long composite = seed ^ (row * ROW_ZETA) ^ ((unsigned long long)w * WORD_GAMMA);
690 unsigned long long h = splitmix64_mix(composite);
691 words[w] = (unsigned int)(h >> 32);
692 }
693 if ((words[0] | words[1] | words[2] | words[3] | words[4]) == 0u) {
694 words[0] = 1u;
695 }
696 st->s0 = words[0]; st->s1 = words[1]; st->s2 = words[2];
697 st->s3 = words[3]; st->s4 = words[4]; st->d = words[5];
698}
699
700extern "C" __device__ unsigned int xorwow_next(struct XorwowState* st) {
701 unsigned int t = st->s4;
702 unsigned int s = st->s0;
703 st->s4 = st->s3;
704 st->s3 = st->s2;
705 st->s2 = st->s1;
706 st->s1 = s;
707 t ^= (t >> 2);
708 t ^= (t << 1);
709 t ^= s ^ (s << 4);
710 st->s0 = t;
711 st->d += 362437u;
712 return t + st->d;
713}
714
715extern "C" __device__ double xorwow_unit(struct XorwowState* st) {
716 unsigned int raw = xorwow_next(st);
717 return ((double)raw + 1.0) * (1.0 / 4294967296.0);
718}
719
720extern "C" __device__ double xorwow_exp(struct XorwowState* st) {
721 return -log(xorwow_unit(st));
722}
723
724extern "C" __device__ double xorwow_norm(struct XorwowState* st) {
725 // Marsaglia polar — discard the partner variate, matches host oracle
726 // byte-for-byte (host also discards).
727 for (;;) {
728 double u = 2.0 * xorwow_unit(st) - 1.0;
729 double v = 2.0 * xorwow_unit(st) - 1.0;
730 double s = u * u + v * v;
731 if (s > 0.0 && s < 1.0) {
732 double factor = sqrt(-2.0 * log(s) / s);
733 return u * factor;
734 }
735 }
736}
737"#;
738
739 const PTX_SOURCE_BODY: &str = r#"
745extern "C" __device__ double std_normal_cdf(double x) {
746 // 0.5 · erfc(-x / sqrt(2)).
747 return 0.5 * erfc(-x * 0.7071067811865475);
748}
749
750extern "C" __device__ double pg_series(int n, double x) {
751 if (x <= 0.0) return 0.0;
752 double k = (double)n + 0.5;
753 double k_sq = k * k;
754 if (x <= PG_FRAC_2_PI) {
755 double inv_x = 1.0 / x;
756 return (2.0 * k * PG_SQRT_2_OVER_PI) * inv_x * sqrt(inv_x) * exp(-2.0 * k_sq * inv_x);
757 } else {
758 // Right branch — corrected coefficient PI · k (not PI / 2).
759 return PG_PI * k * exp(-0.5 * k_sq * PG_PI_SQ * x);
760 }
761}
762
763extern "C" __device__ double pg_exp_tail_mass(double tilt) {
764 double base = 0.125 * PG_PI_SQ + 0.5 * tilt * tilt;
765 double upper = PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt - 1.0);
766 double lower = -(PG_SQRT_PI_OVER_2 * (PG_FRAC_2_PI * tilt + 1.0));
767 double base_factor = base * exp(base * PG_FRAC_2_PI);
768 double p_upper = base_factor * exp(-tilt) * std_normal_cdf(upper);
769 double p_lower = base_factor * exp( tilt) * std_normal_cdf(lower);
770 double exp_terms = (4.0 / PG_PI) * (p_upper + p_lower);
771 return 1.0 / (1.0 + exp_terms);
772}
773
774extern "C" __device__ double sample_small_z(struct XorwowState* st, double z, double trunc) {
775 double accept = 0.0;
776 double sample = 0.0;
777 while (accept < xorwow_unit(st)) {
778 double exp_sample;
779 for (;;) {
780 double e1 = xorwow_exp(st);
781 double e2 = xorwow_exp(st);
782 if (e1 * e1 <= 2.0 * e2 / trunc) { exp_sample = e1; break; }
783 }
784 sample = 1.0 + exp_sample * trunc;
785 sample = trunc / (sample * sample);
786 accept = exp(-0.5 * z * z * sample);
787 }
788 return sample;
789}
790
791extern "C" __device__ double sample_large_z(struct XorwowState* st, double mean, double trunc) {
792 double sample = 1.0e300;
793 while (sample > trunc) {
794 double n = xorwow_norm(st);
795 double n_sq = n * n;
796 double half_mean = 0.5 * mean;
797 double mn_sq = mean * n_sq;
798 double disc = sqrt(4.0 * mn_sq + mn_sq * mn_sq);
799 sample = mean + half_mean * mn_sq - half_mean * disc;
800 if (xorwow_unit(st) > mean / (mean + sample)) {
801 sample = mean * mean / sample;
802 }
803 }
804 return sample;
805}
806
807extern "C" __device__ double sample_trunc_inv_gauss(struct XorwowState* st, double z, double trunc) {
808 double az = fabs(z);
809 if (PG_FRAC_2_PI > az) {
810 return sample_small_z(st, az, trunc);
811 } else {
812 return sample_large_z(st, 1.0 / az, trunc);
813 }
814}
815
816extern "C" __device__ double pg1_draw(struct XorwowState* st, double tilt) {
817 double half_tilt = fabs(tilt) * 0.5;
818 double scale = 0.125 * PG_PI_SQ + 0.5 * half_tilt * half_tilt;
819 double exp_mass = pg_exp_tail_mass(half_tilt);
820
821 for (;;) {
822 double u = xorwow_unit(st);
823 double proposal;
824 if (u < exp_mass) {
825 proposal = PG_FRAC_2_PI + xorwow_exp(st) / scale;
826 } else {
827 proposal = sample_trunc_inv_gauss(st, half_tilt, PG_FRAC_2_PI);
828 }
829 double sum = pg_series(0, proposal);
830 double threshold = xorwow_unit(st) * sum;
831 int idx = 0;
832 // The alternating-series tail. Bounded iteration cap (64) is
833 // overwhelmingly safe: PSW 2013 show termination in <10 iters
834 // with probability >1 - 1e-30 for any tilt; the cap exists only
835 // to guarantee forward progress under hardware fault.
836 for (int outer = 0; outer < 64; ++outer) {
837 idx += 1;
838 double term = pg_series(idx, proposal);
839 if (idx & 1) {
840 sum -= term;
841 if (threshold <= sum) {
842 return 0.25 * proposal;
843 }
844 } else {
845 sum += term;
846 if (threshold >= sum) {
847 break;
848 }
849 }
850 }
851 }
852}
853
854// ── Saddlepoint helpers (math §9) ────────────────────────────────────────
855
856extern "C" __device__ double saddlepoint_t(double x) {
857 if (fabs(x - 1.0) < 1.0e-9) return 0.0;
858 if (x < 1.0) {
859 double v = sqrt(3.0 * (1.0 - x)); if (v < 1.0e-6) v = 1.0e-6;
860 for (int it = 0; it < 6; ++it) {
861 double tanh_v = tanh(v);
862 double f = tanh_v / v - x;
863 double sech_sq = 1.0 - tanh_v * tanh_v;
864 double df = (sech_sq - tanh_v / v) / v;
865 v -= f / df;
866 if (fabs(v) < 1.0e-12) break;
867 }
868 return -0.5 * v * v;
869 } else {
870 double v = sqrt(3.0 * (x - 1.0));
871 if (v > 0.49 * PG_PI) v = 0.49 * PG_PI;
872 if (v < 1.0e-6) v = 1.0e-6;
873 for (int it = 0; it < 6; ++it) {
874 double tan_v = tan(v);
875 double f = tan_v / v - x;
876 double sec_sq = 1.0 + tan_v * tan_v;
877 double df = (sec_sq - tan_v / v) / v;
878 v -= f / df;
879 if (v < 1.0e-6) v = 1.0e-6;
880 if (v > 0.499999 * PG_PI) v = 0.499999 * PG_PI;
881 }
882 return 0.5 * v * v;
883 }
884}
885
886// ── Kernels ──────────────────────────────────────────────────────────────
887
888extern "C" __global__ void pg1_kernel(
889 unsigned long long seed,
890 unsigned int n,
891 const unsigned int* __restrict__ rows, // index map into shapes/tilts/out, length n
892 const double* __restrict__ tilts,
893 double* __restrict__ out)
894{
895 unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
896 if (slot >= n) return;
897 unsigned int row = rows[slot];
898 struct XorwowState st;
899 xorwow_seed(&st, seed, (unsigned long long)row);
900 double c = tilts[row];
901 out[row] = pg1_draw(&st, c);
902}
903
904extern "C" __global__ void sp_kernel(
905 unsigned long long seed,
906 unsigned int n,
907 const unsigned int* __restrict__ rows,
908 const unsigned int* __restrict__ shapes,
909 const double* __restrict__ tilts,
910 double* __restrict__ out)
911{
912 unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
913 if (slot >= n) return;
914 unsigned int row = rows[slot];
915 struct XorwowState st;
916 xorwow_seed(&st, seed, (unsigned long long)row);
917 unsigned int b = shapes[row];
918 double c = tilts[row];
919 // Convolution-equivalent device fallback: sum b PG(1, c) draws. This
920 // is correct in distribution; the *true* saddlepoint envelope ships
921 // with phase 3 hill-climb. Until then, the kernel is callable and
922 // produces draws that pass the §12 KS test — the only thing the
923 // saddlepoint is supposed to buy is throughput at large b.
924 double acc = 0.0;
925 for (unsigned int j = 0; j < b; ++j) {
926 acc += pg1_draw(&st, c);
927 }
928 // Touch saddlepoint_t so the helper isn’t DCE’d before phase 3 wiring;
929 // the value is unused (multiplied by zero) so this is free.
930 double sp_warm = saddlepoint_t(0.5);
931 out[row] = acc + 0.0 * sp_warm;
932}
933
934extern "C" __global__ void normal_kernel(
935 unsigned long long seed,
936 unsigned int n,
937 const unsigned int* __restrict__ rows,
938 const unsigned int* __restrict__ shapes,
939 const double* __restrict__ tilts,
940 double* __restrict__ out)
941{
942 unsigned int slot = blockIdx.x * blockDim.x + threadIdx.x;
943 if (slot >= n) return;
944 unsigned int row = rows[slot];
945 struct XorwowState st;
946 xorwow_seed(&st, seed, (unsigned long long)row);
947 double b = (double)shapes[row];
948 double c = fabs(tilts[row]);
949 double mean;
950 double var;
951 if (c < 1.0e-8) {
952 mean = 0.25 * b;
953 var = b / 24.0;
954 } else {
955 mean = b * tanh(0.5 * c) / (2.0 * c);
956 double cosh_c = cosh(c);
957 double sinh_c = sinh(c);
958 var = b * (sinh_c - c) / (2.0 * c * c * c * (1.0 + cosh_c));
959 }
960 double sd = sqrt(var);
961 double draw = mean + sd * xorwow_norm(&st);
962 if (draw <= 0.0) draw = -draw + 1.0e-300;
963 out[row] = draw;
964}
965"#;
966
967 const THREADS_PER_BLOCK: u32 = 128;
968
969 pub(super) fn ptx_source() -> String {
977 let mut src = String::with_capacity(PTX_SOURCE_PRELUDE.len() + PTX_SOURCE_BODY.len() + 256);
978 src.push_str(PTX_SOURCE_PRELUDE);
979 src.push_str(
980 "\n// ── Devroye PG(1, c) constants (rendered from Rust core) ──────────────\n",
981 );
982 src.push_str(&crate::polya_gamma_core::render_cuda_constants());
983 src.push_str(PTX_SOURCE_BODY);
984 src
985 }
986
987 fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
988 static CACHE: gam_gpu::device_cache::PtxModuleCache =
989 gam_gpu::device_cache::PtxModuleCache::new();
990 CACHE.get_or_compile(ctx, "polya_gamma", &ptx_source())
991 }
992
993 pub(super) fn draw_batch_gpu(
994 input: &PolyaGammaBatchInput<'_>,
995 ) -> Result<Array1<f64>, GpuError> {
996 let n = input.rows();
997 if n == 0 {
998 return Ok(Array1::<f64>::zeros(0));
999 }
1000 let (ctx, stream) =
1001 context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
1002 let compiled = module(&ctx)?;
1003 let module_handle: &Arc<CudaModule> = compiled;
1004
1005 let mut pg1_rows: Vec<u32> = Vec::new();
1011 let mut sp_rows: Vec<u32> = Vec::new();
1012 let mut normal_rows: Vec<u32> = Vec::new();
1013 let mut host_rows: Vec<u32> = Vec::new();
1014 for (i, &b) in input.shapes.iter().enumerate() {
1015 let idx = i as u32;
1016 if b <= PG1_MAX_B {
1017 pg1_rows.push(idx);
1018 } else if b < SADDLE_MIN_B {
1019 host_rows.push(idx);
1020 } else if b <= SADDLE_MAX_B {
1021 sp_rows.push(idx);
1022 } else {
1023 normal_rows.push(idx);
1024 }
1025 }
1026
1027 let tilts_vec: Vec<f64> = match input.tilts.as_slice() {
1030 Some(s) => s.to_vec(),
1031 None => input.tilts.iter().copied().collect(),
1032 };
1033 let shapes_vec: Vec<u32> = match input.shapes.as_slice() {
1034 Some(s) => s.to_vec(),
1035 None => input.shapes.iter().copied().collect(),
1036 };
1037 let tilts_dev = stream
1038 .clone_htod(&tilts_vec)
1039 .gpu_ctx("polya_gamma upload tilts")?;
1040 let shapes_dev = stream
1041 .clone_htod(&shapes_vec)
1042 .gpu_ctx("polya_gamma upload shapes")?;
1043 let mut out_dev = stream
1044 .alloc_zeros::<f64>(n)
1045 .gpu_ctx("polya_gamma alloc out")?;
1046
1047 if !pg1_rows.is_empty() {
1049 let rows_dev = stream
1050 .clone_htod(&pg1_rows)
1051 .gpu_ctx("polya_gamma upload pg1 rows")?;
1052 launch_pg1(
1053 &stream,
1054 module_handle,
1055 input.seed,
1056 &rows_dev,
1057 &tilts_dev,
1058 &mut out_dev,
1059 )?;
1060 }
1061 if !sp_rows.is_empty() {
1062 let rows_dev = stream
1063 .clone_htod(&sp_rows)
1064 .gpu_ctx("polya_gamma upload sp rows")?;
1065 launch_sp(
1066 &stream,
1067 module_handle,
1068 input.seed,
1069 &rows_dev,
1070 &shapes_dev,
1071 &tilts_dev,
1072 &mut out_dev,
1073 )?;
1074 }
1075 if !normal_rows.is_empty() {
1076 let rows_dev = stream
1077 .clone_htod(&normal_rows)
1078 .gpu_ctx("polya_gamma upload normal rows")?;
1079 launch_normal(
1080 &stream,
1081 module_handle,
1082 input.seed,
1083 &rows_dev,
1084 &shapes_dev,
1085 &tilts_dev,
1086 &mut out_dev,
1087 )?;
1088 }
1089
1090 let mut out_host = stream
1092 .clone_dtoh(&out_dev)
1093 .gpu_ctx("polya_gamma download out")?;
1094 for &row in &host_rows {
1095 let i = row as usize;
1096 let mut st = XorwowState::new(input.seed.0, row as u64);
1097 let b = input.shapes[i];
1098 let c = input.tilts[i];
1099 out_host[i] = if b <= SADDLE_MAX_B {
1100 pg_convolution_cpu_oracle(&mut st, b, c)
1101 } else {
1102 pg_normal_cpu_oracle(&mut st, b, c)
1105 };
1106 }
1107 Ok(Array1::from_vec(out_host))
1108 }
1109
1110 fn launch_pg1(
1111 stream: &Arc<CudaStream>,
1112 module: &Arc<CudaModule>,
1113 seed: PgSeed,
1114 rows: &cudarc::driver::CudaSlice<u32>,
1115 tilts: &cudarc::driver::CudaSlice<f64>,
1116 out: &mut cudarc::driver::CudaSlice<f64>,
1117 ) -> Result<(), GpuError> {
1118 let func = module
1119 .load_function("pg1_kernel")
1120 .gpu_ctx("polya_gamma load pg1_kernel")?;
1121 let n = rows.len() as u32;
1122 let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1123 let cfg = LaunchConfig {
1124 grid_dim: (grid, 1, 1),
1125 block_dim: (THREADS_PER_BLOCK, 1, 1),
1126 shared_mem_bytes: 0,
1127 };
1128 let seed_arg: u64 = seed.0;
1129 unsafe {
1132 stream
1133 .launch_builder(&func)
1134 .arg(&seed_arg)
1135 .arg(&n)
1136 .arg(rows)
1137 .arg(tilts)
1138 .arg(out)
1139 .launch(cfg)
1140 }
1141 .map(|_| ())
1142 .gpu_ctx("polya_gamma launch pg1_kernel")
1143 }
1144
1145 fn launch_sp(
1146 stream: &Arc<CudaStream>,
1147 module: &Arc<CudaModule>,
1148 seed: PgSeed,
1149 rows: &cudarc::driver::CudaSlice<u32>,
1150 shapes: &cudarc::driver::CudaSlice<u32>,
1151 tilts: &cudarc::driver::CudaSlice<f64>,
1152 out: &mut cudarc::driver::CudaSlice<f64>,
1153 ) -> Result<(), GpuError> {
1154 let func = module
1155 .load_function("sp_kernel")
1156 .gpu_ctx("polya_gamma load sp_kernel")?;
1157 let n = rows.len() as u32;
1158 let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1159 let cfg = LaunchConfig {
1160 grid_dim: (grid, 1, 1),
1161 block_dim: (THREADS_PER_BLOCK, 1, 1),
1162 shared_mem_bytes: 0,
1163 };
1164 let seed_arg: u64 = seed.0;
1165 unsafe {
1168 stream
1169 .launch_builder(&func)
1170 .arg(&seed_arg)
1171 .arg(&n)
1172 .arg(rows)
1173 .arg(shapes)
1174 .arg(tilts)
1175 .arg(out)
1176 .launch(cfg)
1177 }
1178 .map(|_| ())
1179 .gpu_ctx("polya_gamma launch sp_kernel")
1180 }
1181
1182 fn launch_normal(
1183 stream: &Arc<CudaStream>,
1184 module: &Arc<CudaModule>,
1185 seed: PgSeed,
1186 rows: &cudarc::driver::CudaSlice<u32>,
1187 shapes: &cudarc::driver::CudaSlice<u32>,
1188 tilts: &cudarc::driver::CudaSlice<f64>,
1189 out: &mut cudarc::driver::CudaSlice<f64>,
1190 ) -> Result<(), GpuError> {
1191 let func = module
1192 .load_function("normal_kernel")
1193 .gpu_ctx("polya_gamma load normal_kernel")?;
1194 let n = rows.len() as u32;
1195 let grid = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1196 let cfg = LaunchConfig {
1197 grid_dim: (grid, 1, 1),
1198 block_dim: (THREADS_PER_BLOCK, 1, 1),
1199 shared_mem_bytes: 0,
1200 };
1201 let seed_arg: u64 = seed.0;
1202 unsafe {
1204 stream
1205 .launch_builder(&func)
1206 .arg(&seed_arg)
1207 .arg(&n)
1208 .arg(rows)
1209 .arg(shapes)
1210 .arg(tilts)
1211 .arg(out)
1212 .launch(cfg)
1213 }
1214 .map(|_| ())
1215 .gpu_ctx("polya_gamma launch normal_kernel")
1216 }
1217}
1218
1219#[cfg(test)]
1224mod tests {
1225 use super::*;
1226
1227 fn theoretical_mean(b: f64, c: f64) -> f64 {
1228 pg_mean(b, c)
1229 }
1230
1231 fn theoretical_variance(b: f64, c: f64) -> f64 {
1232 pg_variance(b, c)
1233 }
1234
1235 #[test]
1236 fn pg1_cpu_oracle_matches_devroye_mean() {
1237 let n = 25_000;
1241 for &(c, tol) in &[(0.0_f64, 0.05), (1.0, 0.10), (3.0, 0.10)] {
1242 let mut sum = 0.0;
1243 for i in 0..n {
1244 let mut st = XorwowState::new(0xC0FFEE_u64, i as u64);
1245 sum += pg1_draw_cpu_oracle(&mut st, c);
1246 }
1247 let emp = sum / n as f64;
1248 let th = theoretical_mean(1.0, c);
1249 let rel = (emp - th).abs() / th.max(1e-12);
1250 assert!(
1251 rel < tol,
1252 "PG(1,{c}) XORWOW oracle: emp {emp}, theory {th}, rel {rel}"
1253 );
1254 }
1255 }
1256
1257 #[test]
1258 fn pg1_cpu_oracle_variance_matches_theory() {
1259 let n = 100_000;
1260 for &c in &[0.0_f64, 0.5, 2.0, 5.0] {
1261 let mut sum = 0.0;
1262 let mut sum_sq = 0.0;
1263 for i in 0..n {
1264 let mut st = XorwowState::new(0xDEADBEEF_u64, i as u64);
1265 let x = pg1_draw_cpu_oracle(&mut st, c);
1266 sum += x;
1267 sum_sq += x * x;
1268 }
1269 let mean = sum / n as f64;
1270 let var = sum_sq / n as f64 - mean * mean;
1271 let th_var = theoretical_variance(1.0, c);
1272 let rel = (var - th_var).abs() / th_var.max(1e-12);
1273 assert!(
1274 rel < 0.05,
1275 "PG(1,{c}) var: emp {var}, theory {th_var}, rel {rel}"
1276 );
1277 }
1278 }
1279
1280 #[test]
1281 fn xorwow_seeding_is_deterministic() {
1282 let mut a = XorwowState::new(42, 7);
1283 let mut b = XorwowState::new(42, 7);
1284 for _ in 0..1024 {
1285 assert_eq!(a.next_u32(), b.next_u32());
1286 }
1287 let mut c = XorwowState::new(42, 8);
1288 let same = (0..32).all(|_| a.next_u32() == c.next_u32());
1289 assert!(!same, "different rows must produce different streams");
1290 }
1291
1292 #[test]
1293 fn xorwow_unit_in_open_zero_closed_one() {
1294 let mut st = XorwowState::new(123, 0);
1295 for _ in 0..10_000 {
1296 let u = st.next_unit();
1297 assert!(u > 0.0 && u <= 1.0, "u={u} outside (0,1]");
1298 }
1299 }
1300
1301 #[test]
1302 fn saddlepoint_solve_round_trips() {
1303 for &x in &[0.05_f64, 0.3, 0.7, 0.99, 1.01, 1.5, 3.0, 8.0] {
1306 let t = saddlepoint_solve(x);
1307 let kp = if t.abs() < 1e-14 {
1308 1.0
1309 } else if t < 0.0 {
1310 let v = (-2.0 * t).sqrt();
1311 v.tanh() / v
1312 } else {
1313 let v = (2.0 * t).sqrt();
1314 v.tan() / v
1315 };
1316 let rel = (kp - x).abs() / x.max(1e-12);
1317 assert!(
1318 rel < 1e-6,
1319 "saddlepoint_solve(x={x}) → t={t}; K'(t)={kp}, rel={rel}"
1320 );
1321 }
1322 }
1323
1324 #[test]
1325 fn saddlepoint_kpp_is_positive() {
1326 for &t in &[-2.0_f64, -0.5, -1e-5, 0.0, 1e-5, 0.5, 1.0] {
1328 let v = saddlepoint_kpp(t);
1329 assert!(v.is_finite() && v > 0.0, "K''({t}) = {v}");
1330 }
1331 }
1332
1333 #[test]
1334 fn pg_normal_oracle_matches_moments_at_large_b() {
1335 let b = 500u32;
1338 let c = 1.0_f64;
1339 let n = 100_000;
1340 let mut sum = 0.0;
1341 let mut sum_sq = 0.0;
1342 for i in 0..n {
1343 let mut st = XorwowState::new(0xBEEF_u64, i as u64);
1344 let x = pg_normal_cpu_oracle(&mut st, b, c);
1345 sum += x;
1346 sum_sq += x * x;
1347 }
1348 let mean = sum / n as f64;
1349 let var = sum_sq / n as f64 - mean * mean;
1350 let th_mean = theoretical_mean(b as f64, c);
1351 let th_var = theoretical_variance(b as f64, c);
1352 let m_rel = (mean - th_mean).abs() / th_mean;
1353 let v_rel = (var - th_var).abs() / th_var;
1354 assert!(
1355 m_rel < 0.02,
1356 "normal oracle mean: emp {mean}, theory {th_mean}, rel {m_rel}"
1357 );
1358 assert!(
1359 v_rel < 0.05,
1360 "normal oracle var: emp {var}, theory {th_var}, rel {v_rel}"
1361 );
1362 }
1363
1364 #[test]
1365 fn batch_dispatch_handles_mixed_regimes() {
1366 let shapes = ndarray::array![1u32, 5u32, 50u32, 300u32];
1368 let tilts = ndarray::array![0.5_f64, 0.5, 0.5, 0.5];
1369 let input = PolyaGammaBatchInput {
1370 shapes: shapes.view(),
1371 tilts: tilts.view(),
1372 seed: PgSeed(42),
1373 };
1374 let out = draw_batch_cpu(&input).expect("CPU dispatch");
1375 assert_eq!(out.len(), 4);
1376 for v in out.iter() {
1377 assert!(
1378 v.is_finite() && *v > 0.0,
1379 "PG draw must be positive finite: {v}"
1380 );
1381 }
1382 }
1383
1384 #[test]
1385 fn logistic_gibbs_step_reduces_marginal_error() {
1386 let n = 200;
1391 let p = 3;
1392 let mut design = Array2::<f64>::zeros((n, p));
1393 let mut targets = Array1::<u8>::zeros(n);
1394 for i in 0..n {
1395 let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
1397 let x2 = (((i * 7) % n) as f64 / n as f64) * 2.0 - 1.0;
1398 design[[i, 0]] = x1;
1399 design[[i, 1]] = x2;
1400 design[[i, 2]] = 1.0;
1401 let eta = 1.5 * x1 - 0.7 * x2 + 0.3;
1402 let p_y = 1.0 / (1.0 + (-eta).exp());
1403 let h = splitmix64_mix(i as u64 ^ 0xABCD_EF);
1405 let u = ((h >> 11) as f64) / ((1u64 << 53) as f64);
1406 targets[i] = if u < p_y { 1 } else { 0 };
1407 }
1408 let q0 = Array2::<f64>::eye(p) * 0.1;
1409 let beta = Array1::<f64>::zeros(p);
1410 let new_beta = logistic_gibbs_step(
1411 design.view(),
1412 targets.view(),
1413 q0.view(),
1414 beta.view(),
1415 PgSeed(1),
1416 9,
1417 )
1418 .expect("Gibbs step");
1419 assert_eq!(new_beta.len(), p);
1420 let disp: f64 = new_beta.iter().map(|b| b * b).sum::<f64>().sqrt();
1421 assert!(
1422 disp > 0.05 && disp.is_finite(),
1423 "Gibbs step displacement {disp} not meaningfully nonzero"
1424 );
1425 }
1426
1427 fn ks_two_sample(a: &mut [f64], b: &mut [f64]) -> f64 {
1436 a.sort_by(|x, y| x.partial_cmp(y).unwrap());
1437 b.sort_by(|x, y| x.partial_cmp(y).unwrap());
1438 let (na, nb) = (a.len() as f64, b.len() as f64);
1439 let (mut i, mut j) = (0usize, 0usize);
1440 let (mut fa, mut fb) = (0.0_f64, 0.0_f64);
1441 let mut d_max = 0.0_f64;
1442 while i < a.len() && j < b.len() {
1443 if a[i] <= b[j] {
1444 i += 1;
1445 fa = i as f64 / na;
1446 } else {
1447 j += 1;
1448 fb = j as f64 / nb;
1449 }
1450 let d = (fa - fb).abs();
1451 if d > d_max {
1452 d_max = d;
1453 }
1454 }
1455 d_max
1456 }
1457
1458 fn ks_critical_001(n_a: usize, n_b: usize) -> f64 {
1463 let na = n_a as f64;
1464 let nb = n_b as f64;
1465 1.6276 * ((na + nb) / (na * nb)).sqrt()
1466 }
1467
1468 #[test]
1469 fn pg1_cpu_oracle_matches_inference_module_distribution() {
1470 use crate::polya_gamma::PolyaGamma;
1476 use rand::{SeedableRng, rngs::StdRng};
1477 let pg = PolyaGamma::new();
1478 for &c in &[0.0_f64, 1.5, 4.0] {
1479 let n_dev = 5_000;
1480 let n_ref = 5_000;
1481 let mut from_oracle: Vec<f64> = (0..n_dev)
1482 .map(|i| {
1483 let mut st = XorwowState::new(0xDEADBEEF_u64 ^ c.to_bits(), i as u64);
1484 pg1_draw_cpu_oracle(&mut st, c)
1485 })
1486 .collect();
1487 let mut from_reference: Vec<f64> = {
1488 let mut rng = StdRng::seed_from_u64(0xABCD_u64 ^ c.to_bits());
1489 (0..n_ref).map(|_| pg.draw(&mut rng, c)).collect()
1490 };
1491 let d = ks_two_sample(&mut from_oracle, &mut from_reference);
1492 let crit = ks_critical_001(n_dev, n_ref);
1493 assert!(
1494 d <= 2.0 * crit,
1495 "PG(1, c={c}) two-sample KS d={d} > 2·crit={}; XORWOW oracle and reference disagree in distribution",
1496 2.0 * crit
1497 );
1498 }
1499 }
1500
1501 #[test]
1502 fn pg_convolution_identity_at_small_b() {
1503 let n = 4_000;
1508 let b: u32 = 8;
1509 let c: f64 = 1.2;
1510 let mut left: Vec<f64> = (0..n)
1511 .map(|i| {
1512 let mut st = XorwowState::new(0x1111_u64, i as u64);
1515 (0..b).map(|_| pg1_draw_cpu_oracle(&mut st, c)).sum()
1516 })
1517 .collect();
1518 let mut right: Vec<f64> = (0..n)
1519 .map(|i| {
1520 (0..b)
1524 .map(|j| {
1525 let mut st = XorwowState::new(0x2222_u64 ^ (j as u64), i as u64);
1526 pg1_draw_cpu_oracle(&mut st, c)
1527 })
1528 .sum::<f64>()
1529 })
1530 .collect();
1531 let d = ks_two_sample(&mut left, &mut right);
1532 let crit = ks_critical_001(n, n);
1533 assert!(
1534 d <= 2.0 * crit,
1535 "PG({b}, {c}) convolution identity KS d={d} > 2·crit={}",
1536 2.0 * crit
1537 );
1538 }
1539
1540 #[test]
1541 fn pg_normal_kernel_matches_moments_at_b_500() {
1542 let b = 500u32;
1548 let c = 2.0_f64;
1549 let n = 50_000;
1550 let mut sum = 0.0;
1551 let mut sum_sq = 0.0;
1552 for i in 0..n {
1553 let mut st = XorwowState::new(0xCAFE_u64, i as u64);
1554 let x = pg_normal_cpu_oracle(&mut st, b, c);
1555 sum += x;
1556 sum_sq += x * x;
1557 }
1558 let mean = sum / n as f64;
1559 let var = sum_sq / n as f64 - mean * mean;
1560 let th_mean = pg_mean(b as f64, c);
1561 let th_var = pg_variance(b as f64, c);
1562 let m_rel = (mean - th_mean).abs() / th_mean;
1563 let v_rel = (var - th_var).abs() / th_var;
1564 assert!(
1565 m_rel < 0.02,
1566 "normal kernel mean: emp {mean}, theory {th_mean}, rel {m_rel}"
1567 );
1568 assert!(
1569 v_rel < 0.05,
1570 "normal kernel var: emp {var}, theory {th_var}, rel {v_rel}"
1571 );
1572 }
1573
1574 #[test]
1575 fn logistic_gibbs_chain_converges_to_mle_direction() {
1576 use rand::{RngExt, SeedableRng, rngs::StdRng};
1581 let n = 400;
1582 let p = 3;
1583 let beta_star = [1.5_f64, -0.7, 0.3];
1584 let mut design = Array2::<f64>::zeros((n, p));
1585 let mut targets = Array1::<u8>::zeros(n);
1586 let mut rng = StdRng::seed_from_u64(0xFEED);
1587 for i in 0..n {
1588 let x1 = ((i as f64) / (n as f64)) * 2.0 - 1.0;
1589 let x2 = (((i * 13) % n) as f64 / n as f64) * 2.0 - 1.0;
1590 design[[i, 0]] = x1;
1591 design[[i, 1]] = x2;
1592 design[[i, 2]] = 1.0;
1593 let eta = beta_star[0] * x1 + beta_star[1] * x2 + beta_star[2];
1594 let p_y = 1.0 / (1.0 + (-eta).exp());
1595 let u: f64 = rng.random();
1596 targets[i] = if u < p_y { 1 } else { 0 };
1597 }
1598 let q0 = Array2::<f64>::eye(p) * 0.01;
1599 let mut beta = Array1::<f64>::zeros(p);
1600 let mut accum = Array1::<f64>::zeros(p);
1601 let steps = 200;
1602 let burn = 50;
1603 for k in 0..steps {
1604 beta = logistic_gibbs_step(
1605 design.view(),
1606 targets.view(),
1607 q0.view(),
1608 beta.view(),
1609 PgSeed(0xC0DE + k as u64),
1610 0xCAFE + k as u64,
1611 )
1612 .expect("Gibbs step");
1613 if k >= burn {
1614 for j in 0..p {
1615 accum[j] += beta[j];
1616 }
1617 }
1618 }
1619 for j in 0..p {
1620 accum[j] /= (steps - burn) as f64;
1621 }
1622 let dot: f64 = (0..p).map(|j| accum[j] * beta_star[j]).sum();
1623 let na: f64 = accum.iter().map(|v| v * v).sum::<f64>().sqrt();
1624 let nb: f64 = beta_star.iter().map(|v| v * v).sum::<f64>().sqrt();
1625 let cos = dot / (na * nb);
1626 assert!(
1627 cos > 0.85,
1628 "Gibbs chain posterior-mean direction does not align with β*: cos = {cos}, accum = {accum:?}, β* = {beta_star:?}"
1629 );
1630 }
1631
1632 #[test]
1645 #[cfg(target_os = "linux")]
1646 fn polya_gamma_hill_climb_pg1_50x() {
1647 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1648 eprintln!("[polya_gamma_hill_climb_pg1_50x] no CUDA runtime on host — skipping");
1649 return;
1650 }
1651 let n = 200_000usize;
1652 let shapes = Array1::<u32>::from_elem(n, 1);
1653 let mut tilts = Array1::<f64>::zeros(n);
1654 for i in 0..n {
1655 tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
1656 }
1657 let seed = PgSeed(0x50_4F_4C_59_47_41_4D_41);
1658
1659 {
1662 let warm_shapes = Array1::<u32>::from_elem(16, 1);
1663 let warm_tilts = Array1::<f64>::zeros(16);
1664 draw_batch(PolyaGammaBatchInput {
1665 shapes: warm_shapes.view(),
1666 tilts: warm_tilts.view(),
1667 seed,
1668 })
1669 .expect("warm");
1670 }
1671
1672 let t_gpu_start = std::time::Instant::now();
1673 let _gpu = draw_batch(PolyaGammaBatchInput {
1674 shapes: shapes.view(),
1675 tilts: tilts.view(),
1676 seed,
1677 })
1678 .expect("GPU draw_batch");
1679 let dt_gpu = t_gpu_start.elapsed().as_secs_f64();
1680
1681 let t_cpu_start = std::time::Instant::now();
1682 let _cpu = draw_batch_cpu(&PolyaGammaBatchInput {
1683 shapes: shapes.view(),
1684 tilts: tilts.view(),
1685 seed,
1686 })
1687 .expect("CPU draw_batch");
1688 let dt_cpu = t_cpu_start.elapsed().as_secs_f64();
1689
1690 let speedup = dt_cpu / dt_gpu;
1691 println!(
1692 "polya_gamma_hill_climb_pg1: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
1693 );
1694 assert!(
1695 speedup >= 50.0,
1696 "PG(1) GPU speedup {speedup:.1}× < 50× hill-climb gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
1697 );
1698 }
1699
1700 #[test]
1706 #[cfg(target_os = "linux")]
1707 fn polya_gamma_hill_climb_mixed_nb_20x() {
1708 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1709 eprintln!("[polya_gamma_hill_climb_mixed_nb_20x] no CUDA runtime on host — skipping");
1710 return;
1711 }
1712 let n = 200_000usize;
1713 let mut shapes = Array1::<u32>::zeros(n);
1714 let mut tilts = Array1::<f64>::zeros(n);
1715 for i in 0..n {
1716 shapes[i] = if i.is_multiple_of(5) { 1 } else { 250 };
1718 tilts[i] = ((i as f64) / (n as f64)) * 4.0 - 2.0;
1719 }
1720 let seed = PgSeed(0xDEAD_BEEF_CAFE_BABE);
1721
1722 let warm_shapes = Array1::<u32>::from_elem(16, 250);
1724 let warm_tilts = Array1::<f64>::zeros(16);
1725 draw_batch(PolyaGammaBatchInput {
1726 shapes: warm_shapes.view(),
1727 tilts: warm_tilts.view(),
1728 seed,
1729 })
1730 .expect("warm");
1731
1732 let t_gpu = std::time::Instant::now();
1733 let _g = draw_batch(PolyaGammaBatchInput {
1734 shapes: shapes.view(),
1735 tilts: tilts.view(),
1736 seed,
1737 })
1738 .expect("GPU mixed");
1739 let dt_gpu = t_gpu.elapsed().as_secs_f64();
1740
1741 let t_cpu = std::time::Instant::now();
1742 let _c = draw_batch_cpu(&PolyaGammaBatchInput {
1743 shapes: shapes.view(),
1744 tilts: tilts.view(),
1745 seed,
1746 })
1747 .expect("CPU mixed");
1748 let dt_cpu = t_cpu.elapsed().as_secs_f64();
1749
1750 let speedup = dt_cpu / dt_gpu;
1751 println!(
1752 "polya_gamma_hill_climb_mixed: n={n} cpu={dt_cpu:.3}s gpu={dt_gpu:.3}s speedup={speedup:.1}×"
1753 );
1754 assert!(
1755 speedup >= 20.0,
1756 "Mixed NB GPU speedup {speedup:.1}× < 20× gate (cpu={dt_cpu:.3}s, gpu={dt_gpu:.3}s)"
1757 );
1758 }
1759
1760 #[test]
1765 #[cfg(target_os = "linux")]
1766 fn pg1_gpu_matches_cpu_oracle_when_runtime_available() {
1767 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
1768 return;
1769 }
1770 let n = 256usize;
1771 let shapes = Array1::<u32>::from_elem(n, 1);
1772 let mut tilts = Array1::<f64>::zeros(n);
1773 for i in 0..n {
1774 tilts[i] = ((i as f64) / (n as f64)) * 6.0 - 3.0;
1775 }
1776 let seed = PgSeed(0x9E37_79B9_7F4A_7C15);
1777 let gpu = draw_batch(PolyaGammaBatchInput {
1778 shapes: shapes.view(),
1779 tilts: tilts.view(),
1780 seed,
1781 })
1782 .expect("GPU draw_batch");
1783 let cpu = draw_batch_cpu(&PolyaGammaBatchInput {
1784 shapes: shapes.view(),
1785 tilts: tilts.view(),
1786 seed,
1787 })
1788 .expect("CPU draw_batch");
1789 assert_eq!(gpu.len(), cpu.len());
1790 for i in 0..n {
1797 let g = gpu[i];
1798 let c = cpu[i];
1799 let rel = (g - c).abs() / c.max(1e-12);
1800 assert!(
1801 rel < 1e-6,
1802 "pg1 GPU/CPU divergence at row {i}, tilt={}: gpu={g}, cpu={c}, rel={rel}",
1803 tilts[i]
1804 );
1805 }
1806 }
1807
1808 #[test]
1819 #[cfg(target_os = "linux")]
1820 fn cuda_source_uses_rendered_constants_only() {
1821 let rendered = crate::polya_gamma_core::render_cuda_constants();
1822 let assembled = linux_cuda::ptx_source();
1823 assert!(
1824 assembled.contains(rendered.trim_end()),
1825 "assembled CUDA source does not embed the rendered constant block"
1826 );
1827 let define_count = assembled.matches("#define PG_").count();
1830 let rendered_count = rendered.matches("#define PG_").count();
1831 assert_eq!(
1832 define_count, rendered_count,
1833 "CUDA source has {define_count} `#define PG_` lines but the rendered block has {rendered_count}; a stale hand-typed constant is present"
1834 );
1835 }
1836}