use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::num::NonZeroU32;
use std::sync::Arc;
use axum::extract::{ConnectInfo, Request};
use axum::http::{HeaderValue, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use governor::clock::{Clock, DefaultClock};
use governor::state::keyed::DashMapStateStore;
use governor::{Quota, RateLimiter};
use serde_json::json;
type KeyedLimiter = RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock>;
#[derive(Clone)]
pub struct RateLimitState {
limiter: Arc<KeyedLimiter>,
burst: u32,
}
pub fn per_minute(requests_per_minute: u32) -> RateLimitState {
let quota = Quota::per_minute(NonZeroU32::new(requests_per_minute).expect("burst must be > 0"));
RateLimitState {
limiter: Arc::new(RateLimiter::keyed(quota)),
burst: requests_per_minute,
}
}
pub async fn rate_limit(req: Request, next: Next) -> Response {
let state = match req.extensions().get::<RateLimitState>() {
Some(s) => s.clone(),
None => return next.run(req).await,
};
let ip = extract_client_ip(&req);
match state.limiter.check_key(&ip) {
Ok(_) => {
let mut resp = next.run(req).await;
resp.headers_mut()
.insert("x-ratelimit-limit", HeaderValue::from(state.burst));
resp
}
Err(not_until) => {
let retry_after = not_until.wait_time_from(DefaultClock::default().now());
let retry_secs = retry_after.as_secs().max(1);
let body = json!({
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": "Too many requests, please try again later",
"retry_after_secs": retry_secs,
}
});
let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
resp.headers_mut()
.insert("retry-after", HeaderValue::from(retry_secs));
resp.headers_mut()
.insert("x-ratelimit-limit", HeaderValue::from(state.burst));
resp
}
}
}
fn extract_client_ip(req: &Request) -> IpAddr {
if let Some(forwarded) = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
&& let Some(first) = forwarded.split(',').next()
&& let Ok(ip) = first.trim().parse::<IpAddr>()
{
return ip;
}
if let Some(real_ip) = req.headers().get("x-real-ip").and_then(|v| v.to_str().ok())
&& let Ok(ip) = real_ip.trim().parse::<IpAddr>()
{
return ip;
}
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip())
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::middleware as axum_mw;
use axum::routing::get;
use axum::{Extension, Router};
use http_body_util::BodyExt;
use serde_json::Value as JsonValue;
use tower::ServiceExt;
use super::*;
async fn ok_handler() -> &'static str {
"ok"
}
fn test_app(limiter: RateLimitState) -> Router {
Router::new()
.route("/test", get(ok_handler))
.layer(axum_mw::from_fn(rate_limit))
.layer(Extension(limiter))
}
fn test_request() -> Request<Body> {
Request::builder()
.uri("/test")
.header("x-forwarded-for", "1.2.3.4")
.body(Body::empty())
.unwrap()
}
#[tokio::test]
async fn allows_requests_within_limit() {
let limiter = per_minute(5);
let app = test_app(limiter);
let resp = app.oneshot(test_request()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key("x-ratelimit-limit"));
}
#[tokio::test]
async fn rejects_when_limit_exceeded() {
let limiter = per_minute(2);
let app = test_app(limiter.clone());
let _ = app.oneshot(test_request()).await;
let app = test_app(limiter.clone());
let _ = app.oneshot(test_request()).await;
let app = test_app(limiter);
let resp = app.oneshot(test_request()).await.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
assert_eq!(json_val["error"]["code"], "RATE_LIMIT_EXCEEDED");
}
#[tokio::test]
async fn includes_retry_after_header() {
let limiter = per_minute(1);
let app = test_app(limiter.clone());
let _ = app.oneshot(test_request()).await;
let app = test_app(limiter);
let resp = app.oneshot(test_request()).await.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert!(resp.headers().contains_key("retry-after"));
}
#[tokio::test]
async fn different_ips_have_separate_limits() {
let limiter = per_minute(1);
let app = test_app(limiter.clone());
let req_ip1 = Request::builder()
.uri("/test")
.header("x-forwarded-for", "10.0.0.1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req_ip1).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = test_app(limiter);
let req_ip2 = Request::builder()
.uri("/test")
.header("x-forwarded-for", "10.0.0.2")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req_ip2).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn extracts_ip_from_x_real_ip() {
let limiter = per_minute(1);
let app = test_app(limiter.clone());
let req = Request::builder()
.uri("/test")
.header("x-real-ip", "192.168.1.1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = test_app(limiter);
let req = Request::builder()
.uri("/test")
.header("x-real-ip", "192.168.1.1")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn general_limiter_allows_more_requests() {
let limiter = per_minute(60);
for _ in 0..10 {
let app = test_app(limiter.clone());
let resp = app.oneshot(test_request()).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}
#[test]
fn extract_ip_x_forwarded_for_first_ip() {
let req = Request::builder()
.uri("/test")
.header("x-forwarded-for", "1.2.3.4, 5.6.7.8")
.body(Body::empty())
.unwrap();
let ip = extract_client_ip(&req);
assert_eq!(ip, "1.2.3.4".parse::<IpAddr>().unwrap());
}
#[test]
fn extract_ip_fallback_to_unspecified() {
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
let ip = extract_client_ip(&req);
assert_eq!(ip, IpAddr::V4(Ipv4Addr::UNSPECIFIED));
}
}