use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use axum::{
body::Body,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use tracing::warn;
use super::{config::RateLimitConfig, dispatch::RateLimiter, key::is_private_or_loopback};
#[derive(Debug)]
pub struct RateLimitExceeded {
pub retry_after_secs: u32,
}
impl IntoResponse for RateLimitExceeded {
fn into_response(self) -> Response {
let retry = self.retry_after_secs;
let retry_str = retry.to_string();
let body = format!(
r#"{{"errors":[{{"message":"Rate limit exceeded. Please retry after {retry} second{s}."}}]}}"#,
s = if retry == 1 { "" } else { "s" }
);
(
StatusCode::TOO_MANY_REQUESTS,
[
("Content-Type", "application/json"),
("Retry-After", retry_str.as_str()),
],
body,
)
.into_response()
}
}
static PROXY_WARNING_LOGGED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub(super) fn extract_real_ip(
req: &Request<Body>,
trust_proxy: bool,
trusted_cidrs: &[ipnet::IpNet],
addr: &SocketAddr,
) -> String {
if trust_proxy {
if !trusted_cidrs.is_empty() {
let direct: IpAddr = addr.ip();
let from_trusted_proxy = trusted_cidrs.iter().any(|cidr| cidr.contains(&direct));
if !from_trusted_proxy {
tracing::debug!(
%direct,
"Connection not from a trusted proxy CIDR; ignoring X-Forwarded-For"
);
return direct.to_string();
}
}
if let Some(real_ip) = req
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
{
return real_ip.to_string();
}
if let Some(xff) = req.headers().get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
if let Some(first) = xff.split(',').next().map(str::trim).filter(|s| !s.is_empty()) {
return first.to_string();
}
}
} else if is_private_or_loopback(addr.ip())
&& !PROXY_WARNING_LOGGED.load(std::sync::atomic::Ordering::Relaxed)
&& !PROXY_WARNING_LOGGED.swap(true, std::sync::atomic::Ordering::Relaxed)
{
warn!(
peer_ip = %addr.ip(),
"Rate limiter: peer address is loopback/RFC-1918 — server appears to be \
behind a reverse proxy. All requests will share a single rate-limit bucket \
unless you set `trust_proxy_headers = true` in [security.rate_limiting]."
);
}
addr.ip().to_string()
}
pub(super) fn extract_jwt_subject(authorization: &str) -> Option<String> {
use base64::Engine as _;
let token = authorization.strip_prefix("Bearer ")?;
let payload_b64 = token.split('.').nth(1)?;
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64).ok()?;
let json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
json.get("sub").and_then(|v| v.as_str()).map(String::from)
}
#[allow(clippy::cognitive_complexity)] pub async fn rate_limit_middleware(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request<Body>,
next: Next,
) -> Result<Response, RateLimitExceeded> {
let limiter = req
.extensions()
.get::<Arc<RateLimiter>>()
.cloned()
.unwrap_or_else(|| Arc::new(RateLimiter::new(RateLimitConfig::default())));
let ip = extract_real_ip(
&req,
limiter.config().trust_proxy_headers,
&limiter.config().trusted_proxy_cidrs,
&addr,
);
let path = req.uri().path().to_string();
let user_id = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(extract_jwt_subject);
let tenant_id = req
.headers()
.get("X-Tenant-ID")
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let path_result = limiter.check_path_limit(&path, &ip).await;
if !path_result.allowed {
warn!(ip = %ip, path = %path, "Per-path rate limit exceeded");
return Err(RateLimitExceeded {
retry_after_secs: path_result.retry_after_secs,
});
}
let limit_result = if let Some(ref uid) = user_id {
limiter.check_user_limit(uid, tenant_id.as_deref()).await
} else {
limiter.check_ip_limit(&ip, tenant_id.as_deref()).await
};
if !limit_result.allowed {
if let Some(ref uid) = user_id {
warn!(user_id = %uid, "Per-user rate limit exceeded");
} else {
warn!(ip = %ip, "IP rate limit exceeded");
}
return Err(RateLimitExceeded {
retry_after_secs: limit_result.retry_after_secs,
});
}
let remaining = limit_result.remaining;
let response = next.run(req).await;
let mut response = response;
let limit = if user_id.is_some() {
limiter.config().rps_per_user
} else {
limiter.config().rps_per_ip
};
if let Ok(limit_value) = format!("{limit}").parse() {
response.headers_mut().insert("X-RateLimit-Limit", limit_value);
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
if let Ok(remaining_value) = format!("{}", remaining as u32).parse() {
response.headers_mut().insert("X-RateLimit-Remaining", remaining_value);
}
Ok(response)
}