use std::cell::UnsafeCell;
use rand::{Rng, SeedableRng, rngs::SmallRng};
use super::HeadSampler;
thread_local! {
static RNG: UnsafeCell<SmallRng> = UnsafeCell::new(SmallRng::from_rng(&mut rand::rng()));
}
#[derive(Debug)]
#[must_use]
pub struct ProbabilisticSampler {
probability: f64,
}
impl ProbabilisticSampler {
pub fn new(probability: f64) -> Self {
assert!((0.0..=1.0).contains(&probability));
Self { probability }
}
}
impl HeadSampler for ProbabilisticSampler {
#[inline]
fn should_sample<S>(&self, span: &tracing_subscriber::registry::SpanRef<S>) -> bool
where
S: for<'a> tracing_subscriber::registry::LookupSpan<'a>,
{
if span.parent().is_none() {
RNG.try_with(|rng| {
unsafe { rng.get().as_mut().unwrap().random_bool(self.probability) }
})
.unwrap_or(false)
} else {
true
}
}
}