#[derive(Debug, Clone)]
pub struct SplitMix64(u64);
impl SplitMix64 {
#[must_use]
pub fn new(seed: u64) -> Self {
Self(seed)
}
pub fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
pub fn gen_index(&mut self, bound: usize) -> usize {
assert!(bound > 0, "gen_index bound must be positive");
#[allow(clippy::cast_possible_truncation)]
{
(self.next_u64() % (bound as u64)) as usize
}
}
pub fn gen_unit(&mut self) -> f64 {
let bits = self.next_u64() >> 11; #[allow(clippy::cast_precision_loss)]
let numerator = bits as f64;
numerator * (1.0_f64 / 9_007_199_254_740_992.0_f64)
}
pub fn gen_geom(&mut self, p: f64) -> f64 {
debug_assert!(p > 0.0 && p <= 1.0, "gen_geom requires p in (0, 1]");
if p >= 1.0 {
return 0.0;
}
let u = self.gen_unit();
let one_minus_u = 1.0 - u;
#[allow(clippy::float_cmp)]
let is_one = one_minus_u == 1.0;
if is_one {
return 0.0;
}
let log_one_minus_p = (1.0 - p).ln();
#[allow(clippy::float_cmp)]
if log_one_minus_p == 0.0 {
return f64::INFINITY;
}
(one_minus_u.ln() / log_one_minus_p).floor()
}
pub fn gen_normal(&mut self) -> f64 {
let u1 = self.gen_unit();
let u2 = self.gen_unit();
let r = (-2.0 * (1.0 - u1).ln()).sqrt();
r * (std::f64::consts::TAU * u2).cos()
}
#[allow(clippy::many_single_char_names)]
pub fn gen_gamma(&mut self, shape: f64) -> f64 {
debug_assert!(shape > 0.0, "gen_gamma requires shape > 0");
if shape < 1.0 {
let g = self.gen_gamma(shape + 1.0);
let u = self.gen_unit();
return g * u.powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let x = self.gen_normal();
let v_base = 1.0 + c * x;
if v_base <= 0.0 {
continue;
}
let v = v_base * v_base * v_base;
let u = self.gen_unit();
let x2 = x * x;
if u < 1.0 - 0.0331 * x2 * x2 {
return d * v;
}
if u.ln() < 0.5 * x2 + d * (1.0 - v + v.ln()) {
return d * v;
}
}
}
}
#[cfg(test)]
mod tests {
use super::SplitMix64;
#[test]
fn determinism_same_seed_same_stream() {
let mut a = SplitMix64::new(42);
let mut b = SplitMix64::new(42);
for _ in 0..1024 {
assert_eq!(a.next_u64(), b.next_u64());
}
}
#[test]
fn different_seeds_differ() {
let mut a = SplitMix64::new(1);
let mut b = SplitMix64::new(2);
assert_ne!(a.next_u64(), b.next_u64());
}
#[test]
fn gen_index_in_range() {
let mut r = SplitMix64::new(0x1234);
for _ in 0..10_000 {
let v = r.gen_index(100);
assert!(v < 100);
}
}
#[test]
fn gen_unit_in_unit_interval() {
let mut r = SplitMix64::new(99);
for _ in 0..10_000 {
let u = r.gen_unit();
assert!((0.0..1.0).contains(&u), "u={u}");
}
}
#[test]
fn gen_geom_p1_returns_zero() {
let mut r = SplitMix64::new(7);
for _ in 0..32 {
#[allow(clippy::float_cmp)]
let v = r.gen_geom(1.0);
#[allow(clippy::float_cmp)]
{
assert_eq!(v, 0.0);
}
}
}
#[test]
fn gen_normal_mean_and_variance() {
let n: i32 = 100_000;
let mut r = SplitMix64::new(42);
let samples: Vec<f64> = (0..n).map(|_| r.gen_normal()).collect();
let nf = f64::from(n);
let mean = samples.iter().sum::<f64>() / nf;
let variance = samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ nf;
assert!(mean.abs() < 0.02, "normal mean should be ~0, got {mean}");
assert!(
(variance - 1.0).abs() < 0.05,
"normal variance should be ~1, got {variance}"
);
}
#[test]
fn gen_gamma_mean_matches_shape() {
for &shape in &[0.5, 1.0, 2.0, 5.0] {
let n: i32 = 50_000;
let mut r = SplitMix64::new(99);
let sum: f64 = (0..n).map(|_| r.gen_gamma(shape)).sum();
let mean = sum / f64::from(n);
let rel_err = (mean - shape).abs() / shape;
assert!(
rel_err < 0.05,
"gamma({shape}) mean={mean}, expected ~{shape}, rel_err={rel_err}"
);
}
}
#[test]
fn gen_gamma_all_positive() {
let mut r = SplitMix64::new(777);
for _ in 0..10_000 {
let v = r.gen_gamma(0.3);
assert!(v > 0.0, "gamma sample must be positive, got {v}");
}
}
#[test]
fn gen_geom_mean_matches_distribution() {
let p = 0.1;
let expected_mean = (1.0 - p) / p; let n = 100_000;
let mut r = SplitMix64::new(123_456);
let sum: f64 = (0..n).map(|_| r.gen_geom(p)).sum();
let mean = sum / f64::from(n);
assert!(
(mean - expected_mean).abs() / expected_mean < 0.03,
"mean={mean}, expected≈{expected_mean}"
);
}
}