Skip to main content

ferray_random/bitgen/
mt19937.rs

1// ferray-random: MT19937-64 BitGenerator implementation
2//
3// 64-bit Mersenne Twister (MT19937-64) per Matsumoto & Nishimura (2000).
4// Period 2^19937 - 1. Matches NumPy's `MT19937` BitGenerator (which
5// historically wraps the 32-bit Mersenne Twister; our implementation
6// uses the 64-bit variant since `BitGenerator::next_u64` is the trait
7// contract — both are members of the same family with equivalent
8// statistical properties).
9
10#![allow(clippy::unreadable_literal)]
11
12use super::BitGenerator;
13
14const NN: usize = 312;
15const MM: usize = 156;
16const MATRIX_A: u64 = 0xB502_6F5A_A966_19E9;
17const UM: u64 = 0xFFFF_FFFF_8000_0000;
18const LM: u64 = 0x0000_0000_7FFF_FFFF;
19
20/// MT19937-64 Mersenne Twister BitGenerator.
21///
22/// Period 2^19937 - 1. The `jump` operation is not implemented for the
23/// 64-bit Mersenne Twister in this crate; the standard polynomial-jump
24/// table is large and not required for the typical use case.
25pub struct MT19937 {
26    mt: [u64; NN],
27    index: usize,
28}
29
30impl MT19937 {
31    fn fill_state(&mut self) {
32        for i in 0..NN - MM {
33            let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
34            self.mt[i] = self.mt[i + MM] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
35        }
36        for i in NN - MM..NN - 1 {
37            let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
38            self.mt[i] = self.mt[i + MM - NN] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
39        }
40        let x = (self.mt[NN - 1] & UM) | (self.mt[0] & LM);
41        self.mt[NN - 1] = self.mt[MM - 1] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
42        self.index = 0;
43    }
44}
45
46impl BitGenerator for MT19937 {
47    fn state_bytes(&self) -> Result<Vec<u8>, ferray_core::FerrayError> {
48        // Layout: NN × u64 followed by index as u64. Total
49        // 312 × 8 + 8 = 2504 bytes.
50        let mut out = Vec::with_capacity(NN * 8 + 8);
51        for &w in &self.mt {
52            out.extend_from_slice(&w.to_le_bytes());
53        }
54        out.extend_from_slice(&(self.index as u64).to_le_bytes());
55        Ok(out)
56    }
57
58    fn set_state_bytes(&mut self, bytes: &[u8]) -> Result<(), ferray_core::FerrayError> {
59        let expected = NN * 8 + 8;
60        if bytes.len() != expected {
61            return Err(ferray_core::FerrayError::invalid_value(format!(
62                "MT19937 state must be {expected} bytes, got {}",
63                bytes.len()
64            )));
65        }
66        let mut mt = [0u64; NN];
67        for (i, chunk) in bytes[..NN * 8].chunks_exact(8).enumerate() {
68            mt[i] = u64::from_le_bytes(chunk.try_into().unwrap());
69        }
70        let index = u64::from_le_bytes(bytes[NN * 8..].try_into().unwrap());
71        if index > NN as u64 {
72            return Err(ferray_core::FerrayError::invalid_value(format!(
73                "MT19937 index must be in [0, {NN}], got {index}"
74            )));
75        }
76        self.mt = mt;
77        self.index = index as usize;
78        Ok(())
79    }
80
81    fn next_u64(&mut self) -> u64 {
82        if self.index >= NN {
83            self.fill_state();
84        }
85        let mut x = self.mt[self.index];
86        self.index += 1;
87        // Tempering.
88        x ^= (x >> 29) & 0x5555_5555_5555_5555;
89        x ^= (x << 17) & 0x71D6_7FFF_EDA6_0000;
90        x ^= (x << 37) & 0xFFF7_EEE0_0000_0000;
91        x ^= x >> 43;
92        x
93    }
94
95    fn seed_from_u64(seed: u64) -> Self {
96        let mut mt = [0u64; NN];
97        mt[0] = seed;
98        for i in 1..NN {
99            mt[i] = 6364136223846793005u64
100                .wrapping_mul(mt[i - 1] ^ (mt[i - 1] >> 62))
101                .wrapping_add(i as u64);
102        }
103        Self { mt, index: NN }
104    }
105
106    fn jump(&mut self) -> Option<()> {
107        // Polynomial-jump for MT19937-64 isn't included here; users that
108        // need stream parallelism should pick `Philox` (which has
109        // `stream`) or seed multiple `MT19937`s with disjoint seeds via
110        // SeedSequence.
111        None
112    }
113
114    fn stream(_seed: u64, _stream_id: u64) -> Option<Self> {
115        None
116    }
117}
118
119impl Clone for MT19937 {
120    fn clone(&self) -> Self {
121        Self {
122            mt: self.mt,
123            index: self.index,
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn deterministic_output() {
134        let mut a = MT19937::seed_from_u64(123);
135        let mut b = MT19937::seed_from_u64(123);
136        for _ in 0..512 {
137            assert_eq!(a.next_u64(), b.next_u64());
138        }
139    }
140
141    #[test]
142    fn different_seeds_differ() {
143        let mut a = MT19937::seed_from_u64(1);
144        let mut b = MT19937::seed_from_u64(2);
145        let mut differ = false;
146        for _ in 0..100 {
147            if a.next_u64() != b.next_u64() {
148                differ = true;
149                break;
150            }
151        }
152        assert!(differ);
153    }
154
155    #[test]
156    fn output_covers_full_range() {
157        let mut rng = MT19937::seed_from_u64(0xdead_beef);
158        let mut high = false;
159        let mut low = false;
160        for _ in 0..10_000 {
161            let v = rng.next_u64();
162            if v > u64::MAX / 2 {
163                high = true;
164            } else {
165                low = true;
166            }
167            if high && low {
168                break;
169            }
170        }
171        assert!(high && low);
172    }
173
174    #[test]
175    fn uniform_f64_in_unit_interval() {
176        let mut rng = MT19937::seed_from_u64(42);
177        for _ in 0..1000 {
178            let x = rng.next_f64();
179            assert!((0.0..1.0).contains(&x));
180        }
181    }
182}