use dashmap::DashMap;
use parking_lot::Mutex;
use std::time::{Duration, Instant};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
pub requests_per_second: f64,
pub burst_size: u32,
pub per_client: bool,
pub global_limit: Option<u32>,
pub idle_timeout: Duration,
}
impl RateLimiterConfig {
pub fn new(requests_per_second: f64, burst_size: u32) -> Self {
Self {
requests_per_second,
burst_size,
per_client: true,
global_limit: None,
idle_timeout: Duration::from_secs(300), }
}
#[must_use]
pub fn with_per_client(mut self, per_client: bool) -> Self {
self.per_client = per_client;
self
}
#[must_use]
pub fn with_global_limit(mut self, limit: u32) -> Self {
self.global_limit = Some(limit);
self
}
#[must_use]
pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self::new(100.0, 50)
}
}
#[derive(Debug, Clone)]
pub enum RateLimitError {
GlobalLimitExceeded {
retry_after_ms: u64,
},
ClientLimitExceeded {
client_id: String,
retry_after_ms: u64,
},
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RateLimitError::GlobalLimitExceeded { retry_after_ms } => {
write!(
f,
"Global rate limit exceeded, retry after {}ms",
retry_after_ms
)
}
RateLimitError::ClientLimitExceeded {
client_id,
retry_after_ms,
} => {
write!(
f,
"Rate limit exceeded for client '{}', retry after {}ms",
client_id, retry_after_ms
)
}
}
}
}
impl std::error::Error for RateLimitError {}
#[derive(Debug)]
pub struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
self.last_refill = now;
}
pub fn try_acquire(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
pub fn remaining(&mut self) -> u32 {
self.refill();
self.tokens.floor().max(0.0) as u32
}
pub fn retry_after_ms(&self) -> u64 {
if self.tokens >= 1.0 {
return 0;
}
let deficit = 1.0 - self.tokens;
if self.refill_rate <= 0.0 {
return u64::MAX;
}
let seconds = deficit / self.refill_rate;
(seconds * 1000.0).ceil() as u64
}
pub fn reset(&mut self) {
self.tokens = self.max_tokens;
self.last_refill = Instant::now();
}
pub fn last_access(&self) -> Instant {
self.last_refill
}
}
pub struct RateLimiter {
config: RateLimiterConfig,
global_bucket: Mutex<TokenBucket>,
client_buckets: DashMap<String, Mutex<TokenBucket>>,
}
impl RateLimiter {
pub fn new(config: RateLimiterConfig) -> Self {
let global_max = config
.global_limit
.map(f64::from)
.unwrap_or(config.requests_per_second * 2.0);
let global_rate = config
.global_limit
.map(f64::from)
.unwrap_or(config.requests_per_second * 2.0);
Self {
config: config.clone(),
global_bucket: Mutex::new(TokenBucket::new(global_max, global_rate)),
client_buckets: DashMap::new(),
}
}
pub fn check_rate_limit(&self, client_id: &str) -> Result<(), RateLimitError> {
if self.config.per_client {
let bucket = self
.client_buckets
.entry(client_id.to_string())
.or_insert_with(|| {
Mutex::new(TokenBucket::new(
f64::from(self.config.burst_size),
self.config.requests_per_second,
))
});
let mut bucket_guard = bucket.lock();
if !bucket_guard.try_acquire() {
let retry_after_ms = bucket_guard.retry_after_ms();
debug!(
client_id = %client_id,
retry_after_ms = retry_after_ms,
"Per-client rate limit exceeded"
);
return Err(RateLimitError::ClientLimitExceeded {
client_id: client_id.to_string(),
retry_after_ms,
});
}
}
if self.config.global_limit.is_some() {
let mut global = self.global_bucket.lock();
if !global.try_acquire() {
let retry_after_ms = global.retry_after_ms();
warn!(
client_id = %client_id,
retry_after_ms = retry_after_ms,
"Global rate limit exceeded"
);
return Err(RateLimitError::GlobalLimitExceeded { retry_after_ms });
}
}
Ok(())
}
pub fn try_acquire(&self, client_id: &str) -> bool {
self.check_rate_limit(client_id).is_ok()
}
pub fn remaining_tokens(&self, client_id: &str) -> u32 {
if self.config.per_client {
if let Some(bucket) = self.client_buckets.get(client_id) {
return bucket.lock().remaining();
}
return self.config.burst_size;
}
self.global_bucket.lock().remaining()
}
pub fn cleanup_expired_buckets(&self) -> usize {
let now = Instant::now();
let timeout = self.config.idle_timeout;
let mut removed = 0;
let expired_keys: Vec<String> = self
.client_buckets
.iter()
.filter_map(|entry| {
let bucket = entry.value().lock();
if now.duration_since(bucket.last_access()) > timeout {
Some(entry.key().clone())
} else {
None
}
})
.collect();
for key in &expired_keys {
if let Some((_k, bucket)) = self.client_buckets.remove(key) {
let guard = bucket.lock();
if now.duration_since(guard.last_access()) > timeout {
removed += 1;
debug!(client_id = %key, "Cleaned up expired rate limiter bucket");
} else {
drop(guard);
self.client_buckets.insert(key.clone(), bucket);
}
}
}
if removed > 0 {
debug!(count = removed, "Cleaned up expired rate limiter buckets");
}
removed
}
pub fn reset(&self, client_id: &str) {
if let Some(bucket) = self.client_buckets.get(client_id) {
bucket.lock().reset();
}
}
pub fn reset_all(&self) {
self.global_bucket.lock().reset();
self.client_buckets.clear();
}
pub fn tracked_client_count(&self) -> usize {
self.client_buckets.len()
}
pub fn config(&self) -> &RateLimiterConfig {
&self.config
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("config", &self.config)
.field("tracked_clients", &self.client_buckets.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_token_bucket_basic() {
let mut bucket = TokenBucket::new(5.0, 10.0);
for _ in 0..5 {
assert!(
bucket.try_acquire(),
"Should acquire token from full bucket"
);
}
assert!(!bucket.try_acquire(), "Should fail when bucket is depleted");
}
#[test]
fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(3.0, 100.0);
for _ in 0..3 {
assert!(bucket.try_acquire());
}
assert!(!bucket.try_acquire(), "Bucket should be empty");
thread::sleep(Duration::from_millis(25));
assert!(
bucket.try_acquire(),
"Should have refilled at least one token after 25ms at 100/s"
);
}
#[test]
fn test_token_bucket_remaining() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert_eq!(bucket.remaining(), 10);
assert!(bucket.try_acquire());
assert_eq!(bucket.remaining(), 9);
}
#[test]
fn test_token_bucket_retry_after() {
let mut bucket = TokenBucket::new(1.0, 10.0);
assert!(bucket.try_acquire());
assert!(!bucket.try_acquire());
let retry = bucket.retry_after_ms();
assert!(
retry <= 110,
"retry_after_ms should be approximately 100ms, got {}",
retry
);
assert!(retry > 0, "retry_after_ms should be > 0 when depleted");
}
#[test]
fn test_token_bucket_reset() {
let mut bucket = TokenBucket::new(5.0, 1.0);
for _ in 0..5 {
assert!(bucket.try_acquire());
}
assert!(!bucket.try_acquire());
bucket.reset();
assert_eq!(bucket.remaining(), 5);
assert!(bucket.try_acquire());
}
#[test]
fn test_per_client_isolation() {
let config = RateLimiterConfig::new(1000.0, 3).with_per_client(true);
let limiter = RateLimiter::new(config);
for _ in 0..3 {
assert!(limiter.check_rate_limit("client-a").is_ok());
}
assert!(
limiter.check_rate_limit("client-a").is_err(),
"Client A should be rate limited"
);
assert!(
limiter.check_rate_limit("client-b").is_ok(),
"Client B should not be affected by Client A's limit"
);
}
#[test]
fn test_global_limit() {
let config = RateLimiterConfig::new(1000.0, 10)
.with_per_client(false)
.with_global_limit(3);
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("client-a").is_ok());
assert!(limiter.check_rate_limit("client-b").is_ok());
assert!(limiter.check_rate_limit("client-c").is_ok());
let result = limiter.check_rate_limit("client-d");
assert!(result.is_err(), "Global limit should be enforced");
match result {
Err(RateLimitError::GlobalLimitExceeded { retry_after_ms }) => {
assert!(retry_after_ms > 0);
}
other => panic!("Expected GlobalLimitExceeded, got {:?}", other),
}
}
#[test]
fn test_burst_handling() {
let config = RateLimiterConfig::new(10.0, 20).with_per_client(true);
let limiter = RateLimiter::new(config);
let mut allowed = 0;
for _ in 0..25 {
if limiter.check_rate_limit("burst-client").is_ok() {
allowed += 1;
}
}
assert_eq!(
allowed, 20,
"Should allow exactly burst_size requests in a burst"
);
}
#[test]
fn test_cleanup_expired() {
let config = RateLimiterConfig::new(100.0, 5)
.with_per_client(true)
.with_idle_timeout(Duration::from_millis(50));
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("client-1").is_ok());
assert!(limiter.check_rate_limit("client-2").is_ok());
assert_eq!(limiter.tracked_client_count(), 2);
thread::sleep(Duration::from_millis(80));
let removed = limiter.cleanup_expired_buckets();
assert_eq!(removed, 2, "Both idle clients should be cleaned up");
assert_eq!(limiter.tracked_client_count(), 0);
}
#[test]
fn test_cleanup_keeps_active() {
let config = RateLimiterConfig::new(100.0, 5)
.with_per_client(true)
.with_idle_timeout(Duration::from_millis(100));
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("active-client").is_ok());
thread::sleep(Duration::from_millis(30));
assert!(limiter.check_rate_limit("active-client").is_ok());
assert!(limiter.check_rate_limit("idle-client").is_ok());
thread::sleep(Duration::from_millis(120));
assert!(limiter.check_rate_limit("active-client").is_ok());
let removed = limiter.cleanup_expired_buckets();
assert_eq!(removed, 1, "Only idle client should be cleaned up");
assert_eq!(limiter.tracked_client_count(), 1);
}
#[test]
fn test_rate_limit_error_display() {
let global_err = RateLimitError::GlobalLimitExceeded { retry_after_ms: 42 };
let msg = format!("{}", global_err);
assert!(msg.contains("Global rate limit exceeded"));
assert!(msg.contains("42ms"));
let client_err = RateLimitError::ClientLimitExceeded {
client_id: "test-client".to_string(),
retry_after_ms: 100,
};
let msg = format!("{}", client_err);
assert!(msg.contains("test-client"));
assert!(msg.contains("100ms"));
}
#[test]
fn test_rate_limit_error_details() {
let config = RateLimiterConfig::new(10.0, 2).with_per_client(true);
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("err-client").is_ok());
assert!(limiter.check_rate_limit("err-client").is_ok());
let result = limiter.check_rate_limit("err-client");
match result {
Err(RateLimitError::ClientLimitExceeded {
client_id,
retry_after_ms,
}) => {
assert_eq!(client_id, "err-client");
assert!(retry_after_ms > 0);
}
other => panic!("Expected ClientLimitExceeded, got {:?}", other),
}
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
let config = RateLimiterConfig::new(1000.0, 100).with_per_client(true);
let limiter = Arc::new(RateLimiter::new(config));
let mut handles = Vec::new();
for i in 0..8 {
let limiter = Arc::clone(&limiter);
let handle = thread::spawn(move || {
let client_id = format!("thread-client-{}", i);
let mut allowed = 0u32;
for _ in 0..50 {
if limiter.check_rate_limit(&client_id).is_ok() {
allowed += 1;
}
}
allowed
});
handles.push(handle);
}
let mut total_allowed = 0u32;
for handle in handles {
let count = handle.join().expect("Thread panicked");
total_allowed += count;
}
assert_eq!(
total_allowed, 400,
"All requests should be allowed (50 per thread * 8 threads)"
);
}
#[test]
fn test_concurrent_same_client() {
use std::sync::Arc;
let config = RateLimiterConfig::new(0.001, 50).with_per_client(true);
let limiter = Arc::new(RateLimiter::new(config));
let mut handles = Vec::new();
for _ in 0..4 {
let limiter = Arc::clone(&limiter);
let handle = thread::spawn(move || {
let mut allowed = 0u32;
for _ in 0..20 {
if limiter.check_rate_limit("shared-client").is_ok() {
allowed += 1;
}
}
allowed
});
handles.push(handle);
}
let mut total_allowed = 0u32;
for handle in handles {
let count = handle.join().expect("Thread panicked");
total_allowed += count;
}
assert_eq!(
total_allowed, 50,
"Total allowed should equal burst_size for shared client"
);
}
#[test]
fn test_try_acquire_convenience() {
let config = RateLimiterConfig::new(100.0, 2).with_per_client(true);
let limiter = RateLimiter::new(config);
assert!(limiter.try_acquire("client-x"));
assert!(limiter.try_acquire("client-x"));
assert!(!limiter.try_acquire("client-x"));
}
#[test]
fn test_remaining_tokens() {
let config = RateLimiterConfig::new(100.0, 5).with_per_client(true);
let limiter = RateLimiter::new(config);
assert_eq!(limiter.remaining_tokens("new-client"), 5);
assert!(limiter.check_rate_limit("new-client").is_ok());
assert_eq!(limiter.remaining_tokens("new-client"), 4);
}
#[test]
fn test_reset_client() {
let config = RateLimiterConfig::new(100.0, 3).with_per_client(true);
let limiter = RateLimiter::new(config);
for _ in 0..3 {
assert!(limiter.check_rate_limit("reset-client").is_ok());
}
assert!(limiter.check_rate_limit("reset-client").is_err());
limiter.reset("reset-client");
assert!(
limiter.check_rate_limit("reset-client").is_ok(),
"Should be able to make requests after reset"
);
}
#[test]
fn test_reset_all() {
let config = RateLimiterConfig::new(100.0, 2)
.with_per_client(true)
.with_global_limit(5);
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("a").is_ok());
assert!(limiter.check_rate_limit("b").is_ok());
assert_eq!(limiter.tracked_client_count(), 2);
limiter.reset_all();
assert_eq!(limiter.tracked_client_count(), 0);
}
#[test]
fn test_config_default() {
let config = RateLimiterConfig::default();
assert!((config.requests_per_second - 100.0).abs() < f64::EPSILON);
assert_eq!(config.burst_size, 50);
assert!(config.per_client);
assert!(config.global_limit.is_none());
assert_eq!(config.idle_timeout, Duration::from_secs(300));
}
#[test]
fn test_config_builder_pattern() {
let config = RateLimiterConfig::new(200.0, 100)
.with_per_client(false)
.with_global_limit(500)
.with_idle_timeout(Duration::from_secs(60));
assert!((config.requests_per_second - 200.0).abs() < f64::EPSILON);
assert_eq!(config.burst_size, 100);
assert!(!config.per_client);
assert_eq!(config.global_limit, Some(500));
assert_eq!(config.idle_timeout, Duration::from_secs(60));
}
#[test]
fn test_debug_impl() {
let config = RateLimiterConfig::new(50.0, 10);
let limiter = RateLimiter::new(config);
let debug_str = format!("{:?}", limiter);
assert!(debug_str.contains("RateLimiter"));
assert!(debug_str.contains("tracked_clients"));
}
#[test]
fn test_global_and_per_client_combined() {
let config = RateLimiterConfig::new(0.001, 3)
.with_per_client(true)
.with_global_limit(5);
let limiter = RateLimiter::new(config);
assert!(limiter.check_rate_limit("a").is_ok());
assert!(limiter.check_rate_limit("a").is_ok());
assert!(limiter.check_rate_limit("a").is_ok());
assert!(
limiter.check_rate_limit("a").is_err(),
"Client A should hit per-client limit"
);
assert!(limiter.check_rate_limit("b").is_ok()); assert!(limiter.check_rate_limit("b").is_ok());
let result = limiter.check_rate_limit("b");
assert!(result.is_err(), "Should hit global limit");
assert!(
matches!(result, Err(RateLimitError::GlobalLimitExceeded { .. })),
"Error should be GlobalLimitExceeded"
);
}
#[test]
fn test_zero_refill_rate_retry_after() {
let bucket = TokenBucket::new(0.0, 0.0);
assert_eq!(bucket.retry_after_ms(), u64::MAX);
}
}