use crate::rand::{Seed, Result, PrimitiveType, RandError, RandErrKind, Source};
use std::ops::Shl;
use std::fmt::Debug;
#[derive(Clone)]
pub struct MersenneTwisterRand<T> {
word_size: T,
state_size: T,
shift_size: T,
mask_bits: T,
xor_mask: T,
tempering_u: T,
tempering_d: T,
tempering_s: T,
tempering_b: T,
tempering_t: T,
tempering_c: T,
shift_l: T,
init_multiplier: T,
idx: T,
seed: std::result::Result<T, T>,
state: Vec<T>,
}
impl<T> MersenneTwisterRand<T>
where T: PrimitiveType + PartialOrd<T> + Shl<Output = T> + Default + Debug + Copy {
pub fn new<Sd: Seed<T>>(sd: &Sd, w: T, n: T, m: T, r: T, a: T,
u_d: (T, T), s_b: (T, T), t_c: (T, T), l: T, f: T) -> Result<Self> {
if w == T::default() || m == T::default() || r == T::default() || a == T::default() ||
u_d == (T::default(), T::default()) || s_b == (T::default(), T::default()) ||
t_c == (T::default(), T::default()) || l == T::default() || f == T::default() {
Err(RandError::new(RandErrKind::InvalidRngPara, format!("all parameters cannot be the {:?}", T::default())))
} else if m > n {
Err(RandError::new(RandErrKind::InvalidRngPara, "m must be satisfy the relation: m <= n"))
} else if w > T::bits_len() {
Err(RandError::new(RandErrKind::InvalidRngPara, format!("w must be satisfy the relation: m <= {:?}", T::bits_len())))
} else if r >= w {
Err(RandError::new(RandErrKind::InvalidRngPara, "r must be less than w"))
} else if u_d.0 >= w {
Err(RandError::new(RandErrKind::InvalidRngPara, "u must be less than w"))
} else if s_b.0 >= w {
Err(RandError::new(RandErrKind::InvalidRngPara, "s must be less than w"))
} else if t_c.0 >= w {
Err(RandError::new(RandErrKind::InvalidRngPara, "t must be less than w"))
} else if l >= w {
Err(RandError::new(RandErrKind::InvalidRngPara, "l must be less than w"))
} else if a > w.mask() {
Err(RandError::new(RandErrKind::InvalidRngPara, "a must be less than or equal to ((1<<w) - 1)"))
} else if u_d.1 > w.mask() {
Err(RandError::new(RandErrKind::InvalidRngPara, "d must be less than or equal to ((1<<w) - 1)"))
} else if s_b.1 > w.mask() {
Err(RandError::new(RandErrKind::InvalidRngPara, "b must be less than or equal to ((1<<w) - 1)"))
} else if t_c.1 > w.mask() {
Err(RandError::new(RandErrKind::InvalidRngPara, "c must be less than or equal to ((1<<w) - 1)"))
} else if f > w.mask() {
Err(RandError::new(RandErrKind::InvalidRngPara, "f must be less than or equal to ((1<<w) - 1)"))
} else {
sd.seed().map(|x| {
Self {
word_size: w,
state_size: n,
shift_size: m,
mask_bits: r,
xor_mask: a,
tempering_u: u_d.0,
tempering_d: u_d.1,
tempering_s: s_b.0,
tempering_b: s_b.1,
tempering_t: t_c.0,
tempering_c: t_c.1,
shift_l: l,
init_multiplier: f,
seed: Err(x),
idx: n,
state: Vec::new(),
}
})
}
}
}
macro_rules! mtr_impl {
($Type0: ty) => {
impl MersenneTwisterRand<$Type0> {
fn check_init(&mut self) {
match self.seed {
Err(seed) => {
self.state.clear();
let n = self.state_size as usize;
self.state.reserve(n.saturating_sub(self.state.capacity()));
let mut last = seed & self.word_size.mask();
self.state.push(last);
let n = n as $Type0;
(1..n).for_each(|i| {
last ^= last >> (self.word_size - 2);
last = last.overflowing_mul(self.init_multiplier).0;
last = last.overflowing_add(i).0;
last &= self.word_size.mask();
self.state.push(last);
});
self.idx = self.state_size;
self.seed = Ok(seed);
},
_ => {},
}
}
fn gen_rand(&mut self) {
let upper_mask: $Type0 = (!(<$Type0>::default())) << self.mask_bits;
let lower_mask: $Type0 = !upper_mask;
let k = (self.state_size - self.shift_size) as usize;
(0..k).for_each(|i| {
let y = (self.state[i] & upper_mask) | (self.state[i+1] & lower_mask);
self.state[i] = (self.state[i+(self.shift_size as usize)] ^ (y >> 1)) ^
(if (y & 0x1) > 0 {self.xor_mask} else {0});
});
let n = (self.state_size - 1) as usize;
let mut j = 0;
(k..n).for_each(|i| {
let y = (self.state[i] & upper_mask) | (self.state[i+1] & lower_mask);
self.state[i] = self.state[j] ^ (y >> 1) ^ (if (y & 0x1) > 0 {self.xor_mask} else {0});
j += 1;
});
let y = (self.state[n] & upper_mask) | (self.state[0] & lower_mask);
self.state[n] = self.state[(self.shift_size as usize)- 1] ^ (y >> 1) ^
(if (y & 0x1) > 0 {self.xor_mask} else {0});
self.idx = 0;
}
}
impl Source<$Type0> for MersenneTwisterRand<$Type0> {
fn gen(&mut self) -> Result<$Type0> {
self.check_init();
if self.idx >= self.state_size {
self.gen_rand();
}
let mut z = self.state[self.idx as usize];
z ^= (z >> self.tempering_u) & self.tempering_d;
z ^= (z >> self.tempering_s) & self.tempering_b;
z ^= (z >> self.tempering_t) & self.tempering_c;
z ^= z >> self.shift_l;
self.idx += 1;
Ok(z)
}
fn reset<Sd: Seed<$Type0>>(&mut self, sd: &Sd) -> Result<()> {
sd.seed().map(|x| {
self.seed = Err(x);
self.check_init();
})
}
}
};
}
mtr_impl!(u32);
mtr_impl!(usize);
mtr_impl!(u64);
#[macro_export]
macro_rules! mt19937 {
($Sd: ident) => {
rmath::rand::MersenneTwisterRand::new(&$Sd, 32u32, 624u32, 397u32, 31u32, 0x9908b0dfu32,
(11u32, 0xffffffffu32), (7u32, 0x9d2c5680u32), (15u32, 0xefc60000u32),
18u32, 1812433253u32)
};
}
#[macro_export]
macro_rules! mt19937_64 {
($Sd: ident) => {
rmath::rand::MersenneTwisterRand::new(&$Sd, 64u64, 312u64, 156u64, 31u64, 0xb5026f5aa96619e9u64,
(29u64, 0x5555555555555555u64), (17u64, 0x71d67fffeda60000u64), (37u64, 0xfff7eee000000000u64),
43u64, 6364136223846793005u64)
};
}