use std::time::Duration;
use std::time::Instant;
use dashmap::DashMap;
use tracing::Subscriber;
use tracing::callsite::Identifier;
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context;
#[derive(Default)]
struct RateLimitState {
last_logged: Option<Instant>,
suppressed_count: u64,
}
pub(crate) struct RateLimitLayer {
target_prefix: &'static str,
states: DashMap<Identifier, RateLimitState>,
threshold: Duration,
}
impl RateLimitLayer {
pub(crate) fn new(target_prefix: &'static str, threshold: Duration) -> Self {
Self {
target_prefix,
states: DashMap::new(),
threshold,
}
}
#[cfg(test)]
fn suppressed_count(&self) -> u64 {
self.states.iter().map(|e| e.suppressed_count).sum()
}
fn is_allowed(&self, callsite: Identifier) -> bool {
let now = Instant::now();
let mut entry = self.states.entry(callsite.clone()).or_default();
let state = entry.value_mut();
let allowed = match state.last_logged {
None => true, Some(last) => now.duration_since(last) >= self.threshold,
};
if allowed {
state.last_logged = Some(now);
state.suppressed_count = 0;
} else {
state.suppressed_count += 1;
}
allowed
}
}
impl<S> Layer<S> for RateLimitLayer
where
S: Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
{
fn event_enabled(&self, event: &tracing::Event<'_>, _ctx: Context<'_, S>) -> bool {
let metadata = event.metadata();
if !metadata.target().starts_with(self.target_prefix) {
return true;
}
if *metadata.level() > tracing::Level::WARN {
return true;
}
self.is_allowed(metadata.callsite())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tracing::Subscriber;
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context;
use tracing_subscriber::layer::SubscriberExt;
use super::*;
struct CountingLayer {
count: Arc<AtomicUsize>,
}
impl<S: Subscriber> Layer<S> for CountingLayer {
fn on_event(&self, _event: &tracing::Event<'_>, _ctx: Context<'_, S>) {
self.count.fetch_add(1, Ordering::Relaxed);
}
}
#[test]
fn test_rate_limiting_suppresses_rapid_messages() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..10 {
tracing::warn!(target: "opentelemetry::trace::exporter", "export failed");
}
});
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[test]
fn test_different_callsites_not_suppressed() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
tracing::warn!(target: "opentelemetry::trace::exporter", "trace error");
tracing::warn!(target: "opentelemetry::metrics::exporter", "metric error");
tracing::warn!(target: "opentelemetry::other", "other error");
});
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[test]
fn test_non_otel_messages_not_affected() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..10 {
tracing::warn!(target: "apollo_router", "normal message");
}
});
assert_eq!(count.load(Ordering::Relaxed), 10);
}
#[test]
fn test_messages_allowed_after_threshold() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(50));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
fn emit_otel_warning() {
tracing::warn!(target: "opentelemetry::trace", "message");
}
tracing::subscriber::with_default(subscriber, || {
emit_otel_warning(); emit_otel_warning();
std::thread::sleep(Duration::from_millis(60));
emit_otel_warning(); });
assert_eq!(count.load(Ordering::Relaxed), 2);
}
#[test]
fn test_debug_level_not_rate_limited() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..5 {
tracing::debug!(target: "opentelemetry::trace", "debug message");
}
});
assert_eq!(count.load(Ordering::Relaxed), 5);
}
#[test]
fn test_warn_level_is_rate_limited() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..10 {
tracing::warn!(target: "opentelemetry::trace::exporter", "export warning");
}
});
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[test]
fn test_error_level_is_rate_limited() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..10 {
tracing::error!(target: "opentelemetry::trace::exporter", "export error");
}
});
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[test]
fn test_info_level_not_rate_limited() {
let count = Arc::new(AtomicUsize::new(0));
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
let subscriber = tracing_subscriber::registry()
.with(rate_limiter)
.with(CountingLayer {
count: count.clone(),
});
tracing::subscriber::with_default(subscriber, || {
for _ in 0..5 {
tracing::info!(target: "opentelemetry::trace", "info message");
}
});
assert_eq!(count.load(Ordering::Relaxed), 5);
}
#[test]
fn test_suppression_count_tracked() {
let rate_limiter = RateLimitLayer::new("opentelemetry", Duration::from_millis(100));
static TEST_CALLSITE: tracing_core::callsite::DefaultCallsite =
tracing_core::callsite::DefaultCallsite::new(&TEST_META);
static TEST_META: tracing_core::Metadata<'static> = tracing_core::metadata! {
name: "test",
target: "opentelemetry::test",
level: tracing_core::Level::WARN,
fields: &[],
callsite: &TEST_CALLSITE,
kind: tracing_core::metadata::Kind::EVENT,
};
let callsite = TEST_META.callsite();
assert!(rate_limiter.is_allowed(callsite.clone()));
assert_eq!(rate_limiter.suppressed_count(), 0);
for i in 1..=9 {
assert!(!rate_limiter.is_allowed(callsite.clone()));
assert_eq!(rate_limiter.suppressed_count(), i);
}
}
}