use super::constants::{numeric, rate};
use crate::types::TraceIdLike;
use numeric::{KNUTH_FACTOR, MAX_UINT_64BITS};
use std::fmt;
#[derive(Clone)]
pub struct RateSampler {
sample_rate: f64,
sampling_id_threshold: u64,
}
impl fmt::Debug for RateSampler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RateSampler")
.field("sample_rate", &self.sample_rate)
.finish()
}
}
impl RateSampler {
fn calculate_threshold(rate: f64) -> u64 {
if rate >= rate::MAX_SAMPLE_RATE {
MAX_UINT_64BITS
} else {
(rate * (MAX_UINT_64BITS as f64)) as u64
}
}
pub fn new(sample_rate: f64) -> Self {
let clamped_rate = sample_rate.clamp(rate::MIN_SAMPLE_RATE, rate::MAX_SAMPLE_RATE);
let sampling_id_threshold = Self::calculate_threshold(clamped_rate);
RateSampler {
sample_rate: clamped_rate,
sampling_id_threshold,
}
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
pub fn sample<T: TraceIdLike>(&self, trace_id: &T) -> bool {
if self.sample_rate <= rate::MIN_SAMPLE_RATE {
return false;
}
if self.sample_rate >= rate::MAX_SAMPLE_RATE {
return true;
}
let trace_id_64bits = trace_id.to_u128() as u64;
let hashed_id = trace_id_64bits.wrapping_mul(KNUTH_FACTOR);
hashed_id <= self.sampling_id_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq, Eq)]
struct TestTraceId {
bytes: [u8; 16],
}
impl TestTraceId {
fn from_bytes(bytes: [u8; 16]) -> Self {
Self { bytes }
}
fn to_bytes(&self) -> [u8; 16] {
self.bytes
}
}
impl TraceIdLike for TestTraceId {
fn to_u128(&self) -> u128 {
u128::from_be_bytes(self.bytes)
}
}
#[test]
fn check_debug_impl() {
let sampler = RateSampler::new(0.5);
let debug_output = format!("{sampler:?}");
assert!(debug_output.contains("RateSampler"));
assert!(debug_output.contains("sample_rate: 0.5"));
}
#[test]
fn test_rate_sampler_new() {
let sampler_zero = RateSampler::new(0.0);
assert_eq!(sampler_zero.sample_rate, 0.0);
assert_eq!(sampler_zero.sampling_id_threshold, 0);
let sampler_quarter = RateSampler::new(0.25);
assert_eq!(sampler_quarter.sample_rate, 0.25);
assert_eq!(
sampler_quarter.sampling_id_threshold,
(0.25 * (MAX_UINT_64BITS as f64)) as u64
);
let sampler_half = RateSampler::new(0.5);
assert_eq!(sampler_half.sample_rate, 0.5);
assert_eq!(
sampler_half.sampling_id_threshold,
(0.5 * (MAX_UINT_64BITS as f64)) as u64
);
let sampler_one = RateSampler::new(1.0);
assert_eq!(sampler_one.sample_rate, 1.0);
assert_eq!(sampler_one.sampling_id_threshold, MAX_UINT_64BITS);
let sampler_negative = RateSampler::new(-0.1);
assert_eq!(sampler_negative.sample_rate, 0.0);
let sampler_over_one = RateSampler::new(1.1);
assert_eq!(sampler_over_one.sample_rate, 1.0);
}
#[test]
fn test_rate_sampler_should_sample() {
let sampler_zero = RateSampler::new(0.0);
let mut bytes_zero = [0u8; 16];
bytes_zero[15] = 1; let trace_id_zero = TestTraceId::from_bytes(bytes_zero);
assert!(
!sampler_zero.sample(&trace_id_zero),
"sampler_zero should return false"
);
let sampler_one = RateSampler::new(1.0);
let mut bytes_one = [0u8; 16];
bytes_one[15] = 2; let trace_id_one = TestTraceId::from_bytes(bytes_one);
assert!(
sampler_one.sample(&trace_id_one),
"sampler_one should return true"
);
let sampler_half = RateSampler::new(0.5);
let threshold = sampler_half.sampling_id_threshold;
let bytes_sample = [0u8; 16]; let trace_id_sample = TestTraceId::from_bytes(bytes_sample);
let sample_u64 = u128::from_be_bytes(trace_id_sample.to_bytes()) as u64;
let sample_hash = sample_u64.wrapping_mul(KNUTH_FACTOR);
assert!(sample_hash <= threshold);
assert!(
sampler_half.sample(&trace_id_sample),
"sampler_half should sample trace_id_sample"
);
let mut bytes_drop = [0u8; 16];
bytes_drop[8..16].copy_from_slice(&u64::MAX.to_be_bytes()); let trace_id_drop = TestTraceId::from_bytes(bytes_drop);
let drop_u64 = u128::from_be_bytes(trace_id_drop.to_bytes()) as u64;
let drop_hash = drop_u64.wrapping_mul(KNUTH_FACTOR);
assert!(
drop_hash > threshold,
"Drop hash {drop_hash} should be > threshold {threshold}",
);
assert!(
!sampler_half.sample(&trace_id_drop),
"sampler_half should drop trace_id_drop"
);
}
#[test]
fn test_half_rate_sampling() {
let sampler_half = RateSampler::new(0.5);
let bytes_to_sample = [0u8; 16];
let trace_id_to_sample = TestTraceId::from_bytes(bytes_to_sample);
assert!(
sampler_half.sample(&trace_id_to_sample),
"Sampler with 0.5 rate should sample trace ID 0"
);
}
}