use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::time::Instant;
pub const DEFAULT_MAX_EVENTS_PER_SECOND: u64 = 1000;
const DROPPED_SUMMARY_INTERVAL: Duration = Duration::from_secs(10);
struct RateLimiterInner {
max_events_per_second: u64,
event_count: AtomicU64,
dropped_count: AtomicU64,
total_dropped: AtomicU64,
window_start_nanos: AtomicU64,
last_summary_nanos: AtomicU64,
start_instant: Instant,
}
#[derive(Clone)]
pub struct RateLimiter {
inner: Arc<RateLimiterInner>,
}
impl RateLimiter {
pub fn new(max_events_per_second: u64) -> Self {
let now = Instant::now();
Self {
inner: Arc::new(RateLimiterInner {
max_events_per_second,
event_count: AtomicU64::new(0),
dropped_count: AtomicU64::new(0),
total_dropped: AtomicU64::new(0),
window_start_nanos: AtomicU64::new(0),
last_summary_nanos: AtomicU64::new(0),
start_instant: now,
}),
}
}
pub fn with_defaults() -> Self {
Self::new(DEFAULT_MAX_EVENTS_PER_SECOND)
}
fn elapsed_nanos(&self) -> u64 {
self.inner.start_instant.elapsed().as_nanos() as u64
}
pub fn should_allow(&self) -> bool {
let now_nanos = self.elapsed_nanos();
let window_start = self.inner.window_start_nanos.load(Ordering::Relaxed);
const ONE_SECOND_NANOS: u64 = 1_000_000_000;
if now_nanos >= window_start + ONE_SECOND_NANOS {
if self
.inner
.window_start_nanos
.compare_exchange(window_start, now_nanos, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
self.inner.event_count.store(0, Ordering::Relaxed);
}
}
let count = self.inner.event_count.fetch_add(1, Ordering::Relaxed);
if count < self.inner.max_events_per_second {
self.maybe_log_summary();
true
} else {
self.inner.dropped_count.fetch_add(1, Ordering::Relaxed);
self.inner.total_dropped.fetch_add(1, Ordering::Relaxed);
self.maybe_log_summary();
false
}
}
fn maybe_log_summary(&self) {
let dropped = self.inner.dropped_count.load(Ordering::Relaxed);
if dropped == 0 {
return;
}
let now_nanos = self.elapsed_nanos();
let last_summary = self.inner.last_summary_nanos.load(Ordering::Relaxed);
let interval_nanos = DROPPED_SUMMARY_INTERVAL.as_nanos() as u64;
if now_nanos >= last_summary + interval_nanos {
if self
.inner
.last_summary_nanos
.compare_exchange(last_summary, now_nanos, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
let dropped = self.inner.dropped_count.swap(0, Ordering::Relaxed);
let total = self.inner.total_dropped.load(Ordering::Relaxed);
if dropped > 0 {
eprintln!(
"[RATE LIMIT] Dropped {} log events in last {}s (total dropped: {}). \
Max rate: {}/sec. Consider investigating log spam.",
dropped,
DROPPED_SUMMARY_INTERVAL.as_secs(),
total,
self.inner.max_events_per_second
);
}
}
}
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_allows_under_limit() {
let limiter = RateLimiter::new(100);
for _ in 0..100 {
assert!(limiter.should_allow());
}
}
#[test]
fn test_rate_limiter_drops_over_limit() {
let limiter = RateLimiter::new(10);
for _ in 0..10 {
assert!(limiter.should_allow());
}
for _ in 0..10 {
assert!(!limiter.should_allow());
}
assert_eq!(limiter.inner.dropped_count.load(Ordering::Relaxed), 10);
assert_eq!(limiter.inner.total_dropped.load(Ordering::Relaxed), 10);
}
#[test]
fn test_rate_limiter_resets_after_window() {
let limiter = RateLimiter::new(5);
for _ in 0..5 {
assert!(limiter.should_allow());
}
assert!(!limiter.should_allow());
limiter.inner.event_count.store(0, Ordering::Relaxed);
assert!(limiter.should_allow());
for _ in 0..4 {
assert!(limiter.should_allow());
}
assert!(!limiter.should_allow());
}
#[test]
fn test_default_limit() {
let limiter = RateLimiter::with_defaults();
assert_eq!(
limiter.inner.max_events_per_second,
DEFAULT_MAX_EVENTS_PER_SECOND
);
}
#[test]
fn test_rate_limiter_is_cloneable() {
let limiter1 = RateLimiter::new(10);
let limiter2 = limiter1.clone();
for _ in 0..10 {
assert!(limiter1.should_allow());
}
assert!(!limiter2.should_allow());
}
}