pub(crate) struct Rng {
s: [u64; 4],
}
impl Rng {
pub(crate) fn seed_from_u64(seed: u64) -> Self {
let mut s = seed;
let mut out = [0u64; 4];
for slot in &mut out {
s = s.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
*slot = z ^ (z >> 31);
}
Self { s: out }
}
pub(crate) fn next_u64(&mut self) -> u64 {
let result = self.s[1].wrapping_mul(5).rotate_left(7).wrapping_mul(9);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
pub(crate) fn next_f64(&mut self) -> f64 {
const SCALE: f64 = 9_007_199_254_740_992.0;
let bits = self.next_u64() >> 11;
#[allow(clippy::cast_precision_loss)]
let f = bits as f64;
f / SCALE
}
pub(crate) fn next_normal(&mut self) -> f64 {
let mut u1 = self.next_f64();
while u1 <= f64::EPSILON {
u1 = self.next_f64();
}
let u2 = self.next_f64();
let mag = (-2.0 * u1.ln()).sqrt();
mag * (2.0 * core::f64::consts::PI * u2).cos()
}
pub(crate) fn next_uniform(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.next_f64()
}
}
#[cfg(test)]
mod tests {
use super::Rng;
#[test]
fn deterministic_same_seed() {
let mut a = Rng::seed_from_u64(42);
let mut b = Rng::seed_from_u64(42);
for _ in 0..32 {
assert_eq!(a.next_u64(), b.next_u64());
}
}
#[test]
fn different_seeds_diverge() {
let mut a = Rng::seed_from_u64(1);
let mut b = Rng::seed_from_u64(2);
assert_ne!(a.next_u64(), b.next_u64());
}
#[test]
fn f64_in_unit_range() {
let mut r = Rng::seed_from_u64(7);
for _ in 0..1000 {
let v = r.next_f64();
assert!((0.0..1.0).contains(&v), "f64 out of [0,1): {v}");
}
}
#[test]
fn normal_distribution_roughly_centered() {
let mut r = Rng::seed_from_u64(13);
let n: i32 = 10_000;
let n_f = f64::from(n);
let samples: Vec<f64> = (0..n).map(|_| r.next_normal()).collect();
let mean: f64 = samples.iter().sum::<f64>() / n_f;
let var: f64 = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n_f;
assert!(mean.abs() < 0.1, "mean too far from 0: {mean}");
assert!(
(var.sqrt() - 1.0).abs() < 0.1,
"stdev too far from 1: {}",
var.sqrt()
);
}
}