use axum::{
body::Body,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[derive(Clone)]
pub struct RateLimiter {
state: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
max_requests: usize,
window_secs: u64,
}
impl RateLimiter {
pub fn new(max_requests: usize, window_secs: u64) -> Self {
Self {
state: Arc::new(Mutex::new(HashMap::new())),
max_requests,
window_secs,
}
}
pub fn auth_default() -> Self {
Self::new(10, 60)
}
pub fn check(&self, ip: &str) -> bool {
let mut state = self.state.lock().unwrap();
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_secs);
let timestamps = state.entry(ip.to_string()).or_default();
timestamps.retain(|&t| now.duration_since(t) < window);
if timestamps.len() >= self.max_requests {
return false;
}
timestamps.push(now);
true
}
}
#[allow(dead_code)]
pub async fn rate_limit_middleware(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let limiter = request.extensions().get::<RateLimiter>().cloned();
let ip = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or("").trim().to_string())
.unwrap_or_else(|| addr.ip().to_string());
if let Some(limiter) = limiter {
if !limiter.check(&ip) {
tracing::warn!(ip = %ip, "Rate limited");
return Err(StatusCode::TOO_MANY_REQUESTS);
}
}
Ok(next.run(request).await)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_allows_within_limit() {
let limiter = RateLimiter::new(5, 60);
for _ in 0..5 {
assert!(limiter.check("127.0.0.1"));
}
}
#[test]
fn test_rate_limiter_blocks_over_limit() {
let limiter = RateLimiter::new(3, 60);
assert!(limiter.check("127.0.0.1"));
assert!(limiter.check("127.0.0.1"));
assert!(limiter.check("127.0.0.1"));
assert!(!limiter.check("127.0.0.1"));
}
#[test]
fn test_rate_limiter_per_ip() {
let limiter = RateLimiter::new(2, 60);
assert!(limiter.check("1.1.1.1"));
assert!(limiter.check("1.1.1.1"));
assert!(!limiter.check("1.1.1.1"));
assert!(limiter.check("2.2.2.2"));
assert!(limiter.check("2.2.2.2"));
assert!(!limiter.check("2.2.2.2"));
}
#[test]
fn test_rate_limiter_window_expiry() {
let limiter = RateLimiter::new(2, 0);
assert!(limiter.check("127.0.0.1"));
assert!(limiter.check("127.0.0.1"));
}
}