use candle_core::{DType, Device, Result, Tensor};
use rand_mt::Mt;
use std::sync::Mutex;
static GLOBAL_RNG: Mutex<Option<MtBoxMullerRng>> = Mutex::new(None);
#[derive(Debug, Clone)]
pub struct MtBoxMullerRng {
rng: Mt,
cached_value: Option<f32>,
}
impl MtBoxMullerRng {
pub fn new(seed: u64) -> Self {
Self {
rng: Mt::new(seed as u32),
cached_value: None,
}
}
#[inline]
fn mt_to_uniform_double(lo: u32, hi: u32) -> f64 {
let combined = ((lo as u64) << 32) | (hi as u64);
const MASK_53BIT: u64 = 0x001F_FFFF_FFFF_FFFF;
const DIVISOR: f64 = 9_007_199_254_740_992.0; (combined & MASK_53BIT) as f64 / DIVISOR
}
#[inline]
fn mt_to_uniform_float(val: u32) -> f32 {
const MASK_24BIT: u32 = 0x00FF_FFFF;
const DIVISOR: f32 = 16_777_216.0; (val & MASK_24BIT) as f32 / DIVISOR
}
#[inline]
pub fn uniform(&mut self) -> f32 {
Self::mt_to_uniform_float(self.rng.next_u32())
}
pub fn sample_scalar(&mut self) -> f32 {
if let Some(cached) = self.cached_value.take() {
return cached;
}
let lo1 = self.rng.next_u32();
let hi1 = self.rng.next_u32();
let lo2 = self.rng.next_u32();
let hi2 = self.rng.next_u32();
let u1 = Self::mt_to_uniform_double(lo1, hi1);
let u2 = Self::mt_to_uniform_double(lo2, hi2);
let r = (-2.0_f64 * (1.0_f64 - u2).ln()).sqrt();
let theta = 2.0_f64 * std::f64::consts::PI * u1;
let sample1 = (r * theta.cos()) as f32;
let sample2 = (r * theta.sin()) as f32;
self.cached_value = Some(sample2);
sample1
}
fn sample_vectorized(&mut self, count: usize) -> Vec<f32> {
let mut output = Vec::with_capacity(count);
let num_full_chunks = count / 16;
let remainder = count % 16;
for _ in 0..num_full_chunks {
let mut uniforms = [0.0_f32; 16];
for u in uniforms.iter_mut() {
*u = Self::mt_to_uniform_float(self.rng.next_u32());
}
let mut cos_vals = [0.0_f32; 8];
let mut sin_vals = [0.0_f32; 8];
for i in 0..8 {
let u1 = 1.0_f32 - uniforms[i]; let u2 = uniforms[8 + i];
let r = (-2.0_f32 * u1.ln()).sqrt();
let theta = 2.0_f32 * std::f32::consts::PI * u2;
cos_vals[i] = r * theta.cos();
sin_vals[i] = r * theta.sin();
}
output.extend_from_slice(&cos_vals);
output.extend_from_slice(&sin_vals);
}
if remainder > 0 {
for _ in 0..remainder {
output.push(self.sample_scalar());
}
}
output
}
pub fn randn(&mut self, shape: &[usize], device: &Device, dtype: DType) -> Result<Tensor> {
let elem_count: usize = shape.iter().product();
let data = if elem_count >= 16 {
self.sample_vectorized(elem_count)
} else {
let mut data = Vec::with_capacity(elem_count);
for _ in 0..elem_count {
data.push(self.sample_scalar());
}
data
};
let cpu_tensor = Tensor::from_vec(data, shape, &Device::Cpu)?;
let tensor = if matches!(device, Device::Cpu) {
cpu_tensor
} else {
cpu_tensor.to_device(device)?
};
tensor.to_dtype(dtype)
}
}
pub fn set_seed(seed: u64) {
let mut guard = GLOBAL_RNG.lock().unwrap();
*guard = Some(MtBoxMullerRng::new(seed));
}
pub fn global_uniform() -> f32 {
let mut guard = GLOBAL_RNG.lock().unwrap();
let rng = guard.get_or_insert_with(|| {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
MtBoxMullerRng::new(seed)
});
rng.uniform()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_determinism() {
let mut rng1 = MtBoxMullerRng::new(42);
let mut rng2 = MtBoxMullerRng::new(42);
for _ in 0..100 {
let v1 = rng1.sample_scalar();
let v2 = rng2.sample_scalar();
assert_eq!(v1, v2);
}
}
#[test]
fn test_uniform_determinism() {
let mut rng1 = MtBoxMullerRng::new(42);
let mut rng2 = MtBoxMullerRng::new(42);
for _ in 0..100 {
let v1 = rng1.uniform();
let v2 = rng2.uniform();
assert_eq!(v1, v2);
assert!((0.0..1.0).contains(&v1));
}
}
#[test]
fn test_pytorch_parity_scalar_seed_42() {
let mut rng = MtBoxMullerRng::new(42);
let v1 = rng.sample_scalar();
let v2 = rng.sample_scalar();
let v3 = rng.sample_scalar();
let v4 = rng.sample_scalar();
tracing::debug!(v1, v2, v3, v4, "Rust randn(4) SCALAR path with seed 42");
}
#[test]
fn test_vectorized_basic() {
let mut rng = MtBoxMullerRng::new(42);
let values = rng.sample_vectorized(16);
assert_eq!(values.len(), 16);
for v in &values {
assert!(v.is_finite());
}
}
#[test]
fn test_pytorch_parity_vectorized_seed_42() {
let mut rng = MtBoxMullerRng::new(42);
let values = rng.sample_vectorized(16);
tracing::debug!(first_8 = ?&values[..8], "Rust randn(16) VECTORIZED path with seed 42");
let pytorch_first_8 = [
1.926_915_f32,
1.487_284_2,
0.900_717_2,
-2.105_521_4,
0.678_418_46,
-1.234_545,
-0.043_067_48,
-1.604_667,
];
for (i, (rust, pytorch)) in values.iter().zip(pytorch_first_8.iter()).enumerate() {
let diff = (rust - pytorch).abs();
tracing::debug!(i, rust, pytorch, diff, "Parity check");
assert!(
diff < 1e-5,
"Mismatch at index {i}: rust={rust}, pytorch={pytorch}"
);
}
}
#[test]
fn test_randn_tensor() {
let mut rng = MtBoxMullerRng::new(42);
let tensor = rng
.randn(&[2, 3, 4], &Device::Cpu, DType::F32)
.expect("Failed to create tensor");
assert_eq!(tensor.dims(), &[2, 3, 4]);
}
}