use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64, last_refill: Instant,
}
impl TokenBucket {
fn new(max_tokens: f64, refill_rate: f64) -> Self {
Self {
tokens: max_tokens,
max_tokens,
refill_rate,
last_refill: Instant::now(),
}
}
fn try_consume(&mut self, count: f64) -> bool {
self.refill();
if self.tokens >= count {
self.tokens -= count;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
self.last_refill = now;
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst_size: u32,
pub mutating_requests_per_minute: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 120,
burst_size: 20,
mutating_requests_per_minute: 60,
}
}
}
#[derive(Clone)]
pub struct RateLimiter {
config: RateLimitConfig,
buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check(&self, key: &str, is_mutating: bool) -> bool {
let mut buckets = self.buckets.write().await;
let bucket_key = if is_mutating {
format!("{key}:write")
} else {
format!("{key}:read")
};
let rate = if is_mutating {
self.config.mutating_requests_per_minute as f64 / 60.0
} else {
self.config.requests_per_minute as f64 / 60.0
};
let bucket = buckets
.entry(bucket_key)
.or_insert_with(|| TokenBucket::new(self.config.burst_size as f64, rate));
bucket.try_consume(1.0)
}
pub async fn cleanup(&self, max_age: Duration) {
let mut buckets = self.buckets.write().await;
let now = Instant::now();
buckets.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_age);
}
}
pub struct RateLimitExceeded;
impl IntoResponse for RateLimitExceeded {
fn into_response(self) -> Response {
let body = axum::Json(serde_json::json!({
"error": "Rate limit exceeded",
"status": 429,
"retry_after_seconds": 60
}));
(StatusCode::TOO_MANY_REQUESTS, body).into_response()
}
}
fn get_rate_limit_key(request: &Request<Body>) -> String {
if let Some(principal) = request.extensions().get::<super::auth::Principal>() {
return format!("principal:{}", principal.id);
}
if let Some(forwarded) = request.headers().get("x-forwarded-for")
&& let Ok(value) = forwarded.to_str()
&& let Some(ip) = value.split(',').next()
{
return format!("ip:{}", ip.trim());
}
"unknown".to_string()
}
fn is_mutating_method(method: &Method) -> bool {
matches!(
*method,
Method::POST | Method::PUT | Method::DELETE | Method::PATCH
)
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
request: Request<Body>,
next: Next,
) -> Result<Response, RateLimitExceeded> {
let key = get_rate_limit_key(&request);
let is_mutating = is_mutating_method(request.method());
if !limiter.check(&key, is_mutating).await {
return Err(RateLimitExceeded);
}
Ok(next.run(request).await)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_bucket_basic() {
let mut bucket = TokenBucket::new(5.0, 1.0);
for _ in 0..5 {
assert!(bucket.try_consume(1.0));
}
assert!(!bucket.try_consume(1.0));
}
#[tokio::test]
async fn test_rate_limiter() {
let config = RateLimitConfig {
requests_per_minute: 60,
burst_size: 5,
mutating_requests_per_minute: 30,
};
let limiter = RateLimiter::new(config);
for _ in 0..5 {
assert!(limiter.check("test-key", false).await);
}
assert!(!limiter.check("test-key", false).await);
}
}