use sha2::{Digest, Sha256};
pub struct SeedRng {
_seed: [u8; 32],
_counter: u64,
}
impl SeedRng {
pub fn new(seed: [u8; 32]) -> Self {
Self {
_seed: seed,
_counter: 0,
}
}
#[must_use]
pub fn seed_bytes(&self) -> [u8; 32] {
self._seed
}
pub fn reseed(&mut self, new_seed: [u8; 32]) {
self._seed = new_seed;
self._counter = 0;
}
pub fn next_u64(&mut self) -> u64 {
self._counter = self._counter.wrapping_add(1);
let mut hasher = Sha256::new();
hasher.update(self._seed);
hasher.update(self._counter.to_be_bytes());
let digest = hasher.finalize();
u64::from_be_bytes(digest[..8].try_into().expect("digest length"))
}
pub fn range(&mut self, low: u64, high: u64) -> u64 {
if high <= low {
return low;
}
low + self.next_u64() % (high - low)
}
pub fn choose<'a, T>(&mut self, items: &'a [T]) -> Option<&'a T> {
if items.is_empty() {
return None;
}
let idx = self.range(0, items.len() as u64) as usize;
Some(&items[idx])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stream_is_deterministic() {
let mut a = SeedRng::new([9u8; 32]);
let mut b = SeedRng::new([9u8; 32]);
assert_eq!(a.next_u64(), b.next_u64());
assert_eq!(a.next_u64(), b.next_u64());
}
#[test]
fn different_seeds_diverge() {
let mut a = SeedRng::new([1u8; 32]);
let mut b = SeedRng::new([2u8; 32]);
assert_ne!(a.next_u64(), b.next_u64());
}
}