#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(super) struct Sampler {
capacity: usize,
weights: Vec<f64>,
point_indices: Vec<usize>,
size: usize,
}
#[derive(Debug)]
pub(super) struct AcceptResult {
pub(super) accepted: bool,
pub(super) evicted: Option<usize>,
}
impl Sampler {
pub(super) fn new(capacity: usize) -> Self {
Sampler {
capacity,
weights: vec![0.0f64; capacity],
point_indices: vec![usize::MAX; capacity],
size: 0,
}
}
fn swap(&mut self, a: usize, b: usize) {
self.weights.swap(a, b);
self.point_indices.swap(a, b);
}
fn sift_up(&mut self, mut i: usize) {
while i > 0 {
let parent = (i - 1) / 2;
if self.weights[i] > self.weights[parent] {
self.swap(i, parent);
i = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut i: usize) {
loop {
let left = 2 * i + 1;
let right = 2 * i + 2;
let mut largest = i;
if left < self.size && self.weights[left] > self.weights[largest] {
largest = left;
}
if right < self.size && self.weights[right] > self.weights[largest] {
largest = right;
}
if largest == i {
break;
}
self.swap(i, largest);
i = largest;
}
}
fn max_weight(&self) -> f64 {
if self.size == 0 {
f64::MAX
} else {
self.weights[0]
}
}
pub(super) fn accept(&self, is_initial: bool, weight: f64) -> AcceptResult {
if self.size < self.capacity {
if is_initial {
return AcceptResult {
accepted: true,
evicted: None,
};
}
return AcceptResult {
accepted: false,
evicted: None,
};
}
if weight < self.max_weight() {
let evicted_idx = self.point_indices[0];
AcceptResult {
accepted: true,
evicted: Some(evicted_idx),
}
} else {
AcceptResult {
accepted: false,
evicted: None,
}
}
}
pub(super) fn add_point(&mut self, tree_point_idx: usize, weight: f64) {
if self.size < self.capacity {
let i = self.size;
self.weights[i] = weight;
self.point_indices[i] = tree_point_idx;
self.size += 1;
self.sift_up(i);
} else {
self.weights[0] = weight;
self.point_indices[0] = tree_point_idx;
self.sift_down(0);
}
}
pub(super) fn is_full(&self) -> bool {
self.size == self.capacity
}
pub(super) fn fill_fraction(&self) -> f64 {
self.size as f64 / self.capacity as f64
}
#[cfg(test)]
pub(super) fn points(&self) -> &[usize] {
&self.point_indices[..self.size]
}
}
pub(super) fn reservoir_weight(u: f64, time_decay: f64, entries_seen: u64) -> f64 {
let u = u.clamp(f64::EPSILON, 1.0 - f64::EPSILON);
math::ln(-math::ln(u)) - time_decay * entries_seen as f64
}
#[cfg(test)]
mod tests {
use rstest::*;
use super::*;
#[rstest]
#[case::cap_1(1)]
#[case::cap_4(4)]
#[case::cap_8(8)]
#[case::cap_16(16)]
#[case::cap_32(32)]
fn sampler_fills_to_capacity(#[case] capacity: usize) {
let mut s = Sampler::new(capacity);
for i in 0..capacity as u64 {
let w = reservoir_weight(0.5, 0.0, i);
let result = s.accept(true, w);
assert!(result.accepted);
s.add_point(i as usize, w);
}
assert!(s.is_full());
assert_eq!(s.points().len(), capacity);
}
#[rstest]
#[case::cap_2(2)]
#[case::cap_4(4)]
#[case::cap_8(8)]
fn sampler_evicts_max_weight(#[case] capacity: usize) {
let mut s = Sampler::new(capacity);
for i in 0..capacity {
let weight = 100.0 + i as f64;
s.accept(true, weight);
s.add_point(i, weight);
}
assert!(s.is_full());
let result = s.accept(false, -100.0f64);
assert!(result.accepted);
assert!(result.evicted.is_some());
s.add_point(capacity, -100.0);
}
#[test]
fn sampler_rejects_higher_weight_when_full() {
let mut s = Sampler::new(2);
let w = -10.0f64;
s.accept(true, w);
s.add_point(0, w);
s.accept(true, w - 1.0);
s.add_point(1, w - 1.0);
let result = s.accept(false, 999.0f64);
assert!(!result.accepted);
}
#[test]
fn sampler_rejects_unapproved_warmup_candidate() {
let s = Sampler::new(2);
let result = s.accept(false, -100.0);
assert!(!result.accepted);
assert!(s.points().is_empty());
}
#[test]
fn rejected_accept_does_not_change_sampler_state() {
let mut s = Sampler::new(1);
let accepted_weight = -10.0;
s.accept(true, accepted_weight);
s.add_point(0, accepted_weight);
let result = s.accept(false, 999.0);
assert!(!result.accepted);
assert_eq!(s.points(), &[0]);
}
#[cfg(feature = "std")]
mod proptest_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn fill_fraction_in_unit_interval(capacity in 1usize..=64, n_adds in 0usize..=64) {
let n = n_adds.min(capacity);
let mut s = Sampler::new(capacity);
for i in 0..n as u64 {
let w = reservoir_weight(0.5, 0.0, i);
s.accept(true, w);
s.add_point(i as usize, w);
}
let frac = s.fill_fraction();
prop_assert!((0.0..=1.0).contains(&frac), "fill_fraction={frac}");
}
#[test]
fn sampler_never_exceeds_capacity(capacity in 1usize..=32, n_adds in 1usize..=128) {
let mut s = Sampler::new(capacity);
for i in 0..n_adds as u64 {
let w = reservoir_weight(0.5, 0.0, i);
let result = s.accept(!s.is_full(), w);
if result.accepted {
s.add_point(i as usize, w);
}
}
prop_assert!(
s.points().len() <= capacity,
"len={} > capacity={capacity}",
s.points().len()
);
}
#[test]
fn reservoir_weight_finite(
u in 0.001f64..0.999,
time_decay in 0.0f64..1.0,
entries_seen in 0u64..10_000,
) {
let w = reservoir_weight(u, time_decay, entries_seen);
prop_assert!(w.is_finite(), "reservoir_weight={w} is not finite");
}
}
}
}