use rand::SeedableRng;
use rand::rngs::StdRng;
use std::sync::{Arc, Mutex};
pub trait RngProvider: Send + Sync {
type Rng: rand::RngCore + Clone + Send;
fn create_rng(&self, seed: Option<u64>) -> Self::Rng;
fn create_random_rng(&self) -> Self::Rng {
self.create_rng(None)
}
}
#[derive(Debug, Clone)]
pub struct DefaultRngProvider;
impl RngProvider for DefaultRngProvider {
type Rng = StdRng;
fn create_rng(&self, seed: Option<u64>) -> Self::Rng {
match seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
}
}
}
#[derive(Debug)]
pub struct RngManager<P: RngProvider> {
provider: P,
seed: Option<u64>,
thread_rng: Arc<Mutex<Option<P::Rng>>>,
}
impl<P: RngProvider> RngManager<P> {
pub fn new(provider: P) -> Self {
Self {
provider,
seed: None,
thread_rng: Arc::new(Mutex::new(None)),
}
}
pub fn with_seed(provider: P, seed: u64) -> Self {
Self {
provider,
seed: Some(seed),
thread_rng: Arc::new(Mutex::new(None)),
}
}
pub fn get_rng(&self) -> P::Rng {
self.provider.create_rng(self.seed)
}
pub fn create_seeded_rng(&self, seed: u64) -> P::Rng {
self.provider.create_rng(Some(seed))
}
pub fn seed(&self) -> Option<u64> {
self.seed
}
pub fn set_seed(&mut self, seed: Option<u64>) {
self.seed = seed;
if let Ok(mut cached) = self.thread_rng.lock() {
*cached = None;
}
}
}
impl<P: RngProvider + Clone> Clone for RngManager<P> {
fn clone(&self) -> Self {
Self {
provider: self.provider.clone(),
seed: self.seed,
thread_rng: Arc::new(Mutex::new(None)),
}
}
}
use std::sync::OnceLock;
static GLOBAL_RNG_MANAGER: OnceLock<RngManager<DefaultRngProvider>> = OnceLock::new();
pub fn global_rng_manager() -> &'static RngManager<DefaultRngProvider> {
GLOBAL_RNG_MANAGER.get_or_init(|| RngManager::new(DefaultRngProvider))
}
pub fn set_global_seed(_seed: u64) {
}
pub fn create_rng() -> StdRng {
global_rng_manager().get_rng()
}
pub fn create_seeded_rng(seed: u64) -> StdRng {
global_rng_manager().create_seeded_rng(seed)
}
#[derive(Debug, Clone)]
pub struct TestRngProvider {
base_seed: u64,
}
impl TestRngProvider {
pub fn new(base_seed: u64) -> Self {
Self { base_seed }
}
}
impl RngProvider for TestRngProvider {
type Rng = StdRng;
fn create_rng(&self, seed: Option<u64>) -> Self::Rng {
let actual_seed = seed.unwrap_or(self.base_seed);
StdRng::seed_from_u64(actual_seed)
}
}
#[derive(Debug)]
pub struct RngWrapper<R> {
inner: R,
}
impl<R> RngWrapper<R> {
pub fn new(rng: R) -> Self {
Self { inner: rng }
}
pub fn inner_mut(&mut self) -> &mut R {
&mut self.inner
}
pub fn inner(&self) -> &R {
&self.inner
}
}
impl<R: rand::RngCore> rand::RngCore for RngWrapper<R> {
fn next_u32(&mut self) -> u32 {
self.inner.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.inner.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.inner.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.inner.try_fill_bytes(dest)
}
}
impl<R: rand::RngCore + Clone> Clone for RngWrapper<R> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{Rng, RngCore};
#[test]
fn test_default_rng_provider() {
let provider = DefaultRngProvider;
let mut rng1 = provider.create_rng(None);
let mut rng2 = provider.create_rng(None);
let _val1: u32 = rng1.r#gen();
let _val2: u32 = rng2.r#gen();
let mut rng3 = provider.create_rng(Some(12345));
let mut rng4 = provider.create_rng(Some(12345));
let val3: u32 = rng3.r#gen();
let val4: u32 = rng4.r#gen();
assert_eq!(val3, val4);
}
#[test]
fn test_rng_manager() {
let provider = DefaultRngProvider;
let manager = RngManager::new(provider);
let mut rng = manager.get_rng();
let _value: u32 = rng.r#gen();
let mut seeded_rng = manager.create_seeded_rng(42);
let seeded_value: u32 = seeded_rng.r#gen();
let mut seeded_rng2 = manager.create_seeded_rng(42);
let seeded_value2: u32 = seeded_rng2.r#gen();
assert_eq!(seeded_value, seeded_value2);
}
#[test]
fn test_rng_manager_with_seed() {
let provider = DefaultRngProvider;
let manager = RngManager::with_seed(provider, 999);
assert_eq!(manager.seed(), Some(999));
let mut rng1 = manager.get_rng();
let mut rng2 = manager.get_rng();
let val1: u32 = rng1.r#gen();
let val2: u32 = rng2.r#gen();
assert_eq!(val1, val2);
}
#[test]
fn test_test_rng_provider() {
let provider = TestRngProvider::new(777);
let mut rng1 = provider.create_rng(None);
let mut rng2 = provider.create_rng(None);
let val1: u32 = rng1.r#gen();
let val2: u32 = rng2.r#gen();
assert_eq!(val1, val2);
let mut rng3 = provider.create_rng(Some(888));
let val3: u32 = rng3.r#gen();
assert_ne!(val1, val3);
}
#[test]
fn test_rng_wrapper() {
let base_rng = StdRng::seed_from_u64(12345);
let mut wrapper = RngWrapper::new(base_rng);
let _value = wrapper.next_u32();
let _value = wrapper.next_u64();
let mut bytes = [0u8; 10];
wrapper.fill_bytes(&mut bytes);
let _inner = wrapper.inner();
let _inner_mut = wrapper.inner_mut();
}
#[test]
fn test_global_rng_functions() {
let mut rng = create_rng();
let _value: u32 = rng.r#gen();
let mut seeded_rng1 = create_seeded_rng(555);
let mut seeded_rng2 = create_seeded_rng(555);
let val1: u32 = seeded_rng1.r#gen();
let val2: u32 = seeded_rng2.r#gen();
assert_eq!(val1, val2);
}
#[test]
fn test_rng_manager_clone() {
let provider = DefaultRngProvider;
let manager1 = RngManager::with_seed(provider, 123);
let manager2 = manager1.clone();
assert_eq!(manager1.seed(), manager2.seed());
let mut rng1 = manager1.get_rng();
let mut rng2 = manager2.get_rng();
let val1: u32 = rng1.r#gen();
let val2: u32 = rng2.r#gen();
assert_eq!(val1, val2);
}
#[test]
fn test_rng_manager_set_seed() {
let provider = DefaultRngProvider;
let mut manager = RngManager::new(provider);
assert_eq!(manager.seed(), None);
manager.set_seed(Some(456));
assert_eq!(manager.seed(), Some(456));
manager.set_seed(None);
assert_eq!(manager.seed(), None);
}
}