use core::ops::RangeInclusive;
use num::{BigUint, ToPrimitive};
pub trait Rng {
fn generate(&mut self, range: RangeInclusive<BigUint>) -> BigUint;
}
impl<T: Rng + ?Sized> Rng for Box<T> {
fn generate(&mut self, range: RangeInclusive<BigUint>) -> BigUint {
T::generate(self, range)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RngMock<const N: usize>(pub [BigUint; N]);
impl<const N: usize> Rng for RngMock<N> {
fn generate(&mut self, _range: RangeInclusive<BigUint>) -> BigUint {
let Some(first) = self.0.first().cloned() else {
return BigUint::default();
};
self.0.rotate_left(1);
first
}
}
#[cfg(feature = "rand")]
pub struct RandRng<R>(pub R);
#[cfg(feature = "rand")]
impl<R: rand::Rng> Rng for RandRng<R> {
fn generate(&mut self, range: RangeInclusive<BigUint>) -> BigUint {
use num::One;
let (start, end) = (range.start(), range.end());
let width = end - start + BigUint::one();
let width_bits = width.bits() as usize;
loop {
let mut bytes = vec![0u8; width_bits.div_ceil(8)];
self.0.fill_bytes(&mut bytes);
let mut n = BigUint::from_bytes_le(&bytes);
if n < width {
n += start;
return n;
}
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct JavaRandom {
seed: u64,
}
impl JavaRandom {
const MULT: u64 = 0x5_DEEC_E66D;
const ADD: u64 = 0xB;
#[must_use]
pub const fn new(seed: i64) -> Self {
let s = (seed as u64) ^ Self::MULT;
Self {
seed: s & ((1u64 << 48) - 1),
}
}
const fn next(&mut self, bits: i32) -> i32 {
self.seed =
(self.seed.wrapping_mul(Self::MULT).wrapping_add(Self::ADD)) & ((1u64 << 48) - 1);
((self.seed >> (48 - bits)) & ((1u64 << bits) - 1)) as i32
}
pub const fn next_int(&mut self) -> i32 {
self.next(32)
}
pub fn next_int_bound(&mut self, bound: i32) -> i32 {
assert!(bound > 0, "bound must be positive");
let m = bound - 1;
if (bound & m) == 0 {
((bound as i64 * self.next(31) as i64) >> 31) as i32
} else {
loop {
let bits = self.next(31);
let val = bits % bound;
if bits - val + m >= 0 {
return val;
}
}
}
}
}
impl Default for JavaRandom {
fn default() -> Self {
Self::new(0)
}
}
impl Rng for JavaRandom {
fn generate(&mut self, range: RangeInclusive<BigUint>) -> BigUint {
use num::One;
let (start, end) = (range.start(), range.end());
let width = end - start + BigUint::one();
if let (Some(_start_i32), Some(width_i32)) =
(start.to_i32(), width.to_i32().filter(|&w| w > 0))
{
let offset = self.next_int_bound(width_i32);
return start + BigUint::from(offset as u32);
}
let width_bits = width.bits() as usize;
let mut result = BigUint::ZERO;
let mut bits_generated = 0;
while bits_generated < width_bits {
let random_int = self.next_int();
let random_bits = random_int.unsigned_abs();
let shift_amount = bits_generated.min(32);
result |= BigUint::from(random_bits) << shift_amount;
bits_generated += 32;
if result >= width {
result %= width.clone();
break;
}
}
if result >= width {
result %= width;
}
start + result
}
}
#[cfg(all(test, feature = "rand"))]
mod tests {
use super::*;
use num::BigUint;
use rand::{SeedableRng, rngs::StdRng};
#[test]
fn test_rand_rng_big_range() {
let start = BigUint::parse_bytes(b"10000000000000000000000000000000000000000000000000", 10)
.unwrap();
let end = BigUint::parse_bytes(b"10000000000000000000000000000000000000000000000099", 10)
.unwrap();
let mut rng = RandRng(StdRng::seed_from_u64(42));
let range = start.clone()..=end.clone();
let n1 = rng.generate(range.clone());
let n2 = rng.generate(range.clone());
let n3 = rng.generate(range);
assert!(n1 >= start && n1 <= end, "n1 out of range");
assert!(n2 >= start && n2 <= end, "n2 out of range");
assert!(n3 >= start && n3 <= end, "n3 out of range");
assert!(
n1 != n2 && n1 != n3 && n2 != n3,
"random numbers are not unique"
);
}
#[test]
fn test_java_random_consistency() {
let mut rng = JavaRandom::new(123456789);
println!("First nextInt(): {}", rng.next_int());
println!("Second nextInt(): {}", rng.next_int());
println!("Third nextInt(): {}", rng.next_int());
let mut rng2 = JavaRandom::new(123456789);
println!("First nextInt(100): {}", rng2.next_int_bound(100));
println!("Second nextInt(100): {}", rng2.next_int_bound(100));
println!("Third nextInt(100): {}", rng2.next_int_bound(100));
assert!(rng2.next_int_bound(100) >= 0 && rng2.next_int_bound(100) < 100);
}
}