use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use super::structured_logging::LogLevel;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SamplingDecision {
Record,
Drop,
}
#[derive(Debug, Clone)]
pub struct SamplingContext {
pub level: LogLevel,
pub operation: Option<String>,
pub request_id: Option<String>,
pub is_error: bool,
pub priority: u8,
pub trace_id: Option<String>,
pub user_id: Option<String>,
pub tags: HashMap<String, String>,
}
impl SamplingContext {
pub fn new(level: LogLevel) -> Self {
Self {
level,
operation: None,
request_id: None,
is_error: false,
priority: 5,
trace_id: None,
user_id: None,
tags: HashMap::new(),
}
}
pub fn with_operation(mut self, operation: String) -> Self {
self.operation = Some(operation);
self
}
pub fn with_request_id(mut self, request_id: String) -> Self {
self.request_id = Some(request_id);
self
}
pub fn with_error(mut self, is_error: bool) -> Self {
self.is_error = is_error;
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority.min(10);
self
}
pub fn with_trace_id(mut self, trace_id: String) -> Self {
self.trace_id = Some(trace_id);
self
}
pub fn with_user_id(mut self, user_id: String) -> Self {
self.user_id = Some(user_id);
self
}
pub fn with_tag(mut self, key: String, value: String) -> Self {
self.tags.insert(key, value);
self
}
}
pub trait LogSampler: Send + Sync {
fn should_sample(&self, context: &SamplingContext) -> SamplingDecision;
fn reset(&self) {}
fn get_statistics(&self) -> SamplerStatistics {
SamplerStatistics::default()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SamplerStatistics {
pub total_evaluated: u64,
pub total_recorded: u64,
pub total_dropped: u64,
pub sampling_rate: f64,
pub last_reset: Option<u64>,
}
impl SamplerStatistics {
pub fn calculate_rate(&mut self) {
if self.total_evaluated > 0 {
self.sampling_rate = self.total_recorded as f64 / self.total_evaluated as f64;
} else {
self.sampling_rate = 0.0;
}
}
}
#[derive(Debug, Clone)]
pub struct AlwaysSampler;
impl LogSampler for AlwaysSampler {
fn should_sample(&self, _context: &SamplingContext) -> SamplingDecision {
SamplingDecision::Record
}
}
#[derive(Debug, Clone)]
pub struct NeverSampler;
impl LogSampler for NeverSampler {
fn should_sample(&self, _context: &SamplingContext) -> SamplingDecision {
SamplingDecision::Drop
}
}
#[derive(Debug)]
pub struct ProbabilisticSampler {
rate: f64,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl ProbabilisticSampler {
pub fn new(rate: f64) -> Self {
Self {
rate: rate.clamp(0.0, 1.0),
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
}
impl LogSampler for ProbabilisticSampler {
fn should_sample(&self, _context: &SamplingContext) -> SamplingDecision {
let random_value = fastrand::f64();
if random_value < self.rate {
SamplingDecision::Record
} else {
SamplingDecision::Drop
}
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug)]
pub struct RateLimitedSampler {
max_logs_per_window: usize,
window_duration: Duration,
log_timestamps: Arc<RwLock<VecDeque<Instant>>>,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl RateLimitedSampler {
pub fn new(max_logs_per_window: usize, window_duration: Duration) -> Self {
Self {
max_logs_per_window,
window_duration,
log_timestamps: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
}
impl LogSampler for RateLimitedSampler {
fn should_sample(&self, _context: &SamplingContext) -> SamplingDecision {
if let Ok(mut timestamps) = self.log_timestamps.try_write() {
let now = Instant::now();
while let Some(&front) = timestamps.front() {
if now.duration_since(front) > self.window_duration {
timestamps.pop_front();
} else {
break;
}
}
if timestamps.len() < self.max_logs_per_window {
timestamps.push_back(now);
return SamplingDecision::Record;
}
}
SamplingDecision::Drop
}
fn reset(&self) {
if let Ok(mut timestamps) = self.log_timestamps.try_write() {
timestamps.clear();
}
if let Ok(mut stats) = self.stats.try_write() {
*stats = SamplerStatistics::default();
}
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug)]
pub struct PriorityBasedSampler {
min_priority: u8,
priority_rates: HashMap<u8, f64>,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl PriorityBasedSampler {
pub fn new(min_priority: u8) -> Self {
Self {
min_priority: min_priority.min(10),
priority_rates: HashMap::new(),
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
pub fn with_priority_rate(mut self, priority: u8, rate: f64) -> Self {
self.priority_rates.insert(priority, rate.clamp(0.0, 1.0));
self
}
}
impl LogSampler for PriorityBasedSampler {
fn should_sample(&self, context: &SamplingContext) -> SamplingDecision {
if context.priority >= self.min_priority {
if let Some(&rate) = self.priority_rates.get(&context.priority) {
if fastrand::f64() < rate {
return SamplingDecision::Record;
} else {
return SamplingDecision::Drop;
}
}
return SamplingDecision::Record;
}
SamplingDecision::Drop
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug)]
pub struct ErrorAwareSampler {
normal_rate: f64,
always_sample_errors: bool,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl ErrorAwareSampler {
pub fn new(normal_rate: f64, always_sample_errors: bool) -> Self {
Self {
normal_rate: normal_rate.clamp(0.0, 1.0),
always_sample_errors,
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
}
impl LogSampler for ErrorAwareSampler {
fn should_sample(&self, context: &SamplingContext) -> SamplingDecision {
if self.always_sample_errors && context.is_error {
return SamplingDecision::Record;
}
let rate = match context.level {
LogLevel::Error | LogLevel::Warn => 1.0,
_ => self.normal_rate,
};
if fastrand::f64() < rate {
SamplingDecision::Record
} else {
SamplingDecision::Drop
}
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug)]
pub struct TailSampler {
buffer_size: usize,
buffer: Arc<RwLock<HashMap<String, Vec<SamplingContext>>>>,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl TailSampler {
pub fn new(buffer_size: usize) -> Self {
Self {
buffer_size,
buffer: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
pub async fn decide_trace(&self, trace_id: &str) -> Vec<SamplingDecision> {
let buffer = self.buffer.read().await;
if let Some(contexts) = buffer.get(trace_id) {
let has_error = contexts.iter().any(|ctx| ctx.is_error);
let decision = if has_error {
SamplingDecision::Record
} else {
SamplingDecision::Drop
};
vec![decision; contexts.len()]
} else {
vec![]
}
}
}
impl LogSampler for TailSampler {
fn should_sample(&self, context: &SamplingContext) -> SamplingDecision {
if let Some(trace_id) = &context.trace_id {
if let Ok(mut buffer) = self.buffer.try_write() {
buffer
.entry(trace_id.clone())
.or_insert_with(Vec::new)
.push(context.clone());
if buffer.len() > self.buffer_size {
if let Some(oldest_key) = buffer.keys().next().cloned() {
buffer.remove(&oldest_key);
}
}
}
}
SamplingDecision::Record
}
fn reset(&self) {
if let Ok(mut buffer) = self.buffer.try_write() {
buffer.clear();
}
if let Ok(mut stats) = self.stats.try_write() {
*stats = SamplerStatistics::default();
}
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug)]
pub struct AdaptiveSampler {
base_rate: f64,
current_rate: Arc<RwLock<f64>>,
target_logs_per_second: f64,
recent_logs: Arc<RwLock<VecDeque<Instant>>>,
stats: Arc<RwLock<SamplerStatistics>>,
}
impl AdaptiveSampler {
pub fn new(base_rate: f64, target_logs_per_second: f64) -> Self {
Self {
base_rate: base_rate.clamp(0.0, 1.0),
current_rate: Arc::new(RwLock::new(base_rate.clamp(0.0, 1.0))),
target_logs_per_second,
recent_logs: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(SamplerStatistics::default())),
}
}
pub async fn adjust_rate(&self) {
let mut logs = self.recent_logs.write().await;
let now = Instant::now();
while let Some(&front) = logs.front() {
if now.duration_since(front) > Duration::from_secs(1) {
logs.pop_front();
} else {
break;
}
}
let current_logs_per_second = logs.len() as f64;
let mut current_rate = self.current_rate.write().await;
if current_logs_per_second > self.target_logs_per_second {
*current_rate = (*current_rate * 0.9).max(0.01);
} else if current_logs_per_second < self.target_logs_per_second * 0.8 {
*current_rate = (*current_rate * 1.1).min(1.0);
}
}
pub async fn get_current_rate(&self) -> f64 {
*self.current_rate.read().await
}
}
impl LogSampler for AdaptiveSampler {
fn should_sample(&self, _context: &SamplingContext) -> SamplingDecision {
if let Ok(logs) = self.recent_logs.try_write() {
let mut logs = logs;
logs.push_back(Instant::now());
if let Ok(current_rate) = self.current_rate.try_read() {
if fastrand::f64() < *current_rate {
return SamplingDecision::Record;
}
}
}
SamplingDecision::Drop
}
fn reset(&self) {
if let Ok(mut logs) = self.recent_logs.try_write() {
logs.clear();
}
if let Ok(mut current_rate) = self.current_rate.try_write() {
*current_rate = self.base_rate;
}
if let Ok(mut stats) = self.stats.try_write() {
*stats = SamplerStatistics::default();
}
}
fn get_statistics(&self) -> SamplerStatistics {
match self.stats.try_read() {
Ok(stats) => stats.clone(),
Err(_) => SamplerStatistics::default(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CompositeStrategy {
All,
Any,
FirstMatch,
}
pub struct CompositeSampler {
strategy: CompositeStrategy,
samplers: Vec<Arc<dyn LogSampler>>,
}
impl CompositeSampler {
pub fn new(strategy: CompositeStrategy) -> Self {
Self {
strategy,
samplers: Vec::new(),
}
}
pub fn add_sampler(mut self, sampler: Arc<dyn LogSampler>) -> Self {
self.samplers.push(sampler);
self
}
}
impl LogSampler for CompositeSampler {
fn should_sample(&self, context: &SamplingContext) -> SamplingDecision {
match self.strategy {
CompositeStrategy::All => {
for sampler in &self.samplers {
if sampler.should_sample(context) == SamplingDecision::Drop {
return SamplingDecision::Drop;
}
}
SamplingDecision::Record
}
CompositeStrategy::Any => {
for sampler in &self.samplers {
if sampler.should_sample(context) == SamplingDecision::Record {
return SamplingDecision::Record;
}
}
SamplingDecision::Drop
}
CompositeStrategy::FirstMatch => {
for sampler in &self.samplers {
let decision = sampler.should_sample(context);
if decision == SamplingDecision::Record {
return decision;
}
}
SamplingDecision::Drop
}
}
}
fn reset(&self) {
for sampler in &self.samplers {
sampler.reset();
}
}
}
pub struct LogSamplingManager {
default_sampler: Arc<dyn LogSampler>,
named_samplers: HashMap<String, Arc<dyn LogSampler>>,
}
impl LogSamplingManager {
pub fn new(default_sampler: Arc<dyn LogSampler>) -> Self {
Self {
default_sampler,
named_samplers: HashMap::new(),
}
}
pub fn register_sampler(&mut self, name: String, sampler: Arc<dyn LogSampler>) {
self.named_samplers.insert(name, sampler);
}
pub fn get_sampler(&self, name: &str) -> Option<&Arc<dyn LogSampler>> {
self.named_samplers.get(name)
}
pub fn should_sample(&self, context: &SamplingContext) -> SamplingDecision {
self.default_sampler.should_sample(context)
}
pub fn should_sample_with(&self, name: &str, context: &SamplingContext) -> SamplingDecision {
if let Some(sampler) = self.named_samplers.get(name) {
sampler.should_sample(context)
} else {
self.default_sampler.should_sample(context)
}
}
pub fn reset_all(&self) {
self.default_sampler.reset();
for sampler in self.named_samplers.values() {
sampler.reset();
}
}
}
impl Default for LogSamplingManager {
fn default() -> Self {
Self::new(Arc::new(AlwaysSampler))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_always_sampler() {
let sampler = AlwaysSampler;
let context = SamplingContext::new(LogLevel::Info);
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
#[test]
fn test_never_sampler() {
let sampler = NeverSampler;
let context = SamplingContext::new(LogLevel::Info);
assert_eq!(sampler.should_sample(&context), SamplingDecision::Drop);
}
#[test]
fn test_probabilistic_sampler() {
let sampler = ProbabilisticSampler::new(0.5);
let context = SamplingContext::new(LogLevel::Info);
let mut records = 0;
let mut _drops = 0;
for _ in 0..100 {
match sampler.should_sample(&context) {
SamplingDecision::Record => records += 1,
SamplingDecision::Drop => _drops += 1,
}
}
assert!(records > 30 && records < 70);
assert!(_drops > 30 && _drops < 70);
}
#[test]
fn test_probabilistic_sampler_zero_rate() {
let sampler = ProbabilisticSampler::new(0.0);
let context = SamplingContext::new(LogLevel::Info);
for _ in 0..10 {
assert_eq!(sampler.should_sample(&context), SamplingDecision::Drop);
}
}
#[test]
fn test_probabilistic_sampler_full_rate() {
let sampler = ProbabilisticSampler::new(1.0);
let context = SamplingContext::new(LogLevel::Info);
for _ in 0..10 {
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
}
#[test]
fn test_rate_limited_sampler() {
let sampler = RateLimitedSampler::new(5, Duration::from_secs(1));
let context = SamplingContext::new(LogLevel::Info);
for _ in 0..5 {
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
for _ in 0..3 {
assert_eq!(sampler.should_sample(&context), SamplingDecision::Drop);
}
}
#[test]
fn test_priority_based_sampler() {
let sampler = PriorityBasedSampler::new(7);
let high_priority = SamplingContext::new(LogLevel::Info).with_priority(8);
let low_priority = SamplingContext::new(LogLevel::Info).with_priority(5);
assert_eq!(
sampler.should_sample(&high_priority),
SamplingDecision::Record
);
assert_eq!(sampler.should_sample(&low_priority), SamplingDecision::Drop);
}
#[test]
fn test_priority_based_sampler_with_rates() {
let sampler = PriorityBasedSampler::new(5).with_priority_rate(7, 0.5);
let context = SamplingContext::new(LogLevel::Info).with_priority(7);
let mut records = 0;
let mut _drops = 0;
for _ in 0..100 {
match sampler.should_sample(&context) {
SamplingDecision::Record => records += 1,
SamplingDecision::Drop => _drops += 1,
}
}
assert!(records > 30 && records < 70);
}
#[test]
fn test_error_aware_sampler() {
let sampler = ErrorAwareSampler::new(0.1, true);
let error_context = SamplingContext::new(LogLevel::Error).with_error(true);
let normal_context = SamplingContext::new(LogLevel::Info);
assert_eq!(
sampler.should_sample(&error_context),
SamplingDecision::Record
);
let mut records = 0;
for _ in 0..100 {
if sampler.should_sample(&normal_context) == SamplingDecision::Record {
records += 1;
}
}
assert!(records > 0 && records < 30);
}
#[test]
fn test_error_aware_sampler_log_levels() {
let sampler = ErrorAwareSampler::new(0.1, false);
let error_context = SamplingContext::new(LogLevel::Error);
let warn_context = SamplingContext::new(LogLevel::Warn);
assert_eq!(
sampler.should_sample(&error_context),
SamplingDecision::Record
);
assert_eq!(
sampler.should_sample(&warn_context),
SamplingDecision::Record
);
}
#[test]
fn test_tail_sampler() {
let sampler = TailSampler::new(100);
let context = SamplingContext::new(LogLevel::Info)
.with_trace_id("trace-123".to_string())
.with_error(false);
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
#[test]
fn test_adaptive_sampler() {
let sampler = AdaptiveSampler::new(0.5, 100.0);
let context = SamplingContext::new(LogLevel::Info);
let mut records = 0;
for _ in 0..100 {
if sampler.should_sample(&context) == SamplingDecision::Record {
records += 1;
}
}
assert!(records > 0);
}
#[test]
fn test_composite_sampler_all() {
let sampler = CompositeSampler::new(CompositeStrategy::All)
.add_sampler(Arc::new(AlwaysSampler))
.add_sampler(Arc::new(ProbabilisticSampler::new(0.5)));
let context = SamplingContext::new(LogLevel::Info);
let mut records = 0;
for _ in 0..100 {
if sampler.should_sample(&context) == SamplingDecision::Record {
records += 1;
}
}
assert!(records > 30 && records < 70);
}
#[test]
fn test_composite_sampler_any() {
let sampler = CompositeSampler::new(CompositeStrategy::Any)
.add_sampler(Arc::new(AlwaysSampler))
.add_sampler(Arc::new(NeverSampler));
let context = SamplingContext::new(LogLevel::Info);
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
#[test]
fn test_composite_sampler_first_match() {
let sampler = CompositeSampler::new(CompositeStrategy::FirstMatch)
.add_sampler(Arc::new(NeverSampler))
.add_sampler(Arc::new(AlwaysSampler));
let context = SamplingContext::new(LogLevel::Info);
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
#[test]
fn test_sampling_context_builder() {
let context = SamplingContext::new(LogLevel::Info)
.with_operation("query".to_string())
.with_request_id("req-123".to_string())
.with_error(true)
.with_priority(8)
.with_trace_id("trace-456".to_string())
.with_user_id("user-789".to_string())
.with_tag("env".to_string(), "prod".to_string());
assert_eq!(context.operation, Some("query".to_string()));
assert_eq!(context.request_id, Some("req-123".to_string()));
assert!(context.is_error);
assert_eq!(context.priority, 8);
assert_eq!(context.trace_id, Some("trace-456".to_string()));
assert_eq!(context.user_id, Some("user-789".to_string()));
assert_eq!(context.tags.get("env"), Some(&"prod".to_string()));
}
#[test]
fn test_log_sampling_manager() {
let mut manager = LogSamplingManager::new(Arc::new(AlwaysSampler));
manager.register_sampler("never".to_string(), Arc::new(NeverSampler));
manager.register_sampler(
"probabilistic".to_string(),
Arc::new(ProbabilisticSampler::new(0.5)),
);
let context = SamplingContext::new(LogLevel::Info);
assert_eq!(manager.should_sample(&context), SamplingDecision::Record);
assert_eq!(
manager.should_sample_with("never", &context),
SamplingDecision::Drop
);
assert_eq!(
manager.should_sample_with("unknown", &context),
SamplingDecision::Record
);
}
#[test]
fn test_rate_limited_sampler_reset() {
let sampler = RateLimitedSampler::new(3, Duration::from_secs(1));
let context = SamplingContext::new(LogLevel::Info);
for _ in 0..3 {
sampler.should_sample(&context);
}
assert_eq!(sampler.should_sample(&context), SamplingDecision::Drop);
sampler.reset();
assert_eq!(sampler.should_sample(&context), SamplingDecision::Record);
}
#[test]
fn test_sampler_statistics() {
let mut stats = SamplerStatistics {
total_evaluated: 100,
total_recorded: 75,
total_dropped: 25,
..Default::default()
};
stats.calculate_rate();
assert_eq!(stats.sampling_rate, 0.75);
}
#[test]
fn test_priority_clamping() {
let context = SamplingContext::new(LogLevel::Info).with_priority(15);
assert_eq!(context.priority, 10);
}
#[test]
fn test_probabilistic_rate_clamping() {
let sampler_high = ProbabilisticSampler::new(1.5);
let sampler_low = ProbabilisticSampler::new(-0.5);
assert!(sampler_high.rate <= 1.0);
assert!(sampler_low.rate >= 0.0);
}
}