use rand::Rng;
pub struct SamplingParams {
pub vu_threshold: usize,
pub reservoir_size: usize,
}
impl Default for SamplingParams {
fn default() -> Self {
Self {
vu_threshold: 50,
reservoir_size: 100_000,
}
}
}
pub enum ReservoirAction {
Push,
Replace(usize),
Discard,
}
pub struct SamplingState {
vu_threshold: usize,
reservoir_size: usize,
sample_rate: f64,
min_sample_rate: f64,
total_requests: usize,
total_failures: usize,
total_seen_for_reservoir: usize,
rng: rand::rngs::ThreadRng,
}
impl SamplingState {
pub fn new(params: SamplingParams) -> Self {
Self {
vu_threshold: params.vu_threshold,
reservoir_size: params.reservoir_size,
sample_rate: 1.0,
min_sample_rate: 1.0,
total_requests: 0,
total_failures: 0,
total_seen_for_reservoir: 0,
rng: rand::rng(),
}
}
pub fn set_active_vus(&mut self, vus: usize) {
self.sample_rate = if self.vu_threshold == 0 || vus <= self.vu_threshold {
1.0
} else {
self.vu_threshold as f64 / vus as f64
};
self.min_sample_rate = self.min_sample_rate.min(self.sample_rate);
}
pub fn record_request(&mut self, success: bool) {
self.total_requests += 1;
if !success {
self.total_failures += 1;
}
}
pub fn should_collect(&mut self) -> bool {
self.sample_rate >= 1.0 || self.rng.random::<f64>() < self.sample_rate
}
pub fn reservoir_slot(&mut self, results_len: usize) -> ReservoirAction {
self.total_seen_for_reservoir += 1;
if results_len < self.reservoir_size {
ReservoirAction::Push
} else {
let j = self.rng.random_range(0..self.total_seen_for_reservoir);
if j < self.reservoir_size {
ReservoirAction::Replace(j)
} else {
ReservoirAction::Discard
}
}
}
pub fn total_requests(&self) -> usize {
self.total_requests
}
pub fn total_failures(&self) -> usize {
self.total_failures
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
pub fn min_sample_rate(&self) -> f64 {
self.min_sample_rate
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_state() -> SamplingState {
SamplingState::new(SamplingParams::default())
}
#[test]
fn rate_is_1_below_threshold() {
let mut s = default_state();
s.set_active_vus(49);
assert_eq!(s.sample_rate(), 1.0);
}
#[test]
fn rate_is_1_at_threshold() {
let mut s = default_state();
s.set_active_vus(50);
assert_eq!(s.sample_rate(), 1.0);
}
#[test]
fn rate_drops_above_threshold() {
let mut s = default_state();
s.set_active_vus(100);
assert!((s.sample_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn rate_scales_proportionally() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 50,
reservoir_size: 100_000,
});
s.set_active_vus(200);
assert!((s.sample_rate() - 0.25).abs() < f64::EPSILON);
}
#[test]
fn zero_threshold_always_collects() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 0,
reservoir_size: 100_000,
});
s.set_active_vus(10_000);
assert_eq!(s.sample_rate(), 1.0);
for _ in 0..100 {
assert!(s.should_collect());
}
}
#[test]
fn min_sample_rate_tracks_lowest_observed() {
let mut s = default_state();
s.set_active_vus(100); s.set_active_vus(200); s.set_active_vus(50); assert!((s.min_sample_rate() - 0.25).abs() < f64::EPSILON);
}
#[test]
fn min_sample_rate_starts_at_1() {
let s = default_state();
assert_eq!(s.min_sample_rate(), 1.0);
}
#[test]
fn record_request_increments_total() {
let mut s = default_state();
s.record_request(true);
s.record_request(true);
assert_eq!(s.total_requests(), 2);
}
#[test]
fn record_request_tracks_failures() {
let mut s = default_state();
s.record_request(true);
s.record_request(false);
s.record_request(false);
assert_eq!(s.total_requests(), 3);
assert_eq!(s.total_failures(), 2);
}
#[test]
fn record_request_success_does_not_increment_failures() {
let mut s = default_state();
s.record_request(true);
assert_eq!(s.total_failures(), 0);
}
#[test]
fn should_collect_always_true_at_full_rate() {
let mut s = default_state();
s.set_active_vus(10); for _ in 0..1000 {
assert!(s.should_collect());
}
}
#[test]
fn should_collect_probabilistic_at_half_rate() {
let mut s = default_state();
s.set_active_vus(100); let collected: usize = (0..10_000).filter(|_| s.should_collect()).count();
assert!(
collected > 4_000 && collected < 6_000,
"expected ~5000 collected, got {collected}"
);
}
#[test]
fn reservoir_pushes_while_not_full() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 0,
reservoir_size: 5,
});
for i in 0..5 {
match s.reservoir_slot(i) {
ReservoirAction::Push => {}
_ => panic!("expected Push at results_len={i}"),
}
}
}
#[test]
fn reservoir_never_pushes_when_full() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 0,
reservoir_size: 5,
});
for i in 0..5 {
s.reservoir_slot(i);
}
for _ in 0..100 {
if let ReservoirAction::Push = s.reservoir_slot(5) {
panic!("Push when reservoir is full")
}
}
}
#[test]
fn reservoir_replace_index_is_in_bounds() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 0,
reservoir_size: 5,
});
for i in 0..5 {
s.reservoir_slot(i);
}
for _ in 0..200 {
if let ReservoirAction::Replace(idx) = s.reservoir_slot(5) {
assert!(
idx < 5,
"Replace index {idx} out of bounds for reservoir_size=5"
);
}
}
}
#[test]
fn reservoir_discard_rate_decreases_over_time() {
let mut s = SamplingState::new(SamplingParams {
vu_threshold: 0,
reservoir_size: 10,
});
for i in 0..10 {
s.reservoir_slot(i);
}
let mut replaces = 0usize;
let mut discards = 0usize;
for _ in 0..1000 {
match s.reservoir_slot(10) {
ReservoirAction::Replace(_) => replaces += 1,
ReservoirAction::Discard => discards += 1,
ReservoirAction::Push => panic!("unexpected Push"),
}
}
assert!(
discards > replaces,
"expected more discards than replaces at high total_seen; replaces={replaces}, discards={discards}"
);
}
#[test]
fn is_sampling_reflects_history() {
let mut s = default_state();
s.set_active_vus(10); assert_eq!(s.min_sample_rate(), 1.0);
s.set_active_vus(100); assert!((s.min_sample_rate() - 0.5).abs() < f64::EPSILON);
s.set_active_vus(10); assert!((s.min_sample_rate() - 0.5).abs() < f64::EPSILON);
}
}