use crate::FloatScalar;
#[derive(Debug, Clone)]
pub struct Rng {
s: [u64; 4],
}
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9e3779b97f4a7c15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z ^ (z >> 31)
}
impl Rng {
pub fn new(seed: u64) -> Self {
let mut sm = seed;
let s0 = splitmix64(&mut sm);
let s1 = splitmix64(&mut sm);
let s2 = splitmix64(&mut sm);
let s3 = splitmix64(&mut sm);
Self { s: [s0, s1, s2, s3] }
}
pub fn next_u64(&mut self) -> u64 {
let result = (self.s[0].wrapping_add(self.s[3]))
.rotate_left(23)
.wrapping_add(self.s[0]);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 * (1.0_f64 / (1u64 << 53) as f64)
}
pub fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 * (1.0_f32 / (1u64 << 24) as f32)
}
pub(crate) fn next_normal_f64(&mut self) -> f64 {
loop {
let u1 = self.next_f64();
let u2 = self.next_f64();
if u1 > 0.0 {
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * core::f64::consts::PI * u2;
return r * theta.cos();
}
}
}
pub(crate) fn next_float<T: FloatScalar>(&mut self) -> T {
T::from(self.next_f64()).unwrap()
}
pub(crate) fn next_normal<T: FloatScalar>(&mut self) -> T {
T::from(self.next_normal_f64()).unwrap()
}
pub(crate) fn next_gamma<T: FloatScalar>(&mut self, alpha: T) -> T {
let one = T::one();
if alpha < one {
let g = self.next_gamma(alpha + one);
let u: T = self.next_float();
return g * u.powf(one / alpha);
}
let three = T::from(3.0).unwrap();
let nine = T::from(9.0).unwrap();
let d = alpha - one / three;
let c = one / (nine * d).sqrt();
loop {
let x: T = self.next_normal();
let v = one + c * x;
if v <= T::zero() {
continue;
}
let v = v * v * v;
let u: T = self.next_float();
let x2 = x * x;
if u < one - T::from(0.0331).unwrap() * x2 * x2 {
return d * v;
}
if u.ln() < T::from(0.5).unwrap() * x2 + d * (one - v + v.ln()) {
return d * v;
}
}
}
}