use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use axum::{
http::{header, HeaderValue, Request, Response, StatusCode},
response::IntoResponse,
};
use tower::{Layer, Service};
use crate::limiter::{LimitResult, Limiter};
pub struct ThrottleRule {
pub key_fn: Arc<dyn Fn(&axum::http::request::Parts) -> String + Send + Sync>,
pub requests: u64,
pub window_secs: u64,
}
impl ThrottleRule {
pub fn global(key: impl Into<String>, requests: u64, window_secs: u64) -> Self {
let key: Arc<str> = key.into().into();
Self {
key_fn: Arc::new(move |_| key.as_ref().to_string()),
requests,
window_secs,
}
}
pub fn by_ip(prefix: impl Into<String>, requests: u64, window_secs: u64) -> Self {
let prefix: Arc<str> = prefix.into().into();
Self {
key_fn: Arc::new(move |parts| {
let ip = parts
.headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
format!("{}:ip:{}", prefix, ip)
}),
requests,
window_secs,
}
}
}
#[derive(Clone)]
pub struct ThrottleLayer {
limiter: Arc<Limiter>,
rules: Arc<Vec<ThrottleRule>>,
}
impl ThrottleLayer {
pub fn global(key: impl Into<String>, requests: u64, window_secs: u64) -> Self {
Self::new(
crate::global_limiter().clone(),
vec![ThrottleRule::global(key, requests, window_secs)],
)
}
pub fn by_ip(prefix: impl Into<String>, requests: u64, window_secs: u64) -> Self {
Self::new(
crate::global_limiter().clone(),
vec![ThrottleRule::by_ip(prefix, requests, window_secs)],
)
}
pub fn new(limiter: Limiter, rules: Vec<ThrottleRule>) -> Self {
Self {
limiter: Arc::new(limiter),
rules: Arc::new(rules),
}
}
}
impl<S> Layer<S> for ThrottleLayer {
type Service = ThrottleService<S>;
fn layer(&self, inner: S) -> Self::Service {
ThrottleService {
inner,
limiter: Arc::clone(&self.limiter),
rules: Arc::clone(&self.rules),
}
}
}
#[derive(Clone)]
pub struct ThrottleService<S> {
inner: S,
limiter: Arc<Limiter>,
rules: Arc<Vec<ThrottleRule>>,
}
impl<S, B> Service<Request<B>> for ThrottleService<S>
where
S: Service<Request<B>, Response = Response<axum::body::Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
B: Send + 'static,
{
type Response = Response<axum::body::Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, S::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let limiter = Arc::clone(&self.limiter);
let rules = Arc::clone(&self.rules);
let mut inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let mut last_allowed: Option<(u64, u64, u64)> = None; for rule in rules.iter() {
let key = (rule.key_fn)(&parts);
match limiter
.for_key(key)
.requests(rule.requests)
.per(std::time::Duration::from_secs(rule.window_secs))
.check()
{
LimitResult::Exceeded { retry_after_secs } => {
return Ok(
rate_limit_exceeded(retry_after_secs, rule.requests).into_response()
);
}
LimitResult::Allowed {
remaining,
reset_epoch,
} => {
last_allowed = Some((remaining, reset_epoch, rule.requests));
}
}
}
let req = Request::from_parts(parts, body);
let mut resp = inner.call(req).await?;
if let Some((remaining, reset_epoch, limit)) = last_allowed {
let headers = resp.headers_mut();
set_header(headers, "x-ratelimit-limit", &limit.to_string());
set_header(headers, "x-ratelimit-remaining", &remaining.to_string());
set_header(headers, "x-ratelimit-reset", &reset_epoch.to_string());
}
Ok(resp)
})
}
}
fn rate_limit_exceeded(retry_after_secs: u64, limit: u64) -> impl IntoResponse {
let body = axum::Json(serde_json::json!({
"error": "too_many_requests",
"message": "rate limit exceeded",
"retry_after": retry_after_secs,
}));
let mut resp = (StatusCode::TOO_MANY_REQUESTS, body).into_response();
let headers = resp.headers_mut();
set_header(
headers,
header::RETRY_AFTER.as_str(),
&retry_after_secs.to_string(),
);
set_header(headers, "x-ratelimit-limit", &limit.to_string());
set_header(headers, "x-ratelimit-remaining", "0");
resp
}
fn set_header(headers: &mut axum::http::HeaderMap, name: &str, value: &str) {
if let (Ok(name), Ok(val)) = (
name.parse::<header::HeaderName>(),
HeaderValue::from_str(value),
) {
headers.insert(name, val);
}
}