use axum::{
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
pub struct TokenBucket {
tokens: f64,
capacity: f64,
rate: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new(capacity: f64, rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
rate,
last_refill: Instant::now(),
}
}
pub fn try_acquire(&mut self) -> Result<(), f64> {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
Ok(())
} else {
let deficit = 1.0 - self.tokens;
let retry_after = deficit / self.rate;
Err(retry_after)
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
self.last_refill = now;
}
}
#[derive(Clone)]
pub struct RateLimiter(pub Arc<Mutex<TokenBucket>>);
impl RateLimiter {
pub fn new(capacity: f64, rate_per_second: f64) -> Self {
Self(Arc::new(Mutex::new(TokenBucket::new(
capacity,
rate_per_second,
))))
}
}
pub async fn rate_limit_middleware(
limiter: Option<axum::extract::Extension<RateLimiter>>,
request: Request,
next: Next,
) -> Response {
let Some(axum::extract::Extension(limiter)) = limiter else {
return next.run(request).await;
};
let mut bucket = limiter.0.lock().await;
match bucket.try_acquire() {
Ok(()) => {
drop(bucket);
next.run(request).await
}
Err(retry_after) => {
drop(bucket);
let retry_secs = retry_after.ceil() as u64;
let body = serde_json::json!({
"error": {
"message": "Rate limit exceeded",
"type": "rate_limit_error",
}
});
let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
if let Ok(val) = retry_secs.to_string().parse() {
resp.headers_mut().insert("retry-after", val);
}
resp
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bucket_allows_within_capacity() {
let mut bucket = TokenBucket::new(5.0, 1.0);
for _ in 0..5 {
assert!(bucket.try_acquire().is_ok());
}
assert!(bucket.try_acquire().is_err());
}
#[test]
fn test_bucket_refills() {
let mut bucket = TokenBucket::new(1.0, 1000.0); assert!(bucket.try_acquire().is_ok());
assert!(bucket.try_acquire().is_err());
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(bucket.try_acquire().is_ok());
}
#[test]
fn test_retry_after_is_positive() {
let mut bucket = TokenBucket::new(1.0, 1.0);
bucket.try_acquire().ok(); let err = bucket.try_acquire().unwrap_err();
assert!(err > 0.0, "retry_after should be positive");
}
#[tokio::test]
async fn test_rate_limit_middleware_allows() {
use axum::{body::Body, http::Request as HttpRequest, middleware, routing::get, Router};
use tower::ServiceExt;
let limiter = RateLimiter::new(10.0, 10.0);
let app = Router::new()
.route("/test", get(|| async { "ok" }))
.layer(middleware::from_fn(rate_limit_middleware))
.layer(axum::Extension(limiter));
let req = HttpRequest::builder()
.uri("/test")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}