use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::body::Body;
use axum::extract::{ConnectInfo, Request};
use axum::http::{header, HeaderValue, Response, StatusCode};
use axum::middleware::Next;
use axum::Router;
#[derive(Clone, Debug)]
pub enum KeyBy {
Ip,
Header(&'static str),
Global,
}
#[derive(Clone)]
pub struct RateLimitLayer {
capacity: u32,
refill_period: Duration,
key_by: KeyBy,
store: Arc<tokio::sync::Mutex<HashMap<String, Bucket>>>,
}
#[derive(Clone, Copy, Debug)]
struct Bucket {
tokens: f64,
last_refill: Instant,
}
impl RateLimitLayer {
#[must_use]
pub fn per_ip(capacity: u32, refill_period: Duration) -> Self {
Self::new(capacity, refill_period, KeyBy::Ip)
}
#[must_use]
pub fn per_header(header: &'static str, capacity: u32, refill_period: Duration) -> Self {
Self::new(capacity, refill_period, KeyBy::Header(header))
}
#[must_use]
pub fn global(capacity: u32, refill_period: Duration) -> Self {
Self::new(capacity, refill_period, KeyBy::Global)
}
fn new(capacity: u32, refill_period: Duration, key_by: KeyBy) -> Self {
Self {
capacity,
refill_period,
key_by,
store: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
fn rate_per_sec(&self) -> f64 {
if self.refill_period.is_zero() {
return f64::MAX;
}
self.capacity as f64 / self.refill_period.as_secs_f64()
}
async fn take(&self, key: &str) -> Result<(u32, u64), u64> {
let now = Instant::now();
let cap = self.capacity as f64;
let rate = self.rate_per_sec();
let mut store = self.store.lock().await;
let bucket = store.entry(key.to_owned()).or_insert(Bucket {
tokens: cap,
last_refill: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * rate).min(cap);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok((bucket.tokens.floor() as u32, 0))
} else {
let need = 1.0 - bucket.tokens;
let retry = if rate > 0.0 {
(need / rate).ceil() as u64
} else {
u64::MAX
};
Err(retry.max(1))
}
}
fn extract_key(&self, req: &Request<Body>) -> String {
match &self.key_by {
KeyBy::Ip => req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.ip().to_string())
.unwrap_or_else(|| "<no-ip>".to_owned()),
KeyBy::Header(name) => req
.headers()
.get(*name)
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.unwrap_or_else(|| "<no-header>".to_owned()),
KeyBy::Global => "<global>".to_owned(),
}
}
}
pub trait RateLimitRouterExt {
#[must_use]
fn rate_limit(self, layer: RateLimitLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> RateLimitRouterExt for Router<S> {
fn rate_limit(self, layer: RateLimitLayer) -> Self {
let cfg = Arc::new(layer);
self.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move { handle(cfg, req, next).await }
},
))
}
}
async fn handle(cfg: Arc<RateLimitLayer>, req: Request<Body>, next: Next) -> Response<Body> {
let key = cfg.extract_key(&req);
match cfg.take(&key).await {
Ok((remaining, _)) => {
let mut response = next.run(req).await;
let _ = response.headers_mut().insert(
"x-ratelimit-limit",
HeaderValue::from_str(&cfg.capacity.to_string()).unwrap(),
);
let _ = response.headers_mut().insert(
"x-ratelimit-remaining",
HeaderValue::from_str(&remaining.to_string()).unwrap(),
);
response
}
Err(retry_secs) => Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(header::RETRY_AFTER, retry_secs.to_string())
.header("x-ratelimit-limit", cfg.capacity.to_string())
.header("x-ratelimit-remaining", "0")
.body(Body::from(format!(
r#"{{"error":"rate limit exceeded","retry_after":{retry_secs}}}"#
)))
.unwrap(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn first_n_requests_succeed_under_capacity() {
let l = RateLimitLayer::global(3, Duration::from_secs(60));
for _ in 0..3 {
assert!(l.take("k").await.is_ok());
}
}
#[tokio::test]
async fn n_plus_one_request_is_rejected() {
let l = RateLimitLayer::global(2, Duration::from_secs(60));
assert!(l.take("k").await.is_ok());
assert!(l.take("k").await.is_ok());
let result = l.take("k").await;
assert!(result.is_err());
let retry_after = result.unwrap_err();
assert!(retry_after >= 1);
}
#[tokio::test]
async fn separate_keys_have_independent_buckets() {
let l = RateLimitLayer::global(1, Duration::from_secs(60));
assert!(l.take("alice").await.is_ok());
assert!(l.take("alice").await.is_err());
assert!(l.take("bob").await.is_ok());
}
#[tokio::test]
async fn refill_replenishes_tokens_over_time() {
let l = RateLimitLayer::global(10, Duration::from_millis(100));
for _ in 0..10 {
assert!(l.take("k").await.is_ok());
}
assert!(l.take("k").await.is_err());
tokio::time::sleep(Duration::from_millis(35)).await;
assert!(
l.take("k").await.is_ok(),
"should have refilled at least 1 token"
);
}
}