use crate::middleware::Middleware;
use crate::{Request, Response};
use http::StatusCode;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, trace, warn};
use super::Next;
#[derive(Clone)]
pub struct RateLimitConfig {
pub max_requests: usize,
pub window: Duration,
pub key_extractor: Option<Arc<dyn Fn(&Request) -> String + Send + Sync>>,
pub error_message: String,
pub include_headers: bool,
pub allow_burst: bool,
pub burst_multiplier: f32,
}
impl std::fmt::Debug for RateLimitConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitConfig")
.field("max_requests", &self.max_requests)
.field("window", &self.window)
.field("key_extractor", &"<function>") .field("error_message", &self.error_message)
.field("include_headers", &self.include_headers)
.field("allow_burst", &self.allow_burst)
.field("burst_multiplier", &self.burst_multiplier)
.finish()
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60), key_extractor: None,
error_message: "Rate limit exceeded. Please try again later.".to_string(),
include_headers: true,
allow_burst: false,
burst_multiplier: 1.5,
}
}
}
impl RateLimitConfig {
pub fn new(max_requests: usize, window: Duration) -> Self {
Self {
max_requests,
window,
..Default::default()
}
}
pub fn per_minute(max_requests: usize) -> Self {
Self::new(max_requests, Duration::from_secs(60))
}
pub fn per_second(max_requests: usize) -> Self {
Self::new(max_requests, Duration::from_secs(1))
}
pub fn per_hour(max_requests: usize) -> Self {
Self::new(max_requests, Duration::from_secs(3600))
}
pub fn with_key_extractor<F>(mut self, extractor: F) -> Self
where
F: Fn(&Request) -> String + Send + Sync + 'static,
{
self.key_extractor = Some(Arc::new(extractor));
self
}
pub fn with_error_message(mut self, message: impl Into<String>) -> Self {
self.error_message = message.into();
self
}
pub fn with_headers(mut self, include: bool) -> Self {
self.include_headers = include;
self
}
pub fn with_burst(mut self, multiplier: f32) -> Self {
self.allow_burst = true;
self.burst_multiplier = multiplier.max(1.0);
self
}
pub fn without_burst(mut self) -> Self {
self.allow_burst = false;
self
}
}
#[derive(Debug, Clone)]
struct RequestWindow {
timestamps: Vec<Instant>,
last_cleanup: Instant,
total_requests: u64,
}
impl RequestWindow {
fn new() -> Self {
Self {
timestamps: Vec::new(),
last_cleanup: Instant::now(),
total_requests: 0,
}
}
fn cleanup_expired(&mut self, window: Duration) {
let now = Instant::now();
let cutoff = now.checked_sub(window).unwrap_or(now);
self.timestamps.retain(|×tamp| timestamp > cutoff);
self.last_cleanup = now;
trace!(
"Cleaned up expired timestamps, {} remaining",
self.timestamps.len()
);
}
fn is_allowed(
&mut self,
max_requests: usize,
window: Duration,
allow_burst: bool,
burst_multiplier: f32,
) -> bool {
self.cleanup_expired(window);
let current_count = self.timestamps.len();
let effective_limit = if allow_burst {
((max_requests as f32) * burst_multiplier) as usize
} else {
max_requests
};
if current_count >= effective_limit {
debug!(
"Rate limit exceeded: {} >= {}",
current_count, effective_limit
);
false
} else {
let now = Instant::now();
self.timestamps.push(now);
self.total_requests += 1;
trace!("Request allowed: {}/{}", current_count + 1, effective_limit);
true
}
}
fn remaining_requests(&self, max_requests: usize) -> usize {
max_requests.saturating_sub(self.timestamps.len())
}
fn reset_time(&self, window: Duration) -> Option<Instant> {
self.timestamps.first().map(|&first| first + window)
}
fn total_requests(&self) -> u64 {
self.total_requests
}
}
pub struct RateLimitingMiddleware {
config: RateLimitConfig,
windows: Arc<RwLock<HashMap<String, RequestWindow>>>,
_cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
impl RateLimitingMiddleware {
pub fn new(config: RateLimitConfig) -> Self {
let windows = Arc::new(RwLock::new(HashMap::new()));
let cleanup_handle = Self::start_cleanup_task(Arc::clone(&windows), config.window);
Self {
config,
windows,
_cleanup_handle: Some(cleanup_handle),
}
}
pub fn per_minute(max_requests: usize) -> Self {
Self::new(RateLimitConfig::per_minute(max_requests))
}
pub fn per_second(max_requests: usize) -> Self {
Self::new(RateLimitConfig::per_second(max_requests))
}
pub fn per_hour(max_requests: usize) -> Self {
Self::new(RateLimitConfig::per_hour(max_requests))
}
fn start_cleanup_task(
windows: Arc<RwLock<HashMap<String, RequestWindow>>>,
window_duration: Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
let mut map = windows.write().await;
let now = Instant::now();
let cleanup_threshold = window_duration * 3;
let initial_size = map.len();
map.retain(|_, window| now.duration_since(window.last_cleanup) < cleanup_threshold);
let removed = initial_size - map.len();
if removed > 0 {
debug!("Cleaned up {} stale rate limit entries", removed);
}
}
})
}
fn extract_key(&self, req: &Request) -> String {
if let Some(ref extractor) = self.config.key_extractor {
extractor(req)
} else {
self.extract_ip_key(req)
}
}
fn extract_ip_key(&self, req: &Request) -> String {
if let Some(forwarded) = req.header("x-forwarded-for") {
if let Some(ip) = forwarded.split(',').next() {
let clean_ip = ip.trim();
if !clean_ip.is_empty() {
return format!("ip:{}", clean_ip);
}
}
}
if let Some(real_ip) = req.header("x-real-ip") {
let clean_ip = real_ip.trim();
if !clean_ip.is_empty() {
return format!("ip:{}", clean_ip);
}
}
"ip:unknown".to_string()
}
pub async fn get_stats(&self, key: &str) -> Option<RateLimitStats> {
let windows = self.windows.read().await;
windows.get(key).map(|window| RateLimitStats {
current_requests: window.timestamps.len(),
total_requests: window.total_requests(),
remaining_requests: window.remaining_requests(self.config.max_requests),
reset_time: window.reset_time(self.config.window),
})
}
pub async fn clear_all(&self) {
let mut windows = self.windows.write().await;
windows.clear();
debug!("Cleared all rate limit data");
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub current_requests: usize,
pub total_requests: u64,
pub remaining_requests: usize,
pub reset_time: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub limit: usize,
pub remaining: usize,
pub exceeded: bool,
pub reset_time: Option<Instant>,
pub key: String,
}
#[async_trait::async_trait]
impl Middleware for RateLimitingMiddleware {
async fn handle(&self, mut req: Request, next: Next) -> Response {
let key = self.extract_key(&req);
debug!("Checking rate limit for key: {}", key);
let mut windows = self.windows.write().await;
let window = windows
.entry(key.clone())
.or_insert_with(RequestWindow::new);
let is_allowed = window.is_allowed(
self.config.max_requests,
self.config.window,
self.config.allow_burst,
self.config.burst_multiplier,
);
let remaining = window.remaining_requests(self.config.max_requests);
let reset_time = window.reset_time(self.config.window);
if !is_allowed {
warn!("Rate limit exceeded for key: {}", key);
if self.config.include_headers {
req.insert_extension(RateLimitInfo {
limit: self.config.max_requests,
remaining: 0,
exceeded: true,
reset_time,
key: key.clone(),
});
}
return Response::json(RateLimitError {
message: self.config.error_message.clone(),
limit: self.config.max_requests,
window: self.config.window,
key,
reset_time,
})
.with_status(StatusCode::TOO_MANY_REQUESTS);
}
if self.config.include_headers {
req.insert_extension(RateLimitInfo {
limit: self.config.max_requests,
remaining,
exceeded: false,
reset_time,
key: key.clone(),
});
}
trace!("Rate limit check passed for key: {}", key);
let mut res = next.run(req.clone()).await;
if !self.config.include_headers {
return res;
}
if let Some(rate_limit_info) = req.get_extension::<RateLimitInfo>() {
if let Ok(limit_value) = rate_limit_info.limit.to_string().parse() {
res.headers.insert("x-ratelimit-limit", limit_value);
}
if let Ok(remaining_value) = rate_limit_info.remaining.to_string().parse() {
res.headers.insert("x-ratelimit-remaining", remaining_value);
}
if let Some(reset_time) = rate_limit_info.reset_time {
if let Ok(now) = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)
{
let reset_duration = reset_time.duration_since(std::time::Instant::now());
let reset_timestamp = now.as_secs() + reset_duration.as_secs();
if let Ok(reset_value) = reset_timestamp.to_string().parse() {
res.headers.insert("x-ratelimit-reset", reset_value);
}
}
}
if let Some(reset_time) = rate_limit_info.reset_time {
let now = std::time::Instant::now();
if reset_time > now {
let seconds_until_reset = reset_time.duration_since(now).as_secs();
if let Ok(reset_after_value) = seconds_until_reset.to_string().parse() {
res.headers
.insert("x-ratelimit-reset-after", reset_after_value);
}
}
}
if rate_limit_info.exceeded && res.status == http::StatusCode::TOO_MANY_REQUESTS {
let retry_after = self.config.window.as_secs();
if let Ok(retry_value) = retry_after.to_string().parse() {
res.headers.insert("retry-after", retry_value);
}
}
#[cfg(debug_assertions)]
{
if let Ok(key_value) = rate_limit_info.key.parse() {
res.headers.insert("x-ratelimit-key", key_value);
}
}
tracing::trace!(
"Added rate limit headers: limit={}, remaining={}, exceeded={}",
rate_limit_info.limit,
rate_limit_info.remaining,
rate_limit_info.exceeded
);
}
res
}
}
#[derive(Debug, Serialize)]
pub struct RateLimitError {
pub message: String,
pub limit: usize,
pub window: Duration,
pub key: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_instant")]
pub reset_time: Option<Instant>,
}
fn serialize_instant<S>(
instant: &Option<Instant>,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match instant {
Some(i) => {
let now = Instant::now();
let remaining = i.saturating_duration_since(now);
serializer.serialize_u64(remaining.as_secs())
}
None => serializer.serialize_none(),
}
}
impl std::fmt::Display for RateLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} (limit: {} requests per {:?}, key: {})",
self.message, self.limit, self.window, self.key
)
}
}
impl Drop for RateLimitingMiddleware {
fn drop(&mut self) {
if let Some(handle) = self._cleanup_handle.take() {
handle.abort();
}
}
}