use crate::backend::RandomBackend;
pub struct Rng<B: RandomBackend> {
backend: B,
}
impl<B: RandomBackend> Rng<B> {
#[inline]
pub fn new(backend: B) -> Self {
Self { backend }
}
#[inline]
#[must_use]
pub fn next_u64(&mut self) -> u64 {
self.backend.next_u64()
}
#[inline]
#[must_use]
pub fn next_f64(&mut self) -> f64 {
self.backend.next_f64()
}
#[inline]
#[must_use]
pub fn next_u32(&mut self) -> u32 {
self.backend.next_u32()
}
#[inline]
#[must_use]
pub fn next_f32(&mut self) -> f32 {
let val = (self.backend.next_u64() >> 40) as u32; (val as f32) * (1.0 / ((1u32 << 24) as f32))
}
#[inline]
#[must_use]
pub fn next_bool(&mut self) -> bool {
(self.backend.next_u64() & 1) != 0
}
#[inline]
pub fn gen_range(&mut self, min: u64, max: u64) -> core::result::Result<u64, crate::AporiaError> {
if min >= max {
return Err(crate::AporiaError::InvalidRangeU64 { min, max });
}
let range = max - min;
let zone = u64::MAX - (u64::MAX % range);
loop {
let v = self.next_u64();
if v < zone {
break Ok(min + (v % range));
}
}
}
#[inline]
pub fn gen_range_f64(&mut self, min: f64, max: f64) -> core::result::Result<f64, crate::AporiaError> {
if min >= max {
return Err(crate::AporiaError::InvalidRangeF64 { min, max });
}
let rand = self.next_f64();
Ok(min + (rand * (max - min)))
}
#[inline]
pub fn fill_bytes(&mut self, buf: &mut [u8]) {
self.backend.fill_bytes(buf)
}
}
impl<B> Clone for Rng<B>
where
B: RandomBackend + Clone,
{
fn clone(&self) -> Self {
Self {
backend: self.backend.clone(),
}
}
}
impl<B> core::fmt::Debug for Rng<B>
where
B: RandomBackend + core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Rng").field("backend", &self.backend).finish()
}
}
#[derive(Debug)]
pub struct U64Iter<'a, B: RandomBackend> {
rng: &'a mut Rng<B>,
}
impl<B: RandomBackend> Iterator for U64Iter<'_, B> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
Some(self.rng.next_u64())
}
}
#[derive(Debug)]
pub struct F64Iter<'a, B: RandomBackend> {
rng: &'a mut Rng<B>,
}
impl<B: RandomBackend> Iterator for F64Iter<'_, B> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
Some(self.rng.next_f64())
}
}
impl<B: RandomBackend> Rng<B> {
#[inline]
pub fn iter_u64(&mut self) -> U64Iter<'_, B> {
U64Iter { rng: self }
}
#[inline]
pub fn iter_f64(&mut self) -> F64Iter<'_, B> {
F64Iter { rng: self }
}
}
impl<'a, B: RandomBackend> IntoIterator for &'a mut Rng<B> {
type Item = u64;
type IntoIter = U64Iter<'a, B>;
fn into_iter(self) -> Self::IntoIter {
self.iter_u64()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::{XorShift, SplitMix64};
#[test]
fn next_f64_in_unit_interval() {
let backend = XorShift::new(1);
let mut rng = Rng::new(backend);
for _ in 0..1000 {
let x = rng.next_f64();
assert!((0.0..1.0).contains(&x));
}
}
#[test]
fn gen_range_integral_bounds() {
let backend = SplitMix64::new(123);
let mut rng = Rng::new(backend);
for _ in 0..1000 {
let x = rng.gen_range(10, 20).unwrap();
assert!((10..20).contains(&x));
}
}
#[test]
fn gen_range_f64_bounds() {
let backend = SplitMix64::new(123);
let mut rng = Rng::new(backend);
for _ in 0..1000 {
let x = rng.gen_range_f64(0.5, 1.5).unwrap();
assert!((0.5..1.5).contains(&x));
}
}
#[test]
fn next_u32_and_f32_and_bool_work() {
let backend = XorShift::new(7);
let mut rng = Rng::new(backend);
let _u32v = rng.next_u32();
let f = rng.next_f32();
assert!((0.0..1.0).contains(&f));
let _b = rng.next_bool();
}
#[test]
fn iter_helpers_and_into_iter() {
let backend = XorShift::new(1);
let mut rng = Rng::new(backend);
let mut it = rng.iter_u64();
let a = it.next().unwrap();
let b = it.next().unwrap();
assert_ne!(a, b);
let x = rng.iter_f64().next().unwrap();
assert!((0.0..1.0).contains(&x));
let backend2 = XorShift::new(2);
let mut rng2 = Rng::new(backend2);
let mut count = 0usize;
for _ in &mut rng2 { count += 1;
if count > 4 { break; }
}
assert!(count > 0);
}
}