use axum::{
body::Body,
http::{header::HeaderValue, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use super::rate_limit::{RateLimitConfig, RateLimiter};
#[cfg(feature = "redis")]
use super::redis_rate_limit::RedisRateLimiter;
#[derive(Clone)]
pub enum RateLimitBackend {
InMemory {
limiter: RateLimiter,
config: RateLimitConfig,
},
#[cfg(feature = "redis")]
Redis {
limiter: Box<RedisRateLimiter>,
config: RateLimitConfig,
},
}
impl RateLimitBackend {
pub fn in_memory(config: RateLimitConfig) -> Self {
let limiter = RateLimiter::new(config.clone());
Self::InMemory { limiter, config }
}
#[cfg(feature = "redis")]
pub async fn redis(
redis_url: &str,
config: RateLimitConfig,
) -> Result<Self, redis::RedisError> {
let limiter = RedisRateLimiter::new(redis_url, config.max_requests, config.window).await?;
Ok(Self::Redis {
limiter: Box::new(limiter),
config,
})
}
pub fn config(&self) -> &RateLimitConfig {
match self {
Self::InMemory { config, .. } => config,
#[cfg(feature = "redis")]
Self::Redis { config, .. } => config,
}
}
pub async fn check_rate_limit(&self, client_key: &str) -> bool {
match self {
Self::InMemory { limiter, config } => {
if !config.enabled {
return true;
}
limiter.check_rate_limit(client_key).await
}
#[cfg(feature = "redis")]
Self::Redis { limiter, config } => {
if !config.enabled {
return true;
}
limiter.check_rate_limit(client_key).await.allowed
}
}
}
pub async fn remaining(&self, client_key: &str) -> u32 {
match self {
Self::InMemory { limiter, config } => {
if !config.enabled {
return config.max_requests;
}
limiter.remaining(client_key).await
}
#[cfg(feature = "redis")]
Self::Redis { limiter, config } => {
if !config.enabled {
return config.max_requests;
}
limiter.remaining(client_key).await
}
}
}
pub async fn cleanup_expired(&self) {
match self {
Self::InMemory { limiter, .. } => {
limiter.cleanup_expired().await;
}
#[cfg(feature = "redis")]
Self::Redis { .. } => {
}
}
}
pub fn backend_name(&self) -> &'static str {
match self {
Self::InMemory { .. } => "in-memory",
#[cfg(feature = "redis")]
Self::Redis { .. } => "redis",
}
}
}
pub async fn backend_rate_limit_middleware(
axum::Extension(backend): axum::Extension<RateLimitBackend>,
request: Request<Body>,
next: Next,
) -> Response {
let config = backend.config();
if !config.enabled {
return next.run(request).await;
}
let path = request.uri().path();
if config.exempt_paths.iter().any(|p| path.starts_with(p)) {
return next.run(request).await;
}
let client_key = extract_client_key(&request);
let max_requests = config.max_requests;
let window_secs = config.window.as_secs();
if backend.check_rate_limit(&client_key).await {
let remaining = backend.remaining(&client_key).await;
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert("X-RateLimit-Limit", HeaderValue::from(max_requests));
headers.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
response
} else {
(
StatusCode::TOO_MANY_REQUESTS,
[
("X-RateLimit-Limit", max_requests.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("Retry-After", window_secs.to_string()),
],
format!("Rate limit exceeded. Max {max_requests} requests per {window_secs} seconds."),
)
.into_response()
}
}
fn extract_client_key(request: &Request<Body>) -> String {
if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
if let Ok(s) = forwarded.to_str() {
if let Some(ip) = s.split(',').next() {
return ip.trim().to_string();
}
}
}
if let Some(real_ip) = request.headers().get("X-Real-IP") {
if let Ok(s) = real_ip.to_str() {
return s.to_string();
}
}
"unknown".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, middleware, routing::get, Router};
use tower::ServiceExt;
async fn test_handler() -> &'static str {
"ok"
}
fn test_router_with_backend(config: RateLimitConfig) -> Router {
let backend = RateLimitBackend::in_memory(config);
Router::new()
.route("/api/test", get(test_handler))
.route("/health", get(test_handler))
.layer(middleware::from_fn(backend_rate_limit_middleware))
.layer(axum::Extension(backend))
}
#[tokio::test]
async fn test_backend_rate_limit_disabled() {
let config = RateLimitConfig::default(); let router = test_router_with_backend(config);
let request = Request::builder()
.uri("/api/test")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_backend_rate_limit_allows_under_limit() {
let config = RateLimitConfig::new(5, 60);
let router = test_router_with_backend(config);
for _ in 0..3 {
let router = router.clone();
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.1")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}
#[tokio::test]
async fn test_backend_rate_limit_blocks_over_limit() {
let config = RateLimitConfig::new(2, 60);
let backend = RateLimitBackend::in_memory(config.clone());
let router = Router::new()
.route("/api/test", get(test_handler))
.layer(middleware::from_fn(backend_rate_limit_middleware))
.layer(axum::Extension(backend));
for i in 0..3 {
let router = router.clone();
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.100")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
if i < 2 {
assert_eq!(response.status(), StatusCode::OK);
} else {
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
}
#[tokio::test]
async fn test_backend_rate_limit_exempt_path() {
let config = RateLimitConfig::new(1, 60);
let backend = RateLimitBackend::in_memory(config);
let router = Router::new()
.route("/api/test", get(test_handler))
.route("/health", get(test_handler))
.layer(middleware::from_fn(backend_rate_limit_middleware))
.layer(axum::Extension(backend));
let request = Request::builder()
.uri("/api/test")
.header("X-Forwarded-For", "192.168.1.200")
.body(Body::empty())
.unwrap();
let _ = router.clone().oneshot(request).await.unwrap();
let request = Request::builder()
.uri("/health")
.header("X-Forwarded-For", "192.168.1.200")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_backend_name_in_memory() {
let config = RateLimitConfig::default();
let backend = RateLimitBackend::in_memory(config);
assert_eq!(backend.backend_name(), "in-memory");
}
#[tokio::test]
async fn test_backend_cleanup_in_memory() {
let config = RateLimitConfig::new(10, 1);
let backend = RateLimitBackend::in_memory(config);
backend.cleanup_expired().await;
}
}