#[cfg(feature = "redis")]
use deadpool_redis::Pool as RedisPool;
use axum::{
extract::{ConnectInfo, Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::config::RateLimitConfig;
#[derive(Debug, Clone)]
struct RateLimitEntry {
count: u32,
window_start: Instant,
}
type InMemoryStore = Arc<RwLock<HashMap<String, RateLimitEntry>>>;
#[derive(Clone)]
pub struct RateLimit {
config: RateLimitConfig,
#[cfg(feature = "redis")]
redis_pool: Option<RedisPool>,
in_memory_store: InMemoryStore,
}
impl RateLimit {
#[must_use]
#[cfg(feature = "redis")]
pub fn new(config: RateLimitConfig, redis_pool: Option<RedisPool>) -> Self {
Self {
config,
redis_pool,
in_memory_store: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
#[cfg(not(feature = "redis"))]
pub fn new(config: RateLimitConfig, _redis_pool: Option<()>) -> Self {
Self {
config,
in_memory_store: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn middleware(
State(rate_limit): State<Self>,
request: Request,
next: Next,
) -> Result<Response, RateLimitError> {
if !rate_limit.config.enabled {
return Ok(next.run(request).await);
}
let user_id: Option<i64> = request.extensions().get::<i64>().copied();
let ip_addr = request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip().to_string());
let path = request.uri().path();
let (key, limit) = rate_limit.determine_key_and_limit(user_id, ip_addr.as_deref(), path);
debug!(
key = %key,
limit = limit,
path = %path,
user_id = ?user_id,
"Checking rate limit"
);
rate_limit.check_rate_limit(&key, limit).await?;
Ok(next.run(request).await)
}
fn determine_key_and_limit(
&self,
user_id: Option<i64>,
ip_addr: Option<&str>,
path: &str,
) -> (String, u32) {
let is_strict_route = self
.config
.strict_routes
.iter()
.any(|route| path.starts_with(route));
if is_strict_route {
let key = user_id.map_or_else(|| {
ip_addr.map_or_else(|| "ratelimit:route:unknown".to_string(), |ip| format!("ratelimit:route:ip:{ip}"))
}, |uid| format!("ratelimit:route:user:{uid}"));
(key, self.config.per_route_rpm)
} else if let Some(uid) = user_id {
(
format!("ratelimit:user:{uid}"),
self.config.per_user_rpm,
)
} else if let Some(ip) = ip_addr {
(format!("ratelimit:ip:{ip}"), self.config.per_ip_rpm)
} else {
("ratelimit:unknown".to_string(), self.config.per_ip_rpm)
}
}
async fn check_rate_limit(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
#[cfg(feature = "redis")]
if self.config.redis_enabled {
if let Some(ref redis_pool) = self.redis_pool {
match self.check_rate_limit_redis(redis_pool, key, limit).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(
error = %e,
key = %key,
"Redis rate limit check failed, falling back to in-memory"
);
}
}
}
}
self.check_rate_limit_memory(key, limit).await
}
#[cfg(feature = "redis")]
async fn check_rate_limit_redis(
&self,
redis_pool: &RedisPool,
key: &str,
limit: u32,
) -> Result<(), RateLimitError> {
let mut conn = redis_pool.get().await.map_err(|e| {
RateLimitError::Backend(format!("Failed to get Redis connection: {e}"))
})?;
let count: u32 = redis::cmd("INCR")
.arg(key)
.query_async(&mut *conn)
.await
.map_err(|e| RateLimitError::Backend(format!("Redis INCR failed: {e}")))?;
if count == 1 {
let expire_secs = i64::try_from(self.config.window_secs).unwrap_or(i64::MAX);
let _: () = redis::cmd("EXPIRE")
.arg(key)
.arg(expire_secs)
.query_async(&mut *conn)
.await
.map_err(|e| RateLimitError::Backend(format!("Redis EXPIRE failed: {e}")))?;
}
if count > limit {
warn!(
key = %key,
count = count,
limit = limit,
window_secs = self.config.window_secs,
"Rate limit exceeded"
);
return Err(RateLimitError::Exceeded {
limit,
window: Duration::from_secs(self.config.window_secs),
});
}
debug!(
key = %key,
count = count,
limit = limit,
"Rate limit check passed (Redis)"
);
Ok(())
}
async fn check_rate_limit_memory(&self, key: &str, limit: u32) -> Result<(), RateLimitError> {
let now = Instant::now();
let window_duration = Duration::from_secs(self.config.window_secs);
let mut store = self.in_memory_store.write().await;
let entry = store.entry(key.to_string()).or_insert_with(|| RateLimitEntry {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) >= window_duration {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
}
let count = entry.count;
drop(store);
if count > limit {
warn!(
key = %key,
count = count,
limit = limit,
window_secs = self.config.window_secs,
"Rate limit exceeded"
);
return Err(RateLimitError::Exceeded {
limit,
window: window_duration,
});
}
debug!(
key = %key,
count = count,
limit = limit,
"Rate limit check passed (in-memory)"
);
Ok(())
}
pub async fn cleanup_expired(&self) -> usize {
let now = Instant::now();
let window_duration = Duration::from_secs(self.config.window_secs);
let removed = {
let mut store = self.in_memory_store.write().await;
let before_count = store.len();
store.retain(|_, entry| now.duration_since(entry.window_start) < window_duration);
before_count - store.len()
};
if removed > 0 {
debug!(removed = removed, "Cleaned up expired rate limit entries");
}
removed
}
}
#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error("Rate limit exceeded: {limit} requests per {window:?}")]
Exceeded {
limit: u32,
window: Duration,
},
#[error("Rate limit backend error: {0}")]
Backend(String),
}
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
match self {
Self::Exceeded { limit, window } => {
let retry_after = window.as_secs();
(
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_after.to_string()),
(
"X-RateLimit-Limit",
limit.to_string(),
),
],
format!(
"Rate limit exceeded. Maximum {} requests per {} seconds.",
limit,
window.as_secs()
),
)
.into_response()
}
Self::Backend(msg) => {
warn!(error = %msg, "Rate limit backend error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Rate limiting temporarily unavailable",
)
.into_response()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RateLimitFailureMode;
#[test]
fn test_rate_limit_creation() {
let config = RateLimitConfig::default();
let rate_limit = RateLimit::new(config, None);
assert!(rate_limit.config.enabled);
assert_eq!(rate_limit.config.per_user_rpm, 120);
assert_eq!(rate_limit.config.per_ip_rpm, 60);
assert_eq!(rate_limit.config.per_route_rpm, 30);
}
#[test]
fn test_determine_key_and_limit_authenticated() {
let config = RateLimitConfig::default();
let rate_limit = RateLimit::new(config, None);
let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/posts");
assert_eq!(key, "ratelimit:user:123");
assert_eq!(limit, 120);
}
#[test]
fn test_determine_key_and_limit_anonymous() {
let config = RateLimitConfig::default();
let rate_limit = RateLimit::new(config, None);
let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/posts");
assert_eq!(key, "ratelimit:ip:192.168.1.1");
assert_eq!(limit, 60);
}
#[test]
fn test_determine_key_and_limit_strict_route_authenticated() {
let config = RateLimitConfig::default();
let rate_limit = RateLimit::new(config, None);
let (key, limit) = rate_limit.determine_key_and_limit(Some(123), Some("192.168.1.1"), "/login");
assert_eq!(key, "ratelimit:route:user:123");
assert_eq!(limit, 30);
}
#[test]
fn test_determine_key_and_limit_strict_route_anonymous() {
let config = RateLimitConfig::default();
let rate_limit = RateLimit::new(config, None);
let (key, limit) = rate_limit.determine_key_and_limit(None, Some("192.168.1.1"), "/register");
assert_eq!(key, "ratelimit:route:ip:192.168.1.1");
assert_eq!(limit, 30);
}
#[tokio::test]
async fn test_in_memory_rate_limit_within_limit() {
let config = RateLimitConfig {
enabled: true,
per_user_rpm: 5,
per_ip_rpm: 3,
per_route_rpm: 2,
window_secs: 60,
redis_enabled: false,
failure_mode: RateLimitFailureMode::Closed,
strict_routes: vec![],
};
let rate_limit = RateLimit::new(config, None);
for _ in 0..3 {
let result = rate_limit.check_rate_limit_memory("test_key", 5).await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_in_memory_rate_limit_exceeded() {
let config = RateLimitConfig {
enabled: true,
per_user_rpm: 5,
per_ip_rpm: 3,
per_route_rpm: 2,
window_secs: 60,
redis_enabled: false,
failure_mode: RateLimitFailureMode::Closed,
strict_routes: vec![],
};
let rate_limit = RateLimit::new(config, None);
for _ in 0..3 {
let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
assert!(result.is_ok());
}
let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), RateLimitError::Exceeded { .. }));
}
#[tokio::test]
async fn test_in_memory_rate_limit_window_reset() {
let config = RateLimitConfig {
enabled: true,
per_user_rpm: 5,
per_ip_rpm: 3,
per_route_rpm: 2,
window_secs: 1, redis_enabled: false,
failure_mode: RateLimitFailureMode::Closed,
strict_routes: vec![],
};
let rate_limit = RateLimit::new(config, None);
for _ in 0..3 {
let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
assert!(result.is_ok());
}
let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
assert!(result.is_err());
tokio::time::sleep(Duration::from_secs(2)).await;
let result = rate_limit.check_rate_limit_memory("test_key", 3).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_cleanup_expired() {
let config = RateLimitConfig {
enabled: true,
per_user_rpm: 5,
per_ip_rpm: 3,
per_route_rpm: 2,
window_secs: 1, redis_enabled: false,
failure_mode: RateLimitFailureMode::Closed,
strict_routes: vec![],
};
let rate_limit = RateLimit::new(config, None);
for i in 0..5 {
let key = format!("test_key_{i}");
let _ = rate_limit.check_rate_limit_memory(&key, 10).await;
}
let len = {
let store = rate_limit.in_memory_store.read().await;
store.len()
};
assert_eq!(len, 5);
tokio::time::sleep(Duration::from_secs(2)).await;
let removed = rate_limit.cleanup_expired().await;
assert_eq!(removed, 5);
let len = {
let store = rate_limit.in_memory_store.read().await;
store.len()
};
assert_eq!(len, 0);
}
#[test]
fn test_rate_limit_error_display() {
let error = RateLimitError::Exceeded {
limit: 100,
window: Duration::from_secs(60),
};
assert!(error.to_string().contains("100"));
assert!(error.to_string().contains("60"));
let error = RateLimitError::Backend("Redis connection failed".to_string());
assert!(error.to_string().contains("Redis connection failed"));
}
}