use {
rand::Rng,
std::{
num::NonZero,
ops::{Bound, RangeBounds},
},
};
#[derive(Debug)]
pub struct UniformU64Sampler {
range_end: NonZero<u64>,
zone: u64,
}
impl UniformU64Sampler {
#[allow(clippy::arithmetic_side_effects)]
pub fn new_like_instance_sample(range_end: NonZero<u64>) -> Self {
let ints_to_reject = (u64::MAX - range_end.get() + 1) % range_end.get();
let zone = u64::MAX - ints_to_reject;
Self { range_end, zone }
}
pub fn new_like_trait_sample(range_end: NonZero<u64>) -> Self {
let zone = (range_end.get() << range_end.leading_zeros()).wrapping_sub(1);
Self { range_end, zone }
}
pub fn sample(&self, rng: &mut impl Rng) -> u64 {
loop {
let (hi, lo) = Self::wmul(rng.random(), self.range_end);
if lo <= self.zone {
return hi;
}
}
}
#[allow(clippy::arithmetic_side_effects)]
fn wmul(x: u64, y: NonZero<u64>) -> (u64, u64) {
let tmp = (x as u128) * (y.get() as u128);
((tmp >> 64) as u64, tmp as u64)
}
}
pub fn random_u64_range(rng: &mut impl Rng, range: impl RangeBounds<u64>) -> u64 {
let start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(start) => *start,
Bound::Excluded(&u64::MAX) => panic!("Cannot generate number in empty range (max..)"),
Bound::Excluded(start) => start.wrapping_add(1),
};
let last = match range.end_bound() {
Bound::Unbounded | Bound::Included(&u64::MAX) if start == 0 => return rng.random(),
Bound::Unbounded => u64::MAX,
Bound::Included(last) => *last,
Bound::Excluded(0) => panic!("Cannot generate number in empty range (..0)"),
Bound::Excluded(end) => end.wrapping_sub(1),
};
let zero_range_end = last
.checked_sub(start)
.expect("Range must not be empty")
.wrapping_add(1);
let zero_range_end = NonZero::new(zero_range_end).unwrap();
let sampler = UniformU64Sampler::new_like_trait_sample(zero_range_end);
sampler.sample(rng).wrapping_add(start)
}
#[cfg(test)]
mod tests {
use {
super::*,
rand::SeedableRng as _,
rand_chacha::ChaChaRng,
sha2::{Digest, Sha256},
std::{array, ops::Range},
test_case::test_case,
};
const CHACHA_SEED: [u8; 32] = [16; 32];
#[test]
fn test_uniform_sample_like_instance_sample_example() {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let sampler_compat =
UniformU64Sampler::new_like_instance_sample(NonZero::new(294_533).unwrap());
let values: [u64; 10] = array::from_fn(|_| sampler_compat.sample(&mut rng_compat));
assert_eq!(
values,
[
280405, 7507, 84194, 272634, 52124, 190984, 8676, 230277, 223574, 126007
]
);
}
#[test_case(10, "5p3DrG89DLrJotbMuZJUDHMrqcfoQiDQpa3FbNDbnND6")]
#[test_case(2_729, "J65XfpzkppuxWEyvgVdDmGBsJSex6BuevYuDSRtaDNAK")]
#[test_case(4_098, "EaRABrcPBw5LLNBmjn5ay5YN5yTPgv3GnQJLbpUarRkJ")]
#[test_case(504_302_479, "HniJPVe7zir8XxHmtUuxzhPVvWjMcJxVhhj8QSs1PZTC")]
#[test_case(1_000_346_000_000, "BghMy7yLe6BVzMbc4zNvAB4sz3ZSPYKJS5oLGvXYVzw2")]
fn test_uniform_sampler_like_instance_sample_compat(range_end: u64, expected_hash: &str) {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let sampler_compat =
UniformU64Sampler::new_like_instance_sample(NonZero::new(range_end).unwrap());
let mut hash = Sha256::new();
(0..600_000).for_each(|_| {
let compat = sampler_compat.sample(&mut rng_compat);
hash.update(compat.to_le_bytes());
});
assert_eq!(bs58::encode(hash.finalize()).into_string(), expected_hash);
}
#[test]
fn test_uniform_sample_like_trait_sample_example() {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let sampler_compat =
UniformU64Sampler::new_like_trait_sample(NonZero::new(294_533).unwrap());
let values: [u64; 10] = array::from_fn(|_| sampler_compat.sample(&mut rng_compat));
assert_eq!(
values,
[
272634, 52124, 8676, 230277, 223574, 137788, 212533, 213080, 187008, 209168
]
);
}
#[test_case(10, "2tsA6RuXekKvMMvAqjkMMipjYySpw8mrrwwoBeoKohD1")]
#[test_case(1_000, "58SG5gnD5wR6ngrjxuTv7m1SEXiXejZvuJPUCsyrmqWx")]
#[test_case(4098, "7w4TY1oaeEeqHmPw3iobD8WVq1BsL5eo7ZuNdvDkoSQo")]
#[test_case(10_000_000_000, "ENe5A82Wq2nYsH17q9WkKAudXdLE9FVGRbySuuCpaR5k")]
fn test_uniform_sampler_like_trait_sample_single(range_end: u64, expected_hash: &str) {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let sampler_compat =
UniformU64Sampler::new_like_trait_sample(NonZero::new(range_end).unwrap());
let mut hash = Sha256::new();
(0..1_000).for_each(|_| {
let compat = sampler_compat.sample(&mut rng_compat);
hash.update(compat.to_le_bytes());
});
assert_eq!(bs58::encode(hash.finalize()).into_string(), expected_hash);
}
#[test_case(0..100, &[95, 2, 28, 92, 17, 78, 72, 72, 21, 65])]
#[test_case(4..=100, &[96, 6, 31, 21, 77, 45, 49, 74, 24, 62])]
#[test_case(4..101, &[96, 6, 31, 21, 77, 45, 49, 74, 24, 62])]
#[test_case(77..=77, &[77, 77, 77, 77, 77, 77, 77, 77, 77, 77])]
#[test_case((u64::MAX - 1)..,
&[18446744073709551614, 18446744073709551614, 18446744073709551615, 18446744073709551614,
18446744073709551615, 18446744073709551615, 18446744073709551614, 18446744073709551615,
18446744073709551615, 18446744073709551615])]
#[test_case((u64::MAX - 1)..=u64::MAX,
&[18446744073709551614, 18446744073709551614, 18446744073709551615, 18446744073709551614,
18446744073709551615, 18446744073709551615, 18446744073709551614, 18446744073709551615,
18446744073709551615, 18446744073709551615])]
#[test_case(1_000..3_463, &[1704, 2597, 1072, 2152, 2777, 1526, 2619, 3318, 1824, 2749])]
#[test_case(.., &[17561962876765024395, 470213581521912695, 5273148429961811548, 17075204616871931816,
3264552321098222474, 11961477208003522799, 543394791199312464, 14422393090876740590,
14002568482199435789, 7891933133259177431])]
#[test_case(0..=u64::MAX,
&[17561962876765024395, 470213581521912695, 5273148429961811548, 17075204616871931816,
3264552321098222474, 11961477208003522799, 543394791199312464, 14422393090876740590,
14002568482199435789, 7891933133259177431])]
fn test_random_range_example(range: impl RangeBounds<u64> + Clone, expected: &[u64]) {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let result: Vec<u64> = (0..10)
.map(|_| random_u64_range(&mut rng_compat, range.clone()))
.collect();
assert_eq!(result, expected);
}
#[test_case(0..100, "9qrc9Atiwy7f5vYs6UK4fk4WMtRqjLF8Zd8WsG7MFsYX")]
#[test_case(1_000..9_999, "6qRfSYk55bPZty3qc3jLaSxASXgDRLaw8AWnXKEhiW6M")]
#[test_case(4..4098, "7rKu65HwouY8sr3o7jZrdRPCtGwVXjZDTXsjcKg3VVr")]
#[test_case(1_000..20_000_000_000, "J5vwi9DrgXrTNinMQmDw1zSq6ogeUencayoqEn4r64gH")]
fn test_random_range_reproducibility(range: Range<u64>, expected_hash: &str) {
let mut rng_compat = ChaChaRng::from_seed(CHACHA_SEED);
let mut hash = Sha256::new();
(0..10_000).for_each(|_| {
let compat = random_u64_range(&mut rng_compat, range.clone());
hash.update(compat.to_le_bytes());
});
assert_eq!(bs58::encode(hash.finalize()).into_string(), expected_hash);
}
#[test]
#[should_panic]
fn test_random_range_panic_empty() {
let mut rng = ChaChaRng::from_seed(CHACHA_SEED);
let _ = random_u64_range(&mut rng, 100..100);
}
}