use axum::extract::ConnectInfo;
use axum::http::{HeaderValue, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct RateLimiter {
state: Arc<Mutex<HashMap<IpAddr, Bucket>>>,
max_requests: f64,
window_secs: f64,
max_entries: usize,
trusted_proxies: Vec<IpAddr>,
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last_check: Instant,
}
impl RateLimiter {
pub fn new(max_requests: u32, window_secs: u64) -> Self {
Self {
state: Arc::new(Mutex::new(HashMap::new())),
max_requests: f64::from(max_requests),
window_secs: window_secs as f64,
max_entries: 10_000,
trusted_proxies: Vec::new(),
}
}
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
pub fn with_trusted_proxies(mut self, proxies: Vec<IpAddr>) -> Self {
self.trusted_proxies = proxies;
self
}
pub fn default_limit() -> Self {
Self::new(100, 60)
}
pub fn strict() -> Self {
Self::new(10, 60)
}
pub fn check(&self, ip: IpAddr) -> bool {
let mut state = self.state.lock().unwrap_or_else(|poisoned| {
tracing::error!("rate limiter mutex poisoned, recovering");
poisoned.into_inner()
});
let now = Instant::now();
if state.len() > self.max_entries {
self.evict_stale(&mut state, now);
}
let bucket = state.entry(ip).or_insert_with(|| Bucket {
tokens: self.max_requests,
last_check: now,
});
replenish_tokens(bucket, now, self.max_requests, self.window_secs);
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
fn evict_stale(&self, state: &mut HashMap<IpAddr, Bucket>, now: Instant) {
let stale_threshold = std::time::Duration::from_secs_f64(self.window_secs * 2.0);
let before = state.len();
state.retain(|_, bucket| now.duration_since(bucket.last_check) < stale_threshold);
let evicted = before - state.len();
if evicted > 0 {
tracing::info!(
evicted,
remaining = state.len(),
"rate limiter eviction pass"
);
}
}
pub fn window_secs(&self) -> u64 {
self.window_secs as u64
}
}
fn replenish_tokens(bucket: &mut Bucket, now: Instant, max: f64, window: f64) {
let elapsed = now.duration_since(bucket.last_check).as_secs_f64();
let rate = max / window;
bucket.tokens = (bucket.tokens + elapsed * rate).min(max);
bucket.last_check = now;
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
connect_info: Option<ConnectInfo<SocketAddr>>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let ip = extract_client_ip(&request, connect_info, &limiter.trusted_proxies);
if limiter.check(ip) {
next.run(request).await
} else {
build_rate_limited_response(limiter.window_secs())
}
}
fn extract_client_ip(
request: &Request<axum::body::Body>,
connect_info: Option<ConnectInfo<SocketAddr>>,
trusted_proxies: &[IpAddr],
) -> IpAddr {
if let Some(ConnectInfo(addr)) = connect_info {
let direct_ip = addr.ip();
if trusted_proxies.contains(&direct_ip) {
if let Some(forwarded) = forwarded_for_ip(request) {
return forwarded;
}
}
return direct_ip;
}
if !trusted_proxies.is_empty() {
if let Some(forwarded) = forwarded_for_ip(request) {
return forwarded;
}
}
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}
fn forwarded_for_ip(request: &Request<axum::body::Body>) -> Option<IpAddr> {
request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.and_then(|s| s.trim().parse::<IpAddr>().ok())
}
fn build_rate_limited_response(retry_after: u64) -> Response {
let mut response = (StatusCode::TOO_MANY_REQUESTS, "Too many requests").into_response();
if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
response
.headers_mut()
.insert(axum::http::header::RETRY_AFTER, value);
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn allows_requests_within_limit() {
let limiter = RateLimiter::strict();
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
for i in 0..10 {
assert!(limiter.check(ip), "request {i} should be allowed");
}
}
#[test]
fn rejects_request_exceeding_limit() {
let limiter = RateLimiter::strict();
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
for _ in 0..10 {
limiter.check(ip);
}
assert!(!limiter.check(ip), "11th request should be rejected");
}
#[test]
fn different_ips_have_independent_limits() {
let limiter = RateLimiter::strict();
let ip_a = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3));
let ip_b = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 4));
for _ in 0..10 {
limiter.check(ip_a);
}
assert!(!limiter.check(ip_a), "IP A should be rate limited");
assert!(limiter.check(ip_b), "IP B should still be allowed");
}
#[test]
fn tokens_replenish_after_time_elapses() {
let limiter = RateLimiter::new(2, 1);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5));
assert!(limiter.check(ip));
assert!(limiter.check(ip));
assert!(!limiter.check(ip), "bucket should be empty");
{
let mut state = limiter.state.lock().expect("lock");
let bucket = state.get_mut(&ip).expect("bucket exists");
bucket.last_check -= std::time::Duration::from_secs(2);
}
assert!(
limiter.check(ip),
"tokens should have replenished after window elapsed"
);
}
#[test]
fn default_limit_is_100_per_60s() {
let limiter = RateLimiter::default_limit();
assert_eq!(limiter.max_requests as u32, 100);
assert_eq!(limiter.window_secs(), 60);
}
#[test]
fn strict_limit_is_10_per_60s() {
let limiter = RateLimiter::strict();
assert_eq!(limiter.max_requests as u32, 10);
assert_eq!(limiter.window_secs(), 60);
}
#[test]
fn extract_client_ip_uses_connect_info_when_not_trusted() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
let request = Request::builder()
.header("x-forwarded-for", "10.0.0.1")
.body(axum::body::Body::empty())
.expect("build request");
let ip = extract_client_ip(&request, Some(ConnectInfo(addr)), &[]);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
}
#[test]
fn extract_client_ip_defaults_to_localhost() {
let request = Request::builder()
.body(axum::body::Body::empty())
.expect("build request");
let ip = extract_client_ip(&request, None, &[]);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::LOCALHOST));
}
#[test]
fn untrusted_connection_ignores_forwarded_for() {
let proxy_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)), 3000);
let request = Request::builder()
.header("x-forwarded-for", "10.0.0.99")
.body(axum::body::Body::empty())
.expect("build request");
let ip = extract_client_ip(&request, Some(ConnectInfo(proxy_addr)), &[]);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)));
}
#[test]
fn trusted_proxy_uses_forwarded_for() {
let proxy_ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
let proxy_addr = SocketAddr::new(proxy_ip, 3000);
let request = Request::builder()
.header("x-forwarded-for", "10.0.0.99, 172.16.0.1")
.body(axum::body::Body::empty())
.expect("build request");
let trusted = vec![proxy_ip];
let ip = extract_client_ip(&request, Some(ConnectInfo(proxy_addr)), &trusted);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 99)));
}
#[test]
fn empty_trusted_proxies_always_uses_connect_info() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 50)), 9000);
let request = Request::builder()
.header("x-forwarded-for", "10.0.0.1")
.body(axum::body::Body::empty())
.expect("build request");
let ip = extract_client_ip(&request, Some(ConnectInfo(addr)), &[]);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 50)));
}
#[test]
fn evicts_stale_entries_when_over_capacity() {
let limiter = RateLimiter::new(5, 1).with_max_entries(3);
for i in 0..5u8 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 1, i));
limiter.check(ip);
}
{
let state = limiter.state.lock().expect("lock");
assert_eq!(state.len(), 5, "should have 5 entries before eviction");
}
{
let mut state = limiter.state.lock().expect("lock");
for bucket in state.values_mut() {
bucket.last_check -= std::time::Duration::from_secs(3);
}
}
let new_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 2, 1));
limiter.check(new_ip);
let state = limiter.state.lock().expect("lock");
assert_eq!(state.len(), 1, "stale entries should have been evicted");
}
#[test]
fn does_not_evict_active_entries() {
let limiter = RateLimiter::new(5, 1).with_max_entries(2);
for i in 0..3u8 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 3, i));
limiter.check(ip);
}
let trigger_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 3, 99));
limiter.check(trigger_ip);
let state = limiter.state.lock().expect("lock");
assert_eq!(state.len(), 4, "active entries should not be evicted");
}
}