use std::cell::RefCell;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
const MERSENNE_STATE_N: usize = 624;
const MERSENNE_STATE_M: usize = 397;
const MATRIX_A: u32 = 0x9908_b0df;
const UMASK: u32 = 0x8000_0000;
const LMASK: u32 = 0x7fff_ffff;
#[derive(Clone)]
struct Mt19937 {
state: [u32; MERSENNE_STATE_N],
next: usize,
left: i32,
seed: u64,
}
impl Mt19937 {
fn new(seed: u64) -> Self {
let mut state = [0u32; MERSENNE_STATE_N];
state[0] = (seed & 0xffff_ffff) as u32;
for j in 1..MERSENNE_STATE_N {
let prev = state[j - 1];
state[j] = 1_812_433_253u32
.wrapping_mul(prev ^ (prev >> 30))
.wrapping_add(j as u32);
}
Self {
state,
next: 0,
left: 1,
seed,
}
}
fn mix_bits(u: u32, v: u32) -> u32 {
(u & UMASK) | (v & LMASK)
}
fn twist(u: u32, v: u32) -> u32 {
let mixed = Self::mix_bits(u, v) >> 1;
if v & 1 != 0 { mixed ^ MATRIX_A } else { mixed }
}
fn next_state(&mut self) {
self.left = MERSENNE_STATE_N as i32;
self.next = 0;
let n = MERSENNE_STATE_N;
let m = MERSENNE_STATE_M;
for p in 0..(n - m) {
self.state[p] = self.state[p + m] ^ Self::twist(self.state[p], self.state[p + 1]);
}
for p in (n - m)..(n - 1) {
self.state[p] = self.state[p + m - n] ^ Self::twist(self.state[p], self.state[p + 1]);
}
self.state[n - 1] = self.state[m - 1] ^ Self::twist(self.state[n - 1], self.state[0]);
}
fn random_u32(&mut self) -> u32 {
self.left -= 1;
if self.left == 0 {
self.next_state();
}
let mut y = self.state[self.next];
self.next += 1;
y ^= y >> 11;
y ^= (y << 7) & 0x9d2c_5680;
y ^= (y << 15) & 0xefc6_0000;
y ^= y >> 18;
y
}
fn random_u64(&mut self) -> u64 {
let hi = self.random_u32();
let lo = self.random_u32();
((hi as u64) << 32) | (lo as u64)
}
}
#[derive(Clone)]
pub struct Generator {
engine: Mt19937,
next_float_normal: Option<f32>,
next_double_normal: Option<f64>,
}
impl std::fmt::Debug for Generator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Generator")
.field("seed", &self.engine.seed)
.field("has_cached_f32_normal", &self.next_float_normal.is_some())
.field("has_cached_f64_normal", &self.next_double_normal.is_some())
.finish()
}
}
impl Generator {
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
engine: Mt19937::new(seed),
next_float_normal: None,
next_double_normal: None,
}
}
#[must_use]
pub fn seed_from_entropy() -> Self {
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut seed = hasher.finish();
if seed == 0 {
seed = 0xdead_beef_cafe;
}
Self::new(seed)
}
pub fn manual_seed(&mut self, seed: u64) {
self.engine = Mt19937::new(seed);
self.next_float_normal = None;
self.next_double_normal = None;
}
#[must_use]
pub fn seed(&self) -> u64 {
self.engine.seed
}
pub fn random_u32(&mut self) -> u32 {
self.engine.random_u32()
}
pub fn random_u64(&mut self) -> u64 {
self.engine.random_u64()
}
pub fn next_uniform_f32(&mut self) -> f32 {
const MASK: u32 = (1u32 << 24) - 1;
const DIVISOR: f32 = 1.0f32 / ((1u32 << 24) as f32);
let v = self.engine.random_u32() & MASK;
(v as f32) * DIVISOR
}
pub fn next_uniform_f64(&mut self) -> f64 {
const MASK: u64 = (1u64 << 53) - 1;
const DIVISOR: f64 = 1.0f64 / ((1u64 << 53) as f64);
let v = self.engine.random_u64() & MASK;
(v as f64) * DIVISOR
}
pub fn next_normal_f32(&mut self) -> f32 {
if let Some(cached) = self.next_float_normal.take() {
return cached;
}
let u1 = self.next_uniform_f32();
let u2 = self.next_uniform_f32();
let r = (-2.0f32 * (-u2).ln_1p()).sqrt();
let theta = 2.0f32 * std::f32::consts::PI * u1;
let (sin_t, cos_t) = theta.sin_cos();
self.next_float_normal = Some(r * sin_t);
r * cos_t
}
pub fn next_normal_f64(&mut self) -> f64 {
if let Some(cached) = self.next_double_normal.take() {
return cached;
}
let u1 = self.next_uniform_f64();
let u2 = self.next_uniform_f64();
let r = (-2.0f64 * (-u2).ln_1p()).sqrt();
let theta = 2.0f64 * std::f64::consts::PI * u1;
let (sin_t, cos_t) = theta.sin_cos();
self.next_double_normal = Some(r * sin_t);
r * cos_t
}
}
impl Default for Generator {
fn default() -> Self {
Self::seed_from_entropy()
}
}
thread_local! {
static THREAD_RNG: RefCell<Generator> = RefCell::new(Generator::seed_from_entropy());
}
pub fn manual_seed(seed: u64) {
THREAD_RNG.with(|rng| {
rng.borrow_mut().manual_seed(seed);
});
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let _ = backend.manual_seed_gpu(seed);
}
}
pub fn with_thread_rng<R>(f: impl FnOnce(&mut Generator) -> R) -> R {
THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
}
#[cfg(test)]
mod tests {
use super::*;
const TORCH_RAND_SEED_42_F32_BITS: [u32; 10] = [
0x3f61_dc66,
0x3f6a_3db3,
0x3ec4_06b8,
0x3f75_950e,
0x3ec7_e8d4,
0x3f19_d447,
0x3e83_5d78,
0x3f4b_2c14,
0x3f70_d666,
0x3e08_61e4,
];
#[test]
fn mt19937_seed_42_matches_torch_rand_f32() {
let mut g = Generator::new(42);
for (i, &expected_bits) in TORCH_RAND_SEED_42_F32_BITS.iter().enumerate() {
let got = g.next_uniform_f32();
assert_eq!(
got.to_bits(),
expected_bits,
"i={i}: got=0x{:08x} ({got:.17}), expected=0x{expected_bits:08x}",
got.to_bits()
);
}
}
#[test]
fn manual_seed_resets_thread_local() {
manual_seed(42);
let a: Vec<u32> = (0..5)
.map(|_| with_thread_rng(|g| g.random_u32()))
.collect();
manual_seed(42);
let b: Vec<u32> = (0..5)
.map(|_| with_thread_rng(|g| g.random_u32()))
.collect();
assert_eq!(a, b);
}
#[test]
fn manual_seed_distinct_seeds_distinct_streams() {
manual_seed(42);
let a = with_thread_rng(|g| g.random_u32());
manual_seed(43);
let b = with_thread_rng(|g| g.random_u32());
assert_ne!(a, b);
}
#[test]
fn generator_clone_preserves_stream() {
let mut g = Generator::new(12345);
let _ = g.random_u32();
let mut g2 = g.clone();
assert_eq!(g.random_u32(), g2.random_u32());
assert_eq!(g.random_u32(), g2.random_u32());
}
#[test]
fn normal_box_muller_cache_used() {
let mut g = Generator::new(42);
let n1 = g.next_normal_f32();
let n2 = g.next_normal_f32();
assert!(n1.is_finite() && n2.is_finite());
assert!(g.next_float_normal.is_none(), "cache must be drained");
}
#[test]
fn random_u64_concatenates_two_u32_in_order() {
let mut g = Generator::new(7);
let mut g2 = Generator::new(7);
let hi = g2.random_u32();
let lo = g2.random_u32();
let expected = ((hi as u64) << 32) | (lo as u64);
assert_eq!(g.random_u64(), expected);
}
}