use axum::{
Json,
extract::{ConnectInfo, Request},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::json;
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Instant,
};
use tokio::sync::Mutex;
use crate::config::RateLimitConfig;
struct TokenBucket {
tokens: f64,
last_refill: Instant,
rate: f64,
capacity: u32,
}
impl TokenBucket {
fn new(rate: f64, capacity: u32) -> Self {
Self {
tokens: capacity as f64,
last_refill: Instant::now(),
rate,
capacity,
}
}
fn try_consume(&mut self) -> Option<u64> {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity as f64);
self.last_refill = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
None
} else {
let wait_secs = ((1.0 - self.tokens) / self.rate).ceil() as u64;
Some(wait_secs)
}
}
}
pub struct RateLimiter {
buckets: Mutex<HashMap<IpAddr, TokenBucket>>,
config: RateLimitConfig,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
buckets: Mutex::new(HashMap::new()),
config,
}
}
fn is_path_exempt(&self, path: &str) -> bool {
self.config.exempt_paths.iter().any(|exempt| {
path == exempt || path.starts_with(exempt)
})
}
fn is_ip_exempt(&self, addr: &SocketAddr) -> bool {
self.config.exempt_ips.contains(&addr.ip())
}
pub async fn check(&self, path: &str, addr: SocketAddr) -> Option<u64> {
if self.is_path_exempt(path) {
return None;
}
if self.is_ip_exempt(&addr) {
return None;
}
let mut buckets = self.buckets.lock().await;
let bucket = buckets.entry(addr.ip()).or_insert_with(|| {
TokenBucket::new(
self.config.requests_per_second as f64,
self.config.burst_size,
)
});
bucket.try_consume()
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request,
next: axum::middleware::Next,
) -> Response {
match limiter.check(req.uri().path(), addr).await {
None => next.run(req).await,
Some(retry_after) => {
let error = json!({
"error": {
"code": "rate_limited",
"message": "Too many requests",
"details": {
"retry_after_seconds": retry_after
}
}
});
(StatusCode::TOO_MANY_REQUESTS, Json(error)).into_response()
}
}
}