use rand::Rng;
use std::fmt;
#[derive(Clone)]
pub struct ReservoirSampling<T, R>
where
R: Rng,
{
k: usize,
rng: R,
reservoir: Vec<T>,
i: usize,
skip_until: usize,
}
impl<T, R> ReservoirSampling<T, R>
where
R: Rng,
{
pub fn new(k: usize, rng: R) -> Self {
assert!(k > 0, "k must be greater than 0");
Self {
k,
rng,
reservoir: vec![],
i: 0,
skip_until: 0,
}
}
pub fn k(&self) -> usize {
self.k
}
pub fn reservoir(&self) -> &Vec<T> {
&self.reservoir
}
pub fn i(&self) -> usize {
self.i
}
pub fn add(&mut self, obj: T) {
let t = self.k * 4;
if self.i < self.k {
self.reservoir.push(obj)
} else if self.i < t {
let j: usize = self.rng.gen_range(0..self.i);
if j < self.k {
self.reservoir[j] = obj;
}
} else if self.i >= self.skip_until {
let j: usize = self.rng.gen_range(0..self.k);
self.reservoir[j] = obj;
let p = (self.k as f64) / ((self.i + 1) as f64);
let u = 1f64 - self.rng.gen_range((0.)..(1.)); let g = (u.ln() / (1. - p).ln()).floor() as usize;
self.skip_until = self.i + g;
}
self.i += 1;
}
pub fn is_empty(&self) -> bool {
self.i == 0
}
pub fn clear(&mut self) {
self.i = 0;
self.skip_until = 0;
self.reservoir.clear();
}
}
impl<T, R> fmt::Debug for ReservoirSampling<T, R>
where
R: Rng,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ReservoirSampling {{ k: {} }}", self.k)
}
}
impl<T, R> Extend<T> for ReservoirSampling<T, R>
where
R: Rng,
{
fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {
for elem in iter {
self.add(elem);
}
}
}
#[cfg(test)]
mod tests {
use super::ReservoirSampling;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
#[test]
#[should_panic(expected = "k must be greater than 0")]
fn new_panics_k0() {
ReservoirSampling::<u64, ChaChaRng>::new(0, ChaChaRng::from_seed([0; 32]));
}
#[test]
fn getter() {
let rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
assert_eq!(rs.k(), 10);
assert_eq!(rs.i(), 0);
assert_eq!(rs.reservoir(), &vec![]);
}
#[test]
fn empty() {
let rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
assert!(rs.is_empty());
assert!(rs.reservoir().is_empty());
}
#[test]
fn add_k() {
let mut rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
for i in 0..10 {
rs.add(i);
}
assert_eq!(rs.i(), 10);
assert_eq!(rs.reservoir(), &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn add_parts() {
let mut rs = ReservoirSampling::<u64, ChaChaRng>::new(100, ChaChaRng::from_seed([0; 32]));
for _ in 0..1500 {
rs.add(0);
}
for _ in 0..7500 {
rs.add(1);
}
for _ in 0..1000 {
rs.add(0);
}
assert!((rs.reservoir().iter().sum::<u64>() as i64 - 75).abs() < 5);
}
#[test]
fn debug() {
let rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
assert_eq!(format!("{:?}", rs), "ReservoirSampling { k: 10 }");
}
#[test]
fn clear() {
let mut rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
for i in 0..10 {
rs.add(i);
}
rs.clear();
for i in 10..20 {
rs.add(i);
}
assert_eq!(rs.i(), 10);
assert_eq!(
rs.reservoir(),
&vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
);
}
#[test]
fn extend() {
let mut rs = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
rs.extend(0..10);
assert_eq!(rs.i(), 10);
assert_eq!(rs.reservoir(), &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn clone() {
let mut rs1 = ReservoirSampling::<u64, ChaChaRng>::new(10, ChaChaRng::from_seed([0; 32]));
for i in 0..10 {
rs1.add(i);
}
let mut rs2 = rs1.clone();
for _ in 0..1000 {
rs2.add(99);
}
assert_eq!(rs1.i(), 10);
assert_eq!(rs2.i(), 1010);
assert_eq!(rs1.reservoir(), &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(
rs2.reservoir(),
&vec![99, 99, 99, 99, 99, 99, 99, 99, 99, 99]
);
}
#[test]
fn k1() {
let mut rs = ReservoirSampling::<u64, ChaChaRng>::new(1, ChaChaRng::from_seed([0; 32]));
for _ in 0..10 {
rs.add(1);
}
assert_eq!(rs.reservoir(), &vec![1]);
}
}