use alloc::collections::BinaryHeap;
use alloc::format;
use alloc::vec::Vec;
use core::cmp::Ordering;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use rand::{Rng, RngExt};
use crate::error::{RcfError, RcfResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SamplerOp {
Inserted,
Replaced(usize),
Rejected,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct WeightedEntry {
weight: f64,
point_idx: usize,
}
impl PartialEq for WeightedEntry {
fn eq(&self, other: &Self) -> bool {
self.weight == other.weight
}
}
impl Eq for WeightedEntry {}
impl PartialOrd for WeightedEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WeightedEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.weight
.partial_cmp(&other.weight)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ReservoirSampler {
heap: BinaryHeap<WeightedEntry>,
capacity: usize,
entries_seen: u64,
time_decay: f64,
#[cfg_attr(feature = "serde", serde(default = "default_initial_accept_fraction"))]
initial_accept_fraction: f64,
}
#[must_use]
pub fn default_initial_accept_fraction() -> f64 {
1.0
}
impl ReservoirSampler {
pub fn new(capacity: usize, time_decay: f64) -> RcfResult<Self> {
Self::with_initial_accept_fraction(capacity, time_decay, default_initial_accept_fraction())
}
pub fn with_initial_accept_fraction(
capacity: usize,
time_decay: f64,
initial_accept_fraction: f64,
) -> RcfResult<Self> {
if capacity == 0 {
return Err(RcfError::InvalidConfig(
"ReservoirSampler capacity must be > 0".into(),
));
}
if !time_decay.is_finite() || time_decay < 0.0 {
return Err(RcfError::InvalidConfig(
format!("ReservoirSampler time_decay must be finite and >= 0, got {time_decay}")
.into(),
));
}
if !initial_accept_fraction.is_finite()
|| initial_accept_fraction <= 0.0
|| initial_accept_fraction > 1.0
{
return Err(RcfError::InvalidConfig(format!(
"ReservoirSampler initial_accept_fraction must be in (0.0, 1.0], got {initial_accept_fraction}"
).into()));
}
Ok(Self {
heap: BinaryHeap::with_capacity(capacity),
capacity,
entries_seen: 0,
time_decay,
initial_accept_fraction,
})
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.heap.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
#[must_use]
pub fn entries_seen(&self) -> u64 {
self.entries_seen
}
#[must_use]
pub fn time_decay(&self) -> f64 {
self.time_decay
}
#[must_use]
pub fn initial_accept_fraction(&self) -> f64 {
self.initial_accept_fraction
}
#[must_use]
pub fn admit_probability(&self) -> f64 {
if self.initial_accept_fraction >= 1.0 {
return 1.0;
}
#[allow(clippy::cast_precision_loss)]
let threshold = (self.initial_accept_fraction * self.capacity as f64).max(1.0);
#[allow(clippy::cast_precision_loss)]
let seen = self.entries_seen as f64;
if seen < threshold {
((seen + 1.0) / threshold).min(1.0)
} else {
1.0
}
}
pub fn reset(&mut self) {
self.heap.clear();
self.entries_seen = 0;
}
pub fn iter_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.heap.iter().map(|entry| entry.point_idx)
}
#[must_use]
pub fn contains(&self, point_idx: usize) -> bool {
self.iter_indices().any(|idx| idx == point_idx)
}
pub fn remove(&mut self, point_idx: usize) -> bool {
let before = self.heap.len();
let kept: Vec<WeightedEntry> = self
.heap
.drain()
.filter(|entry| entry.point_idx != point_idx)
.collect();
let removed = kept.len() < before;
for entry in kept {
self.heap.push(entry);
}
removed
}
pub fn accept<R: Rng + ?Sized>(&mut self, point_idx: usize, rng: &mut R) -> SamplerOp {
let admit_prob = self.admit_probability();
self.entries_seen = self.entries_seen.saturating_add(1);
if admit_prob < 1.0 {
let roll: f64 = rng.random();
if roll >= admit_prob {
return SamplerOp::Rejected;
}
}
let mut u: f64 = rng.random();
if u <= 0.0 {
u = f64::MIN_POSITIVE;
}
#[allow(clippy::cast_precision_loss)]
let decay = self.entries_seen as f64 * self.time_decay;
let weight = (-u.ln()).ln() - decay;
if self.heap.len() < self.capacity {
self.heap.push(WeightedEntry { weight, point_idx });
return SamplerOp::Inserted;
}
let max_weight = self.heap.peek().expect("heap is non-empty").weight;
if weight < max_weight {
let evicted = self.heap.pop().expect("heap is non-empty").point_idx;
self.heap.push(WeightedEntry { weight, point_idx });
SamplerOp::Replaced(evicted)
} else {
SamplerOp::Rejected
}
}
}
#[cfg(test)]
#[allow(
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation
)] mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use std::collections::HashSet;
fn fresh_rng(seed: u64) -> ChaCha8Rng {
ChaCha8Rng::seed_from_u64(seed)
}
#[test]
fn new_rejects_zero_capacity() {
assert!(matches!(
ReservoirSampler::new(0, 0.0).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn new_rejects_negative_time_decay() {
assert!(matches!(
ReservoirSampler::new(8, -0.001).unwrap_err(),
RcfError::InvalidConfig(_)
));
}
#[test]
fn new_rejects_non_finite_time_decay() {
assert!(ReservoirSampler::new(8, f64::NAN).is_err());
assert!(ReservoirSampler::new(8, f64::INFINITY).is_err());
}
#[test]
fn new_initial_state() {
let s = ReservoirSampler::new(4, 0.05).unwrap();
assert_eq!(s.capacity(), 4);
assert_eq!(s.len(), 0);
assert!(s.is_empty());
assert_eq!(s.entries_seen(), 0);
assert_eq!(s.time_decay(), 0.05);
}
#[test]
fn accept_fills_capacity_with_inserts() {
let mut s = ReservoirSampler::new(3, 0.0).unwrap();
let mut rng = fresh_rng(1);
assert_eq!(s.accept(10, &mut rng), SamplerOp::Inserted);
assert_eq!(s.accept(11, &mut rng), SamplerOp::Inserted);
assert_eq!(s.accept(12, &mut rng), SamplerOp::Inserted);
assert_eq!(s.len(), 3);
assert_eq!(s.entries_seen(), 3);
}
#[test]
fn accept_after_capacity_returns_replaced_or_rejected() {
let mut s = ReservoirSampler::new(2, 0.0).unwrap();
let mut rng = fresh_rng(7);
s.accept(10, &mut rng);
s.accept(11, &mut rng);
for i in 12..200 {
let op = s.accept(i, &mut rng);
assert!(
matches!(op, SamplerOp::Replaced(_) | SamplerOp::Rejected),
"post-capacity op should be Replaced or Rejected"
);
assert_eq!(s.len(), 2, "capacity invariant violated");
}
}
#[test]
fn replaced_evicts_existing_index() {
let mut s = ReservoirSampler::new(2, 0.0).unwrap();
let mut rng = fresh_rng(13);
s.accept(10, &mut rng);
s.accept(11, &mut rng);
let mut evicted_set: HashSet<usize> = HashSet::new();
for i in 12..200 {
if let SamplerOp::Replaced(evicted) = s.accept(i, &mut rng) {
evicted_set.insert(evicted);
assert!(!s.contains(evicted));
}
}
assert!(!evicted_set.is_empty(), "expected at least one Replaced");
}
#[test]
fn no_duplicate_indices_in_reservoir() {
let mut s = ReservoirSampler::new(50, 0.0).unwrap();
let mut rng = fresh_rng(2026);
for i in 0..10_000 {
s.accept(i, &mut rng);
}
let indices: Vec<usize> = s.iter_indices().collect();
let unique: HashSet<usize> = indices.iter().copied().collect();
assert_eq!(indices.len(), unique.len());
assert!(indices.len() <= s.capacity());
}
#[test]
fn reset_clears_state() {
let mut s = ReservoirSampler::new(4, 0.0).unwrap();
let mut rng = fresh_rng(0);
for i in 0..10 {
s.accept(i, &mut rng);
}
assert!(s.entries_seen() > 0);
s.reset();
assert_eq!(s.entries_seen(), 0);
assert!(s.is_empty());
}
#[test]
fn deterministic_under_fixed_seed() {
fn run(seed: u64) -> Vec<usize> {
let mut s = ReservoirSampler::new(8, 0.0).unwrap();
let mut rng = fresh_rng(seed);
for i in 0..100 {
s.accept(i, &mut rng);
}
let mut idxs: Vec<usize> = s.iter_indices().collect();
idxs.sort_unstable();
idxs
}
assert_eq!(run(2026), run(2026));
assert_ne!(run(2026), run(7));
}
#[test]
fn uniform_distribution_with_zero_decay() {
const CAP: usize = 32;
const N: usize = 1024;
const TRIALS: usize = 256;
let mut counts = vec![0_u32; N];
for trial in 0..TRIALS {
#[allow(clippy::cast_possible_truncation)]
let mut s = ReservoirSampler::new(CAP, 0.0).unwrap();
let mut rng = fresh_rng(trial as u64 + 1);
for i in 0..N {
s.accept(i, &mut rng);
}
for idx in s.iter_indices() {
counts[idx] += 1;
}
}
let expected = (TRIALS * CAP) as f64 / N as f64;
let total: u32 = counts.iter().sum();
assert_eq!(total as usize, TRIALS * CAP);
let max_count = *counts.iter().max().unwrap();
assert!(
f64::from(max_count) <= expected * 4.0,
"uniform sample biased: max_count={max_count} expected={expected}"
);
let nonzero = counts.iter().filter(|&&c| c > 0).count();
assert!(
nonzero as f64 >= 0.80 * N as f64,
"uniform sample too sparse: {nonzero}/{N} buckets non-zero"
);
}
#[test]
fn recency_bias_with_positive_decay() {
const CAP: usize = 32;
const N: usize = 2048;
const TRIALS: usize = 256;
const LAMBDA: f64 = 0.01;
let mut recent_count = 0_u32;
for trial in 0..TRIALS {
#[allow(clippy::cast_possible_truncation)]
let mut s = ReservoirSampler::new(CAP, LAMBDA).unwrap();
let mut rng = fresh_rng(trial as u64 + 100);
for i in 0..N {
s.accept(i, &mut rng);
}
for idx in s.iter_indices() {
if idx >= (N - N / 10) {
recent_count += 1;
}
}
}
let total: u32 = (TRIALS * CAP) as u32;
let recent_share = f64::from(recent_count) / f64::from(total);
assert!(
recent_share > 0.25,
"expected recency bias, got share={recent_share}"
);
}
#[test]
fn uniform_baseline_without_decay() {
const CAP: usize = 32;
const N: usize = 2048;
const TRIALS: usize = 256;
let mut recent_count = 0_u32;
for trial in 0..TRIALS {
#[allow(clippy::cast_possible_truncation)]
let mut s = ReservoirSampler::new(CAP, 0.0).unwrap();
let mut rng = fresh_rng(trial as u64 + 500);
for i in 0..N {
s.accept(i, &mut rng);
}
for idx in s.iter_indices() {
if idx >= (N - N / 10) {
recent_count += 1;
}
}
}
let total = (TRIALS * CAP) as u32;
let share = f64::from(recent_count) / f64::from(total);
assert!(
(0.06..0.15).contains(&share),
"uniform baseline drifted: share={share}"
);
}
#[test]
fn iter_indices_matches_len() {
let mut s = ReservoirSampler::new(5, 0.0).unwrap();
let mut rng = fresh_rng(1);
for i in 0..100 {
s.accept(i, &mut rng);
}
assert_eq!(s.iter_indices().count(), s.len());
assert_eq!(s.len(), 5);
}
#[test]
fn new_defaults_initial_accept_fraction_to_one() {
let s = ReservoirSampler::new(8, 0.0).unwrap();
assert_eq!(s.initial_accept_fraction(), 1.0);
assert_eq!(s.admit_probability(), 1.0);
}
#[test]
fn with_initial_accept_fraction_rejects_non_finite() {
assert!(ReservoirSampler::with_initial_accept_fraction(8, 0.0, f64::NAN).is_err());
assert!(ReservoirSampler::with_initial_accept_fraction(8, 0.0, f64::INFINITY).is_err());
}
#[test]
fn with_initial_accept_fraction_rejects_out_of_range() {
assert!(ReservoirSampler::with_initial_accept_fraction(8, 0.0, 0.0).is_err());
assert!(ReservoirSampler::with_initial_accept_fraction(8, 0.0, -0.1).is_err());
assert!(ReservoirSampler::with_initial_accept_fraction(8, 0.0, 1.01).is_err());
}
#[test]
fn with_initial_accept_fraction_accepts_one() {
let s = ReservoirSampler::with_initial_accept_fraction(8, 0.0, 1.0).unwrap();
assert_eq!(s.admit_probability(), 1.0);
}
#[test]
fn admit_probability_ramps_linearly_during_warmup() {
let mut s = ReservoirSampler::with_initial_accept_fraction(64, 0.0, 0.125).unwrap();
assert!((s.admit_probability() - 1.0 / 8.0).abs() < 1e-12);
s.entries_seen = 4;
assert!((s.admit_probability() - 5.0 / 8.0).abs() < 1e-12);
s.entries_seen = 7;
assert!((s.admit_probability() - 8.0 / 8.0).abs() < 1e-12);
s.entries_seen = 100;
assert_eq!(s.admit_probability(), 1.0);
}
#[test]
fn warmup_gate_rejects_early_offers_more_often_than_late() {
const TRIALS: usize = 512;
const OFFERS: usize = 32;
const CAPACITY: usize = 64;
let mut early_inserts = 0_u32;
for trial in 0..TRIALS {
let mut s =
ReservoirSampler::with_initial_accept_fraction(CAPACITY, 0.0, 0.125).unwrap();
let mut rng = fresh_rng(trial as u64 + 10_000);
for i in 0..7 {
if matches!(s.accept(i, &mut rng), SamplerOp::Inserted) {
early_inserts += 1;
}
}
let _ = OFFERS; }
let expected = 3.5 * TRIALS as f64;
let observed = f64::from(early_inserts);
assert!(
observed > expected * 0.6 && observed < expected * 1.4,
"warmup admit count {observed} outside tolerance of expected {expected}"
);
}
#[test]
fn warmup_disabled_admits_every_offer_while_below_capacity() {
let mut s = ReservoirSampler::with_initial_accept_fraction(32, 0.0, 1.0).unwrap();
let mut rng = fresh_rng(99);
for i in 0..32 {
assert_eq!(s.accept(i, &mut rng), SamplerOp::Inserted);
}
}
#[test]
fn remove_evicts_matching_entry() {
let mut s = ReservoirSampler::new(4, 0.0).unwrap();
let mut rng = fresh_rng(7);
for i in 10..14 {
s.accept(i, &mut rng);
}
assert_eq!(s.len(), 4);
assert!(s.remove(11));
assert_eq!(s.len(), 3);
assert!(!s.contains(11));
assert!(s.contains(10));
assert!(s.contains(12));
assert!(s.contains(13));
}
#[test]
fn remove_missing_is_noop() {
let mut s = ReservoirSampler::new(4, 0.0).unwrap();
let mut rng = fresh_rng(7);
for i in 10..14 {
s.accept(i, &mut rng);
}
assert!(!s.remove(9999));
assert_eq!(s.len(), 4);
}
#[test]
fn remove_keeps_entries_seen_stable() {
let mut s = ReservoirSampler::new(4, 0.1).unwrap();
let mut rng = fresh_rng(1);
for i in 10..20 {
s.accept(i, &mut rng);
}
let seen_before = s.entries_seen();
s.remove(15);
assert_eq!(s.entries_seen(), seen_before);
}
#[test]
fn remove_preserves_heap_invariant() {
let mut s = ReservoirSampler::new(64, 0.01).unwrap();
let mut rng = fresh_rng(2026);
for i in 0..200 {
s.accept(i, &mut rng);
}
let pivot = s.iter_indices().next().unwrap();
assert!(s.remove(pivot));
let op = s.accept(9999, &mut rng);
assert!(matches!(
op,
SamplerOp::Inserted | SamplerOp::Replaced(_) | SamplerOp::Rejected
));
let idxs: Vec<usize> = s.iter_indices().collect();
let unique: HashSet<usize> = idxs.iter().copied().collect();
assert_eq!(idxs.len(), unique.len());
}
#[test]
fn weighted_entry_ord_is_total_on_finite_weights() {
let a = WeightedEntry {
weight: 1.0,
point_idx: 0,
};
let b = WeightedEntry {
weight: 2.0,
point_idx: 1,
};
assert!(a < b);
assert!(b > a);
assert_eq!(a.cmp(&a), Ordering::Equal);
}
}