use axum::extract::{ConnectInfo, Request};
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Json, Response};
use dashmap::DashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct TokenBucket {
capacity: f64,
refill_per_ms: f64,
tokens: f64,
last_refill: Instant,
last_seen_ms: u64,
}
impl TokenBucket {
pub fn new(capacity: u32, refill_per_ms: f64, now: Instant, now_ms: u64) -> Self {
Self {
capacity: capacity as f64,
refill_per_ms,
tokens: capacity as f64,
last_refill: now,
last_seen_ms: now_ms,
}
}
pub fn try_consume(&mut self, now: Instant, now_ms: u64) -> bool {
let elapsed_ms = now
.saturating_duration_since(self.last_refill)
.as_secs_f64()
* 1000.0;
if elapsed_ms > 0.0 {
self.tokens = (self.tokens + elapsed_ms * self.refill_per_ms).min(self.capacity);
self.last_refill = now;
}
self.last_seen_ms = now_ms;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
pub const MAX_RPM: u32 = 60_000;
pub struct RateLimiter {
buckets: DashMap<IpAddr, TokenBucket>,
capacity: u32,
refill_per_ms: f64,
effective_rpm: u32,
}
impl RateLimiter {
pub fn new(rpm: u32, burst: u32) -> Self {
if rpm > MAX_RPM {
tracing::warn!(
rpm,
max_rpm = MAX_RPM,
"rate_limit_per_minute exceeds {MAX_RPM}; clamped to {MAX_RPM} (1 ms minimum interval)"
);
}
let effective_rpm = rpm.min(MAX_RPM);
let refill_per_ms = effective_rpm as f64 / 60_000.0;
Self {
buckets: DashMap::new(),
capacity: burst.max(1),
refill_per_ms,
effective_rpm,
}
}
pub fn interval_ms(&self) -> u64 {
(60_000u64 / self.effective_rpm.max(1) as u64).max(1)
}
pub fn check(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let now_ms = unix_ms();
let mut entry = self
.buckets
.entry(ip)
.or_insert_with(|| TokenBucket::new(self.capacity, self.refill_per_ms, now, now_ms));
entry.try_consume(now, now_ms)
}
pub fn evict_stale(&self, older_than: Duration) {
let cutoff = unix_ms().saturating_sub(older_than.as_millis() as u64);
self.buckets
.retain(|_, bucket| bucket.last_seen_ms >= cutoff);
}
#[cfg(test)]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.buckets.len()
}
}
fn unix_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn extract_client_ip(req: &Request) -> Option<IpAddr> {
let headers = req.headers();
if let Some(value) = headers.get("x-forwarded-for")
&& let Ok(s) = value.to_str()
{
let first = s.split(',').next().unwrap_or("").trim();
if let Ok(ip) = first.parse::<IpAddr>() {
return Some(ip);
}
}
if let Some(value) = headers.get("x-real-ip")
&& let Ok(s) = value.to_str()
&& let Ok(ip) = s.trim().parse::<IpAddr>()
{
return Some(ip);
}
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip())
}
pub async fn rate_limit_middleware(
limiter: Arc<RateLimiter>,
req: Request,
next: Next,
) -> Response {
let ip = extract_client_ip(&req).unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
if limiter.check(ip) {
next.run(req).await
} else {
tracing::debug!(client_ip = %ip, "rate limit rejected request");
(
StatusCode::TOO_MANY_REQUESTS,
[(axum::http::header::RETRY_AFTER, "60")],
Json(serde_json::json!({
"error": "Too many requests",
"code": "rate_limited",
})),
)
.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{HeaderValue, Request as HttpRequest};
#[test]
fn test_token_bucket_allows_within_capacity() {
let now = Instant::now();
let mut bucket = TokenBucket::new(5, 0.0, now, unix_ms());
for i in 0..5 {
assert!(bucket.try_consume(now, unix_ms()), "call {i} must succeed");
}
assert!(
!bucket.try_consume(now, unix_ms()),
"6th call must be rate-limited"
);
}
#[test]
fn test_token_bucket_refills_over_time() {
let start = Instant::now();
let mut bucket = TokenBucket::new(2, 1.0, start, unix_ms());
assert!(bucket.try_consume(start, unix_ms()));
assert!(bucket.try_consume(start, unix_ms()));
assert!(
!bucket.try_consume(start, unix_ms()),
"should be drained after 2 consumes"
);
let later = start + Duration::from_millis(3);
assert!(
bucket.try_consume(later, unix_ms()),
"should refill after 3 ms"
);
}
#[test]
fn test_rate_limiter_per_ip_isolation() {
let limiter = RateLimiter::new(1, 1);
let a: IpAddr = "10.0.0.1".parse().unwrap();
let b: IpAddr = "10.0.0.2".parse().unwrap();
assert!(limiter.check(a), "A first call allowed");
assert!(
limiter.check(b),
"B first call allowed (independent bucket)"
);
assert!(!limiter.check(a), "A second call rate-limited");
assert!(!limiter.check(b), "B second call rate-limited");
}
#[test]
fn test_rate_limiter_refill_formula_matches_v1_06() {
for &rpm in &[1u32, 10, 30, 60, 600, 60_000] {
let limiter = RateLimiter::new(rpm, 1);
let ip: IpAddr = "10.0.0.3".parse().unwrap();
assert!(limiter.check(ip), "rpm={rpm}: initial burst allowed");
assert!(
!limiter.check(ip),
"rpm={rpm}: second immediate call blocked"
);
let mut guard = limiter.buckets.get_mut(&ip).expect("bucket exists");
let interval_ms = (60_000u64 / rpm as u64).max(1);
let later = guard.last_refill + Duration::from_millis(interval_ms);
assert!(
guard.try_consume(later, unix_ms()),
"rpm={rpm}: 1 token must refill after {interval_ms} ms",
);
}
}
#[test]
fn test_extract_ip_prefers_forwarded_for() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut().insert(
"x-forwarded-for",
HeaderValue::from_static("203.0.113.42 , 10.0.0.1"),
);
req.headers_mut()
.insert("x-real-ip", HeaderValue::from_static("198.51.100.7"));
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)));
let got = extract_client_ip(&req).expect("XFF must be parsed");
assert_eq!(got, "203.0.113.42".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_ip_falls_back_to_connect_info() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
55555,
)));
let got = extract_client_ip(&req).expect("ConnectInfo fallback");
assert_eq!(got, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn test_extract_ip_uses_real_ip_when_forwarded_for_garbage() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut()
.insert("x-forwarded-for", HeaderValue::from_static("not-an-ip"));
req.headers_mut()
.insert("x-real-ip", HeaderValue::from_static("198.51.100.7"));
let got = extract_client_ip(&req).expect("X-Real-IP fallback");
assert_eq!(got, "198.51.100.7".parse::<IpAddr>().unwrap());
}
#[test]
fn test_eviction_removes_stale() {
let limiter = RateLimiter::new(60, 1);
let fresh: IpAddr = "10.0.0.4".parse().unwrap();
let stale: IpAddr = "10.0.0.5".parse().unwrap();
assert!(limiter.check(fresh));
assert!(limiter.check(stale));
{
let mut guard = limiter.buckets.get_mut(&stale).expect("stale bucket");
guard.last_seen_ms = unix_ms().saturating_sub(10 * 60_000); }
limiter.evict_stale(Duration::from_secs(60));
assert_eq!(limiter.len(), 1, "stale bucket should be evicted");
assert!(
limiter.buckets.contains_key(&fresh),
"fresh bucket must survive eviction"
);
}
#[tokio::test]
async fn test_rate_limit_middleware_blocks_over_limit() {
use axum::Router;
use axum::routing::get;
let limiter = Arc::new(RateLimiter::new(60, 1)); let app =
Router::new()
.route("/test", get(|| async { "ok" }))
.layer(axum::middleware::from_fn(move |req, next| {
let limiter = limiter.clone();
async move { rate_limit_middleware(limiter, req, next).await }
}));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let client = reqwest::Client::new();
let r1 = client
.get(format!("http://127.0.0.1:{port}/test"))
.send()
.await
.unwrap();
assert_eq!(r1.status(), 200, "first request within burst must succeed");
let r2 = client
.get(format!("http://127.0.0.1:{port}/test"))
.send()
.await
.unwrap();
assert_eq!(
r2.status(),
429,
"second request over burst must be rate-limited"
);
assert_eq!(
r2.headers()
.get(axum::http::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok()),
Some("60"),
"429 response must include Retry-After: 60"
);
let body_text = r2.text().await.unwrap();
let body: serde_json::Value = serde_json::from_str(&body_text).unwrap();
assert_eq!(
body["code"], "rate_limited",
"error code must be 'rate_limited'"
);
}
}