use rand::{Rng, SeedableRng, thread_rng};
use rng::RngDraw;
use statrs::distribution::{Exp, Gamma, InverseGamma, Normal, Uniform};
use std::f64::consts::PI;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
const PI_SQ: f64 = std::f64::consts::PI * std::f64::consts::PI;
const PI2_SQ_RECIP: f64 = 1.0 / (2.0 * PI_SQ);
#[derive(Debug, Clone)]
pub struct PolyaGamma {
exp: Exp,
std_norm: Normal,
unif: Uniform,
gamma: Gamma,
inv_gamma: Vec<InverseGamma>,
series_exp: Vec<Exp>,
shape: f64,
}
impl PolyaGamma {
pub fn new(shape: f64) -> Self {
assert!(shape > 0.0, "Shape parameter must be positive");
const PRECOMPUTE_K: usize = 50;
Self {
exp: Exp::new(1.0).expect("Exp(1) is always valid"),
std_norm: Normal::standard(),
unif: Uniform::standard(),
gamma: Gamma::new(shape, 1.0).expect("Gamma(1,1) is always valid"),
inv_gamma: (0..PRECOMPUTE_K)
.map(|k| {
{
let k = k as f64 + 0.5;
InverseGamma::new(0.5, 2.0 * k * k)
}
.expect("InverseGamma(0.5,2k^2) is always valid because k > 0.5")
})
.collect(),
series_exp: (0..PRECOMPUTE_K)
.map(|k| {
let k = k as f64 + 0.5;
Exp::new(k * k * PI_SQ / 2.0)
.expect("Exp(k^2 * PI^2 / 2) is always valid because k > 0.5")
})
.collect(),
shape,
}
}
pub fn set_shape(&mut self, shape: f64) {
self.shape = shape;
self.init_gamma(shape);
}
pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
self.draw_internal(rng, self.shape, tilt)
}
pub fn draw_vec<R: Rng + ?Sized>(&self, rng: &mut R, c: &[f64]) -> Vec<f64> {
let b = self.shape;
c.iter().map(|&c| self.draw_internal(rng, b, c)).collect()
}
#[cfg(feature = "rayon")]
pub fn draw_vec_par_deterministic<R: SeedableRng + Rng>(
&self,
rng: &mut R,
c: &[f64],
) -> Vec<f64> {
assert!(!c.is_empty(), "Input slice c must not be empty");
let b = self.shape;
let seed = rng.next_u64();
let chunk_size = 32;
let chunks = c.par_chunks(chunk_size);
let num_chunks = chunks.len();
let seeds: Vec<u64> = (0..num_chunks)
.map(|i| seed.wrapping_add(i as u64))
.collect();
chunks
.into_par_iter()
.zip(seeds.into_par_iter())
.flat_map(|(chunk, chunk_seed)| {
let mut rng = R::seed_from_u64(chunk_seed);
chunk
.iter()
.map(|&c_val| self.draw_internal(&mut rng, b, c_val))
.collect::<Vec<_>>()
})
.collect()
}
#[cfg(feature = "rayon")]
pub fn draw_vec_par(&self, c: &[f64]) -> Vec<f64> {
let b = self.shape;
c.into_par_iter()
.map_init(thread_rng, |rng, &ci| self.draw_internal(rng, b, ci))
.collect()
}
}
impl PolyaGamma {
#[inline]
fn draw_internal<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
assert!(b > 0.0, "Shape parameter b must be positive");
if b == 1.0 {
return self.sample_polya_gamma_devroye(rng, c);
}
let b_floor = b.floor();
if b == b_floor {
#[cfg(feature = "rayon")]
if b >= (rayon::current_num_threads() * 20) as f64 {
return self.draw_integer_b_par(b as usize, c);
}
return self.draw_integer_b(rng, b as usize, c);
}
self.draw_non_integer_b(rng, b, c)
}
fn draw_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: usize, c: f64) -> f64 {
(0..b)
.map(|_| self.sample_polya_gamma_devroye(rng, c))
.sum()
}
#[cfg(feature = "rayon")]
fn draw_integer_b_par(&self, b: usize, c: f64) -> f64 {
let threads = rayon::current_num_threads();
let base = b / threads;
let rem = b % threads;
(0..threads)
.into_par_iter()
.map_init(thread_rng, |rng, i| {
let count = base + if i < rem { 1 } else { 0 };
(0..count)
.map(|_| self.sample_polya_gamma_devroye(rng, c))
.sum::<f64>()
})
.sum()
}
fn draw_non_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
debug_assert!(b > 0.0, "`b` has to be strictly positive");
debug_assert!(
b.fract() != 0.0,
"`b` is an integer – use the integer routine"
);
debug_assert!(self.gamma.shape() == b);
let c2 = (c / (2.0 * PI)).powi(2);
let mut sum = 0.0;
const TOL: f64 = 1e-6;
let mut k: usize = 1;
loop {
let kf = k as f64 - 0.5; let den = kf * kf + c2; let g = self.sample_gamma(rng);
sum += g / den;
let next_kf = k as f64 + 0.5; let next_den = next_kf * next_kf + c2;
if b / next_den < TOL {
break;
}
k += 1;
}
sum * PI2_SQ_RECIP
}
fn init_gamma(&mut self, b: f64) {
self.gamma = Gamma::new(b, 1.0).expect("Gamma shape/scale parameters are valid");
}
}
mod devroye;
pub mod regression;
pub(crate) mod rng;
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
fn empirical_mean(b: f64, c: f64, n: usize, seed: u64) -> f64 {
let pg = PolyaGamma::new(b);
let mut rng = StdRng::seed_from_u64(seed);
(0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
}
fn theoretical_mean(b: f64, c: f64) -> f64 {
if c.abs() < 1e-12 {
b / 4.0
} else {
b * (0.5 * c).tanh() / (2.0 * c)
}
}
#[test]
fn non_integer_b_mean_matches_theory() {
let b = 1.7; let n = 25_000;
let emp0 = empirical_mean(b, 0.0, n, 1);
let th0 = theoretical_mean(b, 0.0);
assert!(
(emp0 - th0).abs() / th0 < 0.05,
"PG({}, 0): empirical {}, theory {}",
b,
emp0,
th0
);
let emp1 = empirical_mean(b, 1.0, n, 2);
let th1 = theoretical_mean(b, 1.0);
assert!(
(emp1 - th1).abs() / th1 < 0.10, "PG({}, 1): empirical {}, theory {}",
b,
emp1,
th1
);
}
}