use crate::inference::polya_gamma_core::{PgRng, draw_pg1};
use rand::{Rng, RngExt};
use rand_distr::{Distribution, Exp as RandExp, Normal as RandNormal};
struct RandPgRng<'a, R: Rng + ?Sized> {
rng: &'a mut R,
exp: &'a RandExp<f64>,
std_norm_sampler: &'a RandNormal<f64>,
}
impl<R: Rng + ?Sized> PgRng for RandPgRng<'_, R> {
#[inline]
fn next_unit(&mut self) -> f64 {
self.rng.random::<f64>()
}
#[inline]
fn next_exp(&mut self) -> f64 {
self.exp.sample(self.rng)
}
#[inline]
fn next_norm(&mut self) -> f64 {
self.std_norm_sampler.sample(self.rng)
}
}
#[derive(Debug, Clone)]
pub struct PolyaGamma {
exp: RandExp<f64>,
std_norm_sampler: RandNormal<f64>,
}
impl PolyaGamma {
pub fn new() -> Self {
Self {
exp: RandExp::new(1.0).expect("Exp(1) valid"),
std_norm_sampler: RandNormal::new(0.0, 1.0).expect("N(0,1) valid"),
}
}
pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
let mut source = RandPgRng {
rng,
exp: &self.exp,
std_norm_sampler: &self.std_norm_sampler,
};
draw_pg1(&mut source, tilt)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::StdRng};
fn empirical_mean(c: f64, n: usize, seed: u64) -> f64 {
let pg = PolyaGamma::new();
let mut rng = StdRng::seed_from_u64(seed);
(0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
}
fn theoretical_mean(c: f64) -> f64 {
if c.abs() < 1e-12 {
0.25
} else {
(0.5 * c).tanh() / (2.0 * c)
}
}
#[test]
fn pg1_mean_matches_theory() {
let n = 25_000;
for (c, tol) in [(0.0, 0.05), (1.0, 0.10), (3.0, 0.10)] {
let emp = empirical_mean(c, n, 42);
let th = theoretical_mean(c);
assert!(
(emp - th).abs() / th.max(1e-12) < tol,
"PG(1,{c}): empirical {emp}, theory {th}",
);
}
}
fn theoretical_variance(c: f64) -> f64 {
if c.abs() < 1e-6 {
1.0 / 24.0
} else {
(c.sinh() - c) / (2.0 * c * c * c * (1.0 + c.cosh()))
}
}
#[test]
fn pg1_moments_high_precision() {
let pg = PolyaGamma::new();
let k = 1_000_000usize;
for &c in &[0.0_f64, 0.1, 1.0, 3.0, 10.0, 30.0] {
let mut rng = StdRng::seed_from_u64(0xC0FFEE ^ ((c.to_bits() as u64).wrapping_mul(7)));
let mut sum = 0.0_f64;
let mut sum_sq = 0.0_f64;
for _ in 0..k {
let s = pg.draw(&mut rng, c);
sum += s;
sum_sq += s * s;
}
let mean = sum / k as f64;
let var = sum_sq / k as f64 - mean * mean;
let th_mean = theoretical_mean(c);
let th_var = theoretical_variance(c);
let mean_rel = (mean - th_mean).abs() / th_mean.max(1e-12);
let var_rel = (var - th_var).abs() / th_var.max(1e-12);
assert!(
mean_rel < 5e-3,
"PG(1,{c}) mean: emp {mean:.6e}, theory {th_mean:.6e}, rel {mean_rel:.3e}",
);
assert!(
var_rel < 5e-3,
"PG(1,{c}) var: emp {var:.6e}, theory {th_var:.6e}, rel {var_rel:.3e}",
);
}
}
}