use crate::traits::{MergeError, Sketch};
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Clone, Debug)]
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 0x853c49e6748fea9b } else { seed },
}
}
fn next(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_bounded(&mut self, bound: usize) -> usize {
let bound = bound as u64;
let threshold = bound.wrapping_neg() % bound;
loop {
let r = self.next();
if r >= threshold {
return (r % bound) as usize;
}
}
}
}
#[derive(Clone, Debug)]
pub struct ReservoirSampler<T: Clone + core::fmt::Debug> {
capacity: usize,
reservoir: Vec<T>,
count: u64,
rng: Xorshift64,
}
impl<T: Clone + core::fmt::Debug> ReservoirSampler<T> {
pub fn new(capacity: usize) -> Self {
Self::with_seed(capacity, 0x12345678)
}
pub fn with_seed(capacity: usize, seed: u64) -> Self {
assert!(capacity > 0, "capacity must be positive");
Self {
capacity,
reservoir: Vec::with_capacity(capacity),
count: 0,
rng: Xorshift64::new(seed),
}
}
pub fn add(&mut self, item: T) {
self.count += 1;
if self.reservoir.len() < self.capacity {
self.reservoir.push(item);
} else {
let j = self.rng.next_bounded(self.count as usize);
if j < self.capacity {
self.reservoir[j] = item;
}
}
}
pub fn sample(&self) -> &[T] {
&self.reservoir
}
pub fn sample_mut(&mut self) -> &mut [T] {
&mut self.reservoir
}
pub fn into_sample(self) -> Vec<T> {
self.reservoir
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn len(&self) -> usize {
self.reservoir.len()
}
pub fn is_empty(&self) -> bool {
self.reservoir.is_empty()
}
pub fn is_full(&self) -> bool {
self.reservoir.len() >= self.capacity
}
pub fn items_seen(&self) -> u64 {
self.count
}
pub fn sampling_probability(&self) -> f64 {
if self.count == 0 {
0.0
} else {
((self.capacity as f64) / (self.count as f64)).min(1.0)
}
}
}
impl<T: Clone + core::fmt::Debug> Sketch for ReservoirSampler<T> {
type Item = T;
fn update(&mut self, item: &Self::Item) {
self.add(item.clone());
}
fn merge(&mut self, other: &Self) -> Result<(), MergeError> {
if self.capacity != other.capacity {
return Err(MergeError::IncompatibleConfig {
expected: format!("capacity={}", self.capacity),
found: format!("capacity={}", other.capacity),
});
}
if other.count == 0 {
return Ok(());
}
if self.count == 0 {
self.reservoir = other.reservoir.clone();
self.count = other.count;
return Ok(());
}
let total_count = self.count + other.count;
let self_len = self.reservoir.len();
let other_len = other.reservoir.len();
if self_len + other_len <= self.capacity {
self.reservoir.extend(other.reservoir.iter().cloned());
} else {
let output_len = self.capacity.min(self_len + other_len);
let mut new_reservoir = Vec::with_capacity(output_len);
for _ in 0..output_len {
let r = self.rng.next_bounded(total_count as usize) as u64;
if r < self.count {
let idx = self.rng.next_bounded(self_len);
new_reservoir.push(self.reservoir[idx].clone());
} else {
let idx = self.rng.next_bounded(other_len);
new_reservoir.push(other.reservoir[idx].clone());
}
}
self.reservoir = new_reservoir;
}
self.count = total_count;
Ok(())
}
fn clear(&mut self) {
self.reservoir.clear();
self.count = 0;
}
fn size_bytes(&self) -> usize {
self.reservoir.capacity() * core::mem::size_of::<T>() + 32
}
fn count(&self) -> u64 {
self.count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let mut sampler = ReservoirSampler::<i32>::new(5);
for i in 0..10 {
sampler.add(i);
}
assert_eq!(sampler.len(), 5);
assert_eq!(sampler.items_seen(), 10);
}
#[test]
fn test_underfilled() {
let mut sampler = ReservoirSampler::<i32>::new(10);
for i in 0..5 {
sampler.add(i);
}
assert_eq!(sampler.len(), 5);
assert!(!sampler.is_full());
let sample = sampler.sample();
for i in 0..5 {
assert!(sample.contains(&i));
}
}
#[test]
fn test_reproducibility() {
let mut sampler1 = ReservoirSampler::<i32>::with_seed(5, 42);
let mut sampler2 = ReservoirSampler::<i32>::with_seed(5, 42);
for i in 0..100 {
sampler1.add(i);
sampler2.add(i);
}
assert_eq!(sampler1.sample(), sampler2.sample());
}
#[test]
fn test_uniformity() {
let mut counts = [0usize; 10];
let iterations = 10000;
for i in 0..iterations {
let seed = (i as u64)
.wrapping_mul(0x9e3779b97f4a7c15)
.wrapping_add(0x853c49e6748fea9b);
let mut sampler = ReservoirSampler::<usize>::with_seed(1, seed as u64);
for i in 0..10 {
sampler.add(i);
}
counts[sampler.sample()[0]] += 1;
}
let expected = iterations / 10;
for (i, &count) in counts.iter().enumerate() {
let deviation = (count as i64 - expected as i64).abs() as f64 / expected as f64;
assert!(
deviation < 0.1,
"Item {} appeared {} times (expected ~{})",
i,
count,
expected
);
}
}
#[test]
fn test_clear() {
let mut sampler = ReservoirSampler::<i32>::new(5);
for i in 0..10 {
sampler.add(i);
}
sampler.clear();
assert!(sampler.is_empty());
assert_eq!(sampler.count(), 0);
}
#[test]
fn test_into_sample() {
let mut sampler = ReservoirSampler::<String>::new(3);
sampler.add("a".to_string());
sampler.add("b".to_string());
sampler.add("c".to_string());
let sample = sampler.into_sample();
assert_eq!(sample.len(), 3);
}
}