use crate::error::{PermutationError, PermutationResult};
#[cfg(feature = "use-rand")]
use rand::prelude::*;
use std::num::NonZeroU32;
#[derive(Clone, Debug)]
pub struct HashedPermutation {
pub seed: u32,
pub length: NonZeroU32,
}
impl HashedPermutation {
#[cfg(feature = "use-rand")]
pub fn new(length: NonZeroU32) -> Self {
let seed = rand::random();
HashedPermutation { length, seed }
}
pub fn new_with_seed(length: NonZeroU32, seed: u32) -> Self {
HashedPermutation { length, seed }
}
#[allow(clippy::unreadable_literal)]
pub fn shuffle(&self, input: u32) -> PermutationResult<u32> {
if input >= self.length.get() {
return Err(PermutationError::ShuffleOutOfRange {
shuffle: input,
max_shuffle: self.length.get(),
});
}
let mut i = input;
let n = self.length.get();
let seed = self.seed;
let mut w = n - 1;
w |= w >> 1;
w |= w >> 2;
w |= w >> 4;
w |= w >> 8;
w |= w >> 16;
while i >= n {
i ^= seed;
i *= 0xe170893d;
i ^= seed >> 16;
i ^= (i & w) >> 4;
i ^= seed >> 8;
i *= 0x0929eb3f;
i ^= seed >> 23;
i ^= (i & w) >> 1;
i *= 1 | seed >> 27;
i *= 0x6935fa69;
i ^= (i & w) >> 11;
i *= 0x74dcb303;
i ^= (i & w) >> 2;
i *= 0x9e501cc3;
i ^= (i & w) >> 2;
i *= 0xc860a3df;
i &= w;
i ^= i >> 5;
}
Ok((i + seed) % n)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::collections::HashMap;
fn lengths_and_seeds() -> (Vec<NonZeroU32>, Vec<u32>) {
let lengths: Vec<NonZeroU32> = vec![100, 5, 13, 128, 249]
.iter()
.map(|&x| NonZeroU32::new(x).unwrap())
.collect();
let seeds = vec![100, 5, 13, 128, 249];
assert_eq!(lengths.len(), seeds.len());
(lengths, seeds)
}
#[test]
fn test_domain() {
let (lengths, seeds) = lengths_and_seeds();
for (&length, seed) in lengths.iter().zip(seeds) {
let perm = HashedPermutation { seed, length };
for i in 0..perm.length.get() {
let res = perm.shuffle(i);
assert!(res.is_ok());
assert!(res.unwrap() < perm.length.get());
}
}
}
#[test]
fn test_bijection() {
let (lengths, seeds) = lengths_and_seeds();
for (length, seed) in lengths.iter().zip(seeds) {
let perm = HashedPermutation {
seed,
length: *length,
};
let mut map = HashMap::new();
for i in 0..perm.length.get() {
let res = perm.shuffle(i);
let res = res.unwrap();
let map_result = map.get(&res);
assert!(map_result.is_none());
map.insert(res, i);
}
let mut keys_vec: Vec<u32> = map.keys().into_iter().map(|k| *k).collect();
keys_vec.sort();
let mut vals_vec: Vec<u32> = map.values().into_iter().map(|v| *v).collect();
vals_vec.sort();
let ground_truth: Vec<u32> = (0..length.get()).collect();
assert_eq!(ground_truth, keys_vec);
assert_eq!(ground_truth, vals_vec);
}
}
#[test]
fn test_out_of_range() {
let lengths: Vec<NonZeroU32> = vec![1, 50, 256, 18]
.iter()
.map(|&x| NonZeroU32::new(x).unwrap())
.collect();
let offsets = vec![0, 1, 5, 15, 100];
for length in lengths {
let perm = HashedPermutation { seed: 0, length };
for offset in &offsets {
let result = perm.shuffle(length.get() + offset);
assert!(result.is_err());
}
}
}
}