use axum::{
extract::{ConnectInfo, Request},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use crate::api::models::ApiError;
use super::auth::UserContext;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_duration: Duration,
pub burst_capacity: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window_duration: Duration::from_secs(60),
burst_capacity: 10,
}
}
}
impl RateLimitConfig {
pub fn new(max_requests: u32, window_secs: u64, burst_capacity: u32) -> Self {
Self {
max_requests,
window_duration: Duration::from_secs(window_secs),
burst_capacity,
}
}
pub fn public() -> Self {
Self {
max_requests: 30,
window_duration: Duration::from_secs(60),
burst_capacity: 5,
}
}
pub fn authenticated() -> Self {
Self {
max_requests: 100,
window_duration: Duration::from_secs(60),
burst_capacity: 10,
}
}
pub fn heavy_operations() -> Self {
Self {
max_requests: 10,
window_duration: Duration::from_secs(60),
burst_capacity: 2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitInfo {
pub limit: u32,
pub remaining: u32,
pub reset_at: u64,
pub retry_after: Option<u64>,
}
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_refill: Instant,
window_start: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
let now = Instant::now();
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: now,
window_start: now,
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let tokens_to_add = elapsed * self.refill_rate;
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
fn try_consume(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn remaining(&mut self) -> u32 {
self.refill();
self.tokens.floor() as u32
}
fn time_until_token(&mut self) -> f64 {
self.refill();
if self.tokens >= 1.0 {
0.0
} else {
(1.0 - self.tokens) / self.refill_rate
}
}
fn reset_window(&mut self) {
let now = Instant::now();
if now.duration_since(self.window_start) >= Duration::from_secs(60) {
self.window_start = now;
}
}
}
#[derive(Clone)]
pub struct RateLimitMiddleware {
config: RateLimitConfig,
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
}
impl RateLimitMiddleware {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn default_config() -> Self {
Self::new(RateLimitConfig::default())
}
fn get_rate_limit_key(&self, req: &Request) -> String {
if let Some(user_ctx) = req.extensions().get::<UserContext>() {
return format!("user:{}", user_ctx.user_id);
}
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
return format!("ip:{}", addr.ip());
}
"unknown".to_string()
}
fn check_rate_limit(&self, key: &str) -> Result<RateLimitInfo, RateLimitInfo> {
let mut buckets = match self.buckets.lock() {
Ok(guard) => guard,
Err(poisoned) => {
eprintln!("WARNING: Rate limiter lock poisoned, recovering...");
poisoned.into_inner()
}
};
let refill_rate = self.config.max_requests as f64
/ self.config.window_duration.as_secs_f64();
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(self.config.burst_capacity as f64, refill_rate));
bucket.reset_window();
let allowed = bucket.try_consume();
let remaining = bucket.remaining();
let reset_at = (bucket.window_start + Duration::from_secs(60))
.duration_since(Instant::now())
.as_secs();
let info = RateLimitInfo {
limit: self.config.max_requests,
remaining,
reset_at,
retry_after: if allowed {
None
} else {
Some(bucket.time_until_token().ceil() as u64)
},
};
if allowed {
Ok(info)
} else {
Err(info)
}
}
pub fn cleanup_old_buckets(&self) {
let mut buckets = match self.buckets.lock() {
Ok(guard) => guard,
Err(poisoned) => {
eprintln!("WARNING: Rate limiter lock poisoned during cleanup, recovering...");
poisoned.into_inner()
}
};
let now = Instant::now();
buckets.retain(|_, bucket| {
now.duration_since(bucket.last_refill) < Duration::from_secs(3600)
});
}
}
pub async fn rate_limit_middleware(
rate_limiter: Arc<RateLimitMiddleware>,
req: Request,
next: Next,
) -> Result<Response, Response> {
let key = rate_limiter.get_rate_limit_key(&req);
match rate_limiter.check_rate_limit(&key) {
Ok(info) => {
let mut response = next.run(req).await;
let headers = response.headers_mut();
if let Ok(value) = info.limit.to_string().parse() {
headers.insert("X-RateLimit-Limit", value);
}
if let Ok(value) = info.remaining.to_string().parse() {
headers.insert("X-RateLimit-Remaining", value);
}
if let Ok(value) = info.reset_at.to_string().parse() {
headers.insert("X-RateLimit-Reset", value);
}
Ok(response)
}
Err(info) => {
let retry_after = info.retry_after.unwrap_or(60);
let error = ApiError::new(
StatusCode::TOO_MANY_REQUESTS,
"RateLimitExceeded",
"Rate limit exceeded",
);
let mut response = error.into_response();
let headers = response.headers_mut();
if let Ok(value) = info.limit.to_string().parse() {
headers.insert("X-RateLimit-Limit", value);
}
if let Ok(value) = "0".parse() {
headers.insert("X-RateLimit-Remaining", value);
}
if let Ok(value) = info.reset_at.to_string().parse() {
headers.insert("X-RateLimit-Reset", value);
}
if let Ok(value) = retry_after.to_string().parse() {
headers.insert("Retry-After", value);
}
Err(response)
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_config() {
let config = RateLimitConfig::default();
assert_eq!(config.max_requests, 100);
assert_eq!(config.window_duration, Duration::from_secs(60));
let public = RateLimitConfig::public();
assert_eq!(public.max_requests, 30);
let auth = RateLimitConfig::authenticated();
assert_eq!(auth.max_requests, 100);
let heavy = RateLimitConfig::heavy_operations();
assert_eq!(heavy.max_requests, 10);
}
#[test]
fn test_token_bucket_creation() {
let bucket = TokenBucket::new(10.0, 1.0);
assert_eq!(bucket.capacity, 10.0);
assert_eq!(bucket.tokens, 10.0);
assert_eq!(bucket.refill_rate, 1.0);
}
#[test]
fn test_token_bucket_consume() {
let mut bucket = TokenBucket::new(10.0, 1.0);
for _ in 0..10 {
assert!(bucket.try_consume());
}
assert!(!bucket.try_consume());
}
#[test]
fn test_token_bucket_remaining() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert_eq!(bucket.remaining(), 10);
bucket.try_consume();
assert_eq!(bucket.remaining(), 9);
for _ in 0..9 {
bucket.try_consume();
}
assert_eq!(bucket.remaining(), 0);
}
#[test]
fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(10.0, 10.0);
for _ in 0..10 {
assert!(bucket.try_consume());
}
assert!(!bucket.try_consume());
bucket.last_refill = Instant::now() - Duration::from_millis(500);
bucket.refill();
assert!(bucket.remaining() >= 4 && bucket.remaining() <= 5);
}
#[test]
fn test_rate_limit_middleware_creation() {
let config = RateLimitConfig::default();
let middleware = RateLimitMiddleware::new(config);
assert_eq!(middleware.config.max_requests, 100);
}
#[test]
fn test_rate_limit_check() {
let config = RateLimitConfig::new(10, 60, 5);
let middleware = RateLimitMiddleware::new(config);
for i in 0..5 {
let result = middleware.check_rate_limit("test-key");
assert!(result.is_ok(), "Request {} should succeed", i + 1);
let info = result.unwrap();
assert_eq!(info.limit, 10);
assert!(info.retry_after.is_none());
}
let result = middleware.check_rate_limit("test-key");
assert!(result.is_err(), "Request 6 should be rate limited");
let info = result.unwrap_err();
assert_eq!(info.remaining, 0);
assert!(info.retry_after.is_some());
}
#[test]
fn test_rate_limit_different_keys() {
let config = RateLimitConfig::new(2, 60, 2);
let middleware = RateLimitMiddleware::new(config);
assert!(middleware.check_rate_limit("key1").is_ok());
assert!(middleware.check_rate_limit("key2").is_ok());
assert!(middleware.check_rate_limit("key1").is_ok());
assert!(middleware.check_rate_limit("key2").is_ok());
assert!(middleware.check_rate_limit("key1").is_err());
assert!(middleware.check_rate_limit("key2").is_err());
}
#[test]
fn test_cleanup_old_buckets() {
let config = RateLimitConfig::default();
let middleware = RateLimitMiddleware::new(config);
middleware.check_rate_limit("key1").ok();
middleware.check_rate_limit("key2").ok();
middleware.check_rate_limit("key3").ok();
{
let buckets = middleware.buckets.lock().unwrap();
assert_eq!(buckets.len(), 3);
}
middleware.cleanup_old_buckets();
{
let buckets = middleware.buckets.lock().unwrap();
assert_eq!(buckets.len(), 3);
}
}
#[test]
fn test_rate_limit_info() {
let info = RateLimitInfo {
limit: 100,
remaining: 50,
reset_at: 1234567890,
retry_after: None,
};
assert_eq!(info.limit, 100);
assert_eq!(info.remaining, 50);
assert_eq!(info.reset_at, 1234567890);
assert!(info.retry_after.is_none());
let limited_info = RateLimitInfo {
limit: 100,
remaining: 0,
reset_at: 1234567890,
retry_after: Some(30),
};
assert_eq!(limited_info.remaining, 0);
assert_eq!(limited_info.retry_after, Some(30));
}
}