// src/mt.rs
//
// Copyright (c) 2015,2017 rust-mersenne-twister developers
// Copyright (c) 2020 Ryan Lopopolo <rjl@hyperbo.la>
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE> or <http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT> or <http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.
use core::convert::TryFrom;
use core::fmt;
use core::mem::size_of;
use core::num::Wrapping;
use crate::RecoverRngError;
#[cfg(feature = "rand-traits")]
mod rand;
const N: usize = 624;
const M: usize = 397;
const ONE: Wrapping<u32> = Wrapping(1);
const MATRIX_A: Wrapping<u32> = Wrapping(0x9908_b0df);
const UPPER_MASK: Wrapping<u32> = Wrapping(0x8000_0000);
const LOWER_MASK: Wrapping<u32> = Wrapping(0x7fff_ffff);
/// The 32-bit flavor of the Mersenne Twister pseudorandom number
/// generator.
///
/// The official name of this RNG is `MT19937`. It natively outputs `u32`.
///
/// # Size
///
/// `Mt19937GenRand32` requires approximately 2.5 kilobytes of internal state.
///
/// You may wish to store an `Mt19937GenRand32` on the heap in a [`Box`] to make
/// it easier to embed in another struct.
///
/// `Mt19937GenRand32` is also the same size as
/// [`Mt19937GenRand64`](crate::Mt19937GenRand64).
///
/// ```
/// # use core::mem;
/// # use rand_mt::{Mt19937GenRand32, Mt19937GenRand64};
/// assert_eq!(2504, mem::size_of::<Mt19937GenRand32>());
/// assert_eq!(mem::size_of::<Mt19937GenRand64>(), mem::size_of::<Mt19937GenRand32>());
/// ```
#[cfg_attr(feature = "std", doc = "[`Box`]: std::boxed::Box")]
#[cfg_attr(
not(feature = "std"),
doc = "[`Box`]: https://doc.rust-lang.org/std/boxed/struct.Box.html"
)]
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[allow(clippy::module_name_repetitions)]
pub struct Mt19937GenRand32 {
idx: usize,
state: [Wrapping<u32>; N],
}
impl fmt::Debug for Mt19937GenRand32 {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Mt19937GenRand32 {}")
}
}
impl Default for Mt19937GenRand32 {
/// Return a new `Mt19937GenRand32` with the default seed.
///
/// Equivalent to calling [`Mt19937GenRand32::new_unseeded`].
#[inline]
fn default() -> Self {
Self::new_unseeded()
}
}
impl From<[u8; 4]> for Mt19937GenRand32 {
/// Construct a Mersenne Twister RNG from 4 bytes.
///
/// The given bytes are treated as a little endian encoded `u32`.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// // Default MT seed
/// let seed = 5489_u32.to_le_bytes();
/// let mut mt = Mt19937GenRand32::from(seed);
/// assert_ne!(mt.next_u32(), mt.next_u32());
/// ```
///
/// This constructor is equivalent to passing a little endian encoded `u32`.
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// // Default MT seed
/// let seed = 5489_u32.to_le_bytes();
/// let mt1 = Mt19937GenRand32::from(seed);
/// let mt2 = Mt19937GenRand32::new(5489_u32);
/// assert_eq!(mt1, mt2);
/// ```
#[inline]
fn from(seed: [u8; 4]) -> Self {
Self::new(u32::from_le_bytes(seed))
}
}
impl From<u32> for Mt19937GenRand32 {
/// Construct a Mersenne Twister RNG from a `u32` seed.
///
/// This function is equivalent to [`new`].
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// // Default MT seed
/// let seed = 5489_u32;
/// let mt1 = Mt19937GenRand32::from(seed);
/// let mt2 = Mt19937GenRand32::new(seed);
/// assert_eq!(mt1, mt2);
///
/// // Non-default MT seed
/// let seed = 9927_u32;
/// let mt1 = Mt19937GenRand32::from(seed);
/// let mt2 = Mt19937GenRand32::new(seed);
/// assert_eq!(mt1, mt2);
/// ```
///
/// [`new`]: Self::new
#[inline]
fn from(seed: u32) -> Self {
Self::new(seed)
}
}
impl From<[u32; N]> for Mt19937GenRand32 {
/// Recover the internal state of a Mersenne Twister using the past 624
/// samples.
///
/// This conversion takes a history of samples from a RNG and returns a
/// RNG that will produce identical output to the RNG that supplied the
/// samples.
#[inline]
fn from(key: [u32; N]) -> Self {
let mut mt = Self {
idx: N,
state: [Wrapping(0); N],
};
for (sample, out) in key.iter().copied().zip(mt.state.iter_mut()) {
*out = Wrapping(untemper(sample));
}
mt
}
}
impl TryFrom<&[u32]> for Mt19937GenRand32 {
type Error = RecoverRngError;
/// Attempt to recover the internal state of a Mersenne Twister using the
/// past 624 samples.
///
/// This conversion takes a history of samples from a RNG and returns a
/// RNG that will produce identical output to the RNG that supplied the
/// samples.
///
/// This conversion is implemented with [`Mt19937GenRand32::recover`].
///
/// # Errors
///
/// If `key` has less than 624 elements, an error is returned because there
/// is not enough data to fully initialize the RNG.
///
/// If `key` has more than 624 elements, an error is returned because the
/// recovered RNG will not produce identical output to the RNG that supplied
/// the samples.
#[inline]
fn try_from(key: &[u32]) -> Result<Self, Self::Error> {
Self::recover(key.iter().copied())
}
}
impl Mt19937GenRand32 {
/// Default seed used by [`Mt19937GenRand32::new_unseeded`].
pub const DEFAULT_SEED: u32 = 5489_u32;
/// Create a new Mersenne Twister random number generator using the given
/// seed.
///
/// # Examples
///
/// ## Constructing with a `u32` seed
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let seed = 123_456_789_u32;
/// let mt1 = Mt19937GenRand32::new(seed);
/// let mt2 = Mt19937GenRand32::from(seed.to_le_bytes());
/// assert_eq!(mt1, mt2);
/// ```
///
/// ## Constructing with default seed
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let mt1 = Mt19937GenRand32::new(Mt19937GenRand32::DEFAULT_SEED);
/// let mt2 = Mt19937GenRand32::new_unseeded();
/// assert_eq!(mt1, mt2);
/// ```
#[inline]
#[must_use]
pub fn new(seed: u32) -> Self {
let mut mt = Self {
idx: 0,
state: [Wrapping(0); N],
};
mt.reseed(seed);
mt
}
/// Create a new Mersenne Twister random number generator using the given
/// key.
///
/// Key can have any length.
#[inline]
#[must_use]
pub fn new_with_key<I>(key: I) -> Self
where
I: IntoIterator<Item = u32>,
I::IntoIter: Clone,
{
let mut mt = Self {
idx: 0,
state: [Wrapping(0); N],
};
mt.reseed_with_key(key);
mt
}
/// Create a new Mersenne Twister random number generator using the default
/// fixed seed.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// // Default MT seed
/// let seed = 5489_u32;
/// let mt = Mt19937GenRand32::new(seed);
/// let unseeded = Mt19937GenRand32::new_unseeded();
/// assert_eq!(mt, unseeded);
/// ```
#[inline]
#[must_use]
pub fn new_unseeded() -> Self {
Self::new(Self::DEFAULT_SEED)
}
/// Generate next `u64` output.
///
/// This function is implemented by generating two `u32`s from the RNG and
/// performing shifting and masking to turn them into a `u64` output.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let mut mt = Mt19937GenRand32::new_unseeded();
/// assert_ne!(mt.next_u64(), mt.next_u64());
/// ```
#[inline]
pub fn next_u64(&mut self) -> u64 {
let out = u64::from(self.next_u32());
let out = out << 32;
out | u64::from(self.next_u32())
}
/// Generate next `u32` output.
///
/// `u32` is the native output of the generator. This function advances the
/// RNG step counter by one.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let mut mt = Mt19937GenRand32::new_unseeded();
/// assert_ne!(mt.next_u32(), mt.next_u32());
/// ```
#[inline]
pub fn next_u32(&mut self) -> u32 {
// Failing this check indicates that, somehow, the structure
// was not initialized.
debug_assert!(self.idx != 0);
if self.idx >= N {
fill_next_state(self);
}
let Wrapping(x) = self.state[self.idx];
self.idx += 1;
temper(x)
}
/// Fill a buffer with bytes generated from the RNG.
///
/// This method generates random `u32`s (the native output unit of the RNG)
/// until `dest` is filled.
///
/// This method may discard some output bits if `dest.len()` is not a
/// multiple of 4.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let mut mt = Mt19937GenRand32::new_unseeded();
/// let mut buf = [0; 32];
/// mt.fill_bytes(&mut buf);
/// assert_ne!([0; 32], buf);
/// let mut buf = [0; 31];
/// mt.fill_bytes(&mut buf);
/// assert_ne!([0; 31], buf);
/// ```
#[inline]
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
const CHUNK: usize = size_of::<u32>();
let mut dest_chunks = dest.chunks_exact_mut(CHUNK);
for next in &mut dest_chunks {
let chunk: [u8; CHUNK] = self.next_u32().to_le_bytes();
next.copy_from_slice(&chunk);
}
let remainder = dest_chunks.into_remainder();
if remainder.is_empty() {
return;
}
remainder
.iter_mut()
.zip(self.next_u32().to_le_bytes().iter())
.for_each(|(cell, &byte)| {
*cell = byte;
});
}
/// Attempt to recover the internal state of a Mersenne Twister using the
/// past 624 samples.
///
/// This conversion takes a history of samples from a RNG and returns a
/// RNG that will produce identical output to the RNG that supplied the
/// samples.
///
/// This constructor is also available as a [`TryFrom`] implementation for
/// `&[u32]`.
///
/// # Errors
///
/// If `key` has less than 624 elements, an error is returned because there
/// is not enough data to fully initialize the RNG.
///
/// If `key` has more than 624 elements, an error is returned because the
/// recovered RNG will not produce identical output to the RNG that supplied
/// the samples.
#[inline]
pub fn recover<I>(key: I) -> Result<Self, RecoverRngError>
where
I: IntoIterator<Item = u32>,
{
let mut mt = Self {
idx: N,
state: [Wrapping(0); N],
};
let mut state = mt.state.iter_mut();
for sample in key {
let out = state.next().ok_or(RecoverRngError::TooManySamples(N))?;
*out = Wrapping(untemper(sample));
}
// If the state iterator still has unfilled cells, the given iterator
// was too short. If there are no additional cells, return the
// initialized RNG.
if state.next().is_none() {
Ok(mt)
} else {
Err(RecoverRngError::TooFewSamples(N))
}
}
/// Reseed a Mersenne Twister from a single `u32`.
///
/// # Examples
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// // Default MT seed
/// let mut mt = Mt19937GenRand32::new_unseeded();
/// let first = mt.next_u32();
/// mt.fill_bytes(&mut [0; 512]);
/// // Default MT seed
/// mt.reseed(5489_u32);
/// assert_eq!(first, mt.next_u32());
/// ```
#[inline]
#[allow(clippy::cast_possible_truncation)]
pub fn reseed(&mut self, seed: u32) {
self.idx = N;
self.state[0] = Wrapping(seed);
for i in 1..N {
self.state[i] = Wrapping(1_812_433_253)
* (self.state[i - 1] ^ (self.state[i - 1] >> 30))
+ Wrapping(i as u32);
}
}
/// Reseed a Mersenne Twister from am iterator of `u32`s.
///
/// Key can have any length.
#[inline]
#[allow(clippy::cast_possible_truncation)]
pub fn reseed_with_key<I>(&mut self, key: I)
where
I: IntoIterator<Item = u32>,
I::IntoIter: Clone,
{
self.reseed(19_650_218_u32);
let mut i = 1_usize;
for (j, piece) in key.into_iter().enumerate().cycle().take(N) {
self.state[i] = (self.state[i]
^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_664_525)))
+ Wrapping(piece)
+ Wrapping(j as u32);
i += 1;
if i >= N {
self.state[0] = self.state[N - 1];
i = 1;
}
}
for _ in 0..N - 1 {
self.state[i] = (self.state[i]
^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_566_083_941)))
- Wrapping(i as u32);
i += 1;
if i >= N {
self.state[0] = self.state[N - 1];
i = 1;
}
}
self.state[0] = Wrapping(1 << 31);
}
}
#[inline]
fn temper(mut x: u32) -> u32 {
x ^= x >> 11;
x ^= (x << 7) & 0x9d2c_5680;
x ^= (x << 15) & 0xefc6_0000;
x ^= x >> 18;
x
}
#[inline]
fn untemper(mut x: u32) -> u32 {
// reverse `x ^= x>>18;`
x ^= x >> 18;
// reverse `x ^= (x<<15) & 0xefc6_0000;`
x ^= (x << 15) & 0x2fc6_0000;
x ^= (x << 15) & 0xc000_0000;
// reverse `x ^= (x<< 7) & 0x9d2c_5680;`
x ^= (x << 7) & 0x0000_1680;
x ^= (x << 7) & 0x000c_4000;
x ^= (x << 7) & 0x0d20_0000;
x ^= (x << 7) & 0x9000_0000;
// reverse `x ^= x>>11;`
x ^= x >> 11;
x ^= x >> 22;
x
}
#[inline]
fn fill_next_state(rng: &mut Mt19937GenRand32) {
for i in 0..N - M {
let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
rng.state[i] = rng.state[i + M] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
}
for i in N - M..N - 1 {
let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
rng.state[i] = rng.state[i + M - N] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
}
let x = (rng.state[N - 1] & UPPER_MASK) | (rng.state[0] & LOWER_MASK);
rng.state[N - 1] = rng.state[M - 1] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
rng.idx = 0;
}
#[cfg(test)]
mod tests {
use core::convert::TryFrom;
use core::iter;
use core::num::Wrapping;
use super::{Mt19937GenRand32, N};
use crate::vectors::mt::{STATE_SEEDED_BY_SLICE, STATE_SEEDED_BY_U32, TEST_OUTPUT};
use crate::RecoverRngError;
#[test]
fn seeded_state_from_u32_seed() {
let mt = Mt19937GenRand32::new(0x1234_5678_u32);
let mt_from_seed = Mt19937GenRand32::from(0x1234_5678_u32.to_le_bytes());
assert_eq!(mt.state, mt_from_seed.state);
for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
assert_eq!(x, y);
}
for (&Wrapping(x), &y) in mt_from_seed.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
assert_eq!(x, y);
}
}
#[test]
fn seeded_state_from_u32_slice_key() {
let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
let mt = Mt19937GenRand32::new_with_key(key.iter().copied());
for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_SLICE.iter()) {
assert_eq!(x, y);
}
}
#[test]
fn seed_with_empty_iter_returns() {
let _ = Mt19937GenRand32::new_with_key(iter::empty());
}
#[test]
fn output_from_u32_slice_key() {
let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
let mut mt = Mt19937GenRand32::new_with_key(key.iter().copied());
for &x in TEST_OUTPUT.iter() {
assert_eq!(x, mt.next_u32());
}
}
#[test]
fn temper_untemper_is_identity() {
let mut buf = [0; 4];
for _ in 0..10_000 {
getrandom::getrandom(&mut buf).unwrap();
let x = u32::from_le_bytes(buf);
assert_eq!(x, super::untemper(super::temper(x)));
let x = u32::from_be_bytes(buf);
assert_eq!(x, super::untemper(super::temper(x)));
}
}
#[test]
fn untemper_temper_is_identity() {
let mut buf = [0; 4];
for _ in 0..10_000 {
getrandom::getrandom(&mut buf).unwrap();
let x = u32::from_le_bytes(buf);
assert_eq!(x, super::temper(super::untemper(x)));
let x = u32::from_be_bytes(buf);
assert_eq!(x, super::temper(super::untemper(x)));
}
}
#[test]
fn recovery_via_from() {
let mut buf = [0; 4];
for _ in 0..100 {
getrandom::getrandom(&mut buf).unwrap();
let seed = u32::from_le_bytes(buf);
for skip in 0..256 {
let mut orig_mt = Mt19937GenRand32::new(seed);
// skip some samples so the RNG is in an intermediate state
for _ in 0..skip {
orig_mt.next_u32();
}
let mut samples = [0; 624];
for sample in samples.iter_mut() {
*sample = orig_mt.next_u32();
}
let mut recovered_mt = Mt19937GenRand32::from(samples);
for _ in 0..624 * 2 {
assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
}
}
}
}
#[test]
fn recovery_via_recover() {
let mut buf = [0; 4];
for _ in 0..100 {
getrandom::getrandom(&mut buf).unwrap();
let seed = u32::from_le_bytes(buf);
for skip in 0..256 {
let mut orig_mt = Mt19937GenRand32::new(seed);
// skip some samples so the RNG is in an intermediate state
for _ in 0..skip {
orig_mt.next_u32();
}
let mut samples = [0; 624];
for sample in samples.iter_mut() {
*sample = orig_mt.next_u32();
}
let mut recovered_mt = Mt19937GenRand32::recover(samples.iter().copied()).unwrap();
for _ in 0..624 * 2 {
assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
}
}
}
}
#[test]
fn recover_required_exact_sample_length_via_from() {
assert_eq!(
Mt19937GenRand32::try_from(&[0; 0][..]),
Err(RecoverRngError::TooFewSamples(N))
);
assert_eq!(
Mt19937GenRand32::try_from(&[0; 1][..]),
Err(RecoverRngError::TooFewSamples(N))
);
assert_eq!(
Mt19937GenRand32::try_from(&[0; 623][..]),
Err(RecoverRngError::TooFewSamples(N))
);
Mt19937GenRand32::try_from(&[0; 624][..]).unwrap();
assert_eq!(
Mt19937GenRand32::try_from(&[0; 625][..]),
Err(RecoverRngError::TooManySamples(N))
);
assert_eq!(
Mt19937GenRand32::try_from(&[0; 1000][..]),
Err(RecoverRngError::TooManySamples(N))
);
}
#[test]
fn recover_required_exact_sample_length_via_recover() {
assert_eq!(
Mt19937GenRand32::recover([0; 0].iter().copied()),
Err(RecoverRngError::TooFewSamples(N))
);
assert_eq!(
Mt19937GenRand32::recover([0; 1].iter().copied()),
Err(RecoverRngError::TooFewSamples(N))
);
assert_eq!(
Mt19937GenRand32::recover([0; 623].iter().copied()),
Err(RecoverRngError::TooFewSamples(N))
);
Mt19937GenRand32::recover([0; 624].iter().copied()).unwrap();
assert_eq!(
Mt19937GenRand32::recover([0; 625].iter().copied()),
Err(RecoverRngError::TooManySamples(N))
);
assert_eq!(
Mt19937GenRand32::recover([0; 1000].iter().copied()),
Err(RecoverRngError::TooManySamples(N))
);
}
#[test]
#[cfg(feature = "std")]
fn fmt_debug_does_not_leak_seed() {
use core::fmt::Write as _;
use std::string::String;
let random = Mt19937GenRand32::new(874);
let mut buf = String::new();
write!(&mut buf, "{:?}", random).unwrap();
assert!(!buf.contains("874"));
assert_eq!(buf, "Mt19937GenRand32 {}");
let random = Mt19937GenRand32::new(123_456);
let mut buf = String::new();
write!(&mut buf, "{:?}", random).unwrap();
assert!(!buf.contains("123456"));
assert_eq!(buf, "Mt19937GenRand32 {}");
}
}