use std::slice::from_raw_parts_mut;
#[repr(C)]
pub struct Mt19937 {
mt: [u32; N],
mti: usize,
}
const N: usize = 624;
const M: usize = 397;
const MATRIX_A: u32 = 0x9908B0DF;
const UPPER_MASK: u32 = 0x80000000;
const LOWER_MASK: u32 = 0x7FFFFFFF;
impl Mt19937 {
pub fn new(seed: u32) -> Self {
let mut mt = [0u32; N];
mt[0] = seed;
for i in 1..N {
let prev = mt[i - 1];
mt[i] = (1812433253u32)
.wrapping_mul(prev ^ (prev >> 30))
.wrapping_add(i as u32);
}
Self { mt, mti: N }
}
#[inline]
pub fn nextu(&mut self) -> u32 {
if self.mti >= N {
self.twist();
}
let mut y = self.mt[self.mti];
self.mti += 1;
y ^= y >> 11;
y ^= (y << 7) & 0x9D2C5680;
y ^= (y << 15) & 0xEFC60000;
y ^= y >> 18;
y
}
fn twist(&mut self) {
for i in 0..N - M {
let x = (self.mt[i] & UPPER_MASK) | (self.mt[i + 1] & LOWER_MASK);
let mut x_a = x >> 1;
if x & 1 != 0 {
x_a ^= MATRIX_A;
}
self.mt[i] = self.mt[i + M] ^ x_a;
}
for i in N - M..N - 1 {
let x = (self.mt[i] & UPPER_MASK) | (self.mt[i + 1] & LOWER_MASK);
let mut x_a = x >> 1;
if x & 1 != 0 {
x_a ^= MATRIX_A;
}
self.mt[i] = self.mt[i + M - N] ^ x_a;
}
let x = (self.mt[N - 1] & UPPER_MASK) | (self.mt[0] & LOWER_MASK);
let mut x_a = x >> 1;
if x & 1 != 0 {
x_a ^= MATRIX_A;
}
self.mt[N - 1] = self.mt[M - 1] ^ x_a;
self.mti = 0;
}
#[inline]
pub fn nextf(&mut self) -> f32 {
self.nextu() as f32 * (1.0 / (u32::MAX as f32 + 1.0))
}
#[inline]
pub fn randi(&mut self, min: i32, max: i32) -> i32 {
let range = (max as i64 - min as i64 + 1) as u64;
let x = self.nextu();
((x as u64 * range) >> 32) as i32 + min
}
#[inline]
pub fn randf(&mut self, min: f32, max: f32) -> f32 {
let range = max - min;
let scale = range * (1.0 / (u32::MAX as f32 + 1.0));
(self.nextu() as f32 * scale) + min
}
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_new(seed: u32) -> *mut Mt19937 {
Box::into_raw(Box::new(Mt19937::new(seed)))
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_free(ptr: *mut Mt19937) {
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) };
}
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_next_u32s(ptr: *mut Mt19937, out: *mut u32, count: usize) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
for v in buffer {
*v = rng.nextu();
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_next_f32s(ptr: *mut Mt19937, out: *mut f32, count: usize) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
for v in buffer {
*v = rng.nextf();
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_rand_i32s(
ptr: *mut Mt19937,
out: *mut i32,
count: usize,
min: i32,
max: i32,
) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
for v in buffer {
*v = rng.randi(min, max);
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn mt19937_rand_f32s(
ptr: *mut Mt19937,
out: *mut f32,
count: usize,
min: f32,
max: f32,
) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
for v in buffer {
*v = rng.randf(min, max);
}
}
}
#[cfg(test)]
mod tests {
use crate::mt19937::Mt19937;
#[test]
fn test_mt19937() {
let rng = &mut Mt19937::new(1);
let start = std::time::Instant::now();
assert_eq!(rng.nextu(), 1791095845);
println!("Elapsed: {:?}", start.elapsed());
}
#[test]
fn test_mt19937_methods() {
let mut rng = Mt19937::new(12345);
let _ = rng.nextu();
let f = rng.nextf();
assert!(f >= 0.0 && f < 1.0);
let i = rng.randi(10, 20);
assert!(i >= 10 && i <= 20);
let rf = rng.randf(10.0, 20.0);
assert!(rf >= 10.0 && rf < 20.0);
}
}