use rand::rngs::SmallRng;
use rand::distr::{Distribution, Uniform};
use rand::{RngExt, SeedableRng};
use rand::seq::SliceRandom;
#[derive(Clone)]
pub struct Rng {
inner: SmallRng,
}
impl Rng {
pub fn seed(seed: u64) -> Self {
Self { inner: SmallRng::seed_from_u64(seed) }
}
pub fn from_entropy() -> Self {
Self { inner: rand::make_rng() }
}
pub fn usize(&mut self, n: usize) -> usize {
assert!(n > 0, "Rng::usize(0) is undefined");
Uniform::new(0, n).unwrap().sample(&mut self.inner)
}
pub fn f32(&mut self) -> f32 {
self.inner.random()
}
pub fn f64(&mut self) -> f64 {
self.inner.random()
}
pub fn shuffle<T>(&mut self, slice: &mut [T]) {
slice.shuffle(&mut self.inner);
}
pub fn bernoulli(&mut self, p: f64) -> bool {
self.f64() < p
}
pub fn range(&mut self, low: i64, high: i64) -> i64 {
assert!(low < high, "Rng::range requires low < high, got {low} >= {high}");
Uniform::new(low, high).unwrap().sample(&mut self.inner)
}
pub fn normal(&mut self, mean: f64, std: f64) -> f64 {
let u1: f64 = 1.0 - self.inner.random::<f64>(); let u2: f64 = self.inner.random::<f64>();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
mean + std * z
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deterministic_same_seed() {
let mut a = Rng::seed(42);
let mut b = Rng::seed(42);
let va: Vec<f64> = (0..100).map(|_| a.f64()).collect();
let vb: Vec<f64> = (0..100).map(|_| b.f64()).collect();
assert_eq!(va, vb);
}
#[test]
fn different_seeds_differ() {
let mut a = Rng::seed(1);
let mut b = Rng::seed(2);
let va: Vec<f64> = (0..20).map(|_| a.f64()).collect();
let vb: Vec<f64> = (0..20).map(|_| b.f64()).collect();
assert_ne!(va, vb);
}
#[test]
fn usize_in_range() {
let mut rng = Rng::seed(0);
for _ in 0..1000 {
let v = rng.usize(10);
assert!(v < 10);
}
}
#[test]
#[should_panic(expected = "usize(0) is undefined")]
fn usize_zero_panics() {
Rng::seed(0).usize(0);
}
#[test]
fn f32_in_unit_interval() {
let mut rng = Rng::seed(0);
for _ in 0..1000 {
let v = rng.f32();
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn f64_in_unit_interval() {
let mut rng = Rng::seed(0);
for _ in 0..1000 {
let v = rng.f64();
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn shuffle_preserves_elements() {
let mut rng = Rng::seed(42);
let mut data = vec![1, 2, 3, 4, 5];
rng.shuffle(&mut data);
data.sort();
assert_eq!(data, vec![1, 2, 3, 4, 5]);
}
#[test]
fn shuffle_deterministic() {
let mut a = Rng::seed(42);
let mut b = Rng::seed(42);
let mut da = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut db = da.clone();
a.shuffle(&mut da);
b.shuffle(&mut db);
assert_eq!(da, db);
}
#[test]
fn bernoulli_extremes() {
let mut rng = Rng::seed(0);
for _ in 0..100 {
assert!(!rng.bernoulli(0.0));
}
for _ in 0..100 {
assert!(rng.bernoulli(1.0));
}
}
#[test]
fn bernoulli_roughly_half() {
let mut rng = Rng::seed(42);
let n = 10_000;
let hits = (0..n).filter(|_| rng.bernoulli(0.5)).count();
let ratio = hits as f64 / n as f64;
assert!((0.45..0.55).contains(&ratio), "bernoulli(0.5) ratio = {ratio}");
}
#[test]
fn range_bounds() {
let mut rng = Rng::seed(0);
for _ in 0..1000 {
let v = rng.range(-5, 5);
assert!((-5..5).contains(&v));
}
}
#[test]
#[should_panic(expected = "low < high")]
fn range_empty_panics() {
Rng::seed(0).range(5, 5);
}
#[test]
fn normal_statistical() {
let mut rng = Rng::seed(42);
let n = 50_000;
let samples: Vec<f64> = (0..n).map(|_| rng.normal(3.0, 0.5)).collect();
let mean = samples.iter().sum::<f64>() / n as f64;
let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let std = var.sqrt();
assert!((2.95..3.05).contains(&mean), "normal mean = {mean}");
assert!((0.47..0.53).contains(&std), "normal std = {std}");
}
#[test]
fn clone_preserves_state() {
let mut a = Rng::seed(42);
for _ in 0..10 { a.f64(); }
let mut b = a.clone();
let va: Vec<f64> = (0..50).map(|_| a.f64()).collect();
let vb: Vec<f64> = (0..50).map(|_| b.f64()).collect();
assert_eq!(va, vb);
}
#[test]
fn from_entropy_works() {
let mut rng = Rng::from_entropy();
let v = rng.f64();
assert!((0.0..1.0).contains(&v));
}
}