use rand::{Rng, SeedableRng, rngs::SmallRng};
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Clone)]
pub struct SimulationRng {
inner: Arc<Mutex<SmallRng>>,
seed: u64,
}
impl SimulationRng {
pub fn new(seed: u64) -> Self {
Self {
inner: Arc::new(Mutex::new(SmallRng::seed_from_u64(seed))),
seed,
}
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn gen_bool(&self, probability: f64) -> bool {
self.inner.lock().unwrap().random_bool(probability)
}
pub fn gen_u64(&self) -> u64 {
self.inner.lock().unwrap().random()
}
pub fn gen_u32(&self) -> u32 {
self.inner.lock().unwrap().random()
}
pub fn gen_range(&self, range: std::ops::Range<usize>) -> usize {
self.inner.lock().unwrap().random_range(range)
}
pub fn gen_range_u64(&self, range: std::ops::Range<u64>) -> u64 {
self.inner.lock().unwrap().random_range(range)
}
pub fn gen_f64(&self) -> f64 {
self.inner.lock().unwrap().random()
}
pub fn gen_duration(&self, range: std::ops::Range<Duration>) -> Duration {
let start_nanos = range.start.as_nanos() as u64;
let end_nanos = range.end.as_nanos() as u64;
let nanos = self
.inner
.lock()
.unwrap()
.random_range(start_nanos..end_nanos);
Duration::from_nanos(nanos)
}
pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> {
if slice.is_empty() {
return None;
}
let idx = self.gen_range(0..slice.len());
Some(&slice[idx])
}
pub fn shuffle<T>(&self, slice: &mut [T]) {
let mut rng = self.inner.lock().unwrap();
for i in (1..slice.len()).rev() {
let j = rng.random_range(0..=i);
slice.swap(i, j);
}
}
pub fn child(&self) -> Self {
let child_seed = self.gen_u64();
Self::new(child_seed)
}
pub fn child_with_index(&self, index: u64) -> Self {
let derived_seed = self
.seed
.wrapping_mul(0x517cc1b727220a95)
.wrapping_add(index);
Self::new(derived_seed)
}
pub fn lock(&self) -> std::sync::MutexGuard<'_, SmallRng> {
self.inner.lock().unwrap()
}
}
impl std::fmt::Debug for SimulationRng {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimulationRng")
.field("seed", &self.seed)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_determinism_same_seed() {
let rng1 = SimulationRng::new(42);
let rng2 = SimulationRng::new(42);
for _ in 0..100 {
assert_eq!(rng1.gen_u64(), rng2.gen_u64());
}
}
#[test]
fn test_determinism_different_seeds() {
let rng1 = SimulationRng::new(42);
let rng2 = SimulationRng::new(43);
let mut same_count = 0;
for _ in 0..100 {
if rng1.gen_u64() == rng2.gen_u64() {
same_count += 1;
}
}
assert!(same_count < 10);
}
#[test]
fn test_gen_bool() {
let rng = SimulationRng::new(42);
for _ in 0..100 {
assert!(!rng.gen_bool(0.0));
}
for _ in 0..100 {
assert!(rng.gen_bool(1.0));
}
}
#[test]
fn test_gen_range() {
let rng = SimulationRng::new(42);
for _ in 0..100 {
let val = rng.gen_range(10..20);
assert!((10..20).contains(&val));
}
}
#[test]
fn test_gen_duration() {
let rng = SimulationRng::new(42);
let min = Duration::from_millis(10);
let max = Duration::from_millis(100);
for _ in 0..100 {
let val = rng.gen_duration(min..max);
assert!(val >= min && val < max);
}
}
#[test]
fn test_choose() {
let rng = SimulationRng::new(42);
let items = vec![1, 2, 3, 4, 5];
for _ in 0..100 {
assert!(rng.choose(&items).is_some());
}
let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none());
}
#[test]
fn test_shuffle_determinism() {
let rng1 = SimulationRng::new(42);
let rng2 = SimulationRng::new(42);
let mut items1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut items2 = items1.clone();
rng1.shuffle(&mut items1);
rng2.shuffle(&mut items2);
assert_eq!(items1, items2);
}
#[test]
fn test_child_determinism() {
let parent1 = SimulationRng::new(42);
let parent2 = SimulationRng::new(42);
let child1 = parent1.child();
let child2 = parent2.child();
for _ in 0..100 {
assert_eq!(child1.gen_u64(), child2.gen_u64());
}
}
#[test]
fn test_child_with_index_determinism() {
let parent = SimulationRng::new(42);
let child1 = parent.child_with_index(5);
let child2 = parent.child_with_index(5);
for _ in 0..10 {
assert_eq!(child1.gen_u64(), child2.gen_u64());
}
let child3 = parent.child_with_index(6);
let child4 = parent.child_with_index(7);
let val3 = child3.gen_u64();
let val4 = child4.gen_u64();
assert!(val3 != 0 || val4 != 0);
}
#[test]
fn test_multiple_children_determinism() {
let parent1 = SimulationRng::new(42);
let parent2 = SimulationRng::new(42);
let children1: Vec<_> = (0..10).map(|i| parent1.child_with_index(i)).collect();
let children2: Vec<_> = (0..10).map(|i| parent2.child_with_index(i)).collect();
for (c1, c2) in children1.iter().zip(children2.iter()) {
for _ in 0..20 {
assert_eq!(c1.gen_u64(), c2.gen_u64());
}
}
}
}