use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use hyper::body::Incoming;
use hyper::{Request, Response};
use crate::context::RequestContext;
use crate::error::Error;
use crate::response::{BoxBody, IntoResponse};
use super::{BoxFuture, Middleware, Next};
type KeyExtractorFn = Arc<dyn Fn(&Request<Incoming>) -> String + Send + Sync>;
const CLEANUP_INTERVAL: u64 = 1000;
const STALE_AFTER: Duration = Duration::from_secs(600);
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
last_refill: Instant,
}
#[derive(Clone)]
pub enum KeyExtractor {
Ip,
Custom(KeyExtractorFn),
}
impl std::fmt::Debug for KeyExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KeyExtractor::Ip => write!(f, "KeyExtractor::Ip"),
KeyExtractor::Custom(_) => write!(f, "KeyExtractor::Custom(...)"),
}
}
}
impl KeyExtractor {
fn extract(&self, req: &Request<Incoming>) -> String {
match self {
KeyExtractor::Ip => Self::extract_ip(req),
KeyExtractor::Custom(f) => f(req),
}
}
fn extract_ip(req: &Request<Incoming>) -> String {
if let Some(ip) = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
{
return ip.trim().to_string();
}
if let Some(ip) = req.headers().get("x-real-ip").and_then(|v| v.to_str().ok()) {
return ip.trim().to_string();
}
"unknown".to_string()
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_second: f64,
pub burst: u32,
pub key_extractor: KeyExtractor,
}
impl RateLimitConfig {
pub fn new(requests_per_second: f64, burst: u32) -> Self {
Self {
requests_per_second,
burst,
key_extractor: KeyExtractor::Ip,
}
}
pub fn per_minute(requests: u32) -> Self {
Self::new(requests as f64 / 60.0, requests)
}
pub fn with_key_extractor(mut self, extractor: KeyExtractor) -> Self {
self.key_extractor = extractor;
self
}
}
#[derive(Debug)]
pub struct RateLimitMiddleware {
config: RateLimitConfig,
buckets: Arc<DashMap<String, TokenBucket>>,
request_count: Arc<AtomicU64>,
}
impl Clone for RateLimitMiddleware {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
buckets: Arc::clone(&self.buckets),
request_count: Arc::clone(&self.request_count),
}
}
}
impl RateLimitMiddleware {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(DashMap::new()),
request_count: Arc::new(AtomicU64::new(0)),
}
}
fn cleanup_stale_buckets(&self) {
let now = Instant::now();
self.buckets
.retain(|_, bucket| now.duration_since(bucket.last_refill) < STALE_AFTER);
}
fn check_rate_limit(&self, key: &str) -> Option<u64> {
let count = self.request_count.fetch_add(1, Ordering::Relaxed);
if count > 0 && count % CLEANUP_INTERVAL == 0 {
self.cleanup_stale_buckets();
}
let now = Instant::now();
let mut bucket = self
.buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket {
tokens: self.config.burst as f64,
last_refill: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let refill = elapsed * self.config.requests_per_second;
bucket.tokens = (bucket.tokens + refill).min(self.config.burst as f64);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
None } else {
let tokens_needed = 1.0 - bucket.tokens;
let seconds_until_ready = tokens_needed / self.config.requests_per_second;
Some(seconds_until_ready.ceil() as u64)
}
}
}
impl Middleware for RateLimitMiddleware {
fn handle<'a>(
&'a self,
req: Request<Incoming>,
ctx: &'a RequestContext,
next: Next<'a>,
) -> BoxFuture<'a, Response<BoxBody>> {
Box::pin(async move {
let key = self.config.key_extractor.extract(&req);
if let Some(retry_after) = self.check_rate_limit(&key) {
let mut response = Error::rate_limited("too many requests")
.with_trace_id(&ctx.trace_id)
.into_response();
response
.headers_mut()
.insert("retry-after", retry_after.to_string().parse().unwrap());
return response;
}
next.run(req).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_config_per_minute() {
let config = RateLimitConfig::per_minute(60);
assert!((config.requests_per_second - 1.0).abs() < f64::EPSILON);
assert_eq!(config.burst, 60);
}
#[test]
fn test_config_per_minute_100() {
let config = RateLimitConfig::per_minute(100);
assert!((config.requests_per_second - (100.0 / 60.0)).abs() < f64::EPSILON);
assert_eq!(config.burst, 100);
}
#[test]
fn test_config_new() {
let config = RateLimitConfig::new(10.0, 50);
assert!((config.requests_per_second - 10.0).abs() < f64::EPSILON);
assert_eq!(config.burst, 50);
}
#[test]
fn test_default_key_extractor_is_ip() {
let config = RateLimitConfig::per_minute(100);
assert!(matches!(config.key_extractor, KeyExtractor::Ip));
}
#[test]
fn test_middleware_allows_burst() {
let config = RateLimitConfig::new(1.0, 5); let middleware = RateLimitMiddleware::new(config);
for _ in 0..5 {
assert!(middleware.check_rate_limit("test-key").is_none());
}
assert!(middleware.check_rate_limit("test-key").is_some());
}
#[test]
fn test_middleware_returns_retry_after() {
let config = RateLimitConfig::new(1.0, 1); let middleware = RateLimitMiddleware::new(config);
assert!(middleware.check_rate_limit("test-key").is_none());
let retry_after = middleware.check_rate_limit("test-key");
assert!(retry_after.is_some());
assert_eq!(retry_after.unwrap(), 1); }
#[test]
fn test_middleware_separate_keys() {
let config = RateLimitConfig::new(1.0, 1);
let middleware = RateLimitMiddleware::new(config);
assert!(middleware.check_rate_limit("user-1").is_none());
assert!(middleware.check_rate_limit("user-2").is_none());
assert!(middleware.check_rate_limit("user-3").is_none());
assert!(middleware.check_rate_limit("user-1").is_some());
}
#[test]
fn test_middleware_clone_shares_state() {
let config = RateLimitConfig::new(1.0, 2);
let middleware1 = RateLimitMiddleware::new(config);
let middleware2 = middleware1.clone();
assert!(middleware1.check_rate_limit("shared-key").is_none());
assert!(middleware2.check_rate_limit("shared-key").is_none());
assert!(middleware1.check_rate_limit("shared-key").is_some());
assert!(middleware2.check_rate_limit("shared-key").is_some());
}
#[test]
fn test_cleanup_removes_stale_buckets() {
let config = RateLimitConfig::new(1.0, 5);
let middleware = RateLimitMiddleware::new(config);
middleware.check_rate_limit("key-1");
middleware.check_rate_limit("key-2");
middleware.check_rate_limit("key-3");
assert_eq!(middleware.buckets.len(), 3);
if let Some(mut bucket) = middleware.buckets.get_mut("key-1") {
bucket.last_refill = Instant::now() - Duration::from_secs(700); }
middleware.cleanup_stale_buckets();
assert_eq!(middleware.buckets.len(), 2);
assert!(middleware.buckets.get("key-1").is_none());
assert!(middleware.buckets.get("key-2").is_some());
assert!(middleware.buckets.get("key-3").is_some());
}
#[test]
fn test_cleanup_triggered_periodically() {
let config = RateLimitConfig::new(1000.0, 1000); let middleware = RateLimitMiddleware::new(config);
middleware.check_rate_limit("stale-key");
if let Some(mut bucket) = middleware.buckets.get_mut("stale-key") {
bucket.last_refill = Instant::now() - Duration::from_secs(700);
}
for i in 0..super::CLEANUP_INTERVAL {
middleware.check_rate_limit(&format!("key-{}", i));
}
assert!(middleware.buckets.get("stale-key").is_none());
}
}