axess_core/middleware/ratelimit.rs
1//! Composable tower layer for token-bucket rate limiting.
2//!
3//! Provides defense-in-depth rate limiting that complements any external API
4//! gateway. Buckets are keyed per IP, user, tenant, or arbitrary header value.
5//!
6//! # Key extractors
7//!
8//! | Extractor | Source | Use case |
9//! |-----------|--------|----------|
10//! | `KeyExtractor::PeerIp` | `SocketAddr` from `ConnectInfo` | **Default.** Safe for direct connections. |
11//! | `KeyExtractor::ForwardedIp` | `X-Forwarded-For` header | Behind a **trusted** reverse proxy only. |
12//! | `KeyExtractor::UserId` | `RateLimitUserId` request extension | Per-user limits (set after authentication). |
13//! | `KeyExtractor::TenantId` | `RateLimitTenantId` request extension | Per-tenant limits. |
14//! | `KeyExtractor::LoginIdentifier` | `RateLimitLoginIdentifier` request extension | **Required for login routes** to mitigate per-username lockout DoS. |
15//! | `KeyExtractor::Header(name)` | Arbitrary header | Custom keys (e.g. API key). |
16//!
17//! # Login routes: required pairing
18//!
19//! Per-IP rate limiting alone (`KeyExtractor::PeerIp`) does NOT defend
20//! against a per-username lockout DoS: an attacker rotating IPs (or a
21//! botnet) can issue 5 wrong-password POSTs per known username and lock
22//! the legitimate user out. Always pair `PeerIp` with
23//! [`KeyExtractor::LoginIdentifier`](crate::middleware::ratelimit::KeyExtractor::LoginIdentifier) on `/login`-class routes:
24//!
25//! ```rust,ignore
26//! use axess_core::middleware::ratelimit::{KeyExtractor, RateLimitConfig, RateLimitLayer, RateLimitLoginIdentifier};
27//! use std::time::Duration;
28//!
29//! // The application parses the login form, then injects the identifier
30//! // into request extensions BEFORE the rate-limit layer runs.
31//! let per_username = RateLimitLayer::new(
32//! RateLimitConfig::builder()
33//! .max_requests(10)
34//! .window(Duration::from_secs(60))
35//! .key(KeyExtractor::LoginIdentifier)
36//! .build(),
37//! );
38//! let per_ip = RateLimitLayer::new(
39//! RateLimitConfig::builder()
40//! .max_requests(60)
41//! .window(Duration::from_secs(60))
42//! .key(KeyExtractor::PeerIp)
43//! .build(),
44//! );
45//! let login_route = axum::Router::new()
46//! .route("/login", axum::routing::post(login_handler))
47//! .layer(per_username)
48//! .layer(per_ip);
49//! ```
50//!
51//! The identifier is normalised to lowercase before keying so `Alice`
52//! and `alice` share a bucket.
53//!
54//! # Layered rate limiting
55//!
56//! In production, layer per-IP (volumetric), per-user (brute-force), and
57//! an external gateway (distributed). This library provides only the
58//! in-process buckets; distributed rate limiting belongs at the gateway
59//! with a shared store (Valkey / Redis).
60
61use axum::{
62 body::Body,
63 http::{Request, Response, StatusCode},
64};
65use dashmap::DashMap;
66use std::{
67 future::Future,
68 pin::Pin,
69 sync::{
70 Arc,
71 atomic::{AtomicU64, Ordering},
72 },
73 task::{Context, Poll},
74 time::{Duration, Instant},
75};
76use tower::{Layer, Service};
77
78// ── Configuration ────────────────────────────────────────────────────────────
79
80/// What to use as the rate-limit bucket key.
81#[derive(Clone, Debug)]
82pub enum KeyExtractor {
83 /// Rate limit by client IP from a trusted reverse proxy header.
84 ///
85 /// Reads `X-Real-IP` then `X-Forwarded-For` (first entry). Use this only
86 /// when deployed behind a reverse proxy (NGINX, Envoy, ALB, Cloudflare)
87 /// that sets these headers from the real peer address and strips
88 /// client-supplied values.
89 ///
90 /// Without a trusted proxy, clients can spoof these headers to bypass
91 /// rate limiting entirely. For direct-to-client deployments, use
92 /// [`PeerIp`](KeyExtractor::PeerIp) instead.
93 ForwardedIp,
94 /// Rate limit by TCP peer address via axum's `ConnectInfo`.
95 ///
96 /// Requires `Router::into_make_service_with_connect_info::<SocketAddr>()`
97 /// on your axum server. Falls back to a shared bucket (`"unknown"`) if
98 /// `ConnectInfo` is not available; this is fail-closed (all requests
99 /// share one bucket = stricter limiting).
100 PeerIp,
101 /// Rate limit by authenticated user (reads `x-user-id` header or request extension).
102 UserId,
103 /// Rate limit by tenant (reads `x-tenant-id` header or request extension).
104 TenantId,
105 /// Rate limit by the **login identifier** (username/email submitted to a
106 /// login route). Reads the [`RateLimitLoginIdentifier`] request extension
107 /// that the application sets after parsing the login form.
108 ///
109 /// # Why it exists
110 ///
111 /// `PeerIp` rate-limits the source. `LoginIdentifier` rate-limits the
112 /// *target*. Without this layer, an attacker rotating IPs (or simply
113 /// using a botnet) can issue 5 wrong-password POSTs per known username
114 /// and lock the legitimate user out of their account: a per-account
115 /// denial-of-service that needs no compromised credentials.
116 ///
117 /// Always pair this with `PeerIp` on login routes; see the module-level
118 /// example.
119 ///
120 /// Falls back to a shared anonymous-bucket sentinel when no extension is
121 /// present, so a misconfigured route fails closed (one shared bucket)
122 /// rather than evading the limit.
123 LoginIdentifier,
124 /// Rate limit by an arbitrary header value.
125 Header(String),
126}
127
128/// Rate-limit configuration for a single layer instance.
129#[derive(Clone, Debug)]
130pub struct RateLimitConfig {
131 /// Maximum number of requests allowed within the window.
132 pub max_requests: u32,
133 /// Duration of the sliding window.
134 pub window: Duration,
135 /// Strategy for deriving the bucket key from each request.
136 pub key_extractor: KeyExtractor,
137 /// When `true`, 429 responses include a `Retry-After` header (seconds).
138 pub retry_after: bool,
139}
140
141/// Builder for [`RateLimitConfig`].
142pub struct RateLimitConfigBuilder {
143 max_requests: u32,
144 window: Duration,
145 key_extractor: KeyExtractor,
146 retry_after: bool,
147}
148
149impl RateLimitConfig {
150 /// Start building a new configuration.
151 pub fn builder() -> RateLimitConfigBuilder {
152 RateLimitConfigBuilder {
153 max_requests: 100,
154 window: Duration::from_secs(60),
155 key_extractor: KeyExtractor::PeerIp,
156 retry_after: true,
157 }
158 }
159}
160
161impl RateLimitConfigBuilder {
162 /// Maximum requests per window (default: 100).
163 pub fn max_requests(mut self, n: u32) -> Self {
164 self.max_requests = n;
165 self
166 }
167
168 /// Window duration (default: 60 s).
169 pub fn window(mut self, d: Duration) -> Self {
170 self.window = d;
171 self
172 }
173
174 /// Bucket key strategy (default: [`KeyExtractor::PeerIp`]).
175 pub fn key(mut self, k: KeyExtractor) -> Self {
176 self.key_extractor = k;
177 self
178 }
179
180 /// Include `Retry-After` header in 429 responses (default: `true`).
181 pub fn retry_after(mut self, enabled: bool) -> Self {
182 self.retry_after = enabled;
183 self
184 }
185
186 /// Consume the builder and produce a [`RateLimitConfig`].
187 ///
188 /// # Panics
189 ///
190 /// Panics if `max_requests` is zero or `window` is zero.
191 pub fn build(self) -> RateLimitConfig {
192 assert!(
193 self.max_requests > 0,
194 "RateLimitConfig: max_requests must be > 0"
195 );
196 assert!(
197 !self.window.is_zero(),
198 "RateLimitConfig: window must be > 0"
199 );
200 // Warn on suspicious configurations that are unlikely to be intentional.
201 if should_warn_very_low_max_requests(self.max_requests) {
202 tracing::warn!(
203 max_requests = self.max_requests,
204 "RateLimitConfig: very low max_requests; consider at least 5 to avoid blocking legitimate users"
205 );
206 }
207 if should_warn_very_long_window(self.window) {
208 tracing::warn!(
209 window_secs = self.window.as_secs(),
210 "RateLimitConfig: window exceeds 1 hour; long windows increase memory usage per bucket"
211 );
212 }
213 if matches!(self.key_extractor, KeyExtractor::ForwardedIp) {
214 tracing::warn!(
215 "RateLimitConfig: using ForwardedIp key extractor. \
216 This reads X-Forwarded-For / X-Real-IP headers which are \
217 client-spoofable unless set by a trusted reverse proxy. \
218 Ensure your proxy strips client-supplied forwarded headers \
219 before adding its own. For direct-to-client deployments, \
220 use KeyExtractor::PeerIp instead."
221 );
222 }
223 RateLimitConfig {
224 max_requests: self.max_requests,
225 window: self.window,
226 key_extractor: self.key_extractor,
227 retry_after: self.retry_after,
228 }
229 }
230}
231
232// ── Token bucket ─────────────────────────────────────────────────────────────
233
234/// A single token bucket tracking remaining tokens and the window start.
235#[derive(Debug)]
236struct TokenBucket {
237 /// Tokens remaining in the current window.
238 remaining: u32,
239 /// When the current window started.
240 window_start: Instant,
241}
242
243impl TokenBucket {
244 fn new(max: u32, now: Instant) -> Self {
245 Self {
246 remaining: max,
247 window_start: now,
248 }
249 }
250}
251
252/// Shared bucket store. [`DashMap`] gives lock-free concurrent reads/writes.
253#[derive(Clone)]
254struct BucketStore {
255 buckets: Arc<DashMap<String, TokenBucket>>,
256 max_requests: u32,
257 window: Duration,
258 /// Request counter for deterministic eviction scheduling.
259 request_count: Arc<AtomicU64>,
260}
261
262/// Result of trying to acquire a token.
263enum Acquire {
264 /// Token granted; remaining count returned for `X-RateLimit-Remaining` header.
265 Allowed { remaining: u32 },
266 /// Rate limited; seconds until the bucket resets.
267 Limited { retry_after_secs: u64 },
268}
269
270impl BucketStore {
271 fn new(max_requests: u32, window: Duration) -> Self {
272 Self {
273 buckets: Arc::new(DashMap::new()),
274 max_requests,
275 window,
276 request_count: Arc::new(AtomicU64::new(0)),
277 }
278 }
279
280 fn try_acquire(&self, key: &str) -> Acquire {
281 self.try_acquire_at(key, Instant::now())
282 }
283
284 /// Deterministic variant accepting an explicit `now` for testing.
285 fn try_acquire_at(&self, key: &str, now: Instant) -> Acquire {
286 let mut entry = self
287 .buckets
288 .entry(key.to_owned())
289 .or_insert_with(|| TokenBucket::new(self.max_requests, now));
290
291 let bucket = entry.value_mut();
292
293 // Reset window if expired.
294 if now.duration_since(bucket.window_start) >= self.window {
295 bucket.remaining = self.max_requests;
296 bucket.window_start = now;
297 }
298
299 if bucket.remaining > 0 {
300 bucket.remaining -= 1;
301 Acquire::Allowed {
302 remaining: bucket.remaining,
303 }
304 } else {
305 let elapsed = now.duration_since(bucket.window_start);
306 let retry_after = self.window.saturating_sub(elapsed).as_secs().max(1);
307 Acquire::Limited {
308 retry_after_secs: retry_after,
309 }
310 }
311 }
312
313 /// Remove buckets whose window has expired. Called lazily to avoid
314 /// accumulating stale entries over long uptimes.
315 fn evict_expired(&self) {
316 self.evict_expired_at(Instant::now());
317 }
318
319 /// Deterministic variant accepting an explicit `now` for testing.
320 fn evict_expired_at(&self, now: Instant) {
321 self.buckets
322 .retain(|_, bucket| now.duration_since(bucket.window_start) < self.window);
323 }
324}
325
326// ── Key extraction ───────────────────────────────────────────────────────────
327
328/// Request extension carrying a user ID (set by upstream authn middleware).
329///
330/// Wraps the validated [`UserId`](axess_identity::UserId) newtype rather
331/// than a raw `String`, so an upstream extractor cannot accidentally write
332/// an unvalidated identifier into the rate-limit key space.
333#[derive(Clone, Debug)]
334pub struct RateLimitUserId(pub axess_identity::UserId);
335
336/// Request extension carrying a tenant ID (set by upstream authn middleware).
337///
338/// Wraps the validated [`TenantId`](axess_identity::TenantId) newtype.
339#[derive(Clone, Debug)]
340pub struct RateLimitTenantId(pub axess_identity::TenantId);
341
342/// Request extension carrying the **login identifier** (username/email) for
343/// per-target rate limiting on login routes.
344///
345/// The application is responsible for parsing the login form body and
346/// inserting this extension into the request before the rate-limit layer
347/// runs. The identifier is lowercased on construction so case variation
348/// can't trivially evade the limit.
349///
350/// See [`KeyExtractor::LoginIdentifier`] and the module-level example
351///
352#[derive(Clone, Debug)]
353pub struct RateLimitLoginIdentifier(String);
354
355impl RateLimitLoginIdentifier {
356 /// Build a normalised login-identifier key. Trims whitespace, lowercases
357 /// (Unicode-aware), and rejects empty input by returning `None`; the
358 /// caller should treat that as the anonymous bucket.
359 pub fn new(identifier: impl AsRef<str>) -> Option<Self> {
360 let trimmed = identifier.as_ref().trim();
361 if trimmed.is_empty() {
362 return None;
363 }
364 Some(Self(trimmed.to_lowercase()))
365 }
366
367 /// The normalised key value used for bucket lookup.
368 pub fn as_str(&self) -> &str {
369 &self.0
370 }
371}
372
373/// Sentinel bucket key for unauthenticated requests when the extractor is
374/// scoped to an authenticated identity (`UserId` / `TenantId`). All such
375/// requests share a single bucket so the per-identity rate limit cannot be
376/// trivially evaded by sending unauthenticated traffic; they are
377/// rate-limited collectively under this key.
378const ANONYMOUS_BUCKET: &str = "__anonymous__";
379
380/// Hard cap on bucket-key length (bytes). Header- and identity-derived
381/// keys are truncated to this size before being used as a `DashMap` key, so
382/// an attacker who sets megabyte-sized `X-Forwarded-For` / `X-User-Id` /
383/// custom-header values cannot inflate per-bucket allocation. Real values
384/// (IP addresses, UUIDs, identifiers) are far smaller than this.
385const MAX_KEY_LEN: usize = 256;
386
387/// Whether the builder should emit the "very low max_requests" warning.
388fn should_warn_very_low_max_requests(max_requests: u32) -> bool {
389 max_requests < 5
390}
391
392/// Whether the builder should emit the "window exceeds 1 hour" warning.
393fn should_warn_very_long_window(window: std::time::Duration) -> bool {
394 window > std::time::Duration::from_secs(3600)
395}
396
397fn truncate_key(mut key: String) -> String {
398 if key.len() > MAX_KEY_LEN {
399 // Truncate at a UTF-8 boundary to keep the resulting string valid.
400 let mut cut = MAX_KEY_LEN;
401 while !key.is_char_boundary(cut) {
402 cut -= 1;
403 }
404 key.truncate(cut);
405 }
406 key
407}
408
409fn extract_key(req: &Request<Body>, extractor: &KeyExtractor) -> String {
410 let raw = match extractor {
411 KeyExtractor::ForwardedIp => extract_forwarded_ip(req),
412 KeyExtractor::PeerIp => extract_peer_ip(req),
413 KeyExtractor::UserId => req
414 .extensions()
415 .get::<RateLimitUserId>()
416 .map(|u| u.0.to_string())
417 .or_else(|| header_str(req, "x-user-id"))
418 .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
419 KeyExtractor::TenantId => req
420 .extensions()
421 .get::<RateLimitTenantId>()
422 .map(|t| t.0.to_string())
423 .or_else(|| header_str(req, "x-tenant-id"))
424 .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
425 KeyExtractor::LoginIdentifier => req
426 .extensions()
427 .get::<RateLimitLoginIdentifier>()
428 .map(|i| i.as_str().to_owned())
429 .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
430 KeyExtractor::Header(name) => header_str(req, name).unwrap_or_else(|| extract_peer_ip(req)),
431 };
432 truncate_key(raw)
433}
434
435/// Extract client IP from proxy headers (`X-Real-IP`, `X-Forwarded-For`).
436///
437/// Only use behind a trusted reverse proxy that sets these headers from the
438/// real peer address and strips client-supplied values.
439fn extract_forwarded_ip(req: &Request<Body>) -> String {
440 let headers = req.headers();
441 headers
442 .get("x-real-ip")
443 .or_else(|| headers.get("x-forwarded-for"))
444 .and_then(|v| v.to_str().ok())
445 .and_then(|s| s.split(',').next())
446 .map(|s| s.trim().to_owned())
447 .unwrap_or_else(|| extract_peer_ip(req))
448}
449
450/// Extract client IP from the TCP peer address stored in request extensions.
451///
452/// Looks for axum's `ConnectInfo<SocketAddr>` (populated when the server
453/// is mounted via `Router::into_make_service_with_connect_info::<SocketAddr>()`),
454/// then falls back to a bare `SocketAddr` (which some integrations insert
455/// directly, e.g. test harnesses that wire the extension by hand).
456///
457/// Falls back to a shared bucket (`"unknown"`) if no peer address is
458/// available. This is fail-closed: all unknown-origin requests share one
459/// bucket, meaning stricter (not weaker) rate limiting.
460fn extract_peer_ip(req: &Request<Body>) -> String {
461 let ext = req.extensions();
462 let addr = ext
463 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
464 .map(|ci| ci.0)
465 .or_else(|| ext.get::<std::net::SocketAddr>().copied());
466 addr.map(|addr| addr.ip().to_string()).unwrap_or_else(|| {
467 use std::sync::Once;
468 static WARN: Once = Once::new();
469 WARN.call_once(|| {
470 tracing::warn!(
471 "PeerIp rate limiting: no SocketAddr in request extensions; \
472 all requests will share a single bucket. Use \
473 Router::into_make_service_with_connect_info::<SocketAddr>() \
474 or switch to KeyExtractor::ForwardedIp behind a trusted proxy."
475 );
476 });
477 "unknown".to_owned()
478 })
479}
480
481fn header_str(req: &Request<Body>, name: &str) -> Option<String> {
482 req.headers()
483 .get(name)
484 .and_then(|v| v.to_str().ok())
485 .map(|s| s.to_owned())
486}
487
488// ── Tower Layer / Service ────────────────────────────────────────────────────
489
490/// Tower layer that adds token-bucket rate limiting.
491#[derive(Clone)]
492pub struct RateLimitLayer {
493 store: BucketStore,
494 key_extractor: KeyExtractor,
495 retry_after: bool,
496 metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
497}
498
499impl RateLimitLayer {
500 /// Create a new rate-limit layer from the given configuration.
501 pub fn new(config: RateLimitConfig) -> Self {
502 Self {
503 store: BucketStore::new(config.max_requests, config.window),
504 key_extractor: config.key_extractor,
505 retry_after: config.retry_after,
506 metrics: None,
507 }
508 }
509
510 /// Attach a metrics hook for rate-limit observability.
511 pub fn with_metrics(mut self, metrics: impl crate::metrics::AuthnMetrics) -> Self {
512 self.metrics = Some(Arc::new(metrics));
513 self
514 }
515}
516
517impl<S> Layer<S> for RateLimitLayer {
518 type Service = RateLimitService<S>;
519
520 fn layer(&self, inner: S) -> Self::Service {
521 RateLimitService {
522 inner,
523 store: self.store.clone(),
524 key_extractor: self.key_extractor.clone(),
525 retry_after: self.retry_after,
526 metrics: self.metrics.clone(),
527 }
528 }
529}
530
531/// Tower service that enforces per-key rate limits before forwarding requests.
532#[derive(Clone)]
533pub struct RateLimitService<S> {
534 inner: S,
535 store: BucketStore,
536 key_extractor: KeyExtractor,
537 retry_after: bool,
538 metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
539}
540
541impl<S, ResBody> Service<Request<Body>> for RateLimitService<S>
542where
543 S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
544 S::Future: Send + 'static,
545 S::Error: Send + 'static,
546 ResBody: Default + Send + 'static,
547{
548 type Response = Response<ResBody>;
549 type Error = S::Error;
550 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
551
552 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
553 self.inner.poll_ready(cx)
554 }
555
556 fn call(&mut self, req: Request<Body>) -> Self::Future {
557 let key = extract_key(&req, &self.key_extractor);
558 let retry_after_enabled = self.retry_after;
559
560 // Lazy eviction. Two triggers:
561 // (1) every 128 requests; bounded steady-state cleanup that
562 // keeps the map small under normal load and tightens the
563 // window an attacker has to inflate it before cleanup fires
564 // (the previous 1024-request interval gave a much larger
565 // gap to inject distinct keys);
566 // (2) any time the bucket count exceeds the soft cap, force an
567 // immediate eviction sweep; this is the brake on
568 // attacker-driven growth where unique keys are sent at
569 // high rate.
570 // Combined with the per-key `MAX_KEY_LEN` truncation in
571 // `extract_key`, total memory usage stays bounded even under a
572 // unique-key flood: O(MAX_BUCKETS * (MAX_KEY_LEN + bucket-size)).
573 const EVICT_INTERVAL: u64 = 128;
574 const SOFT_BUCKET_CAP: usize = 64 * 1024;
575 let count = self.store.request_count.fetch_add(1, Ordering::Relaxed);
576 if (count.is_multiple_of(EVICT_INTERVAL) || self.store.buckets.len() > SOFT_BUCKET_CAP)
577 && !self.store.buckets.is_empty()
578 {
579 self.store.evict_expired();
580 }
581
582 let metrics = self.metrics.clone();
583 match self.store.try_acquire(&key) {
584 Acquire::Allowed { remaining } => {
585 if let Some(ref m) = metrics {
586 m.rate_limit_allowed();
587 }
588 // Clone inner *before* the async block; required by tower's contract.
589 let mut inner = self.inner.clone();
590 std::mem::swap(&mut inner, &mut self.inner);
591 Box::pin(async move {
592 let mut resp = inner.call(req).await?;
593 if let Ok(val) = axum::http::HeaderValue::from_str(&remaining.to_string()) {
594 resp.headers_mut().insert("x-ratelimit-remaining", val);
595 }
596 Ok(resp)
597 })
598 }
599 Acquire::Limited { retry_after_secs } => {
600 if let Some(ref m) = metrics {
601 m.rate_limit_rejected();
602 }
603 tracing::debug!(
604 key = %key,
605 retry_after = retry_after_secs,
606 "rate limit exceeded"
607 );
608 Box::pin(async move {
609 let mut response = Response::new(ResBody::default());
610 *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
611 if retry_after_enabled
612 && let Ok(val) =
613 axum::http::HeaderValue::from_str(&retry_after_secs.to_string())
614 {
615 response.headers_mut().insert("retry-after", val);
616 }
617 Ok(response)
618 })
619 }
620 }
621 }
622}
623
624// ── Tests ────────────────────────────────────────────────────────────────────
625
626#[cfg(test)]
627mod tests;