use alloc::vec::Vec;
use core::convert::TryFrom;
use rand_core::{CryptoRng, Rng, SeedableRng};
#[derive(Debug)]
pub struct RandomizerError;
pub struct DeterministicRandomizer<R>
where R: SeedableRng
{
prng: R,
seed: <R as SeedableRng>::Seed,
}
impl<R> DeterministicRandomizer<R>
where
R: CryptoRng + Rng + SeedableRng,
<R as SeedableRng>::Seed: Clone,
{
pub fn new(seed: <R as SeedableRng>::Seed) -> Self {
Self {
prng: R::from_seed(seed.clone()),
seed,
}
}
pub fn reseed(&mut self, seed: <R as SeedableRng>::Seed) {
self.prng = R::from_seed(seed.clone());
self.seed = seed;
}
pub fn reset(&mut self) {
self.reseed(self.seed.clone());
}
pub fn sample<T>(&mut self, data: &[T], k: usize) -> Result<Vec<T>, RandomizerError>
where T: Clone {
let mut result = data.to_vec();
self.partial_shuffle(&mut result, k)?;
Ok(result[(data.len() - k)..].to_vec())
}
pub fn shuffle<T>(&mut self, data: &mut [T]) -> Result<(), RandomizerError> {
self.partial_shuffle(data, data.len())?;
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
fn partial_shuffle<T>(&mut self, data: &mut [T], n: usize) -> Result<(), RandomizerError> {
if n > data.len() {
return Err(RandomizerError);
}
let low = (data.len() - n) as u64;
let high = data.len() as u64;
for i in (low..high).rev() {
let j = self.next_bounded_u64(i + 1).map(|j| j as usize)?;
data.swap(i as usize, j);
}
Ok(())
}
pub fn next_bounded_u64(&mut self, upper: u64) -> Result<u64, RandomizerError> {
let x = (u128::from(self.prng.next_u64()) << 64) | u128::from(self.prng.next_u64());
u64::try_from(x % u128::from(upper)).map_err(|_| RandomizerError)
}
}
#[cfg(test)]
mod test {
use rand_chacha::ChaCha12Rng;
use super::DeterministicRandomizer;
type R = DeterministicRandomizer<ChaCha12Rng>;
#[test]
fn test_shuffle() {
let seed = [1u8; 32];
let mut data = [0u8, 1u8, 2u8, 3u8];
let mut randomizer = R::new(seed);
randomizer.shuffle(&mut data).unwrap();
assert_eq!(&data, &[0u8, 2u8, 3u8, 1u8]);
}
#[test]
fn test_bounded_u64() {
let seed = [1u8; 32];
let mut randomizer = R::new(seed);
assert_eq!(randomizer.next_bounded_u64(1000).unwrap(), 573);
assert_eq!(randomizer.next_bounded_u64(1000).unwrap(), 786);
}
#[test]
fn test_sample() {
let seed = [1u8; 32];
let data = [0u8, 1u8, 2u8, 3u8];
let mut randomizer = R::new(seed);
let sample = randomizer.sample(&data, 3).unwrap();
assert_eq!(&sample, &[2u8, 3u8, 1u8]);
}
}