use crate::middleware::Next;
use crate::request::Request;
use crate::response::Response;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
pub max_requests: u32,
pub window: Duration,
pub message: String,
pub skip_paths: Vec<String>,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
message: "Too many requests. Please try again later.".to_string(),
skip_paths: vec![],
}
}
}
impl RateLimiterConfig {
pub fn new(max_requests: u32, window_secs: u64) -> Self {
Self {
max_requests,
window: Duration::from_secs(window_secs),
..Default::default()
}
}
pub fn message(mut self, msg: &str) -> Self {
self.message = msg.to_string();
self
}
pub fn skip(mut self, paths: Vec<&str>) -> Self {
self.skip_paths = paths.iter().map(|s| s.to_string()).collect();
self
}
}
#[derive(Debug, Clone)]
struct RateLimitEntry {
count: u32,
window_start: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
config: RateLimiterConfig,
entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
}
impl RateLimiter {
pub fn new(config: RateLimiterConfig) -> Self {
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn check(&self, key: &str) -> RateLimitResult {
let mut entries = self.entries.write();
let now = Instant::now();
let entry = entries.entry(key.to_string()).or_insert(RateLimitEntry {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) >= self.config.window {
entry.count = 0;
entry.window_start = now;
}
entry.count += 1;
if entry.count > self.config.max_requests {
let retry_after = self.config.window.as_secs() as u32
- now.duration_since(entry.window_start).as_secs() as u32;
RateLimitResult::Exceeded {
retry_after,
limit: self.config.max_requests,
remaining: 0,
}
} else {
RateLimitResult::Allowed {
limit: self.config.max_requests,
remaining: self.config.max_requests - entry.count,
}
}
}
pub fn config(&self) -> &RateLimiterConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub enum RateLimitResult {
Allowed {
limit: u32,
remaining: u32,
},
Exceeded {
retry_after: u32,
limit: u32,
remaining: u32,
},
}
pub fn rate_limiter(
config: RateLimiterConfig,
) -> impl Fn(
Request,
Response,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Send
+ Sync
+ Clone
+ 'static {
let limiter = RateLimiter::new(config.clone());
move |req: Request, res: Response, next: Next| {
let limiter = limiter.clone();
let config = config.clone();
Box::pin(async move {
if config.skip_paths.iter().any(|p| req.path().starts_with(p)) {
return next(req, res).await;
}
let key = req.ip().to_string();
match limiter.check(&key) {
RateLimitResult::Allowed { limit, remaining } => {
let res = res
.header("X-RateLimit-Limit", &limit.to_string())
.header("X-RateLimit-Remaining", &remaining.to_string());
next(req, res).await
}
RateLimitResult::Exceeded {
retry_after,
limit,
remaining: _,
} => res
.status(429)
.header("X-RateLimit-Limit", &limit.to_string())
.header("X-RateLimit-Remaining", "0")
.header("Retry-After", &retry_after.to_string())
.json(serde_json::json!({
"error": "Too Many Requests",
"message": config.message,
"retry_after": retry_after
})),
}
})
}
}
pub fn simple_rate_limit(
max_requests: u32,
window_secs: u64,
) -> impl Fn(
Request,
Response,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Send
+ Sync
+ Clone
+ 'static {
rate_limiter(RateLimiterConfig::new(max_requests, window_secs))
}