use axum::{
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Extension, Json,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use super::trusted_proxy::extract_client_ip_from_headers;
use crate::redis::RedisPool;
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String, exp: usize,
}
struct UserRateLimiterEntry {
limiter: RateLimiter<NotKeyed, InMemoryState, DefaultClock>,
last_access: Instant,
}
#[derive(Clone)]
pub struct RateLimiterState {
global_limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
user_limiters: Arc<RwLock<HashMap<String, UserRateLimiterEntry>>>,
per_user_quota: Quota,
per_user_limit: u32,
global_limit: u32,
stale_threshold: Duration,
jwt_secret: Option<String>,
redis: Option<RedisPool>,
}
impl RateLimiterState {
pub fn new(requests_per_minute: u32) -> Self {
Self::new_internal(requests_per_minute, None)
}
pub fn with_redis(requests_per_minute: u32, redis: RedisPool) -> Self {
Self::new_internal(requests_per_minute, Some(redis))
}
fn new_internal(requests_per_minute: u32, redis: Option<RedisPool>) -> Self {
let per_user_limit: u32 = std::env::var("RATE_LIMIT_PER_USER")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(100);
Self::new_with_config_internal(requests_per_minute, per_user_limit, redis)
}
#[cfg(test)]
fn new_with_config(
requests_per_minute: u32,
per_user_limit: u32,
redis: Option<RedisPool>,
) -> Self {
Self::new_with_config_internal(requests_per_minute, per_user_limit, redis)
}
fn new_with_config_internal(
requests_per_minute: u32,
per_user_limit: u32,
redis: Option<RedisPool>,
) -> Self {
let global_limit = if requests_per_minute == 0 {
tracing::warn!("requests_per_minute was 0, defaulting to 60");
60
} else {
requests_per_minute
};
let global_quota = Quota::per_minute(NonZeroU32::new(global_limit).unwrap());
let global_limiter = Arc::new(RateLimiter::direct(global_quota));
let per_user_limit = if per_user_limit == 0 {
tracing::warn!("per_user_limit was 0, defaulting to 100");
100
} else {
per_user_limit
};
let per_user_quota = Quota::per_minute(NonZeroU32::new(per_user_limit).unwrap());
let cleanup_interval_secs: u64 = std::env::var("RATE_LIMIT_CLEANUP_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300);
let jwt_secret = std::env::var("JWT_SECRET").ok();
let distributed = redis.is_some();
tracing::info!(
"Rate limiter initialized: global={}/min, per_user={}/min, cleanup_interval={}s, distributed={}",
global_limit,
per_user_limit,
cleanup_interval_secs,
distributed
);
let state = Self {
global_limiter,
user_limiters: Arc::new(RwLock::new(HashMap::new())),
per_user_quota,
per_user_limit,
global_limit,
stale_threshold: Duration::from_secs(cleanup_interval_secs),
jwt_secret,
redis,
};
if state.redis.is_none() {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let cleanup_state = state.clone();
handle.spawn(async move {
cleanup_state.cleanup_loop().await;
});
}
}
state
}
pub fn is_distributed(&self) -> bool {
self.redis.is_some()
}
pub fn check(&self) -> bool {
self.global_limiter.check().is_ok()
}
pub async fn check_user(&self, key: &str) -> bool {
if let Some(ref redis) = self.redis {
return self.check_user_redis(redis, key).await;
}
self.check_user_in_memory(key).await
}
async fn check_user_redis(&self, redis: &RedisPool, key: &str) -> bool {
let global_key = "ratelimit:global";
match redis.increment_with_expiry(global_key, 60).await {
Ok(count) => {
if count > self.global_limit as i64 {
tracing::debug!(
key = global_key,
count = count,
limit = self.global_limit,
"Global rate limit exceeded (Redis)"
);
return false;
}
}
Err(e) => {
tracing::warn!(
"Redis global rate limit check failed: {}, falling back to in-memory",
e
);
if self.global_limiter.check().is_err() {
return false;
}
}
}
let user_key = format!("ratelimit:{}", key);
match redis.increment_with_expiry(&user_key, 60).await {
Ok(count) => {
if count > self.per_user_limit as i64 {
tracing::debug!(
key = user_key,
count = count,
limit = self.per_user_limit,
"Per-user rate limit exceeded (Redis)"
);
return false;
}
true
}
Err(e) => {
tracing::warn!(
"Redis user rate limit check failed: {}, falling back to in-memory",
e
);
self.check_user_in_memory(key).await
}
}
}
async fn check_user_in_memory(&self, key: &str) -> bool {
if self.global_limiter.check().is_err() {
return false;
}
let mut limiters = self.user_limiters.write().await;
if let Some(entry) = limiters.get_mut(key) {
entry.last_access = Instant::now();
entry.limiter.check().is_ok()
} else {
let limiter = RateLimiter::direct(self.per_user_quota);
let result = limiter.check().is_ok();
limiters.insert(
key.to_string(),
UserRateLimiterEntry {
limiter,
last_access: Instant::now(),
},
);
result
}
}
fn extract_user_id(&self, headers: &HeaderMap) -> Option<String> {
let jwt_secret = self.jwt_secret.as_ref()?;
let auth_header = headers.get("Authorization")?.to_str().ok()?;
let token = auth_header.strip_prefix("Bearer ")?;
let validation = Validation::default();
let token_data =
decode::<Claims>(token, &DecodingKey::from_secret(jwt_secret.as_bytes()), &validation)
.ok()?;
Some(token_data.claims.sub)
}
pub fn get_rate_limit_key(&self, headers: &HeaderMap) -> String {
if let Some(user_id) = self.extract_user_id(headers) {
return format!("user:{}", user_id);
}
let ip = extract_client_ip_from_headers(headers);
format!("ip:{}", ip)
}
async fn cleanup_loop(&self) {
let cleanup_interval = self.stale_threshold;
loop {
tokio::time::sleep(cleanup_interval).await;
self.cleanup_stale_entries().await;
}
}
async fn cleanup_stale_entries(&self) {
let mut limiters = self.user_limiters.write().await;
let now = Instant::now();
let initial_count = limiters.len();
limiters.retain(|_, entry| now.duration_since(entry.last_access) < self.stale_threshold);
let removed_count = initial_count - limiters.len();
if removed_count > 0 {
tracing::debug!(
"Cleaned up {} stale rate limiter entries, {} remaining",
removed_count,
limiters.len()
);
}
}
pub async fn active_entries_count(&self) -> usize {
self.user_limiters.read().await.len()
}
}
pub async fn rate_limit_middleware(
Extension(limiter): Extension<RateLimiterState>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, Response> {
let rate_limit_key = limiter.get_rate_limit_key(&headers);
if !limiter.check_user(&rate_limit_key).await {
let is_authenticated = rate_limit_key.starts_with("user:");
tracing::warn!(
rate_limit_key = %rate_limit_key,
path = %request.uri().path(),
authenticated = is_authenticated,
"Rate limit exceeded"
);
return Err(rate_limited_response().into_response());
}
Ok(next.run(request).await)
}
fn rate_limited_response() -> impl IntoResponse {
(
StatusCode::TOO_MANY_REQUESTS,
[("Retry-After", "60")],
Json(json!({
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": "Too many requests. Please try again later.",
"retry_after_seconds": 60
}
})),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = RateLimiterState::new(60);
assert!(limiter.check());
}
#[test]
fn test_rate_limiter_global_limit() {
let limiter = RateLimiterState::new(2);
assert!(limiter.check());
assert!(limiter.check());
assert!(!limiter.check());
}
#[tokio::test]
async fn test_per_user_rate_limiter() {
let limiter = RateLimiterState::new_with_config(1000, 2, None);
assert!(limiter.check_user("user:user1").await);
assert!(limiter.check_user("user:user1").await);
assert!(!limiter.check_user("user:user1").await);
assert!(limiter.check_user("user:user2").await);
}
#[tokio::test]
async fn test_ip_rate_limiter() {
let limiter = RateLimiterState::new_with_config(1000, 2, None);
assert!(limiter.check_user("ip:192.168.1.1").await);
assert!(limiter.check_user("ip:192.168.1.1").await);
assert!(!limiter.check_user("ip:192.168.1.1").await);
assert!(limiter.check_user("ip:192.168.1.2").await);
}
#[test]
fn test_get_rate_limit_key_ip_fallback() {
let limiter = RateLimiterState::new(60);
let mut headers = HeaderMap::new();
headers.insert("X-Forwarded-For", "192.168.1.100".parse().unwrap());
let key = limiter.get_rate_limit_key(&headers);
assert_eq!(key, "ip:192.168.1.100");
}
#[test]
fn test_get_rate_limit_key_x_real_ip() {
let limiter = RateLimiterState::new(60);
let mut headers = HeaderMap::new();
headers.insert("X-Real-IP", "10.0.0.50".parse().unwrap());
let key = limiter.get_rate_limit_key(&headers);
assert_eq!(key, "ip:10.0.0.50");
}
#[test]
fn test_get_rate_limit_key_forwarded_for_multiple() {
let limiter = RateLimiterState::new(60);
let mut headers = HeaderMap::new();
headers.insert(
"X-Forwarded-For",
"203.0.113.195, 70.41.3.18, 150.172.238.178".parse().unwrap(),
);
let key = limiter.get_rate_limit_key(&headers);
assert_eq!(key, "ip:203.0.113.195");
}
#[test]
fn test_get_rate_limit_key_unknown() {
let limiter = RateLimiterState::new(60);
let headers = HeaderMap::new();
let key = limiter.get_rate_limit_key(&headers);
assert_eq!(key, "ip:unknown");
}
#[tokio::test]
async fn test_active_entries_count() {
let limiter = RateLimiterState::new_with_config(1000, 100, None);
assert_eq!(limiter.active_entries_count().await, 0);
limiter.check_user("user:user1").await;
assert_eq!(limiter.active_entries_count().await, 1);
limiter.check_user("user:user2").await;
assert_eq!(limiter.active_entries_count().await, 2);
limiter.check_user("user:user1").await;
assert_eq!(limiter.active_entries_count().await, 2);
}
#[tokio::test]
async fn test_global_limit_takes_precedence() {
let limiter = RateLimiterState::new_with_config(1, 100, None);
assert!(limiter.check_user("user:user1").await);
assert!(!limiter.check_user("user:user2").await);
}
}