use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
pub requests_per_minute: u32,
pub requests_per_hour: u32,
pub burst_size: u32,
#[serde(default = "default_minute_capacity")]
pub minute_window_capacity: usize,
#[serde(default = "default_hour_capacity")]
pub hour_window_capacity: usize,
}
fn default_minute_capacity() -> usize {
1000
}
fn default_hour_capacity() -> usize {
10000
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 10,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 20,
minute_window_capacity: 1000,
hour_window_capacity: 10000,
}
}
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64,
last_update: Instant,
}
impl TokenBucket {
fn new(max_tokens: f64, refill_rate: f64) -> Self {
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_update: Instant::now(),
}
}
fn try_take(&mut self, tokens: f64) -> bool {
let elapsed = self.last_update.elapsed().as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
self.last_update = Instant::now();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
pub struct RateLimiter {
config: RateLimitConfig,
buckets: DashMap<String, TokenBucket>,
counters: DashMap<String, SlidingWindowCounter>,
}
#[derive(Debug)]
struct SlidingWindowCounter {
minute_requests: Vec<Instant>,
hour_requests: Vec<Instant>,
minute_capacity: usize,
hour_capacity: usize,
}
impl SlidingWindowCounter {
fn with_capacity(minute_capacity: usize, hour_capacity: usize) -> Self {
Self {
minute_requests: Vec::new(),
hour_requests: Vec::new(),
minute_capacity,
hour_capacity,
}
}
fn add_request(&mut self) {
let now = Instant::now();
self.minute_requests.push(now);
self.hour_requests.push(now);
self.minute_requests
.retain(|t| t.elapsed() < Duration::from_secs(60));
self.hour_requests
.retain(|t| t.elapsed() < Duration::from_secs(3600));
if self.minute_requests.len() > self.minute_capacity {
let excess = self.minute_requests.len() - self.minute_capacity;
self.minute_requests.drain(0..excess);
}
if self.hour_requests.len() > self.hour_capacity {
let excess = self.hour_requests.len() - self.hour_capacity;
self.hour_requests.drain(0..excess);
}
}
fn minute_count(&self) -> usize {
self.minute_requests.len()
}
fn hour_count(&self) -> usize {
self.hour_requests.len()
}
}
impl RateLimiter {
pub fn new() -> Self {
Self::with_config(RateLimitConfig::default())
}
pub fn with_config(config: RateLimitConfig) -> Self {
Self {
config: config.clone(),
buckets: DashMap::new(),
counters: DashMap::new(),
}
}
pub async fn check(&self, key: &str) -> anyhow::Result<bool> {
let bucket_result = {
let mut bucket = self.buckets.entry(key.to_string()).or_insert_with(|| {
TokenBucket::new(
self.config.burst_size as f64,
self.config.requests_per_second as f64,
)
});
bucket.try_take(1.0)
};
if !bucket_result {
return Ok(false);
}
let window_result = {
let minute_cap = self.config.minute_window_capacity;
let hour_cap = self.config.hour_window_capacity;
let mut counter = self
.counters
.entry(key.to_string())
.or_insert_with(|| SlidingWindowCounter::with_capacity(minute_cap, hour_cap));
let minute_exceeded =
counter.minute_count() >= self.config.requests_per_minute as usize;
let hour_exceeded = counter.hour_count() >= self.config.requests_per_hour as usize;
if minute_exceeded || hour_exceeded {
false
} else {
counter.add_request();
true
}
};
Ok(window_result)
}
pub fn reset(&self, key: &str) {
self.buckets.remove(key);
self.counters.remove(key);
}
pub fn get_status(&self, key: &str) -> RateLimitStatus {
let tokens_remaining = self
.buckets
.get(key)
.map(|b| b.tokens as u32)
.unwrap_or(self.config.burst_size);
let minute_remaining = self.config.requests_per_minute
- self
.counters
.get(key)
.map(|c| c.minute_count() as u32)
.unwrap_or(0);
let hour_remaining = self.config.requests_per_hour
- self
.counters
.get(key)
.map(|c| c.hour_count() as u32)
.unwrap_or(0);
RateLimitStatus {
tokens_remaining,
minute_remaining,
hour_remaining,
}
}
pub fn cleanup_expired(&self, max_age: Duration) {
let now = Instant::now();
self.buckets
.retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
self.counters.retain(|_, counter| {
!counter.minute_requests.is_empty() || !counter.hour_requests.is_empty()
});
}
pub fn active_keys(&self) -> usize {
self.buckets.len()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RateLimitStatus {
pub tokens_remaining: u32,
pub minute_remaining: u32,
pub hour_remaining: u32,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn test_basic_rate_limit() {
let limiter = RateLimiter::new();
for _ in 0..10 {
assert!(limiter.check("test_key").await.unwrap());
}
}
#[tokio::test]
async fn test_rate_limit_exceeded() {
let config = RateLimitConfig {
requests_per_second: 1,
requests_per_minute: 2,
requests_per_hour: 3,
burst_size: 2,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.check("test_key").await.unwrap());
assert!(limiter.check("test_key").await.unwrap());
assert!(!limiter.check("test_key").await.unwrap());
}
#[tokio::test]
async fn test_concurrent_requests() {
let config = RateLimitConfig {
requests_per_second: 100,
requests_per_minute: 1000,
requests_per_hour: 10000,
burst_size: 50,
..Default::default()
};
let limiter = Arc::new(RateLimiter::with_config(config));
let mut tasks = vec![];
for _ in 0..100 {
let limiter_clone = Arc::clone(&limiter);
tasks.push(tokio::spawn(async move {
limiter_clone.check("concurrent_key").await.unwrap()
}));
}
let results: Vec<bool> = futures::future::join_all(tasks)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
let success_count = results.iter().filter(|&&r| r).count();
let fail_count = results.iter().filter(|&&r| !r).count();
assert!(success_count > 0, "At least some requests should succeed");
println!("Success: {}, Fail: {}", success_count, fail_count);
}
#[tokio::test]
async fn test_burst_handling() {
let config = RateLimitConfig {
requests_per_second: 5,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 10,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let mut success_count = 0;
for _ in 0..20 {
if limiter.check("burst_key").await.unwrap() {
success_count += 1;
}
}
assert!(
success_count <= 11,
"Burst should be limited, but got {} successes",
success_count
);
assert!(
success_count >= 8,
"At least burst_size requests should succeed, but got {}",
success_count
);
}
#[tokio::test]
async fn test_token_refill_accuracy() {
let config = RateLimitConfig {
requests_per_second: 10,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 5,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
for _ in 0..5 {
assert!(limiter.check("refill_key").await.unwrap());
}
assert!(!limiter.check("refill_key").await.unwrap());
tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
assert!(
limiter.check("refill_key").await.unwrap(),
"Token should be refilled after waiting"
);
}
#[tokio::test]
async fn test_different_keys_isolated() {
let config = RateLimitConfig {
requests_per_second: 1,
requests_per_minute: 1,
requests_per_hour: 1,
burst_size: 1,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
assert!(limiter.check("key1").await.unwrap());
assert!(!limiter.check("key1").await.unwrap());
assert!(limiter.check("key2").await.unwrap());
assert!(!limiter.check("key2").await.unwrap());
}
#[test]
fn test_reset_functionality() {
let config = RateLimitConfig {
requests_per_second: 1,
requests_per_minute: 1,
requests_per_hour: 1,
burst_size: 1,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
assert!(limiter.check("reset_key").await.unwrap());
assert!(!limiter.check("reset_key").await.unwrap());
});
limiter.reset("reset_key");
rt.block_on(async {
assert!(limiter.check("reset_key").await.unwrap());
});
}
#[test]
fn test_status_reporting() {
let config = RateLimitConfig {
requests_per_second: 10,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 20,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
for _ in 0..5 {
limiter.check("status_key").await.unwrap();
}
});
let status = limiter.get_status("status_key");
assert!(status.tokens_remaining < 20, "Tokens should be consumed");
assert!(
status.minute_remaining < 100,
"Minute count should increase"
);
}
#[test]
fn test_cleanup_expired() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
limiter.check("key1").await.unwrap();
limiter.check("key2").await.unwrap();
});
assert!(limiter.active_keys() >= 2);
limiter.cleanup_expired(Duration::from_secs(0));
assert_eq!(limiter.active_keys(), 0);
}
#[test]
fn test_active_keys_count() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
limiter.check("key1").await.unwrap();
limiter.check("key2").await.unwrap();
limiter.check("key3").await.unwrap();
});
assert_eq!(limiter.active_keys(), 3);
}
#[test]
fn test_zero_rate_limit() {
let config = RateLimitConfig {
requests_per_second: 0,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 2,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let first = limiter.check("key").await.unwrap();
assert!(first, "First request with burst_size=2 should succeed");
let second = limiter.check("key").await.unwrap();
assert!(second, "Second request with burst_size=2 should succeed");
let third = limiter.check("key").await.unwrap();
assert!(!third, "Third request should be rate limited (no refill)");
});
}
#[test]
fn test_very_small_burst_size() {
let config = RateLimitConfig {
requests_per_second: 1,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 1,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
assert!(limiter.check("key").await.unwrap());
assert!(!limiter.check("key").await.unwrap());
});
}
#[test]
fn test_large_burst_size() {
let config = RateLimitConfig {
requests_per_second: 1000,
requests_per_minute: 100000,
requests_per_hour: 1000000,
burst_size: 1000,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let mut success_count = 0;
for _ in 0..500 {
if limiter.check("key").await.unwrap() {
success_count += 1;
}
}
assert!(
success_count >= 400,
"Should allow most requests with large burst"
);
});
}
#[test]
fn test_empty_key() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
assert!(limiter.check("").await.unwrap());
});
}
#[test]
fn test_special_characters_in_key() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let special_keys = vec![
"key:with:colons",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
"key/with/slashes",
];
for key in special_keys {
assert!(
limiter.check(key).await.unwrap(),
"Key '{}' should work",
key
);
}
});
}
#[test]
fn test_unicode_key() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
assert!(limiter.check("用户_123").await.unwrap());
assert!(limiter.check("🔑_key").await.unwrap());
});
}
#[test]
fn test_very_long_key() {
let limiter = RateLimiter::new();
let long_key = "a".repeat(10000);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
assert!(limiter.check(&long_key).await.unwrap());
});
}
#[test]
fn test_reset_nonexistent_key() {
let limiter = RateLimiter::new();
limiter.reset("nonexistent_key");
assert_eq!(limiter.active_keys(), 0);
}
#[test]
fn test_status_nonexistent_key() {
let limiter = RateLimiter::new();
let config = RateLimitConfig::default();
let status = limiter.get_status("nonexistent");
assert_eq!(status.tokens_remaining, config.burst_size);
assert_eq!(status.minute_remaining, config.requests_per_minute);
assert_eq!(status.hour_remaining, config.requests_per_hour);
}
#[tokio::test]
async fn test_rapid_requests() {
let config = RateLimitConfig {
requests_per_second: 10,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 5,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
let mut success_count = 0;
for _ in 0..20 {
if limiter.check("rapid").await.unwrap() {
success_count += 1;
}
}
assert!(
success_count <= 7,
"Expected ~5 successful requests, got {}",
success_count
);
}
#[test]
fn test_cleanup_with_negative_duration() {
let limiter = RateLimiter::new();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
limiter.check("key").await.unwrap();
});
limiter.cleanup_expired(Duration::from_secs(u64::MAX));
assert!(limiter.active_keys() >= 1);
}
#[tokio::test]
async fn test_status_accuracy() {
let config = RateLimitConfig {
requests_per_second: 10,
requests_per_minute: 100,
requests_per_hour: 1000,
burst_size: 10,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
for _ in 0..3 {
limiter.check("status_test").await.unwrap();
}
let status = limiter.get_status("status_test");
assert!(status.tokens_remaining < 10);
assert!(status.tokens_remaining > 0);
}
#[test]
fn test_config_default_values() {
let config = RateLimitConfig::default();
assert_eq!(config.requests_per_second, 10);
assert_eq!(config.requests_per_minute, 100);
assert_eq!(config.requests_per_hour, 1000);
assert_eq!(config.burst_size, 20);
}
#[test]
fn test_config_serialization() {
let config = RateLimitConfig::default();
let json = serde_json::to_string(&config).unwrap();
let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.requests_per_second, config.requests_per_second);
}
#[tokio::test]
async fn test_token_refill_boundary() {
let config = RateLimitConfig {
requests_per_second: 100, requests_per_minute: 10000,
requests_per_hour: 100000,
burst_size: 10,
..Default::default()
};
let limiter = RateLimiter::with_config(config);
for _ in 0..10 {
limiter.check("refill_boundary").await.unwrap();
}
assert!(!limiter.check("refill_boundary").await.unwrap());
tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
assert!(limiter.check("refill_boundary").await.unwrap());
}
}