use std::time::Duration;
use crate::FixedQ32;
pub trait Rng: Send {
fn next_u64(&mut self) -> u64;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
fn next_bool(&mut self) -> bool {
self.next_u64() & 1 == 1
}
#[allow(clippy::as_conversions)] fn next_fixed(&mut self) -> FixedQ32 {
let frac_bits = (self.next_u64() >> 32) as i64;
FixedQ32::from_bits(frac_bits)
}
fn next_u64_bounded(&mut self, bound: u64) -> u64 {
if bound == 0 {
return 0;
}
if bound == 1 {
return 0;
}
let zone = u64::MAX - (u64::MAX % bound);
loop {
let value = self.next_u64();
if value <= zone {
return value % bound;
}
}
}
fn choose<'a, T>(&mut self, slice: &'a [T]) -> Option<&'a T> {
if slice.is_empty() {
return None;
}
let len = u64::try_from(slice.len()).ok()?;
let idx = self.next_u64_bounded(len);
slice.get(usize::try_from(idx).ok()?)
}
fn duration_between(&mut self, min: Duration, max: Duration) -> Duration {
let min_ns = u64::try_from(min.as_nanos()).unwrap_or(u64::MAX);
let max_ns = u64::try_from(max.as_nanos()).unwrap_or(u64::MAX);
if max_ns <= min_ns {
return min;
}
let range = max_ns - min_ns;
Duration::from_nanos(min_ns + self.next_u64_bounded(range))
}
fn fork(&mut self) -> Self
where
Self: Sized;
}
#[derive(Debug, Clone)]
pub struct SeededRng {
state: u64,
}
impl SeededRng {
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
#[must_use]
pub fn state(&self) -> u64 {
self.state
}
pub fn restore(&mut self, state: u64) {
self.state = if state == 0 { 1 } else { state };
}
}
impl Rng for SeededRng {
fn next_u64(&mut self) -> u64 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
self.state
}
fn fork(&mut self) -> Self {
let fork_seed = self.next_u64();
self.next_u64();
Self::new(fork_seed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_seeded_rng_reproducible() {
let mut rng1 = SeededRng::new(12345);
let mut rng2 = SeededRng::new(12345);
for _ in 0..100 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
#[test]
fn test_seeded_rng_different_seeds() {
let mut rng1 = SeededRng::new(12345);
let mut rng2 = SeededRng::new(54321);
assert_ne!(rng1.next_u64(), rng2.next_u64());
}
#[test]
fn test_seeded_rng_zero_seed() {
let mut rng = SeededRng::new(0);
let val = rng.next_u64();
assert_ne!(val, 0);
}
#[test]
fn test_rng_choose() {
let mut rng = SeededRng::new(42);
let items = vec![1, 2, 3, 4, 5];
let chosen = rng.choose(&items);
assert!(chosen.is_some());
assert!(items.contains(chosen.expect("must choose an element")));
}
#[test]
fn test_rng_choose_empty() {
let mut rng = SeededRng::new(42);
let items: Vec<i32> = vec![];
assert!(rng.choose(&items).is_none());
}
#[test]
fn test_rng_duration_between() {
let mut rng = SeededRng::new(42);
let min = Duration::from_millis(100);
let max = Duration::from_millis(200);
for _ in 0..100 {
let d = rng.duration_between(min, max);
assert!(d >= min);
assert!(d < max);
}
}
#[test]
fn test_rng_fork() {
let mut rng1 = SeededRng::new(42);
let mut rng2 = rng1.fork();
assert_ne!(rng1.next_u64(), rng2.next_u64());
let mut rng3 = SeededRng::new(42);
let mut rng4 = rng3.fork();
let mut rng5 = SeededRng::new(42);
let mut rng6 = rng5.fork();
assert_eq!(rng4.next_u64(), rng6.next_u64());
}
#[test]
fn test_next_fixed_in_range() {
let mut rng = SeededRng::new(42);
for _ in 0..100 {
let val = rng.next_fixed();
assert!(val >= FixedQ32::zero());
assert!(val < FixedQ32::one());
}
}
#[test]
fn test_next_u64_bounded_reproducible() {
let mut rng1 = SeededRng::new(777);
let mut rng2 = SeededRng::new(777);
for _ in 0..256 {
assert_eq!(rng1.next_u64_bounded(37), rng2.next_u64_bounded(37));
}
}
#[test]
#[allow(clippy::as_conversions)] fn test_next_u64_bounded_distribution_sanity() {
let mut rng = SeededRng::new(123456);
let bound = 8u64;
let mut counts = [0usize; 8];
let draws = 80_000usize;
for _ in 0..draws {
let idx = usize::try_from(rng.next_u64_bounded(bound)).expect("index fits");
counts[idx] += 1;
}
let expected = draws as f64 / bound as f64;
for count in counts {
let delta = (count as f64 - expected).abs();
assert!(
delta / expected < 0.08,
"bucket count {count} too far from expected {expected}"
);
}
}
}