use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RateLimiter {
config: RateLimitConfig,
state: Arc<RwLock<RateLimitState>>,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub default_limit: EndpointLimit,
pub endpoint_limits: HashMap<String, EndpointLimit>,
pub enabled: bool,
pub cleanup_interval: Duration,
}
#[derive(Debug, Clone)]
pub struct EndpointLimit {
pub requests: u32,
pub window: Duration,
pub burst: u32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct RateLimitKey {
pub key_type: String,
pub value: String,
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub retry_after: Duration,
pub current_count: u32,
pub limit: u32,
pub window: Duration,
}
#[derive(Debug, Default)]
struct RateLimitState {
entries: HashMap<(RateLimitKey, String), RequestTracker>,
last_cleanup: Option<Instant>,
}
#[derive(Debug, Clone)]
struct RequestTracker {
timestamps: Vec<Instant>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(RateLimitState::default())),
}
}
pub fn for_auth() -> Self {
let config = RateLimitConfig::builder()
.default_limit(100, Duration::from_secs(60))
.endpoint_limit(
"login",
EndpointLimit {
requests: 5,
window: Duration::from_secs(60),
burst: 2,
},
)
.endpoint_limit(
"token",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 3,
},
)
.endpoint_limit(
"refresh",
EndpointLimit {
requests: 20,
window: Duration::from_secs(60),
burst: 5,
},
)
.endpoint_limit(
"authorize",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 3,
},
)
.endpoint_limit(
"revoke",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 2,
},
)
.build();
Self::new(config)
}
pub fn disabled() -> Self {
Self::new(RateLimitConfig {
default_limit: EndpointLimit {
requests: u32::MAX,
window: Duration::from_secs(1),
burst: 0,
},
endpoint_limits: HashMap::new(),
enabled: false,
cleanup_interval: Duration::from_secs(3600),
})
}
pub async fn check(&self, key: &RateLimitKey, endpoint: &str) -> Result<(), RateLimitInfo> {
if !self.config.enabled {
return Ok(());
}
let limit = self
.config
.endpoint_limits
.get(endpoint)
.unwrap_or(&self.config.default_limit);
let now = Instant::now();
let mut state = self.state.write().await;
self.maybe_cleanup(&mut state, now);
let entry_key = (key.clone(), endpoint.to_string());
let tracker = state
.entries
.entry(entry_key)
.or_insert_with(|| RequestTracker {
timestamps: Vec::new(),
});
let window_start = now - limit.window;
tracker.timestamps.retain(|&t| t > window_start);
let current_count = tracker.timestamps.len() as u32;
let effective_limit = limit.requests + limit.burst;
if current_count >= effective_limit {
let oldest = tracker.timestamps.first().copied().unwrap_or(now);
let retry_after = limit.window - (now - oldest);
crate::auth_metrics::record_rate_limited(endpoint, &key.key_type);
return Err(RateLimitInfo {
retry_after,
current_count,
limit: limit.requests,
window: limit.window,
});
}
tracker.timestamps.push(now);
Ok(())
}
pub async fn record(&self, key: &RateLimitKey, endpoint: &str) {
if !self.config.enabled {
return;
}
let now = Instant::now();
let mut state = self.state.write().await;
let entry_key = (key.clone(), endpoint.to_string());
let tracker = state
.entries
.entry(entry_key)
.or_insert_with(|| RequestTracker {
timestamps: Vec::new(),
});
tracker.timestamps.push(now);
}
pub async fn get_usage(&self, key: &RateLimitKey, endpoint: &str) -> Option<(u32, u32)> {
let limit = self
.config
.endpoint_limits
.get(endpoint)
.unwrap_or(&self.config.default_limit);
let now = Instant::now();
let state = self.state.read().await;
let entry_key = (key.clone(), endpoint.to_string());
state.entries.get(&entry_key).map(|tracker| {
let window_start = now - limit.window;
let current = tracker
.timestamps
.iter()
.filter(|&&t| t > window_start)
.count() as u32;
(current, limit.requests)
})
}
pub async fn reset(&self, key: &RateLimitKey) {
let mut state = self.state.write().await;
state.entries.retain(|(k, _), _| k != key);
}
pub async fn reset_all(&self) {
let mut state = self.state.write().await;
state.entries.clear();
}
fn maybe_cleanup(&self, state: &mut RateLimitState, now: Instant) {
let should_cleanup = state
.last_cleanup
.map(|t| now - t > self.config.cleanup_interval)
.unwrap_or(true);
if should_cleanup {
let max_window = self
.config
.endpoint_limits
.values()
.map(|l| l.window)
.max()
.unwrap_or(self.config.default_limit.window);
let cutoff = now - max_window * 2;
state.entries.retain(|_, tracker| {
tracker
.timestamps
.last()
.map(|&t| t > cutoff)
.unwrap_or(false)
});
state.last_cleanup = Some(now);
}
}
}
impl RateLimitConfig {
pub fn builder() -> RateLimitConfigBuilder {
RateLimitConfigBuilder::default()
}
}
#[derive(Debug, Default)]
pub struct RateLimitConfigBuilder {
default_limit: Option<EndpointLimit>,
endpoint_limits: HashMap<String, EndpointLimit>,
enabled: bool,
cleanup_interval: Option<Duration>,
}
impl RateLimitConfigBuilder {
pub fn default_limit(mut self, requests: u32, window: Duration) -> Self {
self.default_limit = Some(EndpointLimit {
requests,
window,
burst: requests / 10, });
self.enabled = true;
self
}
pub fn endpoint_limit(mut self, endpoint: impl Into<String>, limit: EndpointLimit) -> Self {
self.endpoint_limits.insert(endpoint.into(), limit);
self.enabled = true;
self
}
pub fn cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = Some(interval);
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn build(self) -> RateLimitConfig {
RateLimitConfig {
default_limit: self.default_limit.unwrap_or(EndpointLimit {
requests: 100,
window: Duration::from_secs(60),
burst: 10,
}),
endpoint_limits: self.endpoint_limits,
enabled: self.enabled,
cleanup_interval: self.cleanup_interval.unwrap_or(Duration::from_secs(300)),
}
}
}
impl RateLimitKey {
pub fn ip(ip: impl Into<String>) -> Self {
Self {
key_type: "ip".to_string(),
value: ip.into(),
}
}
pub fn user(user_id: impl Into<String>) -> Self {
Self {
key_type: "user".to_string(),
value: user_id.into(),
}
}
pub fn api_key_prefix(prefix: impl Into<String>) -> Self {
Self {
key_type: "api_key".to_string(),
value: prefix.into(),
}
}
pub fn session(session_id: impl Into<String>) -> Self {
Self {
key_type: "session".to_string(),
value: session_id.into(),
}
}
pub fn composite(components: Vec<(&str, &str)>) -> Self {
let value = components
.into_iter()
.map(|(k, v)| format!("{k}:{v}"))
.collect::<Vec<_>>()
.join("|");
Self {
key_type: "composite".to_string(),
value,
}
}
}
impl std::fmt::Display for RateLimitInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Rate limited: {}/{} requests in {:?}, retry after {:?}",
self.current_count, self.limit, self.window, self.retry_after
)
}
}
impl std::error::Error for RateLimitInfo {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_under_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.default_limit(5, Duration::from_secs(60))
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
for _ in 0..5 {
assert!(limiter.check(&key, "test").await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_blocks_over_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 2,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
assert!(limiter.check(&key, "test").await.is_ok());
assert!(limiter.check(&key, "test").await.is_ok());
let result = limiter.check(&key, "test").await;
assert!(result.is_err());
let info = result.unwrap_err();
assert_eq!(info.current_count, 2);
assert_eq!(info.limit, 2);
}
#[tokio::test]
async fn test_rate_limiter_allows_burst() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 2,
window: Duration::from_secs(60),
burst: 2,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
for i in 0..4 {
assert!(
limiter.check(&key, "test").await.is_ok(),
"Request {} should be allowed",
i
);
}
assert!(limiter.check(&key, "test").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_disabled() {
let limiter = RateLimiter::disabled();
let key = RateLimitKey::ip("192.168.1.1");
for _ in 0..1000 {
assert!(limiter.check(&key, "test").await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_different_keys() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key1 = RateLimitKey::ip("192.168.1.1");
let key2 = RateLimitKey::ip("192.168.1.2");
assert!(limiter.check(&key1, "test").await.is_ok());
assert!(limiter.check(&key2, "test").await.is_ok());
assert!(limiter.check(&key1, "test").await.is_err());
assert!(limiter.check(&key2, "test").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_reset() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
assert!(limiter.check(&key, "test").await.is_ok());
assert!(limiter.check(&key, "test").await.is_err());
limiter.reset(&key).await;
assert!(limiter.check(&key, "test").await.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_for_auth() {
let limiter = RateLimiter::for_auth();
let key = RateLimitKey::ip("192.168.1.1");
for i in 0..7 {
assert!(
limiter.check(&key, "login").await.is_ok(),
"Login request {} should be allowed",
i
);
}
assert!(limiter.check(&key, "login").await.is_err());
}
#[test]
fn test_rate_limit_key_creation() {
let ip_key = RateLimitKey::ip("10.0.0.1");
assert_eq!(ip_key.key_type, "ip");
assert_eq!(ip_key.value, "10.0.0.1");
let user_key = RateLimitKey::user("user123");
assert_eq!(user_key.key_type, "user");
let composite = RateLimitKey::composite(vec![("ip", "10.0.0.1"), ("endpoint", "/login")]);
assert_eq!(composite.key_type, "composite");
assert_eq!(composite.value, "ip:10.0.0.1|endpoint:/login");
}
#[tokio::test]
async fn test_get_usage() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.default_limit(10, Duration::from_secs(60))
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
assert!(limiter.get_usage(&key, "test").await.is_none());
limiter.check(&key, "test").await.ok();
limiter.check(&key, "test").await.ok();
limiter.check(&key, "test").await.ok();
let usage = limiter.get_usage(&key, "test").await;
assert_eq!(usage, Some((3, 10)));
}
}