use moonpool_core::RandomProvider;
use rand::distr::{Distribution, StandardUniform, uniform::SampleUniform};
use std::ops::Range;
use crate::sim::rng::{set_sim_seed, sim_random, sim_random_range};
#[derive(Clone, Debug)]
pub struct SimRandomProvider {
_marker: std::marker::PhantomData<()>,
}
impl SimRandomProvider {
pub fn new(seed: u64) -> Self {
set_sim_seed(seed);
Self {
_marker: std::marker::PhantomData,
}
}
}
impl RandomProvider for SimRandomProvider {
fn random<T>(&self) -> T
where
StandardUniform: Distribution<T>,
{
sim_random()
}
fn random_range<T>(&self, range: Range<T>) -> T
where
T: SampleUniform + PartialOrd,
{
sim_random_range(range)
}
fn random_ratio(&self) -> f64 {
sim_random::<f64>()
}
fn random_bool(&self, probability: f64) -> bool {
debug_assert!(
(0.0..=1.0).contains(&probability),
"Probability must be between 0.0 and 1.0, got {}",
probability
);
sim_random::<f64>() < probability
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deterministic_randomness() {
let provider1 = SimRandomProvider::new(42);
let value1_1: f64 = provider1.random();
let value1_2: u32 = provider1.random();
let provider2 = SimRandomProvider::new(42);
let value2_1: f64 = provider2.random();
let value2_2: u32 = provider2.random();
assert_eq!(value1_1, value2_1);
assert_eq!(value1_2, value2_2);
}
#[test]
fn test_random_range() {
let provider = SimRandomProvider::new(123);
for _ in 0..100 {
let value = provider.random_range(10..20);
assert!(value >= 10);
assert!(value < 20);
}
for _ in 0..100 {
let value = provider.random_range(0.0..1.0);
assert!(value >= 0.0);
assert!(value < 1.0);
}
}
#[test]
fn test_random_ratio() {
let provider = SimRandomProvider::new(456);
for _ in 0..100 {
let ratio = provider.random_ratio();
assert!(ratio >= 0.0);
assert!(ratio < 1.0);
}
}
#[test]
fn test_random_bool() {
let provider = SimRandomProvider::new(789);
for _ in 0..10 {
assert!(!provider.random_bool(0.0));
}
for _ in 0..10 {
assert!(provider.random_bool(1.0));
}
let results: Vec<bool> = (0..100).map(|_| provider.random_bool(0.5)).collect();
let true_count = results.iter().filter(|&&x| x).count();
assert!(
true_count > 30 && true_count < 70,
"Got {} true values out of 100",
true_count
);
}
#[test]
#[should_panic(expected = "Probability must be between 0.0 and 1.0")]
fn test_random_bool_invalid_probability() {
let provider = SimRandomProvider::new(999);
provider.random_bool(1.5); }
}