use crate::event::SpanEvent;
const SAMPLING_HASHMAP_THRESHOLD: usize = 16;
pub(super) fn apply_sampling(events: Vec<SpanEvent>, rate: f64) -> Vec<SpanEvent> {
if rate >= 1.0 {
return events;
}
if events.len() <= SAMPLING_HASHMAP_THRESHOLD {
let mut cache: [(u64, bool); SAMPLING_HASHMAP_THRESHOLD] =
[(0_u64, false); SAMPLING_HASHMAP_THRESHOLD];
let mut cache_len: usize = 0;
events
.into_iter()
.filter(|e| {
let h = hash_trace_id(&e.trace_id);
if let Some(&(_, decision)) = cache[..cache_len].iter().find(|(k, _)| *k == h) {
return decision;
}
let decision = hash_to_decision(h, rate);
if cache_len < SAMPLING_HASHMAP_THRESHOLD {
cache[cache_len] = (h, decision);
cache_len += 1;
}
decision
})
.collect()
} else {
let mut cache = std::collections::HashMap::<u64, bool>::with_capacity(events.len() / 4);
events
.into_iter()
.filter(|e| {
let h = hash_trace_id(&e.trace_id);
if let Some(&decision) = cache.get(&h) {
return decision;
}
let decision = hash_to_decision(h, rate);
cache.insert(h, decision);
decision
})
.collect()
}
}
#[inline]
fn hash_trace_id(trace_id: &str) -> u64 {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
for b in trace_id.as_bytes() {
hash ^= u64::from(*b);
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
hash
}
#[inline]
#[allow(clippy::cast_precision_loss)] fn hash_to_decision(hash: u64, rate: f64) -> bool {
if rate >= 1.0 {
return true;
}
if rate <= 0.0 {
return false;
}
(hash as f64 / u64::MAX as f64) < rate
}
#[cfg(test)]
fn should_sample(trace_id: &str, rate: f64) -> bool {
hash_to_decision(hash_trace_id(trace_id), rate)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::event::{EventSource, EventType, SpanEvent};
fn make_event(trace_id: &str) -> SpanEvent {
SpanEvent {
timestamp: "2025-07-10T14:32:01.123Z".to_string(),
trace_id: trace_id.to_string(),
span_id: "s1".to_string(),
parent_span_id: None,
service: Arc::from("test"),
cloud_region: None,
event_type: EventType::Sql,
operation: "SELECT".to_string(),
target: "SELECT 1".to_string(),
duration_us: 100,
source: EventSource {
endpoint: "GET /test".to_string(),
method: "Test::test".to_string(),
},
status_code: None,
response_size_bytes: None,
code_function: None,
code_filepath: None,
code_lineno: None,
code_namespace: None,
instrumentation_scopes: Vec::new(),
}
}
#[test]
fn should_sample_deterministic() {
let r1 = should_sample("trace-abc-123", 0.5);
let r2 = should_sample("trace-abc-123", 0.5);
assert_eq!(r1, r2);
}
#[test]
fn should_sample_rate_zero_drops_all() {
assert!(!should_sample("any-trace", 0.0));
assert!(!should_sample("another-trace", 0.0));
}
#[test]
fn should_sample_rate_one_keeps_all() {
assert!(should_sample("any-trace", 1.0));
assert!(should_sample("another-trace", 1.0));
}
#[test]
fn should_sample_rate_half_splits() {
let sampled = (0..1000)
.filter(|i| should_sample(&format!("trace-{i}"), 0.5))
.count();
assert!(
(300..=700).contains(&sampled),
"expected ~500 sampled, got {sampled}"
);
}
#[test]
fn apply_sampling_full_rate_returns_all() {
let events = vec![make_event("t1"), make_event("t2"), make_event("t3")];
let sampled = apply_sampling(events, 1.0);
assert_eq!(sampled.len(), 3);
}
#[test]
fn apply_sampling_zero_rate_drops_all() {
let events = vec![make_event("t1"), make_event("t2")];
let sampled = apply_sampling(events, 0.0);
assert!(sampled.is_empty());
}
#[test]
fn apply_sampling_same_trace_id_cached_decision() {
let events = vec![
make_event("same-trace"),
make_event("same-trace"),
make_event("same-trace"),
make_event("same-trace"),
];
let sampled = apply_sampling(events, 1.0);
assert_eq!(
sampled.len(),
4,
"rate 1.0 must keep every event regardless of trace_id"
);
let events2 = vec![
make_event("cached-trace"),
make_event("cached-trace"),
make_event("cached-trace"),
];
let sampled2 = apply_sampling(events2, 0.5);
assert!(
sampled2.is_empty() || sampled2.len() == 3,
"all events for the same trace_id must share the cached \
decision, got {} of 3 kept (expected 0 or 3)",
sampled2.len()
);
}
#[test]
fn apply_sampling_mixed_trace_ids_with_partial_rate() {
let events: Vec<_> = (0..100)
.map(|i| make_event(&format!("trace-{i}")))
.collect();
let sampled = apply_sampling(events, 0.5);
assert!(
(10..=90).contains(&sampled.len()),
"expected ~50 sampled, got {}",
sampled.len()
);
}
}