use core::ops::{Add, BitAnd, BitOr, BitXor, Mul, Not, Shl, Shr, Sub};
use wrapping_arithmetic::wrappit;
pub trait Int:
Copy
+ Eq
+ PartialEq
+ Ord
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Not<Output = Self>
+ BitAnd<Output = Self>
+ BitOr<Output = Self>
+ BitXor<Output = Self>
+ Shl<usize, Output = Self>
+ Shr<usize, Output = Self>
{
fn zero() -> Self;
fn one() -> Self;
fn wrapping_add(self, other: Self) -> Self;
fn wrapping_sub(self, other: Self) -> Self;
fn wrapping_mul(self, other: Self) -> Self;
}
macro_rules! impl_int {
( $($t:ty),* ) => {
$( impl Int for $t {
#[inline] fn zero() -> Self { 0 }
#[inline] fn one() -> Self { 1 }
#[inline] fn wrapping_add(self, other: Self) -> Self { <$t>::wrapping_add(self, other) }
#[inline] fn wrapping_sub(self, other: Self) -> Self { <$t>::wrapping_sub(self, other) }
#[inline] fn wrapping_mul(self, other: Self) -> Self { <$t>::wrapping_mul(self, other) }
}) *
}
}
impl_int! { u8, u16, u32, u64, u128 }
#[wrappit]
pub fn get_jump<T: Int>(m: T, p: T, n: T) -> (T, T) {
let mut unit_m = m;
let mut unit_p = p;
let mut jump_m = T::one();
let mut jump_p = T::zero();
let mut delta = n;
while delta > T::zero() {
if delta & T::one() == T::one() {
jump_m = jump_m * unit_m;
jump_p = jump_p * unit_m + unit_p;
}
unit_p = (unit_m + T::one()) * unit_p;
unit_m = unit_m * unit_m;
delta = delta >> 1;
}
(jump_m, jump_p)
}
#[wrappit]
pub fn get_iterations<T: Int>(m: T, p: T, origin: T, state: T) -> T {
let mut jump_m = m;
let mut jump_p = p;
let mut ordinal = T::zero();
let mut bit = T::one();
let mut address = origin;
while address != state {
if (bit & address) != (bit & state) {
address = address * jump_m + jump_p;
ordinal = ordinal + bit;
}
jump_p = (jump_m + T::one()) * jump_p;
jump_m *= jump_m;
bit = bit << 1;
}
ordinal
}
#[wrappit]
pub fn get_state<T: Int>(m: T, p: T, origin: T, iterations: T) -> T {
let mut jump_m = m;
let mut jump_p = p;
let mut state = origin;
let mut ordinal = iterations;
while ordinal > T::zero() {
if ordinal & T::one() == T::one() {
state = state * jump_m + jump_p;
}
jump_p = (jump_m + T::one()) * jump_p;
jump_m *= jump_m;
ordinal = ordinal >> 1;
}
state
}
#[cfg(test)]
mod tests {
use super::super::*;
use super::*;
#[test]
pub fn run_tests() {
let mut r: u128 = 0;
let mut rnd = || -> u128 {
r = r.wrapping_mul(LCG_M128_1).wrapping_add(0xffff);
r
};
for _ in 0..1 << 12 {
let m = match rnd() % 3 {
0 => LCG_M128_1,
1 => LCG_M128_2,
_ => LCG_M128_3,
};
let p = rnd() | 1;
let origin = rnd();
assert_eq!(
origin.wrapping_mul(m).wrapping_add(p),
get_state(m, p, origin, 1)
);
assert_eq!(
1,
get_iterations(m, p, origin, origin.wrapping_mul(m).wrapping_add(p))
);
let state = rnd();
let n = get_iterations(m, p, origin, state);
assert_eq!(state, get_state(m, p, origin, n));
let (m_total, p_total) = get_jump(m, p, n);
assert_eq!(origin.wrapping_mul(m_total).wrapping_add(p_total), state);
let n = rnd();
let state = get_state(m, p, origin, n);
assert_eq!(n, get_iterations(m, p, origin, state));
let h = n & rnd();
let state_h = get_state(m, p, origin, h);
assert_eq!(n - h, get_iterations(m, p, state_h, state));
}
}
}