use num_modular::{Montgomery, Reducer};
use num_traits::{PrimInt, WrappingAdd, WrappingSub};
use crate::builder::RandomSequenceBuilder;
#[derive(Debug, Clone)]
pub struct RandomSequence<T>
where
T: PrimInt + WrappingAdd + WrappingSub,
Montgomery<T>: Reducer<T>,
{
pub config: RandomSequenceBuilder<T>,
pub(crate) start_index: T,
pub(crate) current_index: T,
pub(crate) intermediate_offset: T,
}
impl<T> RandomSequence<T>
where
T: PrimInt + WrappingAdd + WrappingSub,
Montgomery<T>: Reducer<T>,
{
#[inline]
pub fn next(&mut self) -> T {
let next = self.n_internal(self.current_index);
self.current_index = self.current_index.wrapping_add(&T::one());
next
}
#[inline]
pub fn prev(&mut self) -> T {
self.current_index = self.current_index.wrapping_sub(&T::one());
self.n_internal(self.current_index)
}
#[inline]
pub fn n(&self, index: T) -> T {
let actual_index = self.start_index.wrapping_add(&index);
self.n_internal(actual_index)
}
#[inline(always)]
fn n_internal(&self, index: T) -> T {
let inner_residue = self.config.permute_qpr(index).wrapping_add(&self.intermediate_offset);
self.config.permute_qpr(inner_residue ^ self.config.intermediate_xor)
}
#[inline]
pub fn index(&self) -> T {
self.current_index.wrapping_sub(&self.start_index)
}
}
impl<T> Iterator for RandomSequence<T>
where
T: PrimInt + WrappingAdd + WrappingSub,
Montgomery<T>: Reducer<T>,
{
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
Some(self.next())
}
}
impl<T> DoubleEndedIterator for RandomSequence<T>
where
T: PrimInt + WrappingAdd + WrappingSub,
Montgomery<T>: Reducer<T>,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
Some(self.prev())
}
}
impl<T> From<RandomSequenceBuilder<T>> for RandomSequence<T>
where
T: PrimInt + WrappingAdd + WrappingSub,
Montgomery<T>: Reducer<T>,
{
fn from(value: RandomSequenceBuilder<T>) -> Self {
value.into_iter()
}
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use std::vec::Vec;
use rand::rngs::OsRng;
use statrs::distribution::{ChiSquared, ContinuousCDF};
use super::*;
fn is_send<T: Send>() {}
fn is_sync<T: Sync>() {}
macro_rules! test_sequence {
($name:ident, $type:ident, $check:literal) => {
#[test]
fn $name() {
let config = RandomSequenceBuilder::<$type>::new(0, 0);
let sequence = config.into_iter();
for (i, num) in std::iter::zip(0..10, sequence.clone()) {
assert_eq!(sequence.n(i as $type), num);
}
for (i, num) in std::iter::zip(0..10, sequence.clone().rev()) {
assert_eq!(sequence.n($type::MAX.wrapping_sub(i as $type)), num);
}
let nums: HashSet<$type> = config.into_iter().take($check).collect();
assert_eq!(nums.len(), $check);
is_send::<RandomSequence<$type>>();
is_sync::<RandomSequence<$type>>();
}
};
}
test_sequence!(test_u8_sequence, u8, 256);
test_sequence!(test_u16_sequence, u16, 65536);
test_sequence!(test_u32_sequence, u32, 100_000);
test_sequence!(test_u64_sequence, u64, 100_000);
macro_rules! test_distribution {
($name:ident, $type:ident, $check:literal) => {
#[test]
fn $name() {
const BUCKETS: usize = 100;
let config = RandomSequenceBuilder::<$type>::rand(&mut OsRng);
let mut data_buckets: HashMap<usize, usize> = HashMap::with_capacity(BUCKETS + 1);
config
.into_iter()
.take($check)
.map(|i| ((i as f64 / $type::MAX as f64) * BUCKETS as f64) as usize)
.for_each(|i| *data_buckets.entry(i).or_insert(0) += 1);
let data_buckets: Vec<f64> = (0..=BUCKETS)
.map(|i| *data_buckets.get(&i).unwrap_or(&0) as f64)
.collect();
let mut uniform_buckets: Vec<f64> = (0..BUCKETS)
.map(|_| ($check as f64 / BUCKETS as f64))
.collect();
uniform_buckets.push($check as f64 / $type::MAX as f64); assert_eq!(data_buckets.len(), uniform_buckets.len(), "Data and uniform buckets logic issue.");
let chi_squared = std::iter::zip(data_buckets.iter(), uniform_buckets.iter())
.map(|(x, e)| (x - e).powi(2) / e)
.sum::<f64>();
let chi_dist = ChiSquared::new((BUCKETS - 1) as f64).unwrap();
let p_value = 1.0 - chi_dist.cdf(chi_squared);
assert!(p_value > 0.05, "Unexpectedly rejected the null hypothesis with high probability. stat: {}, p: {}", chi_squared, p_value);
}
};
}
test_distribution!(test_u8_distribution, u8, 256);
test_distribution!(test_u16_distribution, u16, 65536);
test_distribution!(test_u32_distribution, u32, 100_000);
test_distribution!(test_u64_distribution, u64, 100_000);
}