use rand::{RngExt, SeedableRng, rngs::StdRng};
const KMEANS_SAMPLE_NCENT_MULT: usize = 64;
const KMEANS_SAMPLE_SIZE_FLOOR: usize = 100_000;
const KMEANS_SAMPLE_SIZE_CAP: usize = 500_000;
pub fn default_kmeans_sample_size(n_cent: usize) -> usize {
let target = KMEANS_SAMPLE_NCENT_MULT.saturating_mul(n_cent);
target.clamp(KMEANS_SAMPLE_SIZE_FLOOR, KMEANS_SAMPLE_SIZE_CAP)
}
pub(crate) struct Reservoir {
sample_size: usize,
dim: usize,
rng: StdRng,
buf: Vec<f32>,
n_seen: u64,
w: f64,
next_replace_at: u64,
}
impl Reservoir {
pub fn new(sample_size: usize, dim: usize, seed: u64) -> Self {
assert!(sample_size > 0, "Reservoir: sample_size must be > 0");
assert!(dim > 0, "Reservoir: dim must be > 0");
Self {
sample_size,
dim,
rng: StdRng::seed_from_u64(seed),
buf: Vec::with_capacity(sample_size * dim),
n_seen: 0,
w: 0.0,
next_replace_at: u64::MAX,
}
}
pub fn update(&mut self, vec: &[f32]) {
assert_eq!(
vec.len(),
self.dim,
"Reservoir::update: vec.len() {} != dim {}",
vec.len(),
self.dim
);
let k = self.sample_size as u64;
let i = self.n_seen;
self.n_seen += 1;
if i < k {
self.buf.extend_from_slice(vec);
if self.n_seen == k {
self.w = (Self::nonzero_uniform(&mut self.rng).ln() / k as f64).exp();
self.next_replace_at = i + 1 + Self::skip(&mut self.rng, self.w);
}
return;
}
if i == self.next_replace_at {
let slot = self.rng.random_range(0..self.sample_size);
self.buf[slot * self.dim..(slot + 1) * self.dim].copy_from_slice(vec);
self.w *= (Self::nonzero_uniform(&mut self.rng).ln() / k as f64).exp();
self.next_replace_at = i + 1 + Self::skip(&mut self.rng, self.w);
}
}
#[cfg(test)]
pub(crate) fn n_seen(&self) -> u64 {
self.n_seen
}
pub fn sample(&self) -> &[f32] {
&self.buf
}
#[cfg(test)]
pub(crate) fn into_sample(self) -> Vec<f32> {
self.buf
}
pub fn n_rows(&self) -> usize {
(self.n_seen as usize).min(self.sample_size)
}
fn skip(rng: &mut StdRng, w: f64) -> u64 {
let u = Self::nonzero_uniform(rng);
let denom = (1.0 - w).ln();
if !denom.is_finite() || denom == 0.0 {
return 0;
}
(u.ln() / denom).floor().max(0.0) as u64
}
fn nonzero_uniform(rng: &mut StdRng) -> f64 {
1.0 - rng.random::<f64>()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run(reservoir_size: usize, dim: usize, n: u64, seed: u64) -> Reservoir {
let mut r = Reservoir::new(reservoir_size, dim, seed);
for i in 0..n {
let mut row = vec![0.0f32; dim];
row[0] = i as f32;
r.update(&row);
}
r
}
#[test]
fn fill_phase_appends_exactly_n_seen_rows() {
let dim = 4;
let mut r = Reservoir::new(10, dim, 1);
for i in 0..5 {
let mut row = vec![0.0f32; dim];
row[0] = i as f32;
r.update(&row);
assert_eq!(r.n_rows(), (i + 1) as usize);
assert_eq!(r.sample().len(), (i + 1) as usize * dim);
}
}
#[test]
fn at_fill_boundary_buffer_holds_first_k_rows_in_order() {
let dim = 3;
let r = run( 5, dim, 5, 7);
let s = r.sample();
assert_eq!(s.len(), 5 * dim);
for i in 0..5 {
assert_eq!(
s[i * dim],
i as f32,
"fill phase didn't preserve insertion order"
);
}
}
#[test]
fn n_seen_counts_every_update_regardless_of_acceptance() {
let r = run(
10, 2, 1000, 42,
);
assert_eq!(r.n_seen(), 1000);
assert_eq!(r.n_rows(), 10);
}
#[test]
fn determinism_same_seed_same_reservoir() {
let a = run(50, 4, 10_000, 12345);
let b = run(50, 4, 10_000, 12345);
assert_eq!(a.sample(), b.sample());
}
#[test]
fn different_seeds_yield_different_reservoirs() {
let a = run(50, 4, 10_000, 1);
let b = run(50, 4, 10_000, 2);
assert_ne!(
a.sample(),
b.sample(),
"two seeds yielded identical reservoirs"
);
}
#[test]
fn distribution_is_approximately_uniform_across_seeds() {
let n = 1000usize;
let sample_size = 100usize;
let trials = 200usize;
let dim = 1;
let mut counts = vec![0u64; n];
for trial in 0..trials {
let r = run(sample_size, dim, n as u64, trial as u64 + 1);
let s = r.sample();
assert_eq!(s.len(), sample_size * dim);
for row in 0..sample_size {
let idx = s[row * dim] as usize;
assert!(idx < n, "reservoir held out-of-range item {idx}");
counts[idx] += 1;
}
}
let total: u64 = counts.iter().sum();
let expected_total = (trials * sample_size) as u64;
assert_eq!(total, expected_total, "expected exact total");
let mean = expected_total as f64 / n as f64;
let max = *counts.iter().max().expect("counts non-empty") as f64;
let min = *counts.iter().min().expect("counts non-empty") as f64;
assert!(
(max - mean).abs() < 20.0 && (mean - min).abs() < 20.0,
"non-uniform sampling: mean={mean:.2}, min={min}, max={max} \
(trial={trials}, n={n}, sample_size={sample_size})"
);
}
#[test]
fn handles_n_smaller_than_sample_size() {
let dim = 2;
let mut r = Reservoir::new(100, dim, 999);
for i in 0..5u32 {
let mut row = vec![0.0f32; dim];
row[0] = i as f32;
r.update(&row);
}
assert_eq!(r.n_seen(), 5);
assert_eq!(r.n_rows(), 5);
let s = r.sample();
assert_eq!(s.len(), 5 * dim);
for i in 0..5 {
assert_eq!(s[i * dim], i as f32);
}
}
#[test]
fn handles_n_equal_to_sample_size() {
let r = run(
50, 3, 50, 7,
);
assert_eq!(r.n_seen(), 50);
assert_eq!(r.n_rows(), 50);
let s = r.sample();
for i in 0..50 {
assert_eq!(s[i * 3], i as f32, "expected pure fill phase");
}
}
#[test]
fn into_sample_consumes_reservoir() {
let r = run(10, 4, 10_000, 1);
let owned = r.into_sample();
assert_eq!(owned.len(), 10 * 4);
}
#[test]
fn default_sample_size_clamps() {
assert_eq!(default_kmeans_sample_size(0), 100_000);
assert_eq!(default_kmeans_sample_size(64), 100_000);
assert_eq!(default_kmeans_sample_size(1_000), 100_000);
assert_eq!(default_kmeans_sample_size(2_000), 128_000);
assert_eq!(default_kmeans_sample_size(4_096), 4_096 * 64);
assert_eq!(default_kmeans_sample_size(8_192), 500_000);
assert_eq!(default_kmeans_sample_size(16_384), 500_000);
assert_eq!(default_kmeans_sample_size(usize::MAX), 500_000);
}
#[test]
fn every_item_present_when_stream_shorter_than_sample() {
let dim = 1;
let r = run( 100, dim, 30, 1);
let s = r.sample();
let mut indices: Vec<u32> = s.chunks(dim).map(|row| row[0] as u32).collect();
indices.sort();
assert_eq!(indices, (0..30).collect::<Vec<_>>());
}
}