use axum::{
body::Body,
extract::{ConnectInfo, State},
http::{HeaderName, HeaderValue, Request, StatusCode},
middleware::Next,
response::Response,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use std::net::SocketAddr;
use std::num::NonZeroU32;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::warn;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst: u32,
pub per_ip: bool,
pub per_endpoint: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 100,
burst: 200,
per_ip: true,
per_endpoint: false,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitQuota {
pub limit: u32,
pub remaining: u32,
pub reset: u64,
}
pub struct GlobalRateLimiter {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
config: RateLimitConfig,
window_start: Arc<Mutex<SystemTime>>,
remaining_counter: Arc<Mutex<u32>>,
}
impl GlobalRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let quota = Quota::per_minute(
NonZeroU32::new(config.requests_per_minute)
.unwrap_or(NonZeroU32::new(100).expect("constant 100 is non-zero")),
)
.allow_burst(
NonZeroU32::new(config.burst)
.unwrap_or(NonZeroU32::new(200).expect("constant 200 is non-zero")),
);
let limiter = Arc::new(RateLimiter::direct(quota));
let window_start = Arc::new(Mutex::new(SystemTime::now()));
let remaining_counter = Arc::new(Mutex::new(config.requests_per_minute));
Self {
limiter,
config,
window_start,
remaining_counter,
}
}
pub fn check_rate_limit(&self) -> bool {
self.limiter.check().is_ok()
}
pub fn get_quota_info(&self) -> RateLimitQuota {
let now = SystemTime::now();
let mut window_start =
self.window_start.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
let mut remaining =
self.remaining_counter.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
let window_duration = Duration::from_secs(60);
if now.duration_since(*window_start).unwrap_or(Duration::ZERO) >= window_duration {
*window_start = now;
*remaining = self.config.requests_per_minute;
}
let current_remaining = *remaining;
if current_remaining > 0 {
*remaining = current_remaining.saturating_sub(1);
}
let reset_timestamp =
window_start.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_secs() + 60;
RateLimitQuota {
limit: self.config.requests_per_minute,
remaining: current_remaining,
reset: reset_timestamp,
}
}
}
pub async fn rate_limit_middleware(
State(state): State<crate::HttpServerState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request<Body>,
next: Next,
) -> Response {
let quota_info = if let Some(limiter) = &state.rate_limiter {
if !limiter.check_rate_limit() {
warn!("Rate limit exceeded for IP: {}", addr.ip());
let mut response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::from("Too Many Requests"))
.unwrap_or_else(|_| Response::new(Body::from("Too Many Requests")));
let retry_after = HeaderValue::from_static("60");
response
.headers_mut()
.insert(HeaderName::from_static("retry-after"), retry_after);
let quota = limiter.get_quota_info();
if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
response
.headers_mut()
.insert(HeaderName::from_static("x-rate-limit-limit"), limit_value);
}
if let Ok(remaining_value) = HeaderValue::from_str("0") {
response
.headers_mut()
.insert(HeaderName::from_static("x-rate-limit-remaining"), remaining_value);
}
if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
response
.headers_mut()
.insert(HeaderName::from_static("x-rate-limit-reset"), reset_value);
}
return response;
}
Some(limiter.get_quota_info())
} else {
tracing::debug!("No rate limiter configured, allowing request");
None
};
let mut response = next.run(req).await;
if let Some(quota) = quota_info {
let limit_name = HeaderName::from_static("x-rate-limit-limit");
if let Ok(limit_value) = HeaderValue::from_str("a.limit.to_string()) {
response.headers_mut().insert(limit_name, limit_value);
}
let remaining_name = HeaderName::from_static("x-rate-limit-remaining");
if let Ok(remaining_value) = HeaderValue::from_str("a.remaining.to_string()) {
response.headers_mut().insert(remaining_name, remaining_value);
}
let reset_name = HeaderName::from_static("x-rate-limit-reset");
if let Ok(reset_value) = HeaderValue::from_str("a.reset.to_string()) {
response.headers_mut().insert(reset_name, reset_value);
}
}
response
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert_eq!(config.requests_per_minute, 100);
assert_eq!(config.burst, 200);
assert!(config.per_ip);
assert!(!config.per_endpoint);
}
#[test]
fn test_rate_limit_config_custom() {
let config = RateLimitConfig {
requests_per_minute: 50,
burst: 100,
per_ip: false,
per_endpoint: true,
};
assert_eq!(config.requests_per_minute, 50);
assert_eq!(config.burst, 100);
assert!(!config.per_ip);
assert!(config.per_endpoint);
}
#[test]
fn test_rate_limit_config_clone() {
let config = RateLimitConfig {
requests_per_minute: 75,
burst: 150,
per_ip: true,
per_endpoint: true,
};
let cloned = config.clone();
assert_eq!(cloned.requests_per_minute, config.requests_per_minute);
assert_eq!(cloned.burst, config.burst);
assert_eq!(cloned.per_ip, config.per_ip);
assert_eq!(cloned.per_endpoint, config.per_endpoint);
}
#[test]
fn test_rate_limit_config_debug() {
let config = RateLimitConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("requests_per_minute"));
assert!(debug_str.contains("burst"));
assert!(debug_str.contains("per_ip"));
assert!(debug_str.contains("per_endpoint"));
}
#[test]
fn test_rate_limit_quota_creation() {
let quota = RateLimitQuota {
limit: 100,
remaining: 50,
reset: 1234567890,
};
assert_eq!(quota.limit, 100);
assert_eq!(quota.remaining, 50);
assert_eq!(quota.reset, 1234567890);
}
#[test]
fn test_rate_limit_quota_clone() {
let quota = RateLimitQuota {
limit: 200,
remaining: 175,
reset: 9876543210,
};
let cloned = quota.clone();
assert_eq!(cloned.limit, quota.limit);
assert_eq!(cloned.remaining, quota.remaining);
assert_eq!(cloned.reset, quota.reset);
}
#[test]
fn test_rate_limit_quota_debug() {
let quota = RateLimitQuota {
limit: 100,
remaining: 50,
reset: 1234567890,
};
let debug_str = format!("{:?}", quota);
assert!(debug_str.contains("limit"));
assert!(debug_str.contains("remaining"));
assert!(debug_str.contains("reset"));
}
#[test]
fn test_rate_limiter_creation() {
let config = RateLimitConfig::default();
let limiter = GlobalRateLimiter::new(config);
assert!(limiter.check_rate_limit());
}
#[test]
fn test_rate_limiter_with_custom_config() {
let config = RateLimitConfig {
requests_per_minute: 60,
burst: 10,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
assert!(limiter.check_rate_limit());
}
#[test]
fn test_rate_limiter_burst() {
let config = RateLimitConfig {
requests_per_minute: 10,
burst: 5,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
for _ in 0..5 {
assert!(limiter.check_rate_limit(), "Burst request should be allowed");
}
}
#[test]
fn test_rate_limiter_multiple_requests() {
let config = RateLimitConfig {
requests_per_minute: 1000,
burst: 100,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
for i in 0..50 {
assert!(limiter.check_rate_limit(), "Request {} should be allowed", i);
}
}
#[test]
fn test_get_quota_info() {
let config = RateLimitConfig {
requests_per_minute: 100,
burst: 50,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
let quota = limiter.get_quota_info();
assert_eq!(quota.limit, 100);
assert!(quota.remaining > 0);
assert!(quota.reset > 0);
}
#[test]
fn test_quota_info_limit_matches_config() {
let config = RateLimitConfig {
requests_per_minute: 500,
burst: 100,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
let quota = limiter.get_quota_info();
assert_eq!(quota.limit, 500);
}
#[test]
fn test_quota_decrements_remaining() {
let config = RateLimitConfig {
requests_per_minute: 100,
burst: 50,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
let first_quota = limiter.get_quota_info();
let second_quota = limiter.get_quota_info();
assert!(second_quota.remaining <= first_quota.remaining, "Remaining should not increase");
}
#[test]
fn test_quota_reset_timestamp_is_future() {
let config = RateLimitConfig::default();
let limiter = GlobalRateLimiter::new(config);
let quota = limiter.get_quota_info();
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
assert!(quota.reset >= now, "Reset timestamp should be >= current time");
assert!(quota.reset <= now + 120, "Reset timestamp should be within 2 minutes");
}
#[test]
fn test_rate_limiter_high_burst() {
let config = RateLimitConfig {
requests_per_minute: 10,
burst: 1000, per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
for _ in 0..100 {
assert!(limiter.check_rate_limit());
}
}
#[test]
fn test_rate_limiter_low_limit() {
let config = RateLimitConfig {
requests_per_minute: 1,
burst: 1,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
assert!(limiter.check_rate_limit());
}
#[test]
fn test_config_with_zero_values_handled() {
let config = RateLimitConfig {
requests_per_minute: 0, burst: 0, per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
assert!(limiter.check_rate_limit());
}
#[test]
fn test_multiple_quota_calls_same_limiter() {
let config = RateLimitConfig::default();
let limiter = GlobalRateLimiter::new(config);
let quotas: Vec<RateLimitQuota> = (0..5).map(|_| limiter.get_quota_info()).collect();
for quota in "as {
assert_eq!(quota.limit, 100);
}
let first_reset = quotas[0].reset;
for quota in "as {
assert!(
(quota.reset as i64 - first_reset as i64).abs() <= 1,
"Reset timestamps should be within 1 second of each other"
);
}
}
#[test]
fn test_quota_remaining_never_negative() {
let config = RateLimitConfig {
requests_per_minute: 5,
burst: 5,
per_ip: false,
per_endpoint: false,
};
let limiter = GlobalRateLimiter::new(config);
for _ in 0..20 {
let quota = limiter.get_quota_info();
assert!(quota.remaining <= 100, "Remaining should be reasonable");
}
}
}