use axum::{
body::Body,
http::{header::HeaderValue, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub enabled: bool,
pub max_requests: u32,
pub window: Duration,
pub exempt_paths: Vec<String>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: false,
max_requests: 100,
window: Duration::from_secs(60), exempt_paths: vec![
"/health".to_string(),
"/ready".to_string(),
"/live".to_string(),
],
}
}
}
impl RateLimitConfig {
pub fn new(max_requests: u32, window_secs: u64) -> Self {
Self {
enabled: true,
max_requests,
window: Duration::from_secs(window_secs),
exempt_paths: vec![
"/health".to_string(),
"/ready".to_string(),
"/live".to_string(),
],
}
}
pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
self.exempt_paths.extend(paths);
self
}
}
#[derive(Clone)]
struct RequestRecord {
count: u32,
window_start: Instant,
}
#[derive(Clone)]
pub struct RateLimiter {
config: RateLimitConfig,
records: Arc<RwLock<HashMap<String, RequestRecord>>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
records: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check_rate_limit(&self, key: &str) -> bool {
if !self.config.enabled {
return true;
}
let mut records = self.records.write().await;
let now = Instant::now();
match records.get_mut(key) {
Some(record) => {
if now.duration_since(record.window_start) >= self.config.window {
record.count = 1;
record.window_start = now;
true
} else if record.count < self.config.max_requests {
record.count += 1;
true
} else {
false
}
}
None => {
records.insert(
key.to_string(),
RequestRecord {
count: 1,
window_start: now,
},
);
true
}
}
}
pub async fn remaining(&self, key: &str) -> u32 {
if !self.config.enabled {
return self.config.max_requests;
}
let records = self.records.read().await;
match records.get(key) {
Some(record) => {
let now = Instant::now();
if now.duration_since(record.window_start) >= self.config.window {
self.config.max_requests
} else {
self.config.max_requests.saturating_sub(record.count)
}
}
None => self.config.max_requests,
}
}
pub async fn cleanup_expired(&self) {
let mut records = self.records.write().await;
let now = Instant::now();
records.retain(|_, record| now.duration_since(record.window_start) < self.config.window);
}
}
pub async fn rate_limit_middleware(
axum::Extension(limiter): axum::Extension<RateLimiter>,
request: Request<Body>,
next: Next,
) -> Response {
let path = request.uri().path();
if limiter
.config
.exempt_paths
.iter()
.any(|p| path.starts_with(p))
{
return next.run(request).await;
}
let client_key = extract_client_key(&request);
if limiter.check_rate_limit(&client_key).await {
let remaining = limiter.remaining(&client_key).await;
let mut response = next.run(request).await;
let headers = response.headers_mut();
if let Ok(val) = HeaderValue::try_from(limiter.config.max_requests.to_string()) {
headers.insert("X-RateLimit-Limit", val);
}
if let Ok(val) = HeaderValue::try_from(remaining.to_string()) {
headers.insert("X-RateLimit-Remaining", val);
}
response
} else {
let window_secs = limiter.config.window.as_secs();
(
StatusCode::TOO_MANY_REQUESTS,
[
("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("Retry-After", window_secs.to_string()),
],
format!(
"Rate limit exceeded. Max {} requests per {} seconds.",
limiter.config.max_requests, window_secs
),
)
.into_response()
}
}
fn extract_client_key(request: &Request<Body>) -> String {
if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
if let Ok(s) = forwarded.to_str() {
if let Some(ip) = s.split(',').next() {
return ip.trim().to_string();
}
}
}
if let Some(real_ip) = request.headers().get("X-Real-IP") {
if let Ok(s) = real_ip.to_str() {
return s.to_string();
}
}
"unknown".to_string()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, middleware, routing::get, Router};
use tower::ServiceExt;
async fn test_handler() -> &'static str {
"ok"
}
fn test_router(config: RateLimitConfig) -> Router {
let limiter = RateLimiter::new(config);
Router::new()
.route("/api/test", get(test_handler))
.route("/health", get(test_handler))
.layer(middleware::from_fn(rate_limit_middleware))
.layer(axum::Extension(limiter))
}
#[tokio::test]
async fn test_rate_limit_disabled() {
let config = RateLimitConfig::default();
let router = test_router(config);
let request = Request::builder()
.uri("/api/test")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_rate_limit_allows_under_limit() {
let config = RateLimitConfig::new(5, 60);
let router = test_router(config);
for _ in 0..3 {
let router = router.clone();
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.1")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}
#[tokio::test]
async fn test_rate_limit_blocks_over_limit() {
let config = RateLimitConfig::new(2, 60);
let limiter = RateLimiter::new(config.clone());
let router = Router::new()
.route("/api/test", get(test_handler))
.layer(middleware::from_fn(rate_limit_middleware))
.layer(axum::Extension(limiter.clone()));
for i in 0..3 {
let router = router.clone();
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.100")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
if i < 2 {
assert_eq!(response.status(), StatusCode::OK);
} else {
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
}
#[tokio::test]
async fn test_rate_limit_exempt_path() {
let config = RateLimitConfig::new(1, 60);
let limiter = RateLimiter::new(config);
let router = Router::new()
.route("/api/test", get(test_handler))
.route("/health", get(test_handler))
.layer(middleware::from_fn(rate_limit_middleware))
.layer(axum::Extension(limiter));
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.200")
.body(Body::empty())
.unwrap();
let _ = router.clone().oneshot(request).await.unwrap();
let request = Request::builder()
.uri("/health")
.header("X-Forwarded-For", "192.168.1.200")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_rate_limiter_cleanup() {
let config = RateLimitConfig::new(10, 1); let limiter = RateLimiter::new(config);
limiter.check_rate_limit("test-client").await;
assert!(limiter.records.read().await.contains_key("test-client"));
tokio::time::sleep(Duration::from_millis(1100)).await;
limiter.cleanup_expired().await;
assert!(!limiter.records.read().await.contains_key("test-client"));
}
}