use libm::{cos, log, sin, sqrt};
use rand_chacha::rand_core::{RngCore, SeedableRng};
use rand_chacha::ChaCha20Rng;
pub(crate) struct ChaChaGaussianStream {
rng: ChaCha20Rng,
spare: Option<f64>,
}
impl ChaChaGaussianStream {
pub fn new(seed: u64) -> Self {
Self {
rng: ChaCha20Rng::seed_from_u64(seed),
spare: None,
}
}
pub fn next_f64(&mut self) -> f64 {
if let Some(spare) = self.spare.take() {
return spare;
}
loop {
let u1 = self.next_uniform();
let u2 = self.next_uniform();
if u1 <= 0.0 {
continue;
}
let r = sqrt(-2.0 * log(u1));
let theta = 2.0 * core::f64::consts::PI * u2;
let z0 = r * cos(theta);
let z1 = r * sin(theta);
self.spare = Some(z1);
return z0;
}
}
fn next_uniform(&mut self) -> f64 {
let n = self.rng.next_u64();
#[allow(clippy::cast_precision_loss)]
let numerator = (n >> 11) as f64;
numerator * (1.0_f64 / 9_007_199_254_740_992.0_f64)
}
}
#[cfg(test)]
mod tests {
use super::ChaChaGaussianStream;
#[test]
#[cfg_attr(miri, ignore)]
fn same_seed_produces_identical_stream() {
let mut a = ChaChaGaussianStream::new(42);
let mut b = ChaChaGaussianStream::new(42);
for _ in 0..64 {
assert_eq!(a.next_f64().to_bits(), b.next_f64().to_bits());
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn different_seeds_diverge() {
let mut a = ChaChaGaussianStream::new(42);
let mut b = ChaChaGaussianStream::new(43);
let mut diffs = 0;
for _ in 0..64 {
if a.next_f64().to_bits() != b.next_f64().to_bits() {
diffs += 1;
}
}
assert!(diffs > 0, "different seeds must not collide bit-for-bit");
}
#[test]
#[cfg_attr(miri, ignore)]
fn samples_have_reasonable_spread() {
let mut s = ChaChaGaussianStream::new(0);
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
let n = 2048;
for _ in 0..n {
let x = s.next_f64();
sum += x;
sum_sq += x * x;
}
let mean = sum / f64::from(n);
let var = sum_sq / f64::from(n) - mean * mean;
assert!(mean.abs() < 0.1, "sample mean too far from 0: {mean}");
assert!(
(var - 1.0).abs() < 0.2,
"sample variance too far from 1: {var}"
);
}
}