use std::cell::UnsafeCell;
use num_traits::PrimInt;
use rand::Rng;
use rand_distr::Distribution;
use crate::simd_rng::SimdRng;
pub struct SimdBinomial<T: PrimInt> {
n: u32,
p: f64,
buffer: UnsafeCell<[T; 16]>,
index: UnsafeCell<usize>,
simd_rng: UnsafeCell<SimdRng>,
}
impl<T: PrimInt> SimdBinomial<T> {
pub fn new(n: u32, p: f64) -> Self {
Self::from_seed_source(n, p, &crate::simd_rng::Unseeded)
}
pub fn with_seed(n: u32, p: f64, seed: u64) -> Self {
Self::from_seed_source(n, p, &crate::simd_rng::Deterministic::new(seed))
}
pub fn from_seed_source(n: u32, p: f64, seed: &impl crate::simd_rng::SeedExt) -> Self {
Self {
n,
p,
buffer: UnsafeCell::new([T::zero(); 16]),
index: UnsafeCell::new(16),
simd_rng: UnsafeCell::new(seed.rng()),
}
}
#[inline]
pub fn sample_fast(&self) -> T {
let index = unsafe { &mut *self.index.get() };
if *index >= 16 {
self.refill_buffer_fast();
}
let buf = unsafe { &mut *self.buffer.get() };
let z = buf[*index];
*index += 1;
z
}
fn refill_buffer_fast(&self) {
let rng = unsafe { &mut *self.simd_rng.get() };
let buf = unsafe { &mut *self.buffer.get() };
self.fill_slice(rng, buf);
unsafe {
*self.index.get() = 0;
}
}
#[inline]
pub fn fill_slice_fast(&self, out: &mut [T]) {
let rng = unsafe { &mut *self.simd_rng.get() };
self.fill_slice(rng, out);
}
pub fn fill_slice<R: Rng + ?Sized>(&self, rng: &mut R, out: &mut [T]) {
if self.n == 0 {
for x in out.iter_mut() {
*x = T::zero();
}
return;
}
let p_eff = if self.p > 0.5 { 1.0 - self.p } else { self.p };
let np_eff = (self.n as f64) * p_eff;
if np_eff < 30.0 && p_eff > 0.0 && p_eff < 1.0 {
let log_q = (1.0 - p_eff).ln();
for x in out.iter_mut() {
let mut count: u32 = 0;
let mut total: i64 = 0;
loop {
let u: f64 = rng.random();
let g = (u.ln() / log_q).floor() as i64 + 1;
total += g;
if total > self.n as i64 {
break;
}
count += 1;
}
let result = if self.p > 0.5 { self.n - count } else { count };
*x = num_traits::cast(result).unwrap_or(T::zero());
}
return;
}
for x in out.iter_mut() {
let mut count = 0u32;
for _ in 0..self.n {
let u: f64 = rng.random();
if u < self.p {
count += 1;
}
}
*x = num_traits::cast(count).unwrap_or(T::zero());
}
}
fn refill_buffer<R: Rng + ?Sized>(&self, rng: &mut R) {
let buf = unsafe { &mut *self.buffer.get() };
self.fill_slice(rng, buf);
unsafe {
*self.index.get() = 0;
}
}
}
impl<T: PrimInt> Clone for SimdBinomial<T> {
fn clone(&self) -> Self {
Self::new(self.n, self.p)
}
}
impl<T: PrimInt> Distribution<T> for SimdBinomial<T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
let idx = unsafe { &mut *self.index.get() };
if *idx >= 16 {
self.refill_buffer(rng);
}
let val = unsafe { (*self.buffer.get())[*idx] };
*idx += 1;
val
}
}
impl<T: PrimInt> crate::traits::DistributionExt for SimdBinomial<T> {
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 || x.fract() != 0.0 {
return 0.0;
}
let k = x as i64;
let n = self.n as i64;
if k > n {
return 0.0;
}
let log_pmf = crate::special::ln_gamma((n + 1) as f64)
- crate::special::ln_gamma((k + 1) as f64)
- crate::special::ln_gamma((n - k + 1) as f64)
+ k as f64 * self.p.ln()
+ (n - k) as f64 * (1.0 - self.p).ln();
log_pmf.exp()
}
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
let k = x.floor() as i64;
let n = self.n as i64;
if k >= n {
return 1.0;
}
crate::special::beta_i((n - k) as f64, (k + 1) as f64, 1.0 - self.p)
}
fn inv_cdf(&self, prob: f64) -> f64 {
if prob <= 0.0 {
return 0.0;
}
if prob >= 1.0 {
return self.n as f64;
}
let mean = self.n as f64 * self.p;
let sd = (self.n as f64 * self.p * (1.0 - self.p)).sqrt();
let mut k = (mean + crate::special::ndtri(prob) * sd)
.round()
.clamp(0.0, self.n as f64) as i64;
while self.cdf(k as f64) < prob && k < self.n as i64 {
k += 1;
}
while k > 0 && self.cdf((k - 1) as f64) >= prob {
k -= 1;
}
k as f64
}
fn mean(&self) -> f64 {
self.n as f64 * self.p
}
fn median(&self) -> f64 {
(self.n as f64 * self.p).floor()
}
fn mode(&self) -> f64 {
((self.n as f64 + 1.0) * self.p).floor()
}
fn variance(&self) -> f64 {
self.n as f64 * self.p * (1.0 - self.p)
}
fn skewness(&self) -> f64 {
let q = 1.0 - self.p;
(q - self.p) / (self.n as f64 * self.p * q).sqrt()
}
fn kurtosis(&self) -> f64 {
let q = 1.0 - self.p;
(1.0 - 6.0 * self.p * q) / (self.n as f64 * self.p * q)
}
fn characteristic_function(&self, t: f64) -> num_complex::Complex64 {
let z = num_complex::Complex64::new(1.0 - self.p, 0.0)
+ num_complex::Complex64::new(0.0, t).exp().scale(self.p);
z.powi(self.n as i32)
}
fn moment_generating_function(&self, t: f64) -> f64 {
(1.0 - self.p + self.p * t.exp()).powi(self.n as i32)
}
}
py_distribution_int!(PyBinomial, SimdBinomial,
sig: (n, p, seed=None),
params: (n: u32, p: f64)
);
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::DistributionExt;
fn moments(samples: &[u32]) -> (f64, f64) {
let n = samples.len() as f64;
let mean = samples.iter().map(|&x| x as f64).sum::<f64>() / n;
let var = samples
.iter()
.map(|&x| {
let d = x as f64 - mean;
d * d
})
.sum::<f64>()
/ n;
(mean, var)
}
#[test]
fn small_n_path_matches_population_moments() {
let n = 20u32;
let p = 0.3;
let dist = SimdBinomial::<u32>::with_seed(n, p, 42);
let mut buf = vec![0u32; 50_000];
dist.fill_slice_fast(&mut buf);
let (mean, var) = moments(&buf);
let expected_mean = dist.mean();
let expected_var = dist.variance();
assert!(
(mean - expected_mean).abs() < 0.05,
"mean drift: got {mean}, expected {expected_mean}"
);
assert!(
(var - expected_var).abs() < 0.2,
"variance drift: got {var}, expected {expected_var}"
);
}
#[test]
fn large_n_small_p_wait_path_matches_population() {
let n = 5_000u32;
let p = 0.005;
let dist = SimdBinomial::<u32>::with_seed(n, p, 7);
let mut buf = vec![0u32; 50_000];
dist.fill_slice_fast(&mut buf);
let (mean, var) = moments(&buf);
let expected_mean = dist.mean();
let expected_var = dist.variance();
assert!(
(mean - expected_mean).abs() < 0.5,
"wait-method mean drift: got {mean}, expected {expected_mean}"
);
assert!(
(var / expected_var - 1.0).abs() < 0.1,
"wait-method variance drift: got {var}, expected {expected_var}"
);
}
#[test]
fn large_n_p_close_to_one_uses_complement() {
let n = 5_000u32;
let p = 0.998;
let dist = SimdBinomial::<u32>::with_seed(n, p, 11);
let mut buf = vec![0u32; 20_000];
dist.fill_slice_fast(&mut buf);
let (mean, var) = moments(&buf);
assert!(buf.iter().all(|&x| x <= n));
assert!((mean - dist.mean()).abs() < 0.5);
assert!((var / dist.variance() - 1.0).abs() < 0.15);
}
#[test]
fn n_zero_returns_zeros() {
let dist = SimdBinomial::<u32>::new(0, 0.5);
let mut buf = vec![0u32; 32];
dist.fill_slice_fast(&mut buf);
assert!(buf.iter().all(|&x| x == 0));
}
}