use std::{num::NonZeroU32, sync::Arc};
use axum::{
body::Body,
extract::Request,
http::{HeaderValue, Response, StatusCode},
middleware::{from_fn, Next},
};
use dashmap::DashMap;
use governor::{
clock::DefaultClock,
middleware::NoOpMiddleware,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use serde_json::json;
type Limiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
pub fn attach<S>(router: axum::Router<S>, rps: u32) -> axum::Router<S>
where
S: Clone + Send + Sync + 'static,
{
let quota = Quota::per_second(NonZeroU32::new(rps.max(1)).expect("rps.max(1) >= 1"));
let map: Arc<DashMap<String, Arc<Limiter>>> = Arc::new(DashMap::new());
router.layer(from_fn(move |req: Request, next: Next| {
let map = Arc::clone(&map);
async move {
let agent_id = req
.headers()
.get("X-Reposix-Agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("anonymous")
.to_owned();
let limiter = {
if let Some(existing) = map.get(&agent_id) {
Arc::clone(existing.value())
} else {
let new_limiter = Arc::new(RateLimiter::direct(quota));
map.insert(agent_id.clone(), Arc::clone(&new_limiter));
new_limiter
}
};
match limiter.check() {
Ok(()) => next.run(req).await,
Err(_) => rate_limited_response(),
}
}
}))
}
fn rate_limited_response() -> Response<Body> {
let body = json!({
"error": "rate_limited",
"retry_after_secs": 1,
})
.to_string();
let mut resp = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("content-type", "application/json")
.body(Body::from(body))
.expect("static response");
resp.headers_mut()
.insert("Retry-After", HeaderValue::from_static("1"));
resp
}
#[cfg(test)]
mod tests {
use super::attach;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use tower::ServiceExt;
async fn always_ok() -> StatusCode {
StatusCode::NO_CONTENT
}
#[tokio::test]
async fn rate_limit_rps_1_denies_second_call() {
let app: Router = attach(Router::new().route("/z", get(always_ok)), 1);
let first = app
.clone()
.oneshot(
Request::builder()
.uri("/z")
.header("X-Reposix-Agent", "burst")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(first.status(), 204);
let second = app
.oneshot(
Request::builder()
.uri("/z")
.header("X-Reposix-Agent", "burst")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(second.status(), 429);
assert_eq!(
second
.headers()
.get("Retry-After")
.unwrap()
.to_str()
.unwrap(),
"1"
);
}
#[tokio::test]
async fn rate_limit_is_per_agent() {
let app: Router = attach(Router::new().route("/z", get(always_ok)), 1);
let make = |agent: &'static str| {
Request::builder()
.uri("/z")
.header("X-Reposix-Agent", agent)
.body(Body::empty())
.unwrap()
};
assert_eq!(app.clone().oneshot(make("a")).await.unwrap().status(), 204);
assert_eq!(app.clone().oneshot(make("a")).await.unwrap().status(), 429);
assert_eq!(app.oneshot(make("b")).await.unwrap().status(), 204);
}
}