#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Sampler {
capacity: usize,
weights: Vec<f64>,
point_indices: Vec<usize>,
size: usize,
pending_weight: f64,
pending_point: usize,
}
#[derive(Debug)]
pub struct AcceptResult {
pub accepted: bool,
pub evicted: Option<usize>,
}
impl Sampler {
pub fn new(capacity: usize) -> Self {
Sampler {
capacity,
weights: vec![0.0f64; capacity],
point_indices: vec![usize::MAX; capacity],
size: 0,
pending_weight: 0.0,
pending_point: usize::MAX,
}
}
fn swap(&mut self, a: usize, b: usize) {
self.weights.swap(a, b);
self.point_indices.swap(a, b);
}
fn sift_up(&mut self, mut i: usize) {
while i > 0 {
let parent = (i - 1) / 2;
if self.weights[i] > self.weights[parent] {
self.swap(i, parent);
i = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut i: usize) {
loop {
let left = 2 * i + 1;
let right = 2 * i + 2;
let mut largest = i;
if left < self.size && self.weights[left] > self.weights[largest] {
largest = left;
}
if right < self.size && self.weights[right] > self.weights[largest] {
largest = right;
}
if largest == i {
break;
}
self.swap(i, largest);
i = largest;
}
}
fn max_weight(&self) -> f64 {
if self.size == 0 {
f64::MAX
} else {
self.weights[0]
}
}
pub fn accept(&mut self, is_initial: bool, weight: f64, point_idx: usize) -> AcceptResult {
if is_initial || (self.size < self.capacity) {
self.pending_weight = weight;
self.pending_point = point_idx;
return AcceptResult {
accepted: true,
evicted: None,
};
}
if weight < self.max_weight() {
let evicted_idx = self.point_indices[0];
self.pending_weight = weight;
self.pending_point = point_idx;
AcceptResult {
accepted: true,
evicted: Some(evicted_idx),
}
} else {
AcceptResult {
accepted: false,
evicted: None,
}
}
}
pub fn add_point(&mut self, tree_point_idx: usize) {
debug_assert!(self.pending_point != usize::MAX);
if self.size < self.capacity {
let i = self.size;
self.weights[i] = self.pending_weight;
self.point_indices[i] = tree_point_idx;
self.size += 1;
self.sift_up(i);
} else {
self.weights[0] = self.pending_weight;
self.point_indices[0] = tree_point_idx;
self.sift_down(0);
}
self.pending_point = usize::MAX;
}
pub fn is_full(&self) -> bool {
self.size == self.capacity
}
pub fn fill_fraction(&self) -> f64 {
self.size as f64 / self.capacity as f64
}
pub fn points(&self) -> &[usize] {
&self.point_indices[..self.size]
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
#[cfg(feature = "std")]
fn ln_f64(x: f64) -> f64 {
x.ln()
}
#[cfg(not(feature = "std"))]
fn ln_f64(x: f64) -> f64 {
libm::log(x)
}
pub fn reservoir_weight(u: f64, time_decay: f64, entries_seen: u64) -> f64 {
let u = u.clamp(f64::EPSILON, 1.0 - f64::EPSILON);
ln_f64(-ln_f64(u)) - time_decay * entries_seen as f64
}
#[cfg(test)]
mod tests {
use rstest::*;
use super::*;
#[rstest]
#[case::cap_1(1)]
#[case::cap_4(4)]
#[case::cap_8(8)]
#[case::cap_16(16)]
#[case::cap_32(32)]
fn sampler_fills_to_capacity(#[case] capacity: usize) {
let mut s = Sampler::new(capacity);
for i in 0..capacity as u64 {
let w = reservoir_weight(0.5, 0.0, i);
let result = s.accept(true, w, i as usize);
assert!(result.accepted);
s.add_point(i as usize);
}
assert!(s.is_full());
assert_eq!(s.points().len(), capacity);
}
#[rstest]
#[case::cap_2(2)]
#[case::cap_4(4)]
#[case::cap_8(8)]
fn sampler_evicts_max_weight(#[case] capacity: usize) {
let mut s = Sampler::new(capacity);
for i in 0..capacity {
s.accept(true, 100.0 + i as f64, i);
s.add_point(i);
}
assert!(s.is_full());
let result = s.accept(false, -100.0f64, capacity);
assert!(result.accepted);
assert!(result.evicted.is_some());
s.add_point(capacity);
}
#[test]
fn sampler_rejects_higher_weight_when_full() {
let mut s = Sampler::new(2);
let w = -10.0f64;
s.accept(true, w, 0);
s.add_point(0);
s.accept(true, w - 1.0, 1);
s.add_point(1);
let result = s.accept(false, 999.0f64, 2);
assert!(!result.accepted);
}
}