use rand::distributions::Distribution;
use rand::Rng;
const EXP_MINUS_HALF: f64 = 0.606530659712633;
pub struct DiscreteNormal {
mu: f64,
sigma: f64,
}
impl DiscreteNormal {
pub fn new(mu: f64, sigma: f64) -> Self {
Self { mu, sigma }
}
#[inline]
fn bernoulli(&self) -> bool {
rand::thread_rng().gen::<f64>() < EXP_MINUS_HALF
}
#[inline]
fn g(&self) -> u32 {
let mut res: u32 = 0;
while self.bernoulli() {
res += 1;
}
res
}
#[inline]
fn s(&self) -> u32 {
loop {
let k = self.g() as i32;
if k < 2 {
return k as u32;
}
let mut z = k * (k - 1) - 1;
while z != 0 && self.bernoulli() {
z -= 1;
}
if z < 0 {
return k as u32;
};
}
}
}
impl Distribution<i32> for DiscreteNormal {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> i32 {
loop {
let k = self.s();
let s = if rng.gen::<bool>() { 1 } else { -1 };
let mut xn0 = (k as f64) * self.sigma + (s as f64) * self.mu;
let i0 = xn0.ceil();
xn0 = (i0 - xn0) / self.sigma;
let j = rng.gen::<u32>() % (self.sigma.ceil() as u32);
let x = xn0 + (j as f64) / self.sigma;
if x < 1.0 && !(x == 0.0 && s < 0 && k == 0) {
xn0 = (-x * ((k << 1) as f64 + x) / 2.0).exp();
if x == 0.0 || rng.gen::<f64>() <= xn0 {
return s * (i0 as u32 + j) as i32;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use libm::erf;
use std::collections::HashMap;
fn normcdf(x: f64, mu: f64, sigma: f64) -> f64 {
(0.5) * (1.0 + erf((x - mu) / (sigma * (2.0_f64).sqrt())))
}
#[test]
fn test_chisq() {
let n: usize = 1000;
let mu = 0.0;
let sigma = 3.0;
let dn = DiscreteNormal::new(mu, sigma);
let mut hist = HashMap::new();
for _ in 0..n {
let o = dn.sample(&mut rand::thread_rng());
match hist.get(&o) {
Some(count) => hist.insert(o, count + 1),
None => hist.insert(o, 1),
};
}
let mut data = hist.iter().map(|(k, _)| *k).collect::<Vec<i32>>();
data.sort();
let s = data.len();
let probs = data
.iter()
.map(|d| normcdf(*d as f64, mu, sigma))
.collect::<Vec<f64>>();
let expected = (0..s)
.map(|i| match i {
0 => probs[i],
_ => probs[i] - probs[i - 1],
})
.collect::<Vec<f64>>();
let mut chisq = 0.0;
for i in 0..s {
let o = hist[&data[i]];
let e = expected[i] * n as f64;
chisq += (o as f64 - e).powf(2.0) / e;
}
dbg!(chisq);
}
}