Skip to main content

ferray_random/bitgen/
mt19937.rs

1// ferray-random: MT19937-64 BitGenerator implementation
2//
3// 64-bit Mersenne Twister (MT19937-64) per Matsumoto & Nishimura (2000).
4// Period 2^19937 - 1. Matches NumPy's `MT19937` BitGenerator (which
5// historically wraps the 32-bit Mersenne Twister; our implementation
6// uses the 64-bit variant since `BitGenerator::next_u64` is the trait
7// contract — both are members of the same family with equivalent
8// statistical properties).
9
10#![allow(clippy::unreadable_literal)]
11
12use super::BitGenerator;
13
14const NN: usize = 312;
15const MM: usize = 156;
16const MATRIX_A: u64 = 0xB502_6F5A_A966_19E9;
17const UM: u64 = 0xFFFF_FFFF_8000_0000;
18const LM: u64 = 0x0000_0000_7FFF_FFFF;
19
20/// MT19937-64 Mersenne Twister BitGenerator.
21///
22/// Period 2^19937 - 1. The `jump` operation is not implemented for the
23/// 64-bit Mersenne Twister in this crate; the standard polynomial-jump
24/// table is large and not required for the typical use case.
25pub struct MT19937 {
26    mt: [u64; NN],
27    index: usize,
28}
29
30impl MT19937 {
31    fn fill_state(&mut self) {
32        for i in 0..NN - MM {
33            let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
34            self.mt[i] = self.mt[i + MM] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
35        }
36        for i in NN - MM..NN - 1 {
37            let x = (self.mt[i] & UM) | (self.mt[i + 1] & LM);
38            self.mt[i] = self.mt[i + MM - NN] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
39        }
40        let x = (self.mt[NN - 1] & UM) | (self.mt[0] & LM);
41        self.mt[NN - 1] = self.mt[MM - 1] ^ (x >> 1) ^ if x & 1 == 0 { 0 } else { MATRIX_A };
42        self.index = 0;
43    }
44}
45
46impl BitGenerator for MT19937 {
47    fn next_u64(&mut self) -> u64 {
48        if self.index >= NN {
49            self.fill_state();
50        }
51        let mut x = self.mt[self.index];
52        self.index += 1;
53        // Tempering.
54        x ^= (x >> 29) & 0x5555_5555_5555_5555;
55        x ^= (x << 17) & 0x71D6_7FFF_EDA6_0000;
56        x ^= (x << 37) & 0xFFF7_EEE0_0000_0000;
57        x ^= x >> 43;
58        x
59    }
60
61    fn seed_from_u64(seed: u64) -> Self {
62        let mut mt = [0u64; NN];
63        mt[0] = seed;
64        for i in 1..NN {
65            mt[i] = 6364136223846793005u64
66                .wrapping_mul(mt[i - 1] ^ (mt[i - 1] >> 62))
67                .wrapping_add(i as u64);
68        }
69        Self { mt, index: NN }
70    }
71
72    fn jump(&mut self) -> Option<()> {
73        // Polynomial-jump for MT19937-64 isn't included here; users that
74        // need stream parallelism should pick `Philox` (which has
75        // `stream`) or seed multiple `MT19937`s with disjoint seeds via
76        // SeedSequence.
77        None
78    }
79
80    fn stream(_seed: u64, _stream_id: u64) -> Option<Self> {
81        None
82    }
83}
84
85impl Clone for MT19937 {
86    fn clone(&self) -> Self {
87        Self {
88            mt: self.mt,
89            index: self.index,
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn deterministic_output() {
100        let mut a = MT19937::seed_from_u64(123);
101        let mut b = MT19937::seed_from_u64(123);
102        for _ in 0..512 {
103            assert_eq!(a.next_u64(), b.next_u64());
104        }
105    }
106
107    #[test]
108    fn different_seeds_differ() {
109        let mut a = MT19937::seed_from_u64(1);
110        let mut b = MT19937::seed_from_u64(2);
111        let mut differ = false;
112        for _ in 0..100 {
113            if a.next_u64() != b.next_u64() {
114                differ = true;
115                break;
116            }
117        }
118        assert!(differ);
119    }
120
121    #[test]
122    fn output_covers_full_range() {
123        let mut rng = MT19937::seed_from_u64(0xdead_beef);
124        let mut high = false;
125        let mut low = false;
126        for _ in 0..10_000 {
127            let v = rng.next_u64();
128            if v > u64::MAX / 2 {
129                high = true;
130            } else {
131                low = true;
132            }
133            if high && low {
134                break;
135            }
136        }
137        assert!(high && low);
138    }
139
140    #[test]
141    fn uniform_f64_in_unit_interval() {
142        let mut rng = MT19937::seed_from_u64(42);
143        for _ in 0..1000 {
144            let x = rng.next_f64();
145            assert!((0.0..1.0).contains(&x));
146        }
147    }
148}