use rand::prelude::*;
#[derive(Debug, Clone)]
pub struct ReservoirSampler<T> {
k: usize,
seen: usize,
samples: Vec<T>,
skip_counter: usize,
w: f64,
}
impl<T> ReservoirSampler<T> {
pub fn new(k: usize) -> Self {
Self {
k,
seen: 0,
samples: Vec::with_capacity(k),
skip_counter: 0,
w: 0.0, }
}
#[inline]
pub fn add(&mut self, item: T) {
let mut rng = rand::rng();
self.add_with_rng(item, &mut rng);
}
#[inline]
pub fn add_with_rng<R: Rng + ?Sized>(&mut self, item: T, rng: &mut R) {
self.seen += 1;
if self.k == 0 {
return;
}
if self.samples.len() < self.k {
self.samples.push(item);
if self.samples.len() == self.k {
let u = rng.random::<f64>().max(f64::MIN_POSITIVE);
self.w = (u.ln() / self.k as f64).exp();
self.update_skip(rng);
}
return;
}
if self.skip_counter > 0 {
self.skip_counter -= 1;
return;
}
let replace_idx = rng.random_range(0..self.k);
self.samples[replace_idx] = item;
let u = rng.random::<f64>().max(f64::MIN_POSITIVE);
self.w *= (u.ln() / self.k as f64).exp();
self.update_skip(rng);
}
fn update_skip<R: Rng + ?Sized>(&mut self, rng: &mut R) {
let u = rng.random::<f64>();
let denom = (1.0 - self.w).max(1e-10).ln();
let num = u.max(1e-10).ln();
let skip = (num / denom).floor();
self.skip_counter = skip as usize;
}
pub fn samples(&self) -> &[T] {
&self.samples
}
pub fn seen(&self) -> usize {
self.seen
}
}
#[derive(Debug, Clone)]
pub struct ReservoirSamplerR<T> {
k: usize,
seen: usize,
samples: Vec<T>,
}
impl<T> ReservoirSamplerR<T> {
pub fn new(k: usize) -> Self {
Self {
k,
seen: 0,
samples: Vec::with_capacity(k),
}
}
#[inline]
pub fn add(&mut self, item: T) {
let mut rng = rand::rng();
self.add_with_rng(item, &mut rng);
}
#[inline]
pub fn add_with_rng<R: Rng + ?Sized>(&mut self, item: T, rng: &mut R) {
self.seen += 1;
if self.k == 0 {
return;
}
if self.samples.len() < self.k {
self.samples.push(item);
return;
}
let j = rng.random_range(0..self.seen);
if j < self.k {
self.samples[j] = item;
}
}
pub fn samples(&self) -> &[T] {
&self.samples
}
pub fn seen(&self) -> usize {
self.seen
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum WeightedReservoirError {
NonFiniteWeight(f64),
NonPositiveWeight(f64),
}
impl std::fmt::Display for WeightedReservoirError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NonFiniteWeight(w) => write!(f, "weight must be finite (got {w})"),
Self::NonPositiveWeight(w) => write!(f, "weight must be > 0 (got {w})"),
}
}
}
impl std::error::Error for WeightedReservoirError {}
#[derive(Debug, Clone)]
pub struct WeightedReservoirSampler<T> {
k: usize,
seen: usize,
items: Vec<T>,
keys: Vec<f64>,
min_idx: usize,
}
impl<T> WeightedReservoirSampler<T> {
pub fn new(k: usize) -> Self {
Self {
k,
seen: 0,
items: Vec::with_capacity(k),
keys: Vec::with_capacity(k),
min_idx: 0,
}
}
#[inline]
pub fn add(&mut self, item: T, weight: f64) -> Result<(), WeightedReservoirError> {
let mut rng = rand::rng();
self.add_with_rng(item, weight, &mut rng)
}
#[inline]
pub fn add_with_rng<R: Rng + ?Sized>(
&mut self,
item: T,
weight: f64,
rng: &mut R,
) -> Result<(), WeightedReservoirError> {
self.seen += 1;
if self.k == 0 {
return Ok(());
}
if !weight.is_finite() {
return Err(WeightedReservoirError::NonFiniteWeight(weight));
}
if weight <= 0.0 {
return Err(WeightedReservoirError::NonPositiveWeight(weight));
}
let u = rng.random::<f64>().max(f64::MIN_POSITIVE);
let key = (u.ln() / weight).exp();
if self.items.len() < self.k {
if self.keys.is_empty() || key < self.keys[self.min_idx] {
self.min_idx = self.items.len();
}
self.items.push(item);
self.keys.push(key);
return Ok(());
}
if key > self.keys[self.min_idx] {
self.items[self.min_idx] = item;
self.keys[self.min_idx] = key;
self.min_idx = 0;
for (i, &k_i) in self.keys.iter().enumerate().skip(1) {
if k_i < self.keys[self.min_idx] {
self.min_idx = i;
}
}
}
Ok(())
}
pub fn samples(&self) -> &[T] {
&self.items
}
pub fn keys(&self) -> &[f64] {
&self.keys
}
pub fn seen(&self) -> usize {
self.seen
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[derive(Debug, Clone, Default)]
struct ZeroRng;
impl rand::RngCore for ZeroRng {
fn next_u32(&mut self) -> u32 {
0
}
fn next_u64(&mut self) -> u64 {
0
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
dest.fill(0);
}
}
#[test]
fn reservoir_keeps_k_items() {
let mut s = ReservoirSampler::new(5);
for i in 0..100 {
s.add(i);
}
assert_eq!(s.samples().len(), 5);
assert_eq!(s.seen(), 100);
}
#[test]
fn reservoir_distribution_uniform() {
let n = 100;
let k = 10;
let trials = 10_000;
let mut counts = vec![0; n];
for t in 0..trials {
let mut s = ReservoirSampler::new(k);
let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
for i in 0..n {
s.add_with_rng(i, &mut rng);
}
for &item in s.samples() {
counts[item] += 1;
}
}
let expected = trials as f64 * (k as f64 / n as f64); let chi2: f64 = counts
.iter()
.map(|&c| {
let diff = c as f64 - expected;
(diff * diff) / expected
})
.sum();
assert!(
chi2 < 250.0,
"chi2 too large (chi2={chi2:.2}, expected~{}). counts={counts:?}",
n - 1
);
}
#[test]
fn reservoir_r_keeps_k_items() {
let mut s = ReservoirSamplerR::new(5);
for i in 0..100 {
s.add(i);
}
assert_eq!(s.samples().len(), 5);
assert_eq!(s.seen(), 100);
}
#[test]
fn reservoir_r_distribution_uniform() {
let n = 100;
let k = 10;
let trials = 5_000;
let mut counts = vec![0; n];
for t in 0..trials {
let mut s = ReservoirSamplerR::new(k);
let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
for i in 0..n {
s.add_with_rng(i, &mut rng);
}
for &item in s.samples() {
counts[item] += 1;
}
}
let expected = trials as f64 * (k as f64 / n as f64);
let chi2: f64 = counts
.iter()
.map(|&c| {
let diff = c as f64 - expected;
(diff * diff) / expected
})
.sum();
assert!(
chi2 < 250.0,
"chi2 too large (chi2={chi2:.2}, expected~{}). counts={counts:?}",
n - 1
);
}
#[test]
fn weighted_reservoir_keeps_k_items() {
let mut s = WeightedReservoirSampler::new(5);
for i in 0..100 {
s.add(i, 1.0).expect("weight ok");
}
assert_eq!(s.samples().len(), 5);
assert_eq!(s.seen(), 100);
assert_eq!(s.keys().len(), 5);
}
#[test]
fn weighted_reservoir_rejects_bad_weights() {
let mut s = WeightedReservoirSampler::new(2);
let err = s.add(1, 0.0).expect_err("zero weight rejected");
assert_eq!(err, WeightedReservoirError::NonPositiveWeight(0.0));
let err = s.add(2, f64::NAN).expect_err("nan weight rejected");
assert!(matches!(err, WeightedReservoirError::NonFiniteWeight(w) if !w.is_finite()));
}
#[test]
fn weighted_reservoir_biases_toward_large_weights() {
let n_trials = 2_000;
let mut counts = [0usize; 3];
for t in 0..n_trials {
let mut s = WeightedReservoirSampler::new(1);
let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
s.add_with_rng(0, 100.0, &mut rng).expect("weight ok");
s.add_with_rng(1, 1.0, &mut rng).expect("weight ok");
s.add_with_rng(2, 1.0, &mut rng).expect("weight ok");
let sample = s.samples()[0];
counts[sample] += 1;
}
assert!(counts[0] > counts[1]);
assert!(counts[0] > counts[2]);
}
#[test]
fn reservoir_algorithm_l_handles_zero_rng_draws() {
let mut s = ReservoirSampler::new(5);
let mut rng = ZeroRng;
for i in 0..100 {
s.add_with_rng(i, &mut rng);
}
assert_eq!(s.samples().len(), 5);
assert_eq!(s.seen(), 100);
}
#[test]
fn reservoir_k_zero_discards_everything() {
let mut s = ReservoirSampler::new(0);
let mut rng = ChaCha8Rng::seed_from_u64(0);
for i in 0..100 {
s.add_with_rng(i, &mut rng);
}
assert!(s.samples().is_empty());
assert_eq!(s.seen(), 100);
}
#[test]
fn reservoir_k_larger_than_stream_returns_all() {
let mut s = ReservoirSampler::new(50);
let mut rng = ChaCha8Rng::seed_from_u64(0);
for i in 0..10 {
s.add_with_rng(i, &mut rng);
}
assert_eq!(s.samples().len(), 10);
let mut sorted: Vec<_> = s.samples().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
#[test]
fn reservoir_r_k_zero_discards_everything() {
let mut s = ReservoirSamplerR::new(0);
let mut rng = ChaCha8Rng::seed_from_u64(0);
for i in 0..100 {
s.add_with_rng(i, &mut rng);
}
assert!(s.samples().is_empty());
assert_eq!(s.seen(), 100);
}
#[test]
fn reservoir_r_k_larger_than_stream_returns_all() {
let mut s = ReservoirSamplerR::new(50);
let mut rng = ChaCha8Rng::seed_from_u64(0);
for i in 0..10 {
s.add_with_rng(i, &mut rng);
}
assert_eq!(s.samples().len(), 10);
let mut sorted: Vec<_> = s.samples().to_vec();
sorted.sort_unstable();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
#[test]
fn weighted_reservoir_k_zero_discards_everything() {
let mut s = WeightedReservoirSampler::new(0);
let mut rng = ChaCha8Rng::seed_from_u64(0);
for i in 0..100 {
s.add_with_rng(i, 1.0, &mut rng).unwrap();
}
assert!(s.samples().is_empty());
assert_eq!(s.seen(), 100);
}
#[test]
fn weighted_reservoir_negative_weight_error() {
let mut s = WeightedReservoirSampler::new(5);
let mut rng = ChaCha8Rng::seed_from_u64(0);
let err = s.add_with_rng(1, -1.0, &mut rng).unwrap_err();
assert_eq!(err, WeightedReservoirError::NonPositiveWeight(-1.0));
}
#[test]
fn weighted_reservoir_infinity_weight_error() {
let mut s = WeightedReservoirSampler::new(5);
let mut rng = ChaCha8Rng::seed_from_u64(0);
let err = s.add_with_rng(1, f64::INFINITY, &mut rng).unwrap_err();
assert!(matches!(err, WeightedReservoirError::NonFiniteWeight(w) if w.is_infinite()));
}
mod proptests {
use super::*;
use proptest::prelude::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
proptest! {
#[test]
fn prop_weighted_reservoir_min_idx_correct(
seed in 0u64..10_000,
k in 1usize..=10,
n in 1usize..=50,
weights in proptest::collection::vec(0.01f64..100.0f64, 1..=50),
) {
let n = n.min(weights.len());
let mut sampler = WeightedReservoirSampler::new(k);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
for i in 0..n {
sampler.add_with_rng(i, weights[i], &mut rng).unwrap();
if !sampler.keys.is_empty() {
let actual_min_idx = sampler.keys.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0;
prop_assert!(
sampler.min_idx == actual_min_idx,
"After adding item {}: cached min_idx={} but actual min is at {}. keys={:?}",
i, sampler.min_idx, actual_min_idx, sampler.keys
);
}
}
}
}
}
#[test]
fn algorithm_l_and_r_agree_on_size() {
let k = 10;
let n = 200;
let mut s_l = ReservoirSampler::new(k);
let mut s_r = ReservoirSamplerR::new(k);
let mut rng_l = ChaCha8Rng::seed_from_u64(99);
let mut rng_r = ChaCha8Rng::seed_from_u64(99);
for i in 0..n {
s_l.add_with_rng(i, &mut rng_l);
s_r.add_with_rng(i, &mut rng_r);
}
assert_eq!(s_l.samples().len(), k);
assert_eq!(s_r.samples().len(), k);
assert_eq!(s_l.seen(), n);
assert_eq!(s_r.seen(), n);
}
}