#[cfg(feature = "cache")]
use deadpool_redis::Pool as RedisPool;
#[cfg(feature = "cache")]
use std::ops::DerefMut;
use std::sync::Arc;
use axum::{
body::Body,
extract::{Request, State},
middleware::Next,
response::Response,
};
#[cfg(feature = "cache")]
use axum::http::{header::HeaderValue, HeaderName};
use crate::{config::RateLimitConfig, error::Error};
#[cfg(feature = "cache")]
use crate::middleware::Claims;
use super::route_matcher::CompiledRoutePatterns;
#[cfg(feature = "cache")]
use super::route_matcher::normalize_path;
#[cfg(feature = "cache")]
use tracing::{debug, warn};
#[derive(Clone)]
pub struct RateLimit {
#[cfg_attr(not(feature = "cache"), allow(dead_code))]
config: RateLimitConfig,
#[cfg_attr(not(feature = "cache"), allow(dead_code))]
route_patterns: Arc<CompiledRoutePatterns>,
#[cfg(feature = "cache")]
redis_pool: Option<RedisPool>,
}
#[cfg(feature = "cache")]
struct RateLimitResult {
limit: u32,
count: u32,
reset_secs: u64,
}
impl RateLimit {
#[cfg(feature = "cache")]
pub fn new(config: RateLimitConfig, redis_pool: RedisPool) -> Self {
let route_patterns = CompiledRoutePatterns::compile(&config.routes);
Self {
config,
route_patterns: Arc::new(route_patterns),
redis_pool: Some(redis_pool),
}
}
#[cfg(not(feature = "cache"))]
pub fn new(config: RateLimitConfig) -> Self {
let route_patterns = CompiledRoutePatterns::compile(&config.routes);
Self {
config,
route_patterns: Arc::new(route_patterns),
}
}
pub async fn middleware(
#[cfg_attr(not(feature = "cache"), allow(unused_variables))] State(rate_limit): State<Self>,
request: Request<Body>,
next: Next,
) -> Result<Response, Error> {
#[cfg(feature = "cache")]
{
let method = request.method().as_str();
let path = request.uri().path();
let claims = request.extensions().get::<Claims>().cloned();
let result = rate_limit
.check_rate_limit_with_route(method, path, claims.as_ref())
.await?;
let mut response = next.run(request).await;
Self::add_rate_limit_headers(&mut response, &result);
Ok(response)
}
#[cfg(not(feature = "cache"))]
Ok(next.run(request).await)
}
#[cfg(feature = "cache")]
async fn check_rate_limit_with_route(
&self,
method: &str,
path: &str,
claims: Option<&Claims>,
) -> Result<RateLimitResult, Error> {
let normalized_path = normalize_path(path);
if let Some(route_config) = self.route_patterns.match_route(method, &normalized_path) {
debug!(
"Using per-route rate limit for {} {}: {} rpm",
method, normalized_path, route_config.requests_per_minute
);
let key = if route_config.per_user {
if let Some(claims) = claims {
format!("route:{}:user:{}", normalized_path, claims.sub)
} else {
format!("route:{}:global", normalized_path)
}
} else {
format!("route:{}:global", normalized_path)
};
return self
.check_and_increment(
&key,
route_config.requests_per_minute,
self.config.window_secs,
)
.await;
}
if let Some(claims) = claims {
let (key, limit) = if claims.is_user() {
(
format!("ratelimit:user:{}", claims.sub),
self.config.per_user_rpm,
)
} else if claims.is_client() {
(
format!("ratelimit:client:{}", claims.sub),
self.config.per_client_rpm,
)
} else {
(
format!("ratelimit:unknown:{}", claims.sub),
self.config.per_user_rpm,
)
};
return self
.check_and_increment(&key, limit, self.config.window_secs)
.await;
}
warn!("Rate limit middleware called without JWT claims and no route-specific limit");
Ok(RateLimitResult {
limit: self.config.per_user_rpm,
count: 0,
reset_secs: self.config.window_secs,
})
}
#[cfg(feature = "cache")]
async fn check_and_increment(
&self,
key: &str,
limit: u32,
window_secs: u64,
) -> Result<RateLimitResult, Error> {
let redis_pool = self
.redis_pool
.as_ref()
.ok_or_else(|| Error::Internal("Redis pool not configured".to_string()))?;
let mut conn = redis_pool.get().await.map_err(|e| {
let redis_err = redis::RedisError::from((
redis::ErrorKind::IoError,
"Failed to get Redis connection",
e.to_string(),
));
Error::Redis(Box::new(redis_err))
})?;
let count: u32 = redis::cmd("INCR")
.arg(key)
.query_async(conn.deref_mut())
.await?;
if count == 1 {
let _: () = redis::cmd("EXPIRE")
.arg(key)
.arg(window_secs as i64)
.query_async(conn.deref_mut())
.await?;
}
let ttl: i64 = redis::cmd("TTL")
.arg(key)
.query_async(conn.deref_mut())
.await
.unwrap_or(window_secs as i64);
let reset_secs = if ttl > 0 { ttl as u64 } else { window_secs };
if count > limit {
warn!(
"Rate limit exceeded for {}: {} requests (limit: {})",
key, count, limit
);
return Err(Error::RateLimitExceeded);
}
Ok(RateLimitResult {
limit,
count,
reset_secs,
})
}
#[cfg(feature = "cache")]
fn add_rate_limit_headers(response: &mut Response, result: &RateLimitResult) {
let headers = response.headers_mut();
if let Ok(value) = HeaderValue::from_str(&result.limit.to_string()) {
headers.insert(HeaderName::from_static("x-ratelimit-limit"), value);
}
let remaining = result.limit.saturating_sub(result.count);
if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
headers.insert(HeaderName::from_static("x-ratelimit-remaining"), value);
}
let reset_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + result.reset_secs)
.unwrap_or(0);
if let Ok(value) = HeaderValue::from_str(&reset_timestamp.to_string()) {
headers.insert(HeaderName::from_static("x-ratelimit-reset"), value);
}
}
}
#[cfg(test)]
mod tests {
#[cfg(not(feature = "cache"))]
use super::{RateLimit, RateLimitConfig};
#[test]
fn test_rate_limit_creation() {
#[cfg(not(feature = "cache"))]
{
let config = RateLimitConfig {
per_user_rpm: 200,
per_client_rpm: 1000,
window_secs: 60,
routes: std::collections::HashMap::new(),
auto_apply: true,
trust_forwarded_headers: false,
};
let _rate_limit = RateLimit::new(config);
}
}
#[test]
fn test_rate_limit_with_routes() {
#[cfg(not(feature = "cache"))]
{
use crate::config::RouteRateLimitConfig;
use std::collections::HashMap;
let mut routes = HashMap::new();
routes.insert(
"/api/v1/heavy".to_string(),
RouteRateLimitConfig {
requests_per_minute: 10,
burst_size: 2,
per_user: true,
},
);
let config = RateLimitConfig {
per_user_rpm: 200,
per_client_rpm: 1000,
window_secs: 60,
routes,
auto_apply: true,
trust_forwarded_headers: false,
};
let rate_limit = RateLimit::new(config);
assert!(!rate_limit.route_patterns.is_empty());
}
}
}