use crate::{error::MathError, integer::Z};
use flint_sys::fmpz::{fmpz_addmul_ui, fmpz_set_ui};
use rand::{Rng, rngs::ThreadRng};
pub struct UniformIntegerSampler {
interval_size: Z,
two_pow_32: u64,
nr_iterations: u32,
upper_modulo: u32,
rng: ThreadRng,
}
impl UniformIntegerSampler {
pub fn init(interval_size: &Z) -> Result<Self, MathError> {
if interval_size < &Z::ONE {
return Err(MathError::InvalidInterval(format!(
"An invalid interval size {interval_size} was provided."
)));
}
let two_pow_32 = u32::MAX as u64 + 1;
let bit_size = (interval_size - Z::ONE).bits() as u32;
let nr_iterations = bit_size / 32;
let upper_modulo = 2_u32.pow(bit_size % 32);
let rng = rand::rng();
Ok(Self {
interval_size: interval_size.clone(),
two_pow_32,
nr_iterations,
upper_modulo,
rng,
})
}
pub fn sample(&mut self) -> Z {
if self.interval_size.is_one() {
return Z::ZERO;
}
let mut sample = self.sample_bits_uniform();
while sample >= self.interval_size {
sample = self.sample_bits_uniform();
}
sample
}
pub fn sample_bits_uniform(&mut self) -> Z {
let mut value = Z::from(self.rng.next_u32() % self.upper_modulo);
for _ in 0..self.nr_iterations {
let sample = self.rng.next_u32();
let mut res = Z::default();
unsafe {
fmpz_set_ui(&mut res.value, sample as u64);
fmpz_addmul_ui(&mut res.value, &value.value, self.two_pow_32);
};
value = res;
}
value
}
}
#[cfg(test)]
mod test_uis {
use super::{UniformIntegerSampler, Z};
use std::collections::HashSet;
#[test]
fn small_interval() {
let size_2 = Z::from(2);
let size_7 = Z::from(7);
let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
let mut uis_7 = UniformIntegerSampler::init(&size_7).unwrap();
for _ in 0..3 {
let sample_2 = uis_2.sample();
let sample_7 = uis_7.sample();
assert!(Z::ZERO <= sample_2);
assert!(sample_2 < size_2);
assert!(Z::ZERO <= sample_7);
assert!(sample_7 < size_7)
}
}
#[test]
fn large_interval() {
let size_0 = Z::from(u64::MAX);
let size_1 = Z::from(u64::MAX) * 2 + 1;
let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
for _i in 0..u8::MAX {
let sample_0 = uis_0.sample();
let sample_1 = uis_1.sample();
assert!(Z::ZERO <= sample_0);
assert!(sample_0 < size_0);
assert!(Z::ZERO <= sample_1);
assert!(sample_1 < size_1);
}
}
#[test]
fn entire_interval() {
let interval_sizes = vec![6, 7, 16];
for interval_size in interval_sizes {
let interval = Z::from(interval_size);
let mut uis = UniformIntegerSampler::init(&interval).unwrap();
let mut samples = HashSet::new();
for _ in 0..2_u32.pow(interval_size) {
samples.insert(uis.sample());
}
assert_eq!(
interval_size,
samples.len() as u32,
"This test may fail with low probability."
);
}
}
#[test]
fn invalid_interval() {
assert!(UniformIntegerSampler::init(&Z::ZERO).is_err());
assert!(UniformIntegerSampler::init(&Z::MINUS_ONE).is_err());
}
#[test]
fn sample_bits_uniform_necessary_nr_bytes() {
let size_0 = Z::from(8);
let size_1 = Z::from(256);
let size_2 = Z::from(u32::MAX) + Z::ONE;
let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
for _ in 0..u8::MAX {
let sample_0 = uis_0.sample_bits_uniform();
let sample_1 = uis_1.sample_bits_uniform();
let sample_2 = uis_2.sample_bits_uniform();
assert!(Z::ZERO <= sample_0);
assert!(sample_0 < size_0);
assert!(Z::ZERO <= sample_1);
assert!(sample_1 < size_1);
assert!(Z::ZERO <= sample_2);
assert!(sample_2 < size_2);
}
}
}