ferray_random/bitgen/
mt19937.rs1#![allow(clippy::unreadable_literal)]
11
12use super::BitGenerator;
13
14const NN: usize = 312;
15const MM: usize = 156;
16const MATRIX_A: u64 = 0xB502_6F5A_A966_19E9;
17const UM: u64 = 0xFFFF_FFFF_8000_0000;
18const LM: u64 = 0x0000_0000_7FFF_FFFF;
19
20pub struct MT19937 {
26 mt: [u64; NN],
27 index: usize,
28}
29
30impl MT19937 {
31 fn fill_state(&mut self) {
32 for i in 0..NN - MM {
33 let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
34 self.mt[i] = self.mt[i + MM] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
35 }
36 for i in NN - MM..NN - 1 {
37 let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
38 self.mt[i] = self.mt[i + MM - NN] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
39 }
40 let x = (self.mt[NN - 1] & UM) | (self.mt[0] & LM);
41 self.mt[NN - 1] = self.mt[MM - 1] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
42 self.index = 0;
43 }
44}
45
46impl BitGenerator for MT19937 {
47 fn state_bytes(&self) -> Result<Vec<u8>, ferray_core::FerrayError> {
48 let mut out = Vec::with_capacity(NN * 8 + 8);
51 for &w in &self.mt {
52 out.extend_from_slice(&w.to_le_bytes());
53 }
54 out.extend_from_slice(&(self.index as u64).to_le_bytes());
55 Ok(out)
56 }
57
58 fn set_state_bytes(&mut self, bytes: &[u8]) -> Result<(), ferray_core::FerrayError> {
59 let expected = NN * 8 + 8;
60 if bytes.len() != expected {
61 return Err(ferray_core::FerrayError::invalid_value(format!(
62 "MT19937 state must be {expected} bytes, got {}",
63 bytes.len()
64 )));
65 }
66 let mut mt = [0u64; NN];
67 for (i, chunk) in bytes[..NN * 8].chunks_exact(8).enumerate() {
68 mt[i] = u64::from_le_bytes(chunk.try_into().unwrap());
69 }
70 let index = u64::from_le_bytes(bytes[NN * 8..].try_into().unwrap());
71 if index > NN as u64 {
72 return Err(ferray_core::FerrayError::invalid_value(format!(
73 "MT19937 index must be in [0, {NN}], got {index}"
74 )));
75 }
76 self.mt = mt;
77 self.index = index as usize;
78 Ok(())
79 }
80
81 fn next_u64(&mut self) -> u64 {
82 if self.index >= NN {
83 self.fill_state();
84 }
85 let mut x = self.mt[self.index];
86 self.index += 1;
87 x ^= (x >> 29) & 0x5555_5555_5555_5555;
89 x ^= (x << 17) & 0x71D6_7FFF_EDA6_0000;
90 x ^= (x << 37) & 0xFFF7_EEE0_0000_0000;
91 x ^= x >> 43;
92 x
93 }
94
95 fn seed_from_u64(seed: u64) -> Self {
96 let mut mt = [0u64; NN];
97 mt[0] = seed;
98 for i in 1..NN {
99 mt[i] = 6364136223846793005u64
100 .wrapping_mul(mt[i - 1] ^ (mt[i - 1] >> 62))
101 .wrapping_add(i as u64);
102 }
103 Self { mt, index: NN }
104 }
105
106 fn jump(&mut self) -> Option<()> {
107 None
112 }
113
114 fn stream(_seed: u64, _stream_id: u64) -> Option<Self> {
115 None
116 }
117}
118
119impl Clone for MT19937 {
120 fn clone(&self) -> Self {
121 Self {
122 mt: self.mt,
123 index: self.index,
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn deterministic_output() {
134 let mut a = MT19937::seed_from_u64(123);
135 let mut b = MT19937::seed_from_u64(123);
136 for _ in 0..512 {
137 assert_eq!(a.next_u64(), b.next_u64());
138 }
139 }
140
141 #[test]
142 fn different_seeds_differ() {
143 let mut a = MT19937::seed_from_u64(1);
144 let mut b = MT19937::seed_from_u64(2);
145 let mut differ = false;
146 for _ in 0..100 {
147 if a.next_u64() != b.next_u64() {
148 differ = true;
149 break;
150 }
151 }
152 assert!(differ);
153 }
154
155 #[test]
156 fn output_covers_full_range() {
157 let mut rng = MT19937::seed_from_u64(0xdead_beef);
158 let mut high = false;
159 let mut low = false;
160 for _ in 0..10_000 {
161 let v = rng.next_u64();
162 if v > u64::MAX / 2 {
163 high = true;
164 } else {
165 low = true;
166 }
167 if high && low {
168 break;
169 }
170 }
171 assert!(high && low);
172 }
173
174 #[test]
175 fn uniform_f64_in_unit_interval() {
176 let mut rng = MT19937::seed_from_u64(42);
177 for _ in 0..1000 {
178 let x = rng.next_f64();
179 assert!((0.0..1.0).contains(&x));
180 }
181 }
182}