use rand::SeedableRng;
use rand::{
RngExt,
distr::{Distribution, StandardUniform, uniform::SampleUniform},
};
use rand_chacha::ChaCha8Rng;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
thread_local! {
static SIM_RNG: RefCell<ChaCha8Rng> = RefCell::new(ChaCha8Rng::seed_from_u64(0));
static CURRENT_SEED: RefCell<u64> = const { RefCell::new(0) };
static RNG_CALL_COUNT: Cell<u64> = const { Cell::new(0) };
static RNG_BREAKPOINTS: RefCell<VecDeque<(u64, u64)>> = const { RefCell::new(VecDeque::new()) };
}
fn pre_sample() {
RNG_CALL_COUNT.with(|c| c.set(c.get() + 1));
check_rng_breakpoint();
}
fn check_rng_breakpoint() {
RNG_BREAKPOINTS.with(|bp| {
let mut breakpoints = bp.borrow_mut();
while let Some(&(target_count, new_seed)) = breakpoints.front() {
let count = RNG_CALL_COUNT.with(|c| c.get());
if count > target_count {
breakpoints.pop_front();
SIM_RNG.with(|rng| {
*rng.borrow_mut() = ChaCha8Rng::seed_from_u64(new_seed);
});
CURRENT_SEED.with(|s| {
*s.borrow_mut() = new_seed;
});
RNG_CALL_COUNT.with(|c| c.set(1));
} else {
break;
}
}
});
}
pub fn sim_random<T>() -> T
where
StandardUniform: Distribution<T>,
{
pre_sample();
SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
}
pub fn sim_random_range<T>(range: std::ops::Range<T>) -> T
where
T: SampleUniform + PartialOrd,
{
pre_sample();
SIM_RNG.with(|rng| rng.borrow_mut().random_range(range))
}
pub fn sim_random_range_or_default<T>(range: std::ops::Range<T>) -> T
where
T: SampleUniform + PartialOrd + Clone,
{
if range.start >= range.end {
range.start
} else {
sim_random_range(range)
}
}
pub fn set_sim_seed(seed: u64) {
SIM_RNG.with(|rng| {
*rng.borrow_mut() = ChaCha8Rng::seed_from_u64(seed);
});
CURRENT_SEED.with(|current| {
*current.borrow_mut() = seed;
});
}
pub fn sim_random_f64() -> f64 {
pre_sample();
SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
}
pub fn current_sim_seed() -> u64 {
CURRENT_SEED.with(|current| *current.borrow())
}
pub fn reset_sim_rng() {
SIM_RNG.with(|rng| {
*rng.borrow_mut() = ChaCha8Rng::seed_from_u64(0);
});
CURRENT_SEED.with(|current| {
*current.borrow_mut() = 0;
});
RNG_CALL_COUNT.with(|c| c.set(0));
RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
}
pub fn rng_call_count() -> u64 {
RNG_CALL_COUNT.with(|c| c.get())
}
pub fn reset_rng_call_count() {
RNG_CALL_COUNT.with(|c| c.set(0));
}
pub fn set_rng_breakpoints(breakpoints: Vec<(u64, u64)>) {
RNG_BREAKPOINTS.with(|bp| {
*bp.borrow_mut() = VecDeque::from(breakpoints);
});
}
pub fn clear_rng_breakpoints() {
RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deterministic_randomness() {
set_sim_seed(42);
let value1: f64 = sim_random();
let value2: u32 = sim_random();
let value3: bool = sim_random();
set_sim_seed(42);
assert_eq!(value1, sim_random::<f64>());
assert_eq!(value2, sim_random::<u32>());
assert_eq!(value3, sim_random::<bool>());
}
#[test]
fn test_different_seeds_produce_different_values() {
set_sim_seed(1);
let value1_seed1: f64 = sim_random();
let value2_seed1: f64 = sim_random();
set_sim_seed(2);
let value1_seed2: f64 = sim_random();
let value2_seed2: f64 = sim_random();
assert_ne!(value1_seed1, value1_seed2);
assert_ne!(value2_seed1, value2_seed2);
}
#[test]
fn test_sim_random_range() {
set_sim_seed(42);
for _ in 0..100 {
let value = sim_random_range(10..20);
assert!(value >= 10);
assert!(value < 20);
}
for _ in 0..100 {
let value = sim_random_range(0.0..1.0);
assert!(value >= 0.0);
assert!(value < 1.0);
}
}
#[test]
fn test_range_determinism() {
set_sim_seed(123);
let value1 = sim_random_range(100..1000);
let value2 = sim_random_range(0.0..10.0);
set_sim_seed(123);
assert_eq!(value1, sim_random_range(100..1000));
assert_eq!(value2, sim_random_range(0.0..10.0));
}
#[test]
fn test_reset_clears_state() {
set_sim_seed(42);
let _advance1: f64 = sim_random();
let _advance2: f64 = sim_random();
let after_advance: f64 = sim_random();
reset_sim_rng();
set_sim_seed(42);
let first_value: f64 = sim_random();
assert_ne!(after_advance, first_value);
}
#[test]
fn test_sequence_persistence_within_thread() {
set_sim_seed(42);
let value1: f64 = sim_random();
let value2: f64 = sim_random();
let value3: f64 = sim_random();
set_sim_seed(42);
assert_eq!(value1, sim_random::<f64>());
assert_eq!(value2, sim_random::<f64>());
assert_eq!(value3, sim_random::<f64>());
}
#[test]
fn test_multiple_resets_and_seeds() {
for seed in [1, 42, 12345] {
reset_sim_rng();
set_sim_seed(seed);
let first: f64 = sim_random();
reset_sim_rng();
set_sim_seed(seed);
assert_eq!(first, sim_random::<f64>());
}
}
#[test]
fn test_current_sim_seed() {
set_sim_seed(12345);
assert_eq!(current_sim_seed(), 12345);
set_sim_seed(98765);
assert_eq!(current_sim_seed(), 98765);
reset_sim_rng();
assert_eq!(current_sim_seed(), 0);
}
#[test]
fn test_call_counting() {
reset_sim_rng();
set_sim_seed(42);
assert_eq!(rng_call_count(), 0);
let _: f64 = sim_random();
assert_eq!(rng_call_count(), 1);
let _: u32 = sim_random();
assert_eq!(rng_call_count(), 2);
let _ = sim_random_range(0..100);
assert_eq!(rng_call_count(), 3);
let _ = sim_random_f64();
assert_eq!(rng_call_count(), 4);
let _ = sim_random_range_or_default(0..100);
assert_eq!(rng_call_count(), 5);
let _ = sim_random_range_or_default(100..100);
assert_eq!(rng_call_count(), 5);
}
#[test]
fn test_breakpoint_reseed() {
reset_sim_rng();
set_sim_seed(100);
let mut old_values = Vec::new();
for _ in 0..5 {
old_values.push(sim_random::<f64>());
}
reset_sim_rng();
set_sim_seed(200);
let new_seed_first: f64 = sim_random();
reset_sim_rng();
set_sim_seed(100);
set_rng_breakpoints(vec![(5, 200)]);
for (i, expected) in old_values.iter().enumerate() {
let actual: f64 = sim_random();
assert_eq!(*expected, actual, "Mismatch at call {}", i + 1);
}
let after_breakpoint: f64 = sim_random();
assert_eq!(after_breakpoint, new_seed_first);
assert_eq!(rng_call_count(), 1);
assert_eq!(current_sim_seed(), 200);
}
#[test]
fn test_chained_breakpoints() {
reset_sim_rng();
set_sim_seed(10);
set_rng_breakpoints(vec![(3, 20), (2, 30)]);
let _: f64 = sim_random(); let _: f64 = sim_random(); let _: f64 = sim_random(); assert_eq!(current_sim_seed(), 10);
let _: f64 = sim_random();
assert_eq!(current_sim_seed(), 20);
assert_eq!(rng_call_count(), 1);
let _: f64 = sim_random();
let _: f64 = sim_random();
assert_eq!(current_sim_seed(), 30);
assert_eq!(rng_call_count(), 1);
}
#[test]
fn test_replay_determinism() {
reset_sim_rng();
set_sim_seed(42);
let _: f64 = sim_random();
let _: f64 = sim_random();
let _: f64 = sim_random();
let fork_count = rng_call_count();
set_sim_seed(99);
reset_rng_call_count();
let post_fork_1: f64 = sim_random();
let post_fork_2: f64 = sim_random();
reset_sim_rng();
set_sim_seed(42);
set_rng_breakpoints(vec![(fork_count, 99)]);
let _: f64 = sim_random();
let _: f64 = sim_random();
let _: f64 = sim_random();
let replay_1: f64 = sim_random();
let replay_2: f64 = sim_random();
assert_eq!(post_fork_1, replay_1);
assert_eq!(post_fork_2, replay_2);
}
#[test]
fn test_reset_clears_everything_including_breakpoints() {
set_sim_seed(42);
let _: f64 = sim_random();
let _: f64 = sim_random();
set_rng_breakpoints(vec![(10, 99)]);
assert_eq!(rng_call_count(), 2);
reset_sim_rng();
assert_eq!(rng_call_count(), 0);
assert_eq!(current_sim_seed(), 0);
set_sim_seed(42);
let _: f64 = sim_random();
assert_eq!(rng_call_count(), 1);
assert_eq!(current_sim_seed(), 42); }
}