use axum::{
body::Body,
http::{Request, Response, StatusCode},
};
use dashmap::DashMap;
use std::{
future::Future,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
task::{Context, Poll},
time::{Duration, Instant},
};
use tower::{Layer, Service};
#[derive(Clone, Debug)]
pub enum KeyExtractor {
ForwardedIp,
PeerIp,
UserId,
TenantId,
LoginIdentifier,
Header(String),
}
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub key_extractor: KeyExtractor,
pub retry_after: bool,
}
pub struct RateLimitConfigBuilder {
max_requests: u32,
window: Duration,
key_extractor: KeyExtractor,
retry_after: bool,
}
impl RateLimitConfig {
pub fn builder() -> RateLimitConfigBuilder {
RateLimitConfigBuilder {
max_requests: 100,
window: Duration::from_secs(60),
key_extractor: KeyExtractor::PeerIp,
retry_after: true,
}
}
}
impl RateLimitConfigBuilder {
pub fn max_requests(mut self, n: u32) -> Self {
self.max_requests = n;
self
}
pub fn window(mut self, d: Duration) -> Self {
self.window = d;
self
}
pub fn key(mut self, k: KeyExtractor) -> Self {
self.key_extractor = k;
self
}
pub fn retry_after(mut self, enabled: bool) -> Self {
self.retry_after = enabled;
self
}
pub fn build(self) -> RateLimitConfig {
assert!(
self.max_requests > 0,
"RateLimitConfig: max_requests must be > 0"
);
assert!(
!self.window.is_zero(),
"RateLimitConfig: window must be > 0"
);
if should_warn_very_low_max_requests(self.max_requests) {
tracing::warn!(
max_requests = self.max_requests,
"RateLimitConfig: very low max_requests; consider at least 5 to avoid blocking legitimate users"
);
}
if should_warn_very_long_window(self.window) {
tracing::warn!(
window_secs = self.window.as_secs(),
"RateLimitConfig: window exceeds 1 hour; long windows increase memory usage per bucket"
);
}
if matches!(self.key_extractor, KeyExtractor::ForwardedIp) {
tracing::warn!(
"RateLimitConfig: using ForwardedIp key extractor. \
This reads X-Forwarded-For / X-Real-IP headers which are \
client-spoofable unless set by a trusted reverse proxy. \
Ensure your proxy strips client-supplied forwarded headers \
before adding its own. For direct-to-client deployments, \
use KeyExtractor::PeerIp instead."
);
}
RateLimitConfig {
max_requests: self.max_requests,
window: self.window,
key_extractor: self.key_extractor,
retry_after: self.retry_after,
}
}
}
#[derive(Debug)]
struct TokenBucket {
remaining: u32,
window_start: Instant,
}
impl TokenBucket {
fn new(max: u32, now: Instant) -> Self {
Self {
remaining: max,
window_start: now,
}
}
}
#[derive(Clone)]
struct BucketStore {
buckets: Arc<DashMap<String, TokenBucket>>,
max_requests: u32,
window: Duration,
request_count: Arc<AtomicU64>,
}
enum Acquire {
Allowed { remaining: u32 },
Limited { retry_after_secs: u64 },
}
impl BucketStore {
fn new(max_requests: u32, window: Duration) -> Self {
Self {
buckets: Arc::new(DashMap::new()),
max_requests,
window,
request_count: Arc::new(AtomicU64::new(0)),
}
}
fn try_acquire(&self, key: &str) -> Acquire {
self.try_acquire_at(key, Instant::now())
}
fn try_acquire_at(&self, key: &str, now: Instant) -> Acquire {
let mut entry = self
.buckets
.entry(key.to_owned())
.or_insert_with(|| TokenBucket::new(self.max_requests, now));
let bucket = entry.value_mut();
if now.duration_since(bucket.window_start) >= self.window {
bucket.remaining = self.max_requests;
bucket.window_start = now;
}
if bucket.remaining > 0 {
bucket.remaining -= 1;
Acquire::Allowed {
remaining: bucket.remaining,
}
} else {
let elapsed = now.duration_since(bucket.window_start);
let retry_after = self.window.saturating_sub(elapsed).as_secs().max(1);
Acquire::Limited {
retry_after_secs: retry_after,
}
}
}
fn evict_expired(&self) {
self.evict_expired_at(Instant::now());
}
fn evict_expired_at(&self, now: Instant) {
self.buckets
.retain(|_, bucket| now.duration_since(bucket.window_start) < self.window);
}
}
#[derive(Clone, Debug)]
pub struct RateLimitUserId(pub axess_identity::UserId);
#[derive(Clone, Debug)]
pub struct RateLimitTenantId(pub axess_identity::TenantId);
#[derive(Clone, Debug)]
pub struct RateLimitLoginIdentifier(String);
impl RateLimitLoginIdentifier {
pub fn new(identifier: impl AsRef<str>) -> Option<Self> {
let trimmed = identifier.as_ref().trim();
if trimmed.is_empty() {
return None;
}
Some(Self(trimmed.to_lowercase()))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
const ANONYMOUS_BUCKET: &str = "__anonymous__";
const MAX_KEY_LEN: usize = 256;
fn should_warn_very_low_max_requests(max_requests: u32) -> bool {
max_requests < 5
}
fn should_warn_very_long_window(window: std::time::Duration) -> bool {
window > std::time::Duration::from_secs(3600)
}
fn truncate_key(mut key: String) -> String {
if key.len() > MAX_KEY_LEN {
let mut cut = MAX_KEY_LEN;
while !key.is_char_boundary(cut) {
cut -= 1;
}
key.truncate(cut);
}
key
}
fn extract_key(req: &Request<Body>, extractor: &KeyExtractor) -> String {
let raw = match extractor {
KeyExtractor::ForwardedIp => extract_forwarded_ip(req),
KeyExtractor::PeerIp => extract_peer_ip(req),
KeyExtractor::UserId => req
.extensions()
.get::<RateLimitUserId>()
.map(|u| u.0.to_string())
.or_else(|| header_str(req, "x-user-id"))
.unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
KeyExtractor::TenantId => req
.extensions()
.get::<RateLimitTenantId>()
.map(|t| t.0.to_string())
.or_else(|| header_str(req, "x-tenant-id"))
.unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
KeyExtractor::LoginIdentifier => req
.extensions()
.get::<RateLimitLoginIdentifier>()
.map(|i| i.as_str().to_owned())
.unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
KeyExtractor::Header(name) => header_str(req, name).unwrap_or_else(|| extract_peer_ip(req)),
};
truncate_key(raw)
}
fn extract_forwarded_ip(req: &Request<Body>) -> String {
let headers = req.headers();
headers
.get("x-real-ip")
.or_else(|| headers.get("x-forwarded-for"))
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_owned())
.unwrap_or_else(|| extract_peer_ip(req))
}
fn extract_peer_ip(req: &Request<Body>) -> String {
let ext = req.extensions();
let addr = ext
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0)
.or_else(|| ext.get::<std::net::SocketAddr>().copied());
addr.map(|addr| addr.ip().to_string()).unwrap_or_else(|| {
use std::sync::Once;
static WARN: Once = Once::new();
WARN.call_once(|| {
tracing::warn!(
"PeerIp rate limiting: no SocketAddr in request extensions; \
all requests will share a single bucket. Use \
Router::into_make_service_with_connect_info::<SocketAddr>() \
or switch to KeyExtractor::ForwardedIp behind a trusted proxy."
);
});
"unknown".to_owned()
})
}
fn header_str(req: &Request<Body>, name: &str) -> Option<String> {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned())
}
#[derive(Clone)]
pub struct RateLimitLayer {
store: BucketStore,
key_extractor: KeyExtractor,
retry_after: bool,
metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
}
impl RateLimitLayer {
pub fn new(config: RateLimitConfig) -> Self {
Self {
store: BucketStore::new(config.max_requests, config.window),
key_extractor: config.key_extractor,
retry_after: config.retry_after,
metrics: None,
}
}
pub fn with_metrics(mut self, metrics: impl crate::metrics::AuthnMetrics) -> Self {
self.metrics = Some(Arc::new(metrics));
self
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
store: self.store.clone(),
key_extractor: self.key_extractor.clone(),
retry_after: self.retry_after,
metrics: self.metrics.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
store: BucketStore,
key_extractor: KeyExtractor,
retry_after: bool,
metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
}
impl<S, ResBody> Service<Request<Body>> for RateLimitService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ResBody: Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let key = extract_key(&req, &self.key_extractor);
let retry_after_enabled = self.retry_after;
const EVICT_INTERVAL: u64 = 128;
const SOFT_BUCKET_CAP: usize = 64 * 1024;
let count = self.store.request_count.fetch_add(1, Ordering::Relaxed);
if (count.is_multiple_of(EVICT_INTERVAL) || self.store.buckets.len() > SOFT_BUCKET_CAP)
&& !self.store.buckets.is_empty()
{
self.store.evict_expired();
}
let metrics = self.metrics.clone();
match self.store.try_acquire(&key) {
Acquire::Allowed { remaining } => {
if let Some(ref m) = metrics {
m.rate_limit_allowed();
}
let mut inner = self.inner.clone();
std::mem::swap(&mut inner, &mut self.inner);
Box::pin(async move {
let mut resp = inner.call(req).await?;
if let Ok(val) = axum::http::HeaderValue::from_str(&remaining.to_string()) {
resp.headers_mut().insert("x-ratelimit-remaining", val);
}
Ok(resp)
})
}
Acquire::Limited { retry_after_secs } => {
if let Some(ref m) = metrics {
m.rate_limit_rejected();
}
tracing::debug!(
key = %key,
retry_after = retry_after_secs,
"rate limit exceeded"
);
Box::pin(async move {
let mut response = Response::new(ResBody::default());
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
if retry_after_enabled
&& let Ok(val) =
axum::http::HeaderValue::from_str(&retry_after_secs.to_string())
{
response.headers_mut().insert("retry-after", val);
}
Ok(response)
})
}
}
}
}
#[cfg(test)]
mod tests;