use std::ops::{Range, RangeInclusive};
pub(crate) struct FastRng {
s: [u64; 4],
}
impl FastRng {
pub fn new(seed: u64) -> Self {
let mut state = seed;
let mut s = [0u64; 4];
for slot in &mut s {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
*slot = z ^ (z >> 31);
}
Self { s }
}
fn next_u64(&mut self) -> u64 {
let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
pub fn f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
pub fn normal(&mut self) -> f64 {
let u1 = self.f64().max(1e-300);
let u2 = self.f64();
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
pub fn usize(&mut self, range: impl UsizeRange) -> usize {
let (start, len) = range.start_and_len();
debug_assert!(len > 0, "empty range");
start + (self.next_u64() as usize) % len
}
pub fn shuffle(&mut self, indices: &mut [usize]) {
for i in (1..indices.len()).rev() {
let j = (self.next_u64() as usize) % (i + 1);
indices.swap(i, j);
}
}
}
pub(crate) trait UsizeRange {
fn start_and_len(self) -> (usize, usize);
}
impl UsizeRange for Range<usize> {
fn start_and_len(self) -> (usize, usize) {
(self.start, self.end - self.start)
}
}
impl UsizeRange for RangeInclusive<usize> {
fn start_and_len(self) -> (usize, usize) {
let (start, end) = (*self.start(), *self.end());
(start, end - start + 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_in_unit_interval() {
let mut rng = FastRng::new(42);
for _ in 0..1000 {
let v = rng.f64();
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn usize_range() {
let mut rng = FastRng::new(7);
for _ in 0..200 {
let v = rng.usize(5..10);
assert!((5..10).contains(&v));
}
}
#[test]
fn usize_range_inclusive() {
let mut rng = FastRng::new(7);
for _ in 0..200 {
let v = rng.usize(0..=3);
assert!(v <= 3);
}
}
#[test]
fn normal_distribution() {
let mut rng = FastRng::new(42);
let n = 10_000;
let samples: Vec<f64> = (0..n).map(|_| rng.normal()).collect();
let mean: f64 = samples.iter().sum::<f64>() / n as f64;
let var: f64 = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
assert!(mean.abs() < 0.1, "mean should be ~0, got {mean}");
assert!((var - 1.0).abs() < 0.1, "variance should be ~1, got {var}");
}
}