use crate::{error::MathError, integer::Z, rational::Q};
use rand::rngs::ThreadRng;
use rand_distr::{Binomial, Distribution};
pub struct BinomialSampler {
distr: Binomial,
rng: ThreadRng,
}
impl BinomialSampler {
pub fn init(n: impl Into<Z>, p: impl Into<Q>) -> Result<Self, MathError> {
let n = n.into();
let p = p.into();
if p <= Q::ZERO || p >= Q::ONE {
return Err(MathError::InvalidInterval(format!(
"p (the probability of success for binomial sampling) must be chosen between 0 and 1. \
Currently it is {p}. \
Hence, the interval to sample from is invalid and contains only exactly one number."
)));
}
if n < Z::ZERO {
return Err(MathError::InvalidIntegerInput(format!(
"n (the number of trials for binomial sampling) must be no smaller than 0. Currently it is {n}."
)));
}
let n = i64::try_from(n)? as u64;
let p = f64::from(&p);
let distr = Binomial::new(n, p).unwrap();
let rng = rand::rng();
Ok(Self { distr, rng })
}
pub fn sample(&mut self) -> Z {
Z::from(self.distr.sample(&mut self.rng))
}
}
#[cfg(test)]
mod test_binomial_sampler {
use super::{BinomialSampler, Q, Z};
#[test]
fn keeps_range() {
let n = 16;
let p = 0.5;
let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
for _ in 0..16 {
let sample = bin_sampler.sample();
assert!(sample <= n);
}
}
#[test]
fn distribution() {
let n = 2;
let p = 0.5;
let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
let mut counts = [0; 3];
for _ in 0..1000 {
let sample = u64::try_from(bin_sampler.sample()).unwrap() as usize;
counts[sample] += 1;
}
let expl_text = String::from("This test can fail with probability close to 0.
It fails if the sampled occurrences do not look like a typical binomial random distribution.
If this happens, rerun the tests several times and check whether this issue comes up again.");
assert!(counts[0] > 100, "{expl_text}");
assert!(counts[0] < 400, "{expl_text}");
assert!(counts[1] > 300, "{expl_text}");
assert!(counts[1] < 700, "{expl_text}");
assert!(counts[2] > 100, "{expl_text}");
assert!(counts[2] < 400, "{expl_text}");
}
#[test]
fn invalid_n() {
let p = 0.5;
assert!(BinomialSampler::init(&Z::MINUS_ONE, p).is_err());
assert!(BinomialSampler::init(Z::from(i64::MIN), p).is_err());
}
#[test]
fn invalid_p() {
let n = 2;
assert!(BinomialSampler::init(n, &Q::MINUS_ONE).is_err());
assert!(BinomialSampler::init(n, &Q::ZERO).is_err());
assert!(BinomialSampler::init(n, &Q::ONE).is_err());
assert!(BinomialSampler::init(n, Q::from(5)).is_err());
}
}