use rand::Rng;
use std::iter::FusedIterator;
#[derive(Debug, Clone)]
pub struct RandSubsample<I, R> {
num_samples: usize,
iter: I,
rng: R,
}
impl<I, R> RandSubsample<I, R> {
pub const fn new(iter: I, num_samples: usize, rng: R) -> Self {
Self {
num_samples,
iter,
rng,
}
}
}
impl<I, R> Iterator for RandSubsample<I, R>
where
I: ExactSizeIterator,
R: Rng,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
if self.num_samples == 0 {
return None;
}
let mut remaining = self.iter.len();
while !self
.rng
.gen_bool((self.num_samples as f64 / remaining as f64).min(1.0))
{
remaining -= 1;
}
let num_skip = self.iter.len() - remaining;
self.num_samples -= 1;
self.iter.nth(num_skip)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.iter.len().min(self.num_samples);
(len, Some(len))
}
fn count(self) -> usize {
self.len()
}
}
impl<I, R> ExactSizeIterator for RandSubsample<I, R>
where
I: ExactSizeIterator,
R: Rng,
{
}
impl<I, R> FusedIterator for RandSubsample<I, R>
where
I: ExactSizeIterator + FusedIterator,
R: Rng,
{
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{rngs::StdRng, SeedableRng};
use std::ops::RangeInclusive;
const POP_SIZE: usize = 50;
const NUM_SAMPLES: usize = 10;
const NUM_ITERS_CHECK: u64 = 100;
const NUM_ITERS_STATS: u64 = 1000;
#[test]
fn rand_subsample_len() {
let mut rng = StdRng::seed_from_u64(87);
for _ in 0..NUM_ITERS_CHECK {
let iter = RandSubsample::new(0..POP_SIZE, NUM_SAMPLES, &mut rng);
assert_eq!(iter.len(), NUM_SAMPLES);
}
}
#[test]
fn rand_subsample_count() {
let mut rng = StdRng::seed_from_u64(87);
for _ in 0..NUM_ITERS_CHECK {
let iter = RandSubsample::new(0..POP_SIZE, NUM_SAMPLES, &mut rng);
assert_eq!(iter.count(), NUM_SAMPLES);
}
}
#[test]
fn rand_subsample_manual_count() {
let mut rng = StdRng::seed_from_u64(87);
for _ in 0..NUM_ITERS_CHECK {
let iter = RandSubsample::new(0..POP_SIZE, NUM_SAMPLES, &mut rng);
assert_eq!(iter.map(|_| 1).sum::<usize>(), NUM_SAMPLES);
}
}
#[test]
fn rand_subsample_unif_dist() {
let mut rng = StdRng::seed_from_u64(87);
let mut counts = vec![0; POP_SIZE];
for _ in 0..NUM_ITERS_STATS {
let iter = RandSubsample::new(0..POP_SIZE, NUM_SAMPLES, &mut rng);
for i in iter {
counts[i] += 1;
}
}
let ci =
bernoulli_confidence_interval(NUM_SAMPLES as f64 / POP_SIZE as f64, NUM_ITERS_STATS);
assert!(counts.iter().all(|c| ci.contains(c)));
}
#[test]
fn rand_subsample_sample_all() {
assert!(RandSubsample::new(0..10, 10, StdRng::seed_from_u64(87)).eq(0..10));
}
#[test]
fn rand_subsample_sample_more() {
assert!(RandSubsample::new(0..10, 20, StdRng::seed_from_u64(87)).eq(0..10));
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)] fn bernoulli_confidence_interval(p: f64, n: u64) -> RangeInclusive<u64> {
let z = 4.4;
let nf = n as f64;
let stddev = (p * (1.0 - p) * nf).sqrt();
let lower_bound = nf * p - z * stddev;
let upper_bound = nf * p + z * stddev;
(lower_bound.round() as u64)..=(upper_bound.round() as u64)
}
}